In [None]:
from transformers import BertModel, BertTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, hamming_loss, roc_auc_score, average_precision_score
from collections import defaultdict
from torch.amp import autocast, GradScaler
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time
import pickle

## Hyperparameters
MAX_LEN = 512
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE = 8
TEST_BATCH_SIZE = 8
EPOCHS = 15
LEARNING_RATE = 1e-05
THRESHOLD = 0.5 # threshold for the sigmoid
## Dataset Class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        self.title = list(df['File Contents'])
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': title
        }
    
## Data
train_file = '../data/caves/train.csv'
val_file = '../data/caves/val.csv'
test_file = '../data/caves/test.csv'
train_df = pd.read_csv(train_file)
val_df = pd.read_csv(val_file)
test_df = pd.read_csv(test_file)

target_list = list(train_df.columns[1:])

## Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
train_dataset = CustomDataset(train_df, tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(val_df, tokenizer, MAX_LEN, target_list)
test_dataset = CustomDataset(test_df, tokenizer, MAX_LEN, target_list)

#print(train_dataset[0])

## Data Loader
train_data_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_data_loader = torch.utils.data.DataLoader(test_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

## Device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

## Model
class BERTClass(torch.nn.Module):
    def __init__(self, num_classes):
        super(BERTClass, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-cased', return_dict=True)
        self.dropout = torch.nn.Dropout(0.3)
        self.linear = torch.nn.Linear(768, num_classes)

    def forward(self, input_ids, attn_mask, token_type_ids):
        output = self.bert_model(
            input_ids, 
            attention_mask=attn_mask, 
            token_type_ids=token_type_ids
        )
        output_dropout = self.dropout(output.pooler_output)
        output = self.linear(output_dropout)
        return output
    
    
## Setting the model
model = BERTClass(num_classes=len(target_list))
model.to(device)

## Loss & Optimizer
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

# define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)


## Training function
def train_model(training_loader, model, optimizer, accumulation_steps=4):
    losses = []
    correct_predictions = 0
    num_samples = 0
    total_batches = len(training_loader)

    # Set model to training mode (activate dropout, batch norm)
    model.train()

    # Mixed precision
    scaler = GradScaler()

    for batch_idx, data in enumerate(training_loader):
        # Transfer data to the GPU right before using it
        ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
        mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
        targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

        # Forward pass with mixed precision
        with autocast('cuda'):
            outputs = model(ids, mask, token_type_ids)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

        # Training accuracy, apply sigmoid, round (apply threshold 0.5)
        outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        targets = targets.cpu().detach().numpy()
        correct_predictions += np.sum(outputs == targets)
        num_samples += targets.size

        # Free memory: delete tensors after they are no longer needed
        del ids, mask, token_type_ids, outputs, targets
        torch.cuda.empty_cache()

        # Backward pass with gradient accumulation
        loss = loss / accumulation_steps
        scaler.scale(loss).backward()

        # Gradient clipping
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Step optimizer every accumulation_steps
        if (batch_idx + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        # Clear GPU cache again after optimization step
        torch.cuda.empty_cache()

    # Perform the final optimizer step if not done already
    if (batch_idx + 1) % accumulation_steps != 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    # Returning: trained model, model accuracy, mean loss
    return model, float(correct_predictions) / num_samples, np.mean(losses)

## Evaluator Function
def eval_model(validation_loader, model, threshold=0.5, target_list=None):
    model.eval()
    final_targets = []
    final_outputs = []
    final_probs = []
    losses = []

    # Mixed precision
    scaler = GradScaler()

    with torch.no_grad():
        for data in validation_loader:
            ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
            mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
            targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

            # Mixed precision forward pass
            with autocast('cuda'):
                outputs = model(ids, mask, token_type_ids)
                loss = loss_fn(outputs, targets)
                losses.append(loss.item())

                probs = torch.sigmoid(outputs).cpu().detach().numpy()
                targets = targets.cpu().detach().numpy()
                final_outputs.extend(probs >= threshold)
                final_probs.extend(probs)
                final_targets.extend(targets)

            # Free memory
            del ids, mask, token_type_ids, outputs, targets
            torch.cuda.empty_cache()

    final_outputs = np.array(final_outputs) >= threshold
    final_probs = np.array(final_probs)
    final_targets = np.array(final_targets)

    # Calculating metrics
    acc = accuracy_score(final_targets, final_outputs)
    f1 = f1_score(final_targets, final_outputs, average='weighted')  # Consider using 'macro' or 'weighted' based on your problem
    precision = precision_score(final_targets, final_outputs, average='weighted')
    recall = recall_score(final_targets, final_outputs, average='weighted')
    hamming = hamming_loss(final_targets, final_outputs)

    auc_roc = roc_auc_score(final_targets, final_probs, average='weighted', multi_class='ovr')
    aupr = average_precision_score(final_targets, final_probs, average='weighted')

    average_loss = np.mean(losses)

    print(f"Accuracy: {acc}")
    print(f"F1 Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Hamming Loss: {hamming}")
    print(f"Average Loss: {average_loss}")
    print(f"AUC-ROC: {auc_roc}")
    print(f"AUPR: {aupr}")
    print("\nClassification Report:\n", classification_report(final_targets, final_outputs, target_names=target_list))

    return f1, average_loss

## Training & Evaluation
# recording starting time
start = time.time()

history = defaultdict(list)
best_f1 = 0.0

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_loss = train_model(train_data_loader, model, optimizer)
    val_f1, val_loss = eval_model(val_data_loader, model)

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_f1'].append(val_f1)
    history['val_loss'].append(val_loss)
    # save the best model
    if val_f1 > best_f1:
        with open("caves_BERT_8.pkl", "wb") as f:
            pickle.dump(model, f)
        best_f1 = val_f1


# recording end time
end = time.time()
print(f"Total training and validation time: {end - start} seconds")


## Testing
# Loading pretrained model (best model)
print("\n\nTesting\n\n")
# Load the best model using pickle
with open("caves_BERT_8.pkl", "rb") as f:
    model = pickle.load(f)
model = model.to(device)

# recording starting time
start = time.time()
# Evaluate the model using the test data
eval_model(test_data_loader, model)
# recording end time
end = time.time()
print(f"Test-set evaluation time: {end - start} seconds")

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]



Epoch 1/15


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.06382978723404255
F1 Score: 0.0
Precision: 0.0
Recall: 0.0
Hamming Loss: 0.10610665929814866
Average Loss: 0.5190125497118119
AUC-ROC: 0.45749096666295735
AUPR: 0.18387809543653175

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.00      0.00      0.00       379
          10       0.00      0.00      0.00        72

   micro avg       0.00      0.00      0.00      1152
   macro avg       0.00      0.00      0.00      1152
weighted avg       0.00 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.06382978723404255
F1 Score: 0.0
Precision: 0.0
Recall: 0.0
Hamming Loss: 0.10610665929814866
Average Loss: 0.3735157325863838
AUC-ROC: 0.46954728711547183
AUPR: 0.18858231979176188

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.00      0.00      0.00       379
          10       0.00      0.00      0.00        72

   micro avg       0.00      0.00      0.00      1152
   macro avg       0.00      0.00      0.00      1152
weighted avg       0.00 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.06382978723404255
F1 Score: 0.0
Precision: 0.0
Recall: 0.0
Hamming Loss: 0.10610665929814866
Average Loss: 0.3153043532323453
AUC-ROC: 0.5224249262416327
AUPR: 0.22489886023236125

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.00      0.00      0.00       379
          10       0.00      0.00      0.00        72

   micro avg       0.00      0.00      0.00      1152
   macro avg       0.00      0.00      0.00      1152
weighted avg       0.00  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.06382978723404255
F1 Score: 0.0
Precision: 0.0
Recall: 0.0
Hamming Loss: 0.10610665929814866
Average Loss: 0.3000697297674994
AUC-ROC: 0.5489657967883658
AUPR: 0.254188442755599

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.00      0.00      0.00       379
          10       0.00      0.00      0.00        72

   micro avg       0.00      0.00      0.00      1152
   macro avg       0.00      0.00      0.00      1152
weighted avg       0.00    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.06382978723404255
F1 Score: 0.0
Precision: 0.0
Recall: 0.0
Hamming Loss: 0.10610665929814866
Average Loss: 0.29538225959385594
AUC-ROC: 0.5728844198330751
AUPR: 0.2735879395748361

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.00      0.00      0.00       379
          10       0.00      0.00      0.00        72

   micro avg       0.00      0.00      0.00      1152
   macro avg       0.00      0.00      0.00      1152
weighted avg       0.00  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.06382978723404255
F1 Score: 0.0
Precision: 0.0
Recall: 0.0
Hamming Loss: 0.10610665929814866
Average Loss: 0.29328973963856697
AUC-ROC: 0.5819148094940798
AUPR: 0.2817438138936968

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.00      0.00      0.00       379
          10       0.00      0.00      0.00        72

   micro avg       0.00      0.00      0.00      1152
   macro avg       0.00      0.00      0.00      1152
weighted avg       0.00  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.06382978723404255
F1 Score: 0.0
Precision: 0.0
Recall: 0.0
Hamming Loss: 0.10610665929814866
Average Loss: 0.2921154108499327
AUC-ROC: 0.5975194287516167
AUPR: 0.28759570721263994

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.00      0.00      0.00       379
          10       0.00      0.00      0.00        72

   micro avg       0.00      0.00      0.00      1152
   macro avg       0.00      0.00      0.00      1152
weighted avg       0.00  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.0790273556231003
F1 Score: 0.025050232656514387
Precision: 0.3289930555555556
Recall: 0.013020833333333336
Hamming Loss: 0.10472506217187068
Average Loss: 0.2910406511397131
AUC-ROC: 0.6002208972599947
AUPR: 0.29191305428399505

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       1.00      0.04      0.08       379
          10       0.00      0.00      0.00        72

   micro avg       1.00      0.01      0.03      1152
   macro avg       0.09      0.

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.13576494427558258
F1 Score: 0.10966435185185185
Precision: 0.27358369883040934
Recall: 0.0685763888888889
Hamming Loss: 0.10030395136778116
Average Loss: 0.28978377808966943
AUC-ROC: 0.614110131188071
AUPR: 0.30258791167399757

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.83      0.21      0.33       379
          10       0.00      0.00      0.00        72

   micro avg       0.83      0.07      0.13      1152
   macro avg       0.08      0.0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.08105369807497467
F1 Score: 0.02817572767982088
Precision: 0.31071566358024694
Recall: 0.014756944444444444
Hamming Loss: 0.10463295569678549
Average Loss: 0.28893613370676197
AUC-ROC: 0.6643603514965084
AUPR: 0.33445558110103696

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.94      0.04      0.09       379
          10       0.00      0.00      0.00        72

   micro avg       0.94      0.01      0.03      1152
   macro avg       0.09      

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.19351570415400202
F1 Score: 0.1741384201697257
Precision: 0.2664128547705314
Recall: 0.1293402777777778
Hamming Loss: 0.09560652113843603
Average Loss: 0.28556140247852574
AUC-ROC: 0.6576916022279418
AUPR: 0.3337369112522539

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.81      0.39      0.53       379
          10       0.00      0.00      0.00        72

   micro avg       0.81      0.13      0.22      1152
   macro avg       0.07      0.04 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2188449848024316
F1 Score: 0.19644529159978596
Precision: 0.2507897882513661
Recall: 0.16145833333333334
Hamming Loss: 0.09431703048724326
Average Loss: 0.28305026624471913
AUC-ROC: 0.6796653004900698
AUPR: 0.3507621246722612

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.76      0.49      0.60       379
          10       0.00      0.00      0.00        72

   micro avg       0.76      0.16      0.27      1152
   macro avg       0.07      0.04

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.25430597771023306
F1 Score: 0.21964842484612895
Precision: 0.2454052433948607
Recall: 0.1987847222222222
Hamming Loss: 0.09219858156028368
Average Loss: 0.2808427634018083
AUC-ROC: 0.6856811558638944
AUPR: 0.36136863180757495

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.75      0.60      0.67       379
          10       0.00      0.00      0.00        72

   micro avg       0.75      0.20      0.31      1152
   macro avg       0.07      0.05

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.22391084093211752
F1 Score: 0.20714377572016462
Precision: 0.28614744832041344
Recall: 0.1623263888888889
Hamming Loss: 0.0914617297596021
Average Loss: 0.2769227985653185
AUC-ROC: 0.7205713666519793
AUPR: 0.3921152663827183

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.87      0.49      0.63       379
          10       0.00      0.00      0.00        72

   micro avg       0.87      0.16      0.27      1152
   macro avg       0.08      0.04 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.27152988855116517
F1 Score: 0.2405852300538424
Precision: 0.2712558636865342
Recall: 0.21614583333333337
Hamming Loss: 0.08805379018144975
Average Loss: 0.27394073216184495
AUC-ROC: 0.7195168139397773
AUPR: 0.40015757691590964

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.82      0.66      0.73       379
          10       0.00      0.00      0.00        72

   micro avg       0.82      0.22      0.34      1152
   macro avg       0.07      0.0

  return torch.load(io.BytesIO(b))


Accuracy: 0.26454223571067276
F1 Score: 0.2302417235943247
Precision: 0.2561531715872552
Recall: 0.20909090909090908
Hamming Loss: 0.09040327401480663
Average Loss: 0.27614808352964537
AUC-ROC: 0.7027746034732724
AUPR: 0.3864396337050024

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        97
           1       0.00      0.00      0.00        40
           2       0.00      0.00      0.00       334
           3       0.00      0.00      0.00        87
           4       0.00      0.00      0.00       157
           5       0.00      0.00      0.00       255
           6       0.00      0.00      0.00       125
           7       0.00      0.00      0.00        13
           8       0.00      0.00      0.00       295
           9       0.78      0.63      0.70       762
          10       0.00      0.00      0.00       145

   micro avg       0.78      0.21      0.33      2310
   macro avg       0.07      0.06

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
