In [None]:
from transformers import AutoModel, AutoTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, hamming_loss, roc_auc_score, average_precision_score
from collections import defaultdict
from torch.amp import autocast, GradScaler
import torch.nn.functional as F
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time

## Hyperparameters
MAX_LEN = 128
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
EPOCHS = 12
LEARNING_RATE = 1e-05
THRESHOLD = 0.5 # threshold for the sigmoid
## Dataset Class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        self.title = list(df['File Contents'])
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': title
        }

## Data
train_file = '/content/train.csv'
val_file = '/content/val.csv'
test_file = '/content/test.csv'
train_df = pd.read_csv(train_file)
val_df = pd.read_csv(val_file)
test_df = pd.read_csv(test_file)

target_list = list(train_df.columns[1:])

## Tokenizer

tokenizer = AutoTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert')


train_dataset = CustomDataset(train_df, tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(val_df, tokenizer, MAX_LEN, target_list)
test_dataset = CustomDataset(test_df, tokenizer, MAX_LEN, target_list)

#print(train_dataset[0])

## Data Loader
train_data_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_data_loader = torch.utils.data.DataLoader(test_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)
## Device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

## Model



class CTBERT(torch.nn.Module):
    def __init__(self, num_classes):
        super(CTBERT, self).__init__()
        # Load CT-BERT model
        self.bert_model = AutoModel.from_pretrained('digitalepidemiologylab/covid-twitter-bert', return_dict=True)
        self.dropout = torch.nn.Dropout(0.3)
        self.linear = torch.nn.Linear(1024, num_classes) 

    def forward(self, input_ids, attn_mask, token_type_ids):
        output = self.bert_model(
            input_ids,
            attention_mask=attn_mask,
            token_type_ids=token_type_ids
        )
        output_dropout = self.dropout(output.pooler_output)
        output = self.linear(output_dropout)
        return output



## Setting the model
model = CTBERT(num_classes=len(target_list))
model.to(device)

## Loss & Optimizer
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

# define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

## Training function
def train_model(training_loader, model, optimizer, accumulation_steps=1):  # Removed accumulation_steps
    losses = []
    correct_predictions = 0
    num_samples = 0
    total_batches = len(training_loader)

    model.train()

    for batch_idx, data in enumerate(training_loader):
        ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
        mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
        targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

        outputs = model(ids, mask, token_type_ids)
        loss = loss_fn(outputs, targets)
        losses.append(loss.item())

        # Training accuracy, apply sigmoid, round (apply threshold 0.5)
        outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        targets = targets.cpu().detach().numpy()
        correct_predictions += np.sum(outputs == targets)
        num_samples += targets.size

        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)


        optimizer.step()
        optimizer.zero_grad()

        # Clear GPU cache
        torch.cuda.empty_cache()

    return model, float(correct_predictions) / num_samples, np.mean(losses)


## Evaluator Function
def eval_model(validation_loader, model, threshold=0.5, target_list=None):
    model.eval()
    final_targets = []
    final_outputs = []
    final_probs = []
    losses = []

    with torch.no_grad():
        for data in validation_loader:
            ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
            mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
            targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

            # Forward pass
            outputs = model(ids, mask, token_type_ids)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

            probs = torch.sigmoid(outputs).cpu().detach().numpy()
            targets = targets.cpu().detach().numpy()
            final_outputs.extend(probs >= threshold)
            final_probs.extend(probs)
            final_targets.extend(targets)

            torch.cuda.empty_cache()

    final_outputs = np.array(final_outputs) >= threshold
    final_probs = np.array(final_probs)
    final_targets = np.array(final_targets)

    acc = accuracy_score(final_targets, final_outputs)
    f1 = f1_score(final_targets, final_outputs, average='weighted')
    precision = precision_score(final_targets, final_outputs, average='weighted')
    recall = recall_score(final_targets, final_outputs, average='weighted')
    hamming = hamming_loss(final_targets, final_outputs)

    auc_roc = roc_auc_score(final_targets, final_probs, average='weighted', multi_class='ovr')
    aupr = average_precision_score(final_targets, final_probs, average='weighted')

    average_loss = np.mean(losses)

    print(f"Accuracy: {acc}")
    print(f"F1 Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Hamming Loss: {hamming}")
    print(f"Average Loss: {average_loss}")
    print(f"AUC-ROC: {auc_roc}")
    print(f"AUPR: {aupr}")
    print("\nClassification Report:\n", classification_report(final_targets, final_outputs, target_names=target_list))

    return f1, average_loss


