# Lab 3.1.5 Solutions: Direct Preference Optimization (DPO)

**Module:** 3.1 - Large Language Model Fine-Tuning  
**Difficulty:** ⭐⭐⭐⭐☆ (Advanced)  
**Exercises:** 3 (DPO Loss from Scratch, DPO Variants Comparison, Preference Dataset Generator)

This notebook contains solutions for the DPO training exercises.

---

---

## Exercise 1 Solution: DPO Loss Implementation from Scratch

**Task:** Implement the DPO loss function with detailed comments.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Dict, Optional

class DPOLoss(nn.Module):
    """
    Direct Preference Optimization (DPO) Loss.
    
    DPO directly optimizes the policy to maximize the log probability ratio
    between chosen and rejected responses, weighted by a reference model.
    
    Loss = -log(sigmoid(β * (log π(y_w|x) - log π(y_l|x) - log π_ref(y_w|x) + log π_ref(y_l|x))))
    
    Where:
    - π is the policy model being trained
    - π_ref is the frozen reference model
    - y_w is the chosen (winner) response
    - y_l is the rejected (loser) response
    - x is the prompt
    - β is the temperature controlling preference strength
    
    Reference: https://arxiv.org/abs/2305.18290
    """
    
    def __init__(
        self,
        beta: float = 0.1,
        label_smoothing: float = 0.0,
        reference_free: bool = False,
    ):
        """
        Initialize DPO loss.
        
        Args:
            beta: Temperature parameter (higher = stronger preferences)
            label_smoothing: Label smoothing for soft targets
            reference_free: If True, skip reference model (SimPO-like)
        """
        super().__init__()
        self.beta = beta
        self.label_smoothing = label_smoothing
        self.reference_free = reference_free
    
    def compute_log_probs(
        self,
        logits: torch.Tensor,
        labels: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute per-sequence log probabilities.
        
        Args:
            logits: Model output logits (batch, seq_len, vocab_size)
            labels: Target token IDs (batch, seq_len)
            attention_mask: Mask for valid tokens (batch, seq_len)
        
        Returns:
            Per-sequence sum of log probabilities (batch,)
        """
        # Shift logits and labels for causal LM
        # logits[t] predicts labels[t+1]
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()
        shift_mask = attention_mask[:, 1:].contiguous()
        
        # Compute log softmax
        log_probs = F.log_softmax(shift_logits, dim=-1)
        
        # Gather the log probs for the actual tokens
        # Shape: (batch, seq_len-1)
        per_token_log_probs = torch.gather(
            log_probs,
            dim=-1,
            index=shift_labels.unsqueeze(-1)
        ).squeeze(-1)
        
        # Mask out padding and sum per sequence
        per_token_log_probs = per_token_log_probs * shift_mask
        sequence_log_probs = per_token_log_probs.sum(dim=-1)
        
        # Optional: normalize by length
        # sequence_log_probs = sequence_log_probs / shift_mask.sum(dim=-1).clamp(min=1)
        
        return sequence_log_probs
    
    def forward(
        self,
        policy_chosen_logits: torch.Tensor,
        policy_rejected_logits: torch.Tensor,
        chosen_labels: torch.Tensor,
        rejected_labels: torch.Tensor,
        chosen_mask: torch.Tensor,
        rejected_mask: torch.Tensor,
        reference_chosen_logits: Optional[torch.Tensor] = None,
        reference_rejected_logits: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Compute DPO loss.
        
        Args:
            policy_chosen_logits: Policy model logits for chosen responses
            policy_rejected_logits: Policy model logits for rejected responses
            chosen_labels: Token IDs for chosen responses
            rejected_labels: Token IDs for rejected responses
            chosen_mask: Attention mask for chosen responses
            rejected_mask: Attention mask for rejected responses
            reference_chosen_logits: Reference model logits for chosen (optional)
            reference_rejected_logits: Reference model logits for rejected (optional)
        
        Returns:
            Tuple of (loss, metrics_dict)
        """
        # Step 1: Compute policy log probabilities
        policy_chosen_log_probs = self.compute_log_probs(
            policy_chosen_logits, chosen_labels, chosen_mask
        )
        policy_rejected_log_probs = self.compute_log_probs(
            policy_rejected_logits, rejected_labels, rejected_mask
        )
        
        # Step 2: Compute reference log probabilities (if not reference-free)
        if self.reference_free or reference_chosen_logits is None:
            ref_chosen_log_probs = torch.zeros_like(policy_chosen_log_probs)
            ref_rejected_log_probs = torch.zeros_like(policy_rejected_log_probs)
        else:
            with torch.no_grad():
                ref_chosen_log_probs = self.compute_log_probs(
                    reference_chosen_logits, chosen_labels, chosen_mask
                )
                ref_rejected_log_probs = self.compute_log_probs(
                    reference_rejected_logits, rejected_labels, rejected_mask
                )
        
        # Step 3: Compute log ratios
        # π_chosen = log π(y_w|x) - log π_ref(y_w|x)
        # π_rejected = log π(y_l|x) - log π_ref(y_l|x)
        chosen_log_ratio = policy_chosen_log_probs - ref_chosen_log_probs
        rejected_log_ratio = policy_rejected_log_probs - ref_rejected_log_probs
        
        # Step 4: Compute the DPO loss
        # logits = β * (π_chosen - π_rejected)
        logits = self.beta * (chosen_log_ratio - rejected_log_ratio)
        
        # Apply label smoothing if needed
        if self.label_smoothing > 0:
            # Soft labels: (1 - ε, ε) instead of (1, 0)
            loss = (
                (1 - self.label_smoothing) * F.logsigmoid(logits) +
                self.label_smoothing * F.logsigmoid(-logits)
            )
            loss = -loss.mean()
        else:
            # Standard DPO loss: -log(sigmoid(logits))
            loss = -F.logsigmoid(logits).mean()
        
        # Step 5: Compute metrics
        with torch.no_grad():
            # Accuracy: how often does the model prefer chosen over rejected?
            accuracy = (logits > 0).float().mean()
            
            # Reward margins
            chosen_rewards = self.beta * chosen_log_ratio
            rejected_rewards = self.beta * rejected_log_ratio
            reward_margin = (chosen_rewards - rejected_rewards).mean()
        
        metrics = {
            "loss": loss.detach(),
            "accuracy": accuracy,
            "chosen_reward": chosen_rewards.mean(),
            "rejected_reward": rejected_rewards.mean(),
            "reward_margin": reward_margin,
            "chosen_log_prob": policy_chosen_log_probs.mean(),
            "rejected_log_prob": policy_rejected_log_probs.mean(),
        }
        
        return loss, metrics


# Demo
print("DPO Loss Implementation")
print("=" * 50)

# Create dummy data
batch_size = 4
seq_len = 32
vocab_size = 1000

policy_chosen_logits = torch.randn(batch_size, seq_len, vocab_size)
policy_rejected_logits = torch.randn(batch_size, seq_len, vocab_size)
chosen_labels = torch.randint(0, vocab_size, (batch_size, seq_len))
rejected_labels = torch.randint(0, vocab_size, (batch_size, seq_len))
chosen_mask = torch.ones(batch_size, seq_len)
rejected_mask = torch.ones(batch_size, seq_len)

# Compute loss
dpo_loss = DPOLoss(beta=0.1)
loss, metrics = dpo_loss(
    policy_chosen_logits,
    policy_rejected_logits,
    chosen_labels,
    rejected_labels,
    chosen_mask,
    rejected_mask,
)

print(f"Loss: {loss.item():.4f}")
print(f"Accuracy: {metrics['accuracy'].item():.2%}")
print(f"Reward margin: {metrics['reward_margin'].item():.4f}")

---

## Exercise 2 Solution: DPO Variants Comparison

**Task:** Implement IPO, KTO, and ORPO loss functions.

In [None]:
class IPOLoss(nn.Module):
    """
    Identity Preference Optimization (IPO) Loss.
    
    IPO avoids overfitting by using a squared hinge loss instead of sigmoid.
    This makes the optimization more robust to noisy preferences.
    
    Loss = ((log π(y_w|x)/π_ref(y_w|x) - log π(y_l|x)/π_ref(y_l|x)) - 1/2β)²
    
    Reference: https://arxiv.org/abs/2310.12036
    """
    
    def __init__(self, beta: float = 0.1):
        super().__init__()
        self.beta = beta
        self.target = 1.0 / (2 * beta)
    
    def forward(
        self,
        chosen_log_ratio: torch.Tensor,
        rejected_log_ratio: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute IPO loss.
        
        Args:
            chosen_log_ratio: log π(y_w|x) - log π_ref(y_w|x)
            rejected_log_ratio: log π(y_l|x) - log π_ref(y_l|x)
        
        Returns:
            IPO loss
        """
        log_ratio_diff = chosen_log_ratio - rejected_log_ratio
        loss = (log_ratio_diff - self.target) ** 2
        return loss.mean()


class KTOLoss(nn.Module):
    """
    Kahneman-Tversky Optimization (KTO) Loss.
    
    KTO doesn't require paired preferences - it works with just
    positive/negative labels for individual responses.
    Based on prospect theory's asymmetric value function.
    
    Reference: https://arxiv.org/abs/2402.01306
    """
    
    def __init__(
        self,
        beta: float = 0.1,
        desirable_weight: float = 1.0,
        undesirable_weight: float = 1.0,
    ):
        super().__init__()
        self.beta = beta
        self.desirable_weight = desirable_weight
        self.undesirable_weight = undesirable_weight
    
    def forward(
        self,
        log_ratios: torch.Tensor,
        is_desirable: torch.Tensor,
        kl_penalty: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute KTO loss.
        
        Args:
            log_ratios: log π(y|x) - log π_ref(y|x) for each response
            is_desirable: Binary mask (1 for chosen, 0 for rejected)
            kl_penalty: KL divergence term for regularization
        
        Returns:
            KTO loss
        """
        # Compute rewards
        rewards = self.beta * log_ratios - kl_penalty
        
        # Desirable: maximize log sigmoid(reward)
        # Undesirable: maximize log(1 - sigmoid(reward))
        desirable_loss = -F.logsigmoid(rewards)
        undesirable_loss = -F.logsigmoid(-rewards)
        
        # Weight and combine
        loss = (
            is_desirable * self.desirable_weight * desirable_loss +
            (1 - is_desirable) * self.undesirable_weight * undesirable_loss
        )
        
        return loss.mean()


class ORPOLoss(nn.Module):
    """
    Odds Ratio Preference Optimization (ORPO) Loss.
    
    ORPO combines SFT and preference optimization into a single loss,
    eliminating the need for a separate reference model.
    
    Loss = SFT_loss - λ * log(sigmoid(log(odds_w / odds_l)))
    
    Where odds = p / (1 - p) for each response.
    
    Reference: https://arxiv.org/abs/2403.07691
    """
    
    def __init__(self, lambda_orpo: float = 0.1):
        super().__init__()
        self.lambda_orpo = lambda_orpo
    
    def compute_odds(self, log_probs: torch.Tensor) -> torch.Tensor:
        """
        Compute odds from log probabilities.
        
        odds = p / (1 - p) = exp(log_p) / (1 - exp(log_p))
        log_odds = log_p - log(1 - exp(log_p))
        """
        # Numerically stable computation
        log_odds = log_probs - torch.log1p(-torch.exp(log_probs.clamp(max=-1e-6)))
        return log_odds
    
    def forward(
        self,
        chosen_logits: torch.Tensor,
        rejected_logits: torch.Tensor,
        chosen_labels: torch.Tensor,
        rejected_labels: torch.Tensor,
        chosen_mask: torch.Tensor,
        rejected_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Compute ORPO loss.
        
        Returns:
            Tuple of (total_loss, metrics_dict)
        """
        # SFT loss on chosen responses
        shift_logits = chosen_logits[:, :-1, :].contiguous()
        shift_labels = chosen_labels[:, 1:].contiguous()
        shift_mask = chosen_mask[:, 1:].contiguous()
        
        sft_loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            reduction='none'
        ).view(shift_labels.shape)
        sft_loss = (sft_loss * shift_mask).sum() / shift_mask.sum()
        
        # Compute log probabilities
        chosen_log_probs = F.log_softmax(shift_logits, dim=-1)
        chosen_per_token = torch.gather(
            chosen_log_probs, -1, shift_labels.unsqueeze(-1)
        ).squeeze(-1)
        chosen_avg_log_prob = (chosen_per_token * shift_mask).sum(-1) / shift_mask.sum(-1)
        
        # Same for rejected
        shift_rej_logits = rejected_logits[:, :-1, :].contiguous()
        shift_rej_labels = rejected_labels[:, 1:].contiguous()
        shift_rej_mask = rejected_mask[:, 1:].contiguous()
        
        rejected_log_probs = F.log_softmax(shift_rej_logits, dim=-1)
        rejected_per_token = torch.gather(
            rejected_log_probs, -1, shift_rej_labels.unsqueeze(-1)
        ).squeeze(-1)
        rejected_avg_log_prob = (rejected_per_token * shift_rej_mask).sum(-1) / shift_rej_mask.sum(-1)
        
        # Compute log odds ratio
        chosen_log_odds = self.compute_odds(chosen_avg_log_prob)
        rejected_log_odds = self.compute_odds(rejected_avg_log_prob)
        log_odds_ratio = chosen_log_odds - rejected_log_odds
        
        # ORPO preference loss
        orpo_loss = -F.logsigmoid(log_odds_ratio).mean()
        
        # Total loss
        total_loss = sft_loss + self.lambda_orpo * orpo_loss
        
        metrics = {
            "sft_loss": sft_loss.detach(),
            "orpo_loss": orpo_loss.detach(),
            "total_loss": total_loss.detach(),
            "log_odds_ratio": log_odds_ratio.mean().detach(),
            "accuracy": (log_odds_ratio > 0).float().mean(),
        }
        
        return total_loss, metrics


# Compare the loss functions
print("DPO Variants Comparison")
print("=" * 50)

# Generate synthetic data
batch_size = 8
chosen_log_ratio = torch.randn(batch_size) + 0.5  # Slightly positive
rejected_log_ratio = torch.randn(batch_size) - 0.5  # Slightly negative

# Standard DPO
dpo_logits = 0.1 * (chosen_log_ratio - rejected_log_ratio)
dpo_loss = -F.logsigmoid(dpo_logits).mean()

# IPO
ipo = IPOLoss(beta=0.1)
ipo_loss = ipo(chosen_log_ratio, rejected_log_ratio)

# KTO (simplified)
kto = KTOLoss(beta=0.1)
log_ratios = torch.cat([chosen_log_ratio, rejected_log_ratio])
is_desirable = torch.cat([torch.ones(batch_size), torch.zeros(batch_size)])
kl_penalty = torch.zeros_like(log_ratios)
kto_loss = kto(log_ratios, is_desirable, kl_penalty)

print(f"DPO Loss: {dpo_loss.item():.4f}")
print(f"IPO Loss: {ipo_loss.item():.4f}")
print(f"KTO Loss: {kto_loss.item():.4f}")

print("\nKey Differences:")
print("- DPO: Uses sigmoid, requires paired preferences")
print("- IPO: Uses squared loss, more robust to noise")
print("- KTO: Works with unpaired data, asymmetric weighting")
print("- ORPO: Combines SFT + preference, no reference model")

---

## Exercise 3 Solution: Preference Dataset Generator

**Task:** Create a utility to generate preference pairs from raw data.

In [None]:
from dataclasses import dataclass
from typing import List, Dict, Callable, Optional
import random
import json

@dataclass
class PreferencePair:
    """A single preference comparison."""
    prompt: str
    chosen: str
    rejected: str
    chosen_score: Optional[float] = None
    rejected_score: Optional[float] = None
    metadata: Optional[Dict] = None
    
    def to_dict(self) -> Dict:
        return {
            "prompt": self.prompt,
            "chosen": self.chosen,
            "rejected": self.rejected,
            "chosen_score": self.chosen_score,
            "rejected_score": self.rejected_score,
            "metadata": self.metadata or {},
        }


class PreferenceDatasetGenerator:
    """
    Generate preference pairs for DPO training.
    
    Supports multiple strategies:
    - Score-based pairing (if responses have scores)
    - Length-based heuristics
    - Quality rubric-based
    - LLM-as-judge
    """
    
    def __init__(self, min_score_diff: float = 0.5):
        """
        Initialize generator.
        
        Args:
            min_score_diff: Minimum score difference for valid pairs
        """
        self.min_score_diff = min_score_diff
        self.quality_scorers: List[Callable[[str, str], float]] = []
    
    def add_scorer(self, scorer: Callable[[str, str], float]):
        """
        Add a quality scoring function.
        
        Args:
            scorer: Function(prompt, response) -> score
        """
        self.quality_scorers.append(scorer)
    
    def compute_quality_score(
        self, 
        prompt: str, 
        response: str
    ) -> float:
        """Compute aggregate quality score using all scorers."""
        if not self.quality_scorers:
            return 0.0
        scores = [scorer(prompt, response) for scorer in self.quality_scorers]
        return sum(scores) / len(scores)
    
    def from_scored_responses(
        self,
        data: List[Dict],
        prompt_key: str = "prompt",
        response_key: str = "response",
        score_key: str = "score",
    ) -> List[PreferencePair]:
        """
        Generate pairs from responses with existing scores.
        
        Groups responses by prompt, then pairs high vs low scoring.
        """
        # Group by prompt
        by_prompt: Dict[str, List[Dict]] = {}
        for item in data:
            prompt = item[prompt_key]
            if prompt not in by_prompt:
                by_prompt[prompt] = []
            by_prompt[prompt].append(item)
        
        # Generate pairs
        pairs = []
        for prompt, responses in by_prompt.items():
            if len(responses) < 2:
                continue
            
            # Sort by score
            sorted_resp = sorted(
                responses, 
                key=lambda x: x.get(score_key, 0), 
                reverse=True
            )
            
            # Pair best with worst (and potentially others)
            for i in range(len(sorted_resp) // 2):
                chosen = sorted_resp[i]
                rejected = sorted_resp[-(i+1)]
                
                score_diff = chosen.get(score_key, 0) - rejected.get(score_key, 0)
                if score_diff >= self.min_score_diff:
                    pairs.append(PreferencePair(
                        prompt=prompt,
                        chosen=chosen[response_key],
                        rejected=rejected[response_key],
                        chosen_score=chosen.get(score_key),
                        rejected_score=rejected.get(score_key),
                    ))
        
        return pairs
    
    def from_single_responses(
        self,
        data: List[Dict],
        prompt_key: str = "prompt",
        response_key: str = "response",
        rejection_strategy: str = "perturb",
    ) -> List[PreferencePair]:
        """
        Generate pairs from single good responses by creating rejections.
        
        Args:
            data: List of {prompt, response} items
            rejection_strategy: How to create rejected responses
                - "perturb": Add noise/errors to good response
                - "truncate": Cut the response short
                - "shuffle": Randomly shuffle sentences
        """
        pairs = []
        
        for item in data:
            prompt = item[prompt_key]
            chosen = item[response_key]
            
            # Generate rejected response
            if rejection_strategy == "truncate":
                # Cut to 30% of original length
                words = chosen.split()
                cut_point = max(1, len(words) // 3)
                rejected = " ".join(words[:cut_point])
                
            elif rejection_strategy == "shuffle":
                # Shuffle sentences
                sentences = chosen.split(". ")
                if len(sentences) > 1:
                    random.shuffle(sentences)
                    rejected = ". ".join(sentences)
                else:
                    rejected = chosen[::-1]  # Reverse if single sentence
                    
            elif rejection_strategy == "perturb":
                # Add filler words and make it less helpful
                fillers = ["um", "like", "basically", "I guess", "maybe"]
                words = chosen.split()
                rejected_words = []
                for i, word in enumerate(words):
                    rejected_words.append(word)
                    if i % 5 == 0 and random.random() > 0.5:
                        rejected_words.append(random.choice(fillers))
                rejected = " ".join(rejected_words)
            else:
                rejected = "I don't know."
            
            pairs.append(PreferencePair(
                prompt=prompt,
                chosen=chosen,
                rejected=rejected,
                metadata={"strategy": rejection_strategy},
            ))
        
        return pairs
    
    def validate_pairs(self, pairs: List[PreferencePair]) -> List[PreferencePair]:
        """Filter out invalid pairs."""
        valid = []
        for pair in pairs:
            # Check minimum lengths
            if len(pair.chosen) < 10 or len(pair.rejected) < 10:
                continue
            
            # Check they're not identical
            if pair.chosen.strip() == pair.rejected.strip():
                continue
            
            valid.append(pair)
        
        return valid
    
    def to_huggingface_format(
        self, 
        pairs: List[PreferencePair]
    ) -> List[Dict]:
        """
        Convert to HuggingFace TRL format.
        
        Returns list of {prompt, chosen, rejected} dicts.
        """
        return [
            {
                "prompt": p.prompt,
                "chosen": p.chosen,
                "rejected": p.rejected,
            }
            for p in pairs
        ]


# Demo the generator
generator = PreferenceDatasetGenerator(min_score_diff=0.3)

# Add some quality scorers
def length_scorer(prompt: str, response: str) -> float:
    """Score based on response length (prefer longer, up to a point)."""
    words = len(response.split())
    if words < 10:
        return 0.2
    elif words < 50:
        return 0.5 + words * 0.01
    elif words < 200:
        return 1.0
    else:
        return 0.8  # Too long

def specificity_scorer(prompt: str, response: str) -> float:
    """Score based on specificity (prefer concrete details)."""
    vague_words = ["thing", "stuff", "something", "basically", "just"]
    response_lower = response.lower()
    vague_count = sum(1 for w in vague_words if w in response_lower)
    return max(0.1, 1.0 - vague_count * 0.2)

generator.add_scorer(length_scorer)
generator.add_scorer(specificity_scorer)

# Test with sample data
sample_data = [
    {
        "prompt": "What is machine learning?",
        "response": "Machine learning is a comprehensive field of artificial intelligence that enables computer systems to automatically learn and improve from experience without being explicitly programmed. It focuses on developing algorithms that can access data, learn from it, and make predictions or decisions."
    },
    {
        "prompt": "Explain Python.",
        "response": "Python is a high-level, interpreted programming language known for its clear syntax and readability. It supports multiple programming paradigms including procedural, object-oriented, and functional programming."
    },
]

# Generate pairs using truncation strategy
pairs = generator.from_single_responses(
    sample_data,
    rejection_strategy="truncate"
)

print("Generated Preference Pairs:")
print("=" * 60)
for i, pair in enumerate(pairs):
    print(f"\nPair {i+1}:")
    print(f"Prompt: {pair.prompt}")
    print(f"Chosen ({len(pair.chosen)} chars): {pair.chosen[:100]}...")
    print(f"Rejected ({len(pair.rejected)} chars): {pair.rejected[:100]}...")

---

## Summary

These solutions demonstrate:

1. **DPO Loss**: Core implementation with log probability computation

2. **DPO Variants**: IPO, KTO, and ORPO for different scenarios

3. **Dataset Generation**: Creating preference pairs from various data sources

### Key Takeaways

- **DPO** is simpler than RLHF but equally effective
- **β (beta)** controls preference strength - start with 0.1
- **ORPO** eliminates the reference model requirement
- **KTO** works without paired preferences
- **Data quality** is crucial for preference learning