# A simple notebook to load the saved model and test dataset, then compute metrics.

# 1. Imports

In [None]:
import pandas as pd
import torch
from datasets import Dataset, load_metric
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# 2. Load the test data

In [None]:
test_df = pd.read_csv("../data/processed/test.csv")
test_dataset = Dataset.from_pandas(test_df)

# 3. Load the saved model

In [None]:
model_path = "../model_output/multilingual_model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()

def tokenize_fn(example):
    return tokenizer(
        example['full_text'],
        truncation=True,
        padding='max_length',
        max_length=256
    )

test_dataset = test_dataset.map(tokenize_fn, batched=True)
test_dataset = test_dataset.rename_column("category_label", "labels")
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])


# 4. Evaluate

In [None]:
accuracy_metric = load_metric("accuracy")
precision_metric = load_metric("precision")
recall_metric = load_metric("recall")
f1_metric = load_metric("f1")

all_preds = []
all_labels = []

for item in test_dataset:
    input_ids = item["input_ids"].unsqueeze(0)
    attention_mask = item["attention_mask"].unsqueeze(0)
    labels = item["labels"].unsqueeze(0)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    pred = torch.argmax(logits, dim=1).item()

    all_preds.append(pred)
    all_labels.append(labels.item())

acc = accuracy_metric.compute(predictions=all_preds, references=all_labels)["accuracy"]
prec = precision_metric.compute(predictions=all_preds, references=all_labels, average="weighted")["precision"]
rec = recall_metric.compute(predictions=all_preds, references=all_labels, average="weighted")["recall"]
f1 = f1_metric.compute(predictions=all_preds, references=all_labels, average="weighted")["f1"]

print(f"Accuracy: {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall: {rec:.4f}")
print(f"F1 Score: {f1:.4f}")

# A confusion matrix (using scikit-learn) to see how each category is predicted.