# **Reward Model Trainer**

This notebook implements an end-to-end training pipeline for a Reward Model used in preference-based fine-tuning and reinforcement learning for language models.

It covers data preparation, tokenisation, model setup, training loop and evaluation metrics.

**Author:** Crescent - DSA4213 Group 18

---

## **Table of Contents**
1. [Environment Setup](#env-setup)
2. [Imports](#imports)
3. [Configuration](#config)
4. [Utilities](#utils)
5. [Training & Evaluation](#train-eval)
6. [Visualisations](#viz)

## **Overview**

**Main objectives:**
1. Load and preprocess preference-pair datasets.
2. Tokenise and batch inputs using the Hugging Face `transformers` library.
3. Initialise and fine-tune a `Llama 3.2`-based reward model.
4. Compute and log training loss, validation metrics, and save model checkpoints.
5. Visualise loss trends and model confidence distributions.


# 1. Environment Setup <a id='env-setup'></a>

Installs dependencies automatically depending on where it's run:
- On VS Code (Windows/CPU): CPU-only build of PyTorch
- On RONIN (Linux/GPU): CUDA 12.1 build of PyTorch

In [4]:
import platform
import subprocess

system_name = platform.system().lower()
is_ronin = "ronin" in platform.node().lower()

if system_name == "windows":
    print("🖥️ Detected Windows environment. Installing CPU build...")
    subprocess.run(["pip", "install", "torch"], check=False)
elif is_ronin or system_name == "linux":
    print("🚀 Detected Linux/RONIN GPU environment. Installing CUDA 12.1 build...")
    subprocess.run([
        "pip", "install", "torch",
        "--index-url", "https://download.pytorch.org/whl/cu121"
    ], check=False)
else:
    print("⚙️ Unknown environment — installing CPU build by default.")
    subprocess.run(["pip", "install", "torch"], check=False)

common_packages = [
    "transformers",
    "tqdm",
    "matplotlib",
    "numpy",
    "torchinfo"
]

print(f"Installing common dependencies: {', '.join(common_packages)}")
subprocess.run(["pip", "install", *common_packages], check=False)

🖥️ Detected Windows environment. Installing CPU build...
Installing common dependencies: transformers, tqdm, matplotlib, numpy, torchinfo


CompletedProcess(args=['pip', 'install', 'transformers', 'tqdm', 'matplotlib', 'numpy', 'torchinfo'], returncode=0)

# 2. Imports <a id='imports'></a>

In [5]:
from collections import Counter, defaultdict, OrderedDict
from pathlib import Path
from typing import Any, Dict, Tuple
import contextlib, glob, json, math, random, time, numpy as np, matplotlib.pyplot as plt
import torch
from torchinfo import summary
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# 3. Configuration <a id='config'></a>
Set constants, hyperparameters, and paths.

In [23]:
# ╭────────────────────────────── Model & Directories ───────────────────────────────╮
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"

# Current and source directories
CUR_DIR = Path.cwd()
SRC_DIR = CUR_DIR.parent

# Data directories and files
DATA_DIR = SRC_DIR / "data" / "processed"
PAIRS_PATH = DATA_DIR / "preference_pairs.json"
#TEST_QNA_PATH = DATA_DIR / "test_qna_with_confidence.json"
TRAIN_PATH = DATA_DIR / "preference_pairs_train.json"
VAL_PATH   = DATA_DIR / "preference_pairs_val.json"

# Checkpoint and metric directories
CKPT_DIR   = CUR_DIR / "checkpoints"
METRIC_DIR = CUR_DIR / "metrics"
TRAIN_LOG = METRIC_DIR / "train_steps.jsonl"         # per-step loss
VAL_EPOCH_LOG = METRIC_DIR / "val_epochs.jsonl"      # per-epoch metrics
VAL_DELTAS_DIR = METRIC_DIR / "val_deltas"           # per-epoch margins for histogram

# Ensure directories exist
for p in (DATA_DIR, CKPT_DIR, METRIC_DIR):
    p.mkdir(parents=True, exist_ok=True)

# ╭────────────────────────────── Training Config ──────────────────────────────╮
# Data split ratios
TRAIN_RATIO = 0.9
VAL_RATIO   = 0.1
TEST_RATIO  = 0.0

# Flags for data processing
DROP_WEAK = False   # Drop low-confidence preference pairs (set True for cleaner supervision)

# Hyperparameters
MAX_LEN    = 1024    # Maximum tokenised sequence length
BETA       = 3.0     # Scales confidence penalty into a sample weight
BATCH_SIZE = 2       # Training batch size
EPOCHS     = 5
GRAD_CLIP  = 1.0
LR         = 2e-5
WEIGHT_DECAY = 0.0

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# 4. Utilities <a id='utils'></a>
Helper functions for tokenisation, loss computation, logging, etc.

## JSON I/O

In [24]:
def _read_json(p: Path):
    return json.loads(p.read_text(encoding="utf-8"))

def _write_json(obj, p: Path):
    p.write_text(json.dumps(obj, indent=2, ensure_ascii=False), encoding="utf-8")
    print(f"[saved] {p}")

## Data splitting

In [11]:
def split_preference_pairs(
    pairs_path: Path,
    out_dir: Path,
    train_ratio: float = 0.9,
    val_ratio: float = 0.1,
    test_ratio: float = 0.0,
    seed: int = 42,
    drop_weak: bool = False,
) -> Dict[str, Any]:
    """
    Split preference_pairs.json into train/val(/test) without question leakage.

    Args:
        pairs_path: path to preference_pairs.json
        out_dir: directory to write split files
        train_ratio, val_ratio, test_ratio: must sum to 1.0 (test optional)
        seed: RNG seed for reproducibility
        drop_weak: if True, drops pairs with score_difference==0 and pairs where both ratings == 1

    Writes:
        preference_pairs.train.json
        preference_pairs.val.json
        (optional) preference_pairs.test.json
        preference_pairs_split_stats.json

    Returns:
        stats dict with counts, distributions, and dropped info.
    """
    assert abs((train_ratio + val_ratio + test_ratio) - 1.0) < 1e-6, "Ratios must sum to 1.0"
    out_dir.mkdir(parents=True, exist_ok=True)

    def _rating_score(ans: Dict[str, Any]) -> int:
        # Prefer explicit numeric 'score'; fallback parse from 'rating' like "3-Incomplete"
        if "score" in ans and ans["score"] is not None:
            return int(ans["score"])
        r = ans.get("rating", "")
        return int(r[0]) if isinstance(r, str) and r[:1].isdigit() else -1  # -1 if unknown

    # Load data
    pairs = _read_json(pairs_path)

    # Optional filtering
    dropped = {"score_diff_zero": 0, "both_score1": 0} # counts how many pairs were removed for each reason
    clean = []
    if drop_weak:
        for ex in pairs:
            if ex.get("score_difference", None) == 0:
                dropped["score_diff_zero"] += 1
                continue
            sp = _rating_score(ex.get("preferred_answer", {}))
            sr = _rating_score(ex.get("rejected_answer", {}))
            if sp == 1 and sr == 1:
                dropped["both_score1"] += 1
                continue
            clean.append(ex) # only store if passes filters
    else:
        clean = pairs

    # Group by question_id (no leakage across splits)
    by_qid = defaultdict(list)
    for ex in clean:
        qid = ex["question_id"]
        by_qid[qid].append(ex)

    qids = list(by_qid.keys())
    random.seed(seed)
    random.shuffle(qids)

    # Stratify by risk level (keeps similar risk mix across splits)
    def majority_risk(items):
        return Counter(e.get("risk_level", "UNKNOWN") for e in items).most_common(1)[0][0]

    risk_to_qids = defaultdict(list)
    for qid in qids:
        risk_to_qids[majority_risk(by_qid[qid])].append(qid)

    train_qids, val_qids, test_qids = set(), set(), set()
    for _, bucket in risk_to_qids.items():
        random.shuffle(bucket)
        n = len(bucket)
        n_train = int(round(n * train_ratio))
        n_val   = int(round(n * val_ratio))
        # ensure total fits by assigning remainder to test
        n_test  = n - n_train - n_val
        train_qids.update(bucket[:n_train])
        val_qids.update(bucket[n_train:n_train+n_val])
        test_qids.update(bucket[n_train+n_val:])

    # Materialise splits
    train_pairs = [e for q in train_qids for e in by_qid[q]]
    val_pairs   = [e for q in val_qids   for e in by_qid[q]]
    test_pairs  = [e for q in test_qids  for e in by_qid[q]] if test_ratio > 0 else []

    # Save
    _write_json(train_pairs, out_dir / "preference_pairs_train.json")
    _write_json(val_pairs,   out_dir / "preference_pairs_val.json")
    if test_ratio > 0:
        _write_json(test_pairs, out_dir / "preference_pairs_test.json")

    def _dist(lst):
        return dict(Counter(e.get("risk_level", "UNKNOWN") for e in lst))

    stats = {
        "seed": seed,
        "drop_weak": drop_weak,
        "dropped": dropped,
        "counts": {
            "train": len(train_pairs),
            "val":   len(val_pairs),
            "test":  len(test_pairs),
        },
        "unique_qids": {
            "train": len(train_qids),
            "val":   len(val_qids),
            "test":  len(test_qids),
        },
        "risk_dist": {
            "train": _dist(train_pairs),
            "val":   _dist(val_pairs),
            "test":  _dist(test_pairs),
        },
    }

    _write_json(stats, out_dir / "preference_pairs_split_stats.json")
    return stats

In [None]:
stats = split_preference_pairs(
    pairs_path = PAIRS_PATH,
    out_dir    = DATA_DIR,
    train_ratio= TRAIN_RATIO, # 90/10 train/val
    val_ratio  = VAL_RATIO,
    test_ratio = TEST_RATIO, # set to 0.1 if want a test split
    seed       = SEED,
    drop_weak  = DROP_WEAK # change to True if want cleaner pairs
)
stats

In [None]:
# Load splits
train_pairs = _read_json(TRAIN_PATH)
val_pairs = _read_json(VAL_PATH)
len(train_pairs), len(val_pairs)

(27775, 3327)

## Dataset + Dataloader

In [14]:
# Load tokenizer (ensure PAD exists)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [15]:
def build_input(q, a):
    return f"Question:\n{q}\n\nAnswer:\n{a}"

# Dataset and collate
class PrefPairDataset(Dataset):
    def __init__(self, items):
        self.items = items
    def __len__(self):
        return len(self.items)
    def __getitem__(self, i):
        ex = self.items[i]
        x_pos = build_input(ex["question_text"], ex["preferred_answer"]["answer_text"])
        x_neg = build_input(ex["question_text"], ex["rejected_answer"]["answer_text"])
        pos = tokenizer(x_pos, truncation=True, max_length=MAX_LEN, return_tensors="pt")
        neg = tokenizer(x_neg, truncation=True, max_length=MAX_LEN, return_tensors="pt")
        w = 1.0 + BETA * float(ex.get("confidence_penalty", 0.0))
        risk = ex.get("risk_level", "UNKNOWN")
        return {"pos": pos, "neg": neg, "weight": torch.tensor(w, dtype=torch.float), "risk": risk}

def collate_fn(batch):
    def stack(side):
        ids  = [b[side]["input_ids"].squeeze(0) for b in batch]
        attn = [b[side]["attention_mask"].squeeze(0) for b in batch]
        ids  = torch.nn.utils.rnn.pad_sequence(ids,  batch_first=True, padding_value=tokenizer.pad_token_id)
        attn = torch.nn.utils.rnn.pad_sequence(attn, batch_first=True, padding_value=0)
        return {"input_ids": ids, "attention_mask": attn}
    weights = torch.stack([b["weight"] for b in batch])
    risks   = [b["risk"]   for b in batch]
    return {"pos": stack("pos"), "neg": stack("neg"), "weight": weights, "risk": risks}

# build loaders
train_loader = DataLoader(PrefPairDataset(train_pairs), batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate_fn)
val_loader   = DataLoader(PrefPairDataset(val_pairs),   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
len(train_loader), len(val_loader)

(13888, 1664)

In [None]:
# Use 'dtype' instead of deprecated 'torch_dtype'
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device_map = "auto" if torch.cuda.is_available() else None
max_memory = {0: "13GB"} if torch.cuda.is_available() else None
reward_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=1,
    torch_dtype=dtype,
    device_map=device_map,
    low_cpu_mem_usage=True,
    max_memory=max_memory
)

reward_model.config.pad_token_id = tokenizer.pad_token_id
reward_model = reward_model.float()  # Ensure model is in FP32

# align special tokens / training toggles
# reward_model.config.eos_token_id = tokenizer.eos_token_id
# reward_model.config.bos_token_id = getattr(tokenizer, "bos_token_id", tokenizer.eos_token_id)
# reward_model.config.use_cache = False  # safer during training

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
reward_model = reward_model.to(device)

print("Device:", device)
print("Loaded:", MODEL_NAME)

Loading checkpoint shards: 100%|██████████| 2/2 [00:18<00:00,  9.30s/it]
Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-3.2-3B-Instruct and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Device: CPU
Loaded: meta-llama/Llama-3.2-3B-Instruct


## Check model architecture

In [18]:
# Create dummy input with same dimensions as expected by model
batch_size = 1
seq_len = MAX_LEN

# Build synthetic data
dummy_input_ids = torch.randint(
    low=0,
    high=tokenizer.vocab_size,
    size=(batch_size, seq_len),
    dtype=torch.long
)
dummy_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long)

# Run summary (depth controls how detailed the nested layer view is)
summary(
    reward_model,
    input_data={
        "input_ids": dummy_input_ids,
        "attention_mask": dummy_attention_mask
    },
    col_names=["input_size", "output_size", "num_params", "trainable"],
    depth=3, # adjust depth for more/less detail
    row_settings=("var_names",),  # shows variable names per layer
)

Layer (type (var_name))                                           Input Shape               Output Shape              Param #                   Trainable
LlamaForSequenceClassification (LlamaForSequenceClassification)   --                        --                        --                        True
├─LlamaModel (model)                                              [1, 1024]                 --                        --                        True
│    └─Embedding (embed_tokens)                                   [1, 1024]                 [1, 1024, 3072]           394,002,432               True
│    └─LlamaRotaryEmbedding (rotary_emb)                          [1, 1024, 3072]           [1, 1024, 128]            --                        --
│    └─ModuleList (layers)                                        --                        --                        --                        True
│    │    └─LlamaDecoderLayer (0)                                 [1, 1024, 3072]           [1, 1024, 3

In [19]:
def _num_params(module):
    total = sum(p.numel() for p in module.parameters(recurse=False))
    train = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad)
    return total, train

def _shape_str(x):
    if isinstance(x, torch.Tensor):
        return tuple(x.shape)
    if isinstance(x, (list, tuple)):
        # show first few shapes if it's a tuple/list
        parts = []
        for i, it in enumerate(x[:3]):
            if isinstance(it, torch.Tensor):
                parts.append(str(tuple(it.shape)))
            else:
                parts.append(type(it).__name__)
        if len(x) > 3:
            parts.append("...")
        return "[" + ", ".join(parts) + "]"
    return type(x).__name__

def model_summary(model, tokenizer, seq_len=64, verbose=True):
    """
    Prints a table:
      Layer (path) | Type | Output Shape | Param # (trainable)
    Uses a dummy forward pass with batch_size=1 and input length = seq_len.
    """
    device = next(model.parameters()).device
    dtype  = next(model.parameters()).dtype
    bytes_per = 2 if dtype in (torch.float16, torch.bfloat16) else 4

    # 1) Prepare a tiny dummy batch
    with torch.no_grad():
        toks = tokenizer("dummy", return_tensors="pt", add_special_tokens=True)
        # expand to desired seq_len (pad tokens)
        input_ids = torch.full((1, seq_len), fill_value=tokenizer.pad_token_id, dtype=torch.long)
        attn_mask = torch.zeros((1, seq_len), dtype=torch.long)
        # put some non-pad tokens at the start
        L = min(toks["input_ids"].shape[1], seq_len)
        input_ids[:, :L] = toks["input_ids"][:, :L]
        attn_mask[:, :L] = 1
        batch = {"input_ids": input_ids.to(device), "attention_mask": attn_mask.to(device)}

    # 2) Collect module info (param counts) and register forward hooks for output shapes
    names = {m: n for n, m in model.named_modules()}  # module -> qualified name
    layer_infos = OrderedDict() # preserve execution order
    handles = []

    def hook(module, inputs, outputs):
        name = names.get(module, module.__class__.__name__)
        if name == "": name = module.__class__.__name__
        if name not in layer_infos:
            p_total, p_train = _num_params(module)
            layer_infos[name] = {
                "type": module.__class__.__name__,
                "params": p_total,
                "trainable": p_train,
                "out": _shape_str(outputs),
            }

    # only hook leaf modules to avoid noisy duplicates
    for m in model.modules():
        is_leaf = len(list(m.children())) == 0
        if is_leaf:
            try:
                handles.append(m.register_forward_hook(hook))
            except Exception:
                pass

    # 3) Forward pass (to populate shapes)
    model.eval()
    with torch.no_grad():
        _ = model(**batch)

    for h in handles:
        h.remove()

    # 4) Print table
    total_params = sum(v["params"] for v in layer_infos.values())
    train_params = sum(v["trainable"] for v in layer_infos.values())
    frozen_params = total_params - train_params
    total_mb   = (total_params * bytes_per) / (1024**2)
    train_mb   = (train_params * bytes_per) / (1024**2)
    frozen_mb  = (frozen_params * bytes_per) / (1024**2)

    header = f"{'Layer (path)':60} {'Type':22} {'Output Shape':28} {'Param #':>12} {'Trainable':>10}"
    line   = "-" * len(header)
    if verbose:
        print(header)
        print(line)
        for name, info in layer_infos.items():
            print(f"{name:60} {info['type']:22} {str(info['out'])[:28]:28} {info['params']:12,} {info['trainable']:10,}")
        print(line)
        print(f"Total params:     {total_params:,}  (~{total_mb:.1f} MB at {str(dtype).replace('torch.','')})")
        print(f"Trainable params: {train_params:,}  (~{train_mb:.1f} MB)")
        print(f"Frozen params:    {frozen_params:,}  (~{frozen_mb:.1f} MB)")
        print(f"Seq len used for summary: {seq_len}, batch size: 1")

    # 5) Return a dict if you want to use programmatically
    return {
        "layers": layer_infos,
        "totals": {
            "total_params": total_params,
            "trainable_params": train_params,
            "frozen_params": frozen_params,
            "dtype": str(dtype),
            "mb_total": total_mb,
            "mb_trainable": train_mb,
            "mb_frozen": frozen_mb,
            "bytes_per_param": bytes_per,
            "seq_len": seq_len,
        }
    }

# run it
_ = model_summary(reward_model, tokenizer, seq_len=64)

Layer (path)                                                 Type                   Output Shape                      Param #  Trainable
----------------------------------------------------------------------------------------------------------------------------------------
model.embed_tokens                                           Embedding              (1, 64, 3072)                 394,002,432 394,002,432
model.rotary_emb                                             LlamaRotaryEmbedding   [(1, 64, 128), (1, 64, 128)]            0          0
model.layers.0.input_layernorm                               LlamaRMSNorm           (1, 64, 3072)                       3,072      3,072
model.layers.0.self_attn.q_proj                              Linear                 (1, 64, 3072)                   9,437,184  9,437,184
model.layers.0.self_attn.k_proj                              Linear                 (1, 64, 1024)                   3,145,728  3,145,728
model.layers.0.self_attn.v_proj         

In [None]:
# ===== freeze/unfreeze helpers =====
# Example: freeze all base layers, train only the classification head
def freeze_all_but_head(m):
    for p in m.parameters():
        p.requires_grad = False
    # common head names: 'score' (HF seq-classification), else 'classifier'
    head = getattr(m, "score", None) or getattr(m, "classifier", None)
    if head is not None:
        for p in head.parameters():
            p.requires_grad = True

# Example usage:
# freeze_all_but_head(reward_model)
# print("After freezing:")
# total, trainable, frozen = count_params(reward_model)
# print(f"Trainable params:  {trainable:,}  (~{human_mb(trainable):.1f} MB)")

## Training step (pairwise Bradley-Terry loss)

In [21]:
def step_batch(model, batch):
    """
    Perform one training step using the pairwise Bradley-Terry (logistic) loss.
    
    Commonly used in Reward Model training for RLHF or preference learning, 
    where the model learns to assign higher scalar reqards to preferred samples 
    (positive responses) than to non-preferred ones (negative responses).
    
    Intuition:
    - For each (positive, negative) pair, compute the model scores:
        r_pos = model(preferred response)
        r_neg = model(non-preferred response)
    - The model should learn r_pos > r_neg.
    - The probability that the model ranks the positive higher is:
        P(pos > neg) = σ(r_pos - r_neg), where σ(x) is the sigmoid function.
    - The loss encourages this probability to be close to 1.

    Args:
        model: the reward model (a HF seq-classification model outputting scalar logits)
        batch: a dict with keys:
            - "pos": tokenized batch of preferred responses
            - "neg": tokenized batch of non-preferred responses
            - "weight": tensor of shape (B,) with sample weights (e.g., from confidence penalties)
    
    Returns:
        loss: the computed loss (a scalar tensor)
        delta: tensor of shape (B,) with r_pos - r_neg for each pair (detached from graph)
    """

    # Move all inputs to model device
    for side in ("pos","neg"):
        for k in batch[side]:
            batch[side][k] = batch[side][k].to(model.device)
    w = batch["weight"].to(model.device)

    # Forward pass: compute scalar rewards for positive and negative samples
    r_pos = model(**batch["pos"]).logits.squeeze(-1) # shape: (B,)
    r_neg = model(**batch["neg"]).logits.squeeze(-1) # shape: (B,)

    # Pairwise Bradley-Terry (logistic) loss
    delta = r_pos - r_neg
    loss_core = -torch.log(torch.sigmoid(delta) + 1e-8)
    loss = (w * loss_core).mean()
    return loss, delta.detach().cpu()

## Validation (pairwise accuracy + ECE)

In [None]:
def evaluate_riskwise(model, loader, bins=10):
    """
    Evaluate pairwise accuracy and calibration (ECE), and accuracy broken down by risk level.

    Pairwise setting:
    - The model outputs scalar rewards r(x). For each preference pair (pos, neg),
      compute delta = r_pos - r_neg and the probability p = σ(delta) that the positive
      item is preferred.
    
    Metrics:
    
    - Pairwise accuracy:
        acc = (1/N) * Σ 1{delta_i > 0}
        i.e., the fraction of pairs where r_pos > r_neg.
    
    - Expected Calibration Error (ECE):
        ECE = Σ_b (n_b / N) * | acc_b - conf_b |
      where for bin b:
        - n_b = number of samples in bin b
        - acc_b = average correctness in bin b
        - conf_b = average model confidence (p) in bin b
      Bins partition [0, 1] into `bins` equal-width intervals.
      Empty bins are skipped.
    
    - Accuracy by risk level (`acc_by_risk`):
        Dictionary mapping each risk bucket/value to its mean pairwise accuracy.

    Args:
        model: the reward model
        loader: DataLoader providing batches of preference pairs
        bins: number of bins for ECE calculation
    
    Returns:
        {
            "pairwise_accuracy": float,
            "ece": float,
            "acc_by_risk": Dict[Any, float],
            "ece_detail": {
                "bin_edges": np.ndarray,  # shape (bins+1,)
                "bin_counts": np.ndarray, # shape (bins,)
                "bin_acc": np.ndarray,    # shape (bins,)
                "bin_conf": np.ndarray,   # shape (bins,)
            },
            "deltas": np.ndarray,   # shape (N,), r_pos - r_neg for each pair
            "probs": np.ndarray,    # shape (N,), σ(delta)
            "corrects": np.ndarray, # shape (N,), 1 if delta>0 else 0
        }
    """
    model.eval()

    deltas_list = []
    probs_list = []
    corrects_list = []
    risks_all = []

    with torch.no_grad():
        for batch in loader:
            #  move inputs to model device
            for side in ("pos","neg"):
                for k in batch[side]:
                    batch[side][k] = batch[side][k].to(model.device)
            
            # Forward pass: compute scalar rewards for positive and negative samples
            r_pos = model(**batch["pos"]).logits.squeeze(-1) # shape: (B,)
            r_neg = model(**batch["neg"]).logits.squeeze(-1) # shape: (B,)
            
            # Pairwise probabilities and correctness labels
            delta = r_pos - r_neg
            p = torch.sigmoid(delta).cpu().numpy() # shape: (B,)
            c = (delta.cpu().numpy() > 0).astype(np.float32) # shape: (B,)

            deltas_list.append(delta.detach().cpu().numpy())
            probs_list.append(p)
            corrects_list.append(c)

            # Risks from collate_fn
            if "risk" in batch:
                risk_obj = batch["risk"]
                if isinstance(risk_obj, torch.Tensor):
                    risks_all.extend(risk_obj.detach().cpu().numpy().tolist())
                else:
                    risks_all.extend(list(risk_obj))

    # Concatenate
    if len(probs_list) == 0:
        # empty loader edge case
        return {
            "pairwise_accuracy": 0.0,
            "ece": 0.0,
            "acc_by_risk": {},
            "ece_detail": {
                "bin_edges": np.linspace(0.0, 1.0, bins + 1),
                "bin_counts": np.zeros(bins, dtype=int),
                "bin_acc": np.zeros(bins, dtype=float),
                "bin_conf": np.zeros(bins, dtype=float),
            },
            "deltas": np.array([], dtype=np.float32),
            "probs": np.array([], dtype=np.float32),
            "corrects": np.array([], dtype=np.float32),
        }
    
    deltas = np.concatenate(deltas_list).astype(np.float32)   # (N,)
    probs = np.concatenate(probs_list).astype(np.float32)     # (N,)
    corrects = np.concatenate(corrects_list).astype(np.float32)  # (N,)
    N = probs.shape[0]

    # Pairwise accuracy
    pairwise_acc = float(corrects.mean()) if N > 0 else 0.0

    # ECE bins (equal-width on [0,1])
    edges = np.linspace(0.0, 1.0, bins + 1)  # (bins+1,)
    # Map probs in [0,1] to bin ids 0..bins-1
    bin_ids = np.minimum((probs * bins).astype(int), bins - 1)

    bin_counts = np.bincount(bin_ids, minlength=bins).astype(np.int64)
    bin_correct_sum = np.bincount(bin_ids, weights=corrects, minlength=bins).astype(np.float64)
    bin_conf_sum = np.bincount(bin_ids, weights=probs, minlength=bins).astype(np.float64)

    # Averages per bin; avoid divide-by-zero
    bin_acc = np.divide(
        bin_correct_sum, bin_counts,
        out=np.zeros_like(bin_correct_sum, dtype=np.float64),
        where=bin_counts > 0
    )
    bin_conf = np.divide(
        bin_conf_sum, bin_counts,
        out=np.zeros_like(bin_conf_sum, dtype=np.float64),
        where=bin_counts > 0
    )

    nonempty = bin_counts > 0
    ece = float(np.sum((bin_counts[nonempty] / max(N, 1)) * np.abs(bin_acc[nonempty] - bin_conf[nonempty])))

    # Accuracy by risk level
    acc_by_risk = {}
    if len(risks_all) == N and N > 0:
        # Handle mixed types by unique() on array of objects
        unique_risks = sorted(set(risks_all), key=lambda x: str(x))
        for r in unique_risks:
            idx = np.fromiter((rr == r for rr in risks_all), dtype=bool, count=N)
            if idx.any():
                acc_by_risk[r] = float(corrects[idx].mean())

    return {
        "pairwise_accuracy": pairwise_acc,
        "ece": ece,
        "acc_by_risk": acc_by_risk,
        "ece_detail": {
            "bin_edges": edges,
            "bin_counts": bin_counts,
            "bin_acc": bin_acc,
            "bin_conf": bin_conf,
        },
        "deltas": deltas,
        "probs": probs,
        "corrects": corrects,
    }

## Saving/logging functions

In [25]:
def save_ckpt(model, sub="epoch"): # sub: subdirectory name
    out = CKPT_DIR / sub
    out.mkdir(parents=True, exist_ok=True)
    tokenizer.save_pretrained(out); model.save_pretrained(out)
    print(f"[saved] {out}")

def log_train_step(step:int, loss:float, epoch:int):
    with TRAIN_LOG.open("a", encoding="utf-8") as f:
        f.write(json.dumps({"ts": time.time(), "step": step, "epoch": epoch, "loss": float(loss)}) + "\n")

def log_val_epoch(epoch:int, metrics:dict):
    rec = {"ts": time.time(), "epoch": epoch, **metrics}
    with VAL_EPOCH_LOG.open("a", encoding="utf-8") as f:
        f.write(json.dumps(rec) + "\n")

def dump_val_deltas(epoch:int, deltas:list[float]):
    out = VAL_DELTAS_DIR / f"deltas_epoch{epoch:02d}.json"
    out.write_text(json.dumps({"epoch": epoch, "deltas": deltas}), encoding="utf-8")
    return out

def save_metrics(d: dict, name: str):
    p = METRIC_DIR / name
    p.write_text(json.dumps(d, indent=2), encoding="utf-8")
    print(f"[saved] {p}")

# 5. Training & Evaluation <a id='train-eval'></a>
Training loop, evaluation metrics, checkpoints, and logging.

In [None]:
# Optimiser & AMP (automatic mixed precision) setup
optim = AdamW(reward_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
use_cuda = torch.cuda.is_available()
scaler = torch.amp.GradScaler(enabled=use_cuda)

# Early stopping targets: maximise pairwise accuracy, minimise ECE
patience, since_best = 2, 0
best_score = (-1.0, float("inf"))
global_step = 0

for epoch in range(1, EPOCHS+1):
    reward_model.train()
    pbar = tqdm(train_loader, desc=f"epoch {epoch}")

    for batch in pbar:
        optim.zero_grad(set_to_none=True)

        # autocast only on CUDA
        ctx = (torch.amp.autocast(device_type='cuda', enabled=True) if use_cuda
               else contextlib.nullcontext())  # disable on CPU
        with ctx:
            # Pairwise Bradley-Terry (logistic) loss
            loss, _ = step_batch(reward_model, batch) # returns (loss, delta_cpu)
        
        # backprop with gradient scaling
        scaler.scale(loss).backward()

        # unscale before clipping so clipping is applied on real magnitudes
        scaler.unscale_(optim)
        torch.nn.utils.clip_grad_norm_(reward_model.parameters(), GRAD_CLIP)

        # optimiser step through the scaler, then update scale for next iteration
        scaler.step(optim)
        scaler.update()

        # logging per step
        pbar.set_postfix(loss=f"{loss.item():.4f}")
        global_step += 1
        log_train_step(global_step, loss.item(), epoch)

    # Validation & compute metrics
    val_metrics = evaluate_riskwise(reward_model, val_loader, bins=10)
    all_deltas = val_metrics["deltas"] # for histograms
    save_metrics(val_metrics, f"reward_eval_val_epoch{epoch:02d}.json")
    log_val_epoch(epoch, val_metrics)
    dump_val_deltas(epoch, all_deltas)
    print(f"val (epoch {epoch}): {val_metrics['pairwise_accuracy']:.4f}, ECE={val_metrics['ece']:.4f}")

    # Model selection & early stopping
    # Prefer higher accuracy and lower ECE; compare via (acc, -ece)
    score = (val_metrics["pairwise_accuracy"], -val_metrics["ece"])
    if score > best_score:
        best_score = score
        save_ckpt(reward_model, sub="best")
        since_best = 0
    else:
        since_best += 1
        if since_best >= patience:
            print(f"Early stopping at epoch {epoch}")
            break
    
    # Always keep a per-epoch checkpoint for traceability
    save_ckpt(reward_model, sub=f"epoch{epoch:02d}")

# 6. Visualisations <a id='viz'></a>

## Training Loss by Steps (Smoothed)

In [None]:
steps, losses = [], []
with open(METRIC_DIR / "train_steps.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        rec = json.loads(line)
        steps.append(int(rec["step"]))
        losses.append(float(rec["loss"]))

# inline moving average
def _moving_avg(x, w=50):
    if len(x) < w: 
        return np.array(x, dtype=float)
    c = np.cumsum(np.insert(x, 0, 0.0))
    return (c[w:] - c[:-w]) / float(w)

sm = _moving_avg(losses, w=50)
xs = np.linspace(min(steps), max(steps), num=len(sm))

plt.figure(figsize=(7,4))
plt.plot(xs, sm, label="Train Loss (smoothed)")
plt.title("Training Loss (by step, moving avg)")
plt.xlabel("Training Step")
plt.ylabel("Loss")
plt.grid(True, linestyle="--", alpha=0.4)
plt.legend()
plt.tight_layout()

save_path = METRIC_DIR / "train_loss_by_step.png"
plt.savefig(save_path, dpi=300)
plt.show()
print(f"Saved: {save_path}")

## Training VS Validation Loss by Epoch

In [None]:
# Aggregate train loss per epoch from JSONL
epoch_sum, epoch_cnt = {}, {}
with open(METRIC_DIR / "train_steps.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        rec = json.loads(line)
        ep, loss = int(rec["epoch"]), float(rec["loss"])
        epoch_sum[ep] = epoch_sum.get(ep, 0.0) + loss
        epoch_cnt[ep] = epoch_cnt.get(ep, 0) + 1

train_epochs = sorted(epoch_sum.keys())
train_epoch_loss = [epoch_sum[ep] / epoch_cnt[ep] for ep in train_epochs] # average loss per epoch

# Load validation (use real val_loss if logged; else proxy from accuracy)
val_epochs, val_losses = [], []
for jpath in sorted(METRIC_DIR.glob("reward_eval_val_epoch*.json")):
    with open(jpath, "r", encoding="utf-8") as f:
        rec = json.load(f)
    ep = int(jpath.stem.split("epoch")[-1])
    if "val_loss" in rec:
        vloss = float(rec["val_loss"])
    else:
        # proxy if didn't log val loss
        vloss = 1.0 - float(rec["pairwise_accuracy"])
    val_epochs.append(ep)
    val_losses.append(vloss)

plt.figure(figsize=(7,4))
plt.plot(train_epochs, train_epoch_loss, "o-", label="Train Loss (per epoch)")
if val_losses:
    plt.plot(val_epochs, val_losses, "o-", label="Validation Loss", alpha=0.9)
plt.title("Train vs Validation Loss (per epoch)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True, linestyle="--", alpha=0.4)
plt.legend()
plt.tight_layout()

save_path = METRIC_DIR / "train_val_loss_by_epoch.png"
plt.savefig(save_path, dpi=300)
plt.show()
print(f"Saved: {save_path}")

## Validation Pairwaise Accuracy & ECE (By Epoch)

In [None]:
epochs, accs, eces = [], [], []
with open(METRIC_DIR / "val_epochs.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        rec = json.loads(line)
        epochs.append(rec["epoch"])
        accs.append(rec["pairwise_accuracy"])
        eces.append(rec["ece"])

# Validation pairwise accuracy by epoch
plt.figure(figsize=(7,4))
plt.plot(epochs, accs, marker="o", color="tab:blue")
plt.title("Validation Pairwise Accuracy by Epoch")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.xticks(epochs)
plt.grid(True, linestyle="--", alpha=0.4)
plt.tight_layout()

save_path = METRIC_DIR / "val_accuracy_by_epoch.png"
plt.savefig(save_path, dpi=300)
plt.show()
print(f"Saved: {save_path}")

# Validation ECE by epoch
plt.figure(figsize=(7,4))
plt.plot(epochs, eces, marker="o", color="tab:orange")
plt.title("Validation Calibration (ECE, 10 bins)")
plt.xlabel("Epoch")
plt.ylabel("ECE")
plt.xticks(epochs)
plt.grid(True, linestyle="--", alpha=0.4)
plt.tight_layout()

save_path = METRIC_DIR / "val_ece_by_epoch.png"
plt.savefig(save_path, dpi=300)
plt.show()
print(f"Saved: {save_path}")

## Accuracy by Risk (Latest Epoch)

In [None]:
# Read the latest epoch record
last = None
with open(METRIC_DIR / "val_epochs.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        last = json.loads(line)

if last is None:
    raise ValueError("No validation record found in val_epochs.jsonl.")

# Extract accuracy by risk bins (dict of risk_bin -> accuracy)
acc_by_risk = last.get("acc_by_risk", {})
labels = list(acc_by_risk.keys())
vals = [acc_by_risk[k] for k in labels]

plt.figure(figsize=(6,4))
plt.bar(labels, vals, color="tab:blue", alpha=0.8)
plt.title(f"Accuracy by Risk (Epoch {last['epoch']})")
plt.xlabel("Risk Bin")
plt.ylabel("Accuracy")
plt.ylim(0, 1)
plt.grid(axis="y", linestyle="--", alpha=0.4)
plt.tight_layout()

save_path = METRIC_DIR / f"accuracy_by_risk_epoch{last['epoch']:02d}.png"
plt.savefig(save_path, dpi=300)
plt.show()
print(f"Saved: {save_path}")

## Validation Δ = r_pos - r_neg Histogram (Latest Epoch)

In [None]:
# Load the latest delta file
delta_files = sorted(glob.glob(str(VAL_DELTAS_DIR / "deltas_epoch*.json")))
if not delta_files:
    raise FileNotFoundError(f"No delta files found in: {VAL_DELTAS_DIR}")

with open(delta_files[-1], "r", encoding="utf-8") as f:
    rec = json.load(f)

deltas = rec.get("deltas", [])
epoch = rec.get("epoch", "unknown")

# Plot histogram of Δ margins
plt.figure(figsize=(7,4))
plt.hist(deltas, bins=50, color="tab:purple", alpha=0.75, edgecolor="black")
plt.title(f"Validation Δ = r_pos - r_neg (Epoch {epoch})")
plt.xlabel("Δ margin (r_pos - r_neg)")
plt.ylabel("Count")
plt.grid(True, linestyle="--", alpha=0.4)
plt.tight_layout()

save_path = METRIC_DIR / f"val_deltas_hist_epoch{epoch:02d}.png"
plt.savefig(save_path, dpi=300)
plt.show()
print(f"Saved: {save_path}")