# Lesson 11: Direct Preference Optimization (DPO)

In many real-world applications, we care about *preferences*, not just next-token prediction.\nPreference optimization teaches a model to produce responses that humans prefer by comparing a **chosen**
answer against a **rejected** answer for the same prompt.\n
DPO (Direct Preference Optimization) is a simple, stable way to do this without explicit RL loops:\nit pushes the model to assign higher likelihood to preferred answers than to dispreferred ones.

Goals in this notebook:\n- Understand preference optimization at a high level.
- Fine-tune a small model (GPT-2) with DPO using TRL if available.
- Fall back to a tiny PyTorch-only pairwise loss if TRL is missing.


## 1) Setup
We will use `transformers` for the model and tokenizer, `datasets` for data, and optionally `trl` for DPO.
If TRL is missing, we will still run a simplified pairwise objective.


In [None]:
import random
import math
import os

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset, Dataset

# Optional TRL for real DPO training
try:
    from trl import DPOTrainer, DPOConfig
    HAS_TRL = True  # TRL availability flag
except Exception:
    HAS_TRL = False  # TRL availability flag
    print("TRL not installed. You can install it with: pip install trl")

torch.manual_seed(0)
random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

## 2) Preference Dataset (chosen vs rejected)
We will try to load a small preference dataset from Hugging Face. If that fails, we will create a tiny
synthetic dataset so the notebook still runs end-to-end.


In [None]:
def split_prompt_answer(text):
    """Best-effort split for datasets without explicit prompt column."""
    marker = "Assistant:"
    if marker in text:
        prompt, answer = text.split(marker, 1)
        return (prompt + marker).strip() + "\n", answer.strip()
    # Fallback: no split
    return "", text.strip()

def try_load_preference_dataset(max_rows=200):
    candidates = [
        ("Dahoas/rm-static", "train"),
        ("Anthropic/hh-rlhf", "train"),
    ]

    for name, split in candidates:
        try:
            ds = load_dataset(name, split=f"{split}[:1%]")
            # Normalize columns into prompt/chosen/rejected
            if "prompt" in ds.column_names:
                def mapper(ex):
                    return {
                        "prompt": ex["prompt"],
                        "chosen": ex["chosen"],
                        "rejected": ex["rejected"],
                    }
            else:
                def mapper(ex):
                    c_prompt, c_ans = split_prompt_answer(ex["chosen"])
                    r_prompt, r_ans = split_prompt_answer(ex["rejected"])
                    prompt = c_prompt if c_prompt else r_prompt
                    return {
                        "prompt": prompt,
                        "chosen": c_ans,
                        "rejected": r_ans,
                    }

            ds = ds.map(mapper, remove_columns=ds.column_names)
            ds = ds.filter(lambda x: len(x["chosen"]) > 0 and len(x["rejected"]) > 0)
            ds = ds.shuffle(seed=0)
            ds = ds.select(range(min(max_rows, len(ds))))
            print(f"Loaded dataset: {name} with {len(ds)} pairs")
            return ds
        except Exception as e:
            print(f"Failed to load {name}: {e}")
    return None

def build_synthetic_pairs(num_pairs=200):
    questions = [
        "What is the capital of France?",
        "Explain photosynthesis in one sentence.",
        "What is 2 + 2?",
        "Name one benefit of regular exercise.",
        "What does CPU stand for?",
        "How do you boil an egg?",
        "What is the largest planet in our solar system?",
        "Give a short definition of gravity.",
        "What is a triangle?",
        "What is the boiling point of water in Celsius?",
    ]

    good_answers = [
        "The capital of France is Paris.",
        "Photosynthesis is the process by which plants use sunlight to make food from carbon dioxide and water.",
        "2 + 2 equals 4.",
        "Regular exercise improves health and fitness, such as stronger muscles and a healthier heart.",
        "CPU stands for Central Processing Unit.",
        "Place eggs in boiling water for about 8 to 10 minutes, then cool them in cold water.",
        "Jupiter is the largest planet in our solar system.",
        "Gravity is the force that pulls objects toward each other.",
        "A triangle is a three-sided polygon.",
        "Water boils at 100 degrees Celsius at sea level.",
    ]

    bad_answers = [
        "I do not know.",
        "Maybe.",
        "No idea.",
        "Because it just is.",
        "It is not important.",
    ]

    rows = []
    for i in range(num_pairs):
        idx = i % len(questions)
        prompt = questions[idx]
        chosen = good_answers[idx]
        # Make a bad answer by picking something short and unhelpful
        rejected = random.choice(bad_answers)
        rows.append({"prompt": prompt, "chosen": chosen, "rejected": rejected})
    return Dataset.from_list(rows)

dataset = try_load_preference_dataset(max_rows=200)
if dataset is None:
    print("Falling back to a synthetic preference dataset.")
    dataset = build_synthetic_pairs(num_pairs=200)

# Simple train/test split
dataset = dataset.shuffle(seed=0)
split_idx = int(0.9 * len(dataset))
train_ds = dataset.select(range(split_idx))
eval_ds = dataset.select(range(split_idx, len(dataset)))

print(train_ds[0])
print(f"Train pairs: {len(train_ds)}, Eval pairs: {len(eval_ds)}")

## 3) Tokenization
GPT-2 does not have a padding token, so we reuse the end-of-sequence token as padding.
We will build prompt + answer sequences and mark prompt tokens so we do not score them when computing
the pairwise preference loss.


In [None]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

max_prompt_length = 128
max_length = 256

