# Lab 3.1.5: Direct Preference Optimization (DPO) Training

**Module:** 3.1 - Large Language Model Fine-Tuning  
**Time:** 3 hours  
**Difficulty:** ⭐⭐⭐⭐☆

---

## Learning Objectives

By the end of this notebook, you will:
- [ ] Understand the theory behind DPO and how it differs from RLHF
- [ ] Create preference pair datasets from raw data
- [ ] Implement the DPO loss function from scratch
- [ ] Train a model with DPO using TRL
- [ ] Evaluate and compare DPO-trained vs SFT-only models

---

## Prerequisites

- Completed: Tasks 10.1-10.4
- Knowledge of: Supervised fine-tuning, basic reinforcement learning concepts

---

## Real-World Context

### The Problem with SFT Alone

Supervised Fine-Tuning (SFT) teaches models *what* to say, but not *how to choose* between alternatives. When there are multiple valid responses, how does the model know which one is *better*?

**Real example:** "Write a poem about the ocean"
- Response A: Short, simple, rhymes perfectly
- Response B: Long, metaphorical, more creative

Both are valid! But different users might prefer different styles. **Preference optimization** teaches the model which responses are *preferred* by humans.

### The Traditional Approach: RLHF

RLHF (Reinforcement Learning from Human Feedback) used by ChatGPT:
1. Collect human preferences
2. Train a reward model
3. Use RL (PPO) to optimize against the reward model

**Problem:** Complex, unstable, requires lots of compute.

### DPO: The Elegant Solution

DPO (Direct Preference Optimization) achieves similar results with:
- No reward model needed
- No RL training loop
- Stable, simple supervised learning

**That's why it's become so popular!**

---

## ELI5: What is DPO?

> **Imagine you're training a chef** (your model) to make better dishes.
>
> **SFT (Supervised Fine-Tuning):**  
> You show the chef recipes and say "Cook exactly like this." They learn to follow recipes, but can't improve on their own.
>
> **RLHF (Reinforcement Learning):**  
> You hire a food critic (reward model), have the chef cook many dishes, and the critic rates each one. Complex: you need to train the critic, and the chef learns through trial and error.
>
> **DPO (Direct Preference Optimization):**  
> You simply show the chef pairs of dishes: "This one is better than that one." No food critic needed! The chef directly learns what makes one dish preferred over another.
>
> **In technical terms:** DPO uses pairs of (chosen, rejected) responses to directly teach the model to assign higher probability to preferred responses, without needing an explicit reward model.

---

## Part 1: Understanding the DPO Loss

### The Math Behind DPO

The DPO loss is:

$$\mathcal{L}_{DPO}(\pi_\theta; \pi_{ref}) = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)} \right) \right]$$

