# Experiment B: Focal Loss

Replace `BCEWithLogitsLoss` with `FocalLoss` (gamma searched as hyperparameter).
Focal loss down-weights trivially easy negatives, concentrating gradient on
hard/borderline cases — particularly useful at the ~10:1 class imbalance here.

No extra features; pure DeBERTa fine-tuning with fixed hyperparameter corrections:
VAL_FRACTION=0.15, BATCH_SIZE=32, NUM_EPOCHS=12, PATIENCE=4, pooling searched.

In [None]:
import os
import sys
import random
import logging
import gc
import json
import math

import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
import optuna
from optuna.visualization.matplotlib import (
    plot_optimization_history,
    plot_param_importances,
    plot_parallel_coordinate,
)
import matplotlib.pyplot as plt

sys.path.insert(0, "..")
from utils.data import load_data
from utils.split import split_train_val
from utils.dataloaders import make_dataloaders
from utils.pcl_deberta import PCLDeBERTa, PoolingStrategy
from utils.losses import FocalLoss
from utils.optim import get_cosine_schedule_with_warmup, compute_pos_weight
from utils.early_stopping import EarlyStopping
from utils.eval import evaluate, find_best_threshold

SEED = 42
DATA_DIR = "../data"
OUT_DIR = "out"
MODEL_NAME = "microsoft/deberta-v3-base"
MAX_LENGTH = 256
VAL_FRACTION = 0.15
BATCH_SIZE = 32
N_TRIALS = 20
NUM_EPOCHS = 12
PATIENCE = 4
N_EVAL_STEPS = 35
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s:\t%(message)s")
LOG = logging.getLogger(__name__)
LOG.info(f"Device: {DEVICE}")
os.makedirs(OUT_DIR, exist_ok=True)

## 1. Data Loading

In [None]:
train_df, dev_df = load_data(DATA_DIR)
train_sub_df, val_sub_df = split_train_val(train_df, val_frac=VAL_FRACTION, seed=SEED)
tokeniser = AutoTokenizer.from_pretrained(MODEL_NAME)
LOG.info(f"Train: {len(train_sub_df)}, Val: {len(val_sub_df)}, Dev: {len(dev_df)}")

## 2. Focal-Loss Training Function

Same as `train_model` but uses `FocalLoss(gamma, pos_weight)` instead of `BCEWithLogitsLoss`.

