# Inspect DPO label building, masking, and tensor specs (from training code)

Goal: reproduce the **exact** label/mask building used in training (TRL tokenization + collator + our trainerâ€™s `_concatenate_and_build_labels`) and dump **raw records** (IDs/masks/labels + token strings) for inspection.

This notebook does **not** run training. It only builds one (or a few) batches and inspects them.


In [None]:
# 0) Imports + repo path setup
from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Dict, List, Tuple

import torch


repo_root = Path.cwd().resolve()
if not (repo_root / "src").exists() and (repo_root.parent / "src").exists():
    repo_root = repo_root.parent
assert (repo_root / "src").exists(), f"Could not find repo root from cwd={Path.cwd()}"

import sys

sys.path.insert(0, str(repo_root))

print("repo_root:", repo_root)
print("torch:", torch.__version__)


In [None]:
# 1) Load config (optional) + pick model/tokenizer
#
# NOTE: For Llama models, you may need HF auth + local cache.
# If you only want to inspect token/label mechanics without downloading large models,
# set MODEL_NAME to something small that is already cached.

from src.config.loader import load_yaml

CONFIG_PATH = repo_root / "config_dpo.yaml"
config = load_yaml(str(CONFIG_PATH))

MODEL_NAME = config.get("policy_name", "gpt2")

from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token

print("MODEL_NAME:", MODEL_NAME)
print("pad_token_id:", tok.pad_token_id)
print("eos_token_id:", tok.eos_token_id)


In [None]:
# 2) Build a tiny preference dataset (prompt/chosen/rejected)
#
# By default we use a tiny in-memory dataset so you can inspect mechanics without HF dataset downloads.
# Flip USE_HF_DATASET=True to load the real HH dataset (requires network/cached dataset).

USE_HF_DATASET = False

rows = [
    {
        "prompt": "\n\nHuman: Give me a one-sentence tip to focus at work.\n\nAssistant:",
        "chosen": " Try the Pomodoro method: 25 minutes focused, 5 minutes break.",
        "rejected": " You should never take breaks; just work nonstop.",
    },
    {
        "prompt": "\n\nHuman: What is 2+2?\n\nAssistant:",
        "chosen": " 4.",
        "rejected": " 22.",
    },
]

if USE_HF_DATASET:
    from datasets import load_dataset

    dataset_cfg = config.get("dataset", {})
    raw_ds = load_dataset(dataset_cfg["dataset_name"], split=dataset_cfg.get("subset", "train[:1%]"))

    from src.data.hh_dataset import build_HH_dataset, apply_chat_template_to_dataset

    if bool(dataset_cfg.get("generated_data", False)):
        from src.data.hh_dataset import load_generated_dataset_from_config

        hh_ds = load_generated_dataset_from_config(config)
    else:
        hh_ds = build_HH_dataset(raw_ds)

    if bool(dataset_cfg.get("chat_template", False)):
        hh_ds = apply_chat_template_to_dataset(hh_ds, tok)

    rows = [hh_ds[i] for i in range(min(4, len(hh_ds)))]

from datasets import Dataset

ds = Dataset.from_list(rows)
ds

In [None]:
# 3) Tokenize exactly like TRL DPOTrainer.tokenize_row (prompt/chosen/rejected separated)

from trl.trainer.dpo_trainer import DPOTrainer

max_prompt_length = int(config.get("dataset", {}).get("max_prompt_length", 256))
max_completion_length = int(config.get("dataset", {}).get("max_completion_length", 256))


def tokenize_row(row: Dict[str, str]) -> Dict[str, List[int]]:
    return DPOTrainer.tokenize_row(
        row,
        processing_class=tok,
        max_prompt_length=max_prompt_length,
        max_completion_length=max_completion_length,
        add_special_tokens=False,
    )


tok_rows = [tokenize_row(r) for r in rows]
tok_rows[0].keys(), {k: len(v) for k, v in tok_rows[0].items()}

In [None]:
# 4) Collate exactly like TRL's DataCollatorForPreference (prompt left-pad, completions right-pad)

from trl.trainer.dpo_trainer import DataCollatorForPreference

collator = DataCollatorForPreference(pad_token_id=int(tok.pad_token_id))
batch = collator(tok_rows)

for k, v in batch.items():
    if hasattr(v, "shape"):
        print(f"{k:>24s}", "shape=", tuple(v.shape), "dtype=", v.dtype)

