In [1]:
# !pip install numpy
# !pip install pandas
# !pip install scikit-learn
# !pip install torch
# !pip install transformers

from transformers import XLNetModel, XLNetTokenizer, XLNetForSequenceClassification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, hamming_loss, roc_auc_score, average_precision_score
from collections import defaultdict
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time
## Hyperparameters
MAX_LEN = 512
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE = 8
TEST_BATCH_SIZE = 8
EPOCHS = 15
LEARNING_RATE = 1e-05
THRESHOLD = 0.5 # threshold for the sigmoid
## Dataset Class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        self.title = list(df['File Contents'])
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': title
        }
## Data
train_file = '/kaggle/input/caves-data/caves/caves_train.csv'
val_file = '/kaggle/input/caves-data/caves/caves_val.csv'
test_file = '/kaggle/input/caves-data/caves/caves_test.csv'
train_df = pd.read_csv(train_file)
val_df = pd.read_csv(val_file)
test_df = pd.read_csv(test_file)

target_list = list(train_df.columns[1:])
## Tokenizer
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
train_dataset = CustomDataset(train_df, tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(val_df, tokenizer, MAX_LEN, target_list)
test_dataset = CustomDataset(test_df, tokenizer, MAX_LEN, target_list)

#print(train_dataset[0])
## Data Loader
train_data_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_data_loader = torch.utils.data.DataLoader(test_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)
## Device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device
## Model
class XLNETBase(nn.Module):
    def __init__(self, num_classes):
        super(XLNETBase, self).__init__()
        self.xlnet = XLNetModel.from_pretrained('xlnet-base-cased')
        self.drop = nn.Dropout(0.3)
        self.out = nn.Linear(768, num_classes)  # Assuming 23 classes for classification

    def forward(self, input_ids, attention_mask, token_type_ids):
        # Get the full output
        outputs = self.xlnet(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        # The last hidden state is the first element of the output
        last_hidden_state = outputs[0]

        # Pooling operation if needed, for example, using the [CLS] token's embedding
        # Assuming [CLS] is the first token, similar to BERT. Adjust as needed.
        pooled_output = last_hidden_state[:, 0]

        output = self.drop(pooled_output)
        return self.out(output)
## Setting the model
model = XLNETBase(num_classes=len(target_list))
model.to(device)
## Loss & Optimizer
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

# define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
## Training function
def train_model(training_loader, model, optimizer, accumulation_steps=4):
    losses = []
    correct_predictions = 0
    num_samples = 0
    total_batches = len(training_loader)

    # Set model to training mode (activate dropout, batch norm)
    model.train()

    # Mixed precision
    scaler = GradScaler()

    for batch_idx, data in enumerate(training_loader):
        ids = data['input_ids'].to(device, dtype=torch.long)
        mask = data['attention_mask'].to(device, dtype=torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
        targets = data['targets'].to(device, dtype=torch.float)

        # Forward pass with mixed precision
        with autocast():
            outputs = model(ids, mask, token_type_ids)  # (batch, predict) = (8, 8)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

        # Training accuracy, apply sigmoid, round (apply threshold 0.5)
        outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        targets = targets.cpu().detach().numpy()
        correct_predictions += np.sum(outputs == targets)
        num_samples += targets.size  # Total number of elements in the 2D array

        # Backward pass with gradient accumulation
        loss = loss / accumulation_steps  # Normalize loss to account for accumulation
        scaler.scale(loss).backward()

        # Gradient clipping
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Step optimizer every accumulation_steps
        if (batch_idx + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        # Clear GPU cache
        torch.cuda.empty_cache()

    # Perform the final optimizer step if not done already
    if (batch_idx + 1) % accumulation_steps != 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    # Returning: trained model, model accuracy, mean loss
    return model, float(correct_predictions) / num_samples, np.mean(losses)
## Evaluator Function
def eval_model(validation_loader, model, threshold=0.5, target_list=None):
    model.eval()
    final_targets = []
    final_outputs = []
    final_probs = []
    losses = []

    # Mixed precision
    scaler = torch.cuda.amp.GradScaler()

    with torch.no_grad():
        for data in validation_loader:
            ids = data['input_ids'].to(device, dtype=torch.long)
            mask = data['attention_mask'].to(device, dtype=torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
            targets = data['targets'].to(device, dtype=torch.float)

            # Mixed precision forward pass
            with torch.cuda.amp.autocast():
                outputs = model(ids, mask, token_type_ids)
                loss = loss_fn(outputs, targets)
                losses.append(loss.item())

                probs = torch.sigmoid(outputs).cpu().detach().numpy()
                targets = targets.cpu().detach().numpy()
                final_outputs.extend(probs >= threshold)
                final_probs.extend(probs)
                final_targets.extend(targets)

            # Clear GPU cache
            torch.cuda.empty_cache()

    final_outputs = np.array(final_outputs) >= threshold
    final_probs = np.array(final_probs)
    final_targets = np.array(final_targets)

    # Calculating metrics
    acc = accuracy_score(final_targets, final_outputs)
    f1 = f1_score(final_targets, final_outputs, average='weighted')  # Consider using 'macro' or 'weighted' based on your problem
    precision = precision_score(final_targets, final_outputs, average='weighted')
    recall = recall_score(final_targets, final_outputs, average='weighted')
    hamming = hamming_loss(final_targets, final_outputs)

    auc_roc = roc_auc_score(final_targets, final_probs, average='weighted', multi_class='ovr')
    aupr = average_precision_score(final_targets, final_probs, average='weighted')

    average_loss = np.mean(losses)

    print(f"Accuracy: {acc}")
    print(f"F1 Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Hamming Loss: {hamming}")
    print(f"Average Loss: {average_loss}")
    print(f"AUC-ROC: {auc_roc}")
    print(f"AUPR: {aupr}")
    # Detailed classification report
#     if target_list:
#         print("\nClassification Report:\n", classification_report(final_targets, final_outputs, target_names=target_list))
    print("\nClassification Report:\n", classification_report(final_targets, final_outputs, target_names=target_list))

    print("\n\n")
    return f1, average_loss
## Training & Evaluation
# recording starting time
start = time.time()

history = defaultdict(list)
best_f1 = 0.0

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_loss = train_model(train_data_loader, model, optimizer)
    val_f1, val_loss = eval_model(val_data_loader, model)

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_f1'].append(val_f1)
    history['val_loss'].append(val_loss)
    # save the best model
    if val_f1 > best_f1:
        torch.save(model.state_dict(), "caves_XLNet_8_MLTC_model_state.bin")
        best_f1 = val_f1

# recording end time
end = time.time()
print(f"Total training and validation time: {end - start} seconds")
## Testing
# Loading pretrained model (best model)
print("\n\nTesting\n\n")
model = XLNETBase(num_classes=len(target_list))
model.load_state_dict(torch.load("caves_XLNet_8_MLTC_model_state.bin"))
model = model.to(device)

# recording starting time
start = time.time()
# Evaluate the model using the test data
eval_model(test_data_loader, model)
# recording end time
end = time.time()
print(f"Test-set evaluation time: {end - start} seconds")

spiece.model:   0%|          | 0.00/798k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.38M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/760 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/467M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


Epoch 1/15


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.1590678824721378
F1 Score: 0.13896096570397112
Precision: 0.21995535714285713
Recall: 0.10156250000000001
Hamming Loss: 0.10067237726812195
Average Loss: 0.2906408027535485
AUC-ROC: 0.6250638868759285
AUPR: 0.3033188378716097

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.67      0.31      0.42       379
          10       0.00      0.00      0.00        72

   micro avg       0.67      0.10      0.18      1152
   macro avg       0.06      0.03

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.18439716312056736
F1 Score: 0.16561555177626605
Precision: 0.23297115895800108
Recall: 0.1284722222222222
Hamming Loss: 0.0980933959657364
Average Loss: 0.28779745654713723
AUC-ROC: 0.6705857859252513
AUPR: 0.3482417567371526

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.71      0.39      0.50       379
          10       0.00      0.00      0.00        72

   micro avg       0.71      0.13      0.22      1152
   macro avg       0.06      0.04

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2137791286727457
F1 Score: 0.1909238387978142
Precision: 0.252085588023088
Recall: 0.15364583333333334
Hamming Loss: 0.09477756286266925
Average Loss: 0.27685812848710245
AUC-ROC: 0.7126719682602415
AUPR: 0.38708550179210754

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.77      0.47      0.58       379
          10       0.00      0.00      0.00        72

   micro avg       0.77      0.15      0.26      1152
   macro avg       0.07      0.04 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.23404255319148937
F1 Score: 0.20670792293070672
Precision: 0.2517295349326599
Recall: 0.1753472222222222
Hamming Loss: 0.09321175278622088
Average Loss: 0.2685635958708102
AUC-ROC: 0.7461139664357399
AUPR: 0.4193629083365489

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.00      0.00      0.00       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.77      0.53      0.63       379
          10       0.00      0.00      0.00        72

   micro avg       0.77      0.18      0.29      1152
   macro avg       0.07      0.05 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2553191489361702
F1 Score: 0.23013389787309102
Precision: 0.3774618755087505
Recall: 0.19618055555555555
Hamming Loss: 0.08998802615823892
Average Loss: 0.257328080313821
AUC-ROC: 0.7745685693231372
AUPR: 0.4532929547611265

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.75      0.02      0.04       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.00      0.00      0.00       147
           9       0.82      0.59      0.68       379
          10       0.00      0.00      0.00        72

   micro avg       0.82      0.20      0.32      1152
   macro avg       0.14      0.06  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2867274569402229
F1 Score: 0.2790565704172592
Precision: 0.4042265464248971
Recall: 0.2378472222222222
Hamming Loss: 0.08768536428110896
Average Loss: 0.24628591405287867
AUC-ROC: 0.8025759872621434
AUPR: 0.49040213359270596

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.65      0.19      0.29       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.00      0.00      0.00       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.33      0.01      0.01       147
           9       0.81      0.64      0.72       379
          10       0.00      0.00      0.00        72

   micro avg       0.79      0.24      0.37      1152
   macro avg       0.16      0.08 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2978723404255319
F1 Score: 0.2792359577882352
Precision: 0.5330788665530253
Recall: 0.2526041666666667
Hamming Loss: 0.08713272543059777
Average Loss: 0.23605143855656346
AUC-ROC: 0.8251817435922217
AUPR: 0.5294449299153112

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.67      0.02      0.05       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.88      0.12      0.21       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.67      0.03      0.05       147
           9       0.77      0.71      0.74       379
          10       0.00      0.00      0.00        72

   micro avg       0.77      0.25      0.38      1152
   macro avg       0.27      0.08  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3323201621073962
F1 Score: 0.3481709713142991
Precision: 0.5047859185558878
Recall: 0.3315972222222222
Hamming Loss: 0.08584323477940499
Average Loss: 0.23364977322278485
AUC-ROC: 0.8391514713076219
AUPR: 0.5581107584963994

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.75      0.14      0.24       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.73      0.17      0.28       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.67      0.16      0.26       147
           9       0.70      0.82      0.76       379
          10       0.00      0.00      0.00        72

   micro avg       0.70      0.33      0.45      1152
   macro avg       0.26      0.12  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3434650455927052
F1 Score: 0.3782530104820947
Precision: 0.5297552025994483
Recall: 0.3324652777777778
Hamming Loss: 0.08151423045040067
Average Loss: 0.22292567398999968
AUC-ROC: 0.852015368977245
AUPR: 0.5785451976061612

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.76      0.23      0.36       167
           3       0.00      0.00      0.00        44
           4       0.00      0.00      0.00        78
           5       0.70      0.15      0.25       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.63      0.22      0.33       147
           9       0.79      0.77      0.78       379
          10       0.00      0.00      0.00        72

   micro avg       0.77      0.33      0.46      1152
   macro avg       0.26      0.13   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3586626139817629
F1 Score: 0.4026556256150847
Precision: 0.6081194218130928
Recall: 0.3550347222222222
Hamming Loss: 0.07875103619784471
Average Loss: 0.21496566096621175
AUC-ROC: 0.8621768520464247
AUPR: 0.6009134493797806

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.71      0.40      0.51       167
           3       0.00      0.00      0.00        44
           4       1.00      0.03      0.05        78
           5       0.74      0.16      0.26       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.69      0.16      0.26       147
           9       0.82      0.78      0.80       379
          10       0.00      0.00      0.00        72

   micro avg       0.79      0.36      0.49      1152
   macro avg       0.36      0.14  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.4072948328267477
F1 Score: 0.4886294641291514
Precision: 0.5878176116374988
Recall: 0.4305555555555556
Hamming Loss: 0.07432992539375519
Average Loss: 0.21152669846290542
AUC-ROC: 0.8667748219544953
AUPR: 0.6154608966256389

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.64      0.51      0.57       167
           3       0.00      0.00      0.00        44
           4       0.79      0.28      0.42        78
           5       0.66      0.33      0.44       127
           6       0.00      0.00      0.00        63
           7       0.00      0.00      0.00         6
           8       0.68      0.40      0.50       147
           9       0.86      0.76      0.81       379
          10       0.00      0.00      0.00        72

   micro avg       0.77      0.43      0.55      1152
   macro avg       0.33      0.21  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.4083080040526849
F1 Score: 0.4907550610114868
Precision: 0.6270071130513524
Recall: 0.4262152777777778
Hamming Loss: 0.0735009671179884
Average Loss: 0.20603212627071527
AUC-ROC: 0.8721918591292841
AUPR: 0.6257464686949217

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.74      0.46      0.56       167
           3       0.00      0.00      0.00        44
           4       0.83      0.19      0.31        78
           5       0.66      0.39      0.49       127
           6       0.43      0.05      0.09        63
           7       0.00      0.00      0.00         6
           8       0.67      0.41      0.51       147
           9       0.85      0.76      0.80       379
          10       0.00      0.00      0.00        72

   micro avg       0.78      0.43      0.55      1152
   macro avg       0.38      0.20   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.4072948328267477
F1 Score: 0.48084393620947086
Precision: 0.7678793558384434
Recall: 0.4123263888888889
Hamming Loss: 0.0734088606429032
Average Loss: 0.20447729834385456
AUC-ROC: 0.8759906624784399
AUPR: 0.6374178664107837

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        49
           1       0.00      0.00      0.00        20
           2       0.78      0.37      0.50       167
           3       1.00      0.09      0.17        44
           4       0.87      0.26      0.40        78
           5       0.70      0.37      0.48       127
           6       1.00      0.03      0.06        63
           7       0.00      0.00      0.00         6
           8       0.73      0.26      0.38       147
           9       0.82      0.79      0.81       379
          10       1.00      0.04      0.08        72

   micro avg       0.80      0.41      0.54      1152
   macro avg       0.63      0.20  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.4275582573454914
F1 Score: 0.5175727951329647
Precision: 0.767087192372533
Recall: 0.4583333333333333
Hamming Loss: 0.07202726351662522
Average Loss: 0.19877881142160586
AUC-ROC: 0.8831161021376785
AUPR: 0.6478932535949913

Classification Report:
               precision    recall  f1-score   support

           0       1.00      0.02      0.04        49
           1       0.00      0.00      0.00        20
           2       0.66      0.58      0.62       167
           3       1.00      0.05      0.09        44
           4       0.89      0.21      0.33        78
           5       0.73      0.42      0.53       127
           6       0.40      0.06      0.11        63
           7       0.00      0.00      0.00         6
           8       0.67      0.41      0.51       147
           9       0.85      0.77      0.81       379
          10       1.00      0.04      0.08        72

   micro avg       0.77      0.46      0.57      1152
   macro avg       0.65      0.23   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.4508611955420466
F1 Score: 0.5475770396900372
Precision: 0.7299888425483986
Recall: 0.5112847222222222
Hamming Loss: 0.07110619876577323
Average Loss: 0.20116241862096132
AUC-ROC: 0.886445553601455
AUPR: 0.6578009938736611

Classification Report:
               precision    recall  f1-score   support

           0       1.00      0.02      0.04        49
           1       0.00      0.00      0.00        20
           2       0.78      0.50      0.61       167
           3       1.00      0.09      0.17        44
           4       0.84      0.33      0.48        78
           5       0.70      0.37      0.48       127
           6       0.40      0.06      0.11        63
           7       0.00      0.00      0.00         6
           8       0.60      0.65      0.63       147
           9       0.79      0.84      0.81       379
          10       0.73      0.11      0.19        72

   micro avg       0.74      0.51      0.60      1152
   macro avg       0.62      0.27   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
