In [None]:
# !pip install numpy
# !pip install pandas
# !pip install scikit-learn
# !pip install torch
# !pip install transformers


from transformers import AutoModel, AutoTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, hamming_loss, roc_auc_score, average_precision_score
from collections import defaultdict
from torch.amp import autocast, GradScaler
import torch.nn.functional as F
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time

## Hyperparameters
MAX_LEN = 512
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
EPOCHS = 12
LEARNING_RATE = 1e-05
THRESHOLD = 0.5 # threshold for the sigmoid

## Dataset Class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        self.title = list(df['File Contents']) # 'File Contents' is the column name for text data
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': title
        }

## Data
train_file = '/content/train.csv'
test_file = '/content/test.csv'
train_val_df = pd.read_csv(train_file)
test_df = pd.read_csv(test_file)

train_df, val_df = train_test_split(train_val_df, test_size=0.2, random_state=42)

target_list = list(train_df.columns[1:])

## Tokenizer

tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1')


train_dataset = CustomDataset(train_df, tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(val_df, tokenizer, MAX_LEN, target_list)
test_dataset = CustomDataset(test_df, tokenizer, MAX_LEN, target_list)

#print(train_dataset[0])

## Data Loader
train_data_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_data_loader = torch.utils.data.DataLoader(test_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

## Model

class InceptionBlock(nn.Module):
    """
    An Inception-style block for 1D sequences with:
    - 4 convolution branches (kernel sizes 2, 3, 5, 7)
    - BatchNorm + ReLU in each branch
    """
    def __init__(self, in_channels=768, branch_out=32):
        super().__init__()
        self.branch2 = nn.Sequential(
            nn.Conv1d(in_channels, branch_out, kernel_size=2),
            nn.BatchNorm1d(branch_out),
            nn.ReLU()
        )
        self.branch3 = nn.Sequential(
            nn.Conv1d(in_channels, branch_out, kernel_size=3),
            nn.BatchNorm1d(branch_out),
            nn.ReLU()
        )
        self.branch5 = nn.Sequential(
            nn.Conv1d(in_channels, branch_out, kernel_size=5),
            nn.BatchNorm1d(branch_out),
            nn.ReLU()
        )
        self.branch7 = nn.Sequential(
            nn.Conv1d(in_channels, branch_out, kernel_size=7),
            nn.BatchNorm1d(branch_out),
            nn.ReLU()
        )

    def forward(self, x):
        # We do manual padding to maintain sequence length
        b2 = F.pad(self.branch2(x), (0, 1))   # kernel_size=2 => right-pad 1
        b3 = F.pad(self.branch3(x), (1, 1))   # kernel_size=3 => left+right-pad 1
        b5 = F.pad(self.branch5(x), (2, 2))   # kernel_size=5 => left+right-pad 2
        b7 = F.pad(self.branch7(x), (3, 3))   # kernel_size=7 => left+right-pad 3

        out = torch.cat([b2, b3, b5, b7], dim=1) 

        return out

class BioBERTInceptionAttention(nn.Module):
    def __init__(self, num_classes):
        super(BioBERTInceptionAttention, self).__init__()

        # BioBERT
        self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.1', return_dict=True)
        hidden_size = self.bert.config.hidden_size 

        self.dropout = nn.Dropout(0.3)

        # Inception block with skip connections
        self.branch_out = 128
        self.inception = InceptionBlock(
            in_channels=hidden_size,
            branch_out=self.branch_out
        )

        self.fused_dim = hidden_size + 4 * self.branch_out

        # 4) Self-attention layer after Inception block
        self.attention = nn.MultiheadAttention(embed_dim=self.fused_dim, num_heads=8, batch_first=True)

        # 5) Dense block
        #    - "layernorm, gelu" in that order before each linear
        self.dense = nn.Sequential(
            nn.Linear(self.fused_dim, 512),
            nn.ReLU(),
            nn.LayerNorm(512)
        )

        # Final dropout and classification
        self.final_dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, input_ids, attention_mask, token_type_ids=None):

        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        # shape => (batch, seq_len, hidden_size)
        hidden_states = outputs.last_hidden_state
        hidden_states = self.dropout(hidden_states)

        # Permuting to (batch, hidden_size, seq_len) to apply conv
        x = hidden_states.permute(0, 2, 1)  # (batch, 768, seq_len)
        inception_out = self.inception(x)   # (batch, 128, seq_len)
        # Permute back to (batch, seq_len, 128)
        inception_out = inception_out.permute(0, 2, 1)

        # Fuse: cat([BERT, Inception], dim=2)
        fused_features = torch.cat([hidden_states, inception_out], dim=2)


        # key padding mask: True => mask out
        # Our attention_mask is 1 for valid, 0 for pad => invert it
        key_padding_mask = ~(attention_mask.bool())

        attn_output, _ = self.attention(
            fused_features,
            fused_features,
            fused_features,
            key_padding_mask=key_padding_mask
        )

        # Instead of using the CLS token, we average over seq_len
        pooled_features = F.adaptive_avg_pool1d(attn_output.permute(0, 2, 1), output_size=1).squeeze(-1)

        # F) Dense block -> final

        dense_out = self.dense(pooled_features)      # (batch, 256)
        dense_out = self.final_dropout(dense_out)
        logits = self.fc(dense_out)                  # (batch, num_classes)

        return logits






## Setting the model
model = BioBERTInceptionAttention(num_classes=len(target_list))
model.to(device)

## Loss Fn
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

# optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)


def train_model(training_loader, model, optimizer, accumulation_steps=1):  # Removed accumulation_steps
    losses = []
    correct_predictions = 0
    num_samples = 0
    total_batches = len(training_loader)

    model.train()

    for batch_idx, data in enumerate(training_loader):
        # Transfer data to the GPU right before using it
        ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
        mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
        targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

        outputs = model(ids, mask, token_type_ids)
        loss = loss_fn(outputs, targets)
        losses.append(loss.item())

        # Training accuracy, apply sigmoid, round (apply threshold 0.5)
        outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        targets = targets.cpu().detach().numpy()
        correct_predictions += np.sum(outputs == targets)
        num_samples += targets.size


        loss.backward()


        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Stepping optimizer
        optimizer.step()
        optimizer.zero_grad()

        # Clearing GPU cache
        torch.cuda.empty_cache()

    return model, float(correct_predictions) / num_samples, np.mean(losses)



def eval_model(validation_loader, model, threshold=0.5, target_list=None):
    model.eval()
    final_targets = []
    final_outputs = []
    final_probs = []
    losses = []

    with torch.no_grad():
        for data in validation_loader:
            ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
            mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
            targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

 
            outputs = model(ids, mask, token_type_ids)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

            probs = torch.sigmoid(outputs).cpu().detach().numpy()
            targets = targets.cpu().detach().numpy()
            final_outputs.extend(probs >= threshold)
            final_probs.extend(probs)
            final_targets.extend(targets)

            torch.cuda.empty_cache()

    final_outputs = np.array(final_outputs) >= threshold
    final_probs = np.array(final_probs)
    final_targets = np.array(final_targets)

    acc = accuracy_score(final_targets, final_outputs)
    f1 = f1_score(final_targets, final_outputs, average='weighted')
    precision = precision_score(final_targets, final_outputs, average='weighted')
    recall = recall_score(final_targets, final_outputs, average='weighted')
    hamming = hamming_loss(final_targets, final_outputs)

    auc_roc = roc_auc_score(final_targets, final_probs, average='weighted', multi_class='ovr')
    aupr = average_precision_score(final_targets, final_probs, average='weighted')

    average_loss = np.mean(losses)

    print(f"Accuracy: {acc}")
    print(f"F1 Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Hamming Loss: {hamming}")
    print(f"Average Loss: {average_loss}")
    print(f"AUC-ROC: {auc_roc}")
    print(f"AUPR: {aupr}")
    print("\nClassification Report:\n", classification_report(final_targets, final_outputs, target_names=target_list))

    return f1, average_loss


# Learning Rate Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Training & Evaluation Loop
start = time.time()

history = defaultdict(list)
best_f1 = 0.0

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_loss = train_model(train_data_loader, model, optimizer)
    val_f1, val_loss = eval_model(val_data_loader, model)

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_f1'].append(val_f1)
    history['val_loss'].append(val_loss)

    scheduler.step()  # Stepping scheduler after each epoch

    # save the best model
    if val_f1 > best_f1:
        torch.save(model.state_dict(), "ohsumed_BioBERTIncDNet128_32_best.bin")
        best_f1 = val_f1

end = time.time()
print(f"Total training and evaluation time: {end - start} seconds")


## Testing
print("\n\nTesting\n\n")
model = BioBERTInceptionAttention(num_classes=len(target_list))
model.load_state_dict(torch.load("ohsumed_BioBERTIncDNet128_32_best.bin"))
model = model.to(device)

start = time.time()
eval_model(test_data_loader, model)
end = time.time()
print(f"Total test-set evaluation time: {end - start} seconds")

Epoch 1/12


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.31319554848966613
F1 Score: 0.47455558840744616
Precision: 0.7210394766612231
Recall: 0.4156398104265403
Hamming Loss: 0.04980299993087717
Average Loss: 0.17587030306458473
AUC-ROC: 0.897820488889788
AUPR: 0.6647323445137895

Classification Report:
               precision    recall  f1-score   support

           0       0.90      0.26      0.41       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.84      0.86      0.85       234
           4       0.68      0.25      0.37        51
           5       0.70      0.63      0.66       109
           6       0.00      0.00      0.00        16
           7       0.77      0.47      0.59        99
           8       1.00      0.05      0.10        20
           9       0.76      0.30      0.43       122
          10       0.68      0.50      0.58        26
          11       0.88      0.73      0.80       110
          12       0.81      0.55  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.3728139904610493
F1 Score: 0.5638805614094747
Precision: 0.8077235733699644
Recall: 0.4909952606635071
Hamming Loss: 0.044134927766641324
Average Loss: 0.14307936932891607
AUC-ROC: 0.919367817026774
AUPR: 0.7429608222743578

Classification Report:
               precision    recall  f1-score   support

           0       0.91      0.28      0.43       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.91      0.78      0.84       234
           4       0.80      0.24      0.36        51
           5       0.73      0.69      0.71       109
           6       0.75      0.19      0.30        16
           7       0.80      0.61      0.69        99
           8       0.86      0.30      0.44        20
           9       0.69      0.48      0.57       122
          10       0.75      0.69      0.72        26
          11       0.88      0.84      0.86       110
          12       0.71      0.64   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.417329093799682
F1 Score: 0.6520113270054844
Precision: 0.777535699548599
Recall: 0.5890995260663507
Hamming Loss: 0.04022948779982028
Average Loss: 0.12846425715833903
AUC-ROC: 0.9270150919740872
AUPR: 0.7649781847839447

Classification Report:
               precision    recall  f1-score   support

           0       0.87      0.65      0.74       102
           1       0.00      0.00      0.00        36
           2       1.00      0.44      0.61        16
           3       0.88      0.83      0.85       234
           4       0.71      0.57      0.63        51
           5       0.72      0.74      0.73       109
           6       0.75      0.56      0.64        16
           7       0.78      0.61      0.68        99
           8       0.80      0.40      0.53        20
           9       0.66      0.58      0.62       122
          10       0.74      0.65      0.69        26
          11       0.87      0.82      0.85       110
          12       0.68      0.68     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.43402225755166934
F1 Score: 0.6873029179909524
Precision: 0.7789373586577494
Recall: 0.638388625592417
Hamming Loss: 0.03815580286168521
Average Loss: 0.12214596215635538
AUC-ROC: 0.9301843909255713
AUPR: 0.7799764396103451

Classification Report:
               precision    recall  f1-score   support

           0       0.82      0.74      0.78       102
           1       0.56      0.25      0.35        36
           2       0.77      0.62      0.69        16
           3       0.87      0.87      0.87       234
           4       0.70      0.55      0.62        51
           5       0.73      0.74      0.74       109
           6       0.83      0.62      0.71        16
           7       0.76      0.61      0.67        99
           8       0.82      0.45      0.58        20
           9       0.66      0.60      0.63       122
          10       0.81      0.65      0.72        26
          11       0.88      0.86      0.87       110
          12       0.71      0.74   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.4507154213036566
F1 Score: 0.7100573139680958
Precision: 0.7826909245994483
Recall: 0.6597156398104266
Hamming Loss: 0.036842469067533
Average Loss: 0.11651075910776854
AUC-ROC: 0.9324530876343607
AUPR: 0.7855470257491025

Classification Report:
               precision    recall  f1-score   support

           0       0.84      0.70      0.76       102
           1       0.58      0.19      0.29        36
           2       0.77      0.62      0.69        16
           3       0.90      0.85      0.87       234
           4       0.68      0.59      0.63        51
           5       0.74      0.73      0.74       109
           6       0.67      0.75      0.71        16
           7       0.75      0.70      0.72        99
           8       0.74      0.70      0.72        20
           9       0.69      0.61      0.65       122
          10       0.79      0.73      0.76        26
          11       0.86      0.86      0.86       110
          12       0.76      0.72     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.45310015898251194
F1 Score: 0.7159076666608477
Precision: 0.7857028157062798
Recall: 0.6701421800947868
Hamming Loss: 0.036427732079905996
Average Loss: 0.11374791748821736
AUC-ROC: 0.9339580965875196
AUPR: 0.7883440527760716

Classification Report:
               precision    recall  f1-score   support

           0       0.86      0.71      0.77       102
           1       0.91      0.28      0.43        36
           2       0.77      0.62      0.69        16
           3       0.90      0.86      0.88       234
           4       0.67      0.61      0.64        51
           5       0.71      0.75      0.73       109
           6       0.86      0.75      0.80        16
           7       0.72      0.71      0.71        99
           8       0.68      0.65      0.67        20
           9       0.70      0.61      0.66       122
          10       0.86      0.69      0.77        26
          11       0.87      0.85      0.86       110
          12       0.74      0.74 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.4562798092209857
F1 Score: 0.7197723126626487
Precision: 0.7857155166533363
Recall: 0.6729857819905213
Hamming Loss: 0.036116679339185734
Average Loss: 0.11182483769953251
AUC-ROC: 0.9336796635876417
AUPR: 0.7878510888335226

Classification Report:
               precision    recall  f1-score   support

           0       0.86      0.72      0.78       102
           1       0.75      0.33      0.46        36
           2       0.77      0.62      0.69        16
           3       0.88      0.87      0.88       234
           4       0.71      0.57      0.63        51
           5       0.73      0.74      0.74       109
           6       0.86      0.75      0.80        16
           7       0.72      0.68      0.70        99
           8       0.69      0.55      0.61        20
           9       0.70      0.61      0.65       122
          10       0.82      0.69      0.75        26
          11       0.86      0.85      0.85       110
          12       0.76      0.72  

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.4523052464228935
F1 Score: 0.7216333298430874
Precision: 0.7702086143260163
Recall: 0.6867298578199053
Hamming Loss: 0.0366351005737195
Average Loss: 0.11353404577821494
AUC-ROC: 0.933756714238236
AUPR: 0.7860704988250564

Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.73      0.78       102
           1       0.70      0.39      0.50        36
           2       0.77      0.62      0.69        16
           3       0.89      0.87      0.88       234
           4       0.68      0.63      0.65        51
           5       0.71      0.77      0.74       109
           6       0.80      0.75      0.77        16
           7       0.71      0.70      0.70        99
           8       0.68      0.65      0.67        20
           9       0.69      0.59      0.63       122
          10       0.82      0.69      0.75        26
          11       0.86      0.84      0.85       110
          12       0.76      0.72     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.4491255961844197
F1 Score: 0.7257038368432268
Precision: 0.7669302422623534
Recall: 0.6938388625592417
Hamming Loss: 0.0366351005737195
Average Loss: 0.11246686615049839
AUC-ROC: 0.9334877510316402
AUPR: 0.7858432268161528

Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.72      0.78       102
           1       0.64      0.39      0.48        36
           2       0.77      0.62      0.69        16
           3       0.89      0.87      0.88       234
           4       0.68      0.63      0.65        51
           5       0.73      0.76      0.74       109
           6       0.80      0.75      0.77        16
           7       0.72      0.69      0.70        99
           8       0.68      0.65      0.67        20
           9       0.68      0.59      0.63       122
          10       0.82      0.69      0.75        26
          11       0.86      0.85      0.85       110
          12       0.74      0.72    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.4546899841017488
F1 Score: 0.7261357140205343
Precision: 0.7730105287053073
Recall: 0.6890995260663507
Hamming Loss: 0.036427732079905996
Average Loss: 0.1117443472146988
AUC-ROC: 0.9335619304216025
AUPR: 0.7858862421455691

Classification Report:
               precision    recall  f1-score   support

           0       0.86      0.71      0.77       102
           1       0.64      0.39      0.48        36
           2       0.77      0.62      0.69        16
           3       0.90      0.85      0.87       234
           4       0.70      0.63      0.66        51
           5       0.72      0.76      0.74       109
           6       0.80      0.75      0.77        16
           7       0.72      0.69      0.70        99
           8       0.68      0.65      0.67        20
           9       0.69      0.59      0.64       122
          10       0.82      0.69      0.75        26
          11       0.86      0.85      0.85       110
          12       0.77      0.72   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.45548489666136727
F1 Score: 0.7256001506852059
Precision: 0.7736755447234518
Recall: 0.6876777251184835
Hamming Loss: 0.036496854911177164
Average Loss: 0.11178542468696832
AUC-ROC: 0.9336354006439798
AUPR: 0.7856741164511889

Classification Report:
               precision    recall  f1-score   support

           0       0.86      0.72      0.78       102
           1       0.65      0.42      0.51        36
           2       0.77      0.62      0.69        16
           3       0.91      0.83      0.87       234
           4       0.69      0.61      0.65        51
           5       0.71      0.77      0.74       109
           6       0.80      0.75      0.77        16
           7       0.72      0.69      0.70        99
           8       0.68      0.65      0.67        20
           9       0.70      0.59      0.64       122
          10       0.82      0.69      0.75        26
          11       0.86      0.85      0.85       110
          12       0.77      0.72 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.4578696343402226
F1 Score: 0.7245540450746778
Precision: 0.7741964101392148
Recall: 0.685781990521327
Hamming Loss: 0.03639317066427041
Average Loss: 0.11167171150445938
AUC-ROC: 0.9335586888498398
AUPR: 0.7856191488174253

Classification Report:
               precision    recall  f1-score   support

           0       0.86      0.72      0.78       102
           1       0.65      0.42      0.51        36
           2       0.77      0.62      0.69        16
           3       0.89      0.86      0.87       234
           4       0.69      0.61      0.65        51
           5       0.72      0.74      0.73       109
           6       0.80      0.75      0.77        16
           7       0.72      0.69      0.70        99
           8       0.68      0.65      0.67        20
           9       0.70      0.59      0.64       122
          10       0.82      0.69      0.75        26
          11       0.86      0.85      0.85       110
          12       0.77      0.72    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
