In [None]:
# !pip install numpy
# !pip install pandas
# !pip install scikit-learn
# !pip install torch
# !pip install transformers


from transformers import AutoModel, AutoTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, hamming_loss, roc_auc_score, average_precision_score
from collections import defaultdict
from torch.amp import autocast, GradScaler
import torch.nn.functional as F
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time

## Hyperparameters
MAX_LEN = 512
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
EPOCHS = 12
LEARNING_RATE = 1e-05
THRESHOLD = 0.5 # threshold for the sigmoid

## Dataset Class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        self.title = list(df['File Contents'])
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': title
        }

## Data
train_file = '/content/train.csv'
test_file = '/content/test.csv'
train_val_df = pd.read_csv(train_file)
test_df = pd.read_csv(test_file)

train_df, val_df = train_test_split(train_val_df, test_size=0.2, random_state=42)

target_list = list(train_df.columns[1:])

## Tokenizer
tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-large-cased-v1.1')


train_dataset = CustomDataset(train_df, tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(val_df, tokenizer, MAX_LEN, target_list)
test_dataset = CustomDataset(test_df, tokenizer, MAX_LEN, target_list)

#print(train_dataset[0])

## Data Loader
train_data_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_data_loader = torch.utils.data.DataLoader(test_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)
## Device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

## Model

class BioBERT(torch.nn.Module):
    def __init__(self, num_classes):
        super(BioBERT, self).__init__()
        # Load BioBERT model
        self.bert_model = AutoModel.from_pretrained('dmis-lab/biobert-large-cased-v1.1', return_dict=True)
        self.dropout = torch.nn.Dropout(0.3)
        self.linear = torch.nn.Linear(1024, num_classes) 

    def forward(self, input_ids, attn_mask, token_type_ids):
        output = self.bert_model(
            input_ids,
            attention_mask=attn_mask,
            token_type_ids=token_type_ids
        )
        # Using the [CLS] token representation from the last hidden state
        cls_representation = output.last_hidden_state[:, 0, :] 
        output_dropout = self.dropout(cls_representation)
        output = self.linear(output_dropout)
        return output

model = BioBERT(num_classes=len(target_list))
model.to(device)

## Loss
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

# define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

## Training function
def train_model(training_loader, model, optimizer, accumulation_steps=1): 
    losses = []
    correct_predictions = 0
    num_samples = 0
    total_batches = len(training_loader)

    model.train()

    for batch_idx, data in enumerate(training_loader):
        ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
        mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
        targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

        outputs = model(ids, mask, token_type_ids)
        loss = loss_fn(outputs, targets)
        losses.append(loss.item())

        outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        targets = targets.cpu().detach().numpy()
        correct_predictions += np.sum(outputs == targets)
        num_samples += targets.size

        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        optimizer.zero_grad()

        # Clearing GPU cache
        torch.cuda.empty_cache()

    return model, float(correct_predictions) / num_samples, np.mean(losses)


## Evaluator Function
def eval_model(validation_loader, model, threshold=0.5, target_list=None):
    model.eval()
    final_targets = []
    final_outputs = []
    final_probs = []
    losses = []

    with torch.no_grad():
        for data in validation_loader:
            ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
            mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
            targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

            # Forward pass
            outputs = model(ids, mask, token_type_ids)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

            probs = torch.sigmoid(outputs).cpu().detach().numpy()
            targets = targets.cpu().detach().numpy()
            final_outputs.extend(probs >= threshold)
            final_probs.extend(probs)
            final_targets.extend(targets)

            torch.cuda.empty_cache()

    final_outputs = np.array(final_outputs) >= threshold
    final_probs = np.array(final_probs)
    final_targets = np.array(final_targets)

    acc = accuracy_score(final_targets, final_outputs)
    f1 = f1_score(final_targets, final_outputs, average='weighted')
    precision = precision_score(final_targets, final_outputs, average='weighted')
    recall = recall_score(final_targets, final_outputs, average='weighted')
    hamming = hamming_loss(final_targets, final_outputs)

    auc_roc = roc_auc_score(final_targets, final_probs, average='weighted', multi_class='ovr')
    aupr = average_precision_score(final_targets, final_probs, average='weighted')

    average_loss = np.mean(losses)

    print(f"Accuracy: {acc}")
    print(f"F1 Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Hamming Loss: {hamming}")
    print(f"Average Loss: {average_loss}")
    print(f"AUC-ROC: {auc_roc}")
    print(f"AUPR: {aupr}")
    print("\nClassification Report:\n", classification_report(final_targets, final_outputs, target_names=target_list))

    return f1, average_loss


# Learning Rate Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Training & Evaluation Loop
# recording starting time
start = time.time()

history = defaultdict(list)
best_f1 = 0.0

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_loss = train_model(train_data_loader, model, optimizer)
    val_f1, val_loss = eval_model(val_data_loader, model)

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_f1'].append(val_f1)
    history['val_loss'].append(val_loss)

    scheduler.step()  # Step scheduler after each epoch

    # save the best model
    if val_f1 > best_f1:
        torch.save(model.state_dict(), "ohsumed_bioBERT_LARRGE_CLS_32_best.bin")
        best_f1 = val_f1

# recording end time
end = time.time()
print(f"Total training and evaluation time: {end - start} seconds")


## Testing
# Loading pretrained model (best model)
print("\n\nTesting\n\n")
model = BioBERT(num_classes=len(target_list))
model.load_state_dict(torch.load("ohsumed_bioBERT_LARRGE_CLS_32_best.bin"))
model = model.to(device)

# recording starting time
start = time.time()
# Evaluate the model using the test data
eval_model(test_data_loader, model)
# recording end time
end = time.time()
print(f"Total test-set evaluation time: {end - start} seconds")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/289 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/467k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Epoch 1/12


model.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.31160572337042924
F1 Score: 0.5281519603631152
Precision: 0.7005648491686095
Recall: 0.4568380213385063
Hamming Loss: 0.047521946498928595
Average Loss: 0.1458230186253786
AUC-ROC: 0.9079200914996227
AUPR: 0.6973152756983906

Classification Report:
               precision    recall  f1-score   support

           0       0.74      0.51      0.60        69
           1       0.00      0.00      0.00        29
           2       0.00      0.00      0.00        12
           3       0.88      0.78      0.83       249
           4       0.67      0.05      0.09        41
           5       0.81      0.72      0.76       116
           6       0.00      0.00      0.00        24
           7       0.94      0.39      0.55        85
           8       0.00      0.00      0.00        24
           9       0.76      0.46      0.57       118
          10       0.85      0.37      0.51        30
          11       0.90      0.65      0.75        97
          12       0.88      0.42  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.4379968203497615
F1 Score: 0.6477374674624365
Precision: 0.8817916136507769
Recall: 0.5732298739088264
Hamming Loss: 0.03521808253266054
Average Loss: 0.10576158110052347
AUC-ROC: 0.9450489590696193
AUPR: 0.827171889251378

Classification Report:
               precision    recall  f1-score   support

           0       0.89      0.74      0.81        69
           1       0.00      0.00      0.00        29
           2       1.00      0.08      0.15        12
           3       0.87      0.90      0.88       249
           4       0.94      0.37      0.53        41
           5       0.90      0.69      0.78       116
           6       1.00      0.17      0.29        24
           7       0.88      0.75      0.81        85
           8       1.00      0.12      0.22        24
           9       0.86      0.57      0.68       118
          10       0.92      0.77      0.84        30
          11       0.93      0.88      0.90        97
          12       0.81      0.68    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5397456279809221
F1 Score: 0.7298110390535205
Precision: 0.9249105344913698
Recall: 0.6697381183317168
Hamming Loss: 0.02768369392410313
Average Loss: 0.08828795403242111
AUC-ROC: 0.9620047352175191
AUPR: 0.8823081950890589

Classification Report:
               precision    recall  f1-score   support

           0       0.91      0.74      0.82        69
           1       1.00      0.07      0.13        29
           2       0.88      0.58      0.70        12
           3       0.94      0.94      0.94       249
           4       0.80      0.80      0.80        41
           5       0.92      0.80      0.86       116
           6       0.93      0.54      0.68        24
           7       0.96      0.82      0.89        85
           8       1.00      0.46      0.63        24
           9       0.93      0.63      0.75       118
          10       0.97      0.93      0.95        30
          11       0.97      0.89      0.92        97
          12       0.84      0.82   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.6406995230524642
F1 Score: 0.8294017031334153
Precision: 0.9387598701588965
Recall: 0.7657613967022309
Hamming Loss: 0.019907375406096633
Average Loss: 0.06931810518726707
AUC-ROC: 0.9754582730414434
AUPR: 0.9274344387635985

Classification Report:
               precision    recall  f1-score   support

           0       0.95      0.78      0.86        69
           1       1.00      0.38      0.55        29
           2       1.00      0.83      0.91        12
           3       0.95      0.94      0.94       249
           4       0.97      0.85      0.91        41
           5       0.88      0.91      0.89       116
           6       0.95      0.75      0.84        24
           7       0.96      0.88      0.92        85
           8       0.86      0.75      0.80        24
           9       0.95      0.66      0.78       118
          10       0.97      0.93      0.95        30
          11       0.97      0.89      0.92        97
          12       0.94      0.88  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.739268680445151
F1 Score: 0.887690022004038
Precision: 0.9572609191808376
Recall: 0.8365664403491756
Hamming Loss: 0.013824566254233773
Average Loss: 0.05591625003144145
AUC-ROC: 0.9834041868562626
AUPR: 0.9518581044066805

Classification Report:
               precision    recall  f1-score   support

           0       0.97      0.84      0.90        69
           1       1.00      0.59      0.74        29
           2       1.00      0.75      0.86        12
           3       0.96      0.95      0.95       249
           4       0.97      0.83      0.89        41
           5       0.95      0.95      0.95       116
           6       0.95      0.83      0.89        24
           7       0.98      0.94      0.96        85
           8       0.83      0.83      0.83        24
           9       0.98      0.84      0.90       118
          10       0.97      0.97      0.97        30
          11       1.00      0.88      0.93        97
          12       0.98      0.88    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.814785373608903
F1 Score: 0.9234280155801423
Precision: 0.967355317451571
Recall: 0.8884578079534433
Hamming Loss: 0.009642634962328056
Average Loss: 0.04727101037278771
AUC-ROC: 0.9895635349249646
AUPR: 0.9667596165464135

Classification Report:
               precision    recall  f1-score   support

           0       0.94      0.90      0.92        69
           1       0.96      0.83      0.89        29
           2       1.00      0.83      0.91        12
           3       0.94      0.95      0.95       249
           4       1.00      0.88      0.94        41
           5       0.98      0.97      0.97       116
           6       0.95      0.83      0.89        24
           7       0.98      0.94      0.96        85
           8       0.95      0.83      0.89        24
           9       0.96      0.92      0.94       118
          10       1.00      0.97      0.98        30
          11       1.00      0.96      0.98        97
          12       0.98      0.94    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.8513513513513513
F1 Score: 0.9411943216102948
Precision: 0.9792664215146245
Recall: 0.9078564500484966
Hamming Loss: 0.007430704361650653
Average Loss: 0.040324658155441284
AUC-ROC: 0.9934455359963164
AUPR: 0.9749517680067493

Classification Report:
               precision    recall  f1-score   support

           0       0.95      0.88      0.92        69
           1       0.95      0.69      0.80        29
           2       1.00      0.92      0.96        12
           3       1.00      0.93      0.96       249
           4       1.00      0.93      0.96        41
           5       1.00      0.97      0.99       116
           6       0.95      0.83      0.89        24
           7       0.98      0.94      0.96        85
           8       0.96      0.92      0.94        24
           9       0.99      0.92      0.95       118
          10       1.00      0.97      0.98        30
          11       1.00      0.96      0.98        97
          12       0.98      0.96 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.884737678855326
F1 Score: 0.9537982805398404
Precision: 0.9817346537541988
Recall: 0.9282250242483027
Hamming Loss: 0.005840879242413769
Average Loss: 0.03705003559589386
AUC-ROC: 0.9951624731940196
AUPR: 0.9795640018206381

Classification Report:
               precision    recall  f1-score   support

           0       0.97      0.90      0.93        69
           1       0.96      0.83      0.89        29
           2       1.00      0.92      0.96        12
           3       1.00      0.94      0.97       249
           4       1.00      0.93      0.96        41
           5       1.00      0.98      0.99       116
           6       0.95      0.83      0.89        24
           7       1.00      0.96      0.98        85
           8       0.96      0.92      0.94        24
           9       0.99      0.92      0.96       118
          10       1.00      1.00      1.00        30
          11       1.00      0.99      0.99        97
          12       0.98      0.96   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.8871224165341812
F1 Score: 0.9545306276399445
Precision: 0.9838403049861845
Recall: 0.9282250242483027
Hamming Loss: 0.005702633579871432
Average Loss: 0.035367491329088806
AUC-ROC: 0.9958802957254752
AUPR: 0.9817754934746962

Classification Report:
               precision    recall  f1-score   support

           0       0.97      0.86      0.91        69
           1       1.00      0.79      0.88        29
           2       1.00      0.92      0.96        12
           3       0.99      0.96      0.97       249
           4       1.00      0.93      0.96        41
           5       1.00      0.98      0.99       116
           6       0.95      0.83      0.89        24
           7       1.00      0.96      0.98        85
           8       0.96      0.92      0.94        24
           9       0.99      0.93      0.96       118
          10       1.00      1.00      1.00        30
          11       1.00      0.99      0.99        97
          12       0.98      0.98 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.8926868044515104
F1 Score: 0.9567970763483203
Precision: 0.9843144983208871
Recall: 0.9316197866149369
Hamming Loss: 0.005426142254786756
Average Loss: 0.03458684836514294
AUC-ROC: 0.9960811164135719
AUPR: 0.9826619704698923

Classification Report:
               precision    recall  f1-score   support

           0       0.97      0.86      0.91        69
           1       1.00      0.83      0.91        29
           2       1.00      0.92      0.96        12
           3       1.00      0.94      0.97       249
           4       1.00      0.93      0.96        41
           5       1.00      0.98      0.99       116
           6       0.95      0.83      0.89        24
           7       1.00      0.96      0.98        85
           8       0.96      0.92      0.94        24
           9       0.99      0.93      0.96       118
          10       1.00      1.00      1.00        30
          11       1.00      0.99      0.99        97
          12       0.98      0.96  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.8926868044515104
F1 Score: 0.9567970763483203
Precision: 0.9843144983208871
Recall: 0.9316197866149369
Hamming Loss: 0.005426142254786756
Average Loss: 0.03458684836514294
AUC-ROC: 0.9960811164135719
AUPR: 0.9826619704698923

Classification Report:
               precision    recall  f1-score   support

           0       0.97      0.86      0.91        69
           1       1.00      0.83      0.91        29
           2       1.00      0.92      0.96        12
           3       1.00      0.94      0.97       249
           4       1.00      0.93      0.96        41
           5       1.00      0.98      0.99       116
           6       0.95      0.83      0.89        24
           7       1.00      0.96      0.98        85
           8       0.96      0.92      0.94        24
           9       0.99      0.93      0.96       118
          10       1.00      1.00      1.00        30
          11       1.00      0.99      0.99        97
          12       0.98      0.96  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.8966613672496025
F1 Score: 0.958132670440013
Precision: 0.9838752451053424
Recall: 0.934529582929195
Hamming Loss: 0.005253335176608834
Average Loss: 0.0343077898491174
AUC-ROC: 0.9962164698565605
AUPR: 0.9830701717157347

Classification Report:
               precision    recall  f1-score   support

           0       0.97      0.87      0.92        69
           1       1.00      0.83      0.91        29
           2       1.00      0.92      0.96        12
           3       0.99      0.95      0.97       249
           4       1.00      0.93      0.96        41
           5       1.00      0.98      0.99       116
           6       0.95      0.83      0.89        24
           7       1.00      0.96      0.98        85
           8       0.96      0.92      0.94        24
           9       0.99      0.95      0.97       118
          10       1.00      1.00      1.00        30
          11       1.00      0.99      0.99        97
          12       0.98      0.96     

  model.load_state_dict(torch.load("ohsumed_bioBERT_LARRGE_CLS_32_best.bin"))


Accuracy: 0.47834619913646476
F1 Score: 0.743025961800741
Precision: 0.7797440532041131
Recall: 0.7134218173250608
Hamming Loss: 0.03443901495542952
Average Loss: 0.10353758462770962
AUC-ROC: 0.9390825463380256
AUPR: 0.7838317193432489

Classification Report:
               precision    recall  f1-score   support

           0       0.80      0.72      0.76       506
           1       0.67      0.39      0.50       233
           2       0.97      0.80      0.88        70
           3       0.88      0.86      0.87      1467
           4       0.75      0.76      0.75       429
           5       0.80      0.84      0.82       632
           6       0.77      0.66      0.71       146
           7       0.83      0.71      0.77       600
           8       0.70      0.76      0.73       129
           9       0.77      0.70      0.73       941
          10       0.84      0.76      0.79       202
          11       0.85      0.88      0.86       548
          12       0.78      0.76   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
