# Experiment G: Weighted Random Sampling

Address the ~10:1 class imbalance via `WeightedRandomSampler` on the training
DataLoader rather than via `pos_weight` in the loss.

**Key design choices:**
- `target_pos_frac` (searched) controls the expected fraction of PCL samples
  per batch — 0.5 = balanced, 0.095 ≈ natural distribution.
- `pos_weight = 1.0` in BCE loss — sampler already handles imbalance; using
  both would double-compensate and over-emphasise the minority class.
- Sampler uses `replacement=True` with `num_samples=len(train_sub_df)` so
  each epoch sees approximately the same number of gradient steps as before.
- `train_model` from `utils/training_loop.py` is used **unchanged**.
- Val and dev loaders are unweighted (evaluate on true distribution).

Fixed hyperparameters (not searched):
- VAL_FRACTION=0.15, BATCH_SIZE=32, NUM_EPOCHS=12, PATIENCE=4
- `pooling=MEAN` — best default for DeBERTa-v3 (RTD pretraining; no special CLS)
- `warmup_fraction=0.10`, `label_smoothing=0.0` — fixed to save trials for key params

Searched: `lr`, `weight_decay`, `hidden_dim ∈ {0, 256}`, `dropout_rate`,
`head_lr_multiplier ∈ {1, 3, 5}`, `target_pos_frac ∈ {0.2, 0.35, 0.5}`

In [None]:
import os
import sys
import random
import logging
import gc
import json

import numpy as np
import torch
from transformers import AutoTokenizer
from sklearn.metrics import classification_report
import optuna
from optuna.visualization.matplotlib import (
    plot_optimization_history,
    plot_param_importances,
    plot_parallel_coordinate,
)
import matplotlib.pyplot as plt

sys.path.insert(0, "..")
from utils.data import load_data
from utils.split import split_train_val
from utils.dataloaders import make_weighted_dataloaders
from utils.pcl_deberta import PCLDeBERTa, PoolingStrategy
from utils.training_loop import train_model
from utils.eval import evaluate

SEED = 42
DATA_DIR = "../data"
OUT_DIR = "out"
MODEL_NAME = "microsoft/deberta-v3-base"
MAX_LENGTH = 256
VAL_FRACTION = 0.15
BATCH_SIZE = 32
N_TRIALS = 20
NUM_EPOCHS = 12
PATIENCE = 4
N_EVAL_STEPS = 35
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s:\t%(message)s")
LOG = logging.getLogger(__name__)
LOG.info(f"Device: {DEVICE}")
os.makedirs(OUT_DIR, exist_ok=True)

## 1. Data Loading

In [None]:
train_df, dev_df = load_data(DATA_DIR)
train_sub_df, val_sub_df = split_train_val(train_df, val_frac=VAL_FRACTION, seed=SEED)
tokeniser = AutoTokenizer.from_pretrained(MODEL_NAME)

n_pos = int(train_sub_df["binary_label"].sum())
n_neg = len(train_sub_df) - n_pos
LOG.info(f"Train: {len(train_sub_df)} ({n_pos} PCL / {n_neg} non-PCL, "
         f"natural pos frac={n_pos/len(train_sub_df):.3f})")
LOG.info(f"Val: {len(val_sub_df)}, Dev: {len(dev_df)}")

## 2. Hyperparameter Search

`target_pos_frac` is the key new hyperparameter. We search three values:
- `0.5` — fully balanced batches
- `0.35` — moderately oversampled
- `0.2` — mild oversampling (roughly 2× natural frequency)

Secondary hyperparameters are narrowed to avoid wasting trials:
- `pooling` fixed to MEAN (DeBERTa-v3 RTD pretraining; no special CLS token)
- `warmup_fraction` fixed to 0.10, `label_smoothing` fixed to 0.0
- `hidden_dim ∈ {0, 256}` (0 = single linear, 256 = MLP)
- `dropout_rate ∈ {0.1, 0.3}` (only sampled when hidden_dim=256)
- `head_lr_multiplier ∈ {1, 3, 5}`

`pos_weight=1.0` is passed to `train_model` — the sampler handles class balance.

In [None]:
POOLING = PoolingStrategy.MEAN   # fixed: best default for DeBERTa-v3 RTD pretraining
EXP_NAME = "G_weighted_sampling"


