# Lab 3.1.7: DPO Training - Solutions

Complete solutions for Direct Preference Optimization exercises.

## Exercise 1: DPO Loss Implementation

In [None]:
import torch
import torch.nn.functional as F

def dpo_loss(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    reference_chosen_logps: torch.Tensor,
    reference_rejected_logps: torch.Tensor,
    beta: float = 0.1
) -> tuple:
    """
    Compute DPO loss from scratch.
    
    L_DPO = -log(σ(β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x)))))
    
    Where:
    - π = policy model
    - π_ref = reference (frozen) model
    - y_w = chosen (winning) response
    - y_l = rejected (losing) response
    - β = temperature parameter
    """
    # Compute log ratios
    chosen_logratios = policy_chosen_logps - reference_chosen_logps
    rejected_logratios = policy_rejected_logps - reference_rejected_logps
    
    # Compute DPO loss
    # logits = β * (chosen_logratio - rejected_logratio)
    logits = beta * (chosen_logratios - rejected_logratios)
    
    # Binary cross entropy with logits (more stable than sigmoid + log)
    # Target is 1 (chosen should be preferred)
    losses = -F.logsigmoid(logits)
    
    # Metrics
    chosen_rewards = beta * chosen_logratios.detach()
    rejected_rewards = beta * rejected_logratios.detach()
    reward_margins = (chosen_rewards - rejected_rewards).mean()
    accuracy = (logits > 0).float().mean()
    
    return losses.mean(), {
        "chosen_rewards": chosen_rewards.mean().item(),
        "rejected_rewards": rejected_rewards.mean().item(),
        "reward_margin": reward_margins.item(),
        "accuracy": accuracy.item()
    }

# Test
torch.manual_seed(42)
batch_size = 4

# Simulate log probabilities
policy_chosen = torch.randn(batch_size) - 1  # Higher for chosen
policy_rejected = torch.randn(batch_size) - 2  # Lower for rejected
ref_chosen = torch.randn(batch_size) - 1.5
ref_rejected = torch.randn(batch_size) - 1.5

loss, metrics = dpo_loss(
    policy_chosen, policy_rejected,
    ref_chosen, ref_rejected,
    beta=0.1
)

print(f"DPO Loss: {loss:.4f}")
print(f"Metrics: {metrics}")

## Exercise 2: Complete DPO Training Pipeline

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import DPOTrainer, DPOConfig
from datasets import Dataset

