# dpo.py

Auto-generated implementation from the Agentic RL PhD codebase.

### Original Implementations & References
The following links point to the official or high-quality reference implementations for the papers covered in this notebook:

- https://github.com/eric-mitchell/direct-preference-optimization

*Note: The code below is a simplified pedagogical implementation.*

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Paper: "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" (Rafailov et al., 2023)
# Category: Alignment / RLHF Alternative

def dpo_loss(policy_chosen_logps, policy_rejected_logps, 
             ref_chosen_logps, ref_rejected_logps, 
             beta=0.1):
    """
    The DPO Loss Function.
    
    Args:
        policy_chosen_logps: Log prob of the 'chosen' response under the policy model.
        policy_rejected_logps: Log prob of the 'rejected' response under the policy model.
        ref_chosen_logps: Log prob of the 'chosen' response under the reference (frozen) model.
        ref_rejected_logps: Log prob of the 'rejected' response under the reference (frozen) model.
        beta: Temperature parameter (controls deviation from reference).
        
    Returns:
        losses: The DPO loss for each example in the batch.
        rewards: Implicit rewards (for logging).
    """
    
    # Equation 4 from the paper:
    # r(x,y) = beta * (log pi(y|x) - log ref(y|x))
    
    # Calculate log ratios
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = ref_chosen_logps - ref_rejected_logps
    
    # Calculate logits for the sigmoid
    # logits = log(pi_chosen/ref_chosen) - log(pi_rejected/ref_rejected)
    logits = pi_logratios - ref_logratios
    
    # The DPO loss is -log(sigmoid(beta * logits))
    losses = -F.logsigmoid(beta * logits)
    
    # Implicit rewards (for tracking progress)
    chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps).detach()
    rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps).detach()
    
    return losses, chosen_rewards, rejected_rewards

class DPOTrainer(nn.Module):
    def __init__(self, model, ref_model, beta=0.1):
        super().__init__()
        self.model = model
        self.ref_model = ref_model # Frozen copy
        self.beta = beta
        
    def forward(self, chosen_ids, rejected_ids):
        # This is a simplified forward pass logic
        # In practice, you need to gather log probs of specific tokens
        pass