#Learning Rate Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Training & Evaluation Loop
start = time.time()

history = defaultdict(list)
best_f1 = 0.0

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_loss = train_model(train_data_loader, model, optimizer)
    val_f1, val_loss = eval_model(val_data_loader, model)

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_f1'].append(val_f1)
    history['val_loss'].append(val_loss)

    scheduler.step()  # Step scheduler after each epoch

    if val_f1 > best_f1:
        torch.save(model.state_dict(), "caves_CTBERT_best.bin")
        best_f1 = val_f1

end = time.time()
print(f"Total training and evaluation time: {end - start} seconds")


## Testing
print("\n\nTesting\n\n")
model = CTBERT(num_classes=len(target_list))
model.load_state_dict(torch.load("caves_CTBERT_best.bin"))
model = model.to(device)

start = time.time()
eval_model(test_data_loader, model)
end = time.time()
print(f"Total test-set evaluation time: {end - start} seconds")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/421 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.35G [00:00<?, ?B/s]

Epoch 1/12


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5035460992907801
F1 Score: 0.5806338867269609
Precision: 0.7584333952897363
Recall: 0.5095486111111112
Hamming Loss: 0.06235608363267938
Average Loss: 0.1896053442070561
AUC-ROC: 0.9098187539796538
AUPR: 0.7246612523498058

Classification Report:
               precision    recall  f1-score   support

           0       0.78      0.14      0.24        49
           1       0.00      0.00      0.00        20
           2       0.81      0.60      0.69       167
           3       0.00      0.00      0.00        44
           4       0.78      0.58      0.66        78
           5       0.69      0.74      0.71       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.79      0.51      0.62       147
           9       0.96      0.70      0.81       379
          10       1.00      0.01      0.03        72

   micro avg       0.84      0.51      0.63      1152
   macro avg       0.53      0.30   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.585612968591692
F1 Score: 0.6961152325529797
Precision: 0.7875949834027441
Recall: 0.6484375
Hamming Loss: 0.05369807497467072
Average Loss: 0.15658142970454308
AUC-ROC: 0.9343218951297961
AUPR: 0.7849512009351999

Classification Report:
               precision    recall  f1-score   support

           0       0.62      0.20      0.31        49
           1       0.00      0.00      0.00        20
           2       0.83      0.65      0.73       167
           3       1.00      0.50      0.67        44
           4       0.88      0.65      0.75        78
           5       0.73      0.63      0.68       127
           6       0.79      0.17      0.29        63
           7       0.00      0.00      0.00         6
           8       0.75      0.78      0.76       147
           9       0.87      0.83      0.85       379
          10       0.60      0.50      0.55        72

   micro avg       0.81      0.65      0.72      1152
   macro avg       0.64      0.45      0.51  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.6099290780141844
F1 Score: 0.7326791313611047
Precision: 0.7996937361184382
Recall: 0.6840277777777778
Hamming Loss: 0.05065856129685917
Average Loss: 0.14621644539217796
AUC-ROC: 0.938700779885375
AUPR: 0.7918029386343081

Classification Report:
               precision    recall  f1-score   support

           0       0.70      0.43      0.53        49
           1       0.73      0.40      0.52        20
           2       0.81      0.68      0.74       167
           3       0.88      0.64      0.74        44
           4       0.90      0.69      0.78        78
           5       0.74      0.64      0.68       127
           6       0.76      0.40      0.52        63
           7       0.00      0.00      0.00         6
           8       0.74      0.73      0.74       147
           9       0.88      0.84      0.86       379
          10       0.62      0.44      0.52        72

   micro avg       0.81      0.68      0.74      1152
   macro avg       0.70      0.54   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5987841945288754
