# Experiment D: Scalar Mix Layer Pooling

Adds `PoolingStrategy.SCALAR_MIX` to the pooling search space: a learned convex
combination of all 13 DeBERTa hidden states (embedding layer + 12 transformer layers),
followed by attention-mask-weighted mean pooling.

Rationale: DeBERTa lower layers encode syntax, upper layers encode semantics;
PCL detection may benefit from intermediate representations.

No extra features (isolate the effect of pooling).
Fixed corrections: VAL_FRACTION=0.15, BATCH_SIZE=32, NUM_EPOCHS=12, PATIENCE=4.

In [None]:
import os
import sys
import random
import logging
import gc
import json

import numpy as np
import torch
from transformers import AutoTokenizer
from sklearn.metrics import classification_report
import optuna
from optuna.visualization.matplotlib import (
    plot_optimization_history,
    plot_param_importances,
    plot_parallel_coordinate,
)
import matplotlib.pyplot as plt

sys.path.insert(0, "..")
from utils.data import load_data
from utils.split import split_train_val
from utils.dataloaders import make_dataloaders
from utils.pcl_deberta import PCLDeBERTa, PoolingStrategy
from utils.optim import compute_pos_weight
from utils.training_loop import train_model
from utils.eval import evaluate

SEED = 42
DATA_DIR = "../data"
OUT_DIR = "out"
MODEL_NAME = "microsoft/deberta-v3-base"
MAX_LENGTH = 256
VAL_FRACTION = 0.15
BATCH_SIZE = 32
N_TRIALS = 20
NUM_EPOCHS = 12
PATIENCE = 4
N_EVAL_STEPS = 35
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s:\t%(message)s")
LOG = logging.getLogger(__name__)
LOG.info(f"Device: {DEVICE}")
os.makedirs(OUT_DIR, exist_ok=True)

## 1. Data Loading

In [None]:
train_df, dev_df = load_data(DATA_DIR)
train_sub_df, val_sub_df = split_train_val(train_df, val_frac=VAL_FRACTION, seed=SEED)
tokeniser = AutoTokenizer.from_pretrained(MODEL_NAME)
LOG.info(f"Train: {len(train_sub_df)}, Val: {len(val_sub_df)}, Dev: {len(dev_df)}")

## 2. Hyperparameter Search

The pooling search space is extended to include `scalar_mix` alongside the standard four strategies.

In [None]:
POOLING_MAP = {
    "cls": PoolingStrategy.CLS,
    "mean": PoolingStrategy.MEAN,
    "max": PoolingStrategy.MAX,
    "cls_mean": PoolingStrategy.CLS_MEAN,
    "scalar_mix": PoolingStrategy.SCALAR_MIX,
}
EXP_NAME = "D_layer_pooling"


