In [15]:
from pathlib import Path
import json, random, numpy as np
import torch
from collections import defaultdict, Counter
from typing import Dict, Any, Tuple
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [16]:
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"

CUR_DIR = Path.cwd()
SRC_DIR = CUR_DIR.parent

DATA_DIR = SRC_DIR / "data" / "processed"
PAIRS_PATH = DATA_DIR / "preference_pairs.json"
#TEST_QNA_PATH = DATA_DIR / "test_qna_with_confidence.json"

TRAIN_PATH  = DATA_DIR / "preference_pairs_train.json"
VAL_PATH   = DATA_DIR / "preference_pairs_val.json"

CKPT_DIR = CUR_DIR / "checkpoints"
METRIC_DIR = CUR_DIR / "metrics"

DROP_WEAK = False # set True later if want cleaner supervision

TRAIN_RATIO = 0.9
VAL_RATIO = 0.1
TEST_RATIO = 0.0

SEED = 42
random.seed(SEED); np.random.seed(SEED)

In [17]:
for p in [DATA_DIR, CKPT_DIR, METRIC_DIR]:
    p.mkdir(parents=True, exist_ok=True)

def _read_json(p: Path):
    return json.loads(p.read_text(encoding="utf-8"))

def _write_json(obj, p: Path):
    p.write_text(json.dumps(obj, indent=2, ensure_ascii=False), encoding="utf-8")
    print(f"[saved] {p}")

# Data splitting

In [18]:
def split_preference_pairs(
    pairs_path: Path,
    out_dir: Path,
    train_ratio: float = 0.9,
    val_ratio: float = 0.1,
    test_ratio: float = 0.0,
    seed: int = 42,
    drop_weak: bool = False,
) -> Dict[str, Any]:
    """
    Split preference_pairs.json into train/val(/test) without question leakage.

    Args:
        pairs_path: path to preference_pairs.json
        out_dir: directory to write split files
        train_ratio, val_ratio, test_ratio: must sum to 1.0 (test optional)
        seed: RNG seed for reproducibility
        drop_weak: if True, drops pairs with score_difference==0 and pairs where both ratings == 1

    Writes:
        preference_pairs.train.json
        preference_pairs.val.json
        (optional) preference_pairs.test.json
        preference_pairs_split_stats.json

    Returns:
        stats dict with counts, distributions, and dropped info.
    """
    assert abs((train_ratio + val_ratio + test_ratio) - 1.0) < 1e-6, "Ratios must sum to 1.0"
    out_dir.mkdir(parents=True, exist_ok=True)

    def _rating_score(ans: Dict[str, Any]) -> int:
        # Prefer explicit numeric 'score'; fallback parse from 'rating' like "3-Incomplete"
        if "score" in ans and ans["score"] is not None:
            return int(ans["score"])
        r = ans.get("rating", "")
        return int(r[0]) if isinstance(r, str) and r[:1].isdigit() else -1  # -1 if unknown

    # Load data
    pairs = _read_json(pairs_path)

    # Optional filtering
    dropped = {"score_diff_zero": 0, "both_score1": 0} # counts how many pairs were removed for each reason
    clean = []
    if drop_weak:
        for ex in pairs: # ex = preference pair dict
            if ex.get("score_difference", None) == 0:
                dropped["score_diff_zero"] += 1
                continue
            sp = _rating_score(ex.get("preferred_answer", {}))
            sr = _rating_score(ex.get("rejected_answer", {}))
            if sp == 1 and sr == 1:
                dropped["both_score1"] += 1
                continue
            clean.append(ex) # only store if passes filters
    else:
        clean = pairs

    # Group by question_id (no leakage across splits)
    by_qid = defaultdict(list)
    for ex in clean:
        qid = ex["question_id"]
        by_qid[qid].append(ex)

    qids = list(by_qid.keys())
    random.seed(seed)
    random.shuffle(qids)

    # Stratify by risk level (keeps similar risk mix across splits)
    def majority_risk(items):
        return Counter(e.get("risk_level", "UNKNOWN") for e in items).most_common(1)[0][0]

    risk_to_qids = defaultdict(list)
    for qid in qids:
        risk_to_qids[majority_risk(by_qid[qid])].append(qid)

    train_qids, val_qids, test_qids = set(), set(), set()
    for _, bucket in risk_to_qids.items():
        random.shuffle(bucket)
        n = len(bucket)
        n_train = int(round(n * train_ratio))
        n_val   = int(round(n * val_ratio))
        # ensure total fits by assigning remainder to test
        n_test  = n - n_train - n_val
        train_qids.update(bucket[:n_train])
        val_qids.update(bucket[n_train:n_train+n_val])
        test_qids.update(bucket[n_train+n_val:])

    # Materialise splits
    train_pairs = [e for q in train_qids for e in by_qid[q]]
    val_pairs   = [e for q in val_qids   for e in by_qid[q]]
    test_pairs  = [e for q in test_qids  for e in by_qid[q]] if test_ratio > 0 else []

    # Save
    _write_json(train_pairs, out_dir / "preference_pairs_train.json")
    _write_json(val_pairs,   out_dir / "preference_pairs_val.json")
    if test_ratio > 0:
        _write_json(test_pairs, out_dir / "preference_pairs_test.json")

    def _dist(lst):
        return dict(Counter(e.get("risk_level", "UNKNOWN") for e in lst))

    stats = {
        "seed": seed,
        "drop_weak": drop_weak,
        "dropped": dropped,
        "counts": {
            "train": len(train_pairs),
            "val":   len(val_pairs),
            "test":  len(test_pairs),
        },
        "unique_qids": {
            "train": len(train_qids),
            "val":   len(val_qids),
            "test":  len(test_qids),
        },
        "risk_dist": {
            "train": _dist(train_pairs),
            "val":   _dist(val_pairs),
            "test":  _dist(test_pairs),
        },
    }

    _write_json(stats, out_dir / "preference_pairs_split_stats.json")
    return stats