Where:
- $\pi_\theta$ is our trainable policy (the model we're training)
- $\pi_{ref}$ is the reference policy (frozen copy of initial model)
- $y_w$ is the winning/chosen response
- $y_l$ is the losing/rejected response
- $\beta$ is a temperature parameter (typically 0.1)
- $\sigma$ is the sigmoid function

### Intuition

The loss encourages:
- **Higher probability for chosen responses** relative to reference
- **Lower probability for rejected responses** relative to reference
- **Staying close to the reference model** (the $\pi_{ref}$ terms)

In [None]:
# Setup
import warnings
# Suppress verbose warnings from transformers/PEFT that clutter notebook output
# (e.g., deprecation warnings, tokenizer warnings that don't affect functionality)
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional
import json
import gc

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
def dpo_loss(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    reference_chosen_logps: torch.Tensor,
    reference_rejected_logps: torch.Tensor,
    beta: float = 0.1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute the DPO loss from scratch.
    
    Args:
        policy_chosen_logps: Log probabilities of chosen responses under policy
        policy_rejected_logps: Log probabilities of rejected responses under policy
        reference_chosen_logps: Log probabilities of chosen responses under reference
        reference_rejected_logps: Log probabilities of rejected responses under reference
        beta: Temperature parameter controlling deviation from reference
    
    Returns:
        Tuple of (loss, chosen_rewards, rejected_rewards)
    """
    # Compute log ratios
    # These represent how much more/less likely the response is under policy vs reference
    chosen_logratios = policy_chosen_logps - reference_chosen_logps
    rejected_logratios = policy_rejected_logps - reference_rejected_logps
    
    # Compute implicit rewards
    # Higher is better for chosen, lower is better for rejected
    chosen_rewards = beta * chosen_logratios
    rejected_rewards = beta * rejected_logratios
    
    # DPO loss is negative log sigmoid of reward difference
    # We want chosen_rewards > rejected_rewards
    logits = chosen_rewards - rejected_rewards
    loss = -F.logsigmoid(logits).mean()
    
    return loss, chosen_rewards.mean(), rejected_rewards.mean()


# Demonstrate with simple example
batch_size = 4

# Simulated log probabilities (more negative = less likely)
policy_chosen = torch.tensor([-2.0, -1.5, -2.5, -1.8])
policy_rejected = torch.tensor([-3.0, -2.5, -3.5, -2.8])
ref_chosen = torch.tensor([-2.5, -2.0, -2.5, -2.0])
ref_rejected = torch.tensor([-2.5, -2.0, -3.0, -2.5])

loss, chosen_rew, rejected_rew = dpo_loss(
    policy_chosen, policy_rejected, ref_chosen, ref_rejected
)

print("DPO Loss Example:")
print(f"  Loss: {loss.item():.4f}")
print(f"  Mean chosen reward: {chosen_rew.item():.4f}")
print(f"  Mean rejected reward: {rejected_rew.item():.4f}")
print(f"  Reward margin: {(chosen_rew - rejected_rew).item():.4f}")

In [None]:
# Visualize DPO loss behavior
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Loss vs reward margin
margins = np.linspace(-5, 5, 100)
losses = -np.log(1 / (1 + np.exp(-margins)))

axes[0].plot(margins, losses, 'b-', linewidth=2)
axes[0].axvline(x=0, color='gray', linestyle='--', alpha=0.5)
axes[0].fill_between(margins, losses, alpha=0.3, where=(margins > 0))
axes[0].set_xlabel('Reward Margin (chosen - rejected)')
axes[0].set_ylabel('Loss')
axes[0].set_title('DPO Loss vs Reward Margin')
axes[0].grid(True, alpha=0.3)
axes[0].annotate('Good: chosen > rejected', xy=(2, 0.2), fontsize=10)
axes[0].annotate('Bad: rejected > chosen', xy=(-4, 2), fontsize=10)

# Plot 2: Effect of beta
log_ratios = np.linspace(-2, 2, 100)
betas = [0.05, 0.1, 0.2, 0.5]

for beta in betas:
    rewards = beta * log_ratios
    axes[1].plot(log_ratios, rewards, label=f'β={beta}')

axes[1].set_xlabel('Log Ratio (log π_θ / π_ref)')
axes[1].set_ylabel('Implicit Reward')
axes[1].set_title('Effect of β on Reward Scaling')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('dpo_loss_visualization.png', dpi=150)
plt.show()
plt.close(fig)  # Release memory

print("\nKey Insights:")
print("1. Loss is low when chosen response has higher reward than rejected")
print("2. Higher β makes the model deviate more from reference for the same log ratio")
print("3. Lower β keeps the model closer to the reference policy")

---

## Part 2: Creating Preference Datasets

DPO requires pairs of (chosen, rejected) responses for the same prompt.

In [None]:
# Sample preference dataset
preference_data = [
    {
        "prompt": "Explain what machine learning is.",
        "chosen": "Machine learning is a subset of artificial intelligence that enables computers to learn patterns from data without being explicitly programmed. By analyzing examples, ML algorithms can identify patterns and make predictions or decisions. It's the technology behind recommendation systems, voice assistants, and many other applications we use daily.",
        "rejected": "Machine learning is when computers learn stuff. It's like AI but different. You put in data and it learns patterns or whatever."
    },
    {
        "prompt": "How do I improve my Python programming skills?",
        "chosen": "Here are effective ways to improve your Python skills:\n\n1. **Practice daily** - Solve problems on LeetCode, HackerRank, or Codewars\n2. **Build projects** - Create real applications that solve problems you care about\n3. **Read quality code** - Study open-source projects on GitHub\n4. **Learn the standard library** - Python's batteries-included philosophy offers powerful built-in tools\n5. **Write tests** - Test-driven development improves code quality\n6. **Review others' code** - Learn different approaches and patterns",
        "rejected": "Just practice a lot. Code every day and you'll get better eventually."
    },
    {
        "prompt": "What's the difference between a list and a tuple in Python?",
        "chosen": "Lists and tuples are both sequence types in Python, but they have key differences:\n\n**Mutability:**\n- Lists are mutable (can be modified after creation)\n- Tuples are immutable (cannot be changed after creation)\n\n**Syntax:**\n- Lists: `[1, 2, 3]`\n- Tuples: `(1, 2, 3)`\n\n**Performance:**\n- Tuples are slightly faster due to immutability\n- Tuples use less memory\n\n**Use Cases:**\n- Lists for collections that may change\n- Tuples for fixed data like coordinates or function returns",
        "rejected": "Lists use [] and tuples use (). Lists can change, tuples can't."
    },
    {
        "prompt": "Write a function to calculate factorial.",
        "chosen": """Here's a well-documented factorial function with error handling:

```python
def factorial(n: int) -> int:
    \"\"\"Calculate the factorial of a non-negative integer.
    
    Args:
        n: A non-negative integer
    
    Returns:
        The factorial of n (n!)
    
    Raises:
        ValueError: If n is negative
    
    Examples:
        >>> factorial(5)
        120
        >>> factorial(0)
        1
    \"\"\" 
    if n < 0:
        raise ValueError("Factorial is not defined for negative numbers")
    if n <= 1:
        return 1
    return n * factorial(n - 1)
```

This handles edge cases and includes type hints and documentation.""",
        "rejected": "def f(n): return 1 if n<=1 else n*f(n-1)"
    },
    {
        "prompt": "What is overfitting in machine learning?",
        "chosen": "Overfitting occurs when a machine learning model learns the training data too well, including its noise and random fluctuations. The result is a model that performs excellently on training data but poorly on new, unseen data.\n\n**Signs of overfitting:**\n- Very low training error but high validation error\n- Large gap between training and validation performance\n\n**Solutions:**\n- Regularization (L1, L2, dropout)\n- More training data\n- Simpler model architecture\n- Cross-validation\n- Early stopping",
        "rejected": "It's when the model memorizes the data instead of learning. Not good."
    },
    # Edge cases: Similar quality responses where preference is subtle
    {
        "prompt": "How do I center a div in CSS?",
        "chosen": "There are several modern ways to center a div:\n\n**Flexbox (recommended):**\n```css\n.container {\n  display: flex;\n  justify-content: center;\n  align-items: center;\n}\n```\n\n**Grid:**\n```css\n.container {\n  display: grid;\n  place-items: center;\n}\n```\n\nBoth methods work for horizontal and vertical centering.",
        "rejected": "You can use flexbox with display: flex and then justify-content: center and align-items: center on the parent element."
    },
    # Edge case: Handling a request that could be misinterpreted
    {
        "prompt": "Tell me a joke about programmers.",
        "chosen": "Why do programmers prefer dark mode? Because light attracts bugs!\n\nHere's another: A programmer's wife tells them, 'Go to the store and buy milk. If they have eggs, get a dozen.' The programmer comes home with 12 cartons of milk. When asked why, they say, 'They had eggs!'",
        "rejected": "Programmers are nerdy and don't go outside much. Ha ha."
    },
    # Edge case: Technical accuracy matters more than length
    {
        "prompt": "What is the time complexity of binary search?",
        "chosen": "Binary search has O(log n) time complexity, where n is the number of elements. This is because each comparison eliminates half of the remaining elements.",
        "rejected": "Binary search is really fast because it keeps dividing the array in half. It's much faster than linear search which checks every element one by one. Binary search only works on sorted arrays though, which is important to remember. The algorithm compares the target with the middle element and decides which half to search next."
    },
]

print(f"Created {len(preference_data)} preference pairs")
print(f"\nIncludes {len(preference_data) - 5} edge case examples demonstrating:")
print("  - Similar quality responses with subtle differences")
print("  - Humor handling")
print("  - Accuracy vs verbosity tradeoffs")
print(f"\nSample pair:")
print(f"Prompt: {preference_data[0]['prompt']}")
print(f"\nChosen (length: {len(preference_data[0]['chosen'])})")
print(f"Rejected (length: {len(preference_data[0]['rejected'])})")

In [None]:
class PreferenceDataset:
    """
    Dataset class for DPO training.
    """
    
    def __init__(self, data: List[Dict], tokenizer, max_length: int = 512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def format_prompt(self, prompt: str, response: str) -> str:
        """Format prompt and response for the model."""
        return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{response}<|eot_id|>"""
    
    def __getitem__(self, idx: int) -> Dict:
        item = self.data[idx]
        
        # Format chosen and rejected
        chosen_text = self.format_prompt(item['prompt'], item['chosen'])
        rejected_text = self.format_prompt(item['prompt'], item['rejected'])
        
        # Tokenize
        chosen_tokens = self.tokenizer(
            chosen_text,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        rejected_tokens = self.tokenizer(
            rejected_text,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'chosen_input_ids': chosen_tokens['input_ids'].squeeze(),
            'chosen_attention_mask': chosen_tokens['attention_mask'].squeeze(),
            'rejected_input_ids': rejected_tokens['input_ids'].squeeze(),
            'rejected_attention_mask': rejected_tokens['attention_mask'].squeeze(),
        }

### Creating Preference Data from Existing Datasets

Often you need to create preference pairs from existing data. Here are strategies:

In [None]:
class PreferenceDataGenerator:
    """
    Generate preference pairs from various sources.
    """
    
    @staticmethod
    def from_quality_scores(
        responses: List[Dict],
        min_score_diff: float = 0.5
    ) -> List[Dict]:
        """
        Create pairs from responses with quality scores.
        
        Input format: [{"prompt": str, "response": str, "score": float}, ...]
        """
        # Group by prompt
        by_prompt = {}
        for r in responses:
            prompt = r['prompt']
            if prompt not in by_prompt:
                by_prompt[prompt] = []
            by_prompt[prompt].append(r)
        
        pairs = []
        for prompt, responses in by_prompt.items():
            # Sort by score
            responses.sort(key=lambda x: x['score'], reverse=True)
            
            # Create pairs from high/low scores
            for i, high in enumerate(responses):
                for low in responses[i+1:]:
                    if high['score'] - low['score'] >= min_score_diff:
                        pairs.append({
                            'prompt': prompt,
                            'chosen': high['response'],
                            'rejected': low['response']
                        })
        
        return pairs
    
    @staticmethod
    def from_length_preference(
        data: List[Dict],
        prefer_longer: bool = True
    ) -> List[Dict]:
        """
        Create pairs based on response length.
        Useful for teaching the model to give more/less detailed responses.
        """
        pairs = []
        
        for item in data:
            if 'responses' not in item:
                continue
            
            responses = sorted(item['responses'], key=len, reverse=True)
            
            if len(responses) >= 2:
                if prefer_longer:
                    chosen, rejected = responses[0], responses[-1]
                else:
                    chosen, rejected = responses[-1], responses[0]
                
                pairs.append({
                    'prompt': item['prompt'],
                    'chosen': chosen,
                    'rejected': rejected
                })
        
        return pairs
    
    @staticmethod
    def from_human_rankings(
        rankings: List[Dict]
    ) -> List[Dict]:
        """
        Create pairs from human rankings.
        
        Input: [{"prompt": str, "responses": [str], "ranking": [int]}, ...]
        ranking[i] = position of responses[i] (1 = best)
        """
        pairs = []
        
        for item in rankings:
            responses = item['responses']
            ranking = item['ranking']
            
            # Create all pairs where chosen is ranked higher
            for i in range(len(responses)):
                for j in range(len(responses)):
                    if ranking[i] < ranking[j]:  # lower number = better
                        pairs.append({
                            'prompt': item['prompt'],
                            'chosen': responses[i],
                            'rejected': responses[j]
                        })
        
        return pairs


# Example with quality scores
scored_data = [
    {"prompt": "What is AI?", "response": "AI is advanced computers.", "score": 0.3},
    {"prompt": "What is AI?", "response": "Artificial Intelligence (AI) refers to computer systems designed to perform tasks that typically require human intelligence, such as visual perception, speech recognition, decision-making, and language translation.", "score": 0.9},
    {"prompt": "What is AI?", "response": "AI means Artificial Intelligence. It's technology that mimics human thinking.", "score": 0.6},
]

generated_pairs = PreferenceDataGenerator.from_quality_scores(scored_data, min_score_diff=0.2)
print(f"Generated {len(generated_pairs)} preference pairs from scored data")

---

## Part 3: DPO Training with TRL

Let's use the TRL library to train with DPO on a real model.

In [None]:
# Import DPO trainer with version compatibility
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import Dataset

# Handle TRL API changes across versions
# TRL >=0.8.0 uses DPOConfig, older versions use TrainingArguments
from trl import DPOTrainer

# Log TRL version for debugging
try:
    import trl
    print(f"TRL version: {trl.__version__}")
except AttributeError:
    print("TRL version: unknown (pre-0.5)")

# DPOConfig was introduced in TRL 0.8+ with different parameter names
# We need to detect the TRL version to handle this correctly
USE_DPO_CONFIG = False
DPO_BETA_IN_CONFIG = True  # Whether beta goes in config vs trainer

try:
    from trl import DPOConfig
    USE_DPO_CONFIG = True
    DPO_BETA_IN_CONFIG = True
    print("Using DPOConfig from TRL (recommended for TRL >=0.8)")
except ImportError:
    try:
        from trl import DPOTrainingArguments as DPOConfig
        USE_DPO_CONFIG = True
        DPO_BETA_IN_CONFIG = True
        print("Using DPOTrainingArguments as DPOConfig (TRL 0.7.x)")
    except ImportError:
        # Fallback: older TRL versions use TrainingArguments + pass beta to trainer
        USE_DPO_CONFIG = False
        DPO_BETA_IN_CONFIG = False
        print("Note: Using TrainingArguments (older TRL <0.7 - beta passed to DPOTrainer directly)")
        print("  Consider upgrading TRL: pip install --upgrade trl")

print("\nTRL DPO imports successful!")

In [None]:
# Configuration for DPO training
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"  # Or smaller model for faster testing
# Alternative: "TinyLlama/TinyLlama-1.1B-Chat-v1.0" for quick testing

# Quantization config for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# LoRA config
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)

print("Configuration ready!")

In [None]:
# Create dataset in the format expected by TRL's DPOTrainer
def prepare_dpo_dataset(data: List[Dict]) -> Dataset:
    """
    Prepare dataset for TRL DPOTrainer.
    
    Expected columns: 'prompt', 'chosen', 'rejected'
    """
    return Dataset.from_list(data)

# Prepare our preference data
dpo_dataset = prepare_dpo_dataset(preference_data)
print(f"DPO Dataset: {len(dpo_dataset)} examples")
print(f"Columns: {dpo_dataset.column_names}")

In [None]:
# Function to load model and run DPO training
def train_with_dpo(
    model_name: str,
    dataset: Dataset,
    output_dir: str = "./dpo-finetuned",
    num_epochs: int = 1,
    beta: float = 0.1,
    learning_rate: float = 5e-5,
):
    """
    Complete DPO training pipeline.
    
    Handles API differences between TRL versions automatically.
    """
    print(f"Loading model: {model_name}")
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, lora_config)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"  # Important for generation
    
    # The reference model is a frozen copy
    # DPOTrainer handles this automatically
    
    # Handle TRL API differences
    if USE_DPO_CONFIG:
        # Modern TRL with DPOConfig - beta goes in config
        dpo_config = DPOConfig(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=1,
            gradient_accumulation_steps=4,
            learning_rate=learning_rate,
            beta=beta,  # DPO beta parameter in config
            max_length=512,
            max_prompt_length=256,
            bf16=True,
            gradient_checkpointing=True,
            logging_steps=1,
            save_strategy="epoch",
            report_to="none",
        )
        
        # Try new API first (processing_class), fall back to tokenizer
        try:
            trainer = DPOTrainer(
                model=model,
                args=dpo_config,
                train_dataset=dataset,
                processing_class=tokenizer,
            )
        except TypeError:
            # Older TRL version uses 'tokenizer' parameter
            trainer = DPOTrainer(
                model=model,
                args=dpo_config,
                train_dataset=dataset,
                tokenizer=tokenizer,
            )
    else:
        # Older TRL: use TrainingArguments and pass beta to trainer
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=1,
            gradient_accumulation_steps=4,
            learning_rate=learning_rate,
            bf16=True,
            gradient_checkpointing=True,
            logging_steps=1,
            save_strategy="epoch",
            report_to="none",
        )
        
        trainer = DPOTrainer(
            model=model,
            args=training_args,
            beta=beta,  # Pass beta directly to trainer in older TRL
            train_dataset=dataset,
            tokenizer=tokenizer,
            max_length=512,
            max_prompt_length=256,
        )
    
    print("\nStarting DPO training...")
    trainer.train()
    
    # Save model
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    print(f"\nModel saved to {output_dir}")
    return model, tokenizer

