# Baseline Evaluation — DistilRoBERTa Emotion Model

This notebook evaluates `j-hartmann/emotion-english-distilroberta-base` against the GoEmotions validation split mapped to 7 classes.

In [None]:
# !pip install transformers torch datasets scikit-learn matplotlib seaborn

In [None]:
from transformers import pipeline
from datasets import load_dataset
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import json

# Load the model
classifier = pipeline(
    "text-classification",
    model="j-hartmann/emotion-english-distilroberta-base",
    return_all_scores=True,
    device=-1
)
print("Model loaded.")

In [None]:
# Label mapping from GoEmotions 28 → 7 classes
ORIGINAL_LABELS = [
    'admiration','amusement','anger','annoyance','approval','caring','confusion',
    'curiosity','desire','disappointment','disapproval','disgust','embarrassment',
    'excitement','fear','gratitude','grief','joy','love','nervousness','optimism',
    'pride','realization','relief','remorse','sadness','surprise','neutral'
]

LABEL_MAP_7 = {
    'anger':   ['anger','annoyance','disapproval'],
    'disgust': ['disgust'],
    'fear':    ['fear','nervousness'],
    'joy':     ['joy','amusement','approval','excitement','gratitude','love','optimism','pride','relief','admiration'],
    'neutral': ['neutral','realization'],
    'sadness': ['sadness','disappointment','grief','remorse','embarrassment'],
    'surprise':['surprise','confusion','curiosity','desire','caring']
}

# Reverse map: original label string → 7-class label
reverse_map = {}
for target, sources in LABEL_MAP_7.items():
    for s in sources:
        reverse_map[s] = target

CLASSES_7 = ['anger','disgust','fear','joy','neutral','sadness','surprise']
print("Mapping ready.")

In [None]:
# Load validation split
ds = load_dataset("google-research-datasets/go_emotions", "raw")
val = ds["validation"]

# Filter to single-label examples with a clear 7-class mapping
samples = []
for row in val:
    if len(row["labels"]) == 1:
        orig = ORIGINAL_LABELS[row["labels"][0]]
        if orig in reverse_map:
            samples.append({"text": row["text"], "true": reverse_map[orig]})

# Take a balanced sample (up to 200 per class for speed)
from collections import defaultdict
per_class = defaultdict(list)
for s in samples:
    per_class[s["true"]].append(s)
balanced = []
for cls, items in per_class.items():
    balanced.extend(items[:200])

print(f"Evaluation samples: {len(balanced)}")
for cls in CLASSES_7:
    print(f"  {cls}: {len(per_class[cls][:200])}")

In [None]:
# Run inference
texts  = [s["text"] for s in balanced]
truths = [s["true"] for s in balanced]

preds = []
for i in range(0, len(texts), 32):
    batch = texts[i:i+32]
    results = classifier(batch)
    for r in results:
        top = max(r, key=lambda x: x["score"])
        preds.append(top["label"].lower())
    if i % 320 == 0:
        print(f"  {i}/{len(texts)}")

print("Inference complete.")

In [None]:
# Metrics
acc = accuracy_score(truths, preds)
report = classification_report(truths, preds, labels=CLASSES_7, output_dict=True)

print(f"Accuracy: {acc:.3f}")
print(classification_report(truths, preds, labels=CLASSES_7))

# Save to artifacts
metrics_out = {
    "model": "j-hartmann/emotion-english-distilroberta-base",
    "overall_accuracy": round(acc, 4),
    "macro_f1": round(report["macro avg"]["f1-score"], 4),
    "per_class": {
        cls: {"precision": round(report[cls]["precision"], 4),
              "recall":    round(report[cls]["recall"], 4),
              "f1":        round(report[cls]["f1-score"], 4)}
        for cls in CLASSES_7
    }
}

with open("artifacts/metrics.json", "w") as f:
    json.dump(metrics_out, f, indent=2)
print("Saved to artifacts/metrics.json")

In [None]:
# Confusion matrix
cm = confusion_matrix(truths, preds, labels=CLASSES_7)
plt.figure(figsize=(8, 6))
sns.heatmap(
    cm / cm.sum(axis=1, keepdims=True),
    annot=True, fmt=".2f",
    xticklabels=CLASSES_7, yticklabels=CLASSES_7,
    cmap="magma", linewidths=0.5
)
plt.title("Normalised Confusion Matrix", fontsize=13)
plt.ylabel("True")
plt.xlabel("Predicted")
plt.tight_layout()
plt.savefig("artifacts/confusion_matrix.png", dpi=150, bbox_inches="tight")
plt.show()