In [19]:
stats = split_preference_pairs(
    pairs_path = PAIRS_PATH,
    out_dir    = DATA_DIR,
    train_ratio= TRAIN_RATIO,     # 90/10 train/val
    val_ratio  = VAL_RATIO,
    test_ratio = TEST_RATIO,     # set to 0.1 if want a test split too
    seed       = SEED,
    drop_weak  = DROP_WEAK # change to True later if want cleaner pairs
)
stats

[saved] c:\Users\Crescent\OneDrive\Personal\my files\3. nus\Y3S1\DSA4213 Natural Language Processing for Data Science\Project\GitHub\SelfTraining-Reward-MedQA\src\data\processed\preference_pairs_train.json
[saved] c:\Users\Crescent\OneDrive\Personal\my files\3. nus\Y3S1\DSA4213 Natural Language Processing for Data Science\Project\GitHub\SelfTraining-Reward-MedQA\src\data\processed\preference_pairs_val.json
[saved] c:\Users\Crescent\OneDrive\Personal\my files\3. nus\Y3S1\DSA4213 Natural Language Processing for Data Science\Project\GitHub\SelfTraining-Reward-MedQA\src\data\processed\preference_pairs_split_stats.json


{'seed': 42,
 'drop_weak': False,
 'dropped': {'score_diff_zero': 0, 'both_score1': 0},
 'counts': {'train': 27775, 'val': 3327, 'test': 0},
 'unique_qids': {'train': 92, 'val': 11, 'test': 0},
 'risk_dist': {'train': {'Low Risk': 6389,
   'High Risk': 11715,
   'Medium Risk': 9671},
  'val': {'High Risk': 1833, 'Medium Risk': 878, 'Low Risk': 616},
  'test': {}}}

# Load splits

In [20]:
TRAIN_PATH  = DATA_DIR / "preference_pairs_train.json"
VAL_PATH   = DATA_DIR / "preference_pairs_val.json"

train_pairs = _read_json(TRAIN_PATH)
val_pairs = _read_json(VAL_PATH)
len(train_pairs), len(val_pairs)

(27775, 3327)

# Dataset + Dataloader

In [21]:
MAX_LEN = 1024
BETA = 3.0  # scales confidence_penalty into a sample weight
BATCH_SIZE = 2

In [22]:
# Load tokenizer (ensure PAD exists)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [23]:
def build_input(q, a):
    return f"Question:\n{q}\n\nAnswer:\n{a}"