F1 Score: 0.7209280967852405
Precision: 0.7732806582422469
Recall: 0.6892361111111112
Hamming Loss: 0.053421755549415126
Average Loss: 0.15064434899437812
AUC-ROC: 0.9365410759917481
AUPR: 0.7847325600686182

Classification Report:
               precision    recall  f1-score   support

           0       0.64      0.43      0.51        49
           1       0.62      0.25      0.36        20
           2       0.75      0.75      0.75       167
           3       0.79      0.70      0.75        44
           4       0.83      0.69      0.76        78
           5       0.70      0.66      0.68       127
           6       0.81      0.33      0.47        63
           7       0.00      0.00      0.00         6
           8       0.72      0.71      0.72       147
           9       0.87      0.84      0.85       379
          10       0.68      0.42      0.52        72

   micro avg       0.78      0.69      0.73      1152
   macro avg       0.67      0.53 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.6068895643363729
F1 Score: 0.7300787667751735
Precision: 0.7538572266570207
Recall: 0.7178819444444444
Hamming Loss: 0.05351386202450032
Average Loss: 0.15606019790134124
AUC-ROC: 0.9362258953036969
AUPR: 0.7733314286661869

Classification Report:
               precision    recall  f1-score   support

           0       0.68      0.43      0.53        49
           1       0.75      0.45      0.56        20
           2       0.74      0.75      0.75       167
           3       0.88      0.64      0.74        44
           4       0.82      0.72      0.77        78
           5       0.67      0.69      0.68       127
           6       0.59      0.38      0.46        63
           7       0.00      0.00      0.00         6
           8       0.72      0.75      0.74       147
           9       0.83      0.88      0.86       379
          10       0.68      0.44      0.54        72

   micro avg       0.76      0.72      0.74      1152
   macro avg       0.67      0.56  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5977710233029382
F1 Score: 0.7321388391397178
Precision: 0.7449539522855203
Recall: 0.7248263888888888
Hamming Loss: 0.0548954591507783
Average Loss: 0.16525652499929552
AUC-ROC: 0.9326697022072832
AUPR: 0.7595437722397705

Classification Report:
               precision    recall  f1-score   support

           0       0.56      0.45      0.50        49
           1       0.62      0.50      0.56        20
           2       0.73      0.73      0.73       167
           3       0.80      0.73      0.76        44
           4       0.80      0.71      0.75        78
           5       0.67      0.69      0.68       127
           6       0.56      0.44      0.50        63
           7       1.00      0.50      0.67         6
           8       0.71      0.77      0.74       147
           9       0.84      0.87      0.85       379
          10       0.69      0.49      0.57        72

   micro avg       0.75      0.72      0.74      1152
   macro avg       0.73      0.62   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.6008105369807497
F1 Score: 0.735870965878094
Precision: 0.7559680211976514
Recall: 0.7274305555555556
Hamming Loss: 0.05360596849958552
Average Loss: 0.16590229373785756
AUC-ROC: 0.9321570390615275
AUPR: 0.7582302936159854

Classification Report:
               precision    recall  f1-score   support

           0       0.56      0.47      0.51        49
           1       0.61      0.55      0.58        20
           2       0.70      0.78      0.74       167
           3       0.86      0.70      0.78        44
           4       0.79      0.74      0.77        78
           5       0.72      0.65      0.68       127
           6       0.69      0.35      0.46        63
           7       1.00      0.50      0.67         6
           8       0.69      0.78      0.73       147
           9       0.85      0.86      0.85       379
          10       0.69      0.50      0.58        72

   micro avg       0.76      0.73      0.74      1152
   macro avg       0.74      0.63   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 9/12
Accuracy: 0.5886524822695035
F1 Score: 0.7295128597846752
Precision: 0.7398931109438551
Recall: 0.7265625
Hamming Loss: 0.05572441742654509
Average Loss: 0.16836591329305403
AUC-ROC: 0.9313741291183634
AUPR: 0.7553554952557588