print("DPO training function ready!")
print("\nTo train, run:")
print("model, tokenizer = train_with_dpo(MODEL_NAME, dpo_dataset)")

In [None]:
# Uncomment to actually run training (requires model access and significant compute)
# model, tokenizer = train_with_dpo(MODEL_NAME, dpo_dataset)

# For demonstration, we'll show the expected training output
print("""Expected training output:

Loading model: meta-llama/Llama-3.1-8B-Instruct
Starting DPO training...

{'loss': 0.693, 'chosen_rewards': 0.012, 'rejected_rewards': -0.015, ...}
{'loss': 0.651, 'chosen_rewards': 0.045, 'rejected_rewards': -0.032, ...}
{'loss': 0.589, 'chosen_rewards': 0.089, 'rejected_rewards': -0.078, ...}
...

What to look for:
- loss should decrease over time
- chosen_rewards should increase (positive)
- rejected_rewards should decrease (negative)
- The gap between chosen and rejected rewards should grow
""")

---

## Part 4: Comparing SFT vs DPO

Let's compare models trained with only SFT vs SFT + DPO.

In [None]:
class ModelComparator:
    """
    Compare responses from SFT-only vs DPO-trained models.
    """
    
    def __init__(self, sft_model, dpo_model, tokenizer):
        self.sft_model = sft_model
        self.dpo_model = dpo_model
        self.tokenizer = tokenizer
    
    def generate(self, model, prompt: str, max_tokens: int = 200) -> str:
        """Generate response from a model."""
        formatted = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        
        inputs = self.tokenizer(formatted, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                temperature=0.7,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
            )
        
        generated = outputs[0][inputs['input_ids'].shape[1]:]
        return self.tokenizer.decode(generated, skip_special_tokens=True)
    
    def compare(self, prompt: str) -> Dict:
        """Compare responses from both models."""
        return {
            "prompt": prompt,
            "sft_response": self.generate(self.sft_model, prompt),
            "dpo_response": self.generate(self.dpo_model, prompt),
        }
    
    def run_comparison(self, prompts: List[str]) -> List[Dict]:
        """Compare responses for multiple prompts."""
        results = []
        for prompt in prompts:
            result = self.compare(prompt)
            results.append(result)
            
            print(f"\n{'='*60}")
            print(f"Prompt: {prompt}")
            print(f"\n--- SFT Response ---")
            print(result['sft_response'][:500])
            print(f"\n--- DPO Response ---")
            print(result['dpo_response'][:500])
        
        return results