# Dataset and collate
class PrefPairDataset(Dataset):
    def __init__(self, items):
        self.items = items
    def __len__(self):
        return len(self.items)
    def __getitem__(self, i):
        ex = self.items[i]
        x_pos = build_input(ex["question_text"], ex["preferred_answer"]["answer_text"])
        x_neg = build_input(ex["question_text"], ex["rejected_answer"]["answer_text"])
        pos = tokenizer(x_pos, truncation=True, max_length=MAX_LEN, return_tensors="pt")
        neg = tokenizer(x_neg, truncation=True, max_length=MAX_LEN, return_tensors="pt")
        w = 1.0 + BETA * float(ex.get("confidence_penalty", 0.0))
        return {"pos": pos, "neg": neg, "weight": torch.tensor(w, dtype=torch.float)}

def collate_fn(batch):
    def stack(side):
        ids  = [b[side]["input_ids"].squeeze(0) for b in batch]
        attn = [b[side]["attention_mask"].squeeze(0) for b in batch]
        ids  = torch.nn.utils.rnn.pad_sequence(ids,  batch_first=True, padding_value=tokenizer.pad_token_id)
        attn = torch.nn.utils.rnn.pad_sequence(attn, batch_first=True, padding_value=0)
        return {"input_ids": ids, "attention_mask": attn}
    weights = torch.stack([b["weight"] for b in batch])
    return {"pos": stack("pos"), "neg": stack("neg"), "weight": weights}

# build loaders
train_loader = DataLoader(PrefPairDataset(train_pairs), batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate_fn)
val_loader   = DataLoader(PrefPairDataset(val_pairs),   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
len(train_loader), len(val_loader)

(13888, 1664)

In [24]:
# Reward model (1 scalar output)
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # CPU requires float32
device_map = "auto" if torch.cuda.is_available() else None
max_memory = {0: "13GB"} if torch.cuda.is_available() else None

reward_model  = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=1,
    torch_dtype=dtype,
    device_map=device_map,
    low_cpu_mem_usage=True,
    max_memory=max_memory
)

# align special tokens / training toggles
reward_model.config.pad_token_id = tokenizer.pad_token_id
reward_model.config.eos_token_id = tokenizer.eos_token_id
reward_model.config.bos_token_id = getattr(tokenizer, "bos_token_id", tokenizer.eos_token_id)
reward_model.config.use_cache = False  # safer during training

print("Device:", "CUDA" if torch.cuda.is_available() else "CPU")
print("Loaded:", MODEL_NAME)

Loading checkpoint shards: 100%|██████████| 2/2 [00:20<00:00, 10.32s/it]
Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-3.2-3B-Instruct and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Device: CPU
Loaded: meta-llama/Llama-3.2-3B-Instruct


## check model architecture

In [30]:
# ===== Keras-like summary for HF models =====
import torch, math
from collections import OrderedDict

def _num_params(module):
    total = sum(p.numel() for p in module.parameters(recurse=False))
    train = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad)
    return total, train

def _shape_str(x):
    if isinstance(x, torch.Tensor):
        return tuple(x.shape)
    if isinstance(x, (list, tuple)):
        # show first few shapes if it's a tuple/list
        parts = []
        for i, it in enumerate(x[:3]):
            if isinstance(it, torch.Tensor):
                parts.append(str(tuple(it.shape)))
            else:
                parts.append(type(it).__name__)
        if len(x) > 3:
            parts.append("...")
        return "[" + ", ".join(parts) + "]"
    return type(x).__name__

