# DeepSeek-R1: Group Relative Policy Optimization (GRPO)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/deepseek_r1_grpo.ipynb)

DeepSeek-R1 introduced a reinforcement learning method called **GRPO** that removes the need for a Critic model (Value Function), making RL efficiency much higher for reasoning tasks.

### Key Idea
Instead of estimating an "advantage" $A_t$ using a separate Critic network (like PPO), GRPO samples a **group** of $G$ outputs for the *same* prompt.

The advantage of the $i$-th output is simply its normalized score relative to the group:
$$ A_i = \frac{r_i - \text{mean}(R)}{\text{std}(R)} $$

Where $R = \{r_1, r_2, ..., r_G\}$ are the rewards for the group.

This encourages the model to generate answers that are better than its own average attempts.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Minimal Policy Model
A simple language model (like a tiny GPT) that we want to train.

In [None]:
class TinyPolicy(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.net = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, vocab_size)
        )
    
    def forward(self, x):
        # x: (batch, seq_len)
        emb = self.embedding(x)
        logits = self.net(emb)
        return logits

## 2. GRPO Loss Function
This computes the policy gradient loss similar to PPO but using group-normalized advantages.

In [None]:
def grpo_loss(logits, old_log_probs, actions, advantages, epsilon=0.2):
    """
    logits: (Batch, Seq_Len, Vocab_Size) - Current policy outputs
    old_log_probs: (Batch, Seq_Len) - Log probs of actions from old policy (before update)
    actions: (Batch, Seq_Len) - The token indices generated
    advantages: (Batch,) - Group relative advantage for each sequence
    epsilon: PPO clipping parameter
    """
    
    # 1. Calculate current log probs
    log_probs = F.log_softmax(logits, dim=-1)
    # Gather log prob of the actual taken actions
    action_log_probs = torch.gather(log_probs, dim=2, index=actions.unsqueeze(-1)).squeeze(-1)
    
    # 2. Probability Ratio r_t(theta) = exp(log_p_new - log_p_old)
    ratio = torch.exp(action_log_probs - old_log_probs)
    
    # 3. PPO-style Clipped Objective
    # We expand advantages to match sequence length for broadcasting if needed,
    # or usually we average over the sequence. Here let's assume advantage applies to the whole seq.
    adv_expanded = advantages.unsqueeze(1).expand_as(ratio)
    
    surr1 = ratio * adv_expanded
    surr2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * adv_expanded
    
    # Maximize objective => Minimize -objective
    loss = -torch.min(surr1, surr2).mean()
    
    # Optional: KL Divergence penalty to reference model would be added here in full R1.
    # loss += beta * KL(policy || ref)
    
    return loss

## 3. Training Loop Simulation
We verify mathematical correctness with dummy data.

In [None]:
# Setup
group_size = 4 # G in the paper
vocab_size = 100
seq_len = 10

model = TinyPolicy(vocab_size, 32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# SIMULATE A BATCH OF GENERATIONS
# Assume we had 1 prompt, and we generated 'group_size' completions.
actions = torch.randint(0, vocab_size, (group_size, seq_len)).to(device)

# SIMULATE REWARDS
# Let's say we have an oracle reward model or verifier (e.g. math solver)
# Completion 0 and 2 were "correct" (high reward), others wrong.
rewards = torch.tensor([1.0, 0.1, 1.0, 0.2]).to(device)

# 1. Compute Group Advantages
mean_r = rewards.mean()
std_r = rewards.std() + 1e-8 # Prevent div by zero
advantages = (rewards - mean_r) / std_r

print("Rewards:", rewards)
print("Group Advantages:", advantages)

# 2. Get 'Old' Log Probs (Assuming we just sampled them from current model for this step)
with torch.no_grad():
    logits_old = model(actions)
    log_probs_old_all = F.log_softmax(logits_old, dim=-1)
    old_log_probs = torch.gather(log_probs_old_all, dim=2, index=actions.unsqueeze(-1)).squeeze(-1)

# 3. Optimization Step
logits_new = model(actions)
loss = grpo_loss(logits_new, old_log_probs, actions, advantages)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"GRPO Step Loss: {loss.item():.4f}")