In [1]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, classification_report
from torch.utils.data import (TensorDataset, DataLoader, RandomSampler,
                              SequentialSampler)

device = 'cpu'

In [2]:
def data_loader(test_inputs, test_labels, batch_size=50):
    """Convert test set to torch.Tensors and load them to DataLoader.
    """

    # Convert data type to torch.Tensor
    test_inputs, test_labels = torch.tensor(test_inputs), torch.tensor(test_labels)

    # Create DataLoader for test data
    test_data = TensorDataset(test_inputs, test_labels)
    test_sampler = SequentialSampler(test_data)
    test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

    return test_dataloader


In [None]:
def evaluate_model(model, test_inputs, test_labels, loss_fn):
    
    test_dataloader = data_loader(test_inputs, test_labels)
    model.eval()

    # Tracking variables
    total_loss = 0.0
    total_correct = 0

    # Evaluate data for one epoch
    for batch in test_dataloader:
        # Load batch to GPU
        b_input_ids, b_labels = tuple(t.to(device) for t in batch)

        with torch.no_grad():
            # Forward pass
            logits = model(b_input_ids)

            # Compute loss
            loss = loss_fn(logits, b_labels)
            total_loss += loss.item()

            # Compute predictions and accuracy
            preds = torch.argmax(logits, axis=1)
            total_correct += torch.sum(preds == b_labels).item()

    # Calculate average loss and accuracy
    avg_loss = total_loss / len(test_dataloader)
    accuracy = total_correct / len(test_dataloader.dataset)
    preds = preds.cpu().numpy()
    labels = b_labels.cpu().numpy()
    confusion_mat = confusion_matrix(labels, preds)
    report = classification_report(labels, preds)
    f1_score_val = f1_score(labels, preds, average='macro')
    precision_val = precision_score(labels, preds, average='macro')
    recall_val = recall_score(labels, preds, average='macro')
    auc_val = roc_auc_score(labels, preds, multi_class='ovo')
    
    print(f"Average test loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision_val:.4f}, Recall: {recall_val:.4f}, F1-score: {f1_score_val:.4f}, AUC-ROC: {auc_val:.4f}")
    print(f"Confusion Matrix: \n{confusion_mat}")
    print(f"Classification Report: \n{report}")

    return