# Lab 3.1.8: SimPO vs ORPO - Solutions

Complete solutions for modern preference optimization exercises.

## Exercise 1: SimPO Loss Implementation

In [None]:
import torch
import torch.nn.functional as F

def simpo_loss(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    chosen_lengths: torch.Tensor,
    rejected_lengths: torch.Tensor,
    beta: float = 2.0,
    gamma: float = 0.5
) -> tuple:
    """
    SimPO Loss: Simple Preference Optimization.
    
    Key differences from DPO:
    1. No reference model needed
    2. Length-normalized log probs
    3. Target reward margin (gamma)
    
    L_SimPO = -log(σ(β * (avg_logp(y_w) - avg_logp(y_l)) - γ))
    
    Where avg_logp = log_prob / sequence_length
    """
    # Length-normalized log probabilities
    chosen_avg_logps = policy_chosen_logps / chosen_lengths
    rejected_avg_logps = policy_rejected_logps / rejected_lengths
    
    # Compute logits with target margin
    logits = beta * (chosen_avg_logps - rejected_avg_logps) - gamma
    
    # Loss
    losses = -F.logsigmoid(logits)
    
    # Metrics
    accuracy = (logits > 0).float().mean()
    margin = (chosen_avg_logps - rejected_avg_logps).mean()
    
    return losses.mean(), {
        "accuracy": accuracy.item(),
        "avg_margin": margin.item(),
        "chosen_avg_logp": chosen_avg_logps.mean().item(),
        "rejected_avg_logp": rejected_avg_logps.mean().item()
    }

# Test
torch.manual_seed(42)
batch_size = 4

# Simulate
policy_chosen = torch.randn(batch_size) * 100 - 200  # Total log prob
policy_rejected = torch.randn(batch_size) * 100 - 250
chosen_lens = torch.randint(50, 200, (batch_size,)).float()
rejected_lens = torch.randint(50, 200, (batch_size,)).float()

loss, metrics = simpo_loss(
    policy_chosen, policy_rejected,
    chosen_lens, rejected_lens,
    beta=2.0, gamma=0.5
)

print(f"SimPO Loss: {loss:.4f}")
print(f"Metrics: {metrics}")

## Exercise 2: ORPO Loss Implementation

In [None]:
def orpo_loss(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    policy_chosen_logits: torch.Tensor,  # For SFT loss
    chosen_labels: torch.Tensor,
    lambda_weight: float = 0.1
) -> tuple:
    """
    ORPO Loss: Odds Ratio Preference Optimization.
    
    L_ORPO = L_SFT + λ * L_OR
    
    Where L_OR = -log(σ(log(odds(y_w)) - log(odds(y_l))))
    And odds(y) = p(y) / (1 - p(y))
    
    This combines SFT and preference alignment in one step.
    """
    # SFT loss (standard cross-entropy)
    sft_loss = F.cross_entropy(
        policy_chosen_logits.view(-1, policy_chosen_logits.size(-1)),
        chosen_labels.view(-1),
        ignore_index=-100
    )
    
    # Odds ratio loss
    # log_odds = log(p / (1-p)) = log(p) - log(1-p) ≈ logp - log(1-exp(logp))
    # For numerical stability, we use a simpler approximation
    log_odds_chosen = policy_chosen_logps - torch.log1p(-torch.exp(policy_chosen_logps).clamp(max=0.99))
    log_odds_rejected = policy_rejected_logps - torch.log1p(-torch.exp(policy_rejected_logps).clamp(max=0.99))
    
    or_logits = log_odds_chosen - log_odds_rejected
    or_loss = -F.logsigmoid(or_logits).mean()
    
    # Combined loss
    total_loss = sft_loss + lambda_weight * or_loss
    
    # Metrics
    accuracy = (or_logits > 0).float().mean()
    
    return total_loss, {
        "sft_loss": sft_loss.item(),
        "or_loss": or_loss.item(),
        "total_loss": total_loss.item(),
        "accuracy": accuracy.item()
    }

print("ORPO combines SFT + Odds Ratio in single training stage.")
print("Benefit: 50% less memory (no reference model)")

