In [None]:
# !pip install numpy
# !pip install pandas
# !pip install scikit-learn
# !pip install torch
# !pip install transformers


from transformers import RobertaModel, RobertaTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, hamming_loss, roc_auc_score, average_precision_score
from collections import defaultdict
from torch.amp import autocast, GradScaler
import torch.nn.functional as F
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time

## Hyperparameters
MAX_LEN = 128
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
EPOCHS = 12
LEARNING_RATE = 1e-05
THRESHOLD = 0.5 # threshold for the sigmoid
## Dataset Class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        self.title = list(df['File Contents'])
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': title
        }

## Data
train_file = '/content/train.csv'
val_file = '/content/val.csv'
test_file = '/content/test.csv'
train_df = pd.read_csv(train_file)
val_df = pd.read_csv(val_file)
test_df = pd.read_csv(test_file)

target_list = list(train_df.columns[1:])

## Tokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')


train_dataset = CustomDataset(train_df, tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(val_df, tokenizer, MAX_LEN, target_list)
test_dataset = CustomDataset(test_df, tokenizer, MAX_LEN, target_list)

#print(train_dataset[0])

## Data Loader
train_data_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_data_loader = torch.utils.data.DataLoader(test_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)
## Device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

## Model

class RobertaBase(nn.Module):
    def __init__(self, num_classes):
        super(RobertaBase, self).__init__()
        self.roberta = RobertaModel.from_pretrained('roberta-base')

        # Dropout layer
        self.drop = nn.Dropout(0.3)

        # Fully connected layer for classification
        self.fc = nn.Linear(768, num_classes)

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        # RoBERTa features
        outputs = self.roberta(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        roberta_features = outputs.last_hidden_state  # (batch_size, seq_length, 768)

        # Apply dropout
        roberta_features = self.drop(roberta_features)  # (batch_size, seq_length, 768)

        # Take the [CLS] token representation for classification (first token)
        cls_token = roberta_features[:, 0, :]  # (batch_size, 768)

        # Final classification
        output = self.fc(cls_token)  # (batch_size, num_classes)

        return output


## Setting the model
model = RobertaBase(num_classes=len(target_list))
model.to(device)

## Loss & Optimizer
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

# define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

## Training function
def train_model(training_loader, model, optimizer, accumulation_steps=1):  # Removed accumulation_steps
    losses = []
    correct_predictions = 0
    num_samples = 0
    total_batches = len(training_loader)

    model.train()

    for batch_idx, data in enumerate(training_loader):

        ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
        mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
        targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

  
        outputs = model(ids, mask, token_type_ids)
        loss = loss_fn(outputs, targets)
        losses.append(loss.item())

        # Training accuracy, apply sigmoid, round (apply threshold 0.5)
        outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        targets = targets.cpu().detach().numpy()
        correct_predictions += np.sum(outputs == targets)
        num_samples += targets.size


        loss.backward()


        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

   
        optimizer.step()
        optimizer.zero_grad()

        # Clear GPU cache
        torch.cuda.empty_cache()

    return model, float(correct_predictions) / num_samples, np.mean(losses)



def eval_model(validation_loader, model, threshold=0.5, target_list=None):
    model.eval()
    final_targets = []
    final_outputs = []
    final_probs = []
    losses = []

    with torch.no_grad():
        for data in validation_loader:
            ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
            mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
            targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

            # Forward pass
            outputs = model(ids, mask, token_type_ids)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

            probs = torch.sigmoid(outputs).cpu().detach().numpy()
            targets = targets.cpu().detach().numpy()
            final_outputs.extend(probs >= threshold)
            final_probs.extend(probs)
            final_targets.extend(targets)

            torch.cuda.empty_cache()

    final_outputs = np.array(final_outputs) >= threshold
    final_probs = np.array(final_probs)
    final_targets = np.array(final_targets)

    acc = accuracy_score(final_targets, final_outputs)
    f1 = f1_score(final_targets, final_outputs, average='weighted')
    precision = precision_score(final_targets, final_outputs, average='weighted')
    recall = recall_score(final_targets, final_outputs, average='weighted')
    hamming = hamming_loss(final_targets, final_outputs)

    auc_roc = roc_auc_score(final_targets, final_probs, average='weighted', multi_class='ovr')
    aupr = average_precision_score(final_targets, final_probs, average='weighted')

    average_loss = np.mean(losses)

    print(f"Accuracy: {acc}")
    print(f"F1 Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Hamming Loss: {hamming}")
    print(f"Average Loss: {average_loss}")
    print(f"AUC-ROC: {auc_roc}")
    print(f"AUPR: {aupr}")
    print("\nClassification Report:\n", classification_report(final_targets, final_outputs, target_names=target_list))

    return f1, average_loss



scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Training & Evaluation Loop
start = time.time()

history = defaultdict(list)
best_f1 = 0.0

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_loss = train_model(train_data_loader, model, optimizer)
    val_f1, val_loss = eval_model(val_data_loader, model)

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_f1'].append(val_f1)
    history['val_loss'].append(val_loss)

    scheduler.step()  # Step scheduler after each epoch


    if val_f1 > best_f1:
        torch.save(model.state_dict(), "caves_roberta_32_best_final.bin")
        best_f1 = val_f1

end = time.time()
print(f"Total training and evaluation time: {end - start} seconds")


## Testing
print("\n\nTesting\n\n")
model = RobertaBase(num_classes=len(target_list))
model.load_state_dict(torch.load("caves_roberta_32_best_final.bin"))
model = model.to(device)

start = time.time()
eval_model(test_data_loader, model)
end = time.time()
print(f"Total test-set evaluation time: {end - start} seconds")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/12


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.414387031408308
F1 Score: 0.44435147964202926
Precision: 0.5578402017170582
Recall: 0.3967013888888889
Hamming Loss: 0.07626416137054436
Average Loss: 0.21791814123430558
AUC-ROC: 0.8703359312933248
AUPR: 0.6278748580238175

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.61      0.71      0.65       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.76      0.25      0.38       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.72      0.24      0.37       147
           9       0.89      0.72      0.79       379
          10       0.00      0.00      0.00        72

   micro avg       0.77      0.40      0.52      1152
   macro avg       0.27      0.17  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.49341438703140833
F1 Score: 0.5906586084443995
Precision: 0.7458586970826071
Recall: 0.5269097222222222
Hamming Loss: 0.06382978723404255
Average Loss: 0.18251887877141276
AUC-ROC: 0.9101615941017767
AUPR: 0.7059443880339271

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.74      0.62      0.67       167
           3       0.79      0.25      0.38        44
           4       0.85      0.42      0.56        78
           5       0.71      0.52      0.60       127
           6       0.40      0.06      0.11        63
           7       0.00      0.00      0.00         6
           8       0.69      0.64      0.66       147
           9       0.91      0.77      0.83       379
          10       1.00      0.08      0.15        72

   micro avg       0.80      0.53      0.64      1152
   macro avg       0.55      0.31 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.513677811550152
F1 Score: 0.630594571587192
Precision: 0.7332680232342585
Recall: 0.5946180555555556
Hamming Loss: 0.06254029658284978
Average Loss: 0.16870342867989693
AUC-ROC: 0.9205469742870195
AUPR: 0.7461461852500484

Classification Report:
               precision    recall  f1-score   support

           0       0.67      0.08      0.15        49
           1       0.00      0.00      0.00        20
           2       0.72      0.65      0.68       167
           3       0.90      0.20      0.33        44
           4       0.84      0.53      0.65        78
           5       0.72      0.56      0.63       127
           6       0.57      0.13      0.21        63
           7       0.00      0.00      0.00         6
           8       0.69      0.66      0.68       147
           9       0.82      0.86      0.84       379
          10       0.63      0.31      0.41        72

   micro avg       0.76      0.59      0.67      1152
   macro avg       0.60      0.36    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5491388044579534
F1 Score: 0.6719835918183139
Precision: 0.74991198834918
Recall: 0.6267361111111112
Hamming Loss: 0.05821129225384544
Average Loss: 0.1594374634085163
AUC-ROC: 0.9264650062743596
AUPR: 0.7689934198139188

Classification Report:
               precision    recall  f1-score   support

           0       0.64      0.18      0.29        49
           1       0.00      0.00      0.00        20
           2       0.76      0.66      0.71       167
           3       0.79      0.59      0.68        44
           4       0.84      0.60      0.70        78
           5       0.76      0.56      0.64       127
           6       0.65      0.24      0.35        63
           7       0.00      0.00      0.00         6
           8       0.71      0.61      0.66       147
           9       0.85      0.85      0.85       379
          10       0.60      0.40      0.48        72

   micro avg       0.78      0.63      0.70      1152
   macro avg       0.60      0.43     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5582573454913881
F1 Score: 0.689241184711692
Precision: 0.7762080738480632
Recall: 0.6432291666666666
Hamming Loss: 0.05655337570231187
Average Loss: 0.1591108381267517
AUC-ROC: 0.9281214089281193
AUPR: 0.7751983029743861

Classification Report:
               precision    recall  f1-score   support

           0       0.77      0.20      0.32        49
           1       0.83      0.25      0.38        20
           2       0.73      0.72      0.72       167
           3       0.82      0.61      0.70        44
           4       0.85      0.60      0.71        78
           5       0.75      0.60      0.66       127
           6       0.60      0.29      0.39        63
           7       0.00      0.00      0.00         6
           8       0.68      0.69      0.69       147
           9       0.87      0.82      0.85       379
          10       0.73      0.33      0.46        72

   micro avg       0.78      0.64      0.71      1152
   macro avg       0.69      0.47    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5643363728470111
F1 Score: 0.6971870065541466
Precision: 0.7608459974216779
Recall: 0.6579861111111112
Hamming Loss: 0.056921801602652665
Average Loss: 0.15805932567004236
AUC-ROC: 0.9277512876803071
AUPR: 0.7763064937047974

Classification Report:
               precision    recall  f1-score   support

           0       0.69      0.22      0.34        49
           1       0.89      0.40      0.55        20
           2       0.77      0.66      0.71       167
           3       0.85      0.66      0.74        44
           4       0.82      0.65      0.73        78
           5       0.71      0.65      0.67       127
           6       0.62      0.29      0.39        63
           7       0.00      0.00      0.00         6
           8       0.73      0.66      0.69       147
           9       0.85      0.84      0.84       379
          10       0.53      0.47      0.50        72

   micro avg       0.77      0.66      0.71      1152
   macro avg       0.68      0.50 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5754812563323202
F1 Score: 0.7110874248519212
Precision: 0.7586708942948442
Recall: 0.6779513888888888
Hamming Loss: 0.05554020447637469
Average Loss: 0.15720542279943342
AUC-ROC: 0.9291360030030784
AUPR: 0.7782528461496884

Classification Report:
               precision    recall  f1-score   support

           0       0.61      0.39      0.47        49
           1       0.89      0.40      0.55        20
           2       0.76      0.69      0.72       167
           3       0.84      0.70      0.77        44
           4       0.86      0.69      0.77        78
           5       0.73      0.63      0.68       127
           6       0.58      0.35      0.44        63
           7       0.00      0.00      0.00         6
           8       0.72      0.69      0.70       147
           9       0.83      0.85      0.84       379
          10       0.64      0.42      0.50        72

   micro avg       0.77      0.68      0.72      1152
   macro avg       0.68      0.53  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.569402228976697
F1 Score: 0.7055123014075962
Precision: 0.7479953709309548
Recall: 0.6762152777777778
Hamming Loss: 0.05701390807773787
Average Loss: 0.1602559680900266
AUC-ROC: 0.9276271898085783
AUPR: 0.7749916265053498

Classification Report:
               precision    recall  f1-score   support

           0       0.59      0.35      0.44        49
           1       0.80      0.40      0.53        20
           2       0.75      0.68      0.71       167
           3       0.87      0.59      0.70        44
           4       0.84      0.69      0.76        78
           5       0.74      0.65      0.69       127
           6       0.57      0.38      0.46        63
           7       0.00      0.00      0.00         6
           8       0.74      0.65      0.70       147
           9       0.81      0.86      0.84       379
          10       0.56      0.44      0.50        72

   micro avg       0.76      0.68      0.72      1152
   macro avg       0.66      0.52    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5775075987841946
F1 Score: 0.7059591199874851
Precision: 0.758794205656052
Recall: 0.6727430555555556
Hamming Loss: 0.05572441742654509
Average Loss: 0.15927291709569194
AUC-ROC: 0.9280482957006781
AUPR: 0.7767831484754527

Classification Report:
               precision    recall  f1-score   support

           0       0.58      0.29      0.38        49
           1       0.89      0.40      0.55        20
           2       0.75      0.69      0.72       167
           3       0.90      0.64      0.75        44
           4       0.85      0.72      0.78        78
           5       0.75      0.61      0.68       127
           6       0.59      0.38      0.46        63
           7       0.00      0.00      0.00         6
           8       0.75      0.65      0.70       147
           9       0.81      0.87      0.84       379
          10       0.63      0.38      0.47        72

   micro avg       0.77      0.67      0.72      1152
   macro avg       0.68      0.51   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5683890577507599
F1 Score: 0.7049133377586694
Precision: 0.7458100709736084
Recall: 0.6770833333333334
Hamming Loss: 0.05710601455282306
Average Loss: 0.159194057747241
AUC-ROC: 0.9278896871884899
AUPR: 0.7762556807446318

Classification Report:
               precision    recall  f1-score   support

           0       0.59      0.33      0.42        49
           1       0.80      0.40      0.53        20
           2       0.74      0.69      0.72       167
           3       0.88      0.64      0.74        44
           4       0.84      0.68      0.75        78
           5       0.73      0.65      0.69       127
           6       0.55      0.33      0.42        63
           7       0.00      0.00      0.00         6
           8       0.72      0.71      0.71       147
           9       0.83      0.85      0.84       379
          10       0.55      0.43      0.48        72

   micro avg       0.76      0.68      0.72      1152
   macro avg       0.66      0.52    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.5683890577507599
F1 Score: 0.7049133377586694
Precision: 0.7458100709736084
Recall: 0.6770833333333334
Hamming Loss: 0.05710601455282306
Average Loss: 0.159194057747241
AUC-ROC: 0.9278896871884899
AUPR: 0.7762556807446318

Classification Report:
               precision    recall  f1-score   support

           0       0.59      0.33      0.42        49
           1       0.80      0.40      0.53        20
           2       0.74      0.69      0.72       167
           3       0.88      0.64      0.74        44
           4       0.84      0.68      0.75        78
           5       0.73      0.65      0.69       127
           6       0.55      0.33      0.42        63
           7       0.00      0.00      0.00         6
           8       0.72      0.71      0.71       147
           9       0.83      0.85      0.84       379
          10       0.55      0.43      0.48        72

   micro avg       0.76      0.68      0.72      1152
   macro avg       0.66      0.52    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.569402228976697
F1 Score: 0.7108419348510346
Precision: 0.7481835704616963
Recall: 0.6848958333333334
Hamming Loss: 0.05627705627705628
Average Loss: 0.15956395335735812
AUC-ROC: 0.9279964087553374
AUPR: 0.7759286639723506

Classification Report:
               precision    recall  f1-score   support

           0       0.57      0.33      0.42        49
           1       0.80      0.40      0.53        20
           2       0.75      0.69      0.72       167
           3       0.85      0.64      0.73        44
           4       0.85      0.74      0.79        78
           5       0.73      0.63      0.68       127
           6       0.59      0.38      0.46        63
           7       0.00      0.00      0.00         6
           8       0.72      0.73      0.72       147
           9       0.83      0.85      0.84       379
          10       0.55      0.43      0.48        72

   micro avg       0.76      0.68      0.72      1152
   macro avg       0.66      0.53   

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  model.load_state_dict(torch.load("caves_roberta_32_best_final.bin"))


Accuracy: 0.577137076378351
F1 Score: 0.711835052005773
Precision: 0.7619901915791534
Recall: 0.674025974025974
Hamming Loss: 0.05577780843334713
Average Loss: 0.1592050079376467
AUC-ROC: 0.9259840531812296
AUPR: 0.7712575510760469

Classification Report:
               precision    recall  f1-score   support

           0       0.63      0.37      0.47        97
           1       0.82      0.45      0.58        40
           2       0.75      0.72      0.73       334
           3       0.66      0.64      0.65        87
           4       0.77      0.65      0.70       157
           5       0.75      0.60      0.66       255
           6       0.65      0.50      0.57       125
           7       0.00      0.00      0.00        13
           8       0.75      0.70      0.72       295
           9       0.85      0.81      0.83       762
          10       0.68      0.43      0.53       145

   micro avg       0.77      0.67      0.72      2310
   macro avg       0.66      0.53      

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