def objective(trial: optuna.trial.Trial) -> float:
    lr              = trial.suggest_float("lr", 4e-6, 6e-5, log=True)
    warmup_fraction = trial.suggest_float("warmup_fraction", 0.03, 0.20, step=0.01)
    hidden_dim      = trial.suggest_categorical("hidden_dim", [0, 128, 256, 512])
    dropout_rate    = trial.suggest_float("dropout_rate", 0.0, 0.4, step=0.05) if hidden_dim > 0 else 0.0
    weight_decay    = trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True)
    head_lr_mult    = trial.suggest_categorical("head_lr_multiplier", [1, 3, 5, 10])
    label_smoothing = trial.suggest_float("label_smoothing", 0.0, 0.15, step=0.025)
    pooling_name    = trial.suggest_categorical(
        "pooling", ["cls", "mean", "max", "cls_mean", "scalar_mix"]
    )

    pooling = POOLING_MAP[pooling_name]

    LOG.info(f"[{EXP_NAME}] Trial {trial.number}: lr={lr:.2e}, pool={pooling_name}, hidden={hidden_dim}")

    train_loader, val_loader, dev_loader = make_dataloaders(
        train_sub_df, val_sub_df, dev_df, BATCH_SIZE, MAX_LENGTH, tokeniser
    )

    model = PCLDeBERTa(
        hidden_dim=hidden_dim, dropout_rate=dropout_rate, pooling=pooling
    ).to(DEVICE)

    pos_weight = compute_pos_weight(train_sub_df, DEVICE)

    results = train_model(
        model=model, device=DEVICE,
        train_loader=train_loader, val_loader=val_loader, dev_loader=dev_loader,
        pos_weight=pos_weight, lr=lr, weight_decay=weight_decay,
        num_epochs=NUM_EPOCHS, warmup_fraction=warmup_fraction,
        patience=PATIENCE, head_lr_multiplier=head_lr_mult,
        label_smoothing=label_smoothing, eval_every_n_steps=N_EVAL_STEPS,
        trial=trial,
    )

    trial.set_user_attr("best_val_f1",    results["best_val_f1"])
    trial.set_user_attr("best_threshold", results["best_threshold"])
    trial.set_user_attr("dev_f1",         results["dev_metrics"]["f1"])
    trial.set_user_attr("dev_precision",  results["dev_metrics"]["precision"])
    trial.set_user_attr("dev_recall",     results["dev_metrics"]["recall"])

    try:
        prev_best = trial.study.best_value
    except ValueError:
        prev_best = -float("inf")
    if results["best_val_f1"] > prev_best:
        torch.save(
            {k: v.cpu() for k, v in model.state_dict().items()},
            os.path.join(OUT_DIR, f"exp_{EXP_NAME}_best_model.pt")
        )
        config = {**trial.params, "batch_size": BATCH_SIZE, "num_epochs": NUM_EPOCHS,
                  "patience": PATIENCE, "best_threshold": results["best_threshold"]}
        with open(os.path.join(OUT_DIR, f"exp_{EXP_NAME}_best_params.json"), "w") as f:
            json.dump(config, f, indent=2)
        LOG.info(f"[{EXP_NAME}] New best saved (val F1={results['best_val_f1']:.4f})")

    del model, train_loader, val_loader, dev_loader
    gc.collect()
    torch.cuda.empty_cache()
    return results["best_val_f1"]

## 3. Run Experiment

In [None]:
gc.collect()
torch.cuda.empty_cache()

study = optuna.create_study(
    direction="maximize",
    study_name=f"pcl_deberta_exp_{EXP_NAME}",
    sampler=optuna.samplers.TPESampler(seed=SEED),
    pruner=optuna.pruners.MedianPruner(n_startup_trials=6, n_warmup_steps=300),
)
study.optimize(objective, n_trials=N_TRIALS)

best = study.best_trial
LOG.info(f"Best trial: {best.number}")
LOG.info(f"Val F1: {best.user_attrs['best_val_f1']:.4f} | Dev F1: {best.user_attrs['dev_f1']:.4f}")
LOG.info(f"Best params: {best.params}")

## 4. Results

In [None]:
for plot_fn, suffix in [
    (plot_optimization_history, "history"),
    (plot_param_importances, "importances"),
    (plot_parallel_coordinate, "parallel"),
]:
    plot_fn(study)
    plt.tight_layout()
    plt.savefig(f"{OUT_DIR}/{EXP_NAME}_optuna_{suffix}.png", dpi=300)
    plt.show()

best = study.best_trial
best_params = best.params
pooling = POOLING_MAP[best_params["pooling"]]

model = PCLDeBERTa(
    hidden_dim=best_params["hidden_dim"],
    dropout_rate=best_params.get("dropout_rate", 0.0),
    pooling=pooling,
).to(DEVICE)

state_dict = torch.load(
    os.path.join(OUT_DIR, f"exp_{EXP_NAME}_best_model.pt"), map_location=DEVICE
)
model.load_state_dict(state_dict)

_, _, dev_loader = make_dataloaders(train_sub_df, val_sub_df, dev_df, BATCH_SIZE, MAX_LENGTH, tokeniser)
dev_metrics = evaluate(model, DEVICE, dev_loader, threshold=best.user_attrs["best_threshold"])

print(f"\n{'='*60}")
print(f"{EXP_NAME.upper()} â€” Dev Set Results (threshold={best.user_attrs['best_threshold']:.3f})")
print(f"{'='*60}")
print(classification_report(dev_metrics["labels"], dev_metrics["preds"], target_names=["Non-PCL", "PCL"]))
for k, v in best_params.items():
    print(f"  {k}: {v}")

# Inspect learned layer weights if scalar_mix was selected
if best_params["pooling"] == "scalar_mix":
    weights = torch.softmax(model.layer_weights, dim=0).detach().cpu().numpy()
    print("\nLearned layer weights (embedding=0, transformer=1..12):")
    for i, w in enumerate(weights):
        print(f"  Layer {i:2d}: {w:.4f}")

del model
gc.collect()
torch.cuda.empty_cache()