## Exercise 3: Side-by-Side Comparison

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def compare_methods():
    """
    Visual comparison of DPO, SimPO, and ORPO.
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # 1. Memory comparison
    methods = ['DPO', 'SimPO', 'ORPO']
    memory = [100, 50, 50]  # Relative %
    colors = ['#e74c3c', '#3498db', '#2ecc71']
    
    axes[0].bar(methods, memory, color=colors)
    axes[0].set_ylabel('Relative Memory (%)')
    axes[0].set_title('Memory Usage')
    axes[0].set_ylim(0, 120)
    for i, v in enumerate(memory):
        axes[0].text(i, v + 3, f'{v}%', ha='center', fontweight='bold')
    
    # 2. Training stages
    stages_dpo = ['SFT', 'DPO']
    stages_simpo = ['SFT', 'SimPO']
    stages_orpo = ['ORPO\n(combined)']
    
    axes[1].barh(['DPO'], [2], color='#e74c3c', label='Steps')
    axes[1].barh(['SimPO'], [2], color='#3498db')
    axes[1].barh(['ORPO'], [1], color='#2ecc71')
    axes[1].set_xlabel('Training Stages')
    axes[1].set_title('Training Complexity')
    
    # Add annotations
    axes[1].text(2.1, 0, 'SFT + DPO', va='center')
    axes[1].text(2.1, 1, 'SFT + SimPO', va='center')
    axes[1].text(1.1, 2, 'Single stage!', va='center', fontweight='bold')
    
    # 3. Performance comparison (based on benchmarks)
    benchmarks = ['AlpacaEval', 'MT-Bench', 'Arena Hard']
    dpo_scores = [40.2, 7.5, 35.0]
    simpo_scores = [44.7, 7.8, 38.0]  # +6.4% on AlpacaEval
    orpo_scores = [42.5, 7.6, 36.5]
    
    x = np.arange(len(benchmarks))
    width = 0.25
    
    axes[2].bar(x - width, dpo_scores, width, label='DPO', color='#e74c3c')
    axes[2].bar(x, simpo_scores, width, label='SimPO', color='#3498db')
    axes[2].bar(x + width, orpo_scores, width, label='ORPO', color='#2ecc71')
    
    axes[2].set_ylabel('Score')
    axes[2].set_title('Benchmark Performance')
    axes[2].set_xticks(x)
    axes[2].set_xticklabels(benchmarks)
    axes[2].legend()
    
    plt.tight_layout()
    plt.savefig('method_comparison.png', dpi=150)
    plt.show()
    
    print("\nMethod Summary:")
    print("=" * 60)
    print(f"{'Method':<10} {'Ref Model':<12} {'Stages':<10} {'Best For':<30}")
    print("-" * 60)
    print(f"{'DPO':<10} {'Yes':<12} {'2':<10} {'Standard preference alignment':<30}")
    print(f"{'SimPO':<10} {'No':<12} {'2':<10} {'Best performance, length control':<30}")
    print(f"{'ORPO':<10} {'No':<12} {'1':<10} {'Memory-constrained training':<30}")

compare_methods()

## Exercise 4: Complete SimPO Training

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from trl import CPOTrainer, CPOConfig
from datasets import Dataset

def create_simpo_training(
    model_id: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    output_dir: str = "./simpo-output"
):
    """
    Complete SimPO training setup.
    
    Note: SimPO is implemented in TRL as CPO with loss_type="simpo"
    """
    print("Setting up SimPO Training...")
    
    # Quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    
    # Load model (no reference model needed!)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    
    # LoRA
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)
    
    # Sample data
    data = Dataset.from_list([
        {"prompt": "Explain AI", "chosen": "AI is...", "rejected": "dunno"},
    ])
    
    # SimPO config via CPO
    training_args = CPOConfig(
        output_dir=output_dir,
        loss_type="simpo",  # Key setting!
        cpo_alpha=0.5,  # gamma in SimPO paper
        per_device_train_batch_size=2,
        bf16=True,
        report_to="none",
    )
    
    trainer = CPOTrainer(
        model=model,
        args=training_args,
        train_dataset=data,
        tokenizer=tokenizer,
    )
    
    print("\nSimPO Key Settings:")
    print(f"  loss_type: simpo")
    print(f"  gamma (cpo_alpha): {training_args.cpo_alpha}")
    print(f"  No reference model needed!")
    
    return trainer

# Uncomment to run
# trainer = create_simpo_training()
# trainer.train()

## Exercise 5: Complete ORPO Training

In [None]:
from trl import ORPOTrainer, ORPOConfig

def create_orpo_training(
    model_id: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    output_dir: str = "./orpo-output"
):
    """
    Complete ORPO training setup.
    
    ORPO = SFT + Odds Ratio in single stage!
    """
    print("Setting up ORPO Training...")
    
    # Quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    
    # LoRA
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)
    
    # Sample data
    data = Dataset.from_list([
        {"prompt": "Explain AI", "chosen": "AI is...", "rejected": "dunno"},
    ])
    
    # ORPO config
    training_args = ORPOConfig(
        output_dir=output_dir,
        beta=0.1,  # Odds ratio weight (lambda in paper)
        per_device_train_batch_size=2,
        bf16=True,
        gradient_checkpointing=True,
        report_to="none",
    )
    
    trainer = ORPOTrainer(
        model=model,
        args=training_args,
        train_dataset=data,
        tokenizer=tokenizer,
    )
    
    print("\nORPO Key Settings:")
    print(f"  beta (lambda): {training_args.beta}")
    print(f"  Single training stage (SFT + preference combined)")
    print(f"  No reference model needed!")
    print(f"  50% memory savings vs DPO!")
    
    return trainer

# Uncomment to run
# trainer = create_orpo_training()
# trainer.train()

## Decision Guide

```
┌─────────────────────────────────────────┐
│          Which method to use?           │
└─────────────────────────────────────────┘
                    │
          ┌─────────┴─────────┐
          │ Memory limited?   │
          └─────────┬─────────┘
           Yes      │      No
            │       │       │
     ┌──────┴───┐   │   ┌───┴──────┐
     │  ORPO    │   │   │ Need best│
     │(single   │   │   │ quality? │
     │ stage!)  │   │   └────┬─────┘
     └──────────┘   │     Yes│    No
                    │        │     │
              ┌─────┴────┐ ┌─┴───┐ ┌┴────┐
              │  SimPO   │ │SimPO│ │ DPO │
              │  (+6.4%  │ │     │ │     │
              │AlpacaEval)│ └─────┘ └─────┘
              └──────────┘
```

## Key Takeaways

1. **SimPO**: No ref model + length normalization = +6.4% on AlpacaEval
2. **ORPO**: Single stage (SFT + preference) = 50% memory savings
3. **Both beat DPO** while being more memory efficient
4. **Use SimPO** for best quality, **ORPO** for memory constraints