# RLHF & Direct Preference Optimization (DPO)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/rlhf_dpo.ipynb)

Aligning language models with human preferences using DPO (a stable alternative to PPO).

Key Concepts:
1. **Preference Data:** Tuples of $(x, y_w, y_l)$ where $x$ is prompt, $y_w$ is winning response, $y_l$ is losing response.
2. **Policy vs Reference:** We maximize the probability of winning responses while staying close to the original (reference) model to prevent mode collapse.
3. **DPO Loss:** Analytically solves the RL objective without a Reward Model or PPO.

$$ \mathcal{L}_{DPO} = -\mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)} \right) \right] $$

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. DPO Loss Implementation

In [None]:
def dpo_loss(policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps, beta=0.1):
    """
    Computes DPO loss.
    Input logps: (batch_size,)
    """
    
    # Calculate log-ratios
    # pi_theta(y|x) / pi_ref(y|x) in log space is: log(pi_theta) - log(pi_ref)
    chosen_logratios = policy_chosen_logps - ref_chosen_logps
    rejected_logratios = policy_rejected_logps - ref_rejected_logps
    
    # Estimate preference log-odds
    logits = chosen_logratios - rejected_logratios
    
    # Loss = -log(sigmoid(beta * logits))
    # using softplus for stability: -log(sigmoid(x)) = softplus(-x)
    losses = -F.logsigmoid(beta * logits)
    
    # Rewards (implicit) for tracking
    chosen_rewards = beta * chosen_logratios.detach()
    rejected_rewards = beta * rejected_logratios.detach()
    
    return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()

## 2. Mock Training Loop

In [None]:
# Mock Model (Tiny GPT)
class TinyGPT(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.lm_head = nn.Linear(d_model, vocab_size)
        
    def forward(self, idx):
        x = self.token_emb(idx)
        return self.lm_head(x)

vocab_size = 100
d_model = 32
policy_model = TinyGPT(vocab_size, d_model).to(device)
ref_model = TinyGPT(vocab_size, d_model).to(device)
# Ref model is frozen
for p in ref_model.parameters():
    p.requires_grad = False

optimizer = torch.optim.AdamW(policy_model.parameters(), lr=1e-3)

# Mock Batch of Data (Indices)
batch_size = 4
seq_len = 10
chosen_ids = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)
rejected_ids = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)

def get_logps(model, input_ids):
    logits = model(input_ids)
    # Calculate log_softmax
    log_probs = F.log_softmax(logits, dim=-1)
    
    # Gather log_prob of the actual tokens in the sequence
    # input_ids: (B, L)
    # log_probs: (B, L, V)
    gathered_log_probs = torch.gather(log_probs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
    
    return gathered_log_probs.sum(dim=-1)

print("Training DPO Step...")
# 1. Forward pass policy
policy_chosen_logps = get_logps(policy_model, chosen_ids)
policy_rejected_logps = get_logps(policy_model, rejected_ids)

# 2. Forward pass reference (no grad)
with torch.no_grad():
    ref_chosen_logps = get_logps(ref_model, chosen_ids)
    ref_rejected_logps = get_logps(ref_model, rejected_ids)

# 3. Loss
loss, r_chosen, r_rejected = dpo_loss(
    policy_chosen_logps, policy_rejected_logps,
    ref_chosen_logps, ref_rejected_logps
)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"Loss: {loss.item():.4f}")
print(f"Reward (Chosen): {r_chosen.item():.4f}")
print(f"Reward (Rejected): {r_rejected.item():.4f}")