# Debugging Dynamic Beta DPO (self-contained)

This notebook **flattens** the repo’s Dynamic-Beta DPO pipeline into a **single, linear workflow** so you can execute line-by-line and pinpoint where the logic breaks.

**Rules enforced here**
- No imports from `src/...` (no relative imports). All repo logic is copied into cells.
- Each code cell does *one* logical thing and ends with a **Sanity Check** print.

**Practical tips**
- If you don’t have access to `meta-llama/Llama-3.2-1B-Instruct`, set `USE_TINY_MODEL = True`.
- If you’re offline, the dataset cell falls back to a tiny synthetic HH-like dataset.


In [21]:
# 0) Credentials (optional) — DO NOT hardcode secrets in notebooks
#
# If you need Hugging Face / Weights & Biases auth, set env vars *outside* the notebook
# (e.g., in your shell before launching Jupyter), so you don't accidentally save tokens.
#
#   export HF_TOKEN=...             # Hugging Face
#   export WANDB_API_KEY=...        # Weights & Biases
#
# This cell only checks whether they're present.

import os

print('Sanity Check: HF_TOKEN set? ->', bool(os.environ.get('HF_TOKEN')))
print('Sanity Check: WANDB_API_KEY set? ->', bool(os.environ.get('WANDB_API_KEY')))


In [1]:
# 1) Imports + environment versions

import os
import sys
import json
import math
import random
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional

import numpy as np
import torch
import torch.nn.functional as F

import datasets as hf_datasets
from datasets import Dataset, load_dataset

import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

import accelerate
import trl
from trl import DPOConfig, DPOTrainer

print("Sanity Check: versions")
print("  python:", f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}")
print("  torch:", torch.__version__)
print("  transformers:", transformers.__version__)
print("  trl:", trl.__version__)
print("  datasets:", hf_datasets.__version__)
print("  accelerate:", accelerate.__version__)
print("  cuda_available:", torch.cuda.is_available(), "cuda_devices:", torch.cuda.device_count())


  from .autonotebook import tqdm as notebook_tqdm


Sanity Check: versions
  python: 3.11.13
  torch: 2.9.1+cu128
  transformers: 4.57.3
  trl: 0.26.2
  datasets: 4.4.2
  accelerate: 1.12.0
  cuda_available: True cuda_devices: 1


In [None]:
# 1.1) Debug helpers (assertions + tensor summaries)

from typing import Tuple


def summarize_tensor(name: str, t: torch.Tensor) -> None:
    t_detached = t.detach()
    stats = {
        'shape': tuple(t_detached.shape),
        'dtype': str(t_detached.dtype),
        'device': str(t_detached.device),
        'min': float(t_detached.min().float().cpu()) if t_detached.numel() else None,
        'max': float(t_detached.max().float().cpu()) if t_detached.numel() else None,
        'mean': float(t_detached.float().mean().cpu()) if t_detached.numel() else None,
        'std': float(t_detached.float().std(unbiased=False).cpu()) if t_detached.numel() else None,
        'requires_grad': bool(t.requires_grad),
    }
    print(f'{name}:', stats)


def assert_all_finite(name: str, t: torch.Tensor) -> None:
    ok = torch.isfinite(t.detach()).all().item() if t.numel() else True
    assert ok, f'{name} contains NaN/Inf'


def assert_close(name: str, a: torch.Tensor, b: torch.Tensor, *, atol: float = 1e-5, rtol: float = 1e-5) -> None:
    diff = (a - b).detach().abs()
    denom = (b.detach().abs() + 1e-12)
    rel = diff / denom
    max_abs = float(diff.max().cpu()) if diff.numel() else 0.0
    max_rel = float(rel.max().cpu()) if rel.numel() else 0.0
    assert (diff <= atol).all().item() or (rel <= rtol).all().item(), (
        f'{name} not close: max_abs={max_abs:.3e}, max_rel={max_rel:.3e} (atol={atol}, rtol={rtol})'
    )


print('Sanity Check: debug helpers ready')


In [2]:
# 2) Determinism / seed control

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print("Sanity Check: seed set ->", SEED)


Sanity Check: seed set -> 42


In [9]:
# 3) Hardcoded config (edit here)

# Reduce these for fast iteration; scale them up once the logic is correct.
USE_TINY_MODEL = False
FALLBACK_MODEL = "sshleifer/tiny-gpt2"
RUN_FULL_TRAIN = False

# Avoid tokenizer parallelism warnings in notebooks.
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

config: Dict[str, Any] = {
    "policy_name": "meta-llama/Llama-3.2-1B-Instruct",
    "ref_name": "meta-llama/Llama-3.2-1B-Instruct",
    "precision": "bf16",  # one of: fp16, bf16, fp32
    "dataset": {
        "dataset_name": "Anthropic/hh-rlhf",
        "generated_data": False,
        "chat_template": True,
        "subset": "train[:1%]",  # small slice for debugging
        "val_ratio": 0.1,
        "seed": 42,
        # TRL truncation settings
        "max_prompt_length": 256,
        "max_completion_length": 256,
        "max_length": 512,
        "truncation_mode": "keep_end",
    },
    "dpo_training": {
        "epochs": 1,
        "batch_size": 2,
        "eval_batch_size": 2,
        "learning_rate": 5e-7,
        "log_steps": 1,
        "eval_steps": 25,
        "save_steps": 25,
        "gradient_accumulation": 1,
        "max_grad_norm": 10,
        "warmup_steps": 0,
        "run_name": "debug-dynamic-beta",
        "save_dir": "trl_dynamic_beta_dpo_debug",
        "report": None,  # set to "wandb" / project name if desired
    },
    "risk_test": {
        "delta": 0.1,
        "lambda": 0.1,
        "beta_warmup": 5,  # keep small for debugging; repo default is 120
    },
    "beta_update": {
        "beta_0": 0.1,
        "gamma": 2.0,
        "alpha": 0.005,
        "beta_max": 2.0,
        "beta_min": 0.0,
    },
    "margin_log": {
        "jsonl_sample_size": 32,
        "save_per_rank": False,
        "log_dir": "logs/margins_debug",
    },
}

print("Sanity Check: config loaded")
print("  policy_name:", config["policy_name"])
print("  dataset:", config["dataset"]["dataset_name"], config["dataset"]["subset"], "chat_template=", config["dataset"]["chat_template"])
print("  warmup_steps:", config["risk_test"]["beta_warmup"], "delta=", config["risk_test"]["delta"]) 


Sanity Check: config loaded
  policy_name: meta-llama/Llama-3.2-1B-Instruct
  dataset: Anthropic/hh-rlhf train[:1%] chat_template= True
  warmup_steps: 5 delta= 0.1


In [5]:
# 4) (Copied) src/data/templates.py  — chat templates + HH parsing

import re

TAG_RE = re.compile(r"\n\n(Human|Assistant): ?")

# Llama 3 chat template
LLAMA3_CHAT_TEMPLATE = (
    "{% set loop_messages = messages %}"
    "{% for message in loop_messages %}"
    "{% set content = message['content'] %}"
    "{% if loop.index0 == 0 %}"
    "{{ '<|begin_of_text|>' }}"
    "{% endif %}"
    "{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' + content | trim + '<|eot_id|>' }}"
    "{% endfor %}"
    "{% if add_generation_prompt %}"
    "{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}"
    "{% endif %}"
)


def strip_one_leading_newline(text: str) -> str:
    """Remove a single leading newline to normalize HH blocks."""
    return text[1:] if text.startswith("\n") else text


def parse_hh_to_messages(text: str) -> List[Dict[str, str]]:
    """Parse Anthropic HH multi-turn text into [{role, content}, ...]."""
    text = str(text).replace("\r\n", "\n").replace("\r", "\n")
    if not text.startswith("\n\nHuman:") and not text.startswith("\n\nAssistant:"):
        text = "\n\n" + text

    parts = TAG_RE.split(text)
    messages: List[Dict[str, str]] = []
    for i in range(1, len(parts), 2):
        role_tag = parts[i]
        content = parts[i + 1] if i + 1 < len(parts) else ""
        content = strip_one_leading_newline(content).strip()
        if not content:
            continue
        role = "user" if role_tag == "Human" else "assistant"
        messages.append({"role": role, "content": content})
    return messages


# Sanity check
_example = "\n\nHuman: Hello\n\nAssistant: Hi!\n\nHuman: What's 2+2?\n\nAssistant: 4"
print("Sanity Check: parse_hh_to_messages ->", parse_hh_to_messages(_example)[:4])


