In [1]:
import sys
import os
sys.path.append(os.path.abspath("..")) 

import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from models.distilbert_model import DistilBERTClassifier
from data.dataset_loader import load_dataset
from utils.config import NUM_CLASSES

  from .autonotebook import tqdm as notebook_tqdm


Tokens: ['this', 'animal', 'has', 'black', 'and', 'white', 'stripes', ',', 'lives', 'in', 'africa', ',', 'and', 'eats', 'grass', '.']
Total tokens: 16


In [2]:
model = DistilBERTClassifier(num_classes=NUM_CLASSES)
model.load_state_dict(torch.load("../saved_models/distilbert_animal_classifier.pth", map_location="cpu"))
model.eval()

DistilBERTClassifier(
  (encoder): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1)

In [3]:
_, val_loader = load_dataset("../data/processed/animal_dataset.csv")

In [5]:
def evaluate_on_validation(model, val_loader, device):
    model.to(device)
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro')

    print("\nEvaluation on Validation Set:")
    print(f"Accuracy:  {acc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evaluate_on_validation(model, val_loader, device)


Evaluation on Validation Set:
Accuracy:  0.8944
Precision: 0.8952
Recall:    0.8944
F1 Score:  0.8939
