In [1]:
import sys
sys.path.append("../")
from utils.models import DistilBertBaseUncased
from utils.evaluate import evaluate_slm_performance
from utils.nlp import TextDatasetSLM

import torch
import numpy as np
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, hamming_loss, classification_report

In [2]:
model = DistilBertBaseUncased()
model.load_state_dict(torch.load("../Models/distilbert.pth"))
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

Number of Parameters: 66.97M


In [3]:
xtest = np.load("../Data/Testing/x_test.npy", allow_pickle=True)
ytest = np.load("../Data/Testing/y_test.npy", allow_pickle=True)

In [4]:
device= 'cuda'
test_set = TextDatasetSLM(xtest, ytest, tokenizer)
loader = DataLoader(test_set, 32, True)
true_labels = []
pred_labels = []
model.to(device)
with torch.inference_mode():
    for batch in loader:
        tokens = batch['input_ids'].to(device, dtype=torch.long)
        masks = batch['attention_mask'].to(device, dtype=torch.long)
        labels = batch['targets'].to(device, dtype=torch.float)
        logits = model(tokens, masks)
        preds = torch.round(torch.sigmoid(logits))
        true_labels.extend(labels.cpu().numpy())
        pred_labels.extend(preds.cpu().numpy())

true_labels = np.array(true_labels)
pred_labels = np.array(pred_labels)

h_loss = hamming_loss(true_labels, pred_labels)
accuracy = accuracy_score(true_labels, pred_labels)
print(classification_report(true_labels, pred_labels))
print(f'Accuracy: {accuracy * 100:.2f}%')
print(f'Hamming Loss: {h_loss:.4f}')

              precision    recall  f1-score   support

           0       0.90      0.41      0.56        22
           1       0.60      0.51      0.55       251
           2       0.76      0.62      0.68       163
           3       0.00      0.00      0.00        22
           4       0.59      0.50      0.54       158
           5       0.70      0.83      0.76       693
           6       0.75      0.64      0.69       148
           7       0.65      0.62      0.63        52
           8       0.67      0.39      0.49        46
           9       0.55      0.55      0.55        77
          10       0.62      0.64      0.63       352
          11       0.72      0.65      0.68       502
          12       0.59      0.56      0.57       212
          13       0.42      0.32      0.36       125
          14       0.66      0.66      0.66       304
          15       0.79      0.58      0.67       145
          16       0.70      0.58      0.64       168
          17       0.59    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
