In [None]:
# !pip install numpy
# !pip install pandas
# !pip install scikit-learn
# !pip install torch
# !pip install transformers


from transformers import AutoModel, AutoTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, hamming_loss, roc_auc_score, average_precision_score
from collections import defaultdict
from torch.amp import autocast, GradScaler
import torch.nn.functional as F
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time

## Hyperparameters
MAX_LEN = 128
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
EPOCHS = 12
LEARNING_RATE = 1e-05
THRESHOLD = 0.5 # threshold for the sigmoid
## Dataset Class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        self.title = list(df['File Contents'])
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': title
        }

## Data
train_file = '/content/train.csv'
val_file = '/content/val.csv'
test_file = '/content/test.csv'
train_df = pd.read_csv(train_file)
val_df = pd.read_csv(val_file)
test_df = pd.read_csv(test_file)

target_list = list(train_df.columns[1:])

## Tokenizer

tokenizer = AutoTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert')


train_dataset = CustomDataset(train_df, tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(val_df, tokenizer, MAX_LEN, target_list)
test_dataset = CustomDataset(test_df, tokenizer, MAX_LEN, target_list)

#print(train_dataset[0])

## Data Loader
train_data_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_data_loader = torch.utils.data.DataLoader(test_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)
## Device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

## Model

class CTBERTInceptionAttention(nn.Module):
    def __init__(self, num_classes):
        super(CTBERTInceptionAttention, self).__init__()

        # Initialize HateBERT model
        self.bert = AutoModel.from_pretrained('digitalepidemiologylab/covid-twitter-bert', return_dict=True)

        # Dropout layer after HateBERT output
        self.dropout = nn.Dropout(0.3)

        # Inception block with fewer channels and smaller kernel sizes
        self.conv2 = nn.Conv1d(in_channels=1024, out_channels=16, kernel_size=2, padding=0)
        self.conv3 = nn.Conv1d(in_channels=1024, out_channels=16, kernel_size=3, padding=0)
        self.conv5 = nn.Conv1d(in_channels=1024, out_channels=16, kernel_size=5, padding=0)

        # Self-attention layer after Inception block
        self.attention = nn.MultiheadAttention(embed_dim=1072, num_heads=4, batch_first=True)  # 768 (BERT) + 48 (Inception)

        # Additional dense layer with LayerNorm for refined feature interaction
        self.dense = nn.Sequential(
            nn.Linear(1072, 512),
            nn.ReLU(),
            nn.LayerNorm(512)
        )

        # Final dropout and classification layer
        self.final_dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, input_ids, attention_mask, token_type_ids):
        # HateBERT branch
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        hidden_states = outputs.last_hidden_state  # Shape: (batch_size, seq_length, 768)

        # Apply dropout to RoBERTa embeddings
        hidden_states = self.dropout(hidden_states)

        # Inception block with manual padding after convolutions
        hidden_states = hidden_states.permute(0, 2, 1)  # Shape: (batch_size, 768, seq_length)

        # Apply convolutions without padding, then pad manually
        conv2_output = F.pad(self.conv2(hidden_states), (0, 1))  # Padding to match max seq length
        conv3_output = F.pad(self.conv3(hidden_states), (1, 1))  # Adjust to max seq length
        conv5_output = F.pad(self.conv5(hidden_states), (2, 2))  # Adjust to max seq length

        # Concatenate along the channel dimension
        inception_output = torch.cat([conv2_output, conv3_output, conv5_output], dim=1)  # Shape: (batch_size, 128, seq_length)
        inception_output = inception_output.permute(0, 2, 1)  # Back to (batch_size, seq_length, 128)

        # Concatenate Inception outputs with original RoBERTa embeddings
        concatenated_features = torch.cat([hidden_states.permute(0, 2, 1), inception_output], dim=2)  # Shape: (batch_size, seq_length, 896)


        # Apply multi-head self-attention after gating
        key_padding_mask = ~attention_mask.bool()  # Shape: (batch_size, seq_length)
        attn_output, _ = self.attention(
            concatenated_features,
            concatenated_features,
            concatenated_features,
            key_padding_mask=key_padding_mask
        )  # Shape: (batch_size, seq_length, 896)

        # Global mean pooling over sequence length
        pooled_output = F.adaptive_avg_pool1d(attn_output.permute(0, 2, 1), output_size=1).squeeze(-1)

        # Additional dense layer with LayerNorm
        dense_output = self.dense(pooled_output)

        # Final dropout and classification layer
        dense_output = self.final_dropout(dense_output)
        logits = self.fc(dense_output)  # Shape: (batch_size, num_classes)

        return logits






