# RLHF & Direct Preference Optimization (DPO)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/rlhf_dpo.ipynb)

## 1. Mathematical Derivation of DPO

The standard RLHF objective maximizes the expected reward while penalizing deviation from the reference model:

$$ \max_{\pi} \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi(\cdot|x)} [r(x, y)] - \beta \mathbb{D}_{KL}[\pi(y|x) || \pi_{ref}(y|x)] $$

The optimal solution to this maximation problem has a closed form (Gibbs distribution):

$$ \pi^*(y|x) = \frac{1}{Z(x)} \pi_{ref}(y|x) \exp\left( \frac{1}{\beta} r(x,y) \right) $$

Rafailov et al. (2023) rearranged this to express the reward function in terms of the optimal policy:

$$ r(x,y) = \beta \log \frac{\pi^*(y|x)}{\pi_{ref}(y|x)} + \beta \log Z(x) $$

Substituting this into the Bradley-Terry preference model $P(y_w \succ y_l | x) = \sigma(r(x, y_w) - r(x, y_l))$, the partition function $Z(x)$ cancels out, yielding the DPO loss:

$$ \mathcal{L}_{DPO}(\pi_\theta; \pi_{ref}) = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)} \right) \right] $$

In [None]:
!pip install torch torchvision matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. DPO Loss Implementation

In [None]:
def dpo_loss(policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps, beta=0.1):
    """
    Computes DPO loss.
    Input logps: (batch_size,)
    """
    
    # Calculate log-ratios
    # pi_theta(y|x) / pi_ref(y|x) in log space is: log(pi_theta) - log(pi_ref)
    chosen_logratios = policy_chosen_logps - ref_chosen_logps
    rejected_logratios = policy_rejected_logps - ref_rejected_logps
    
    # Estimate preference log-odds (implicit reward difference)
    logits = chosen_logratios - rejected_logratios
    
    # Loss = -log(sigmoid(beta * logits))
    losses = -F.logsigmoid(beta * logits)
    
    # Rewards (implicit) for tracking
    chosen_rewards = beta * chosen_logratios.detach()
    rejected_rewards = beta * rejected_logratios.detach()
    
    return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()

## 3. Mock Training Loop with Margin Visualization

In [None]:
# Mock Model (Tiny GPT)
class TinyGPT(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.lm_head = nn.Linear(d_model, vocab_size)
        
    def forward(self, idx):
        x = self.token_emb(idx)
        return self.lm_head(x)

vocab_size = 100
d_model = 32
policy_model = TinyGPT(vocab_size, d_model).to(device)
ref_model = TinyGPT(vocab_size, d_model).to(device)
# Ref model is frozen and identical init (policy starts same as ref)
ref_model.load_state_dict(policy_model.state_dict())
for p in ref_model.parameters():
    p.requires_grad = False

optimizer = torch.optim.AdamW(policy_model.parameters(), lr=0.05) # High LR for demo

# Mock Data
batch_size = 16
seq_len = 10
chosen_ids = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)
rejected_ids = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)

def get_logps(model, input_ids):
    logits = model(input_ids)
    log_probs = F.log_softmax(logits, dim=-1)
    gathered_log_probs = torch.gather(log_probs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
    return gathered_log_probs.sum(dim=-1)

margins = []

print("Training DPO Step...")
for i in range(50):
    # 1. Forward pass policy
    policy_chosen_logps = get_logps(policy_model, chosen_ids)
    policy_rejected_logps = get_logps(policy_model, rejected_ids)

    # 2. Forward pass reference (no grad)
    with torch.no_grad():
        ref_chosen_logps = get_logps(ref_model, chosen_ids)
        ref_rejected_logps = get_logps(ref_model, rejected_ids)

    # 3. Loss
    loss, r_chosen, r_rejected = dpo_loss(
        policy_chosen_logps, policy_rejected_logps,
        ref_chosen_logps, ref_rejected_logps
    )

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    margin = r_chosen - r_rejected
    margins.append(margin.item())
    
    if i % 10 == 0:
        print(f"Step {i}: Loss {loss.item():.4f}, Margin {margin.item():.4f}")

plt.figure(figsize=(8, 4))
plt.plot(margins)
plt.title("Reward Margin ($r_{chosen} - r_{rejected}$)")
plt.xlabel("Step")
plt.ylabel("Margin")
plt.grid(True)
plt.show()
print("Positive margin means model prefers chosen response over rejected.")