In [None]:
# !pip install numpy
# !pip install pandas
# !pip install scikit-learn
# !pip install torch
# !pip install transformers


from transformers import AutoModel, AutoTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, hamming_loss, roc_auc_score, average_precision_score
from collections import defaultdict
from torch.amp import autocast, GradScaler
import torch.nn.functional as F
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time

## Hyperparameters
MAX_LEN = 512
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
EPOCHS = 12
LEARNING_RATE = 1e-05
THRESHOLD = 0.5 # threshold for the sigmoid
## Dataset Class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        self.title = list(df['File Contents'])
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': title
        }

## Data
train_file = '/content/train.csv'
test_file = '/content/test.csv'
train_val_df = pd.read_csv(train_file)
test_df = pd.read_csv(test_file)

train_df, val_df = train_test_split(train_val_df, test_size=0.2, random_state=42)

target_list = list(train_df.columns[1:])

## Tokenizer
tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1')


train_dataset = CustomDataset(train_df, tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(val_df, tokenizer, MAX_LEN, target_list)
test_dataset = CustomDataset(test_df, tokenizer, MAX_LEN, target_list)

#print(train_dataset[0])

## Data Loader
train_data_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_data_loader = torch.utils.data.DataLoader(test_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)
## Device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

## Model

class BioBERT(torch.nn.Module):
    def __init__(self, num_classes):
        super(BioBERT, self).__init__()
        self.bert_model = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.1', return_dict=True)
        self.dropout = torch.nn.Dropout(0.3)
        self.linear = torch.nn.Linear(768, num_classes)

    def forward(self, input_ids, attn_mask, token_type_ids):
        output = self.bert_model(
            input_ids,
            attention_mask=attn_mask,
            token_type_ids=token_type_ids
        )
        # Using the [CLS] token representation from the last hidden state
        cls_representation = output.last_hidden_state[:, 0, :]
        output_dropout = self.dropout(cls_representation)
        output = self.linear(output_dropout)
        return output

## Setting the model
model = BioBERT(num_classes=len(target_list))
model.to(device)

## Loss Fn
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

# optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

## Training function
def train_model(training_loader, model, optimizer, accumulation_steps=1):  # Removed accumulation_steps
    losses = []
    correct_predictions = 0
    num_samples = 0
    total_batches = len(training_loader)

    model.train()

    for batch_idx, data in enumerate(training_loader):
        ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
        mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
        targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

        # Forward pass without mixed precision
        outputs = model(ids, mask, token_type_ids)
        loss = loss_fn(outputs, targets)
        losses.append(loss.item())

        # Training accuracy, apply sigmoid, round (apply threshold 0.5)
        outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        targets = targets.cpu().detach().numpy()
        correct_predictions += np.sum(outputs == targets)
        num_samples += targets.size

        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        optimizer.zero_grad()

        # Clearing GPU cache
        torch.cuda.empty_cache()

    return model, float(correct_predictions) / num_samples, np.mean(losses)


def eval_model(validation_loader, model, threshold=0.5, target_list=None):
    model.eval()
    final_targets = []
    final_outputs = []
    final_probs = []
    losses = []

    with torch.no_grad():
        for data in validation_loader:
            ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
            mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
            targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

            # Forward pass
            outputs = model(ids, mask, token_type_ids)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

            probs = torch.sigmoid(outputs).cpu().detach().numpy()
            targets = targets.cpu().detach().numpy()
            final_outputs.extend(probs >= threshold)
            final_probs.extend(probs)
            final_targets.extend(targets)

            torch.cuda.empty_cache()

    final_outputs = np.array(final_outputs) >= threshold
    final_probs = np.array(final_probs)
    final_targets = np.array(final_targets)

    acc = accuracy_score(final_targets, final_outputs)
    f1 = f1_score(final_targets, final_outputs, average='weighted')
    precision = precision_score(final_targets, final_outputs, average='weighted')
    recall = recall_score(final_targets, final_outputs, average='weighted')
    hamming = hamming_loss(final_targets, final_outputs)

    auc_roc = roc_auc_score(final_targets, final_probs, average='weighted', multi_class='ovr')
    aupr = average_precision_score(final_targets, final_probs, average='weighted')

    average_loss = np.mean(losses)

    print(f"Accuracy: {acc}")
    print(f"F1 Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Hamming Loss: {hamming}")
    print(f"Average Loss: {average_loss}")
    print(f"AUC-ROC: {auc_roc}")
    print(f"AUPR: {aupr}")
    print("\nClassification Report:\n", classification_report(final_targets, final_outputs, target_names=target_list))

    return f1, average_loss


#Learning Rate Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Training & Evaluation Loop
start = time.time()

history = defaultdict(list)
best_f1 = 0.0

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_loss = train_model(train_data_loader, model, optimizer)
    val_f1, val_loss = eval_model(val_data_loader, model)

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_f1'].append(val_f1)
    history['val_loss'].append(val_loss)

    scheduler.step()  # Stepping scheduler after each epoch


    if val_f1 > best_f1:
        torch.save(model.state_dict(), "ohsumed_bioBERT_CLS_32_best.bin")
        best_f1 = val_f1

end = time.time()
print(f"Total training and evaluation time: {end - start} seconds")


## Testing
print("\n\nTesting\n\n")
model = BioBERT(num_classes=len(target_list))
model.load_state_dict(torch.load("ohsumed_bioBERT_CLS_32_best.bin"))
model = model.to(device)

start = time.time()
eval_model(test_data_loader, model)
end = time.time()
print(f"Total test-set evaluation time: {end - start} seconds")

Epoch 1/12


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.0
F1 Score: 0.0
Precision: 0.0
Recall: 0.0
Hamming Loss: 0.07292458699108316
Average Loss: 0.23318374902009964
AUC-ROC: 0.7210210639169193
AUPR: 0.3313652885725276

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.00      0.00      0.00       234
           4       0.00      0.00      0.00        51
           5       0.00      0.00      0.00       109
           6       0.00      0.00      0.00        16
           7       0.00      0.00      0.00        99
           8       0.00      0.00      0.00        20
           9       0.00      0.00      0.00       122
          10       0.00      0.00      0.00        26
          11       0.00      0.00      0.00       110
          12       0.00      0.00      0.00        47
          13       0.00      0.00      0.0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.17090620031796502
F1 Score: 0.22459215948333483
Precision: 0.3838731264577965
Recall: 0.21184834123222748
Hamming Loss: 0.06058616160917951
Average Loss: 0.1865213230252266
AUC-ROC: 0.8636867052915098
AUPR: 0.581128766389068

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.84      0.88      0.86       234
           4       0.00      0.00      0.00        51
           5       0.82      0.30      0.44       109
           6       0.00      0.00      0.00        16
           7       0.00      0.00      0.00        99
           8       0.00      0.00      0.00        20
           9       0.00      0.00      0.00       122
          10       0.00      0.00      0.00        26
          11       0.00      0.00      0.00       110
          12       0.00      0.00  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.2972972972972973
F1 Score: 0.45492715814410173
Precision: 0.6756668744239211
Recall: 0.3876777251184834
Hamming Loss: 0.05052878965922444
Average Loss: 0.1574738420546055
AUC-ROC: 0.8977121828387509
AUPR: 0.6620357894745106

Classification Report:
               precision    recall  f1-score   support

           0       0.76      0.16      0.26       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.89      0.83      0.86       234
           4       0.71      0.10      0.17        51
           5       0.72      0.62      0.67       109
           6       0.00      0.00      0.00        16
           7       0.81      0.38      0.52        99
           8       0.00      0.00      0.00        20
           9       0.74      0.28      0.40       122
          10       1.00      0.23      0.38        26
          11       0.86      0.73      0.79       110
          12       0.68      0.36   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.3314785373608903
F1 Score: 0.49680924799145904
Precision: 0.7191195436639487
Recall: 0.4308056872037915
Hamming Loss: 0.04793668348655561
Average Loss: 0.14324754551053048
AUC-ROC: 0.9114018869381508
AUPR: 0.7048640741619302

Classification Report:
               precision    recall  f1-score   support

           0       0.74      0.30      0.43       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.88      0.84      0.86       234
           4       0.75      0.24      0.36        51
           5       0.75      0.59      0.66       109
           6       0.00      0.00      0.00        16
           7       0.85      0.51      0.63        99
           8       0.00      0.00      0.00        20
           9       0.71      0.34      0.46       122
          10       0.80      0.31      0.44        26
          11       0.87      0.82      0.84       110
          12       0.76      0.47  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.35930047694753575
F1 Score: 0.5480314323763472
Precision: 0.7565101157431134
Recall: 0.4791469194312796
Hamming Loss: 0.04486071749498859
Average Loss: 0.13376149777323007
AUC-ROC: 0.9180027718371889
AUPR: 0.7254676484053327

Classification Report:
               precision    recall  f1-score   support

           0       0.80      0.55      0.65       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.90      0.81      0.86       234
           4       0.70      0.51      0.59        51
           5       0.74      0.67      0.71       109
           6       0.00      0.00      0.00        16
           7       0.86      0.55      0.67        99
           8       0.00      0.00      0.00        20
           9       0.72      0.34      0.47       122
          10       0.89      0.31      0.46        26
          11       0.85      0.80      0.82       110
          12       0.85      0.47  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.3847376788553259
F1 Score: 0.6053183066937462
Precision: 0.7375788387857766
Recall: 0.5445497630331754
Hamming Loss: 0.04289071680376028
Average Loss: 0.12817178722470998
AUC-ROC: 0.9238211164610882
AUPR: 0.7376621242729903

Classification Report:
               precision    recall  f1-score   support

           0       0.80      0.67      0.73       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.89      0.82      0.86       234
           4       0.74      0.49      0.59        51
           5       0.69      0.75      0.72       109
           6       0.00      0.00      0.00        16
           7       0.76      0.63      0.69        99
           8       1.00      0.05      0.10        20
           9       0.74      0.47      0.57       122
          10       0.77      0.65      0.71        26
          11       0.86      0.88      0.87       110
          12       0.74      0.62   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.3815580286168522
F1 Score: 0.5989950701620743
Precision: 0.7408466217900476
Recall: 0.5308056872037915
Hamming Loss: 0.043098085297573786
Average Loss: 0.12473903186619281
AUC-ROC: 0.924449266687356
AUPR: 0.7429209488009816

Classification Report:
               precision    recall  f1-score   support

           0       0.84      0.60      0.70       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.88      0.86      0.87       234
           4       0.79      0.53      0.64        51
           5       0.71      0.69      0.70       109
           6       0.00      0.00      0.00        16
           7       0.76      0.62      0.68        99
           8       1.00      0.05      0.10        20
           9       0.73      0.49      0.59       122
          10       0.74      0.54      0.62        26
          11       0.86      0.81      0.83       110
          12       0.80      0.60   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.38950715421303655
F1 Score: 0.6151760011683944
Precision: 0.7628744410436682
Recall: 0.5521327014218009
Hamming Loss: 0.042234049906684175
Average Loss: 0.12331730090081691
AUC-ROC: 0.92561131547284
AUPR: 0.7442419437238351

Classification Report:
               precision    recall  f1-score   support

           0       0.86      0.59      0.70       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.88      0.85      0.87       234
           4       0.71      0.53      0.61        51
           5       0.71      0.68      0.69       109
           6       0.00      0.00      0.00        16
           7       0.77      0.61      0.68        99
           8       1.00      0.05      0.10        20
           9       0.76      0.49      0.60       122
          10       0.75      0.69      0.72        26
          11       0.85      0.87      0.86       110
          12       0.77      0.64   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.3918918918918919
F1 Score: 0.6270905505557185
Precision: 0.7709310463262496
Recall: 0.5601895734597157
Hamming Loss: 0.041404575931430154
Average Loss: 0.12206144835799933
AUC-ROC: 0.9261698511150714
AUPR: 0.7453198433677191

Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.61      0.71       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.89      0.84      0.86       234
           4       0.71      0.53      0.61        51
           5       0.71      0.71      0.71       109
           6       0.00      0.00      0.00        16
           7       0.76      0.61      0.67        99
           8       1.00      0.10      0.18        20
           9       0.75      0.52      0.61       122
          10       0.76      0.62      0.68        26
          11       0.85      0.85      0.85       110
          12       0.81      0.64  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.39507154213036566
F1 Score: 0.626811960181168
Precision: 0.7693329630259514
Recall: 0.5616113744075829
Hamming Loss: 0.041404575931430154
Average Loss: 0.12207885030657054
AUC-ROC: 0.926187635842868
AUPR: 0.7454961425932319

Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.61      0.71       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.89      0.85      0.87       234
           4       0.71      0.53      0.61        51
           5       0.72      0.72      0.72       109
           6       0.00      0.00      0.00        16
           7       0.76      0.62      0.68        99
           8       1.00      0.10      0.18        20
           9       0.76      0.51      0.61       122
          10       0.78      0.69      0.73        26
          11       0.85      0.86      0.86       110
          12       0.75      0.64   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.39507154213036566
F1 Score: 0.626811960181168
Precision: 0.7693329630259514
Recall: 0.5616113744075829
Hamming Loss: 0.041404575931430154
Average Loss: 0.12207885030657054
AUC-ROC: 0.926187635842868
AUPR: 0.7454961425932319

Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.61      0.71       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.89      0.85      0.87       234
           4       0.71      0.53      0.61        51
           5       0.72      0.72      0.72       109
           6       0.00      0.00      0.00        16
           7       0.76      0.62      0.68        99
           8       1.00      0.10      0.18        20
           9       0.76      0.51      0.61       122
          10       0.78      0.69      0.73        26
          11       0.85      0.86      0.86       110
          12       0.75      0.64   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.39507154213036566
F1 Score: 0.6316419497967869
Precision: 0.7728852489215937
Recall: 0.5644549763033175
Hamming Loss: 0.04116264602198106
Average Loss: 0.12179693840444088
AUC-ROC: 0.926384158122587
AUPR: 0.7462783340507204

Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.61      0.71       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.89      0.84      0.86       234
           4       0.71      0.53      0.61        51
           5       0.71      0.71      0.71       109
           6       0.00      0.00      0.00        16
           7       0.76      0.62      0.68        99
           8       1.00      0.10      0.18        20
           9       0.76      0.51      0.61       122
          10       0.76      0.62      0.68        26
          11       0.85      0.86      0.86       110
          12       0.77      0.64   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