## Setting the model
model = CTBERTInceptionAttention(num_classes=len(target_list))
model.to(device)

## Loss & Optimizer
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

# define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)


def train_model(training_loader, model, optimizer, accumulation_steps=1):  # Removed accumulation_steps
    losses = []
    correct_predictions = 0
    num_samples = 0
    total_batches = len(training_loader)

    model.train()

    for batch_idx, data in enumerate(training_loader):
        ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
        mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
        targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

        outputs = model(ids, mask, token_type_ids)
        loss = loss_fn(outputs, targets)
        losses.append(loss.item())

        # Training accuracy, apply sigmoid, round (apply threshold 0.5)
        outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        targets = targets.cpu().detach().numpy()
        correct_predictions += np.sum(outputs == targets)
        num_samples += targets.size

        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        optimizer.zero_grad()

        # Clear GPU cache
        torch.cuda.empty_cache()

    return model, float(correct_predictions) / num_samples, np.mean(losses)


def eval_model(validation_loader, model, threshold=0.5, target_list=None):
    model.eval()
    final_targets = []
    final_outputs = []
    final_probs = []
    losses = []

    with torch.no_grad():
        for data in validation_loader:
            ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
            mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
            targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

            # Forward pass
            outputs = model(ids, mask, token_type_ids)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

            probs = torch.sigmoid(outputs).cpu().detach().numpy()
            targets = targets.cpu().detach().numpy()
            final_outputs.extend(probs >= threshold)
            final_probs.extend(probs)
            final_targets.extend(targets)

            torch.cuda.empty_cache()

    final_outputs = np.array(final_outputs) >= threshold
    final_probs = np.array(final_probs)
    final_targets = np.array(final_targets)

    acc = accuracy_score(final_targets, final_outputs)
    f1 = f1_score(final_targets, final_outputs, average='weighted')
    precision = precision_score(final_targets, final_outputs, average='weighted')
    recall = recall_score(final_targets, final_outputs, average='weighted')
    hamming = hamming_loss(final_targets, final_outputs)

    auc_roc = roc_auc_score(final_targets, final_probs, average='weighted', multi_class='ovr')
    aupr = average_precision_score(final_targets, final_probs, average='weighted')

    average_loss = np.mean(losses)

    print(f"Accuracy: {acc}")
    print(f"F1 Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Hamming Loss: {hamming}")
    print(f"Average Loss: {average_loss}")
    print(f"AUC-ROC: {auc_roc}")
    print(f"AUPR: {aupr}")
    print("\nClassification Report:\n", classification_report(final_targets, final_outputs, target_names=target_list))

    return f1, average_loss


#Learning Rate Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Training & Evaluation Loop
start = time.time()

history = defaultdict(list)
best_f1 = 0.0

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_loss = train_model(train_data_loader, model, optimizer)
    val_f1, val_loss = eval_model(val_data_loader, model)

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_f1'].append(val_f1)
    history['val_loss'].append(val_loss)

    scheduler.step()  # Step scheduler after each epoch

    if val_f1 > best_f1:
        torch.save(model.state_dict(), "caves_inceptiveCTBERT_16_best.bin")
        best_f1 = val_f1

end = time.time()
print(f"Total training and evaluation time: {end - start} seconds")


