# DeepSeek-R1: Group Relative Policy Optimization (GRPO)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/deepseek_r1_grpo.ipynb)

## 1. Mathematical Foundations: GRPO Algorithm

DeepSeek-R1 avoids the need for a Value Function (Critic) by using group-based relative advantages.

For each prompt $q$, the model samples a group of $G$ outputs $\{o_1, o_2, \dots, o_G\}$ from the old policy $\pi_{\theta_{old}}$.

### Group Advantage
The advantage for the $i$-th output is computed by normalizing the rewards within the group:

$$ A_i = \frac{r_i - \text{mean}(\{r_1, \dots, r_G\})}{\text{std}(\{r_1, \dots, r_G\})} $$

### GRPO Objective
The objective maximizes the PPO-clipped surrogate using this group advantage, plus a KL penalty to keep the policy close to a reference model $\pi_{ref}$:

$$ \mathcal{L}_{GRPO}(\theta) = \frac{1}{G} \sum_{i=1}^G \left( \min \left( \frac{\pi_\theta(o_i|q)}{\pi_{\theta_{old}}(o_i|q)} A_i, \text{clip}(\dots) A_i \right) - \beta D_{KL}(\pi_\theta || \pi_{ref}) \right) $$

By removing the critic, GRPO saves significant memory and training stability, especially for reasoning tasks where judging partial steps is hard.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Policy Model Implementation

In [None]:
class TinyPolicy(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.net = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, vocab_size)
        )
    
    def forward(self, x):
        # x: (batch, seq_len)
        emb = self.embedding(x)
        logits = self.net(emb)
        return logits

## 3. GRPO Loss Function

In [None]:
def grpo_loss(logits, old_log_probs, actions, advantages, epsilon=0.2):
    """
    logits: (Batch, Seq_Len, Vocab_Size) - Current policy outputs
    old_log_probs: (Batch, Seq_Len) - Log probs of actions from old policy (before update)
    actions: (Batch, Seq_Len) - The token indices generated
    advantages: (Batch,) - Group relative advantage for each sequence
    epsilon: PPO clipping parameter
    """
    
    # 1. Calculate current log probs
    log_probs = F.log_softmax(logits, dim=-1)
    # Gather log prob of the actual taken actions
    action_log_probs = torch.gather(log_probs, dim=2, index=actions.unsqueeze(-1)).squeeze(-1)
    
    # 2. Probability Ratio r_t(theta) = exp(log_p_new - log_p_old)
    ratio = torch.exp(action_log_probs - old_log_probs)
    
    # 3. PPO-style Clipped Objective
    # We expand advantages to match sequence length for broadcasting
    # For simplicity, we assume advantage is scalar for whole sequence
    adv_expanded = advantages.unsqueeze(1).expand_as(ratio)
    
    surr1 = ratio * adv_expanded
    surr2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * adv_expanded
    
    loss = -torch.min(surr1, surr2).mean()
    
    return loss

## 4. Simulation with Advantage Visualization

In [None]:
# Setup
group_size = 64 # High group size for better visualization
vocab_size = 100
seq_len = 10

model = TinyPolicy(vocab_size, 32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# SIMULATE A BATCH
# actions = samples from the model (mocked here)
actions = torch.randint(0, vocab_size, (group_size, seq_len)).to(device)

# SIMULATE REWARDS
# Assume a sparse reward distribution (common in reasoning: most wrong (0), some right (1))
# Let's say 10% correct.
rewards = torch.zeros(group_size).to(device)
rewards[:int(group_size * 0.1)] = 1.0
rewards = rewards[torch.randperm(group_size)] # Shuffle

# 1. Compute Group Advantages
mean_r = rewards.mean()
std_r = rewards.std() + 1e-8
advantages = (rewards - mean_r) / std_r

print(f"Mean Reward: {mean_r:.4f}, Std Reward: {std_r:.4f}")

# 2. Forward Old (Mock)
with torch.no_grad():
    logits_old = model(actions)
    log_probs_old_all = F.log_softmax(logits_old, dim=-1)
    old_log_probs = torch.gather(log_probs_old_all, dim=2, index=actions.unsqueeze(-1)).squeeze(-1)

# 3. Optimize
logits_new = model(actions)
loss = grpo_loss(logits_new, old_log_probs, actions, advantages)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"GRPO Loss: {loss.item():.4f}")

# Plot Advantage Distribution
plt.figure(figsize=(8, 4))
plt.hist(advantages.cpu().numpy(), bins=10, alpha=0.7, color='purple', edgecolor='black')
plt.title(f"Distribution of Advantages $A_i$ (Group Size G={group_size})")
plt.xlabel("Advantage Value (Normalized Reward)")
plt.ylabel("Frequency")
plt.grid(axis='y', alpha=0.3)
plt.show()

print("Notice how rewards [0, 1] are transformed into advantages centered at 0.")