In [None]:
# Evaluation criteria for manual comparison
evaluation_criteria = """
EVALUATION CRITERIA FOR SFT vs DPO COMPARISON
=============================================

When comparing responses, evaluate on these dimensions:

1. HELPFULNESS (0-5)
   - Does it actually answer the question?
   - Is the information accurate and complete?
   - Does it provide actionable guidance?

2. CLARITY (0-5)
   - Is it well-organized and easy to follow?
   - Does it use appropriate formatting?
   - Is the language clear and precise?

3. DEPTH (0-5)
   - Does it provide sufficient detail?
   - Does it cover edge cases or nuances?
   - Does it offer examples where helpful?

4. STYLE (0-5)
   - Is the tone appropriate?
   - Is it engaging without being overly casual?
   - Does it match the complexity to the question?

5. SAFETY (0-5)
   - Does it avoid harmful content?
   - Does it acknowledge uncertainty appropriately?
   - Does it avoid hallucination?

EXPECTED DIFFERENCES:
- DPO should produce responses more aligned with human preferences
- DPO often produces more detailed, better-structured responses
- DPO may better match the preferred style from training data
- SFT alone may produce valid but "blander" responses
"""

print(evaluation_criteria)

---

## Part 5: DPO Variants and Extensions

Several improvements to DPO have been proposed. Let's explore the main ones.

