In [None]:
## !pip install numpy
# !pip install pandas
# !pip install scikit-learn
# !pip install torch
# !pip install transformers

from transformers import XLNetModel, XLNetTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, hamming_loss, roc_auc_score, average_precision_score
from collections import defaultdict
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time

## Hyperparameters
MAX_LEN = 512
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE = 8
TEST_BATCH_SIZE = 8
EPOCHS = 15
LEARNING_RATE = 1e-05
THRESHOLD = 0.5 # threshold for the sigmoid

## Dataset Class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        self.title = list(df['File Contents'])
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': title
        }
    
## Data
train_file = '../data/ohsumed/train.csv'
val_file = '../data/ohsumed/val.csv'
test_file = '../data/ohsumed/test.csv'
train_df = pd.read_csv(train_file)
val_df = pd.read_csv(val_file)
test_df = pd.read_csv(test_file)

target_list = list(train_df.columns[1:])

## Tokenizer
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
train_dataset = CustomDataset(train_df, tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(val_df, tokenizer, MAX_LEN, target_list)
test_dataset = CustomDataset(test_df, tokenizer, MAX_LEN, target_list)
#print(train_dataset[0])

## Data Loader
train_data_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_data_loader = torch.utils.data.DataLoader(test_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

## Device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

## Model
class XLNETWithCNN(nn.Module):
    def __init__(self, num_classes=23):
        super(XLNETWithCNN, self).__init__()
        self.xlnet = XLNetModel.from_pretrained('xlnet-base-cased')
        self.drop = nn.Dropout(0.3)
        
        # CNN Layer
        self.conv1 = nn.Conv1d(in_channels=768, out_channels=128, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2)
        
        # Fully connected layer
        self.fc = nn.Linear(128 * (MAX_LEN // 2), num_classes)  # Adjust based on the pooling size

    def forward(self, input_ids, attention_mask, token_type_ids):
        # Get the last hidden state from XLNet
        outputs = self.xlnet(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        last_hidden_state = outputs[0]  # Shape: (batch_size, sequence_length, hidden_size)
        
        # Apply dropout
        x = self.drop(last_hidden_state)
        
        # Permute to match the input dimensions required by Conv1d: (batch_size, hidden_size, sequence_length)
        x = x.permute(0, 2, 1)
        
        # Apply CNN
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        
        # Flatten the output from the CNN
        x = x.view(x.size(0), -1)
        
        # Fully connected layer for classification
        x = self.fc(x)
        
        return x
    

## Setting the model
model = XLNETWithCNN(num_classes=len(target_list))
model.to(device)


## Loss & Optimizer
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

# define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)


## Training function
def train_model(training_loader, model, optimizer, accumulation_steps=4):
    losses = []
    correct_predictions = 0
    num_samples = 0
    total_batches = len(training_loader)

    # Set model to training mode (activate dropout, batch norm)
    model.train()

    # Mixed precision
    scaler = GradScaler()

    for batch_idx, data in enumerate(training_loader):
        ids = data['input_ids'].to(device, dtype=torch.long)
        mask = data['attention_mask'].to(device, dtype=torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
        targets = data['targets'].to(device, dtype=torch.float)

        # Forward pass with mixed precision
        with autocast():
            outputs = model(ids, mask, token_type_ids)  # (batch, predict) = (8, 8)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

        # Training accuracy, apply sigmoid, round (apply threshold 0.5)
        outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        targets = targets.cpu().detach().numpy()
        correct_predictions += np.sum(outputs == targets)
        num_samples += targets.size  # Total number of elements in the 2D array

        # Backward pass with gradient accumulation
        loss = loss / accumulation_steps  # Normalize loss to account for accumulation
        scaler.scale(loss).backward()

        # Gradient clipping
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Step optimizer every accumulation_steps
        if (batch_idx + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        # Clear GPU cache
        torch.cuda.empty_cache()

    # Perform the final optimizer step if not done already
    if (batch_idx + 1) % accumulation_steps != 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    # Returning: trained model, model accuracy, mean loss
    return model, float(correct_predictions) / num_samples, np.mean(losses)


## Evaluator Function
def eval_model(validation_loader, model, threshold=0.5, target_list=None):
    model.eval()
    final_targets = []
    final_outputs = []
    final_probs = []
    losses = []

    # Mixed precision
    scaler = torch.cuda.amp.GradScaler()

    with torch.no_grad():
        for data in validation_loader:
            ids = data['input_ids'].to(device, dtype=torch.long)
            mask = data['attention_mask'].to(device, dtype=torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
            targets = data['targets'].to(device, dtype=torch.float)

            # Mixed precision forward pass
            with torch.cuda.amp.autocast():
                outputs = model(ids, mask, token_type_ids)
                loss = loss_fn(outputs, targets)
                losses.append(loss.item())

                probs = torch.sigmoid(outputs).cpu().detach().numpy()
                targets = targets.cpu().detach().numpy()
                final_outputs.extend(probs >= threshold)
                final_probs.extend(probs)
                final_targets.extend(targets)

            # Clear GPU cache
            torch.cuda.empty_cache()

    final_outputs = np.array(final_outputs) >= threshold
    final_probs = np.array(final_probs)
    final_targets = np.array(final_targets)

    # Calculating metrics
    acc = accuracy_score(final_targets, final_outputs)
    f1 = f1_score(final_targets, final_outputs, average='weighted')  # Consider using 'macro' or 'weighted' based on your problem
    precision = precision_score(final_targets, final_outputs, average='weighted')
    recall = recall_score(final_targets, final_outputs, average='weighted')
    hamming = hamming_loss(final_targets, final_outputs)

    auc_roc = roc_auc_score(final_targets, final_probs, average='weighted', multi_class='ovr')
    aupr = average_precision_score(final_targets, final_probs, average='weighted')

    average_loss = np.mean(losses)

    print(f"Accuracy: {acc}")
    print(f"F1 Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Hamming Loss: {hamming}")
    print(f"Average Loss: {average_loss}")
    print(f"AUC-ROC: {auc_roc}")
    print(f"AUPR: {aupr}")
    print("\nClassification Report:\n", classification_report(final_targets, final_outputs, target_names=target_list))

    print("\n\n")
    return f1, average_loss


## Training & Evaluation
# recording starting time
import time
start = time.time()

history = defaultdict(list)
best_f1 = 0.0

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_loss = train_model(train_data_loader, model, optimizer)
    val_f1, val_loss = eval_model(val_data_loader, model)

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_f1'].append(val_f1)
    history['val_loss'].append(val_loss)
    # save the best model
    if val_f1 > best_f1:
        torch.save(model.state_dict(), "ohsumed_XLNET_CNN_8_MLTC_model_state.bin")
        best_f1 = val_f1

# recording end time
end = time.time()
print(f"Total training time: {end - start} seconds")


## Testing
# Loading pretrained model (best model)
print("\n\nTesting\n\n")
model = XLNETWithCNN(num_classes=len(target_list))
model.load_state_dict(torch.load("ohsumed_XLNET_CNN_8_MLTC_model_state.bin"))
model = model.to(device)

# recording starting time
start = time.time()
# Evaluate the model using the test data
eval_model(test_data_loader, model)
# recording end time
end = time.time()
print(f"Total evaluation time: {end - start} seconds")

spiece.model:   0%|          | 0.00/798k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.38M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/760 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/467M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


Epoch 1/15


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.0047694753577106515
F1 Score: 0.01855631005397886
Precision: 0.11020397745090173
Recall: 0.010184287099903006
Hamming Loss: 0.07192230593765121
Average Loss: 0.2517439649452137
AUC-ROC: 0.5721875017939394
AUPR: 0.17183175133407744

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        69
           1       0.00      0.00      0.00        29
           2       0.00      0.00      0.00        12
           3       0.00      0.00      0.00       249
           4       0.00      0.00      0.00        41
           5       0.00      0.00      0.00       116
           6       0.00      0.00      0.00        24
           7       0.00      0.00      0.00        85
           8       0.00      0.00      0.00        24
           9       0.00      0.00      0.00       118
          10       0.00      0.00      0.00        30
          11       0.00      0.00      0.00        97
          12       0.00      

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.03577106518282989
F1 Score: 0.07182741420866161
Precision: 0.24205705937418143
Recall: 0.04704170708050436
Hamming Loss: 0.07002142807769406
Average Loss: 0.24342705818671215
AUC-ROC: 0.6838804207864287
AUPR: 0.27518809980416375

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        69
           1       0.00      0.00      0.00        29
           2       0.00      0.00      0.00        12
           3       0.67      0.02      0.03       249
           4       0.00      0.00      0.00        41
           5       0.00      0.00      0.00       116
           6       0.00      0.00      0.00        24
           7       0.00      0.00      0.00        85
           8       0.00      0.00      0.00        24
           9       0.00      0.00      0.00       118
          10       0.00      0.00      0.00        30
          11       0.00      0.00      0.00        97
          12       0.00      0.

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.16136724960254373
F1 Score: 0.24959375705782735
Precision: 0.47113655411229016
Recall: 0.22356935014548981
Hamming Loss: 0.06224510955968757
Average Loss: 0.2143736616531505
AUC-ROC: 0.8091690268500153
AUPR: 0.4420403891586283

Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        69
           1       0.00      0.00      0.00        29
           2       0.00      0.00      0.00        12
           3       0.75      0.74      0.74       249
           4       0.00      0.00      0.00        41
           5       0.90      0.08      0.14       116
           6       0.00      0.00      0.00        24
           7       0.00      0.00      0.00        85
           8       0.00      0.00      0.00        24
           9       0.67      0.10      0.18       118
          10       0.00      0.00      0.00        30
          11       0.00      0.00      0.00        97
          12       0.00      0.00

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.21939586645468998
F1 Score: 0.41555330975326343
Precision: 0.5378023915309864
Recall: 0.3874878758486906
Hamming Loss: 0.06103546001244211
Average Loss: 0.19904032000635244
AUC-ROC: 0.8478496519835965
AUPR: 0.5233377527419069

Classification Report:
               precision    recall  f1-score   support

           0       0.28      0.45      0.34        69
           1       0.00      0.00      0.00        29
           2       0.00      0.00      0.00        12
           3       0.79      0.78      0.79       249
           4       0.56      0.12      0.20        41
           5       0.53      0.64      0.58       116
           6       0.00      0.00      0.00        24
           7       0.59      0.41      0.49        85
           8       0.00      0.00      0.00        24
           9       0.56      0.42      0.48       118
          10       0.00      0.00      0.00        30
          11       0.59      0.34      0.43        97
          12       0.71      0.30 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2631160572337043
F1 Score: 0.5077112608154141
Precision: 0.5546920560360751
Recall: 0.4980601357904947
Hamming Loss: 0.057717564111426006
Average Loss: 0.18372770054619522
AUC-ROC: 0.8692929176681525
AUPR: 0.569559607213265

Classification Report:
               precision    recall  f1-score   support

           0       0.54      0.29      0.38        69
           1       0.28      0.28      0.28        29
           2       0.00      0.00      0.00        12
           3       0.74      0.84      0.79       249
           4       0.43      0.49      0.45        41
           5       0.52      0.69      0.59       116
           6       0.00      0.00      0.00        24
           7       0.53      0.68      0.60        85
           8       0.00      0.00      0.00        24
           9       0.58      0.44      0.50       118
          10       0.74      0.57      0.64        30
          11       0.79      0.51      0.62        97
          12       0.62      0.48   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2543720190779014
F1 Score: 0.5419680484880375
Precision: 0.5937950762843834
Recall: 0.5528612997090203
Hamming Loss: 0.05782124835833276
Average Loss: 0.18340968647146527
AUC-ROC: 0.877833259151297
AUPR: 0.589671352341957

Classification Report:
               precision    recall  f1-score   support

           0       0.42      0.45      0.43        69
           1       0.30      0.38      0.33        29
           2       0.00      0.00      0.00        12
           3       0.79      0.82      0.80       249
           4       0.49      0.49      0.49        41
           5       0.55      0.70      0.62       116
           6       0.00      0.00      0.00        24
           7       0.55      0.64      0.59        85
           8       1.00      0.08      0.15        24
           9       0.52      0.59      0.56       118
          10       0.64      0.70      0.67        30
          11       0.75      0.66      0.70        97
          12       0.62      0.56     

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2543720190779014
F1 Score: 0.5652118083096273
Precision: 0.5947337742913359
Recall: 0.5935984481086324
Hamming Loss: 0.05837423100850211
Average Loss: 0.17895966409882413
AUC-ROC: 0.8852286913307773
AUPR: 0.6138154335008627

Classification Report:
               precision    recall  f1-score   support

           0       0.32      0.54      0.40        69
           1       0.32      0.55      0.41        29
           2       0.00      0.00      0.00        12
           3       0.78      0.83      0.81       249
           4       0.43      0.63      0.51        41
           5       0.48      0.78      0.60       116
           6       0.50      0.04      0.08        24
           7       0.56      0.67      0.61        85
           8       1.00      0.12      0.22        24
           9       0.63      0.43      0.51       118
          10       0.67      0.73      0.70        30
          11       0.60      0.75      0.67        97
          12       0.51      0.70   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.20349761526232116
F1 Score: 0.5768971671614951
Precision: 0.5774024564239305
Recall: 0.6246362754607178
Hamming Loss: 0.06272896937858574
Average Loss: 0.18337917917325527
AUC-ROC: 0.8910144680185305
AUPR: 0.6284059823368177

Classification Report:
               precision    recall  f1-score   support

           0       0.38      0.52      0.44        69
           1       0.25      0.62      0.36        29
           2       0.33      0.08      0.13        12
           3       0.77      0.84      0.81       249
           4       0.37      0.71      0.49        41
           5       0.42      0.81      0.55       116
           6       0.62      0.21      0.31        24
           7       0.57      0.68      0.62        85
           8       0.62      0.21      0.31        24
           9       0.50      0.65      0.57       118
          10       0.55      0.77      0.64        30
          11       0.45      0.80      0.58        97
          12       0.63      0.64  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.23052464228934816
F1 Score: 0.6023116758056284
Precision: 0.5807974501268467
Recall: 0.6610087293889427
Hamming Loss: 0.06086265293426419
Average Loss: 0.17615443868916245
AUC-ROC: 0.8941452714596642
AUPR: 0.6421268597265007

Classification Report:
               precision    recall  f1-score   support

           0       0.43      0.57      0.49        69
           1       0.26      0.69      0.38        29
           2       0.50      0.08      0.14        12
           3       0.81      0.83      0.82       249
           4       0.47      0.68      0.55        41
           5       0.41      0.87      0.56       116
           6       0.70      0.29      0.41        24
           7       0.50      0.76      0.60        85
           8       0.75      0.38      0.50        24
           9       0.61      0.53      0.57       118
          10       0.71      0.73      0.72        30
          11       0.60      0.79      0.68        97
          12       0.56      0.76  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2893481717011129
F1 Score: 0.6101913737840925
Precision: 0.6096782174946653
Recall: 0.6372453928225025
Hamming Loss: 0.053743001313333795
Average Loss: 0.17009071435166312
AUC-ROC: 0.896501194946869
AUPR: 0.6532017243732999

Classification Report:
               precision    recall  f1-score   support

           0       0.51      0.59      0.55        69
           1       0.34      0.48      0.40        29
           2       0.45      0.42      0.43        12
           3       0.78      0.87      0.82       249
           4       0.47      0.66      0.55        41
           5       0.62      0.76      0.68       116
           6       0.59      0.54      0.57        24
           7       0.57      0.76      0.65        85
           8       0.44      0.50      0.47        24
           9       0.60      0.54      0.57       118
          10       0.65      0.73      0.69        30
          11       0.74      0.80      0.77        97
          12       0.53      0.76   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.28537360890302066
F1 Score: 0.6163788186431793
Precision: 0.6031860616978545
Recall: 0.6542192046556741
Hamming Loss: 0.055436510679477434
Average Loss: 0.170186214290465
AUC-ROC: 0.9007409294782063
AUPR: 0.6660992971452214

Classification Report:
               precision    recall  f1-score   support

           0       0.56      0.55      0.55        69
           1       0.33      0.52      0.40        29
           2       0.33      0.50      0.40        12
           3       0.80      0.84      0.82       249
           4       0.44      0.76      0.55        41
           5       0.53      0.79      0.64       116
           6       0.69      0.46      0.55        24
           7       0.50      0.76      0.60        85
           8       0.55      0.50      0.52        24
           9       0.59      0.61      0.60       118
          10       0.64      0.77      0.70        30
          11       0.55      0.84      0.67        97
          12       0.65      0.68   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2631160572337043
F1 Score: 0.6312851290790543
Precision: 0.5958310946775425
Recall: 0.6896217264791464
Hamming Loss: 0.05605861616091795
Average Loss: 0.16611345744208444
AUC-ROC: 0.9040233141279717
AUPR: 0.6754637525118645

Classification Report:
               precision    recall  f1-score   support

           0       0.46      0.77      0.58        69
           1       0.30      0.62      0.40        29
           2       0.55      0.50      0.52        12
           3       0.78      0.87      0.82       249
           4       0.48      0.68      0.57        41
           5       0.47      0.86      0.61       116
           6       0.61      0.58      0.60        24
           7       0.57      0.75      0.65        85
           8       0.52      0.50      0.51        24
           9       0.62      0.58      0.60       118
          10       0.70      0.77      0.73        30
          11       0.71      0.78      0.75        97
          12       0.46      0.84   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2774244833068362
F1 Score: 0.633186715098302
Precision: 0.613420821889722
Recall: 0.6774975751697381
Hamming Loss: 0.05405405405405406
Average Loss: 0.16046133528970466
AUC-ROC: 0.9070949932362251
AUPR: 0.68476502836801

Classification Report:
               precision    recall  f1-score   support

           0       0.43      0.78      0.56        69
           1       0.34      0.52      0.41        29
           2       0.75      0.50      0.60        12
           3       0.79      0.86      0.82       249
           4       0.40      0.78      0.53        41
           5       0.52      0.81      0.63       116
           6       0.62      0.62      0.62        24
           7       0.57      0.74      0.65        85
           8       0.38      0.62      0.48        24
           9       0.66      0.55      0.60       118
          10       0.68      0.77      0.72        30
          11       0.64      0.84      0.72        97
          12       0.62      0.72      0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.26391096979332274
F1 Score: 0.6475411919854721
Precision: 0.5981944117304525
Recall: 0.7216294859359845
Hamming Loss: 0.05654247597981613
Average Loss: 0.1645116158112695
AUC-ROC: 0.9082125606843937
AUPR: 0.6894947137991257

Classification Report:
               precision    recall  f1-score   support

           0       0.49      0.72      0.58        69
           1       0.31      0.72      0.44        29
           2       0.32      0.50      0.39        12
           3       0.80      0.86      0.83       249
           4       0.44      0.73      0.55        41
           5       0.53      0.87      0.66       116
           6       0.67      0.58      0.62        24
           7       0.52      0.78      0.62        85
           8       0.40      0.67      0.50        24
           9       0.60      0.63      0.61       118
          10       0.66      0.77      0.71        30
          11       0.77      0.79      0.78        97
          12       0.58      0.86   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.2647058823529412
F1 Score: 0.6522153294509412
Precision: 0.6112539662822588
Recall: 0.7177497575169738
Hamming Loss: 0.055160019354392756
Average Loss: 0.15770172940779337
AUC-ROC: 0.910852744558916
AUPR: 0.6951222291018185

Classification Report:
               precision    recall  f1-score   support

           0       0.45      0.81      0.58        69
           1       0.33      0.66      0.44        29
           2       0.50      0.50      0.50        12
           3       0.81      0.85      0.83       249
           4       0.45      0.73      0.56        41
           5       0.62      0.77      0.69       116
           6       0.44      0.71      0.54        24
           7       0.64      0.66      0.65        85
           8       0.44      0.62      0.52        24
           9       0.71      0.53      0.60       118
          10       0.72      0.77      0.74        30
          11       0.53      0.88      0.66        97
          12       0.58      0.86   

In [2]:
## Testing
# Loading pretrained model (best model)
print("\n\nTesting\n\n")
model = XLNETWithCNN(num_classes=len(target_list))
model.load_state_dict(torch.load("ohsumed_XLNET_CNN_8_MLTC_model_state.bin"))
model = model.to(device)

# recording starting time
start = time.time()
# Evaluate the model using the test data
eval_model(test_data_loader, model)
# recording end time
end = time.time()
print(f"Total evaluation time: {end - start} seconds")



Testing


Accuracy: 0.23995813162370797
F1 Score: 0.6140723407789481
Precision: 0.5825537336176128
Recall: 0.6664572371004477
Hamming Loss: 0.061323518536427195
Average Loss: 0.17157131110446597
AUC-ROC: 0.8920789519777913
AUPR: 0.6462183040770687

Classification Report:
               precision    recall  f1-score   support

           0       0.41      0.82      0.55       506
           1       0.39      0.70      0.50       233
           2       0.77      0.57      0.66        70
           3       0.81      0.84      0.82      1467
           4       0.49      0.66      0.56       429
           5       0.68      0.73      0.70       632
           6       0.34      0.53      0.41       146
           7       0.70      0.65      0.67       600
           8       0.41      0.62      0.50       129
           9       0.69      0.55      0.61       941
          10       0.67      0.77      0.72       202
          11       0.46      0.84      0.60       548
          12       0.5

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
