# GoEmotions – DistilBERT Baseline (Colab)
*Generated on 2025-09-28 11:50 UTC*

This notebook trains a **multi-label** emotion classifier on **GoEmotions** using Hugging Face Transformers.
It supports both the **raw (27 + neutral)** and **simplified (6 + neutral)** schema, auto-creates a validation split if needed, and can generate a one-click **PDF report** with metrics and per-label F1 charts.

**What you get:**
- Reproducible baseline (DistilBERT) with simple args
- Micro/Macro/Weighted F1 on validation & test
- Optional threshold sweep to optimize Macro-F1
- Efficiency snapshot (trainable params, basic latency)
- Auto-generated PDF report

In [None]:
import sys, platform, torch

print("Python:", sys.version)
print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
else:
    print("Running on CPU – training will be slower.")

In [None]:
!pip -q install transformers==4.44.2 datasets==2.21.0 scikit-learn==1.5.1                accelerate==0.34.2 pandas==2.2.2 matplotlib==3.9.2 reportlab==4.2.2

## Configure run
Update the arguments below as needed. Common tweaks:
- `dataset_config`: `"raw"` vs `"simplified"`
- `epochs`: try 3–5 for a quick baseline
- `eval_threshold`: try 0.2–0.7 if you *don’t* run the sweep

In [None]:
from dataclasses import dataclass

@dataclass
class Args:
    model_name: str = "distilbert-base-uncased"
    dataset_name: str = "go_emotions"
    dataset_config: str = "raw"   # "raw" or "simplified"
    max_length: int = 128
    batch_size: int = 32
    epochs: int = 3
    lr: float = 5e-5
    weight_decay: float = 0.01
    output_dir: str = "/content/outputs_distilbert_goemotions"
    seed: int = 42
    eval_threshold: float = 0.5
    report_to: str = "none"
    val_frac: float = 0.1

args = Args()
args

## Train & Evaluate

In [None]:
import os, json, time
from typing import List

import numpy as np
import torch
from datasets import load_dataset, DatasetDict, Sequence
from transformers import (AutoTokenizer, AutoModelForSequenceClassification,
                          DataCollatorWithPadding, Trainer, TrainingArguments)
from sklearn.metrics import f1_score, classification_report

RAW_EMOTIONS = [
    "admiration","amusement","anger","annoyance","approval","caring","confusion",
    "curiosity","desire","disappointment","disapproval","disgust","embarrassment",
    "excitement","fear","gratitude","grief","joy","love","nervousness","optimism",
    "pride","realization","relief","remorse","sadness","surprise","neutral"
]
SIMPLIFIED_EMOTIONS = ["anger","disgust","fear","joy","sadness","surprise","neutral"]

def ensure_validation(ds: DatasetDict, seed: int, val_frac: float) -> DatasetDict:
    if "validation" in ds:
        return ds
    split = ds["train"].train_test_split(test_size=val_frac, seed=seed, stratify_by_column=None)
    return DatasetDict(train=split["train"], validation=split["test"], test=ds["test"] if "test" in ds else split["test"])

def detect_schema_and_label_cols(ds, config: str):
    cols = ds["train"].column_names
    if "labels" in cols:
        return "list", None
    expected = RAW_EMOTIONS if config == "raw" else SIMPLIFIED_EMOTIONS
    label_cols = [c for c in expected if c in cols]
    if label_cols:
        return "wide", label_cols
    raise KeyError(f"Could not detect labels. Columns found: {cols}")

def get_label_names(ds, schema: str, label_cols, config: str) -> List[str]:
    if schema == "list":
        feat = ds["train"].features["labels"]
        if isinstance(feat, Sequence):
            return feat.feature.names
        # Fallback: infer max label id
        max_id = 0
        for ex in ds["train"]["labels"][:1000]:
            if isinstance(ex, list) and len(ex) > 0:
                max_id = max(max_id, max(ex))
        return [str(i) for i in range(max_id + 1)]
    else:
        return label_cols