In [None]:
# DPO Variants
dpo_variants = """
DPO VARIANTS AND EXTENSIONS
===========================

1. IPO (Identity Preference Optimization)
   - Uses a different loss formulation
   - Less prone to overfitting on preference data
   - Better for smaller datasets

2. cDPO (Conservative DPO)
   - Adds label smoothing to prevent overconfidence
   - More robust when preference labels are noisy

3. ORPO (Odds Ratio Preference Optimization)
   - Combines SFT and DPO in a single training step
   - More memory efficient (no reference model needed)
   - Loss: L_SFT + λ * L_odds_ratio

4. SimPO (Simple Preference Optimization)
   - Removes need for reference model
   - Uses length-normalized rewards
   - Simpler to implement and train

5. KTO (Kahneman-Tversky Optimization)
   - Doesn't require paired preferences
   - Only needs good/bad labels
   - Based on prospect theory

CHOOSING THE RIGHT VARIANT:
- DPO: Standard choice, well-tested
- ORPO: When memory is limited
- SimPO: When you want simplicity
- KTO: When you only have thumbs up/down data
"""

print(dpo_variants)

In [None]:
# Implement ORPO loss for comparison
def orpo_loss(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    lambda_weight: float = 0.1,
) -> torch.Tensor:
    """
    Compute ORPO loss (no reference model needed!).
    
    ORPO computes odds ratio: P(chosen) / P(rejected)
    and maximizes log odds ratio.
    """
    # Convert log probs to probs
    # In practice, we work in log space for numerical stability
    
    # Log odds ratio = log P(chosen) - log P(rejected)
    # This is already what we have!
    log_odds_ratio = policy_chosen_logps - policy_rejected_logps
    
    # ORPO loss: negative log sigmoid of log odds ratio
    odds_loss = -F.logsigmoid(log_odds_ratio).mean()
    
    # SFT loss component (maximize chosen log prob)
    sft_loss = -policy_chosen_logps.mean()
    
    # Combined loss
    total_loss = sft_loss + lambda_weight * odds_loss
    
    return total_loss

