# 3.0 Fine-tuning: DistilBERT

In [None]:
# Imports & global configuration
import os
import random
from typing import Dict, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    set_seed,
)
import evaluate
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    confusion_matrix,
    ConfusionMatrixDisplay,
    classification_report,
    brier_score_loss,
)
from sklearn.calibration import calibration_curve

import optuna

plt.style.use("default")
plt.rcParams["figure.figsize"] = (6, 4)

SEED = 42
VAL_SIZE = 0.10
TEST_SIZE = 0.10

BASE_CFG = {
    "model_name": "distilbert-base-uncased",
    "max_length": 256,
    "lr": 2e-5,
    "batch_size": 16,
    "epochs": 3,
    "weight_decay": 0.01,
    "warmup_ratio": 0.06,
}

random.seed(SEED)
np.random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
set_seed(SEED)

In [None]:
# Load datasets
DATASET_IMDB = "imdb"
DATASET_RT   = "rotten_tomatoes"

ds_imdb = load_dataset(DATASET_IMDB)
ds_rt   = load_dataset(DATASET_RT)

def to_df(ds, text_key="text", label_key="label"):
    return pd.DataFrame({"text": ds[text_key], "label": ds[label_key]})

# IMDB
imdb_train_df = to_df(ds_imdb["train"])
imdb_test_df  = to_df(ds_imdb["test"])
imdb_full = pd.concat([imdb_train_df, imdb_test_df], ignore_index=True)

# Rotten Tomatoes
rt_train_df = to_df(ds_rt["train"])
rt_val_df   = to_df(ds_rt["validation"])
rt_test_df  = to_df(ds_rt["test"])
rt_full = pd.concat([rt_train_df, rt_val_df, rt_test_df], ignore_index=True)

print("IMDB full size:", len(imdb_full))
print("RT full size:", len(rt_full))
imdb_full.head()

In [None]:
# Create fixed 80/10/10 stratified splits
def stratified_indices(df: pd.DataFrame,
                       val_size: float,
                       test_size: float,
                       seed: int) -> Dict[str, np.ndarray]:
    X = df.index.values
    y = df["label"].values

    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y,
        test_size=val_size + test_size,
        stratify=y,
        random_state=seed,
    )

    rel_test = test_size / (val_size + test_size)
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp,
        test_size=rel_test,
        stratify=y_temp,
        random_state=seed,
    )

    return {
        "train": np.sort(X_train),
        "val":   np.sort(X_val),
        "test":  np.sort(X_test),
    }

def apply_indices(df: pd.DataFrame,
                  idx: Dict[str, np.ndarray]) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    tr = df.loc[idx["train"]].reset_index(drop=True)
    va = df.loc[idx["val"]].reset_index(drop=True)
    te = df.loc[idx["test"]].reset_index(drop=True)
    return tr, va, te

imdb_idx = stratified_indices(imdb_full, VAL_SIZE, TEST_SIZE, SEED)
rt_idx   = stratified_indices(rt_full,   VAL_SIZE, TEST_SIZE, SEED)

imdb_tr, imdb_va, imdb_te = apply_indices(imdb_full, imdb_idx)
rt_tr,   rt_va,   rt_te   = apply_indices(rt_full,   rt_idx)

for name, (tr, va, te) in {
    "IMDB": (imdb_tr, imdb_va, imdb_te),
    "RT":   (rt_tr,  rt_va,  rt_te),
}.items():
    print(f"{name} splits â€“ train: {len(tr)}, val: {len(va)}, test: {len(te)}")

In [None]:
# Tokeniser and HF dataset builder
tokenizer = AutoTokenizer.from_pretrained(BASE_CFG["model_name"], use_fast=True)

def build_hf(train_df, val_df, test_df, max_length):
    """
    Convert pandas DataFrames into tokenised HF Datasets
    suitable for Trainer.
    """
    def tok(batch):
        return tokenizer(
            batch["text"],
            truncation=True,
            max_length=max_length,
        )

    def to_hf(df):
        return Dataset.from_pandas(df, preserve_index=False)

    hf_tr = to_hf(train_df)
    hf_va = to_hf(val_df)
    hf_te = to_hf(test_df)

    hf_tr = hf_tr.map(tok, batched=True)
    hf_va = hf_va.map(tok, batched=True)
    hf_te = hf_te.map(tok, batched=True)

    keep_cols = {"input_ids", "attention_mask", "label"}

    def prune(ds):
        drop = [c for c in ds.column_names if c not in keep_cols]
        return ds.remove_columns(drop)

    hf_tr = prune(hf_tr)
    hf_va = prune(hf_va)
    hf_te = prune(hf_te)

    return hf_tr, hf_va, hf_te

In [None]:
# Metrics
acc_metric = evaluate.load("accuracy")
f1_metric  = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    logits = np.array(logits)
    logits = logits - logits.max(axis=1, keepdims=True)
    probs  = np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True)
    preds  = probs.argmax(axis=1)

    return {
        "accuracy": acc_metric.compute(predictions=preds, references=labels)["accuracy"],
        "macro_f1": f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"],
        "brier":    brier_score_loss(labels, probs[:, 1]),
    }

In [None]:
# TrainingArguments helper
def make_training_args(cfg, save_checkpoints=False):
    return TrainingArguments(
        output_dir="_finetune_tmp",  
        evaluation_strategy="epoch",
        save_strategy="epoch" if save_checkpoints else "no",
        logging_strategy="epoch",    
        learning_rate=float(cfg["lr"]),
        per_device_train_batch_size=int(cfg["batch_size"]),
        per_device_eval_batch_size=int(cfg["batch_size"]),
        num_train_epochs=int(cfg["epochs"]),
        weight_decay=float(cfg["weight_decay"]),
        warmup_ratio=float(cfg["warmup_ratio"]),
        load_best_model_at_end=save_checkpoints,
        report_to=[],                 
        seed=SEED,
    )