Classification Report:
               precision    recall  f1-score   support

           0       0.59      0.45      0.51        49
           1       0.62      0.50      0.56        20
           2       0.69      0.78      0.73       167
           3       0.86      0.70      0.78        44
           4       0.79      0.73      0.76        78
           5       0.68      0.66      0.67       127
           6       0.62      0.38      0.47        63
           7       1.00      0.50      0.67         6
           8       0.70      0.78      0.73       147
           9       0.84      0.85      0.85       379
          10       0.59      0.53      0.56        72

   micro avg       0.74      0.73      0.73      1152
   macro avg       0.73      0.62

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5927051671732523
F1 Score: 0.7331315163324478
Precision: 0.7471182134172166
Recall: 0.7274305555555556
Hamming Loss: 0.0546191397255227
Average Loss: 0.1684319787929135
AUC-ROC: 0.9310970575818964
AUPR: 0.7549894764502862

Classification Report:
               precision    recall  f1-score   support

           0       0.61      0.47      0.53        49
           1       0.62      0.50      0.56        20
           2       0.70      0.77      0.73       167
           3       0.84      0.70      0.77        44
           4       0.80      0.72      0.76        78
           5       0.70      0.66      0.68       127
           6       0.65      0.38      0.48        63
           7       1.00      0.50      0.67         6
           8       0.70      0.78      0.73       147
           9       0.84      0.86      0.85       379
          10       0.63      0.53      0.58        72

   micro avg       0.75      0.73      0.74      1152
   macro avg       0.73      0.62    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5927051671732523
F1 Score: 0.7331315163324478
Precision: 0.7471182134172166
Recall: 0.7274305555555556
Hamming Loss: 0.0546191397255227
Average Loss: 0.1684319787929135
AUC-ROC: 0.9310970575818964
AUPR: 0.7549894764502862

Classification Report:
               precision    recall  f1-score   support

           0       0.61      0.47      0.53        49
           1       0.62      0.50      0.56        20
           2       0.70      0.77      0.73       167
           3       0.84      0.70      0.77        44
           4       0.80      0.72      0.76        78
           5       0.70      0.66      0.68       127
           6       0.65      0.38      0.48        63
           7       1.00      0.50      0.67         6
           8       0.70      0.78      0.73       147
           9       0.84      0.86      0.85       379
          10       0.63      0.53      0.58        72

   micro avg       0.75      0.73      0.74      1152
   macro avg       0.73      0.62    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5886524822695035
F1 Score: 0.7309472814711141
Precision: 0.7401806533997076
Recall: 0.7291666666666666
Hamming Loss: 0.05544809800128949
Average Loss: 0.1695630033650706
AUC-ROC: 0.930544101797167
AUPR: 0.7529184559716603

Classification Report:
               precision    recall  f1-score   support

           0       0.59      0.47      0.52        49
           1       0.62      0.50      0.56        20
           2       0.70      0.77      0.73       167
           3       0.84      0.70      0.77        44
           4       0.80      0.72      0.76        78
           5       0.71      0.66      0.68       127
           6       0.63      0.38      0.48        63
           7       1.00      0.50      0.67         6
           8       0.69      0.78      0.73       147
           9       0.84      0.87      0.85       379
          10       0.58      0.53      0.55        72

   micro avg       0.74      0.73      0.74      1152
   macro avg       0.73      0.63    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  model.load_state_dict(torch.load("caves_CTBERT_128_32_base_best.bin"))


Accuracy: 0.6064744562468386
F1 Score: 0.7466974702597419
Precision: 0.766829959480748
Recall: 0.7350649350649351
Hamming Loss: 0.0521911068193314
Average Loss: 0.1618671289855434
AUC-ROC: 0.9316296378393872
AUPR: 0.7664330895451599

Classification Report:
               precision    recall  f1-score   support

           0       0.65      0.57      0.60        97
           1       0.81      0.62      0.70        40
           2       0.69      0.79      0.74       334
           3       0.69      0.72      0.71        87
           4       0.74      0.72      0.73       157
           5       0.74      0.66      0.70       255
           6       0.70      0.52      0.60       125
           7       1.00      0.31      0.47        13
           8       0.71      0.78      0.75       295
           9       0.87      0.83      0.85       762
          10       0.71      0.51      0.59       145

   micro avg       0.76      0.74      0.75      2310
   macro avg       0.76      0.64     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
