In [16]:
import torch
import pandas as pd
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

In [17]:
model_path = "../models/bert-ag-news"
tokenizer_ft = AutoTokenizer.from_pretrained(model_path)
model_ft = AutoModelForSequenceClassification.from_pretrained(model_path)

model_ft.eval()

label_names = ["World", "Sports", "Business", "Sci/Tech"]

Loading weights:   0%|          | 0/201 [00:00<?, ?it/s]

In [18]:
dataset = load_dataset("ag_news")
test_texts = dataset["test"]["text"][:600]
test_labels = dataset["test"]["label"][:600]

In [19]:
def predict(text, model, tokenizer):
    """Predict label for a single text"""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)

    probs = torch.softmax(outputs.logits, dim=-1)
    pred_id: int = probs.argmax().item()  # type: ignore
    confidence = probs[0, pred_id].item()

    return pred_id, label_names[pred_id], confidence

In [20]:
predictions = []

for text in test_texts:
    pred_id, _, _ = predict(text, model_ft, tokenizer_ft)
    predictions.append(pred_id)
print(f"Predictions are done: {len(predictions)}")

Predictions are done: 600


In [22]:
df = pd.DataFrame(
    {
        "text": test_texts,
        "true_label": [label_names[i] for i in test_labels],
        "pred_label": [label_names[i] for i in predictions],
    }
)

df["is_error"] = df["true_label"] != df["pred_label"]

errors = df[df["is_error"]]
print(f"Total errors: {len(errors)} / {len(df)}")
errors.head(10)

Total errors: 71 / 600


Unnamed: 0,text,true_label,pred_label,is_error
9,"Card fraud unit nets 36,000 cards In its first...",Sci/Tech,Business,True
20,IBM to hire even more new workers By the end o...,Sci/Tech,Business,True
23,Some People Not Eligible to Get in on Google I...,Sci/Tech,Business,True
24,Rivals Try to Turn Tables on Charles Schwab By...,Sci/Tech,Business,True
27,Tougher rules won't soften Law's game FOXBOROU...,Sports,World,True
28,Shoppach doesn't appear ready to hit the next ...,Sports,World,True
37,1994 Law Designed to Preserve Guard Jobs (AP) ...,World,Sci/Tech,True
56,India's Tata expands regional footprint via Na...,World,Business,True
79,Live: Olympics day four Richard Faulds and Ste...,World,Sports,True
83,Intel to delay product aimed for high-definiti...,Business,Sci/Tech,True


In [24]:
error_pairs = (
    errors.groupby(["true_label", "pred_label"]).size().sort_values(ascending=False)
)
print("Most common errors:")
print(error_pairs.head(10))

Most common errors:
true_label  pred_label
Sci/Tech    Business      21
Sports      World         12
Business    Sci/Tech      10
World       Business      10
Business    World          9
World       Sports         3
            Sci/Tech       2
Business    Sports         1
Sci/Tech    Sports         1
            World          1
dtype: int64