def model_summary(model, tokenizer, seq_len=64, verbose=True):
    """
    Prints a Keras-like table:
      Layer (path) | Type | Output Shape | Param # (trainable)
    Uses a dummy forward pass with batch_size=1 and input length = seq_len.
    """
    device = next(model.parameters()).device
    dtype  = next(model.parameters()).dtype
    bytes_per = 2 if dtype in (torch.float16, torch.bfloat16) else 4

    # 1) Prepare a tiny dummy batch
    with torch.no_grad():
        toks = tokenizer("dummy", return_tensors="pt", add_special_tokens=True)
        # expand to desired seq_len (pad tokens)
        input_ids = torch.full((1, seq_len), fill_value=tokenizer.pad_token_id, dtype=torch.long)
        attn_mask = torch.zeros((1, seq_len), dtype=torch.long)
        # put some non-pad tokens at the start
        L = min(toks["input_ids"].shape[1], seq_len)
        input_ids[:, :L] = toks["input_ids"][:, :L]
        attn_mask[:, :L] = 1
        batch = {"input_ids": input_ids.to(device), "attention_mask": attn_mask.to(device)}

    # 2) Collect module info (param counts) and register forward hooks for output shapes
    names = {m: n for n, m in model.named_modules()}  # module -> qualified name
    layer_infos = OrderedDict()      # preserve execution order
    handles = []

    def hook(module, inputs, outputs):
        name = names.get(module, module.__class__.__name__)
        if name == "": name = module.__class__.__name__
        if name not in layer_infos:
            p_total, p_train = _num_params(module)
            layer_infos[name] = {
                "type": module.__class__.__name__,
                "params": p_total,
                "trainable": p_train,
                "out": _shape_str(outputs),
            }

    # only hook leaf modules to avoid noisy duplicates
    for m in model.modules():
        is_leaf = len(list(m.children())) == 0
        if is_leaf:
            try:
                handles.append(m.register_forward_hook(hook))
            except Exception:
                pass

    # 3) Forward pass (to populate shapes)
    model.eval()
    with torch.no_grad():
        _ = model(**batch)

    for h in handles:
        h.remove()

    # 4) Print table
    total_params = sum(v["params"] for v in layer_infos.values())
    train_params = sum(v["trainable"] for v in layer_infos.values())
    frozen_params = total_params - train_params
    total_mb   = (total_params * bytes_per) / (1024**2)
    train_mb   = (train_params * bytes_per) / (1024**2)
    frozen_mb  = (frozen_params * bytes_per) / (1024**2)

    header = f"{'Layer (path)':60} {'Type':22} {'Output Shape':28} {'Param #':>12} {'Trainable':>10}"
    line   = "-" * len(header)
    if verbose:
        print(header)
        print(line)
        for name, info in layer_infos.items():
            print(f"{name:60} {info['type']:22} {str(info['out'])[:28]:28} {info['params']:12,} {info['trainable']:10,}")
        print(line)
        print(f"Total params:     {total_params:,}  (~{total_mb:.1f} MB at {str(dtype).replace('torch.','')})")
        print(f"Trainable params: {train_params:,}  (~{train_mb:.1f} MB)")
        print(f"Frozen params:    {frozen_params:,}  (~{frozen_mb:.1f} MB)")
        print(f"Seq len used for summary: {seq_len}, batch size: 1")

    # 5) Return a dict if you want to use programmatically
    return {
        "layers": layer_infos,
        "totals": {
            "total_params": total_params,
            "trainable_params": train_params,
            "frozen_params": frozen_params,
            "dtype": str(dtype),
            "mb_total": total_mb,
            "mb_trainable": train_mb,
            "mb_frozen": frozen_mb,
            "bytes_per_param": bytes_per,
            "seq_len": seq_len,
        }
    }

# ---- run it
_ = model_summary(reward_model, tokenizer, seq_len=64)


Layer (path)                                                 Type                   Output Shape                      Param #  Trainable
----------------------------------------------------------------------------------------------------------------------------------------
model.embed_tokens                                           Embedding              (1, 64, 3072)                 394,002,432 394,002,432
model.rotary_emb                                             LlamaRotaryEmbedding   [(1, 64, 128), (1, 64, 128)]            0          0
model.layers.0.input_layernorm                               LlamaRMSNorm           (1, 64, 3072)                       3,072      3,072
model.layers.0.self_attn.q_proj                              Linear                 (1, 64, 3072)                   9,437,184  9,437,184
model.layers.0.self_attn.k_proj                              Linear                 (1, 64, 1024)                   3,145,728  3,145,728
model.layers.0.self_attn.v_proj         

In [26]:
# ===== Model introspection (counts, layers, arch) =====

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable, (total-trainable)

def human_mb(n_params):  # params → MB assuming 4 bytes/param for fp32, ~2 bytes for fp16
    bytes_per = 2 if next(reward_model.parameters()).dtype in (torch.float16, torch.bfloat16) else 4
    return (n_params * bytes_per) / (1024**2)