# Test ORPO
orpo_l = orpo_loss(policy_chosen, policy_rejected)
print(f"ORPO Loss: {orpo_l.item():.4f}")
print("\nORPO advantage: No reference model needed!")

---

## Try It Yourself: Exercises

### Exercise 1: Create a Preference Dataset

Create at least 20 preference pairs for a specific domain (e.g., customer support, medical Q&A).

In [None]:
# Exercise 1: Your preference dataset
your_preferences = [
    # Add your pairs here
]

### Exercise 2: Implement SimPO Loss

SimPO uses length-normalized rewards and no reference model. Implement it!

In [None]:
# Exercise 2: SimPO implementation
def simpo_loss(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    chosen_lengths: torch.Tensor,
    rejected_lengths: torch.Tensor,
    beta: float = 2.0,
    gamma: float = 0.5,
) -> torch.Tensor:
    """
    Implement SimPO loss.
    
    Hint: SimPO uses length-normalized log probs and adds a margin.
    r_chosen = beta * (logp_chosen / len_chosen)
    r_rejected = beta * (logp_rejected / len_rejected)
    loss = -log_sigmoid(r_chosen - r_rejected - gamma)
    """
    # Your implementation here
    pass

---

## Common Mistakes

### Mistake 1: Imbalanced Preference Pairs