## Testing
print("\n\nTesting\n\n")
model = CTBERTInceptionAttention(num_classes=len(target_list))
model.load_state_dict(torch.load("caves_inceptiveCTBERT_16_best.bin"))
model = model.to(device)

start = time.time()
eval_model(test_data_loader, model)
end = time.time()
print(f"Total test-set evaluation time: {end - start} seconds")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/421 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.35G [00:00<?, ?B/s]

Epoch 1/12


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5379939209726444
F1 Score: 0.6493056543148418
Precision: 0.8163566253386522
Recall: 0.5685763888888888
Hamming Loss: 0.05867182462927144
Average Loss: 0.1613867189134321
AUC-ROC: 0.9317058584725552
AUPR: 0.782170978729099

Classification Report:
               precision    recall  f1-score   support

           0       0.73      0.22      0.34        49
           1       0.50      0.05      0.09        20
           2       0.83      0.57      0.68       167
           3       1.00      0.48      0.65        44
           4       0.87      0.62      0.72        78
           5       0.80      0.44      0.57       127
           6       0.80      0.13      0.22        63
           7       0.00      0.00      0.00         6
           8       0.68      0.75      0.71       147
           9       0.90      0.74      0.81       379
          10       0.69      0.33      0.45        72

   micro avg       0.82      0.57      0.67      1152
   macro avg       0.71      0.39    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5835866261398176
F1 Score: 0.7261870961223316
Precision: 0.7781664618515959
Recall: 0.6927083333333334
Hamming Loss: 0.052869116698903935
Average Loss: 0.14681897216266201
AUC-ROC: 0.9416548016592005
AUPR: 0.8067756590772996

Classification Report:
               precision    recall  f1-score   support

           0       0.59      0.47      0.52        49
           1       0.64      0.35      0.45        20
           2       0.81      0.63      0.71       167
           3       0.94      0.66      0.77        44
           4       0.87      0.74      0.80        78
           5       0.70      0.76      0.73       127
           6       0.68      0.33      0.45        63
           7       0.00      0.00      0.00         6
           8       0.70      0.78      0.74       147
           9       0.87      0.83      0.85       379
          10       0.66      0.43      0.52        72

   micro avg       0.78      0.69      0.74      1152
   macro avg       0.68      0.54 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5947315096251267
F1 Score: 0.7386685865305863
Precision: 0.7657556907639314
Recall: 0.7256944444444444
Hamming Loss: 0.05250069079856314
Average Loss: 0.15022634762910106
AUC-ROC: 0.9398393544106729
AUPR: 0.7847536467748791

Classification Report:
               precision    recall  f1-score   support

           0       0.64      0.47      0.54        49
           1       0.60      0.60      0.60        20
           2       0.70      0.78      0.74       167
           3       0.90      0.61      0.73        44
           4       0.84      0.74      0.79        78
           5       0.74      0.68      0.70       127
           6       0.68      0.37      0.47        63
           7       1.00      0.50      0.67         6
           8       0.70      0.81      0.75       147
           9       0.86      0.85      0.85       379
          10       0.65      0.43      0.52        72

   micro avg       0.77      0.73      0.75      1152
   macro avg       0.75      0.62  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 5/12
Accuracy: 0.6038500506585613
F1 Score: 0.7370792118563334
Precision: 0.767948260243302
Recall: 0.7161458333333334
Hamming Loss: 0.05250069079856314
Average Loss: 0.1536640258085343
AUC-ROC: 0.9381778095393857
AUPR: 0.7828030166485109

Classification Report:
               precision    recall  f1-score   support

           0       0.62      0.43      0.51        49
           1       0.67      0.50      0.57        20
           2       0.70      0.81      0.75       167
           3       0.86      0.70      0.78        44
           4       0.82      0.71      0.76        78
           5       0.76      0.68      0.72       127
           6       0.60      0.40      0.48        63
           7       1.00      0.50      0.67         6
           8       0.76      0.73      0.74       147
           9       0.86      0.84      0.85       379
          10       0.63      0.44      0.52        72

   micro avg       0.77      0.72      0.74      1152
   macro avg       0.75   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.6008105369807497