def attach_labels(examples, schema: str, label_names: List[str], label_cols=None):
    n = len(examples["text"])
    y = np.zeros((n, len(label_names)), dtype=np.float32)
    if schema == "list":
        for i, lbls in enumerate(examples["labels"]):
            for j in lbls:
                if 0 <= j < len(label_names):
                    y[i, j] = 1.0
    else:
        for idx, name in enumerate(label_names):
            y[:, idx] = np.array(examples[name], dtype=np.float32)
    return {"labels": y.tolist()}

def compute_metrics_builder(threshold: float, label_names: List[str], out_dir: str):
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        probs = 1 / (1 + np.exp(-logits))
        preds = (probs >= threshold).astype(int)
        micro = f1_score(labels, preds, average="micro", zero_division=0)
        macro = f1_score(labels, preds, average="macro", zero_division=0)
        weighted = f1_score(labels, preds, average="weighted", zero_division=0)
        rep = classification_report(labels, preds, target_names=label_names, zero_division=0, output_dict=True)
        with open(os.path.join(out_dir, "classification_report.json"), "w") as f:
            json.dump(rep, f, indent=2)
        return {"f1_micro": float(micro), "f1_macro": float(macro), "f1_weighted": float(weighted)}
    return compute_metrics

def count_trainable_parameters(model: torch.nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def measure_latency(model, tokenizer, device: str, max_length: int = 128, batch_size: int = 32, iters: int = 30):
    model.eval()
    sents = ["This is a sample sentence about feelings."] * batch_size
    with torch.no_grad():
        # warmup
        for _ in range(5):
            _ = model(**tokenizer(sents, return_tensors="pt", padding=True, truncation=True,
                                  max_length=max_length).to(device))
        start = time.time()
        for _ in range(iters):
            _ = model(**tokenizer(sents, return_tensors="pt", padding=True, truncation=True,
                                  max_length=max_length).to(device))
        end = time.time()
    return (end - start) * 1000.0 / iters

# Set seeds & dirs
os.makedirs(args.output_dir, exist_ok=True)
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# Load dataset
ds = load_dataset(args.dataset_name, args.dataset_config)
ds = ensure_validation(ds, seed=args.seed, val_frac=args.val_frac)

# Schema + labels
schema, label_cols = detect_schema_and_label_cols(ds, args.dataset_config)
label_names = get_label_names(ds, schema, label_cols, args.dataset_config)

# Tokenizer & preprocess
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)

def preprocess(batch):
    enc = tokenizer(batch["text"], truncation=True, max_length=args.max_length)
    enc.update(attach_labels(batch, schema, label_names, label_cols))
    return enc

remove_cols = [c for c in ds["train"].column_names if c != "text"]
encoded = ds.map(preprocess, batched=True, remove_columns=remove_cols)