```python
# ❌ Wrong: Chosen always much longer
pairs = [
    {"prompt": "...", "chosen": "Very long detailed response...", "rejected": "Short."},
    # Model learns to just output longer responses!
]

# ✅ Right: Mix of lengths, focus on quality
pairs = [
    {"prompt": "...", "chosen": "Concise but complete.", "rejected": "Verbose but unhelpful..."},
    {"prompt": "...", "chosen": "Detailed when needed...", "rejected": "Too brief."},
]
```

### Mistake 2: Wrong Beta Value

```python
# ❌ Wrong: Beta too high
beta = 1.0  # Model diverges too far from reference, becomes unstable

# ❌ Wrong: Beta too low
beta = 0.01  # Model barely learns preferences

# ✅ Right: Standard range
beta = 0.1  # Good starting point
```

### Mistake 3: Noisy Preference Labels

```python
# ❌ Wrong: Inconsistent preferences
pairs = [
    {"prompt": "...", "chosen": "Response A", "rejected": "Response B"},
    {"prompt": "...", "chosen": "Response B", "rejected": "Response A"},  # Contradicts!
]

# ✅ Right: Consistent, clear preferences
# Use multiple annotators and filter disagreements
```

---

## Checkpoint

You've learned:
- ✅ The theory behind DPO and how it simplifies RLHF
- ✅ How to create and structure preference datasets
- ✅ How to implement DPO loss from scratch
- ✅ How to use TRL's DPOTrainer
- ✅ DPO variants (ORPO, SimPO, KTO)

---

## Cleanup

In [None]:
# Cleanup
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
print("Cleanup complete!")

---

## Further Reading

- [DPO Paper: Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
- [ORPO Paper](https://arxiv.org/abs/2403.07691)
- [SimPO Paper](https://arxiv.org/abs/2405.14734)
- [TRL Documentation](https://huggingface.co/docs/trl)

---

## Next Steps

**[Lab 3.1.6: LLaMA Factory Exploration](06-llama-factory-exploration.ipynb)**

Learn to use GUI-based fine-tuning with LLaMA Factory!