F1 Score: 0.7375418563218399
Precision: 0.7623804595647932
Recall: 0.7213541666666666
Hamming Loss: 0.052869116698903935
Average Loss: 0.1581627515054518
AUC-ROC: 0.9348211332797087
AUPR: 0.7683138949165086

Classification Report:
               precision    recall  f1-score   support

           0       0.65      0.41      0.50        49
           1       0.67      0.60      0.63        20
           2       0.76      0.71      0.73       167
           3       0.88      0.68      0.77        44
           4       0.82      0.74      0.78        78
           5       0.76      0.68      0.72       127
           6       0.57      0.37      0.45        63
           7       1.00      0.50      0.67         6
           8       0.70      0.80      0.75       147
           9       0.84      0.85      0.85       379
          10       0.59      0.54      0.57        72

   micro avg       0.77      0.72      0.74      1152
   macro avg       0.75      0.63  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.6048632218844985
F1 Score: 0.7406110162998771
Precision: 0.7552673557150337
Recall: 0.7326388888888888
Hamming Loss: 0.05305332964907433
Average Loss: 0.16088732956878601
AUC-ROC: 0.9347007961321153
AUPR: 0.7613777411655521

Classification Report:
               precision    recall  f1-score   support

           0       0.61      0.45      0.52        49
           1       0.65      0.55      0.59        20
           2       0.72      0.77      0.74       167
           3       0.89      0.70      0.78        44
           4       0.81      0.71      0.75        78
           5       0.75      0.70      0.72       127
           6       0.59      0.38      0.46        63
           7       1.00      0.50      0.67         6
           8       0.72      0.79      0.75       147
           9       0.85      0.86      0.86       379
          10       0.55      0.51      0.53        72

   micro avg       0.76      0.73      0.75      1152
   macro avg       0.74      0.63  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 8/12
Accuracy: 0.5987841945288754
F1 Score: 0.7396392316450235
Precision: 0.7533190981656749
Recall: 0.7326388888888888
Hamming Loss: 0.053329649074329924
Average Loss: 0.1635822055320586
AUC-ROC: 0.933224894229357
AUPR: 0.7589042577020855

Classification Report:
               precision    recall  f1-score   support

           0       0.63      0.45      0.52        49
           1       0.63      0.60      0.62        20
           2       0.69      0.77      0.73       167
           3       0.88      0.66      0.75        44
           4       0.80      0.78      0.79        78
           5       0.74      0.69      0.72       127
           6       0.58      0.40      0.47        63
           7       1.00      0.50      0.67         6
           8       0.73      0.78      0.75       147
           9       0.84      0.86      0.85       379
          10       0.61      0.51      0.56        72

   micro avg       0.76      0.73      0.74      1152
   macro avg       0.74  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.601823708206687
F1 Score: 0.7418506581226296
Precision: 0.7521628425989895
Recall: 0.7387152777777778
Hamming Loss: 0.05305332964907433
Average Loss: 0.16529147326946259
AUC-ROC: 0.9325249301035993
AUPR: 0.7571736545696836

Classification Report:
               precision    recall  f1-score   support

           0       0.63      0.45      0.52        49
           1       0.63      0.60      0.62        20
           2       0.71      0.75      0.73       167
           3       0.89      0.70      0.78        44
           4       0.83      0.76      0.79        78
           5       0.74      0.71      0.72       127
           6       0.60      0.38      0.47        63
           7       1.00      0.50      0.67         6
           8       0.71      0.80      0.75       147
           9       0.83      0.87      0.85       379
          10       0.58      0.53      0.55        72

   micro avg       0.76      0.74      0.75      1152
   macro avg       0.74      0.64   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 10/12