def objective(trial: optuna.trial.Trial) -> float:
    lr              = trial.suggest_float("lr", 4e-6, 6e-5, log=True)
    weight_decay    = trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True)
    hidden_dim      = trial.suggest_categorical("hidden_dim", [0, 256])
    dropout_rate    = trial.suggest_categorical("dropout_rate", [0.1, 0.3]) if hidden_dim > 0 else 0.0
    head_lr_mult    = trial.suggest_categorical("head_lr_multiplier", [1, 3, 5, 10])
    target_pos_frac = trial.suggest_categorical("target_pos_frac", [0.2, 0.35, 0.5])

    # Fixed — not worth spending trials on
    warmup_fraction = 0.10
    label_smoothing = 0.0

    LOG.info(f"[{EXP_NAME}] Trial {trial.number}: lr={lr:.2e}, hidden={hidden_dim}, "
             f"target_pos_frac={target_pos_frac}, head_lr_mult={head_lr_mult}")

    train_loader, val_loader, dev_loader = make_weighted_dataloaders(
        train_sub_df, val_sub_df, dev_df, BATCH_SIZE, MAX_LENGTH, tokeniser,
        target_pos_frac=target_pos_frac,
    )

    model = PCLDeBERTa(
        hidden_dim=hidden_dim, dropout_rate=dropout_rate, pooling=POOLING
    ).to(DEVICE)

    # pos_weight=1.0 — sampler handles imbalance; no double-compensation
    pos_weight = torch.ones(1, device=DEVICE)

    results = train_model(
        model=model, device=DEVICE,
        train_loader=train_loader, val_loader=val_loader, dev_loader=dev_loader,
        pos_weight=pos_weight, lr=lr, weight_decay=weight_decay,
        num_epochs=NUM_EPOCHS, warmup_fraction=warmup_fraction,
        patience=PATIENCE,
        head_lr_multiplier=head_lr_mult,
        label_smoothing=label_smoothing,
        eval_every_n_steps=N_EVAL_STEPS,
        trial=trial,
    )

    trial.set_user_attr("best_val_f1",    results["best_val_f1"])
    trial.set_user_attr("best_threshold", results["best_threshold"])
    trial.set_user_attr("dev_f1",         results["dev_metrics"]["f1"])
    trial.set_user_attr("dev_precision",  results["dev_metrics"]["precision"])
    trial.set_user_attr("dev_recall",     results["dev_metrics"]["recall"])

    try:
        prev_best = trial.study.best_value
    except ValueError:
        prev_best = -float("inf")
    if results["best_val_f1"] > prev_best:
        torch.save(
            {k: v.cpu() for k, v in model.state_dict().items()},
            os.path.join(OUT_DIR, f"exp_{EXP_NAME}_best_model.pt")
        )
        config = {
            **trial.params,
            "pooling": "mean",
            "warmup_fraction": warmup_fraction,
            "label_smoothing": label_smoothing,
            "batch_size": BATCH_SIZE, "num_epochs": NUM_EPOCHS, "patience": PATIENCE,
            "best_threshold": results["best_threshold"],
        }
        with open(os.path.join(OUT_DIR, f"exp_{EXP_NAME}_best_params.json"), "w") as f:
            json.dump(config, f, indent=2)
        LOG.info(f"[{EXP_NAME}] New best saved (val F1={results['best_val_f1']:.4f})")

    del model, train_loader, val_loader, dev_loader
    gc.collect()
    torch.cuda.empty_cache()
    return results["best_val_f1"]

## 3. Run Experiment

In [None]:
gc.collect()
torch.cuda.empty_cache()

study = optuna.create_study(
    direction="maximize",
    study_name=f"pcl_deberta_exp_{EXP_NAME}",
    sampler=optuna.samplers.TPESampler(seed=SEED),
    pruner=optuna.pruners.MedianPruner(n_startup_trials=6, n_warmup_steps=300),
)
study.optimize(objective, n_trials=N_TRIALS)

best = study.best_trial
LOG.info(f"Best trial: {best.number}")
LOG.info(f"Val F1: {best.user_attrs['best_val_f1']:.4f} | Dev F1: {best.user_attrs['dev_f1']:.4f}")
LOG.info(f"Best params: {best.params}")

## 4. Results

In [None]:
for plot_fn, suffix in [
    (plot_optimization_history, "history"),
    (plot_param_importances, "importances"),
    (plot_parallel_coordinate, "parallel"),
]:
    plot_fn(study)
    plt.tight_layout()
    plt.savefig(f"{OUT_DIR}/{EXP_NAME}_optuna_{suffix}.png", dpi=300)
    plt.show()

best = study.best_trial
best_params = best.params

model = PCLDeBERTa(
    hidden_dim=best_params["hidden_dim"],
    dropout_rate=best_params.get("dropout_rate", 0.0),
    pooling=POOLING,
).to(DEVICE)

state_dict = torch.load(
    os.path.join(OUT_DIR, f"exp_{EXP_NAME}_best_model.pt"), map_location=DEVICE
)
model.load_state_dict(state_dict)

_, _, dev_loader = make_weighted_dataloaders(
    train_sub_df, val_sub_df, dev_df, BATCH_SIZE, MAX_LENGTH, tokeniser,
    target_pos_frac=best_params["target_pos_frac"],
)

best_threshold = best.user_attrs["best_threshold"]
dev_metrics = evaluate(model, DEVICE, dev_loader, threshold=best_threshold)

print(f"\n{'='*60}")
print(f"{EXP_NAME.upper()} — Dev Set Results (threshold={best_threshold:.3f})")
print(f"{'='*60}")
print(classification_report(dev_metrics["labels"], dev_metrics["preds"],
                             target_names=["Non-PCL", "PCL"]))
print("Best hyperparams:")
for k, v in best_params.items():
    print(f"  {k}: {v}")
print(f"  pooling: mean (fixed)")
print(f"  warmup_fraction: 0.10 (fixed)")
print(f"  label_smoothing: 0.0 (fixed)")

del model
gc.collect()
torch.cuda.empty_cache()