In [None]:
def train_model_focal(
    model, device, train_loader, val_loader, dev_loader,
    pos_weight, lr, weight_decay, num_epochs, warmup_fraction,
    patience, gamma, head_lr_multiplier=3.0, label_smoothing=0.0,
    eval_every_n_steps=50, trial=None,
) -> dict:
    """train_model variant using FocalLoss instead of BCEWithLogitsLoss."""
    criterion = FocalLoss(gamma=gamma, pos_weight=pos_weight)

    backbone_params = list(model.backbone.parameters())
    head_params = list(model.classifier.parameters())
    optimizer = AdamW([
        {"params": backbone_params, "lr": lr},
        {"params": head_params, "lr": lr * head_lr_multiplier},
    ], weight_decay=weight_decay)

    total_steps = len(train_loader) * num_epochs
    warmup_steps = int(total_steps * warmup_fraction)
    scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    evals_per_epoch = max(1, len(train_loader) // eval_every_n_steps)
    patience_in_evals = patience * evals_per_epoch
    early_stopper = EarlyStopping(patience=patience_in_evals)

    model.train()
    global_step = 0
    train_losses = []
    best_val_f1 = 0.0
    best_state_dict = None
    running_loss = 0.0

    for epoch in range(num_epochs):
        LOG.info(f"Epoch {epoch+1}/{num_epochs}")
        for batch in train_loader:
            input_ids      = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels         = batch["labels"].to(device)

            if label_smoothing > 0:
                labels = labels * (1 - label_smoothing) + 0.5 * label_smoothing

            optimizer.zero_grad()
            scores = model(input_ids=input_ids, attention_mask=attention_mask).squeeze(-1)
            loss = criterion(scores, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            running_loss += loss.item()
            global_step += 1

            if global_step % eval_every_n_steps == 0:
                avg_loss = running_loss / eval_every_n_steps
                train_losses.append(avg_loss)
                running_loss = 0.0

                val_metrics = evaluate(model, device, val_loader, criterion=criterion)
                val_f1 = val_metrics["f1"]
                LOG.info(f"Step {global_step} | Loss: {avg_loss:.4f} | Val F1: {val_f1:.4f}")

                if val_f1 > best_val_f1:
                    best_val_f1 = val_f1
                    best_state_dict = {k: v.cpu().clone() for k, v in model.state_dict().items()}

                if trial is not None:
                    trial.report(val_f1, global_step)
                    if trial.should_prune():
                        raise optuna.exceptions.TrialPruned()

                if early_stopper.step(val_f1):
                    LOG.info(f"Early stopping at step {global_step}")
                    break

        if early_stopper.should_stop:
            break

    if best_state_dict is not None:
        model.load_state_dict(best_state_dict)

    best_thresh, thresh_val_f1 = find_best_threshold(model, device, val_loader)
    dev_metrics = evaluate(model, device, dev_loader, criterion=criterion, threshold=best_thresh)
    LOG.info(f"Threshold: {best_thresh:.3f} | Dev F1: {dev_metrics['f1']:.4f}")

    return {
        "best_val_f1": thresh_val_f1,
        "best_threshold": best_thresh,
        "dev_metrics": dev_metrics,
        "train_losses": train_losses,
    }

## 3. Hyperparameter Search

In [None]:
POOLING_MAP = {
    "cls": PoolingStrategy.CLS,
    "mean": PoolingStrategy.MEAN,
    "max": PoolingStrategy.MAX,
    "cls_mean": PoolingStrategy.CLS_MEAN,
}
EXP_NAME = "B_focal"


def objective(trial: optuna.trial.Trial) -> float:
    lr              = trial.suggest_float("lr", 4e-6, 6e-5, log=True)
    warmup_fraction = trial.suggest_float("warmup_fraction", 0.03, 0.20, step=0.01)
    hidden_dim      = trial.suggest_categorical("hidden_dim", [0, 128, 256, 512])
    dropout_rate    = trial.suggest_float("dropout_rate", 0.0, 0.4, step=0.05) if hidden_dim > 0 else 0.0
    weight_decay    = trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True)
    head_lr_mult    = trial.suggest_categorical("head_lr_multiplier", [1, 3, 5, 10])
    label_smoothing = trial.suggest_float("label_smoothing", 0.0, 0.15, step=0.025)
    pooling_name    = trial.suggest_categorical("pooling", ["cls", "mean", "max", "cls_mean"])
    gamma           = trial.suggest_categorical("gamma", [0.5, 1.0, 1.5, 2.0, 2.5, 3.0])

    pooling = POOLING_MAP[pooling_name]

    LOG.info(f"[{EXP_NAME}] Trial {trial.number}: lr={lr:.2e}, gamma={gamma}, pool={pooling_name}")

    train_loader, val_loader, dev_loader = make_dataloaders(
        train_sub_df, val_sub_df, dev_df, BATCH_SIZE, MAX_LENGTH, tokeniser
    )

    model = PCLDeBERTa(
        hidden_dim=hidden_dim, dropout_rate=dropout_rate, pooling=pooling
    ).to(DEVICE)

    pos_weight = compute_pos_weight(train_sub_df, DEVICE)

    results = train_model_focal(
        model=model, device=DEVICE,
        train_loader=train_loader, val_loader=val_loader, dev_loader=dev_loader,
        pos_weight=pos_weight, lr=lr, weight_decay=weight_decay,
        num_epochs=NUM_EPOCHS, warmup_fraction=warmup_fraction,
        patience=PATIENCE, gamma=gamma,
        head_lr_multiplier=head_lr_mult,
        label_smoothing=label_smoothing,
        eval_every_n_steps=N_EVAL_STEPS,
        trial=trial,
    )

    trial.set_user_attr("best_val_f1",    results["best_val_f1"])
    trial.set_user_attr("best_threshold", results["best_threshold"])
    trial.set_user_attr("dev_f1",         results["dev_metrics"]["f1"])
    trial.set_user_attr("dev_precision",  results["dev_metrics"]["precision"])
    trial.set_user_attr("dev_recall",     results["dev_metrics"]["recall"])

    try:
        prev_best = trial.study.best_value
    except ValueError:
        prev_best = -float("inf")
    if results["best_val_f1"] > prev_best:
        torch.save(
            {k: v.cpu() for k, v in model.state_dict().items()},
            os.path.join(OUT_DIR, f"exp_{EXP_NAME}_best_model.pt")
        )
        config = {**trial.params, "batch_size": BATCH_SIZE, "num_epochs": NUM_EPOCHS,
                  "patience": PATIENCE, "best_threshold": results["best_threshold"]}
        with open(os.path.join(OUT_DIR, f"exp_{EXP_NAME}_best_params.json"), "w") as f:
            json.dump(config, f, indent=2)
        LOG.info(f"[{EXP_NAME}] New best saved (val F1={results['best_val_f1']:.4f})")

    del model, train_loader, val_loader, dev_loader
    gc.collect()
    torch.cuda.empty_cache()
    return results["best_val_f1"]

## 4. Run Experiment

In [None]:
gc.collect()
torch.cuda.empty_cache()

study = optuna.create_study(
    direction="maximize",
    study_name=f"pcl_deberta_exp_{EXP_NAME}",
    sampler=optuna.samplers.TPESampler(seed=SEED),
    pruner=optuna.pruners.MedianPruner(n_startup_trials=6, n_warmup_steps=300),
)
study.optimize(objective, n_trials=N_TRIALS)

best = study.best_trial
LOG.info(f"Best trial: {best.number}")
LOG.info(f"Val F1: {best.user_attrs['best_val_f1']:.4f} | Dev F1: {best.user_attrs['dev_f1']:.4f}")
LOG.info(f"Best params: {best.params}")

## 5. Results

In [None]:
for plot_fn, suffix in [
    (plot_optimization_history, "history"),
    (plot_param_importances, "importances"),
    (plot_parallel_coordinate, "parallel"),
]:
    plot_fn(study)
    plt.tight_layout()
    plt.savefig(f"{OUT_DIR}/{EXP_NAME}_optuna_{suffix}.png", dpi=300)
    plt.show()

best = study.best_trial
best_params = best.params
pooling = POOLING_MAP[best_params["pooling"]]

model = PCLDeBERTa(
    hidden_dim=best_params["hidden_dim"],
    dropout_rate=best_params.get("dropout_rate", 0.0),
    pooling=pooling,
).to(DEVICE)

state_dict = torch.load(
    os.path.join(OUT_DIR, f"exp_{EXP_NAME}_best_model.pt"), map_location=DEVICE
)
model.load_state_dict(state_dict)

_, _, dev_loader = make_dataloaders(train_sub_df, val_sub_df, dev_df, BATCH_SIZE, MAX_LENGTH, tokeniser)
pos_weight = compute_pos_weight(train_sub_df, DEVICE)
focal = FocalLoss(gamma=best_params["gamma"], pos_weight=pos_weight)
dev_metrics = evaluate(model, DEVICE, dev_loader, criterion=focal,
                       threshold=best.user_attrs["best_threshold"])

print(f"\n{'='*60}")
print(f"{EXP_NAME.upper()} — Dev Set Results (threshold={best.user_attrs['best_threshold']:.3f})")
print(f"{'='*60}")
print(classification_report(dev_metrics["labels"], dev_metrics["preds"], target_names=["Non-PCL", "PCL"]))
for k, v in best_params.items():
    print(f"  {k}: {v}")

del model
gc.collect()
torch.cuda.empty_cache()