def create_dpo_pipeline(
    model_id: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    output_dir: str = "./dpo-output",
    beta: float = 0.1,
    epochs: int = 1
):
    """
    Complete DPO training pipeline.
    """
    print("=" * 60)
    print("DPO TRAINING PIPELINE")
    print("=" * 60)
    
    # 1. Quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    
    # 2. Load model
    print("\n1. Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
    model = prepare_model_for_kbit_training(model)
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    
    # 3. LoRA for DPO
    print("2. Configuring LoRA...")
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    
    # 4. Reference model (frozen copy)
    print("3. Creating reference model...")
    ref_model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
    
    # 5. Sample preference data
    print("4. Preparing dataset...")
    preference_data = [
        {
            "prompt": "Explain quantum computing in simple terms.",
            "chosen": "Quantum computing uses quantum bits that can be 0, 1, or both at once, "
                      "allowing it to process many possibilities simultaneously.",
            "rejected": "Quantum computing is about computers using quantum stuff."
        },
        {
            "prompt": "What are the benefits of exercise?",
            "chosen": "Regular exercise improves cardiovascular health, boosts mood through "
                      "endorphin release, enhances sleep quality, and increases energy levels.",
            "rejected": "Exercise is good for you."
        },
        {
            "prompt": "How do I make a good first impression?",
            "chosen": "Make eye contact, smile genuinely, use a firm handshake, listen actively, "
                      "and show interest in others by asking thoughtful questions.",
            "rejected": "Just be yourself I guess."
        },
    ]
    
    dataset = Dataset.from_list(preference_data)
    
    # 6. DPO training config
    print("5. Setting up DPO training...")
    training_args = DPOConfig(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        learning_rate=5e-5,
        
        # DPO specific
        beta=beta,
        loss_type="sigmoid",  # or "hinge", "ipo"
        
        # Memory
        bf16=True,
        gradient_checkpointing=True,
        
        # Logging
        logging_steps=1,
        report_to="none",
    )
    
    # 7. Create DPO trainer
    trainer = DPOTrainer(
        model=model,
        ref_model=ref_model,
        args=training_args,
        train_dataset=dataset,
        tokenizer=tokenizer,
    )
    
    print(f"\nConfiguration:")
    print(f"  β (beta): {beta}")
    print(f"  Loss type: {training_args.loss_type}")
    print(f"  Samples: {len(dataset)}")
    print("=" * 60)
    
    return trainer, model, tokenizer

# Uncomment to run
# trainer, model, tokenizer = create_dpo_pipeline()
# trainer.train()

## Exercise 3: Beta Parameter Study

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def beta_sensitivity_analysis():
    """
    Visualize how beta affects DPO behavior.
    """
    # Simulate log probability differences
    log_ratio_diff = np.linspace(-5, 5, 100)  # chosen_logratio - rejected_logratio
    
    betas = [0.01, 0.1, 0.5, 1.0]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Loss curves
    for beta in betas:
        logits = beta * log_ratio_diff
        loss = -np.log(1 / (1 + np.exp(-logits)))
        axes[0].plot(log_ratio_diff, loss, label=f'β={beta}', linewidth=2)
    
    axes[0].set_xlabel('Log Ratio Difference (chosen - rejected)')
    axes[0].set_ylabel('DPO Loss')
    axes[0].set_title('DPO Loss vs Log Ratio Difference')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    axes[0].axvline(x=0, color='black', linestyle='--', alpha=0.5)
    
    # Plot 2: Gradient magnitude
    for beta in betas:
        logits = beta * log_ratio_diff
        sigmoid = 1 / (1 + np.exp(-logits))
        gradient = -beta * (1 - sigmoid)  # d/d(log_ratio) of -log(sigmoid)
        axes[1].plot(log_ratio_diff, np.abs(gradient), label=f'β={beta}', linewidth=2)
    
    axes[1].set_xlabel('Log Ratio Difference')
    axes[1].set_ylabel('|Gradient|')
    axes[1].set_title('Gradient Magnitude (Learning Signal)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('beta_analysis.png', dpi=150)
    plt.show()
    
    print("\nBeta Guidelines:")
    print("  β = 0.01-0.05: Very conservative, slow alignment")
    print("  β = 0.1:       Default, good starting point")
    print("  β = 0.2-0.5:   Stronger preference enforcement")
    print("  β > 0.5:       Aggressive, risk of overfitting")

beta_sensitivity_analysis()

## Exercise 4: Evaluation Metrics

In [None]:
def evaluate_dpo_model(model, tokenizer, ref_model, eval_data: list):
    """
    Evaluate DPO model performance.
    """
    model.eval()
    ref_model.eval()
    
    metrics = {
        "preference_accuracy": [],
        "reward_margin": [],
        "kl_divergence": []
    }
    
    for item in eval_data:
        prompt = item["prompt"]
        chosen = item["chosen"]
        rejected = item["rejected"]
        
        # Tokenize
        chosen_input = tokenizer(prompt + chosen, return_tensors="pt").to(model.device)
        rejected_input = tokenizer(prompt + rejected, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            # Policy log probs
            policy_chosen_logp = model(**chosen_input).logits.log_softmax(-1).mean()
            policy_rejected_logp = model(**rejected_input).logits.log_softmax(-1).mean()
            
            # Reference log probs
            ref_chosen_logp = ref_model(**chosen_input).logits.log_softmax(-1).mean()
            ref_rejected_logp = ref_model(**rejected_input).logits.log_softmax(-1).mean()
        
        # Compute metrics
        chosen_reward = policy_chosen_logp - ref_chosen_logp
        rejected_reward = policy_rejected_logp - ref_rejected_logp
        
        correct = (chosen_reward > rejected_reward).float().item()
        margin = (chosen_reward - rejected_reward).item()
        kl = (policy_chosen_logp - ref_chosen_logp).item()
        
        metrics["preference_accuracy"].append(correct)
        metrics["reward_margin"].append(margin)
        metrics["kl_divergence"].append(kl)
    
    # Aggregate
    results = {
        "preference_accuracy": np.mean(metrics["preference_accuracy"]),
        "avg_reward_margin": np.mean(metrics["reward_margin"]),
        "avg_kl_divergence": np.mean(metrics["kl_divergence"])
    }
    
    print("\nDPO Evaluation Results")
    print("=" * 40)
    for key, value in results.items():
        print(f"{key}: {value:.4f}")
    
    return results

# Uncomment after training
# eval_results = evaluate_dpo_model(model, tokenizer, ref_model, eval_data)

## Key Takeaways

1. **DPO Loss**: -log(σ(β * (chosen_logratio - rejected_logratio)))
2. **Reference Model**: Frozen copy to prevent mode collapse
3. **Beta Tuning**: 0.1 is good default, increase for stronger alignment
4. **Evaluation**: Track preference accuracy and KL divergence