In [None]:
# !pip install numpy
# !pip install pandas
# !pip install scikit-learn
# !pip install torch
# !pip install transformers


from transformers import RobertaModel, RobertaTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, hamming_loss, roc_auc_score, average_precision_score
from collections import defaultdict
from torch.amp import autocast, GradScaler
import torch.nn.functional as F
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time

## Hyperparameters
MAX_LEN = 512
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
EPOCHS = 12
LEARNING_RATE = 1e-05
THRESHOLD = 0.5 # threshold for the sigmoid

## Dataset Class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        self.title = list(df['File Contents'])
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': title
        }

## Data
train_file = 'train.csv'
test_file = 'test.csv'
train_val_df = pd.read_csv(train_file)
test_df = pd.read_csv(test_file)

train_df, val_df = train_test_split(train_val_df, test_size=0.2, random_state=42)

target_list = list(train_df.columns[1:])

## Tokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')


train_dataset = CustomDataset(train_df, tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(val_df, tokenizer, MAX_LEN, target_list)
test_dataset = CustomDataset(test_df, tokenizer, MAX_LEN, target_list)

#print(train_dataset[0])

## Data Loader
train_data_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_data_loader = torch.utils.data.DataLoader(test_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

## Model

class RobertaBase(nn.Module):
    def __init__(self, num_classes):
        super(RobertaBase, self).__init__()
        self.roberta = RobertaModel.from_pretrained('roberta-base')

        # Dropout layer
        self.drop = nn.Dropout(0.3)

        # Fully connected layer for classification
        self.fc = nn.Linear(768, num_classes)

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        # RoBERTa features
        outputs = self.roberta(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        roberta_features = outputs.last_hidden_state  # (batch_size, seq_length, 768)

        # Apply dropout
        roberta_features = self.drop(roberta_features)  # (batch_size, seq_length, 768)

        # Take the [CLS] token representation for classification (first token)
        cls_token = roberta_features[:, 0, :]  # (batch_size, 768)

        # Final classification
        output = self.fc(cls_token)  # (batch_size, num_classes)

        return output


## Setting the model
model = RobertaBase(num_classes=len(target_list))
model.to(device)

## Loss Fn
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

# optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)


def train_model(training_loader, model, optimizer, accumulation_steps=1):
    losses = []
    correct_predictions = 0
    num_samples = 0
    total_batches = len(training_loader)

    # Set model to training mode (activate dropout, batch norm)
    model.train()

    for batch_idx, data in enumerate(training_loader):
        ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
        mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
        targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

        outputs = model(ids, mask, token_type_ids)
        loss = loss_fn(outputs, targets)
        losses.append(loss.item())


        outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        targets = targets.cpu().detach().numpy()
        correct_predictions += np.sum(outputs == targets)
        num_samples += targets.size

        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        optimizer.zero_grad()

        # Clear GPU cache
        torch.cuda.empty_cache()

    return model, float(correct_predictions) / num_samples, np.mean(losses)



def eval_model(validation_loader, model, threshold=0.5, target_list=None):
    model.eval()
    final_targets = []
    final_outputs = []
    final_probs = []
    losses = []

    with torch.no_grad():
        for data in validation_loader:
            ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
            mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
            targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

  
            outputs = model(ids, mask, token_type_ids)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

            probs = torch.sigmoid(outputs).cpu().detach().numpy()
            targets = targets.cpu().detach().numpy()
            final_outputs.extend(probs >= threshold)
            final_probs.extend(probs)
            final_targets.extend(targets)

            torch.cuda.empty_cache()

    final_outputs = np.array(final_outputs) >= threshold
    final_probs = np.array(final_probs)
    final_targets = np.array(final_targets)

    acc = accuracy_score(final_targets, final_outputs)
    f1 = f1_score(final_targets, final_outputs, average='weighted')
    precision = precision_score(final_targets, final_outputs, average='weighted')
    recall = recall_score(final_targets, final_outputs, average='weighted')
    hamming = hamming_loss(final_targets, final_outputs)

    auc_roc = roc_auc_score(final_targets, final_probs, average='weighted', multi_class='ovr')
    aupr = average_precision_score(final_targets, final_probs, average='weighted')

    average_loss = np.mean(losses)

    print(f"Accuracy: {acc}")
    print(f"F1 Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Hamming Loss: {hamming}")
    print(f"Average Loss: {average_loss}")
    print(f"AUC-ROC: {auc_roc}")
    print(f"AUPR: {aupr}")
    print("\nClassification Report:\n", classification_report(final_targets, final_outputs, target_names=target_list))

    return f1, average_loss


scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Training & Evaluation Loop
start = time.time()

history = defaultdict(list)
best_f1 = 0.0

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_loss = train_model(train_data_loader, model, optimizer)
    val_f1, val_loss = eval_model(val_data_loader, model)

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_f1'].append(val_f1)
    history['val_loss'].append(val_loss)

    scheduler.step()  # Stepping scheduler after each epoch

    # save the best model
    if val_f1 > best_f1:
        torch.save(model.state_dict(), "ohsumed_roberta_32_base_best.bin")
        best_f1 = val_f1

end = time.time()
print(f"Total training and evaluation time: {end - start} seconds")


## Testing
print("\n\nTesting\n\n")
model = RobertaBase(num_classes=len(target_list))
model.load_state_dict(torch.load("ohsumed_roberta_32_base_best.bin"))
model = model.to(device)

start = time.time()
eval_model(test_data_loader, model)
end = time.time()
print(f"Total test-set evaluation time: {end - start} seconds")

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/12


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.046104928457869634
F1 Score: 0.0879735219702536
Precision: 0.20180376495542374
Recall: 0.057345971563981045
Hamming Loss: 0.06939932259625355
Average Loss: 0.2195033598691225
AUC-ROC: 0.7328640895651005
AUPR: 0.3490395347044521

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.83      0.31      0.45       234
           4       0.00      0.00      0.00        51
           5       0.00      0.00      0.00       109
           6       0.00      0.00      0.00        16
           7       0.00      0.00      0.00        99
           8       0.00      0.00      0.00        20
           9       0.00      0.00      0.00       122
          10       0.00      0.00      0.00        26
          11       0.00      0.00      0.00       110
          12       0.00      0.0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.19475357710651828
F1 Score: 0.3027961300163896
Precision: 0.6166675781500146
Recall: 0.26492890995260665
Hamming Loss: 0.05792493260523951
Average Loss: 0.1750894282013178
AUC-ROC: 0.854189050401663
AUPR: 0.5716870439204056

Classification Report:
               precision    recall  f1-score   support

           0       1.00      0.01      0.02       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.83      0.78      0.80       234
           4       0.00      0.00      0.00        51
           5       0.80      0.32      0.46       109
           6       0.00      0.00      0.00        16
           7       0.82      0.28      0.42        99
           8       0.00      0.00      0.00        20
           9       0.75      0.02      0.05       122
          10       0.00      0.00      0.00        26
          11       0.00      0.00      0.00       110
          12       1.00      0.02   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2901430842607313
F1 Score: 0.4791338400683947
Precision: 0.6961461320750684
Recall: 0.3962085308056872
Hamming Loss: 0.0508398423999447
Average Loss: 0.15408915579319
AUC-ROC: 0.8822518234244461
AUPR: 0.6323693101416972

Classification Report:
               precision    recall  f1-score   support

           0       0.86      0.19      0.31       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.84      0.78      0.81       234
           4       0.58      0.14      0.22        51
           5       0.71      0.59      0.64       109
           6       0.00      0.00      0.00        16
           7       0.74      0.53      0.62        99
           8       0.00      0.00      0.00        20
           9       0.82      0.37      0.51       122
          10       0.00      0.00      0.00        26
          11       0.83      0.48      0.61       110
          12       0.75      0.38      0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3322734499205087
F1 Score: 0.5556813384119675
Precision: 0.7444016400999455
Recall: 0.4786729857819905
Hamming Loss: 0.047038086680030414
Average Loss: 0.14296302739530803
AUC-ROC: 0.8963438519330413
AUPR: 0.6636767166617751

Classification Report:
               precision    recall  f1-score   support

           0       0.81      0.47      0.60       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.85      0.80      0.83       234
           4       0.56      0.43      0.49        51
           5       0.71      0.64      0.68       109
           6       0.00      0.00      0.00        16
           7       0.75      0.59      0.66        99
           8       0.67      0.10      0.17        20
           9       0.74      0.43      0.55       122
          10       1.00      0.23      0.38        26
          11       0.86      0.54      0.66       110
          12       0.78      0.45  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3251192368839428
F1 Score: 0.546396719569408
Precision: 0.7437917344504839
Recall: 0.4663507109004739
Hamming Loss: 0.047072648095666
Average Loss: 0.13766265101730824
AUC-ROC: 0.8996243771263752
AUPR: 0.6785059350194833

Classification Report:
               precision    recall  f1-score   support

           0       0.87      0.46      0.60       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.86      0.82      0.84       234
           4       0.61      0.37      0.46        51
           5       0.72      0.64      0.68       109
           6       0.00      0.00      0.00        16
           7       0.82      0.42      0.56        99
           8       0.67      0.10      0.17        20
           9       0.76      0.44      0.56       122
          10       0.82      0.35      0.49        26
          11       0.85      0.66      0.74       110
          12       0.81      0.45      

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.35532591414944353
F1 Score: 0.5813962000563936
Precision: 0.7241758824942692
Recall: 0.5137440758293839
Hamming Loss: 0.04558650722333587
Average Loss: 0.13359336648136377
AUC-ROC: 0.9066409941713416
AUPR: 0.6916083333262981

Classification Report:
               precision    recall  f1-score   support

           0       0.79      0.55      0.65       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.87      0.80      0.84       234
           4       0.60      0.41      0.49        51
           5       0.67      0.74      0.70       109
           6       0.00      0.00      0.00        16
           7       0.76      0.54      0.63        99
           8       0.57      0.20      0.30        20
           9       0.78      0.48      0.60       122
          10       0.72      0.50      0.59        26
          11       0.83      0.78      0.81       110
          12       0.78      0.53  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.35930047694753575
F1 Score: 0.6039412163037716
Precision: 0.7309809312489227
Recall: 0.5374407582938389
Hamming Loss: 0.04492984032625976
Average Loss: 0.13101671375334262
AUC-ROC: 0.9091932635445376
AUPR: 0.6992452500492868

Classification Report:
               precision    recall  f1-score   support

           0       0.81      0.61      0.69       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.84      0.85      0.85       234
           4       0.56      0.49      0.52        51
           5       0.72      0.70      0.71       109
           6       0.00      0.00      0.00        16
           7       0.74      0.55      0.63        99
           8       0.57      0.20      0.30        20
           9       0.71      0.53      0.61       122
          10       0.71      0.46      0.56        26
          11       0.85      0.72      0.78       110
          12       0.86      0.40  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.36327503974562797
F1 Score: 0.6014640808946583
Precision: 0.7390944167383863
Recall: 0.5293838862559241
Hamming Loss: 0.04458422616990392
Average Loss: 0.12955872174352406
AUC-ROC: 0.9101716145683584
AUPR: 0.7035519033809362

Classification Report:
               precision    recall  f1-score   support

           0       0.90      0.54      0.67       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.86      0.81      0.83       234
           4       0.56      0.47      0.51        51
           5       0.72      0.67      0.70       109
           6       0.00      0.00      0.00        16
           7       0.76      0.52      0.61        99
           8       0.50      0.15      0.23        20
           9       0.76      0.51      0.61       122
          10       0.71      0.46      0.56        26
          11       0.85      0.75      0.80       110
          12       0.82      0.57  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3672496025437202
F1 Score: 0.6105117808775365
Precision: 0.7364847895719587
Recall: 0.5436018957345972
Hamming Loss: 0.04437685767609041
Average Loss: 0.1293625757098198
AUC-ROC: 0.9110731876937206
AUPR: 0.7046113440826196

Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.56      0.67       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.87      0.79      0.83       234
           4       0.55      0.51      0.53        51
           5       0.71      0.73      0.72       109
           6       0.00      0.00      0.00        16
           7       0.72      0.59      0.65        99
           8       0.62      0.25      0.36        20
           9       0.76      0.53      0.62       122
          10       0.71      0.46      0.56        26
          11       0.82      0.80      0.81       110
          12       0.82      0.60    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.36645468998410174
F1 Score: 0.6118220893116417
Precision: 0.7409025901822134
Recall: 0.5407582938388625
Hamming Loss: 0.044238612013548076
Average Loss: 0.1287667192518711
AUC-ROC: 0.9112238561563754
AUPR: 0.7052220859307424

Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.57      0.68       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.86      0.80      0.83       234
           4       0.56      0.49      0.52        51
           5       0.72      0.71      0.71       109
           6       0.00      0.00      0.00        16
           7       0.74      0.57      0.64        99
           8       0.57      0.20      0.30        20
           9       0.77      0.52      0.62       122
          10       0.71      0.46      0.56        26
          11       0.84      0.78      0.81       110
          12       0.82      0.60  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.36645468998410174
F1 Score: 0.6118220893116417
Precision: 0.7409025901822134
Recall: 0.5407582938388625
Hamming Loss: 0.044238612013548076
Average Loss: 0.1287667192518711
AUC-ROC: 0.9112238561563754
AUPR: 0.7052220859307424

Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.57      0.68       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.86      0.80      0.83       234
           4       0.56      0.49      0.52        51
           5       0.72      0.71      0.71       109
           6       0.00      0.00      0.00        16
           7       0.74      0.57      0.64        99
           8       0.57      0.20      0.30        20
           9       0.77      0.52      0.62       122
          10       0.71      0.46      0.56        26
          11       0.84      0.78      0.81       110
          12       0.82      0.60  