def encode_prompt_answer(prompt, answer):
    # Tokenize prompt alone so we can mask its tokens in the loss
    prompt_ids = tokenizer(prompt, truncation=True, max_length=max_prompt_length)["input_ids"]
    full = tokenizer(prompt + answer, truncation=True, max_length=max_length)
    input_ids = full["input_ids"]
    labels = input_ids.copy()
    # Mask prompt tokens so we only score the answer
    labels[: len(prompt_ids)] = [-100] * len(prompt_ids)
    attention_mask = [1] * len(input_ids)
    return input_ids, labels, attention_mask

## 4) DPO Training
If TRL is installed, we will use `DPOTrainer`, which implements the DPO loss directly.
Otherwise, we fall back to a tiny PyTorch loop that maximizes the log-probability gap between
chosen and rejected answers. This fallback is a simplified version of the DPO idea.


In [None]:
# Load the model for training
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

if HAS_TRL:
    # TRL DPO training
    ref_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    config = DPOConfig(
        output_dir="dpo_outputs",
        per_device_train_batch_size=2,
        num_train_epochs=1,
        gradient_accumulation_steps=4,
        learning_rate=1e-5,
        logging_steps=5,
        max_length=max_length,
        max_prompt_length=max_prompt_length,
        report_to=[],
    )
    trainer = DPOTrainer(
        model=model,
        ref_model=ref_model,
        args=config,
        train_dataset=train_ds,
        tokenizer=tokenizer,
    )
    train_result = trainer.train()
    print(train_result)
else:
    # PyTorch fallback: pairwise preference loss (DPO-like)
    beta = 0.1
    batch_size = 4
    lr = 1e-5
    num_steps = 50

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model.train()

    def collate_fn(batch):
        def encode_pair(prompt, answer):
            input_ids, labels, attention_mask = encode_prompt_answer(prompt, answer)
            return input_ids, labels, attention_mask

        chosen_inputs = []
        rejected_inputs = []

        for ex in batch:
            c = encode_pair(ex["prompt"], ex["chosen"])
            r = encode_pair(ex["prompt"], ex["rejected"])
            chosen_inputs.append(c)
            rejected_inputs.append(r)

        def pad_batch(seqs, pad_token_id, label_pad_id=-100):
            max_len = max(len(s[0]) for s in seqs)
            input_ids = []
            labels = []
            attention_mask = []

            for ids, lab, mask in seqs:
                pad_len = max_len - len(ids)
                input_ids.append(ids + [pad_token_id] * pad_len)
                labels.append(lab + [label_pad_id] * pad_len)
                attention_mask.append(mask + [0] * pad_len)

            return (
                torch.tensor(input_ids, dtype=torch.long),
                torch.tensor(labels, dtype=torch.long),
                torch.tensor(attention_mask, dtype=torch.long),
            )

        c_input_ids, c_labels, c_attention = pad_batch(chosen_inputs, tokenizer.pad_token_id)
        r_input_ids, r_labels, r_attention = pad_batch(rejected_inputs, tokenizer.pad_token_id)

        return {
            "c_input_ids": c_input_ids,
            "c_labels": c_labels,
            "c_attention": c_attention,
            "r_input_ids": r_input_ids,
            "r_labels": r_labels,
            "r_attention": r_attention,
        }

    loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    def sequence_logp(input_ids, labels, attention_mask):
        # Compute log-probability of only the answer tokens (labels != -100)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits[:, :-1, :]
        target = labels[:, 1:]
        mask = (target != -100).float()

        log_probs = torch.log_softmax(logits, dim=-1)
        token_logp = log_probs.gather(-1, target.unsqueeze(-1)).squeeze(-1)
        # Sum over answer tokens
        return (token_logp * mask).sum(dim=-1)

    step = 0
    for epoch in range(1):
        for batch in loader:
            step += 1
            if step > num_steps:
                break

            c_input_ids = batch["c_input_ids"].to(device)
            c_labels = batch["c_labels"].to(device)
            c_attention = batch["c_attention"].to(device)
            r_input_ids = batch["r_input_ids"].to(device)
            r_labels = batch["r_labels"].to(device)
            r_attention = batch["r_attention"].to(device)

            chosen_logp = sequence_logp(c_input_ids, c_labels, c_attention)
            rejected_logp = sequence_logp(r_input_ids, r_labels, r_attention)

            # DPO-like loss: encourage chosen > rejected
            loss = -torch.nn.functional.logsigmoid(beta * (chosen_logp - rejected_logp)).mean()
            acc = (chosen_logp > rejected_logp).float().mean().item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step % 10 == 0:
                print(f"step {step:03d} | loss {loss.item():.4f} | pairwise_acc {acc:.2f}")

        if step > num_steps:
            break

## 5) Inference: before vs after
We compare a few prompts before and after training. The outputs will be noisy (GPT-2 is tiny), but
you should see a slight preference shift toward more helpful answers.


In [None]:
def generate_text(model, prompt, max_new_tokens=60):
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        output_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.8,
            top_p=0.9,
        )
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

test_prompts = [train_ds[i]["prompt"] for i in range(3)]

# Load a fresh base model for comparison
base_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

for p in test_prompts:
    print("Prompt:", p)
    print("Base:", generate_text(base_model, p))
    print("Tuned:", generate_text(model, p))
    print("-" * 60)

## 6) Scaling Notes
- Larger models improve noticeably with richer preference data.
- Better preference pairs (more diverse prompts, higher-quality judgments) matter more than raw count.
- Safety: preference data can encode bias; always review and filter sources.
- In production, consider evaluation benchmarks and human review loops.


## 7) Exercises
1) Increase `num_steps` and observe if the pairwise accuracy changes.
2) Replace GPT-2 with a slightly larger model (e.g., `gpt2-medium`) and compare outputs.
3) Try a different dataset and inspect how prompt formatting changes outcomes.
4) Add your own prompt/answer pairs and see whether the model shifts.
