# Inspect DPO label building, masking, and tensor specs (from training code)

Goal: reproduce the **exact** label/mask building used in training (TRL tokenization + collator + our trainerâ€™s `_concatenate_and_build_labels`) and dump **raw records** (IDs/masks/labels + token strings) for inspection.

This notebook does **not** run training. It only builds one (or a few) batches and inspects them.


In [3]:
# 0) Imports + repo path setup
from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Dict, List, Tuple

import torch


repo_root = Path.cwd().resolve()
if not (repo_root / "src").exists() and (repo_root.parent / "src").exists():
    repo_root = repo_root.parent
assert (repo_root / "src").exists(), f"Could not find repo root from cwd={Path.cwd()}"

import sys

sys.path.insert(0, str(repo_root))

print("repo_root:", repo_root)
print("torch:", torch.__version__)


repo_root: /home/feng/github/dynamic-dpo-v1
torch: 2.9.1+cu128


In [4]:
# 1) Load config + tokenizer (same setup as src/cli.py)
#
# NOTE: For large models, make sure they are cached or you are logged into HF.

from src.config.loader import load_yaml
from transformers import AutoTokenizer

CONFIG_PATH = repo_root / "config_dpo.yaml"
config = load_yaml(str(CONFIG_PATH))

# Use the Llama policy from config; avoid gpt2 in this notebook.
POLICY_NAME = config.get("policy_name", "meta-llama/Llama-3.2-1B-Instruct")
REF_NAME = config.get("ref_name", POLICY_NAME)