# Collator & model
collator = DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=8 if torch.cuda.is_available() else None)
model = AutoModelForSequenceClassification.from_pretrained(
    args.model_name,
    num_labels=len(label_names),
    problem_type="multi_label_classification",
    id2label={i: n for i, n in enumerate(label_names)},
    label2id={n: i for i, n in enumerate(label_names)},
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

targs = TrainingArguments(
    output_dir=args.output_dir,
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    num_train_epochs=args.epochs,
    learning_rate=args.lr,
    weight_decay=args.weight_decay,
    logging_steps=100,
    report_to=args.report_to,
    seed=args.seed,
)

trainer = Trainer(
    model=model,
    args=targs,
    train_dataset=encoded["train"],
    eval_dataset=encoded["validation"],
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics_builder(args.eval_threshold, label_names, args.output_dir),
)

train_metrics = trainer.train()
val_metrics = trainer.evaluate()
test_metrics = trainer.evaluate(encoded["test"]) if "test" in encoded else {}

with open(os.path.join(args.output_dir, "metrics.json"), "w") as f:
    json.dump({"train": train_metrics.metrics, "val": val_metrics, "test": test_metrics,
               "label_names": label_names, "args": vars(args)}, f, indent=2)

params = count_trainable_parameters(model)
latency_ms = measure_latency(model, tokenizer, device=device, max_length=args.max_length, batch_size=32, iters=30)
with open(os.path.join(args.output_dir, "efficiency_snapshot.json"), "w") as f:
    json.dump({"trainable_params": int(params), "avg_latency_ms_per_batch32": float(latency_ms)}, f, indent=2)

print("Done. Saved outputs in:", args.output_dir)

## (Optional) Threshold sweep for Macro-F1
Run this to find a better decision threshold on the **validation** set. It reads logits & labels from the last eval pass (re-computes if needed).

In [None]:
import os, json, numpy as np
from sklearn.metrics import f1_score

# Re-run eval to get logits & labels
eval_out = trainer.predict(encoded["validation"])
logits = eval_out.predictions
labels = eval_out.label_ids
probs = 1/(1+np.exp(-logits))

def eval_at(th):
    preds = (probs >= th).astype(int)
    return dict(
        micro = f1_score(labels, preds, average="micro", zero_division=0),
        macro = f1_score(labels, preds, average="macro", zero_division=0),
        weighted = f1_score(labels, preds, average="weighted", zero_division=0),
        th = th
    )

ths = np.round(np.linspace(0.1, 0.9, 17), 3)
scores = [eval_at(th) for th in ths]
best = max(scores, key=lambda d: d["macro"])

print("Best (by Macro-F1):", best)
with open(os.path.join(args.output_dir, "threshold_sweep.json"), "w") as f:
    json.dump({"scores": scores, "best": best}, f, indent=2)

## Generate PDF report

In [None]:
import os, json
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import cm
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib import colors

outputs_dir = args.output_dir
metrics_path = os.path.join(outputs_dir, "metrics.json")
report_path  = os.path.join(outputs_dir, "classification_report.json")
eff_path     = os.path.join(outputs_dir, "efficiency_snapshot.json")

assert os.path.exists(metrics_path), "metrics.json not found"
assert os.path.exists(report_path), "classification_report.json not found"
assert os.path.exists(eff_path), "efficiency_snapshot.json not found"

METRICS = json.load(open(metrics_path))
REPORT  = json.load(open(report_path))
EFF     = json.load(open(eff_path))

def safe_get(metric_block, keys):
    for k in keys:
        if k in metric_block and metric_block[k] is not None:
            return metric_block[k]
    return None

val = METRICS.get("val", {}) or {}
test = METRICS.get("test", {}) or {}
val_micro = safe_get(val, ["f1_micro", "eval_f1_micro"])
val_macro = safe_get(val, ["f1_macro", "eval_f1_macro"])
val_weighted = safe_get(val, ["f1_weighted", "eval_f1_weighted"])
val_loss = safe_get(val, ["loss", "eval_loss"])
t_micro = safe_get(test, ["f1_micro", "eval_f1_micro"])
t_macro = safe_get(test, ["f1_macro", "eval_f1_macro"])
t_weighted = safe_get(test, ["f1_weighted", "eval_f1_weighted"])
t_loss = safe_get(test, ["loss", "eval_loss"])

label_names = METRICS.get("label_names", [])
cfg = METRICS.get("args", {})

# Build top/bottom-10 chart
rows = []
for lbl in label_names:
    stats = REPORT.get(lbl, {})
    if isinstance(stats, dict) and "f1-score" in stats:
        rows.append({"label": lbl,
                     "precision": stats.get("precision", 0.0),
                     "recall": stats.get("recall", 0.0),
                     "f1": stats.get("f1-score", 0.0),
                     "support": stats.get("support", 0)})
chart_path = None
if rows:
    df = pd.DataFrame(rows).sort_values("f1", ascending=False).reset_index(drop=True)
    top = df.head(10); bottom = df.tail(10)
    labels = list(top["label"]) + ["..."] + list(bottom["label"])
    values = list(top["f1"]) + [None] + list(bottom["f1"])
    x = list(range(len(values)))
    sep_idx = labels.index("...")
    plt.figure(figsize=(7,5))
    plt.bar(x[:sep_idx], values[:sep_idx])
    plt.bar(x[sep_idx+1:], values[sep_idx+1:])
    plt.xticks(x, labels, rotation=45, ha="right")
    plt.title("Per-label F1 (Top-10 & Bottom-10)")
    plt.tight_layout()
    chart_path = os.path.join(outputs_dir, "label_f1_chart.png")
    plt.savefig(chart_path, dpi=200)
    plt.close()

# Build PDF
pdf_path = os.path.join(outputs_dir, "goemotions_baseline_report.pdf")
styles = getSampleStyleSheet()
title, h2, body = styles["Title"], styles["Heading2"], styles["BodyText"]
doc = SimpleDocTemplate(pdf_path, pagesize=A4, leftMargin=2*cm, rightMargin=2*cm, topMargin=1.5*cm, bottomMargin=1.5*cm)
story = []
story.append(Paragraph("GoEmotions Baseline Report", title))
story.append(Paragraph(datetime.utcnow().strftime("%Y-%m-%d %H:%M UTC"), body)); story.append(Spacer(1,12))

summary = f"""
<b>Model:</b> {cfg.get('model_name')}<br/>
<b>Dataset:</b> {cfg.get('dataset_name')} ({cfg.get('dataset_config')})<br/>
<b>Training:</b> epochs={cfg.get('epochs', cfg.get('num_train_epochs', 3))},
batch_size={cfg.get('batch_size', cfg.get('per_device_train_batch_size', 32))},
max_length={cfg.get('max_length', 128)},
threshold={cfg.get('eval_threshold', 0.5)}<br/>
"""
story.append(Paragraph("Run Summary", h2))
story.append(Paragraph(summary, body)); story.append(Spacer(1,8))

def fmt(x):
    try: return f"{float(x):.4f}"
    except: return "nan"

tbl = [
    ["", "F1 (micro)", "F1 (macro)", "F1 (weighted)", "Loss"],
    ["Validation", fmt(val_micro), fmt(val_macro), fmt(val_weighted), fmt(val_loss)],
    ["Test", fmt(t_micro), fmt(t_macro), fmt(t_weighted), fmt(t_loss)],
]
table = Table(tbl, hAlign="LEFT")
table.setStyle(TableStyle([
    ("BACKGROUND", (0,0), (-1,0), colors.lightgrey),
    ("GRID", (0,0), (-1,-1), 0.25, colors.grey),
    ("FONTNAME", (0,0), (-1,0), "Helvetica-Bold"),
    ("ALIGN", (1,1), (-1,-1), "CENTER"),
]))
story.append(Paragraph("Evaluation Metrics", h2)); story.append(table); story.append(Spacer(1,8))

eff_txt = f"""
<b>Trainable parameters:</b> {EFF.get('trainable_params','?'):,}<br/>
<b>Avg latency (ms) per batch of 32:</b> {EFF.get('avg_latency_ms_per_batch32','?'):.2f}
"""
story.append(Paragraph("Efficiency Snapshot", h2))
story.append(Paragraph(eff_txt, body)); story.append(Spacer(1,8))

if chart_path and os.path.exists(chart_path):
    story.append(Paragraph("Per-label F1 — Top & Bottom 10", h2))
    story.append(Image(chart_path, width=16*cm, height=10*cm)); story.append(Spacer(1,8))

notes = """
<b>Notes:</b><br/>
• Fixed decision threshold by default. Run the threshold sweep cell to optimize Macro-F1.<br/>
• Latency is a simple forward pass benchmark; serving latency varies by hardware and batch size.
"""
story.append(Paragraph("Notes", h2))
story.append(Paragraph(notes, body))

doc.build(story)
print("PDF written to:", pdf_path)