In [None]:
# Optuna: hyperparameter search on IMDB
def objective(trial, train_df, val_df, cfg_base):
    cfg = dict(cfg_base)
    cfg["lr"]           = trial.suggest_loguniform("lr", 1e-5, 5e-4)
    cfg["batch_size"]   = trial.suggest_categorical("batch_size", [8, 16, 32])
    cfg["weight_decay"] = trial.suggest_loguniform("weight_decay", 1e-4, 1e-1)
    cfg["warmup_ratio"] = trial.suggest_uniform("warmup_ratio", 0.0, 0.2)
    cfg["epochs"]       = trial.suggest_int("epochs", 2, 4)

    hf_tr, hf_va, _ = build_hf(train_df, val_df, val_df, cfg["max_length"])
    model = AutoModelForSequenceClassification.from_pretrained(
        cfg["model_name"],
        num_labels=2,
    ).to(device)

    args = make_training_args(cfg, save_checkpoints=False)

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=hf_tr,
        eval_dataset=hf_va,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    eval_results = trainer.evaluate()
    # Minimise validation loss
    return float(eval_results["eval_loss"])

study = optuna.create_study(direction="minimize")
study.optimize(
    lambda t: objective(t, imdb_tr, imdb_va, BASE_CFG),
    n_trials=8,
    show_progress_bar=True,
)

print("Best trial:", study.best_trial.number)
print("Best value (val loss):", study.best_value)
print("Best params:", study.best_params)

CFG_TUNED = dict(BASE_CFG)
CFG_TUNED.update(study.best_params)
CFG_TUNED

In [None]:
# Fine-tune and evaluate on one dataset
def finetune_and_eval(tag: str,
                      train_df: pd.DataFrame,
                      val_df: pd.DataFrame,
                      test_df: pd.DataFrame,
                      cfg: dict,
                      n_error_examples: int = 10) -> dict:

    hf_tr, hf_va, hf_te = build_hf(train_df, val_df, test_df, cfg["max_length"])

    model = AutoModelForSequenceClassification.from_pretrained(
        cfg["model_name"],
        num_labels=2,
    ).to(device)

    args = make_training_args(cfg, save_checkpoints=False)

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=hf_tr,
        eval_dataset=hf_va,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    trainer.train()

    # Evaluation on test set
    preds = trainer.predict(hf_te)
    logits = preds.predictions
    logits = logits - logits.max(axis=1, keepdims=True)
    probs  = np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True)
    y_prob = probs[:, 1]
    y_pred = probs.argmax(axis=1)
    y_true = np.array(test_df["label"].tolist())

    acc   = accuracy_score(y_true, y_pred)
    f1m   = f1_score(y_true, y_pred, average="macro")
    brier = brier_score_loss(y_true, y_prob)

    print(f"\nTest metrics for {tag}:")
    print(f"Accuracy : {acc:.4f}")
    print(f"Macro-F1: {f1m:.4f}")
    print(f"Brier   : {brier:.4f}")
    print("\nClassification report:")
    print(classification_report(
        y_true,
        y_pred,
        target_names=["Negative", "Positive"],
    ))

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    disp = ConfusionMatrixDisplay(cm, display_labels=["Negative", "Positive"])
    disp.plot(values_format="d")
    plt.title(f"Fine-tune: {tag} Confusion Matrix")
    plt.tight_layout()
    plt.show()

    # Reliability curve
    fracs, means = calibration_curve(y_true, y_prob, n_bins=10, strategy="quantile")
    plt.figure()
    plt.plot([0, 1], [0, 1], "k--", label="Perfect calibration")
    plt.plot(means, fracs, marker="o", label="Fine-tuned")
    plt.xlabel("Predicted probability (bin avg)")
    plt.ylabel("Empirical positive rate")
    plt.title(f"Fine-tune: {tag} Reliability Curve")
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Error analysis
    print(f"\nMisclassified examples for {tag} (up to {n_error_examples}):\n")
    text_list = test_df["text"].tolist()
    errors = []
    for t, yt, yp, p in zip(text_list, y_true, y_pred, y_prob):
        if yt != yp:
            errors.append((t, yt, yp, p))

    for i, (t, yt, yp, p) in enumerate(errors[:n_error_examples], 1):
        true_lbl = "Positive" if yt == 1 else "Negative"
        pred_lbl = "Positive" if yp == 1 else "Negative"
        print(f"Example {i}:")
        print(f"  True label : {true_lbl}")
        print(f"  Pred label : {pred_lbl}")
        print(f"  P(Positive): {p:.3f}")
        print("  Text:", t[:400].replace("\n", " "))
        if len(t) > 400:
            print("  ...")
        print()

    return {
        "dataset": tag,
        "accuracy": acc,
        "macro_f1": f1m,
        "brier": brier,
        "n_test": len(y_true),
    }

In [None]:
# Run final fine-tuning
metrics = []
metrics.append(finetune_and_eval("IMDB", imdb_tr, imdb_va, imdb_te, CFG_TUNED))
metrics.append(finetune_and_eval("Rotten Tomatoes", rt_tr, rt_va, rt_te, CFG_TUNED))

pd.DataFrame(metrics)