tok = AutoTokenizer.from_pretrained(POLICY_NAME, use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token

print("policy_name:", POLICY_NAME)
print("ref_name:", REF_NAME)
print("pad_token_id:", tok.pad_token_id, "eos_token_id:", tok.eos_token_id)


policy_name: meta-llama/Llama-3.2-1B-Instruct
ref_name: meta-llama/Llama-3.2-1B-Instruct
pad_token_id: 128009 eos_token_id: 128009


In [5]:
# 2) Build a tiny preference dataset (prompt/chosen/rejected)
#
# By default we use a tiny in-memory dataset so you can inspect mechanics without HF dataset downloads.
# Flip USE_HF_DATASET=True to load the real HH dataset (requires network/cached dataset).

USE_HF_DATASET = False

rows = [
    {
        "prompt": "\n\nHuman: Give me a one-sentence tip to focus at work.\n\nAssistant:",
        "chosen": " Try the Pomodoro method: 25 minutes focused, 5 minutes break.",
        "rejected": " You should never take breaks; just work nonstop.",
    },
    {
        "prompt": "\n\nHuman: What is 2+2?\n\nAssistant:",
        "chosen": " 4.",
        "rejected": " 22.",
    },
]

if USE_HF_DATASET:
    from datasets import load_dataset

    dataset_cfg = config.get("dataset", {})
    raw_ds = load_dataset(dataset_cfg["dataset_name"], split=dataset_cfg.get("subset", "train[:1%]"))

    from src.data.hh_dataset import build_HH_dataset, apply_chat_template_to_dataset

    if bool(dataset_cfg.get("generated_data", False)):
        from src.data.hh_dataset import load_generated_dataset_from_config

        hh_ds = load_generated_dataset_from_config(config)
    else:
        hh_ds = build_HH_dataset(raw_ds)

    if bool(dataset_cfg.get("chat_template", False)):
        hh_ds = apply_chat_template_to_dataset(hh_ds, tok)

    rows = [hh_ds[i] for i in range(min(4, len(hh_ds)))]

from datasets import Dataset

ds = Dataset.from_list(rows)


In [6]:
# 3) Build the DynamicBetaDPOTrainer (same pipeline as CLI)
#    We won't train; we only use its dataloader and collator.

from transformers import AutoModelForCausalLM
from trl import DPOConfig
from src.trainers.dynamic_beta_dpo import DynamicBetaDPOConfig, DynamicBetaDPOTrainer

# Split train/val like CLI (for completeness)
from datasets import Dataset

dataset_cfg = config.get("dataset", {})
val_ratio = float(dataset_cfg.get("val_ratio", 0.1))
seed = int(dataset_cfg.get("seed", 0))

split = ds.train_test_split(test_size=val_ratio, seed=seed)
train_ds = split["train"]
eval_ds = split["test"]

# Model + ref model (same as CLI)
policy = AutoModelForCausalLM.from_pretrained(POLICY_NAME)
ref_model = AutoModelForCausalLM.from_pretrained(REF_NAME)
ref_model.eval()
for param in ref_model.parameters():
    param.requires_grad_(False)

# Training args (mirror CLI defaults)
prec = str(config.get("precision", "fp16")).lower()
fp16 = prec == "fp16"
bf16 = prec == "bf16"

train_cfg = config.get("dpo_training", {})
training_args = DPOConfig(
    learning_rate=float(train_cfg.get("learning_rate", 5e-6)),
    per_device_train_batch_size=int(train_cfg.get("batch_size", 2)),
    per_device_eval_batch_size=int(train_cfg.get("eval_batch_size", 2)),
    num_train_epochs=int(train_cfg.get("epochs", 1)),
    logging_steps=int(train_cfg.get("log_steps", 10)),
    eval_strategy="steps",
    eval_steps=int(train_cfg.get("eval_steps", 50)),
    save_strategy="steps",
    save_steps=int(train_cfg.get("save_steps", 50)),
    fp16=fp16,
    bf16=bf16,
    gradient_accumulation_steps=int(train_cfg.get("gradient_accumulation", 1)),
    max_grad_norm=float(train_cfg.get("max_grad_norm", 1.0)),
    warmup_steps=int(train_cfg.get("warmup_steps", 0)),
    report_to=["wandb"] if train_cfg.get("report") else [],
    run_name=str(train_cfg.get("run_name", "debug")),
    remove_unused_columns=False,
    output_dir=str(train_cfg.get("save_dir", "trl_dynamic_beta_dpo")),
    # Truncation settings
    max_prompt_length=int(dataset_cfg.get("max_prompt_length", 256)),
    max_completion_length=int(dataset_cfg.get("max_completion_length", 256)),
    max_length=int(dataset_cfg.get("max_length", 1024)),
    truncation_mode=str(dataset_cfg.get("truncation_mode", "keep_end")),
)

# Dynamic-beta config
risk = config.get("risk_test", {})
beta_up = config.get("beta_update", {})
margin_log = config.get("margin_log", {})

dyn_cfg = DynamicBetaDPOConfig(
    delta=float(risk.get("delta", 0.1)),
    momentum=float(risk.get("lambda", 0.05)),
    warmup_steps=int(risk.get("beta_warmup", 120)),
    beta_0=float(beta_up.get("beta_0", 0.1)),
    alpha=float(beta_up.get("alpha", 0.005)),
    gamma=float(beta_up.get("gamma", 2.0)),
    beta_min=float(beta_up.get("beta_min", 0.0)),
    beta_max=float(beta_up.get("beta_max", 2.0)),
    log_margins=bool(margin_log.get("log_margins", True)),
    log_dir=str(margin_log.get("log_dir", "logs/margins")),
    jsonl_sample_size=int(margin_log.get("jsonl_sample_size", 32)),
    save_per_rank=bool(margin_log.get("save_per_rank", False)),
)

trainer = DynamicBetaDPOTrainer(
    model=policy,
    ref_model=ref_model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    dynamic_cfg=dyn_cfg,
    processing_class=tok,
)

print("trainer ready; train_dataset size:", len(train_ds))


: 

In [None]:
# 4) Pull a batch from the trainer dataloader

train_loader = trainer.get_train_dataloader()
batch = next(iter(train_loader))

for k, v in batch.items():
    if hasattr(v, "shape"):
        print(f"{k:>24s}", "shape=", tuple(v.shape), "dtype=", v.dtype)

batch.keys()


        prompt_input_ids shape= (2, 17) dtype= torch.int64
   prompt_attention_mask shape= (2, 17) dtype= torch.int64
        chosen_input_ids shape= (2, 18) dtype= torch.int64
   chosen_attention_mask shape= (2, 18) dtype= torch.int64
      rejected_input_ids shape= (2, 12) dtype= torch.int64
 rejected_attention_mask shape= (2, 12) dtype= torch.int64


dict_keys(['prompt_input_ids', 'prompt_attention_mask', 'chosen_input_ids', 'chosen_attention_mask', 'rejected_input_ids', 'rejected_attention_mask'])

In [None]:
# 5) Build concatenated tensors + labels using the trainer method
#    (prompt tokens masked to -100; padding masked to -100)

chosen_input_ids, chosen_attention_mask, chosen_labels = trainer._concatenate_and_build_labels(
    prompt_input_ids=batch["prompt_input_ids"],
    prompt_attention_mask=batch["prompt_attention_mask"],
    completion_input_ids=batch["chosen_input_ids"],
    completion_attention_mask=batch["chosen_attention_mask"],
)

rejected_input_ids, rejected_attention_mask, rejected_labels = trainer._concatenate_and_build_labels(
    prompt_input_ids=batch["prompt_input_ids"],
    prompt_attention_mask=batch["prompt_attention_mask"],
    completion_input_ids=batch["rejected_input_ids"],
    completion_attention_mask=batch["rejected_attention_mask"],
)

print("chosen_input_ids", chosen_input_ids.shape)
print("chosen_labels", chosen_labels.shape)
print("rejected_input_ids", rejected_input_ids.shape)
print("rejected_labels", rejected_labels.shape)


chosen_input_ids torch.Size([2, 35])
chosen_labels torch.Size([2, 35])
rejected_input_ids torch.Size([2, 29])
rejected_labels torch.Size([2, 29])


In [None]:
# 6) Invariants: prompt is masked; padding is masked

def assert_label_invariants(
    prompt_input_ids: torch.Tensor,
    prompt_attention_mask: torch.Tensor,
    concat_attention_mask: torch.Tensor,
    labels: torch.Tensor,
) -> None:
    prompt_len = prompt_input_ids.shape[1]
    assert (labels[:, :prompt_len] == -100).all().item(), "Prompt tokens should be all -100"
    assert (labels[concat_attention_mask == 0] == -100).all().item(), "Padding tokens should be -100"
    # Completion: any non-padding completion token should be unmasked
    completion_mask = concat_attention_mask.clone()
    completion_mask[:, :prompt_len] = 0
    if completion_mask.any().item():
        assert (labels[completion_mask == 1] != -100).all().item(), "Completion tokens should be unmasked"


assert_label_invariants(batch["prompt_input_ids"], batch["prompt_attention_mask"], chosen_attention_mask, chosen_labels)
assert_label_invariants(batch["prompt_input_ids"], batch["prompt_attention_mask"], rejected_attention_mask, rejected_labels)
print("OK")


OK


In [None]:
# 7) Raw per-token inspection (IDs, tokens, attention_mask, label mask)

import pandas as pd


def _first_eos_index(ids, eos_id):
    if eos_id is None:
        return None
    for i, tid in enumerate(ids):
        if tid == eos_id:
            return i
    return None


def _tensor_to_list(value, idx):
    if value is None:
        return []
    if hasattr(value, "tolist"):
        if value.ndim == 1:
            return value.tolist()
        return value[idx].tolist()
    if isinstance(value, (list, tuple)):
        if not value:
            return []
        first = value[0]
        if isinstance(first, (list, tuple)):
            return list(value[idx])
        return list(value)
    return []


def token_table(
    *,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    labels: torch.Tensor,
    sample_idx: int,
    tok,
    trim_to_first_eos: bool = True,
) -> tuple[pd.DataFrame, list[int]]:
    ids_full = _tensor_to_list(input_ids, sample_idx)
    attn_full = _tensor_to_list(attention_mask, sample_idx)
    labs_full = _tensor_to_list(labels, sample_idx)
    toks_full = tok.convert_ids_to_tokens(ids_full)

    eos_id = tok.eos_token_id
    eos_positions = [i for i, tid in enumerate(ids_full) if tid == eos_id] if eos_id is not None else []
    first_eos = _first_eos_index(ids_full, eos_id)

    if trim_to_first_eos and first_eos is not None:
        ids = ids_full[: first_eos + 1]
        attn = attn_full[: first_eos + 1]
        labs = labs_full[: first_eos + 1]
        toks = toks_full[: first_eos + 1]
    else:
        ids, attn, labs, toks = ids_full, attn_full, labs_full, toks_full

    df = pd.DataFrame(
        {
            "pos": list(range(len(ids))),
            "token_id": ids,
            "token": toks,
            "attn": attn,
            "label": labs,
            "label_masked": [x == -100 for x in labs],
            "is_eos": [x == eos_id for x in ids] if eos_id is not None else [False] * len(ids),
        }
    )
    return df, eos_positions


SAMPLE_IDX = 0

print("RAW TEXT")
print("prompt:", rows[SAMPLE_IDX]["prompt"])
print("chosen:", rows[SAMPLE_IDX]["chosen"])
print("rejected:", rows[SAMPLE_IDX]["rejected"])

chosen_df, chosen_eos_positions = token_table(
    input_ids=chosen_input_ids,
    attention_mask=chosen_attention_mask,
    labels=chosen_labels,
    sample_idx=SAMPLE_IDX,
    tok=tok,
)
rejected_df, rejected_eos_positions = token_table(
    input_ids=rejected_input_ids,
    attention_mask=rejected_attention_mask,
    labels=rejected_labels,
    sample_idx=SAMPLE_IDX,
    tok=tok,
)

print("eos_token:", tok.eos_token, "eos_token_id:", tok.eos_token_id)
print("chosen eos positions (full):", chosen_eos_positions)
print("rejected eos positions (full):", rejected_eos_positions)

display(chosen_df)
display(rejected_df)


In [None]:
# 8) Show what compute_log_prob will consider "valid" (labels shifted by 1)

from src.losses.dpo_loss import compute_log_prob


def build_loss_mask(labels: torch.Tensor) -> torch.Tensor:
    # compute_log_prob() uses labels[:, 1:] (shift by 1)
    shifted = labels[:, 1:].clone()
    return (shifted != -100)


chosen_loss_mask = build_loss_mask(chosen_labels)
rejected_loss_mask = build_loss_mask(rejected_labels)

print("chosen_loss_mask shape:", chosen_loss_mask.shape, "valid tokens:", int(chosen_loss_mask.sum().item()))
print("rejected_loss_mask shape:", rejected_loss_mask.shape, "valid tokens:", int(rejected_loss_mask.sum().item()))

# (Optional) You need logits to actually run compute_log_prob; this cell just mirrors its masking logic.


chosen_loss_mask shape: torch.Size([2, 34]) valid tokens: 22
rejected_loss_mask shape: torch.Size([2, 28]) valid tokens: 16


In [None]:
# 9) Dump raw records (IDs/masks/labels + token strings) to JSONL
#
# Default output is under /tmp so we don't create repo artifacts.


def dump_records_jsonl(path: Path, *, rows):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for i, r in enumerate(rows):
            prompt_padded_len = int(batch["prompt_input_ids"].shape[1])
            chosen_comp_padded_len = int(batch["chosen_input_ids"].shape[1])
            rejected_comp_padded_len = int(batch["rejected_input_ids"].shape[1])
            chosen_concat_padded_len = int(chosen_input_ids.shape[1])
            rejected_concat_padded_len = int(rejected_input_ids.shape[1])

            prompt_actual_len = int(batch["prompt_attention_mask"][i].sum().item())
            chosen_comp_actual_len = int(batch["chosen_attention_mask"][i].sum().item())
            rejected_comp_actual_len = int(batch["rejected_attention_mask"][i].sum().item())
            chosen_concat_actual_len = int(chosen_attention_mask[i].sum().item())
            rejected_concat_actual_len = int(rejected_attention_mask[i].sum().item())

            chosen_valid_tokens = int((chosen_labels[i] != -100).sum().item())
            rejected_valid_tokens = int((rejected_labels[i] != -100).sum().item())
            chosen_valid_tokens_shifted = int((chosen_labels[i, 1:] != -100).sum().item())
            rejected_valid_tokens_shifted = int((rejected_labels[i, 1:] != -100).sum().item())

            rec = {
                "idx": i,
                "raw": {"prompt": r["prompt"], "chosen": r["chosen"], "rejected": r["rejected"]},
                "lengths": {
                    "prompt_padded_len": prompt_padded_len,
                    "prompt_actual_len": prompt_actual_len,
                    "chosen_completion_padded_len": chosen_comp_padded_len,
                    "chosen_completion_actual_len": chosen_comp_actual_len,
                    "rejected_completion_padded_len": rejected_comp_padded_len,
                    "rejected_completion_actual_len": rejected_comp_actual_len,
                    "chosen_concat_padded_len": chosen_concat_padded_len,
                    "chosen_concat_actual_len": chosen_concat_actual_len,
                    "rejected_concat_padded_len": rejected_concat_padded_len,
                    "rejected_concat_actual_len": rejected_concat_actual_len,
                    "chosen_valid_tokens": chosen_valid_tokens,
                    "rejected_valid_tokens": rejected_valid_tokens,
                    "chosen_valid_tokens_shifted": chosen_valid_tokens_shifted,
                    "rejected_valid_tokens_shifted": rejected_valid_tokens_shifted,
                },
                "prompt": {
                    "input_ids": batch["prompt_input_ids"][i].tolist(),
                    "attention_mask": batch["prompt_attention_mask"][i].tolist(),
                },
                "chosen_completion": {
                    "input_ids": batch["chosen_input_ids"][i].tolist(),
                    "attention_mask": batch["chosen_attention_mask"][i].tolist(),
                },
                "rejected_completion": {
                    "input_ids": batch["rejected_input_ids"][i].tolist(),
                    "attention_mask": batch["rejected_attention_mask"][i].tolist(),
                },
                "chosen_concat": {
                    "input_ids": chosen_input_ids[i].tolist(),
                    "attention_mask": chosen_attention_mask[i].tolist(),
                    "labels": chosen_labels[i].tolist(),
                    "tokens": tok.convert_ids_to_tokens(chosen_input_ids[i].tolist()),
                },
                "rejected_concat": {
                    "input_ids": rejected_input_ids[i].tolist(),
                    "attention_mask": rejected_attention_mask[i].tolist(),
                    "labels": rejected_labels[i].tolist(),
                    "tokens": tok.convert_ids_to_tokens(rejected_input_ids[i].tolist()),
                },
            }
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")


OUT_PATH = Path("/tmp/dpo_label_mask_debug.jsonl")
dump_records_jsonl(OUT_PATH, rows=rows)
print("wrote:", OUT_PATH)


wrote: /tmp/dpo_label_mask_debug.jsonl