Sanity Check: parse_hh_to_messages -> [{'role': 'user', 'content': 'Hello'}, {'role': 'assistant', 'content': 'Hi!'}, {'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}]


In [6]:
# 5) (Copied) src/data/hh_dataset.py — HH triplets builder

ASSISTANT_TAG = "\n\nAssistant:"
HUMAN_TAG = "\n\nHuman:"
LLAMA3_ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>\n\n"


def strip_one_leading_newline(s: str) -> str:
    """Remove a single leading newline to normalize HH blocks."""
    return s[1:] if s.startswith("\n") else s


def split_prompt_and_response(input_text: str) -> tuple[str, str]:
    """Split HH format text into prompt and response (last Assistant tag)."""
    input_text = str(input_text).replace("\r\n", "\n").replace("\r", "\n")
    index = input_text.rfind(ASSISTANT_TAG)
    if index < 0:
        raise ValueError("No '\\n\\nAssistant:' tag found in HH input.")
    prompt = input_text[: index + len(ASSISTANT_TAG)]
    response = input_text[index + len(ASSISTANT_TAG) :]
    response = strip_one_leading_newline(response)
    return prompt, response


def convert_to_triples(chosen_text: str, rejected_text: str) -> Optional[Dict[str, str]]:
    """Convert one HH row into {prompt, chosen, rejected}."""
    chosen_prompt, chosen_response = split_prompt_and_response(chosen_text)

    if not rejected_text.startswith(chosen_prompt):
        return None

    rejected_response = strip_one_leading_newline(rejected_text[len(chosen_prompt) :])

    if len(chosen_prompt.strip()) == 0:
        return None
    if len(chosen_response.strip()) == 0 or len(rejected_response.strip()) == 0:
        return None

    return {"prompt": chosen_prompt, "chosen": chosen_response, "rejected": rejected_response}


def build_HH_dataset(ds) -> Dataset:
    """Process entire dataset into HH triplets format."""
    hh_ds_raw: List[Dict[str, str]] = []
    for _, row in enumerate(ds):
        output = convert_to_triples(chosen_text=row["chosen"], rejected_text=row["rejected"])
        if output is not None:
            hh_ds_raw.append(output)
    return Dataset.from_list(hh_ds_raw)


# Sanity check
_chosen = "\n\nHuman: hi\n\nAssistant: hello"
_rejected = "\n\nHuman: hi\n\nAssistant: nope"
print("Sanity Check: convert_to_triples ->", convert_to_triples(_chosen, _rejected))


Sanity Check: convert_to_triples -> {'prompt': '\n\nHuman: hi\n\nAssistant:', 'chosen': ' hello', 'rejected': ' nope'}


In [7]:
# 6) (Copied) src/data/hh_dataset.py — rollout loader + chat-template application


def _normalize_text(text: Any) -> str:
    return str(text).replace("\r\n", "\n").replace("\r", "\n")


def _coerce_messages(messages: Any) -> Optional[List[Dict[str, str]]]:
    if not isinstance(messages, list):
        return None
    cleaned: List[Dict[str, str]] = []
    for msg in messages:
        if not isinstance(msg, dict):
            continue
        role = msg.get("role")
        if role not in ("user", "assistant"):
            continue
        content = _normalize_text(msg.get("content", "")).strip()
        if not content:
            continue
        cleaned.append({"role": role, "content": content})
    return cleaned if cleaned else None


def _messages_to_hh_prompt(messages: List[Dict[str, str]]) -> Optional[str]:
    if not messages or messages[-1]["role"] != "user":
        return None
    parts: List[str] = []
    for msg in messages:
        tag = HUMAN_TAG if msg["role"] == "user" else ASSISTANT_TAG
        parts.append(f"{tag} {msg['content']}")
    prompt = "".join(parts)
    if not prompt.endswith(ASSISTANT_TAG):
        prompt = f"{prompt}{ASSISTANT_TAG}"
    return prompt


def _extract_response_text(value: Any) -> Optional[str]:
    if isinstance(value, str):
        text = _normalize_text(value).strip()
        return text if text else None
    if isinstance(value, dict):
        content = _normalize_text(value.get("content", "")).strip()
        return content if content else None
    if isinstance(value, list):
        parts: List[str] = []
        for msg in value:
            if not isinstance(msg, dict):
                continue
            role = msg.get("role")
            if role is not None and role != "assistant":
                continue
            content = _normalize_text(msg.get("content", "")).strip()
            if content:
                parts.append(content)
        if parts:
            return "\n\n".join(parts)
    return None


def build_rollout_dataset(ds: Iterable[Dict[str, Any]]) -> Dataset:
    """Build dataset from rollout generation outputs."""
    rollout_ds_raw: List[Dict[str, str]] = []
    for row in ds:
        prompt_messages = _coerce_messages(row.get("prompt_messages"))
        if prompt_messages is None:
            continue
        prompt_text = _messages_to_hh_prompt(prompt_messages)
        if not prompt_text:
            continue
        chosen_text = _extract_response_text(row.get("chosen"))
        rejected_text = _extract_response_text(row.get("rejected"))
        if not chosen_text or not rejected_text:
            continue
        rollout_ds_raw.append({"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text})
    return Dataset.from_list(rollout_ds_raw)


def load_generated_hf_dataset(dataset_name: str, *, subset: str = "train") -> Dataset:
    """Load a generated dataset from HuggingFace."""
    raw_ds = load_dataset(dataset_name, split=subset)
    return build_rollout_dataset(raw_ds)


def load_generated_dataset_from_config(config: Dict[str, Any]) -> Dataset:
    """Load generated dataset using configuration dictionary."""
    dataset_cfg = config.get("dataset", {})
    dataset_name = dataset_cfg.get("dataset_name")
    if not dataset_name:
        raise ValueError("Missing dataset.dataset_name in config.")
    subset = dataset_cfg.get("subset", "train")
    return load_generated_hf_dataset(dataset_name, subset=subset)


def _ensure_chat_template(tokenizer: Any) -> None:
    if not getattr(tokenizer, "chat_template", None):
        tokenizer.chat_template = LLAMA3_CHAT_TEMPLATE


def _ensure_generation_prompt(prompt_text: str, tokenizer: Any) -> str:
    trimmed = prompt_text.rstrip()
    if trimmed.endswith(LLAMA3_ASSISTANT_HEADER.rstrip()):
        return prompt_text
    template = getattr(tokenizer, "chat_template", "") or ""
    if "<|start_header_id|>" in prompt_text or "start_header_id" in template:
        return f"{prompt_text}{LLAMA3_ASSISTANT_HEADER}"
    return prompt_text


def _render_response_with_chat_template(
    messages: List[Dict[str, str]],
    response: str,
    *,
    tokenizer: Any,
    prompt_text: str,
) -> Optional[str]:
    response = _normalize_text(response).strip()
    if not response:
        return None
    full_messages = messages + [{"role": "assistant", "content": response}]
    full_text = tokenizer.apply_chat_template(full_messages, tokenize=False, add_generation_prompt=False)
    if full_text.startswith(prompt_text):
        rendered = full_text[len(prompt_text) :]
    else:
        rendered = response
    rendered = rendered.strip()
    return rendered if rendered else None


def apply_chat_template_to_dataset(ds: Dataset, tokenizer: Any) -> Dataset:
    """Apply chat template to dataset prompts and responses."""
    _ensure_chat_template(tokenizer)
    rows: List[Dict[str, str]] = []
    for row in ds:
        prompt_text = _normalize_text(row.get("prompt", "")).strip()
        chosen_text = row.get("chosen", "")
        rejected_text = row.get("rejected", "")
        if not prompt_text:
            continue

        messages = parse_hh_to_messages(prompt_text)
        if not messages or messages[-1]["role"] != "user":
            continue

        prompt_rendered = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        if not prompt_rendered:
            continue
        prompt_rendered = _ensure_generation_prompt(prompt_rendered, tokenizer)

        chosen_rendered = _render_response_with_chat_template(
            messages, str(chosen_text), tokenizer=tokenizer, prompt_text=prompt_rendered
        )
        rejected_rendered = _render_response_with_chat_template(
            messages, str(rejected_text), tokenizer=tokenizer, prompt_text=prompt_rendered
        )
        if not chosen_rendered or not rejected_rendered:
            continue

        rows.append({"prompt": prompt_rendered, "chosen": chosen_rendered, "rejected": rejected_rendered})
    return Dataset.from_list(rows)


# Sanity check
_msgs = [{"role": "user", "content": "hello"}]
print("Sanity Check: _messages_to_hh_prompt ->", _messages_to_hh_prompt(_msgs))


Sanity Check: _messages_to_hh_prompt -> 

Human: hello

Assistant:


In [10]:
# 7) Load raw dataset (HF) OR fall back to tiny synthetic data

dataset_cfg = config["dataset"]

dataset_name = dataset_cfg["dataset_name"]
subset = dataset_cfg.get("subset", "train")

try:
    raw_ds = load_dataset(dataset_name, split=subset)
    print("Loaded HF dataset:", dataset_name, subset)
except Exception as e:
    print("WARNING: failed to load HF dataset; using tiny synthetic fallback.")
    print("  error:", repr(e))

    raw_ds = Dataset.from_list(
        [
            {
                "chosen": "\n\nHuman: Say hello\n\nAssistant: Hello!",
                "rejected": "\n\nHuman: Say hello\n\nAssistant: ",
            },
            {
                "chosen": "\n\nHuman: What's 2+2?\n\nAssistant: 4",
                "rejected": "\n\nHuman: What's 2+2?\n\nAssistant: 5",
            },
        ]
    )

print("Sanity Check: raw_ds")
print("  type:", type(raw_ds))
print("  len:", len(raw_ds))
print("  columns:", getattr(raw_ds, "column_names", None))
print("  first_row_keys:", list(raw_ds[0].keys()))


Loaded HF dataset: Anthropic/hh-rlhf train[:1%]
Sanity Check: raw_ds
  type: <class 'datasets.arrow_dataset.Dataset'>
  len: 1608
  columns: ['chosen', 'rejected']
  first_row_keys: ['chosen', 'rejected']


In [14]:
# 8) Convert raw HH (chosen/rejected) -> triplets (prompt/chosen/rejected)

if bool(dataset_cfg.get("generated_data", False)):
    hh_ds = load_generated_dataset_from_config(config)
    source_kind = "generated"
else:
    hh_ds = build_HH_dataset(raw_ds)
    source_kind = "hh-rlhf"

print("Sanity Check: hh_ds built ->", source_kind)
print("  len:", len(hh_ds))
print("  columns:", hh_ds.column_names)
if len(hh_ds) > 0:
    print("  sample.prompt[:500]:", hh_ds[0]["prompt"][:1000])
    print("  sample.chosen[:200]:", hh_ds[0]["chosen"][:200])
    print("  sample.rejected[:200]:", hh_ds[0]["rejected"][:200])


Sanity Check: hh_ds built -> hh-rlhf
  len: 1606
  columns: ['prompt', 'chosen', 'rejected']
  sample.prompt[:500]: 

Human: What are some cuss words in english?

Assistant: Here’s an incomplete list.

Ass, dick, bugger, crap, fuck, shit, bitch, turd, shithead, shitbag, scrotum, cunt, whore, fucker, shit-eating, cum, cumbucket, fucknugget, butthole, poop, jackass, cocksucker, asshole, goddamn, piss, sperm, blow, wank, jism, cum-sucking, masturbate, faggot, queer, jizz, jizz-licking, prostitute, slut, cheater, fornicator, floozy, wetback, Mexican, Hispanic, sodomite, midget, mama’s boy, faggot, pervert, queer, scumbag, bitch,

Human: What's your favorite one?

Assistant:
  sample.chosen[:200]:  I haven't even thought about it.
  sample.rejected[:200]:  Ass.


In [15]:
# 9) Load tokenizer (+ pad token fix)

policy_name = config["policy_name"]
ref_name = config["ref_name"]

if USE_TINY_MODEL:
    policy_name = ref_name = FALLBACK_MODEL
    print("Using tiny model:", policy_name)

try:
    tok = AutoTokenizer.from_pretrained(policy_name, use_fast=True)
except Exception as e:
    print("WARNING: tokenizer load failed; falling back to", FALLBACK_MODEL)
    print("  error:", repr(e))
    policy_name = ref_name = FALLBACK_MODEL
    tok = AutoTokenizer.from_pretrained(policy_name, use_fast=True)

if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token

print("Sanity Check: tokenizer")
print("  name:", policy_name)
print("  class:", tok.__class__.__name__)
print("  pad_token_id:", tok.pad_token_id, "eos_token_id:", tok.eos_token_id)
print("  has_chat_template:", bool(getattr(tok, "chat_template", None)))


Sanity Check: tokenizer
  name: meta-llama/Llama-3.2-1B-Instruct
  class: PreTrainedTokenizerFast
  pad_token_id: 128009 eos_token_id: 128009
  has_chat_template: True


In [18]:
# 10) Apply chat template to the dataset (optional but matches repo default)

if bool(dataset_cfg.get("chat_template", False)):
    hh_ds_templated = apply_chat_template_to_dataset(hh_ds, tok)
    applied = True
else:
    hh_ds_templated = hh_ds
    applied = False

print("Sanity Check: chat template applied ->", applied)
print("  before_len:", len(hh_ds), "after_len:", len(hh_ds_templated))
if len(hh_ds_templated) > 0:
    print("  prompt[:120]:", hh_ds_templated[0]["prompt"][:1200])
    print("  chosen[:120]:", hh_ds_templated[0]["chosen"][:120])
    print("  rejected[:120]:", hh_ds_templated[0]["rejected"][:120])


Sanity Check: chat template applied -> True
  before_len: 1606 after_len: 1606
  prompt[:120]: <|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 27 Jan 2026

<|eot_id|><|start_header_id|>user<|end_header_id|>

What are some cuss words in english?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Here’s an incomplete list.

Ass, dick, bugger, crap, fuck, shit, bitch, turd, shithead, shitbag, scrotum, cunt, whore, fucker, shit-eating, cum, cumbucket, fucknugget, butthole, poop, jackass, cocksucker, asshole, goddamn, piss, sperm, blow, wank, jism, cum-sucking, masturbate, faggot, queer, jizz, jizz-licking, prostitute, slut, cheater, fornicator, floozy, wetback, Mexican, Hispanic, sodomite, midget, mama’s boy, faggot, pervert, queer, scumbag, bitch,<|eot_id|><|start_header_id|>user<|end_header_id|>

What's your favorite one?<|eot_id|><|start_header_id|>assistant<|end_header_id|>


  chosen[:120]: I haven't even thought abou

In [19]:
# 11) Train/val split

val_ratio = float(dataset_cfg.get("val_ratio", 0.1))
seed = int(dataset_cfg.get("seed", 42))

split = hh_ds_templated.train_test_split(test_size=val_ratio, seed=seed)
train_ds = split["train"]
eval_ds = split["test"]

print("Sanity Check: split sizes")
print("  train:", len(train_ds), "eval:", len(eval_ds))
if len(train_ds) > 0:
    print("  train sample keys:", list(train_ds[0].keys()))


Sanity Check: split sizes
  train: 1445 eval: 161
  train sample keys: ['prompt', 'chosen', 'rejected']


In [None]:
# 12) Load policy model

prec = str(config.get("precision", "fp32")).lower()
if prec == "fp16":
    torch_dtype = torch.float16
elif prec == "bf16":
    torch_dtype = torch.bfloat16
else:
    torch_dtype = None

model_load_kwargs: Dict[str, Any] = {}
if torch_dtype is not None:
    model_load_kwargs["torch_dtype"] = torch_dtype

# For notebooks, device_map="auto" is usually convenient.
if torch.cuda.is_available():
    model_load_kwargs["device_map"] = "auto"

try:
    policy = AutoModelForCausalLM.from_pretrained(policy_name, **model_load_kwargs)
except Exception as e:
    print("WARNING: policy model load failed; falling back to", FALLBACK_MODEL)
    print("  error:", repr(e))
    policy_name = FALLBACK_MODEL
    ref_name = FALLBACK_MODEL
    tok = AutoTokenizer.from_pretrained(policy_name, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    policy = AutoModelForCausalLM.from_pretrained(policy_name, **model_load_kwargs)

policy.config.use_cache = False

num_params = sum(p.numel() for p in policy.parameters())
trainable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)

print("Sanity Check: policy model")
print("  name:", policy_name)
print("  num_params:", num_params)
print("  trainable_params:", trainable_params)
print("  first_param.device:", next(policy.parameters()).device)
print("  first_param.dtype:", next(policy.parameters()).dtype)


`torch_dtype` is deprecated! Use `dtype` instead!


Sanity Check: policy model
  name: meta-llama/Llama-3.2-1B-Instruct
  num_params: 1235814400
  trainable_params: 1235814400
  first_param.device: cuda:0
  first_param.dtype: torch.bfloat16


In [22]:
# 13) Load reference model (freeze)

ref_load_kwargs = dict(model_load_kwargs)

ref_model = AutoModelForCausalLM.from_pretrained(ref_name, **ref_load_kwargs)
ref_model.config.use_cache = False
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad_(False)

frozen = sum(p.numel() for p in ref_model.parameters() if not p.requires_grad)

print("Sanity Check: reference model")
print("  name:", ref_name)
print("  frozen_params:", frozen)
print("  first_param.device:", next(ref_model.parameters()).device)


Sanity Check: reference model
  name: meta-llama/Llama-3.2-1B-Instruct
  frozen_params: 1235814400
  first_param.device: cuda:0


In [23]:
# 14) (Copied) src/losses/dpo_loss.py — compute_log_prob


def compute_log_prob(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """Calculate the log probability of labels given logits."""
    logits = logits[:, :-1, :]
    labels = labels[:, 1:].clone()

    loss_mask = labels != -100
    labels[labels == -100] = 0

    per_token_log_prob = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

    return (per_token_log_prob * loss_mask).sum(-1)


# Sanity check
_logits = torch.randn(2, 5, 11)
_labels = torch.tensor([
    [-100, 1, 2, 3, 4],
    [-100, 5, 6, -100, 7],
])
_lp = compute_log_prob(_logits, _labels)
print("Sanity Check: compute_log_prob")
print("  shape:", _lp.shape)
print("  values:", _lp)


Sanity Check: compute_log_prob
  shape: torch.Size([2])
  values: tensor([-10.0262,  -8.3727])


In [24]:
# 15) (Copied) src/losses/dpo_loss.py — dpo_loss


def dpo_loss(
    policy_chosen_log_prob: torch.Tensor,
    policy_rejected_log_prob: torch.Tensor,
    ref_chosen_log_prob: torch.Tensor,
    ref_rejected_log_prob: torch.Tensor,
    beta: float,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Calculate the DPO loss."""
    chosen_log_prob = policy_chosen_log_prob - ref_chosen_log_prob
    rejected_log_prob = policy_rejected_log_prob - ref_rejected_log_prob

    loss = -F.logsigmoid(beta * (chosen_log_prob - rejected_log_prob))

    chosen_rewards = (beta * chosen_log_prob).detach()
    rejected_rewards = (beta * rejected_log_prob).detach()

    return loss, chosen_rewards, rejected_rewards


# Sanity check: margin=0 -> loss ~ log(2)
_loss, _cr, _rr = dpo_loss(
    policy_chosen_log_prob=torch.tensor([0.0]),
    policy_rejected_log_prob=torch.tensor([0.0]),
    ref_chosen_log_prob=torch.tensor([0.0]),
    ref_rejected_log_prob=torch.tensor([0.0]),
    beta=1.0,
)
print("Sanity Check: dpo_loss")
print("  loss:", float(_loss.item()), "(expected ~0.6931)")


Sanity Check: dpo_loss
  loss: 0.6931471824645996 (expected ~0.6931)


In [25]:
# 16) (Copied) src/losses/margin.py — margin + risk helpers


def margin_compute(
    policy_chosen_log_prob: torch.Tensor,
    policy_rejected_log_prob: torch.Tensor,
    ref_chosen_log_prob: torch.Tensor,
    ref_rejected_log_prob: torch.Tensor,
) -> torch.Tensor:
    """Compute the model margin."""
    policy_diff = policy_chosen_log_prob - policy_rejected_log_prob
    ref_diff = ref_chosen_log_prob - ref_rejected_log_prob
    model_margin = (policy_diff - ref_diff).detach()
    return model_margin


def empirical_over_threshold_proportion(margins: torch.Tensor, threshold: float) -> float:
    """Proportion of margins >= threshold."""
    return (margins >= threshold).float().mean().item()


def risk_test(p_hat: float, delta: float) -> bool:
    """True if p_hat > delta."""
    return p_hat > delta


# Sanity check
_m = margin_compute(torch.tensor([1.0]), torch.tensor([0.0]), torch.tensor([1.0]), torch.tensor([0.0]))
print("Sanity Check: margin_compute ->", _m.item())
print("  p_hat@tau=0:", empirical_over_threshold_proportion(torch.tensor([-1.0, 0.5, 2.0]), 0.0))
print("  risk_test(0.2,0.1):", risk_test(0.2, 0.1))


Sanity Check: margin_compute -> 0.0
  p_hat@tau=0: 0.6666666865348816
  risk_test(0.2,0.1): True


In [26]:
# 17) (Copied) src/losses/beta_update.py — update_beta (Dynamic Beta core)

import math


def update_beta(
    beta: float,
    p_hat: float,
    delta: float,
    alpha: float,
    n: int,
    gamma: float,
    beta_min: float,
    beta_max: float,
) -> tuple[float, float, float, float]:
    """Update beta using the risk-based update equation."""
    u_k = (p_hat - delta) / math.sqrt((delta * (1 - delta)) / n)
    s_k = math.tanh(gamma * u_k)
    beta_new = beta * math.exp(alpha * s_k)
    beta_new = max(beta_min, min(beta_new, beta_max))
    return beta_new, u_k, s_k, alpha


# Sanity check: dummy cases
beta0 = 0.5
n = 128
delta = 0.1
alpha = 0.01
gamma = 2.0

beta_up, u_up, s_up, _ = update_beta(
    beta=beta0, p_hat=0.3, delta=delta, alpha=alpha, n=n, gamma=gamma, beta_min=0.0, beta_max=2.0
)

beta_down, u_down, s_down, _ = update_beta(
    beta=beta0, p_hat=0.01, delta=delta, alpha=alpha, n=n, gamma=gamma, beta_min=0.0, beta_max=2.0
)

print("Sanity Check: update_beta")
print("  beta0:", beta0)
print("  p_hat > delta -> beta:", beta_up, "u_k:", u_up, "s_k:", s_up)
print("  p_hat < delta -> beta:", beta_down, "u_k:", u_down, "s_k:", s_down)


Sanity Check: update_beta
  beta0: 0.5
  p_hat > delta -> beta: 0.5050250835420832 u_k: 7.542472332656505 s_k: 0.9999999999998421
  p_hat < delta -> beta: 0.49502492944874754 u_k: -3.394112549695428 s_k: -0.9999974598928385


In [27]:
# 18) (Copied) src/quantile/accumulator.py — WarmupQuantileAccumulator


class WarmupQuantileAccumulator:
    """Accumulate margins during warmup and estimate initial quantile threshold (tau_0)."""

    def __init__(self, q: float):
        self.q = q
        self._buf: list[torch.Tensor] = []

    @torch.no_grad()
    def update(self, batch_margins: torch.Tensor) -> None:
        t = batch_margins.detach().float().view(-1)
        if t.numel() == 0:
            return
        self._buf.append(t.cpu())

    def finalize(self) -> float:
        if len(self._buf) == 0:
            return 0.0
        all_m = torch.cat(self._buf, dim=0)
        tau0 = torch.quantile(all_m, self.q).item()
        return float(tau0)


# Sanity check
acc = WarmupQuantileAccumulator(q=0.9)
acc.update(torch.tensor([0.0, 1.0, 2.0]))
acc.update(torch.tensor([3.0, 4.0]))
print("Sanity Check: WarmupQuantileAccumulator tau0 ->", acc.finalize())


Sanity Check: WarmupQuantileAccumulator tau0 -> 3.5999999046325684


In [28]:
# 19) (Copied) src/quantile/accumulator.py — EMAUpdate


class EMAUpdate:
    """Exponential moving average update for threshold tau."""

    def __init__(self, tau_0: float, q: float, momentum: float):
        self.tau = tau_0
        self.q = q
        self.lam = momentum

    def update_tau(self, batch_margins: torch.Tensor) -> float:
        t = batch_margins.detach().float().view(-1)
        if t.numel() == 0:
            return self.tau
        batch_tau = torch.quantile(t, self.q).item()
        self.tau = (1.0 - self.lam) * self.tau + self.lam * batch_tau
        return self.tau


# Sanity check
ema = EMAUpdate(tau_0=1.0, q=0.9, momentum=0.1)
print("Sanity Check: EMA tau")
print("  before:", ema.tau)
print("  after:", ema.update_tau(torch.tensor([0.0, 1.0, 10.0])))


Sanity Check: EMA tau
  before: 1.0
  after: 1.7199999809265138


In [29]:
# 20) (Copied) src/utils/logging.py — compute_and_log_model_margin


def compute_and_log_model_margin(
    model_margin: torch.Tensor,
    epoch_dir: str,
    epoch: int,
    step: int,
    jsonl_path: str,
) -> None:
    """Compute and log model margin statistics."""
    import os
    import json
    import numpy as np

    m = model_margin.detach().float().cpu().numpy()

    npy_path = os.path.join(epoch_dir, f"step_{step:05d}.npy")
    np.save(npy_path, m)

    p10, p50, p90 = np.percentile(m, [10, 50, 90])

    record = {
        "epoch": int(epoch),
        "step": int(step),
        "batch_size": int(m.shape[0]),
        "mean": float(m.mean()),
        "std": float(m.std(ddof=0)),
        "min": float(m.min()),
        "p10": float(p10),
        "median": float(p50),
        "p90": float(p90),
        "max": float(m.max()),
        "pos_frac": float((m > 0).mean()),
        "npy": npy_path,
        "sample": [float(x) for x in m[:]],
    }

    with open(jsonl_path, "a", encoding="utf-8") as f:
        f.write(json.dumps(record, ensure_ascii=False) + "\n")


# Sanity check (writes to a temp folder)
import tempfile

with tempfile.TemporaryDirectory() as d:
    j = os.path.join(d, "margins.jsonl")
    compute_and_log_model_margin(torch.tensor([0.0, 1.0, 2.0]), epoch_dir=d, epoch=0, step=0, jsonl_path=j)
    print("Sanity Check: margin log files")
    print("  dir:", d)
    print("  files:", sorted(os.listdir(d)))


Sanity Check: margin log files
  dir: /tmp/tmplj9836h4
  files: ['margins.jsonl', 'step_00000.npy']


In [30]:
# 21) (Copied) src/trainers/dynamic_beta_dpo.py — DynamicBetaDPOConfig


@dataclass
class DynamicBetaDPOConfig:
    """Configuration for dynamic beta DPO training."""

    # Risk control
    delta: float = 0.1
    momentum: float = 0.05

    # Beta update
    beta_0: float = 0.1
    alpha: float = 0.005
    gamma: float = 2.0
    beta_min: float = 0.0
    beta_max: float = 2.0

    # Warmup
    warmup_steps: int = 120

    # Margin logging
    log_margins: bool = True
    log_dir: str = "logs/margins"
    jsonl_name: str = "margins.jsonl"
    save_per_rank: bool = False
    jsonl_sample_size: int = 32


# Sanity check
_dyn = DynamicBetaDPOConfig(warmup_steps=int(config["risk_test"]["beta_warmup"]))
print("Sanity Check: DynamicBetaDPOConfig ->", _dyn)


Sanity Check: DynamicBetaDPOConfig -> DynamicBetaDPOConfig(delta=0.1, momentum=0.05, beta_0=0.1, alpha=0.005, gamma=2.0, beta_min=0.0, beta_max=2.0, warmup_steps=5, log_margins=True, log_dir='logs/margins', jsonl_name='margins.jsonl', save_per_rank=False, jsonl_sample_size=32)


In [40]:
# 22) (Split) DynamicBetaDPOTrainer — class shell

class DynamicBetaDPOTrainer(DPOTrainer):
    """DPO Trainer with dynamic beta adjustment based on risk control.

    This class is assembled step-by-step across multiple cells so you can
    inspect and debug each method independently.
    """

    pass

print('Sanity Check: class shell defined ->', DynamicBetaDPOTrainer)


Sanity Check: class shell defined -> <class '__main__.DynamicBetaDPOTrainer'>


In [41]:
# 22.1) DynamicBetaDPOTrainer._get_rank

def _dynamic_beta__get_rank(self) -> int:
    acc = getattr(self, 'accelerator', None)
    if acc is not None:
        try:
            return int(acc.process_index)
        except Exception:
            pass
    return int(os.environ.get('RANK', '0'))

DynamicBetaDPOTrainer._get_rank = _dynamic_beta__get_rank

print('Sanity Check: attached -> DynamicBetaDPOTrainer._get_rank')


Sanity Check: attached -> DynamicBetaDPOTrainer._get_rank


In [42]:
# 22.2) DynamicBetaDPOTrainer._maybe_log_margins

def _dynamic_beta__maybe_log_margins(self, model_margin: torch.Tensor) -> None:
    if not self.dynamic_cfg.log_margins:
        return
    if (not self.dynamic_cfg.save_per_rank) and self._rank != 0:
        return

    epoch = getattr(self.state, 'epoch', None)
    epoch_i = int(epoch) if epoch is not None else 0
    epoch_dir = os.path.join(self._margin_base_dir, f'epoch_{epoch_i:03d}')
    os.makedirs(epoch_dir, exist_ok=True)

    jsonl_path = os.path.join(epoch_dir, self.dynamic_cfg.jsonl_name)
    step = int(getattr(self.state, 'global_step', 0))

    if self.dynamic_cfg.jsonl_sample_size and self.dynamic_cfg.jsonl_sample_size > 0:
        m = model_margin.detach().float().cpu().numpy()
        npy_path = os.path.join(epoch_dir, f'step_{step:05d}.npy')
        np.save(npy_path, m)

        p10, p50, p90 = np.percentile(m, [10, 50, 90])
        rec = {
            'epoch': int(epoch_i),
            'step': int(step),
            'batch_size': int(m.shape[0]),
            'mean': float(m.mean()),
            'std': float(m.std(ddof=0)),
            'min': float(m.min()),
            'p10': float(p10),
            'median': float(p50),
            'p90': float(p90),
            'max': float(m.max()),
            'pos_frac': float((m > 0).mean()),
            'npy': npy_path,
            'sample': [float(x) for x in m[: self.dynamic_cfg.jsonl_sample_size]],
        }
        with open(jsonl_path, 'a', encoding='utf-8') as f:
            f.write(json.dumps(rec, ensure_ascii=False) + '\n')
    else:
        compute_and_log_model_margin(
            model_margin=model_margin,
            epoch_dir=epoch_dir,
            epoch=epoch_i,
            step=step,
            jsonl_path=jsonl_path,
        )

DynamicBetaDPOTrainer._maybe_log_margins = _dynamic_beta__maybe_log_margins

print('Sanity Check: attached -> DynamicBetaDPOTrainer._maybe_log_margins')


Sanity Check: attached -> DynamicBetaDPOTrainer._maybe_log_margins


In [43]:
# 22.3) DynamicBetaDPOTrainer.__init__ (state + logging paths)

def _dynamic_beta__init__(self, *args, dynamic_cfg: DynamicBetaDPOConfig, **kwargs):
    # NOTE: since we're attaching this method post-class-definition, we call the base
    # initializer explicitly instead of using super().
    DPOTrainer.__init__(self, *args, **kwargs)

    self.dynamic_cfg = dynamic_cfg

    # Beta adjustment state
    self.beta = float(dynamic_cfg.beta_0)
    self._warmup_done = False
    self._warmup_count = 0
    self.warmup_threshold = WarmupQuantileAccumulator(q=(1 - self.dynamic_cfg.delta))
    self._ema: Optional[EMAUpdate] = None
    self.tau = 0.0

    # Bookkeeping
    self._last_stats: Dict[str, Any] = {}

    # Rank/process info (Accelerate)
    self._rank = self._get_rank()

    # Margin logging paths
    if self.dynamic_cfg.log_margins:
        base = self.dynamic_cfg.log_dir
        if self.dynamic_cfg.save_per_rank:
            base = os.path.join(base, f'rank_{self._rank}')
        os.makedirs(base, exist_ok=True)
        self._margin_base_dir = base
    else:
        self._margin_base_dir = None

DynamicBetaDPOTrainer.__init__ = _dynamic_beta__init__

print('Sanity Check: attached -> DynamicBetaDPOTrainer.__init__')


Sanity Check: attached -> DynamicBetaDPOTrainer.__init__


In [49]:
# 22.4) DynamicBetaDPOTrainer._concatenate_and_build_labels

def _dynamic_beta__concatenate_and_build_labels(
    self,
    prompt_input_ids: torch.Tensor,
    prompt_attention_mask: torch.Tensor,
    completion_input_ids: torch.Tensor,
    completion_attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Concatenate prompt + completion and build labels for log prob computation."""
    input_ids = torch.cat([prompt_input_ids, completion_input_ids], dim=1)
    attention_mask = torch.cat([prompt_attention_mask, completion_attention_mask], dim=1)

    labels = input_ids.clone()
    prompt_len = prompt_input_ids.shape[1]
    labels[:, :prompt_len] = -100
    labels[attention_mask == 0] = -100

    return input_ids, attention_mask, labels

DynamicBetaDPOTrainer._concatenate_and_build_labels = _dynamic_beta__concatenate_and_build_labels

print('Sanity Check: attached -> DynamicBetaDPOTrainer._concatenate_and_build_labels')


Sanity Check: attached -> DynamicBetaDPOTrainer._concatenate_and_build_labels


In [50]:
# 22.5) DynamicBetaDPOTrainer._build_concatenated_inputs

def _dynamic_beta__build_concatenated_inputs(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
    """Build concatenated (prompt+completion) sequences and labels for chosen/rejected."""
    chosen_input_ids, chosen_attention_mask, chosen_labels = self._concatenate_and_build_labels(
        prompt_input_ids=inputs['prompt_input_ids'],
        prompt_attention_mask=inputs['prompt_attention_mask'],
        completion_input_ids=inputs['chosen_input_ids'],
        completion_attention_mask=inputs['chosen_attention_mask'],
    )
    rejected_input_ids, rejected_attention_mask, rejected_labels = self._concatenate_and_build_labels(
        prompt_input_ids=inputs['prompt_input_ids'],
        prompt_attention_mask=inputs['prompt_attention_mask'],
        completion_input_ids=inputs['rejected_input_ids'],
        completion_attention_mask=inputs['rejected_attention_mask'],
    )

    return {
        'chosen_input_ids': chosen_input_ids,
        'chosen_attention_mask': chosen_attention_mask,
        'chosen_labels': chosen_labels,
        'rejected_input_ids': rejected_input_ids,
        'rejected_attention_mask': rejected_attention_mask,
        'rejected_labels': rejected_labels,
    }

DynamicBetaDPOTrainer._build_concatenated_inputs = _dynamic_beta__build_concatenated_inputs

print('Sanity Check: attached -> DynamicBetaDPOTrainer._build_concatenated_inputs')


Sanity Check: attached -> DynamicBetaDPOTrainer._build_concatenated_inputs


In [47]:
# 22.6) DynamicBetaDPOTrainer._debug_print_concat_shapes

def _dynamic_beta__debug_print_concat_shapes(self, inputs: Dict[str, Any], c: Dict[str, torch.Tensor]) -> None:
    """Print shapes/length stats to catch masking/truncation bugs."""
    prompt_padded_len = inputs['prompt_input_ids'].shape[1]
    chosen_comp_padded_len = inputs['chosen_input_ids'].shape[1]
    rejected_comp_padded_len = inputs['rejected_input_ids'].shape[1]

    prompt_actual = inputs['prompt_attention_mask'].sum(dim=1).float()
    chosen_comp_actual = inputs['chosen_attention_mask'].sum(dim=1).float()
    rejected_comp_actual = inputs['rejected_attention_mask'].sum(dim=1).float()

    chosen_concat_padded = c['chosen_input_ids'].shape[1]
    rejected_concat_padded = c['rejected_input_ids'].shape[1]
    chosen_concat_actual = c['chosen_attention_mask'].sum(dim=1).float()
    rejected_concat_actual = c['rejected_attention_mask'].sum(dim=1).float()

    chosen_valid = (c['chosen_labels'] != -100).sum(dim=1).float()
    rejected_valid = (c['rejected_labels'] != -100).sum(dim=1).float()

    print(f"\n[DEBUG Step {self._warmup_count}] Tensor shapes before forward pass:")
    print(f"  Batch size: {c['chosen_input_ids'].shape[0]}")
    print(f"  --- Input from batch ---")
    print(
        f"  prompt_input_ids: padded={prompt_padded_len}, actual={prompt_actual.mean():.1f} (min={prompt_actual.min():.0f}, max={prompt_actual.max():.0f})"
    )
    print(
        f"  chosen_completion: padded={chosen_comp_padded_len}, actual={chosen_comp_actual.mean():.1f} (min={chosen_comp_actual.min():.0f}, max={chosen_comp_actual.max():.0f})"
    )
    print(
        f"  rejected_completion: padded={rejected_comp_padded_len}, actual={rejected_comp_actual.mean():.1f} (min={rejected_comp_actual.min():.0f}, max={rejected_comp_actual.max():.0f})"
    )
    print(f"  --- After concatenation ---")
    print(
        f"  chosen_concat: padded={chosen_concat_padded}, actual={chosen_concat_actual.mean():.1f}, valid_for_loss={chosen_valid.mean():.1f} (min={chosen_valid.min():.0f}, max={chosen_valid.max():.0f})"
    )
    print(
        f"  rejected_concat: padded={rejected_concat_padded}, actual={rejected_concat_actual.mean():.1f}, valid_for_loss={rejected_valid.mean():.1f} (min={rejected_valid.min():.0f}, max={rejected_valid.max():.0f})"
    )

DynamicBetaDPOTrainer._debug_print_concat_shapes = _dynamic_beta__debug_print_concat_shapes

print('Sanity Check: attached -> DynamicBetaDPOTrainer._debug_print_concat_shapes')


Sanity Check: attached -> DynamicBetaDPOTrainer._debug_print_concat_shapes


In [48]:
# 22.7) DynamicBetaDPOTrainer._forward_policy_and_ref

def _dynamic_beta__forward_policy_and_ref(self, model: torch.nn.Module, c: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """Run forward passes for policy and reference; returns logits."""
    policy_chosen_logits = model(
        input_ids=c['chosen_input_ids'],
        attention_mask=c['chosen_attention_mask'],
    ).logits

    policy_rejected_logits = model(
        input_ids=c['rejected_input_ids'],
        attention_mask=c['rejected_attention_mask'],
    ).logits

    with torch.no_grad():
        ref_chosen_logits = self.ref_model(
            input_ids=c['chosen_input_ids'],
            attention_mask=c['chosen_attention_mask'],
        ).logits
        ref_rejected_logits = self.ref_model(
            input_ids=c['rejected_input_ids'],
            attention_mask=c['rejected_attention_mask'],
        ).logits

    return {
        'policy_chosen_logits': policy_chosen_logits,
        'policy_rejected_logits': policy_rejected_logits,
        'ref_chosen_logits': ref_chosen_logits,
        'ref_rejected_logits': ref_rejected_logits,
    }

DynamicBetaDPOTrainer._forward_policy_and_ref = _dynamic_beta__forward_policy_and_ref

print('Sanity Check: attached -> DynamicBetaDPOTrainer._forward_policy_and_ref')


Sanity Check: attached -> DynamicBetaDPOTrainer._forward_policy_and_ref


In [51]:
# 22.8) DynamicBetaDPOTrainer._compute_log_probs

def _dynamic_beta__compute_log_probs(self, logits: Dict[str, torch.Tensor], c: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """Compute per-sequence log probabilities for chosen/rejected under policy/ref."""
    policy_chosen_log_prob = compute_log_prob(logits=logits['policy_chosen_logits'], labels=c['chosen_labels'])
    policy_rejected_log_prob = compute_log_prob(logits=logits['policy_rejected_logits'], labels=c['rejected_labels'])
    ref_chosen_log_prob = compute_log_prob(logits=logits['ref_chosen_logits'], labels=c['chosen_labels'])
    ref_rejected_log_prob = compute_log_prob(logits=logits['ref_rejected_logits'], labels=c['rejected_labels'])

    return {
        'policy_chosen_log_prob': policy_chosen_log_prob,
        'policy_rejected_log_prob': policy_rejected_log_prob,
        'ref_chosen_log_prob': ref_chosen_log_prob,
        'ref_rejected_log_prob': ref_rejected_log_prob,
    }

DynamicBetaDPOTrainer._compute_log_probs = _dynamic_beta__compute_log_probs

print('Sanity Check: attached -> DynamicBetaDPOTrainer._compute_log_probs')


Sanity Check: attached -> DynamicBetaDPOTrainer._compute_log_probs


In [52]:
# 22.9) DynamicBetaDPOTrainer._debug_print_log_probs

def _dynamic_beta__debug_print_log_probs(self, lp: Dict[str, torch.Tensor]) -> None:
    chosen_ratio = lp['policy_chosen_log_prob'] - lp['ref_chosen_log_prob']
    rejected_ratio = lp['policy_rejected_log_prob'] - lp['ref_rejected_log_prob']
    margin = chosen_ratio - rejected_ratio

    print(f"\n[DEBUG Step {self._warmup_count}]")
    print(f"  policy_chosen_log_prob: {lp['policy_chosen_log_prob'].mean():.4f} (should be large negative)")
    print(f"  policy_rejected_log_prob: {lp['policy_rejected_log_prob'].mean():.4f}")
    print(f"  ref_chosen_log_prob: {lp['ref_chosen_log_prob'].mean():.4f}")
    print(f"  ref_rejected_log_prob: {lp['ref_rejected_log_prob'].mean():.4f}")
    print(f"  chosen_ratio (policy-ref): {chosen_ratio.mean():.4f} (should be ~0 at start)")
    print(f"  rejected_ratio (policy-ref): {rejected_ratio.mean():.4f} (should be ~0 at start)")
    print(f"  margin (chosen-rejected ratio): {margin.mean():.4f}")
    print(f"  margin min/max: {margin.min():.4f} / {margin.max():.4f}")
    print(f"  expected loss at margin=0: {0.693:.4f} (log(2))")

DynamicBetaDPOTrainer._debug_print_log_probs = _dynamic_beta__debug_print_log_probs

print('Sanity Check: attached -> DynamicBetaDPOTrainer._debug_print_log_probs')


Sanity Check: attached -> DynamicBetaDPOTrainer._debug_print_log_probs


In [53]:
# 22.10) DynamicBetaDPOTrainer._compute_dpo_loss

def _dynamic_beta__compute_dpo_loss(self, lp: Dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Compute (loss_tensor_per_sample, chosen_rewards, rejected_rewards)."""
    loss_ten, chosen_rewards, rejected_rewards = dpo_loss(
        policy_chosen_log_prob=lp['policy_chosen_log_prob'],
        policy_rejected_log_prob=lp['policy_rejected_log_prob'],
        ref_chosen_log_prob=lp['ref_chosen_log_prob'],
        ref_rejected_log_prob=lp['ref_rejected_log_prob'],
        beta=float(self.beta),
    )
    return loss_ten, chosen_rewards, rejected_rewards

DynamicBetaDPOTrainer._compute_dpo_loss = _dynamic_beta__compute_dpo_loss

print('Sanity Check: attached -> DynamicBetaDPOTrainer._compute_dpo_loss')


Sanity Check: attached -> DynamicBetaDPOTrainer._compute_dpo_loss


In [54]:
# 22.11) DynamicBetaDPOTrainer._compute_margin

def _dynamic_beta__compute_margin(self, lp: Dict[str, torch.Tensor]) -> torch.Tensor:
    return margin_compute(
        policy_chosen_log_prob=lp['policy_chosen_log_prob'],
        policy_rejected_log_prob=lp['policy_rejected_log_prob'],
        ref_chosen_log_prob=lp['ref_chosen_log_prob'],
        ref_rejected_log_prob=lp['ref_rejected_log_prob'],
    )

DynamicBetaDPOTrainer._compute_margin = _dynamic_beta__compute_margin

print('Sanity Check: attached -> DynamicBetaDPOTrainer._compute_margin')


Sanity Check: attached -> DynamicBetaDPOTrainer._compute_margin


In [55]:
# 22.12) DynamicBetaDPOTrainer._dynamic_beta_update (warmup/EMA + beta update)

def _dynamic_beta__dynamic_beta_update(self, model_margin: torch.Tensor, loss: torch.Tensor) -> None:
    """Update warmup/EMA thresholding and beta; mirrors repo trainer semantics."""
    with torch.no_grad():
        self._warmup_count += 1

        if (not self._warmup_done) and (self._warmup_count <= self.dynamic_cfg.warmup_steps):
            self.warmup_threshold.update(model_margin)

            if self._warmup_count == self.dynamic_cfg.warmup_steps:
                tau0 = self.warmup_threshold.finalize()
                self._ema = EMAUpdate(
                    tau_0=tau0,
                    q=1.0 - float(self.dynamic_cfg.delta),
                    momentum=float(self.dynamic_cfg.momentum),
                )
                self.tau = float(tau0)
                self._warmup_done = True
        else:
            # Post-warmup: update tau + beta.
            if self._ema is not None:
                self.tau = float(self._ema.update_tau(model_margin))

            if self.tau is not None:
                p_hat = empirical_over_threshold_proportion(model_margin, self.tau)
            else:
                p_hat = 0.0

            fail = risk_test(p_hat, float(self.dynamic_cfg.delta))

            beta_new, u_k, s_k, alpha = update_beta(
                beta=float(self.beta),
                p_hat=float(p_hat),
                delta=float(self.dynamic_cfg.delta),
                alpha=float(self.dynamic_cfg.alpha),
                n=int(model_margin.numel()),
                gamma=float(self.dynamic_cfg.gamma),
                beta_min=float(self.dynamic_cfg.beta_min),
                beta_max=float(self.dynamic_cfg.beta_max),
            )
            self.beta = float(beta_new)

            self._last_stats = {
                'p_hat': float(p_hat),
                'tau': float(self.tau) if self.tau is not None else None,
                'fail': int(fail),
                'u_k': float(u_k),
                's_k': float(s_k),
                'alpha': float(alpha),
            }

        # Margin logging (matches repo: logged during warmup too)
        self._maybe_log_margins(model_margin)

        # Trainer logging (matches repo)
        log_payload = {
            'dpo/beta': float(self.beta),
            'dpo/margin_mean': float(model_margin.mean().item()),
            'dpo/loss': float(loss.detach().float().item()),
        }
        if self._warmup_done and self._last_stats:
            log_payload.update(
                {
                    'risk/p_hat': self._last_stats.get('p_hat', 0.0),
                    'risk/tau': self._last_stats.get('tau', 0.0),
                    'risk/fail': self._last_stats.get('fail', 0),
                    'risk/u_k': self._last_stats.get('u_k', 0.0),
                    'risk/s_k': self._last_stats.get('s_k', 0.0),
                }
            )

        try:
            self.log(log_payload)
        except Exception:
            pass

DynamicBetaDPOTrainer._dynamic_beta_update = _dynamic_beta__dynamic_beta_update

print('Sanity Check: attached -> DynamicBetaDPOTrainer._dynamic_beta_update')


Sanity Check: attached -> DynamicBetaDPOTrainer._dynamic_beta_update


In [56]:
# 22.13) DynamicBetaDPOTrainer.compute_loss (orchestrator)

def _dynamic_beta__compute_loss(self, model, inputs, return_outputs: bool = False, **kwargs):
    c = self._build_concatenated_inputs(inputs)

    if self._warmup_count <= 3 or self._warmup_count % 50 == 0:
        self._debug_print_concat_shapes(inputs, c)

    logits = self._forward_policy_and_ref(model, c)
    lp = self._compute_log_probs(logits, c)

    if self._warmup_count <= 3 or self._warmup_count % 50 == 0:
        self._debug_print_log_probs(lp)

    loss_ten, chosen_rewards, rejected_rewards = self._compute_dpo_loss(lp)
    loss = loss_ten.mean()

    model_margin = self._compute_margin(lp)

    self._dynamic_beta_update(model_margin=model_margin, loss=loss)

    if return_outputs:
        return loss, {'chosen': chosen_rewards, 'rejected': rejected_rewards}
    return loss

DynamicBetaDPOTrainer.compute_loss = _dynamic_beta__compute_loss

print('Sanity Check: attached -> DynamicBetaDPOTrainer.compute_loss')


Sanity Check: attached -> DynamicBetaDPOTrainer.compute_loss


In [57]:
# 22.14) DynamicBetaDPOTrainer.evaluate

def _dynamic_beta__evaluate(self, eval_dataset=None, **kwargs):
    metrics = DPOTrainer.evaluate(self, eval_dataset=eval_dataset, **kwargs)
    if self.args.report_to and 'wandb' in self.args.report_to:
        try:
            self.log({'eval/loss': float(metrics['eval_loss'])})
        except Exception:
            pass
    return metrics

DynamicBetaDPOTrainer.evaluate = _dynamic_beta__evaluate

print('Sanity Check: attached -> DynamicBetaDPOTrainer.evaluate')


Sanity Check: attached -> DynamicBetaDPOTrainer.evaluate


In [58]:
# 22.15) Final sanity check: class surface area

_expected = [
    '__init__',
    '_get_rank',
    '_maybe_log_margins',
    '_concatenate_and_build_labels',
    '_build_concatenated_inputs',
    '_debug_print_concat_shapes',
    '_forward_policy_and_ref',
    '_compute_log_probs',
    '_debug_print_log_probs',
    '_compute_dpo_loss',
    '_compute_margin',
    '_dynamic_beta_update',
    'compute_loss',
    'evaluate',
]
missing = [name for name in _expected if not hasattr(DynamicBetaDPOTrainer, name)]

print('Sanity Check: DynamicBetaDPOTrainer assembled')
print('  missing_methods:', missing)


Sanity Check: DynamicBetaDPOTrainer assembled
  missing_methods: []


In [59]:
# 23) Build TRL DPOConfig (training args)

dpo_train_args = config["dpo_training"]

# Match CLI behavior: fp16/bf16 flags derived from config["precision"].
fp16 = str(config.get("precision", "")).lower() == "fp16"
bf16 = str(config.get("precision", "")).lower() == "bf16"

training_args = DPOConfig(
    learning_rate=float(dpo_train_args["learning_rate"]),
    per_device_train_batch_size=int(dpo_train_args["batch_size"]),
    per_device_eval_batch_size=int(dpo_train_args["eval_batch_size"]),
    num_train_epochs=int(dpo_train_args["epochs"]),
    logging_steps=int(dpo_train_args["log_steps"]),
    eval_strategy="steps",
    eval_steps=int(dpo_train_args["eval_steps"]),
    save_strategy="steps",
    save_steps=int(dpo_train_args["save_steps"]),
    fp16=fp16,
    bf16=bf16,
    gradient_accumulation_steps=int(dpo_train_args["gradient_accumulation"]),
    max_grad_norm=float(dpo_train_args["max_grad_norm"]),
    warmup_steps=int(dpo_train_args["warmup_steps"]),
    report_to=[],
    run_name=str(dpo_train_args["run_name"]),
    remove_unused_columns=False,
    output_dir=str(dpo_train_args["save_dir"]),
    max_prompt_length=int(dataset_cfg.get("max_prompt_length", 512)),
    max_completion_length=int(dataset_cfg.get("max_completion_length", 256)),
    max_length=int(dataset_cfg.get("max_length", 1024)),
    truncation_mode=str(dataset_cfg.get("truncation_mode", "keep_end")),
)

print("Sanity Check: training_args")
print("  output_dir:", training_args.output_dir)
print("  batch_size:", training_args.per_device_train_batch_size)
print("  max_length:", training_args.max_length)


Sanity Check: training_args
  output_dir: trl_dynamic_beta_dpo_debug
  batch_size: 2
  max_length: 512


In [60]:
# 24) Build DynamicBetaDPOConfig from config dict

risk = config["risk_test"]
beta_up = config["beta_update"]
margin_log = config["margin_log"]

# Default to no disk writes in notebooks; flip to True if you want margin .npy/.jsonl logs.
LOG_MARGINS = False

# IMPORTANT: warmup_steps kept small by default in this notebook.

dyn_cfg = DynamicBetaDPOConfig(
    delta=float(risk["delta"]),
    momentum=float(risk["lambda"]),
    warmup_steps=int(risk["beta_warmup"]),
    beta_0=float(beta_up["beta_0"]),
    
    alpha=float(beta_up["alpha"]),
    gamma=float(beta_up["gamma"]),
    beta_min=float(beta_up["beta_min"]),
    beta_max=float(beta_up["beta_max"]),
    log_margins=bool(LOG_MARGINS),
    log_dir=str(margin_log["log_dir"]),
    jsonl_sample_size=int(margin_log["jsonl_sample_size"]),
    save_per_rank=bool(margin_log["save_per_rank"]),
)

print("Sanity Check: dyn_cfg")
print("  warmup_steps:", dyn_cfg.warmup_steps)
print("  delta:", dyn_cfg.delta, "momentum:", dyn_cfg.momentum)
print("  beta0:", dyn_cfg.beta_0, "alpha:", dyn_cfg.alpha, "gamma:", dyn_cfg.gamma)


Sanity Check: dyn_cfg
  warmup_steps: 5
  delta: 0.1 momentum: 0.1
  beta0: 0.1 alpha: 0.005 gamma: 2.0


In [61]:
# 25) Initialize the DynamicBetaDPOTrainer

trainer = DynamicBetaDPOTrainer(
    model=policy,
    ref_model=ref_model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    dynamic_cfg=dyn_cfg,
    processing_class=tok,
)

print("Sanity Check: trainer")
print("  trainer.beta:", trainer.beta)
print("  warmup_done:", trainer._warmup_done, "warmup_count:", trainer._warmup_count)
print("  train_len:", len(trainer.train_dataset), "eval_len:", len(trainer.eval_dataset))


Extracting prompt in train dataset: 100%|██████████| 1445/1445 [00:00<00:00, 4171.66 examples/s]
Applying chat template to train dataset: 100%|██████████| 1445/1445 [00:00<00:00, 16376.25 examples/s]
Tokenizing train dataset: 100%|██████████| 1445/1445 [-1:59:59<00:00, -809.13 examples/s]
Extracting prompt in eval dataset: 100%|██████████| 161/161 [00:00<00:00, 9134.95 examples/s]
Applying chat template to eval dataset: 100%|██████████| 161/161 [00:00<00:00, 16971.17 examples/s]
Tokenizing eval dataset: 100%|██████████| 161/161 [00:00<00:00, 1517.10 examples/s]
The model is already on multiple devices. Skipping the move to device specified in `args`.


Sanity Check: trainer
  trainer.beta: 0.1
  warmup_done: False warmup_count: 0
  train_len: 1445 eval_len: 161


In [62]:
# 26) Inspect one collated batch (keys + tensor shapes)

loader = trainer.get_train_dataloader()
batch = next(iter(loader))

print("Sanity Check: batch keys")
print("  keys:", sorted(batch.keys()))
for k, v in batch.items():
    if torch.is_tensor(v):
        print(f"  {k}: shape={tuple(v.shape)} dtype={v.dtype}")
    else:
        print(f"  {k}: type={type(v)}")


Sanity Check: batch keys
  keys: ['chosen_attention_mask', 'chosen_input_ids', 'prompt_attention_mask', 'prompt_input_ids', 'rejected_attention_mask', 'rejected_input_ids']
  prompt_input_ids: shape=(2, 106) dtype=torch.int64
  prompt_attention_mask: shape=(2, 106) dtype=torch.int64
  chosen_input_ids: shape=(2, 17) dtype=torch.int64
  chosen_attention_mask: shape=(2, 17) dtype=torch.int64
  rejected_input_ids: shape=(2, 43) dtype=torch.int64
  rejected_attention_mask: shape=(2, 43) dtype=torch.int64


In [64]:
# 27) Move batch to the right device(s) (Trainer internal)

batch_prepared = trainer._prepare_inputs(batch)

print("Sanity Check: prepared batch tensors")
for k, v in batch_prepared.items():
    if torch.is_tensor(v):
        print(f"  {k}: device={v.device} dtype={v.dtype} shape={tuple(v.shape)}")


Sanity Check: prepared batch tensors
  prompt_input_ids: device=cuda:0 dtype=torch.int64 shape=(2, 106)
  prompt_attention_mask: device=cuda:0 dtype=torch.int64 shape=(2, 106)
  chosen_input_ids: device=cuda:0 dtype=torch.int64 shape=(2, 17)
  chosen_attention_mask: device=cuda:0 dtype=torch.int64 shape=(2, 17)
  rejected_input_ids: device=cuda:0 dtype=torch.int64 shape=(2, 43)
  rejected_attention_mask: device=cuda:0 dtype=torch.int64 shape=(2, 43)


In [65]:
# 28.1) Build concatenated inputs (prompt+chosen / prompt+rejected)

c = trainer._build_concatenated_inputs(batch_prepared)

print('Sanity Check: concatenated tensors')
for k in [
    'chosen_input_ids',
    'chosen_attention_mask',
    'chosen_labels',
    'rejected_input_ids',
    'rejected_attention_mask',
    'rejected_labels',
]:
    v = c[k]
    print(f'  {k}: shape={tuple(v.shape)} dtype={v.dtype} device={v.device}')

chosen_valid = (c['chosen_labels'] != -100).sum(dim=1)
rejected_valid = (c['rejected_labels'] != -100).sum(dim=1)
print('  chosen_valid_tokens(mean/min/max):', float(chosen_valid.float().mean()), int(chosen_valid.min()), int(chosen_valid.max()))
print('  rejected_valid_tokens(mean/min/max):', float(rejected_valid.float().mean()), int(rejected_valid.min()), int(rejected_valid.max()))


Sanity Check: concatenated tensors
  chosen_input_ids: shape=(2, 123) dtype=torch.int64 device=cuda:0
  chosen_attention_mask: shape=(2, 123) dtype=torch.int64 device=cuda:0
  chosen_labels: shape=(2, 123) dtype=torch.int64 device=cuda:0
  rejected_input_ids: shape=(2, 149) dtype=torch.int64 device=cuda:0
  rejected_attention_mask: shape=(2, 149) dtype=torch.int64 device=cuda:0
  rejected_labels: shape=(2, 149) dtype=torch.int64 device=cuda:0
  chosen_valid_tokens(mean/min/max): 12.0 7 17
  rejected_valid_tokens(mean/min/max): 35.0 27 43


In [66]:
# 28.1a) Audit: label masking invariants (prompt masked, padding masked)

# Expected invariants:
# - All prompt tokens should be masked (-100) in labels.
# - All padding tokens (attention_mask==0) should be masked (-100).
# - Completion tokens should have some unmasked labels (otherwise log_prob becomes ~0).

prompt_len_padded = int(batch_prepared['prompt_input_ids'].shape[1])

for side in ['chosen', 'rejected']:
    labels = c[f'{side}_labels']
    attn = c[f'{side}_attention_mask']

    # Prompt region
    prompt_region = labels[:, :prompt_len_padded]
    prompt_mask_ok = (prompt_region == -100).all().item()

    # Padding region
    pad_mask_ok = (labels[attn == 0] == -100).all().item()

    # Completion valid tokens (non-masked)
    valid = (labels != -100).sum(dim=1)

    print(f'[{side}] prompt_mask_ok={prompt_mask_ok} pad_mask_ok={pad_mask_ok} valid_tokens={valid.tolist()}')
    assert prompt_mask_ok, f'{side}: prompt tokens not fully masked'
    assert pad_mask_ok, f'{side}: padding tokens not fully masked'
    assert (valid > 0).all().item(), f'{side}: all completion tokens masked; log_prob will be 0'

print('Sanity Check: label masking invariants passed')


[chosen] prompt_mask_ok=True pad_mask_ok=True valid_tokens=[7, 17]
[rejected] prompt_mask_ok=True pad_mask_ok=True valid_tokens=[27, 43]
Sanity Check: label masking invariants passed


## Deep Inspection: concatenation + label masking (raw tokens)

This section **prints the raw, decoded text** for the prompt, chosen, and rejected sequences, then inspects the **concatenated input ids** and **labels mask** position‑by‑position.

**What to verify**
- The **prompt** portion is **fully masked** in labels (`-100`).
- The **completion** portion has **non‑masked labels** (otherwise `compute_log_prob` becomes 0).
- **Padding** positions are masked (`-100`) and have `attention_mask == 0`.
- The decoded concatenation is **prompt + completion** (no accidental truncation or interleaving).


In [87]:
# 28.1b) Raw inspection: decode prompt/choices + visualize label mask

# Default is SAFE preview (no huge dumps). Flip to True only if needed.
PRINT_FULL_DECODED_TEXT = False
PREVIEW_HEAD_CHARS = 400
PREVIEW_TAIL_CHARS = 400

sample_idx = 0

prompt_ids = batch_prepared['prompt_input_ids'][sample_idx]
prompt_mask = batch_prepared['prompt_attention_mask'][sample_idx]

chosen_ids = batch_prepared['chosen_input_ids'][sample_idx]
chosen_mask = batch_prepared['chosen_attention_mask'][sample_idx]

rejected_ids = batch_prepared['rejected_input_ids'][sample_idx]
rejected_mask = batch_prepared['rejected_attention_mask'][sample_idx]

chosen_concat_ids = c['chosen_input_ids'][sample_idx]
chosen_concat_mask = c['chosen_attention_mask'][sample_idx]
chosen_labels = c['chosen_labels'][sample_idx]

rejected_concat_ids = c['rejected_input_ids'][sample_idx]
rejected_concat_mask = c['rejected_attention_mask'][sample_idx]
rejected_labels = c['rejected_labels'][sample_idx]

prompt_len = int(prompt_mask.sum().item())
chosen_len = int(chosen_mask.sum().item())
rejected_len = int(rejected_mask.sum().item())

print('Sanity Check: lengths')
print('  prompt_len:', prompt_len, 'prompt_padded:', int(prompt_ids.shape[0]))
print('  chosen_len:', chosen_len, 'chosen_padded:', int(chosen_ids.shape[0]))
print('  rejected_len:', rejected_len, 'rejected_padded:', int(rejected_ids.shape[0]))
print('  chosen_concat_len:', int(chosen_concat_mask.sum().item()), 'padded:', int(chosen_concat_ids.shape[0]))
print('  rejected_concat_len:', int(rejected_concat_mask.sum().item()), 'padded:', int(rejected_concat_ids.shape[0]))


def preview(text: str) -> str:
    if PRINT_FULL_DECODED_TEXT:
        return text
    if len(text) <= PREVIEW_HEAD_CHARS + PREVIEW_TAIL_CHARS + 40:
        return text
    head = text[:PREVIEW_HEAD_CHARS]
    tail = text[-PREVIEW_TAIL_CHARS:]
    return head + "\n...<snip>...\n" + tail


prompt_text = tok.decode(prompt_ids[:prompt_len].tolist(), skip_special_tokens=False)
chosen_text = tok.decode(chosen_ids[:chosen_len].tolist(), skip_special_tokens=False)
rejected_text = tok.decode(rejected_ids[:rejected_len].tolist(), skip_special_tokens=False)

print()
print('Sanity Check: decoded prompt (preview)')
print(preview(prompt_text))

print()
print('Sanity Check: decoded chosen completion (preview)')
print(preview(chosen_text))

print()
print('Sanity Check: decoded rejected completion (preview)')
print(preview(rejected_text))

chosen_concat_text = tok.decode(chosen_concat_ids[chosen_concat_mask == 1].tolist(), skip_special_tokens=False)
rejected_concat_text = tok.decode(rejected_concat_ids[rejected_concat_mask == 1].tolist(), skip_special_tokens=False)

print()
print('Sanity Check: decoded chosen concat (preview)')
print(preview(chosen_concat_text))

print()
print('Sanity Check: decoded rejected concat (preview)')
print(preview(rejected_concat_text))


def mask_string(attn: torch.Tensor, labels: torch.Tensor) -> str:
    chars = []
    for a, y in zip(attn.tolist(), labels.tolist()):
        if a == 0:
            chars.append('0')
        elif y == -100:
            chars.append('P')
        else:
            chars.append('C')
    return ''.join(chars)

chosen_mask_str = mask_string(chosen_concat_mask, chosen_labels)
rejected_mask_str = mask_string(rejected_concat_mask, rejected_labels)

print()
print('Sanity Check: mask legend')
print('  P = prompt masked (-100)')
print('  C = completion token (label used)')
print('  0 = padding (attention_mask==0)')

print()
print('Sanity Check: chosen mask string (first 200)')
print(chosen_mask_str[:200])

print()
print('Sanity Check: rejected mask string (first 200)')
print(rejected_mask_str[:200])

# Compact token/label table for first N positions
N = min(8000, int(chosen_concat_ids.shape[0]))

print()
print('Sanity Check: first tokens + labels (chosen)')
for i in range(N):
    tok_id = int(chosen_concat_ids[i].item())
    tok_txt = tok.decode([tok_id], skip_special_tokens=False).replace("\n", "\n")
    lbl = int(chosen_labels[i].item())
    att = int(chosen_concat_mask[i].item())
    print(f'{i:03d}  id={tok_id:>6}  attn={att}  label={lbl:>5}  tok={tok_txt}')


Sanity Check: lengths
  prompt_len: 51 prompt_padded: 106
  chosen_len: 7 chosen_padded: 17
  rejected_len: 27 rejected_padded: 43
  chosen_concat_len: 58 padded: 123
  rejected_concat_len: 78 padded: 149

Sanity Check: decoded prompt (preview)
<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|>

Sanity Check: decoded chosen completion (preview)
What do you mean?<|eot_id|><|eot_id|>

Sanity Check: decoded rejected completion (preview)
Do you want me to help you in your decision making process, or do you want me to make this decision fo

## Root-causing duplicated `<|eot_id|>`

When you decode with `skip_special_tokens=False`, you will see every special token the model will actually receive.

A duplicated `<|eot_id|><|eot_id|>` at the end of the concatenated sequence usually comes from one of these:

1) **EOS is already present in the completion string** (as literal text or already-templated output), *and* tokenization/apply_chat_template adds EOS again.

2) **Chat template applied twice**:
   - Your dataset already contains Llama-3 chat tokens like `<|start_header_id|>` / `<|eot_id|>`, and you run `apply_chat_template_to_dataset()` again.

3) **Tokenization adds EOS to both prompt and completion** (depends on the collator/tokenizer settings), so concatenation ends with two EOS tokens.

The next cell checks which one is happening by inspecting **raw strings** and **token ids**.


In [88]:
# 28.1c) Diagnose: where does the extra <|eot_id|> come from?

# This cell does not print the full raw text (to avoid dumping unsafe content).
# It focuses on token-level structure.

example = train_ds[0]

prompt_str = str(example.get('prompt', ''))
chosen_str = str(example.get('chosen', ''))
rejected_str = str(example.get('rejected', ''))

print('Sanity Check: raw string contains chat tokens?')
for name, s in [('prompt', prompt_str), ('chosen', chosen_str), ('rejected', rejected_str)]:
    has_start = '<|start_header_id|>' in s
    has_eot_text = '<|eot_id|>' in s
    print(f'  {name}: has_<|start_header_id|>={has_start} has_literal_<|eot_id|>={has_eot_text} len={len(s)}')

print('Sanity Check: tokenizer special tokens')
print('  eos_token:', repr(tok.eos_token), 'eos_token_id:', tok.eos_token_id)
print('  pad_token:', repr(tok.pad_token), 'pad_token_id:', tok.pad_token_id)

# Tokenize (string -> ids) for comparison. TRL's collator may differ.
enc_prompt = tok(prompt_str, add_special_tokens=False)
enc_chosen = tok(chosen_str, add_special_tokens=False)

prompt_ids_text = torch.tensor(enc_prompt['input_ids'], dtype=torch.long)
chosen_ids_text = torch.tensor(enc_chosen['input_ids'], dtype=torch.long)


def eos_positions(ids: torch.Tensor) -> list[int]:
    if tok.eos_token_id is None:
        return []
    return (ids == tok.eos_token_id).nonzero(as_tuple=False).view(-1).tolist()

print('Sanity Check: EOS token occurrences (string-tokenized, add_special_tokens=False)')
print('  prompt eos positions (last 10):', eos_positions(prompt_ids_text)[-10:])
print('  chosen eos positions (last 10):', eos_positions(chosen_ids_text)[-10:])

# Now inspect what the trainer actually uses (collated + concatenated tensors)

def tail(ids: torch.Tensor, n: int = 20) -> list[int]:
    n = min(n, int(ids.numel()))
    return [int(x) for x in ids[-n:].tolist()]

sample_idx = 0
chosen_concat_ids = c['chosen_input_ids'][sample_idx]
chosen_concat_mask = c['chosen_attention_mask'][sample_idx]

chosen_trim = chosen_concat_ids[chosen_concat_mask == 1]

print('Sanity Check: EOS at the very end? (trainer inputs)')
print('  chosen_concat_tail_ids:', tail(chosen_trim, 30))
if tok.eos_token_id is not None:
    last_two = chosen_trim[-2:].tolist() if chosen_trim.numel() >= 2 else chosen_trim.tolist()
    print('  chosen_concat_last_two_equal_eos?:', [int(x) == int(tok.eos_token_id) for x in last_two])

chosen_tail_text = tok.decode(chosen_trim[-80:].tolist(), skip_special_tokens=False)
print('Sanity Check: decoded tail (last ~80 tokens)')
print(chosen_tail_text[-800:])

print()
print('Interpretation guide:')
print('  - If raw chosen/rejected strings contain literal "<|eot_id|>", you are double-encoding template tokens.')
print('  - If chosen_trim ends with [eos, eos], tokenization/collation appended EOS twice (or EOS was in-text).')
print('  - If prompt has chat tokens already, do NOT run apply_chat_template_to_dataset() again.')


Sanity Check: raw string contains chat tokens?
  prompt: has_<|start_header_id|>=True has_literal_<|eot_id|>=True len=258
  chosen: has_<|start_header_id|>=False has_literal_<|eot_id|>=True len=59
  rejected: has_<|start_header_id|>=False has_literal_<|eot_id|>=True len=43
Sanity Check: tokenizer special tokens
  eos_token: '<|eot_id|>' eos_token_id: 128009
  pad_token: '<|eot_id|>' pad_token_id: 128009
Sanity Check: EOS token occurrences (string-tokenized, add_special_tokens=False)
  prompt eos positions (last 10): [25, 37]
  chosen eos positions (last 10): [14]
Sanity Check: EOS at the very end? (trainer inputs)
  chosen_concat_tail_ids: [128007, 271, 12978, 757, 311, 10491, 11, 1288, 358, 13054, 264, 1579, 2978, 477, 323, 36256, 2978, 30, 128009, 128006, 78191, 128007, 271, 3923, 656, 499, 3152, 30, 128009, 128009]
  chosen_concat_last_two_equal_eos?: [True, True]
Sanity Check: decoded tail (last ~80 tokens)
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowle

In [69]:
# 28.2) Forward pass (policy + reference) -> logits

trainer.model.train()

logits = trainer._forward_policy_and_ref(trainer.model, c)

print('Sanity Check: logits')
for k, v in logits.items():
    print(f'  {k}: shape={tuple(v.shape)} dtype={v.dtype} device={v.device}')


Sanity Check: logits
  policy_chosen_logits: shape=(2, 123, 128256) dtype=torch.bfloat16 device=cuda:0
  policy_rejected_logits: shape=(2, 149, 128256) dtype=torch.bfloat16 device=cuda:0
  ref_chosen_logits: shape=(2, 123, 128256) dtype=torch.float32 device=cuda:0
  ref_rejected_logits: shape=(2, 149, 128256) dtype=torch.float32 device=cuda:0


In [70]:
# 28.2a) Audit: grad flow (policy logits require grad; ref logits do not)

# Fallback in case the debug helper cell was not run.
if 'summarize_tensor' not in globals():
    def summarize_tensor(name: str, t: torch.Tensor) -> None:
        t_detached = t.detach()
        stats = {
            'shape': tuple(t_detached.shape),
            'dtype': str(t_detached.dtype),
            'device': str(t_detached.device),
            'min': float(t_detached.min().float().cpu()) if t_detached.numel() else None,
            'max': float(t_detached.max().float().cpu()) if t_detached.numel() else None,
            'mean': float(t_detached.float().mean().cpu()) if t_detached.numel() else None,
            'std': float(t_detached.float().std(unbiased=False).cpu()) if t_detached.numel() else None,
            'requires_grad': bool(t.requires_grad),
        }
        print(f'{name}:', stats)

for k, v in logits.items():
    summarize_tensor(k, v)

assert logits['policy_chosen_logits'].requires_grad, 'policy_chosen_logits should require grad'
assert logits['policy_rejected_logits'].requires_grad, 'policy_rejected_logits should require grad'
assert not logits['ref_chosen_logits'].requires_grad, 'ref_chosen_logits should NOT require grad'
assert not logits['ref_rejected_logits'].requires_grad, 'ref_rejected_logits should NOT require grad'

print('Sanity Check: grad flow looks correct')


policy_chosen_logits: {'shape': (2, 123, 128256), 'dtype': 'torch.bfloat16', 'device': 'cuda:0', 'min': -29.125, 'max': 41.25, 'mean': 0.4087529480457306, 'std': 2.8631155490875244, 'requires_grad': True}
policy_rejected_logits: {'shape': (2, 149, 128256), 'dtype': 'torch.bfloat16', 'device': 'cuda:0', 'min': -28.625, 'max': 41.25, 'mean': 0.7949432134628296, 'std': 2.940920114517212, 'requires_grad': True}
ref_chosen_logits: {'shape': (2, 123, 128256), 'dtype': 'torch.float32', 'device': 'cuda:0', 'min': -29.125, 'max': 41.25, 'mean': 0.4087529480457306, 'std': 2.8631155490875244, 'requires_grad': False}
ref_rejected_logits: {'shape': (2, 149, 128256), 'dtype': 'torch.float32', 'device': 'cuda:0', 'min': -28.625, 'max': 41.25, 'mean': 0.7949432134628296, 'std': 2.940920114517212, 'requires_grad': False}
Sanity Check: grad flow looks correct


In [72]:
# 28.3) Log-probabilities for chosen/rejected under policy/ref

lp = trainer._compute_log_probs(logits, c)

print('Sanity Check: log_probs')
for k, v in lp.items():
    _v = v.detach()
    print(f'  {k}: shape={tuple(_v.shape)} mean={float(_v.mean()):.4f} min={float(_v.min()):.4f} max={float(_v.max()):.4f}')

# Optional: full debug print (matches trainer's internal debug)
if trainer._warmup_count <= 3 or trainer._warmup_count % 50 == 0:
    trainer._debug_print_log_probs(lp)


Sanity Check: log_probs
  policy_chosen_log_prob: shape=(2,) mean=-53.7500 min=-61.2500 max=-46.2500
  policy_rejected_log_prob: shape=(2,) mean=-97.5000 min=-119.5000 max=-75.5000
  ref_chosen_log_prob: shape=(2,) mean=-53.7080 min=-61.2424 max=-46.1736
  ref_rejected_log_prob: shape=(2,) mean=-97.4445 min=-119.3337 max=-75.5553

[DEBUG Step 0]
  policy_chosen_log_prob: -53.7500 (should be large negative)
  policy_rejected_log_prob: -97.5000
  ref_chosen_log_prob: -53.7080
  ref_rejected_log_prob: -97.4445
  chosen_ratio (policy-ref): -0.0420 (should be ~0 at start)
  rejected_ratio (policy-ref): -0.0555 (should be ~0 at start)
  margin (chosen-rejected ratio): 0.0135
  margin min/max: -0.1318 / 0.1587
  expected loss at margin=0: 0.6930 (log(2))


In [74]:
# 28.3a) Audit: token-level logprob sum matches compute_log_prob

# Fallbacks in case the debug helper cell was not run.
if 'assert_close' not in globals():
    def assert_close(name: str, a: torch.Tensor, b: torch.Tensor, *, atol: float = 1e-5, rtol: float = 1e-5) -> None:
        diff = (a - b).detach().abs()
        denom = (b.detach().abs() + 1e-12)
        rel = diff / denom
        max_abs = float(diff.max().cpu()) if diff.numel() else 0.0
        max_rel = float(rel.max().cpu()) if rel.numel() else 0.0
        assert (diff <= atol).all().item() or (rel <= rtol).all().item(), (
            f'{name} not close: max_abs={max_abs:.3e}, max_rel={max_rel:.3e} (atol={atol}, rtol={rtol})'
        )

if 'assert_all_finite' not in globals():
    def assert_all_finite(name: str, t: torch.Tensor) -> None:
        ok = torch.isfinite(t.detach()).all().item() if t.numel() else True
        assert ok, f'{name} contains NaN/Inf'

# compute_log_prob returns sum over non-masked token positions.

def token_level_logprob_sum(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    # Match compute_log_prob shifting
    logits_s = logits[:, :-1, :]
    labels_s = labels[:, 1:].clone()

    mask = labels_s != -100
    labels_s[labels_s == -100] = 0

    per_tok = torch.gather(logits_s.log_softmax(-1), dim=2, index=labels_s.unsqueeze(2)).squeeze(2)
    return (per_tok * mask).sum(-1)

for side in ['chosen', 'rejected']:
    pol_sum = token_level_logprob_sum(logits[f'policy_{side}_logits'], c[f'{side}_labels'])
    ref_sum = token_level_logprob_sum(logits[f'ref_{side}_logits'], c[f'{side}_labels'])

    assert_close(f'policy_{side}_log_prob', pol_sum, lp[f'policy_{side}_log_prob'])
    assert_close(f'ref_{side}_log_prob', ref_sum, lp[f'ref_{side}_log_prob'])

    assert_all_finite(f'policy_{side}_log_prob', lp[f'policy_{side}_log_prob'])
    assert_all_finite(f'ref_{side}_log_prob', lp[f'ref_{side}_log_prob'])

print('Sanity Check: compute_log_prob matches token-level sum')


Sanity Check: compute_log_prob matches token-level sum


In [None]:
# 28.4) DPO loss + rewards (uses current trainer.beta)

loss_ten, chosen_rewards, rejected_rewards = trainer._compute_dpo_loss(lp)
loss = loss_ten.mean()

print('Sanity Check: loss/rewards')
print('  beta:', trainer.beta)
print('  loss_mean:', float(loss.detach().float().cpu()))
print('  chosen_rewards.mean:', float(chosen_rewards.mean().detach().float().cpu()))
print('  rejected_rewards.mean:', float(rejected_rewards.mean().detach().float().cpu()))


In [76]:
# 28.4a) Audit: DPO loss matches manual formula

# Fallback in case 28.4 wasn't run.
if 'loss_ten' not in globals() or 'chosen_rewards' not in globals() or 'rejected_rewards' not in globals():
    loss_ten, chosen_rewards, rejected_rewards = trainer._compute_dpo_loss(lp)

chosen_ratio = lp['policy_chosen_log_prob'] - lp['ref_chosen_log_prob']
rejected_ratio = lp['policy_rejected_log_prob'] - lp['ref_rejected_log_prob']
margin = chosen_ratio - rejected_ratio

manual = -F.logsigmoid(torch.tensor(float(trainer.beta), device=margin.device) * margin)

# Use fallback helper if needed
if 'assert_close' not in globals():
    def assert_close(name: str, a: torch.Tensor, b: torch.Tensor, *, atol: float = 1e-5, rtol: float = 1e-5) -> None:
        diff = (a - b).detach().abs()
        denom = (b.detach().abs() + 1e-12)
        rel = diff / denom
        max_abs = float(diff.max().cpu()) if diff.numel() else 0.0
        max_rel = float(rel.max().cpu()) if rel.numel() else 0.0
        assert (diff <= atol).all().item() or (rel <= rtol).all().item(), (
            f'{name} not close: max_abs={max_abs:.3e}, max_rel={max_rel:.3e} (atol={atol}, rtol={rtol})'
        )

if 'assert_all_finite' not in globals():
    def assert_all_finite(name: str, t: torch.Tensor) -> None:
        ok = torch.isfinite(t.detach()).all().item() if t.numel() else True
        assert ok, f'{name} contains NaN/Inf'

assert_close('dpo_loss_per_sample', manual, loss_ten)
assert_all_finite('dpo_loss_per_sample', loss_ten)

print('Sanity Check: DPO loss matches manual computation')


Sanity Check: DPO loss matches manual computation


In [77]:
# 28.4b) Audit: backward() + gradient sanity

trainer.model.zero_grad(set_to_none=True)
loss.backward()

# Grab a few gradient norms to ensure training signal exists and is finite.

grad_norms = []
for name, p in trainer.model.named_parameters():
    if p.grad is None:
        continue
    g = p.grad.detach()
    if g.numel() == 0:
        continue
    grad_norms.append((name, float(g.float().norm().cpu())))

# Sort descending and show top-k

grad_norms.sort(key=lambda x: x[1], reverse=True)
print('Sanity Check: gradients')
print('  num_params_with_grad:', len(grad_norms))
print('  top5:', grad_norms[:5])

assert len(grad_norms) > 0, 'No gradients found on policy model'
assert all(math.isfinite(v) for _, v in grad_norms), 'Non-finite gradient norm detected'


Sanity Check: gradients
  num_params_with_grad: 146
  top5: [('model.layers.1.mlp.down_proj.weight', 291.70391845703125), ('model.layers.0.self_attn.v_proj.weight', 81.9301528930664), ('model.embed_tokens.weight', 79.6904525756836), ('model.layers.13.self_attn.q_proj.weight', 76.76607513427734), ('model.layers.0.mlp.down_proj.weight', 70.7591552734375)]


In [78]:
# 28.5) Margin computation (policy-vs-ref preference gap)

model_margin = trainer._compute_margin(lp)

print('Sanity Check: margin')
print('  shape:', tuple(model_margin.shape))
print('  mean/std:', float(model_margin.mean()), float(model_margin.std(unbiased=False)))
print('  min/max:', float(model_margin.min()), float(model_margin.max()))


Sanity Check: margin
  shape: (2,)
  mean/std: 0.013483047485351562 0.1452465057373047
  min/max: -0.13176345825195312 0.15872955322265625


In [79]:
# 28.5a) Audit: margin matches expected definition

policy_diff = lp['policy_chosen_log_prob'] - lp['policy_rejected_log_prob']
ref_diff = lp['ref_chosen_log_prob'] - lp['ref_rejected_log_prob']
manual_margin = (policy_diff - ref_diff).detach()

assert_close('model_margin', manual_margin, model_margin)
assert_all_finite('model_margin', model_margin)

print('Sanity Check: margin definition matches')


Sanity Check: margin definition matches


In [None]:
# 28.6) Dynamic-beta state update (warmup -> tau/EMA -> beta update)

trainer._dynamic_beta_update(model_margin=model_margin, loss=loss)

print('Sanity Check: dynamic-beta state')
print('  warmup_count:', trainer._warmup_count, 'warmup_done:', trainer._warmup_done)
print('  beta:', trainer.beta)
print('  tau:', trainer.tau)
print('  last_stats:', trainer._last_stats)


In [None]:
# 28.6a) Audit: dynamic-beta update semantics

print('Sanity Check: beta/tau semantics')
print('  delta:', trainer.dynamic_cfg.delta)
print('  warmup_steps:', trainer.dynamic_cfg.warmup_steps)
print('  warmup_done:', trainer._warmup_done)
print('  beta:', trainer.beta)
print('  tau:', trainer.tau)

# IMPORTANT NOTE (potential bug/ambiguity in repo naming):
# risk_test() returns (p_hat > delta). In the original code, the variable is named `fail`.
# That name is confusing: `fail=1` actually means p_hat exceeded delta.

if trainer._warmup_done:
    p_hat_now = empirical_over_threshold_proportion(model_margin, trainer.tau)
    beta_now = trainer.beta
    print('  p_hat_now:', p_hat_now)
    print('  risk_test(p_hat_now, delta)=', risk_test(p_hat_now, trainer.dynamic_cfg.delta))
    # Directional expectation (not a strict assertion because beta is clipped and uses tanh)
    print('  Interpretation: if p_hat_now > delta, update_beta should push beta up (until beta_max)')

print('Sanity Check: dynamic-beta audit completed')


In [38]:
# 29) Dry-run multiple batches to see warmup -> beta updates (no optimizer step)

steps = int(dyn_cfg.warmup_steps) + 3

it = iter(trainer.get_train_dataloader())
for i in range(steps):
    b = next(it)
    b = trainer._prepare_inputs(b)
    loss = trainer.compute_loss(trainer.model, b)
    print(
        f"step={i:02d} loss={float(loss.detach().float().cpu()):.4f} beta={trainer.beta:.4f} tau={trainer.tau:.4f} warmup_done={trainer._warmup_done}"
    )

print("Sanity Check: dry-run completed")



[DEBUG Step 1] Tensor shapes before forward pass:
  Batch size: 2
  --- Input from batch ---
  prompt_input_ids: padded=106, actual=78.5 (min=51, max=106)
  chosen_completion: padded=17, actual=12.0 (min=7, max=17)
  rejected_completion: padded=43, actual=35.0 (min=27, max=43)
  --- After concatenation ---
  chosen_concat: padded=123, actual=90.5, valid_for_loss=12.0 (min=7, max=17)
  rejected_concat: padded=149, actual=113.5, valid_for_loss=35.0 (min=27, max=43)

[DEBUG Step 1]
  policy_chosen_log_prob: -53.7500 (should be large negative)
  policy_rejected_log_prob: -97.5000
  ref_chosen_log_prob: -53.7080
  ref_rejected_log_prob: -97.4445
  chosen_ratio (policy-ref): -0.0420 (should be ~0 at start)
  rejected_ratio (policy-ref): -0.0555 (should be ~0 at start)
  margin (chosen-rejected ratio): 0.0135
  margin min/max: -0.1318 / 0.1587
  expected loss at margin=0: 0.6930 (log(2))
step=00 loss=0.6925 beta=0.1000 tau=0.0000 warmup_done=False

[DEBUG Step 2] Tensor shapes before forward

In [80]:
# 30) Optional: run the actual Trainer training loop

if RUN_FULL_TRAIN:
    trainer.train()
    print("Sanity Check: trainer.train() finished")
else:
    print("Sanity Check: skipped trainer.train(); set RUN_FULL_TRAIN=True to run")


Sanity Check: skipped trainer.train(); set RUN_FULL_TRAIN=True to run
