In this Colab notebook, you'll do on-policy distillation (OPD) on Qwen3-0.6b using Qwen3-4b-Instruct-2507, to make it better at [GSM8K](https://huggingface.co/datasets/openai/gsm8k) (a dataset of math problems).

Unlike standard supervised fine-tuning (SFT), the student model (Qwen-3-0.6b) learns from its own generated outputs rather than fixed gold data — reducing exposure bias and better matching the inference-time distribution.

You'll need to connect an A100 GPU (40 GB Ram) or better. You might be able to get away with smaller GPUs if you change some of the config parameters, like samples_per_prompt and max_new_tokens!

Inspired by [Thinking Machines](https://thinkingmachines.ai/blog/on-policy-distillation/) and prior art like [Agarwal et al](https://arxiv.org/abs/2306.13649).

In [1]:
#@title 🛠️ Setup
!nvidia-smi -L || true

import os, sys, random, numpy as np, torch, json, time, platform, math
print("Python:", sys.version)
print("CUDA available:", torch.cuda.is_available())

# Qwen3 requires Transformers >= 4.51
try:
    get_ipython().run_line_magic("uv", "pip -q install transformers==4.51.3 accelerate==1.4.0 peft==0.14.0 datasets==3.3.2 evaluate==0.4.3 sentencepiece protobuf tqdm matplotlib > /dev/null")
except Exception:
    get_ipython().run_line_magic("pip", "-q install transformers==4.51.3 accelerate==1.4.0 peft==0.14.0 datasets==3.3.2 evaluate==0.4.3 sentencepiece protobuf tqdm matplotlib > /dev/null")

import transformers, datasets, peft, accelerate, matplotlib
print("Transformers:", transformers.__version__)
print("Accelerate:", accelerate.__version__)
print("PEFT:", peft.__version__)
print("matplotlib:", matplotlib.__version__)

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
assert DEVICE == "cuda", "Please connect a GPU (A100+ recommended)."

def print_header():
    print("== Environment ==")
    print(dict(
        python=sys.version,
        torch=torch.__version__,
        transformers=transformers.__version__,
        accelerate=accelerate.__version__,
        peft=peft.__version__,
        cuda=torch.version.cuda if torch.cuda.is_available() else "cpu",
        device_name=torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu",
        platform=platform.platform(),
        seed=SEED
    ))
print_header()


GPU 0: NVIDIA A100-SXM4-40GB (UUID: GPU-84373428-e84d-9a3b-e6a4-2771d4c1a0d0)
Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
CUDA available: True
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.12.0 which is incompatible.[0m[31m
[0mTransformers: 4.51.3
Accelerate: 1.4.0
PEFT: 0.14.0
matplotlib: 3.10.0
== Environment ==
{'python': '3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]', 'torch': '2.8.0+cu126', 'transformers': '4.51.3', 'accelerate': '1.4.0', 'peft': '0.14.0', 'cuda': '12.6', 'device_name': 'NVIDIA A100-SXM4-40GB', 'platform': 'Linux-6.6.105+-x86_64-with-glibc2.35', 'seed': 42}


In [2]:
import os, sys, time, json, random, platform
from dataclasses import dataclass
from typing import Optional, List, Dict

import numpy as np
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm

import pandas as pd
from IPython.display import display

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
import re

# --------------------------
# Reproducibility & device
# --------------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
assert DEVICE == "cuda", "A CUDA GPU is required."


Try toying around with these settings. You can also try using a different teacher model, even if it doesn't use the same tokenizer.

In [3]:
# --------------------------
# Config
# --------------------------
@dataclass
class Config:
    # Models
    student_id: str = "Qwen/Qwen3-0.6B-Base"
    teacher_id: str = "Qwen/Qwen3-4B-Instruct-2507"

    # Prompting
    prompt_template: str = (
        "Solve step by step.\n"
        "Give ONLY ONE final numeric answer (no units), inside square brackets.\n"
        "Problem: {question}\n\nSolution:"
    )
    max_new_tokens: int = 256

    # Generation temps
    eval_temperature: float = 0.0   # greedy for eval
    train_temperature: float = 0.7  # sampling for on-policy data

    # Training schedule
    steps: int = 50
    batch_prompts: int = 4
    samples_per_prompt: int = 4 # increase this if you have a H200/B200
    lr: float = 1e-4
    weight_decay: float = 0.0
    grad_accum: int = 1

    # Micro-batching
    student_mb: int = 8

    # Monitoring
    log_every: int = 10
    val_every: int = 10
    val_sample_n: int = 100
    ema_momentum: float = 0.9

    # Validation size
    val_rows: Optional[int] = None  # if None, uses min(200, len(train))

    # Output dir
    run_root: str = f"./run_opd_{int(time.time())}"

cfg = Config()
os.makedirs(cfg.run_root, exist_ok=True)

def print_env():
    import transformers, accelerate, peft, matplotlib
    print("== Environment ==")
    print({
        "python": sys.version,
        "torch": torch.__version__,
        "transformers": transformers.__version__,
        "accelerate": accelerate.__version__,
        "peft": peft.__version__,
        "cuda": torch.version.cuda if torch.cuda.is_available() else "cpu",
        "device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu",
        "platform": platform.platform(),
        "seed": SEED
    })
print_env()

# --------------------------
# Data: GSM8K
# --------------------------
def render_prompt(question: str) -> str:
    return cfg.prompt_template.format(question=question)

def parse_gold(answer_text: str) -> Optional[str]:
    m = re.search(r"####\s*(-?\d+(?:\.\d+)?)", answer_text)
    if m: return m.group(1).strip()
    nums = re.findall(r"-?\d+(?:\.\d+)?", answer_text)
    return nums[-1].strip() if nums else None

def parse_pred(text: str) -> Optional[str]:
    m = re.search(r"\[\s*(-?\d+(?:\.\d+)?)\s*\]", text)
    if m: return m.group(1).strip()
    nums = re.findall(r"-?\d+(?:\.\d+)?", text)
    return nums[-1].strip() if nums else None

print("Loading GSM8K…")
ds_train_full = load_dataset("openai/gsm8k", "main", split="train")
ds_test       = load_dataset("openai/gsm8k", "main", split="test")

if cfg.val_rows is None:
    val_rows = min(200, len(ds_train_full))
else:
    val_rows = min(cfg.val_rows, len(ds_train_full))

ds_val   = ds_train_full.select(range(val_rows))
ds_train = ds_train_full.select(range(val_rows, len(ds_train_full)))

print(f"Splits: {len(ds_train)} train | {len(ds_val)} val | {len(ds_test)} test")

# --------------------------
# Tokenizers & Models
# --------------------------
def load_tokenizers(student_id: str, teacher_id: str):
    tok_s = AutoTokenizer.from_pretrained(student_id, use_fast=True)
    tok_t = AutoTokenizer.from_pretrained(teacher_id, use_fast=True)
    for tok in (tok_s, tok_t):
        if tok.pad_token is None and tok.eos_token is not None:
            tok.pad_token = tok.eos_token
        tok.padding_side = "left"  # decoder-only: left padding
    return tok_s, tok_t

def make_lora_student(model_id: str) -> torch.nn.Module:
    base = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map="auto"
    )
    base.config.use_cache = False  # off for training
    lora_cfg = LoraConfig(
        r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    )
    return get_peft_model(base, lora_cfg)

def load_teacher(model_id: str) -> torch.nn.Module:
    m = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map="auto"
    ).eval()
    for p in m.parameters():
        p.requires_grad_(False)
    return m

tok_s, tok_t = load_tokenizers(cfg.student_id, cfg.teacher_id)
print("Padding sides:", tok_s.padding_side, tok_t.padding_side)

== Environment ==
{'python': '3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]', 'torch': '2.8.0+cu126', 'transformers': '4.51.3', 'accelerate': '1.4.0', 'peft': '0.14.0', 'cuda': '12.6', 'device_name': 'NVIDIA A100-SXM4-40GB', 'platform': 'Linux-6.6.105+-x86_64-with-glibc2.35', 'seed': 42}
Loading GSM8K…


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Splits: 7273 train | 200 val | 1319 test


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

Padding sides: left left


Let's establish baseline scores on the test set of GSM8K, so we know if we're able to improve the student or not.

EM is Exact Match (accuracy score) on the val/test sets of GSM8K. Note that OPD doesn't use answer accuracy to inform weight updates during training! But it's still relevant for us to know since we do care about the accuracy.

In [4]:

# --------------------------
# Evaluation utils
# --------------------------
def _encode_cuda(tokenizer, texts: List[str], max_length=2048) -> Dict[str, torch.Tensor]:
    enc = tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
    return {k: v.to("cuda") for k, v in enc.items()}

@torch.no_grad()
def evaluate(model, tokenizer, dataset, *, num_examples: Optional[int] = None,
             temperature: float = 0.0, max_new_tokens: int = 256,
             batch_size: int = 32, desc: str = "Eval") -> float:
    """Exact-match accuracy against bracketed numeric answer."""
    n = len(dataset) if num_examples is None else min(num_examples, len(dataset))
    rows = dataset.select(range(n))
    correct = 0

    was_cache = getattr(model.config, "use_cache", True)
    model.eval(); model.config.use_cache = True

    for i in tqdm(range(0, n, batch_size), desc=desc):
        batch = rows.select(range(i, min(i + batch_size, n)))
        prompts = [render_prompt(ex["question"]) for ex in batch]
        enc = _encode_cuda(tokenizer, prompts)

        gen_kwargs = dict(
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True
        )
        if temperature and temperature > 0.0:
            gen_kwargs.update(do_sample=True, temperature=temperature, top_p=0.9)
        else:
            gen_kwargs.update(do_sample=False)

        outs = model.generate(**enc, **gen_kwargs)
        txts = tokenizer.batch_decode(outs, skip_special_tokens=True)

        for ex, text in zip(batch, txts):
            pred = parse_pred(text) or ""
            gold = parse_gold(ex["answer"]) or ""
            correct += int(pred == gold)

    model.config.use_cache = was_cache
    return correct / max(n, 1)

@torch.no_grad()
def preview(model, tokenizer, dataset, k=2, temperature=0.0, max_new_tokens=256):
    model.eval(); model.config.use_cache = True
    rows = dataset.select(range(min(k, len(dataset))))
    for ex in rows:
        prompt = render_prompt(ex["question"])
        enc = tokenizer([prompt], return_tensors="pt").to(model.device)
        gen_kwargs = dict(max_new_tokens=max_new_tokens, use_cache=True,
                          pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
        if temperature and temperature > 0.0:
            gen_kwargs.update(do_sample=True, temperature=temperature, top_p=0.9)
        out = model.generate(**enc, **gen_kwargs)
        text = tokenizer.decode(out[0], skip_special_tokens=True)
        print("="*80); print(prompt); print("-"*80); print(text)
        print("-"*80, f"\nParsed: [{parse_pred(text)}] | Gold: [{parse_gold(ex['answer'])}]")

# --------------------------
# Baselines (before training)
# --------------------------
print("\n== Loading models for baseline evals ==")
student_for_eval = make_lora_student(cfg.student_id)
teacher = load_teacher(cfg.teacher_id)

print("\nPreview (student, greedy)…")
preview(student_for_eval, tok_s, ds_test, k=1, temperature=0.0, max_new_tokens=cfg.max_new_tokens)

print("\nComputing baselines (greedy)…")
baseline_student_em = evaluate(student_for_eval, tok_s, ds_test,
                               temperature=cfg.eval_temperature, max_new_tokens=cfg.max_new_tokens,
                               batch_size=128, desc="Student baseline (test)")
baseline_teacher_em = evaluate(teacher, tok_t, ds_test,
                               temperature=cfg.eval_temperature, max_new_tokens=cfg.max_new_tokens,
                               batch_size=64, desc="Teacher baseline (test)")
print(f"Student (0.6B base) EM: {baseline_student_em:.4f}")
print(f"Teacher (4B instruct) EM: {baseline_teacher_em:.4f}")

with open(os.path.join(cfg.run_root, "baselines.json"), "w") as f:
    json.dump(dict(student_0p6b=baseline_student_em, teacher_4b=baseline_teacher_em), f, indent=2)

# Free the temporary student used just for baselines
del student_for_eval; torch.cuda.empty_cache()


== Loading models for baseline evals ==


config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.19G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/238 [00:00<?, ?B/s]


Preview (student, greedy)…
Solve step by step.
Give ONLY ONE final numeric answer (no units), inside square brackets.
Problem: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?

Solution:
--------------------------------------------------------------------------------
Solve step by step.
Give ONLY ONE final numeric answer (no units), inside square brackets.
Problem: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?

Solution: 16 - 3 - 4 = 9 ducks are left to sell. 9 * 2 = $18.
-------------------------------------------------------

Student baseline (test):   0%|          | 0/11 [00:00<?, ?it/s]

Teacher baseline (test):   0%|          | 0/21 [00:00<?, ?it/s]



Student (0.6B base) EM: 0.3867
Teacher (4B instruct) EM: 0.8309


Now let's actually do the OPD training!

As we train, we'll output the following metrics every 10 steps.

| Column        | Meaning                                                                                                                                                                                                         | How to interpret                                                                      |
| :------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------ |
| **step**      | The current training iteration (out of the total configured `cfg.steps`, e.g. 100). Each step processes one batch of sampled prompts.                                                                           | Training progresses along this axis.                                                  |
| **loss**      | The instantaneous batch loss (reverse-KL–style policy-gradient objective). Negative values are expected because we optimize `-E[(log p_T − log p_S) · log p_S]`; good updates drive this lower (more negative). | Lower (more negative) → better alignment with teacher on that batch.                  |
| **loss_ema**  | Exponential moving average (EMA) of the loss across steps (with momentum = `cfg.ema_momentum`, e.g. 0.9). Smooths noise to show trend.                                                                          | Downward trend = overall improvement; steadiness = convergence.                       |
| **revkl**     | The *reverse-KL divergence estimate* for this batch: roughly E[ log p_S − log p_T ]. It measures how far the student’s token distribution is from the teacher’s on its own rollouts.                            | Smaller (approaching 0) → student policy is closer to teacher.                        |
| **revkl_ema** | EMA-smoothed version of reverse-KL, again for trend stability.                                                                                                                                                  | Should decrease steadily if distillation is working.                                  |
| **tokens**    | Cumulative count of *valid graded tokens* seen so far — all pre-EOS generated tokens used for loss computation. It grows as training proceeds.                                                                  | Measures total learning signal processed; roughly proportional to compute/throughput. |
| **val_em**    | Validation exact-match accuracy (fraction of GSM8K val examples where the parsed numeric answer matches gold). Evaluated every `cfg.val_every` steps on `cfg.val_sample_n` examples.                            | Direct measure of task performance; higher = better reasoning accuracy.               |


In [6]:

# --------------------------
# OPD utilities
# --------------------------
def mask_before_first_eos(next_ids: torch.Tensor, eos_id: int) -> torch.Tensor:
    """Mask tokens strictly before the first EOS in each sequence."""
    is_eos = (next_ids == eos_id)
    csum = is_eos.cumsum(dim=1)
    return csum.eq(0)

def student_logp_batched(student, pad_id, full_ids, next_ids, T, micro_bsz):
    """Return student log p(a_t) for the last T tokens (with grad)."""
    outs = []
    for s in range(0, full_ids.size(0), micro_bsz):
        chunk = full_ids[s:s+micro_bsz]; nxt = next_ids[s:s+micro_bsz]
        out = student(input_ids=chunk[:, :-1],
                      attention_mask=(chunk[:, :-1] != pad_id))
        logits = out.logits[:, -T:, :]
        logp = F.log_softmax(logits, dim=-1).gather(-1, nxt.unsqueeze(-1)).squeeze(-1)
        outs.append(logp)
        del out, logits
    return torch.cat(outs, dim=0)

# ---------- Cross-tokenizer teacher scoring ----------
def _decode_token_str(tokenizer, token_id: int) -> str:
    # Decode one token to text exactly as-is (keep spaces/prefixes).
    return tokenizer.decode([int(token_id)],
                            skip_special_tokens=False,
                            clean_up_tokenization_spaces=False)

def _encode_text_ids(tokenizer, text: str):
    return tokenizer(text,
                     add_special_tokens=False,
                     return_tensors="pt").input_ids[0].tolist()

@torch.no_grad()
def teacher_logp_grouped_by_student_tokens(
    teacher, tok_teacher, tok_student, prompts: List[str], next_ids: torch.Tensor, max_len: Optional[int] = None
):
    """
    For each sample b and student step t:
      1) Decode the student's token id next_ids[b,t] → text piece
      2) Tokenize that text with the teacher tokenizer (may become multiple tokens)
      3) Sum teacher log-probs over that group
    Returns: Tensor [B, T] on CUDA with per-student-step teacher log-probs.
    """
    device = teacher.device if hasattr(teacher, "device") else "cuda"
    B, T = next_ids.shape
    out = torch.zeros((B, T), device=device, dtype=torch.float32)

    if max_len is None:
        max_len = int(getattr(teacher.config, "max_position_embeddings", 2048))

    for b in range(B):
        prompt_text = prompts[b]
        ctx_ids = _encode_text_ids(tok_teacher, prompt_text)

        groups = []
        for t in range(T):
            s_tok_id = int(next_ids[b, t].item())
            piece = _decode_token_str(tok_student, s_tok_id)
            ids_t = _encode_text_ids(tok_teacher, piece)
            groups.append(ids_t)

        flat_gen = [tid for g in groups for tid in g]
        if len(flat_gen) == 0:
            continue

        # Respect the teacher's context length by trimming left context
        total = len(ctx_ids) + len(flat_gen)
        if total > max_len:
            overflow = total - max_len
            ctx_ids = ctx_ids[overflow:]

        # Build teacher-forcing inputs; labels are shifted by one token
        full = ctx_ids + flat_gen
        if len(full) < 2:
            continue

        input_ids = torch.tensor(full[:-1], device=device).unsqueeze(0)
        labels    = torch.tensor(full[1:],  device=device).unsqueeze(0)
        attn_mask = torch.ones_like(input_ids, device=device)

        outputs = teacher(input_ids=input_ids, attention_mask=attn_mask)
        logprobs = F.log_softmax(outputs.logits, dim=-1)
        tok_lp = logprobs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)[0]  # [L]

        # Labels index full[1:], so generated part starts at k = len(ctx_ids)-1
        start = max(len(ctx_ids) - 1, 0)
        gen_lp = tok_lp[start : start + len(flat_gen)]

        # Sum back per student step
        off = 0
        for t, g in enumerate(groups):
            k = len(g)
            if k > 0:
                out[b, t] = gen_lp[off : off + k].sum()
            off += k

    return out

class LiveTable:
    def __init__(self, title: str = "Training metrics", max_rows: int = 200):
        self.title = title
        self.max_rows = max_rows
        self.rows = []
        empty = pd.DataFrame(columns=["step","loss","loss_ema","revkl","revkl_ema","tokens","val_em"])
        self.handle = display(self._styled(empty), display_id=True)

    def _styled(self, df: pd.DataFrame):
        styler = (
            df.style
              .set_caption(self.title)
              .format({
                  "loss": "{:.4f}",
                  "loss_ema": "{:.4f}",
                  "revkl": "{:.4f}",
                  "revkl_ema": "{:.4f}",
                  "val_em": (lambda v: "" if pd.isna(v) else f"{v:.3f}"),
              })
        )
        try:
            styler = styler.hide(axis="index")
            return styler
        except Exception:
            pass
        return styler.set_table_styles([
            {"selector": "th.row_heading", "props": [("display", "none")]},
            {"selector": "th.blank",       "props": [("display", "none")]},
        ])

    def update(self, *, step, loss, loss_ema, revkl, revkl_ema, tokens, val_em=None):
        self.rows.append(dict(
            step=int(step),
            loss=float(loss),
            loss_ema=(None if loss_ema is None else float(loss_ema)),
            revkl=float(revkl),
            revkl_ema=(None if revkl_ema is None else float(revkl_ema)),
            tokens=int(tokens),
            val_em=(None if val_em is None else float(val_em)),
        ))
        rows = self.rows[-self.max_rows:]
        df = pd.DataFrame(rows, columns=["step","loss","loss_ema","revkl","revkl_ema","tokens","val_em"])
        self.handle.update(self._styled(df))

# --------------------------
# Training loop (verbose)
# --------------------------
def ema(prev, new, beta):
    return new if prev is None else (beta * prev + (1 - beta) * new)

def run_training(run_dir: str):
    os.makedirs(run_dir, exist_ok=True)

    # Fresh student (LoRA adapters)
    student = make_lora_student(cfg.student_id)
    optimizer = torch.optim.AdamW(student.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    prompts_all = [render_prompt(x["question"]) for x in ds_train]
    ema_loss = None
    ema_revkl = None
    tokens_graded_cum = 0
    logs = []

    table = LiveTable(title="On-Policy Distillation")

    pbar = tqdm(range(cfg.steps), desc=f"OPD [{os.path.basename(run_dir)}]")
    for step in pbar:
        # Deterministic batch selection per step
        rng = np.random.default_rng(SEED + step)
        idxs = rng.choice(len(prompts_all), size=cfg.batch_prompts, replace=False)
        prompts = [prompts_all[i] for i in idxs]
        prompts_rep = sum(([p] * cfg.samples_per_prompt for p in prompts), [])
        enc = tok_s(prompts_rep, padding=True, truncation=True, max_length=2048, return_tensors="pt").to("cuda")

        # 1) Student rollouts (no grad) to get sequences and step count
        with torch.no_grad():
            gen_out = student.generate(
                **enc,
                do_sample=True, temperature=cfg.train_temperature, top_p=0.9,
                max_new_tokens=cfg.max_new_tokens,
                eos_token_id=tok_s.eos_token_id, pad_token_id=tok_s.pad_token_id,
                return_dict_in_generate=True, output_scores=True
            )
            seqs = gen_out.sequences
            scores_list = list(gen_out.scores)  # per-step logits if you want to inspect
            T = len(scores_list)
            next_ids = seqs[:, -T:]
            valid_mask = mask_before_first_eos(next_ids, eos_id=tok_s.eos_token_id).float()

        # 2) Student log-probs with grad
        student.train(); student.config.use_cache = False
        logp_s = student_logp_batched(
            student, tok_s.pad_token_id, seqs, next_ids, T, cfg.student_mb
        )

        # 3) Teacher log-probs (no grad), robust to tokenizer mismatches
        teacher.eval()
        logp_t = teacher_logp_grouped_by_student_tokens(
            teacher=teacher,
            tok_teacher=tok_t,
            tok_student=tok_s,
            prompts=prompts_rep,                  # the prompts used for these rollouts
            next_ids=next_ids,                    # [B, T] student IDs
            max_len=getattr(teacher.config, "max_position_embeddings", 2048),
        )

        # 4) Reverse-KL-style policy gradient
        adv = (logp_t - logp_s).clamp(-5, 5)  # detach below for stability
        denom = valid_mask.sum().clamp_min(1.0)
        loss = - ((valid_mask * adv.detach()) * logp_s).sum() / denom

        loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        with torch.no_grad():
            rev_kl = ((logp_s - logp_t) * valid_mask).sum().item() / float(denom)
            tokens_graded_cum += int(denom.item())
            ema_loss = ema(ema_loss, float(loss.item()), cfg.ema_momentum)
            ema_revkl = ema(ema_revkl, float(rev_kl), cfg.ema_momentum)

        # Periodic validation EM (greedy)
        val_em = None
        if (step % cfg.val_every == 0) or (step == cfg.steps - 1):
            student.eval(); student.config.use_cache = True
            val_em = evaluate(student, tok_s, ds_val,
                              num_examples=min(cfg.val_sample_n, len(ds_val)),
                              temperature=0.0, max_new_tokens=cfg.max_new_tokens,
                              batch_size=32, desc="VAL EM")
            student.train(); student.config.use_cache = False

        row = dict(
            step=int(step),
            train_loss=float(loss.item()),
            train_loss_ema=float(ema_loss) if ema_loss is not None else None,
            train_revkl=float(rev_kl),
            train_revkl_ema=float(ema_revkl) if ema_revkl is not None else None,
            tokens_graded=int(tokens_graded_cum),
            **({"val_em": float(val_em)} if val_em is not None else {})
        )
        logs.append(row)

        if (step % cfg.log_every == 0) or (val_em is not None):
            table.update(
                step=row["step"],
                loss=row["train_loss"],
                loss_ema=row["train_loss_ema"],
                revkl=row["train_revkl"],
                revkl_ema=row["train_revkl_ema"],
                tokens=row["tokens_graded"],
                val_em=row.get("val_em", None),
            )

        postfix = {
            "loss": f"{loss.item():.3f}",
            "ema": f"{(ema_loss if ema_loss is not None else loss.item()):.3f}",
            "rkl": f"{rev_kl:.3f}",
            "toks": tokens_graded_cum
        }
        if val_em is not None:
            postfix["val"] = f"{val_em:.3f}"
        pbar.set_postfix(**postfix)

        del scores_list
        torch.cuda.empty_cache()

    # Save logs
    try:
        pd.DataFrame(logs).to_csv(os.path.join(run_dir, "train_logs.csv"), index=False)
    except Exception:
        with open(os.path.join(run_dir, "train_logs.jsonl"), "w") as f:
            for r in logs:
                f.write(json.dumps(r) + "\n")

    # Final test EM (greedy)
    student.eval(); student.config.use_cache = True
    test_em = evaluate(student, tok_s, ds_test,
                       temperature=0.0, max_new_tokens=cfg.max_new_tokens,
                       batch_size=64, desc=f"Test EM [{os.path.basename(run_dir)}]")

    # Save adapters and summary
    save_dir = os.path.join(run_dir, "adapters_lora")
    os.makedirs(save_dir, exist_ok=True)
    student.save_pretrained(save_dir)

    summary = dict(
        steps=cfg.steps,
        batch_prompts=cfg.batch_prompts,
        samples_per_prompt=cfg.samples_per_prompt,
        max_new_tokens=cfg.max_new_tokens,
        train_tokens_graded=tokens_graded_cum,
        test_em=float(test_em)
    )
    with open(os.path.join(run_dir, "summary.json"), "w") as f:
        json.dump(summary, f, indent=2)

    # Free GPU
    del student; torch.cuda.empty_cache()
    return summary, logs

# --------------------------
# Run training (single run)
# --------------------------
print("\n== Training (OPD) ==")
run_dir = os.path.join(cfg.run_root, "opd")
summary, _ = run_training(run_dir)

print("\n== Final Results ==")
print(json.dumps(dict(
    env=dict(
        python=sys.version,
        torch=torch.__version__,
        cuda=torch.version.cuda if torch.cuda.is_available() else "cpu",
        device=torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu",
    ),
    prompt_template=cfg.prompt_template,
    baselines=dict(student_0p6b=baseline_student_em, teacher_4b=baseline_teacher_em),
    final=summary,
), indent=2))
print("Artifacts saved to:", cfg.run_root)


== Training (OPD) ==


step,loss,loss_ema,revkl,revkl_ema,tokens,val_em
0,-0.3254,-0.3254,0.3572,0.3572,1748,0.44
10,-0.2809,-0.252,0.2494,0.2808,30496,0.64
20,-0.4044,-0.281,0.2459,0.227,59299,0.74
30,-0.3436,-0.3018,0.2464,0.2223,87205,0.69
40,-0.3429,-0.2922,0.2059,0.195,115343,0.77
49,-0.2595,-0.26,0.122,0.1631,141744,0.73


OPD [opd]:   0%|          | 0/50 [00:00<?, ?it/s]

VAL EM:   0%|          | 0/4 [00:00<?, ?it/s]

VAL EM:   0%|          | 0/4 [00:00<?, ?it/s]

VAL EM:   0%|          | 0/4 [00:00<?, ?it/s]

VAL EM:   0%|          | 0/4 [00:00<?, ?it/s]

VAL EM:   0%|          | 0/4 [00:00<?, ?it/s]

VAL EM:   0%|          | 0/4 [00:00<?, ?it/s]

Test EM [opd]:   0%|          | 0/21 [00:00<?, ?it/s]


== Final Results ==
{
  "env": {
    "python": "3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]",
    "torch": "2.8.0+cu126",
    "cuda": "12.6",
    "device": "NVIDIA A100-SXM4-40GB"
  },
  "prompt_template": "Solve step by step.\nGive ONLY ONE final numeric answer (no units), inside square brackets.\nProblem: {question}\n\nSolution:",
  "baselines": {
    "student_0p6b": 0.3866565579984837,
    "teacher_4b": 0.8309325246398787
  },
  "final": {
    "steps": 50,
    "batch_prompts": 4,
    "samples_per_prompt": 4,
    "max_new_tokens": 256,
    "train_tokens_graded": 141744,
    "test_em": 0.6186504927975739
  }
}
Artifacts saved to: ./run_opd_1761781081