batch.keys()

In [None]:
# 5) Build concatenated tensors + labels using the *training* method
#    (prompt tokens masked to -100; padding masked to -100)

from src.trainers.dynamic_beta_dpo import DynamicBetaDPOTrainer

chosen_input_ids, chosen_attention_mask, chosen_labels = DynamicBetaDPOTrainer._concatenate_and_build_labels(
    None,
    prompt_input_ids=batch["prompt_input_ids"],
    prompt_attention_mask=batch["prompt_attention_mask"],
    completion_input_ids=batch["chosen_input_ids"],
    completion_attention_mask=batch["chosen_attention_mask"],
)

rejected_input_ids, rejected_attention_mask, rejected_labels = DynamicBetaDPOTrainer._concatenate_and_build_labels(
    None,
    prompt_input_ids=batch["prompt_input_ids"],
    prompt_attention_mask=batch["prompt_attention_mask"],
    completion_input_ids=batch["rejected_input_ids"],
    completion_attention_mask=batch["rejected_attention_mask"],
)

print("chosen_input_ids", chosen_input_ids.shape)
print("chosen_labels", chosen_labels.shape)
print("rejected_input_ids", rejected_input_ids.shape)
print("rejected_labels", rejected_labels.shape)


In [None]:
# 6) Invariants: prompt is masked; padding is masked

def assert_label_invariants(
    prompt_input_ids: torch.Tensor,
    prompt_attention_mask: torch.Tensor,
    concat_attention_mask: torch.Tensor,
    labels: torch.Tensor,
) -> None:
    prompt_len = prompt_input_ids.shape[1]
    assert (labels[:, :prompt_len] == -100).all().item(), "Prompt tokens should be all -100"
    assert (labels[concat_attention_mask == 0] == -100).all().item(), "Padding tokens should be -100"
    # Completion: any non-padding completion token should be unmasked
    completion_mask = concat_attention_mask.clone()
    completion_mask[:, :prompt_len] = 0
    if completion_mask.any().item():
        assert (labels[completion_mask == 1] != -100).all().item(), "Completion tokens should be unmasked"


assert_label_invariants(batch["prompt_input_ids"], batch["prompt_attention_mask"], chosen_attention_mask, chosen_labels)
assert_label_invariants(batch["prompt_input_ids"], batch["prompt_attention_mask"], rejected_attention_mask, rejected_labels)
print("OK")


In [None]:
# 7) Raw per-token inspection (IDs, tokens, attention_mask, label mask)

import pandas as pd


def token_table(
    *,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    labels: torch.Tensor,
    sample_idx: int,
    tok,
) -> pd.DataFrame:
    ids = input_ids[sample_idx].tolist()
    attn = attention_mask[sample_idx].tolist()
    labs = labels[sample_idx].tolist()
    toks = tok.convert_ids_to_tokens(ids)
    return pd.DataFrame(
        {
            "pos": list(range(len(ids))),
            "token_id": ids,
            "token": toks,
            "attn": attn,
            "label": labs,
            "label_masked": [x == -100 for x in labs],
        }
    )


SAMPLE_IDX = 0

print("RAW TEXT")
print("prompt:", rows[SAMPLE_IDX]["prompt"])
print("chosen:", rows[SAMPLE_IDX]["chosen"])
print("rejected:", rows[SAMPLE_IDX]["rejected"])

chosen_df = token_table(
    input_ids=chosen_input_ids,
    attention_mask=chosen_attention_mask,
    labels=chosen_labels,
    sample_idx=SAMPLE_IDX,
    tok=tok,
)
rejected_df = token_table(
    input_ids=rejected_input_ids,
    attention_mask=rejected_attention_mask,
    labels=rejected_labels,
    sample_idx=SAMPLE_IDX,
    tok=tok,
)

display(chosen_df)
display(rejected_df)


In [None]:
# 8) Show what compute_log_prob will consider "valid" (labels shifted by 1)

from src.losses.dpo_loss import compute_log_prob


def build_loss_mask(labels: torch.Tensor) -> torch.Tensor:
    # compute_log_prob() uses labels[:, 1:] (shift by 1)
    shifted = labels[:, 1:].clone()
    return (shifted != -100)


chosen_loss_mask = build_loss_mask(chosen_labels)
rejected_loss_mask = build_loss_mask(rejected_labels)