total, trainable, frozen = count_params(reward_model)
print(f"Total params:      {total:,}  (~{human_mb(total):.1f} MB)")
print(f"Trainable params:  {trainable:,}  (~{human_mb(trainable):.1f} MB)")
print(f"Frozen params:     {frozen:,}  (~{human_mb(frozen):.1f} MB)")

Total params:      3,212,752,896  (~12255.7 MB)
Trainable params:  3,212,752,896  (~12255.7 MB)
Frozen params:     0  (~0.0 MB)


In [27]:
# List trainable layers with param counts (top-N largest first)
train_layers = []
for n, p in reward_model.named_parameters():
    if p.requires_grad:
        train_layers.append((n, p.numel()))
train_layers.sort(key=lambda x: x[1], reverse=True)

print("\nTop trainable layers (name → #params):")
for n, sz in train_layers[:20]:  # print top 20 to keep it short
    print(f"  {n:60s} {sz:,}")


Top trainable layers (name → #params):
  model.embed_tokens.weight                                    394,002,432
  model.layers.0.mlp.gate_proj.weight                          25,165,824
  model.layers.0.mlp.up_proj.weight                            25,165,824
  model.layers.0.mlp.down_proj.weight                          25,165,824
  model.layers.1.mlp.gate_proj.weight                          25,165,824
  model.layers.1.mlp.up_proj.weight                            25,165,824
  model.layers.1.mlp.down_proj.weight                          25,165,824
  model.layers.2.mlp.gate_proj.weight                          25,165,824
  model.layers.2.mlp.up_proj.weight                            25,165,824
  model.layers.2.mlp.down_proj.weight                          25,165,824
  model.layers.3.mlp.gate_proj.weight                          25,165,824
  model.layers.3.mlp.up_proj.weight                            25,165,824
  model.layers.3.mlp.down_proj.weight                          25,165,8

In [28]:
# Show the classification head shape
if hasattr(reward_model, "score"):
    print("\nClassification head (reward scalar) weight shape:", tuple(reward_model.score.weight.shape))
elif hasattr(reward_model, "classifier"):
    print("\nClassification head (reward scalar) weight shape:", tuple(reward_model.classifier.weight.shape))


Classification head (reward scalar) weight shape: (1, 3072)


In [29]:
# Short architecture sketch (top-level only)
print("\nTop-level modules:")
for name, module in reward_model.named_children():
    print(f"  - {name}: {module.__class__.__name__}")


Top-level modules:
  - model: LlamaModel
  - score: Linear


In [None]:
# ===== OPTIONAL: freeze/unfreeze helpers =====
# Example: freeze all base layers, train only the classification head
def freeze_all_but_head(m):
    for p in m.parameters():
        p.requires_grad = False
    # common head names: 'score' (HF seq-classification), else 'classifier'
    head = getattr(m, "score", None) or getattr(m, "classifier", None)
    if head is not None:
        for p in head.parameters():
            p.requires_grad = True

# Example usage:
# freeze_all_but_head(reward_model)
# print("After freezing:")
# total, trainable, frozen = count_params(reward_model)
# print(f"Trainable params:  {trainable:,}  (~{human_mb(trainable):.1f} MB)")

# Training step (pairwise Bradley-Terry loss)

In [31]:
from torch.optim import AdamW
from tqdm.auto import tqdm
import numpy as np, math, json

def step_batch(model, batch):
    for side in ("pos","neg"):
        for k in batch[side]:
            batch[side][k] = batch[side][k].to(model.device)
    w = batch["weight"].to(model.device)

    r_pos = model(**batch["pos"]).logits.squeeze(-1)  # (B,)
    r_neg = model(**batch["neg"]).logits.squeeze(-1)
    delta = r_pos - r_neg
    loss_core = -torch.log(torch.sigmoid(delta) + 1e-8)   # pairwise logistic
    loss = (w * loss_core).mean()
    return loss, delta.detach().cpu()

optim = AdamW(reward_model.parameters(), lr=2e-5)
EPOCHS = 5
GRAD_CLIP = 1.0
best_val = -1.0
best_path = CKPT_DIR / "best"

def save_ckpt(model, sub="epoch"): # sub: subdirectory name
    out = CKPT_DIR / sub
    out.mkdir(parents=True, exist_ok=True)
    tokenizer.save_pretrained(out); model.save_pretrained(out)
    print(f"[saved] {out}")


# Validation (pairwise accuracy + ECE)

In [32]:
def evaluate(model, loader, bins=10):
    model.eval()
    probs, corrects = [], []
    with torch.no_grad():
        for batch in loader:
            _, delta = step_batch(model, batch)
            p = torch.sigmoid(delta).numpy()
            c = (delta.numpy() > 0).astype(np.float32)
            probs.extend(p.tolist()); corrects.extend(c.tolist())

    probs = np.array(probs); corrects = np.array(corrects)
    acc = float(corrects.mean())

    # ECE (Expected Calibration Error)
    edges = np.linspace(0,1,bins+1)
    ece = 0.0
    for i in range(bins):
        m = (probs >= edges[i]) & (probs < edges[i+1])
        if m.any():
            ece += abs(corrects[m].mean() - probs[m].mean()) * (m.sum()/len(probs))
    return {"pairwise_accuracy": acc, "ece": float(ece)}


# Train loop + save metrics & best

## Lightweight logger (JSONL for steps, JSON for epochs)

In [33]:
import json, time
from pathlib import Path

LOG_DIR = METRIC_DIR  # reuse your metrics/ folder
LOG_DIR.mkdir(parents=True, exist_ok=True)
TRAIN_LOG = LOG_DIR / "train_steps.jsonl"         # per-step loss
VAL_EPOCH_LOG = LOG_DIR / "val_epochs.jsonl"      # per-epoch metrics
VAL_DELTAS_DIR = LOG_DIR / "val_deltas"           # per-epoch margins for histogram
VAL_DELTAS_DIR.mkdir(exist_ok=True)

def log_train_step(step:int, loss:float, epoch:int):
    with TRAIN_LOG.open("a", encoding="utf-8") as f:
        f.write(json.dumps({"ts": time.time(), "step": step, "epoch": epoch, "loss": float(loss)}) + "\n")

def log_val_epoch(epoch:int, metrics:dict):
    rec = {"ts": time.time(), "epoch": epoch, **metrics}
    with VAL_EPOCH_LOG.open("a", encoding="utf-8") as f:
        f.write(json.dumps(rec) + "\n")

def dump_val_deltas(epoch:int, deltas:list[float]):
    out = VAL_DELTAS_DIR / f"deltas_epoch{epoch:02d}.json"
    out.write_text(json.dumps({"epoch": epoch, "deltas": deltas}), encoding="utf-8")
    return out

## model training, record train loss & store val deltas each epoch

In [34]:
def save_metrics(d: dict, name: str):
    p = METRIC_DIR / name
    p.write_text(json.dumps(d, indent=2), encoding="utf-8")
    print(f"[saved] {p}")

In [None]:
import contextlib, math, random, numpy as np, torch

use_cuda = torch.cuda.is_available()
scaler = torch.amp.GradScaler(enabled=use_cuda)

patience, since_best = 2, 0
best_score = (-1.0, float("inf"))  # (acc, -ece) target is max
global_step = 0

for epoch in range(1, EPOCHS+1):
    reward_model.train()
    pbar = tqdm(train_loader, desc=f"epoch {epoch}")

    for batch in pbar:
        optim.zero_grad(set_to_none=True)

        # autocast only on CUDA
        ctx = (torch.amp.autocast(device_type='cuda', enabled=True) if use_cuda
               else contextlib.nullcontext())  # disable on CPU
        with ctx:
            loss, _ = step_batch(reward_model, batch) # returns (loss, delta_cpu)
        
        # backprop with gradient scaling
        scaler.scale(loss).backward()
        scaler.unscale_(optim)
        torch.nn.utils.clip_grad_norm_(reward_model.parameters(), GRAD_CLIP)
        scaler.step(optim)
        scaler.update()

        pbar.set_postfix(loss=f"{loss.item():.4f}")
        global_step += 1
        log_train_step(global_step, loss.item(), epoch)

    # validation (also collect deltas for histogram)
    reward_model.eval()
    all_deltas = []
    with torch.no_grad():
        for batch in val_loader:
            _, delta = step_batch(reward_model, batch)   # delta already detached
            all_deltas.extend(delta.cpu().numpy().tolist())

    val_metrics = evaluate(reward_model, val_loader, bins=10)  # or evaluate_riskwise(...)
    save_metrics(val_metrics, f"reward_eval_val_epoch{epoch:02d}.json")
    log_val_epoch(epoch, val_metrics)
    dump_val_deltas(epoch, all_deltas)
    print("val:", val_metrics)

    # select best by (acc ↑, ECE ↓)
    score = (val_metrics["pairwise_accuracy"], -val_metrics["ece"])
    if score > best_score:
        best_score = score
        save_ckpt(reward_model, sub="best")
        since_best = 0
    else:
        since_best += 1
        if since_best >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

    save_ckpt(reward_model, sub=f"epoch{epoch:02d}")

epoch 1:   0%|          | 0/13888 [00:00<?, ?it/s]

epoch 1:   0%|          | 0/13888 [02:04<?, ?it/s]


KeyboardInterrupt: 

## plot training vs val loss

In [None]:
import json, numpy as np, matplotlib.pyplot as plt

# load step losses
steps, losses = [], []
with open(METRIC_DIR / "train_steps.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        rec = json.loads(line)
        steps.append(rec["step"]); losses.append(rec["loss"])

# smooth a bit for readability (moving avg window=50)
def moving_avg(x, w=50):
    if len(x) < w: return np.array(x, dtype=float)
    c = np.cumsum(np.insert(x, 0, 0.0))
    return (c[w:] - c[:-w]) / float(w)

sm = moving_avg(losses, w=50)
xs = np.arange(len(sm))  # pseudo-steps for smoothed series

# get per-epoch val loss proxy from ECE or 1-acc? plot accuracy separately.
# If want val loss, can compute it in evaluate() and log it; for now we do train loss only.
plt.figure(figsize=(7,4))
plt.plot(xs, sm)
plt.title("Training loss (moving avg)")
plt.xlabel("Step (smoothed index)")
plt.ylabel("Loss")
plt.show()


## plot validation Pairwaise Accuracy & ECE vs epoch

In [None]:
import json, matplotlib.pyplot as plt

epochs, accs, eces = [], [], []
with open(METRIC_DIR / "val_epochs.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        rec = json.loads(line)
        epochs.append(rec["epoch"])
        accs.append(rec["pairwise_accuracy"])
        eces.append(rec["ece"])

plt.figure(figsize=(7,4))
plt.plot(epochs, accs, marker="o")
plt.title("Validation Pairwise Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.xticks(epochs)
plt.show()

plt.figure(figsize=(7,4))
plt.plot(epochs, eces, marker="o")
plt.title("Validation Calibration (ECE, 10 bins)")
plt.xlabel("Epoch")
plt.ylabel("ECE")
plt.xticks(epochs)
plt.show()


## Bar chart: validation accuracy by risk bucket (per latest epoch)

In [None]:
import json, matplotlib.pyplot as plt

# read the latest epoch record
last = None
with open(METRIC_DIR / "val_epochs.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        last = json.loads(line)

acc_by_risk = last.get("acc_by_risk", {})
labels = list(acc_by_risk.keys())
vals = [acc_by_risk[k] for k in labels]

plt.figure(figsize=(6,4))
plt.bar(labels, vals)
plt.title(f"Accuracy by risk (epoch {last['epoch']})")
plt.ylabel("Accuracy")
plt.ylim(0,1)
plt.show()

## histogram of validation margins Δ = r⁺−r⁻ (latest epoch)

In [None]:
import json, glob, matplotlib.pyplot as plt

delta_files = sorted(glob.glob(str(VAL_DELTAS_DIR / "deltas_epoch*.json")))
with open(delta_files[-1], "r", encoding="utf-8") as f:
    rec = json.loads(f.read())
deltas = rec["deltas"]

plt.figure(figsize=(7,4))
plt.hist(deltas, bins=50)
plt.title(f"Validation Δ = r_pos - r_neg (epoch {rec['epoch']})")
plt.xlabel("Δ margin")
plt.ylabel("Count")
plt.show()