Accuracy: 0.6028368794326241
F1 Score: 0.7428966791939172
Precision: 0.7530285503902208
Recall: 0.7413194444444444
Hamming Loss: 0.052777010223818734
Average Loss: 0.1662094427212592
AUC-ROC: 0.9321424731352241
AUPR: 0.7550593022789501

Classification Report:
               precision    recall  f1-score   support

           0       0.62      0.43      0.51        49
           1       0.63      0.60      0.62        20
           2       0.71      0.77      0.74       167
           3       0.89      0.73      0.80        44
           4       0.83      0.77      0.80        78
           5       0.74      0.70      0.72       127
           6       0.65      0.38      0.48        63
           7       1.00      0.50      0.67         6
           8       0.71      0.80      0.75       147
           9       0.83      0.87      0.85       379
          10       0.58      0.53      0.55        72

   micro avg       0.76      0.74      0.75      1152
   macro avg       0.75

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 11/12
Accuracy: 0.6028368794326241
F1 Score: 0.7428966791939172
Precision: 0.7530285503902208
Recall: 0.7413194444444444
Hamming Loss: 0.052777010223818734
Average Loss: 0.1662094427212592
AUC-ROC: 0.9321424731352241
AUPR: 0.7550593022789501

Classification Report:
               precision    recall  f1-score   support

           0       0.62      0.43      0.51        49
           1       0.63      0.60      0.62        20
           2       0.71      0.77      0.74       167
           3       0.89      0.73      0.80        44
           4       0.83      0.77      0.80        78
           5       0.74      0.70      0.72       127
           6       0.65      0.38      0.48        63
           7       1.00      0.50      0.67         6
           8       0.71      0.80      0.75       147
           9       0.83      0.87      0.85       379
          10       0.58      0.53      0.55        72

   micro avg       0.76      0.74      0.75      1152
   macro avg       0.75

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.6048632218844985
F1 Score: 0.7438959085321062
Precision: 0.7500446692235551
Recall: 0.7456597222222222
Hamming Loss: 0.052869116698903935
Average Loss: 0.1673333608815747
AUC-ROC: 0.932286761151577
AUPR: 0.7548960301557022

Classification Report:
               precision    recall  f1-score   support

           0       0.63      0.45      0.52        49
           1       0.63      0.60      0.62        20
           2       0.70      0.77      0.73       167
           3       0.89      0.73      0.80        44
           4       0.82      0.77      0.79        78
           5       0.74      0.71      0.73       127
           6       0.62      0.38      0.47        63
           7       1.00      0.50      0.67         6
           8       0.70      0.80      0.75       147
           9       0.83      0.88      0.85       379
          10       0.59      0.53      0.56        72

   micro avg       0.75      0.75      0.75      1152
   macro avg       0.74      0.65   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Total training and evaluation time: 1461.7659289836884 seconds


Testing




  model.load_state_dict(torch.load("caves_CTBERTIncDNet16_32_base_best_final.bin"))


Accuracy: 0.6125442589782498
F1 Score: 0.7562036379523234
Precision: 0.7672779944575234
Recall: 0.7493506493506493
Hamming Loss: 0.05053570607440107
Average Loss: 0.16397577199724414
AUC-ROC: 0.9337180429636642
AUPR: 0.7666566862724324

Classification Report:
               precision    recall  f1-score   support

           0       0.67      0.53      0.59        97
           1       0.77      0.60      0.68        40
           2       0.72      0.80      0.76       334
           3       0.71      0.71      0.71        87
           4       0.75      0.70      0.73       157
           5       0.73      0.67      0.70       255
           6       0.64      0.55      0.59       125
           7       0.88      0.54      0.67        13
           8       0.75      0.81      0.78       295
           9       0.86      0.85      0.86       762
          10       0.71      0.57      0.63       145

   micro avg       0.77      0.75      0.76      2310
   macro avg       0.75      0.67  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