print("chosen_loss_mask shape:", chosen_loss_mask.shape, "valid tokens:", int(chosen_loss_mask.sum().item()))
print("rejected_loss_mask shape:", rejected_loss_mask.shape, "valid tokens:", int(rejected_loss_mask.sum().item()))

# (Optional) You need logits to actually run compute_log_prob; this cell just mirrors its masking logic.


In [None]:
# 9) Dump raw records (IDs/masks/labels + token strings) to JSONL
#
# Default output is under /tmp so we don't create repo artifacts.

def dump_records_jsonl(path: Path, *, rows: List[Dict[str, str]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for i, r in enumerate(rows):
            prompt_padded_len = int(batch["prompt_input_ids"].shape[1])
            chosen_comp_padded_len = int(batch["chosen_input_ids"].shape[1])
            rejected_comp_padded_len = int(batch["rejected_input_ids"].shape[1])
            chosen_concat_padded_len = int(chosen_input_ids.shape[1])
            rejected_concat_padded_len = int(rejected_input_ids.shape[1])

            prompt_actual_len = int(batch["prompt_attention_mask"][i].sum().item())
            chosen_comp_actual_len = int(batch["chosen_attention_mask"][i].sum().item())
            rejected_comp_actual_len = int(batch["rejected_attention_mask"][i].sum().item())
            chosen_concat_actual_len = int(chosen_attention_mask[i].sum().item())
            rejected_concat_actual_len = int(rejected_attention_mask[i].sum().item())

            chosen_valid_tokens = int((chosen_labels[i] != -100).sum().item())
            rejected_valid_tokens = int((rejected_labels[i] != -100).sum().item())
            chosen_valid_tokens_shifted = int((chosen_labels[i, 1:] != -100).sum().item())
            rejected_valid_tokens_shifted = int((rejected_labels[i, 1:] != -100).sum().item())

            rec = {
                "idx": i,
                "raw": {"prompt": r["prompt"], "chosen": r["chosen"], "rejected": r["rejected"]},
                "lengths": {
                    "prompt_padded_len": prompt_padded_len,
                    "prompt_actual_len": prompt_actual_len,
                    "chosen_completion_padded_len": chosen_comp_padded_len,
                    "chosen_completion_actual_len": chosen_comp_actual_len,
                    "rejected_completion_padded_len": rejected_comp_padded_len,
                    "rejected_completion_actual_len": rejected_comp_actual_len,
                    "chosen_concat_padded_len": chosen_concat_padded_len,
                    "chosen_concat_actual_len": chosen_concat_actual_len,
                    "rejected_concat_padded_len": rejected_concat_padded_len,
                    "rejected_concat_actual_len": rejected_concat_actual_len,
                    "chosen_valid_tokens": chosen_valid_tokens,
                    "rejected_valid_tokens": rejected_valid_tokens,
                    "chosen_valid_tokens_shifted": chosen_valid_tokens_shifted,
                    "rejected_valid_tokens_shifted": rejected_valid_tokens_shifted,
                },
                "prompt": {
                    "input_ids": batch["prompt_input_ids"][i].tolist(),
                    "attention_mask": batch["prompt_attention_mask"][i].tolist(),
                },
                "chosen_completion": {
                    "input_ids": batch["chosen_input_ids"][i].tolist(),
                    "attention_mask": batch["chosen_attention_mask"][i].tolist(),
                },
                "rejected_completion": {
                    "input_ids": batch["rejected_input_ids"][i].tolist(),
                    "attention_mask": batch["rejected_attention_mask"][i].tolist(),
                },
                "chosen_concat": {
                    "input_ids": chosen_input_ids[i].tolist(),
                    "attention_mask": chosen_attention_mask[i].tolist(),
                    "labels": chosen_labels[i].tolist(),
                    "tokens": tok.convert_ids_to_tokens(chosen_input_ids[i].tolist()),
                },
                "rejected_concat": {
                    "input_ids": rejected_input_ids[i].tolist(),
                    "attention_mask": rejected_attention_mask[i].tolist(),
                    "labels": rejected_labels[i].tolist(),
                    "tokens": tok.convert_ids_to_tokens(rejected_input_ids[i].tolist()),
                },
            }
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")


OUT_PATH = Path("/tmp/dpo_label_mask_debug.jsonl")
dump_records_jsonl(OUT_PATH, rows=rows)
print("wrote:", OUT_PATH)
