In [None]:
import torch
from torch.utils.data import DataLoader
import seaborn as sns
import matplotlib.pyplot as plt

from noisy_intents.data import DydaDA, NoDA
from noisy_intents.training import autodetect_device
from noisy_intents.eval import compute_metrics
from transformers import BertTokenizer
from torchmetrics import Accuracy, Precision, Recall, F1Score, ConfusionMatrix

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

## Load model

In [None]:
device = autodetect_device()

In [None]:
model = torch.load("best_bert_finetuned.pt")

In [None]:
model.to(device);

## On test set

In [None]:
test_data = DydaDA.from_hugging_face("test", tokenizer, max_len=128)
test_loader = DataLoader(test_data, batch_size=64, num_workers=8, drop_last=True)

In [None]:
cm = compute_metrics(
    model,
    test_loader,
    device,
    metrics=[
        ConfusionMatrix("multiclass", num_classes=4),
    ],
)
cm = cm[0].cpu().numpy().astype(int)
cm

In [None]:
labels = ["Commissive", "Directive", "Question", "Inform"]
sns.heatmap(cm / cm.sum(), annot=True, fmt=".2%", cmap="Blues")
plt.gca().xaxis.set_ticklabels(labels)
plt.gca().yaxis.set_ticklabels(labels)
plt.xlabel("Predicted class")
plt.ylabel("True class")

In [None]:
metrics = [
    Accuracy("multiclass", num_classes=4, average=None),
    Precision("multiclass", num_classes=4, average=None),
    Recall("multiclass", num_classes=4, average=None),
    F1Score("multiclass", num_classes=4, average=None),
    Accuracy("multiclass", num_classes=4, average="micro"),
    Accuracy("multiclass", num_classes=4, average="macro"),
    Precision("multiclass", num_classes=4, average="macro"),
    Recall("multiclass", num_classes=4, average="macro"),
    F1Score("multiclass", num_classes=4, average="macro"),
    F1Score("multiclass", num_classes=4, average="micro"),
]

In [None]:
compute_metrics(model, test_loader, device, metrics=metrics)

## On NoDA

In [None]:
noda_data = NoDA(tokenizer, max_len=128)
noda_loader = DataLoader(noda_data, batch_size=64, num_workers=8, drop_last=True)

In [None]:
cm2 = compute_metrics(
    model,
    noda_loader,
    device,
    metrics=[
        ConfusionMatrix("multiclass", num_classes=4),
    ],
)
cm2 = cm2[0].cpu().numpy().astype(int)
cm2

In [None]:
labels = ["Commissive", "Directive", "Question", "Inform"]
sns.heatmap(cm2 / cm2.sum(), annot=True, fmt=".2%", cmap="Blues")
plt.gca().xaxis.set_ticklabels(labels)
plt.gca().yaxis.set_ticklabels(labels)
plt.xlabel("Predicted class")
plt.ylabel("True class")

In [None]:
compute_metrics(model, noda_loader, device, metrics=metrics)