In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# GRPO from Scratch -- Vizuara

In this notebook, we will build Group Relative Policy Optimization (GRPO) from scratch. GRPO is the RL algorithm behind DeepSeek-R1 that teaches language models to reason. By the end, you will have a working GRPO implementation that you understand line by line.

**What you will build:** A complete GRPO training loop that computes group-relative advantages, clipped surrogate losses, and KL penalties.
```

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

print(f"PyTorch version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

torch.manual_seed(42)
np.random.seed(42)

## 1. Why Does This Matter?

Supervised fine-tuning teaches a model the *format* of reasoning, but not the *quality*. To teach a model which reasoning strategies actually lead to correct answers, we need reinforcement learning.

GRPO (Group Relative Policy Optimization) is the algorithm that made DeepSeek-R1 possible. It is elegant, efficient, and does not require a separate value network (critic). Instead, it uses group statistics as a baseline — a beautifully simple idea that works remarkably well.

By the end of this notebook, you will:
- Understand how GRPO computes advantages without a critic
- Implement the clipped surrogate objective from scratch
- Visualize how clipping prevents catastrophic policy updates
- Implement the KL divergence penalty
```

## 2. Building Intuition

Imagine you are a teacher grading a math exam. You have 4 students who each attempted the same problem:
- Student A: Correct answer (reward = 1)
- Student B: Wrong answer (reward = 0)
- Student C: Correct answer (reward = 1)
- Student D: Wrong answer (reward = 0)

The class average is 0.5. Students A and C are **above average** — you want to encourage their approaches. Students B and D are **below average** — you want to discourage theirs.

This is exactly what GRPO does! For each prompt, it generates multiple completions (the "group"), scores them, and then encourages the good ones while discouraging the bad ones — all relative to the group average.

### Think About This
- Why is using the group average better than using a fixed baseline of 0?
- What happens if ALL completions in the group get the same reward?
```

## 3. The Mathematics

### Group-Relative Advantages

For a group of $G$ completions with rewards $\mathbf{r} = [r_1, r_2, \ldots, r_G]$, the advantage of the $i$-th completion is:

$$\hat{A}_i = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r}) + \epsilon}$$

**Computational meaning:** This is z-score normalization. Completions above the group mean get positive advantages (encouraged), below get negative (discouraged). The standard deviation normalizes the scale.

### The Clipped Surrogate Objective

$$\mathcal{L}_{\text{GRPO}} = -\frac{1}{G}\sum_{i=1}^{G} \min\left(\frac{\pi_\theta(y_i|x)}{\pi_{\text{old}}(y_i|x)} \hat{A}_i,\; \text{clip}\left(\frac{\pi_\theta(y_i|x)}{\pi_{\text{old}}(y_i|x)}, 1-\epsilon, 1+\epsilon\right) \hat{A}_i\right)$$

**Computational meaning:** The ratio $\frac{\pi_\theta}{\pi_{\text{old}}}$ measures how much the policy has changed. If it changes too much (ratio far from 1), the clipping limits the gradient signal, preventing catastrophic updates.

### KL Divergence Penalty

$$R_{\text{total}} = R_{\text{outcome}} - \beta \cdot D_{\text{KL}}(\pi_\theta \| \pi_{\text{ref}})$$

**Computational meaning:** We subtract a penalty proportional to how much the current policy has diverged from the reference. This prevents the model from "forgetting" useful behaviors.
```

In [None]:
# Let us compute group-relative advantages manually
rewards = torch.tensor([1.0, 0.0, 1.0, 0.0])

mean_r = rewards.mean()
std_r = rewards.std()
advantages = (rewards - mean_r) / (std_r + 1e-8)

print("=== Group-Relative Advantage Computation ===")
print(f"Rewards:    {rewards.tolist()}")
print(f"Mean:       {mean_r:.2f}")
print(f"Std:        {std_r:.2f}")
print(f"Advantages: {advantages.tolist()}")
print(f"\nPositive advantages -> encourage these completions")
print(f"Negative advantages -> discourage these completions")

## 4. Let's Build It -- Component by Component

### 4.1 Group-Relative Advantage Function
```

In [None]:
def compute_group_advantages(rewards):
    """
    Compute group-relative advantages via z-score normalization.

    Args:
        rewards: Tensor of shape (G,) with reward for each completion

    Returns:
        advantages: Tensor of shape (G,) with normalized advantages
    """
    mean = rewards.mean()
    std = rewards.std()

    # If all rewards are the same, advantages are all zero
    if std < 1e-8:
        return torch.zeros_like(rewards)

    advantages = (rewards - mean) / (std + 1e-8)
    return advantages

# Test with different reward distributions
test_cases = [
    torch.tensor([1.0, 0.0, 1.0, 0.0]),   # Binary: half correct
    torch.tensor([1.0, 1.0, 1.0, 0.0]),   # Mostly correct
    torch.tensor([1.0, 0.0, 0.0, 0.0]),   # Mostly wrong
    torch.tensor([1.0, 1.0, 1.0, 1.0]),   # All correct
]

for rewards in test_cases:
    adv = compute_group_advantages(rewards)
    print(f"Rewards: {rewards.tolist()} -> Advantages: {[f'{a:.2f}' for a in adv.tolist()]}")

### Visualization Checkpoint: How Advantages Change with Group Composition
```

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
labels = ["2 correct\n2 wrong", "3 correct\n1 wrong", "1 correct\n3 wrong", "All correct"]

for ax, rewards, label in zip(axes, test_cases, labels):
    adv = compute_group_advantages(rewards)
    colors = ['#4CAF50' if a > 0 else '#F44336' if a < 0 else '#9E9E9E'
              for a in adv.tolist()]
    ax.bar(range(len(adv)), adv.tolist(), color=colors)
    ax.axhline(y=0, color='black', linewidth=0.5)
    ax.set_title(label, fontsize=11)
    ax.set_xlabel("Completion")
    ax.set_ylabel("Advantage")
    ax.set_ylim(-2.5, 2.5)

plt.suptitle("Group-Relative Advantages for Different Reward Distributions", fontsize=13, y=1.02)
plt.tight_layout()
plt.show()

### 4.2 Probability Ratio and Clipping

The probability ratio tells us how much the policy has changed for a given completion.
```

In [None]:
def compute_clipped_objective(ratios, advantages, epsilon=0.2):
    """
    Compute the clipped surrogate objective.

    Args:
        ratios: Probability ratios pi_new / pi_old, shape (G,)
        advantages: Group-relative advantages, shape (G,)
        epsilon: Clipping parameter (default 0.2)

    Returns:
        objective: Per-completion clipped objective, shape (G,)
    """
    # Unclipped objective
    unclipped = ratios * advantages

    # Clipped objective
    clipped_ratios = torch.clamp(ratios, 1 - epsilon, 1 + epsilon)
    clipped = clipped_ratios * advantages

    # Take the minimum (pessimistic bound)
    objective = torch.min(unclipped, clipped)

    return objective

# Demonstrate with a concrete example
ratio = torch.tensor(1.5)  # Policy changed significantly
advantage = torch.tensor(1.0)  # This was a good completion
epsilon = 0.2

obj = compute_clipped_objective(ratio, advantage, epsilon)
print(f"Ratio: {ratio.item():.1f}, Advantage: {advantage.item():.1f}")
print(f"Unclipped: {ratio.item() * advantage.item():.1f}")
print(f"Clipped ratio: {torch.clamp(ratio, 1-epsilon, 1+epsilon).item():.1f}")
print(f"Clipped objective: {obj.item():.1f}")
print(f"\nClipping prevented the objective from being {ratio.item() * advantage.item():.1f}")
print(f"and capped it at {obj.item():.1f} — trust region enforced!")

### Visualization Checkpoint: The Clipping Function
```

In [None]:
# Visualize clipping for both positive and negative advantages
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ratios_range = torch.linspace(0.3, 2.5, 200)
epsilon = 0.2

# Positive advantage
A_pos = 1.0
unclipped_pos = ratios_range * A_pos
clipped_pos = compute_clipped_objective(ratios_range, torch.full_like(ratios_range, A_pos), epsilon)

ax1.plot(ratios_range, unclipped_pos, '--', color='#2196F3', linewidth=2, label='Unclipped')
ax1.plot(ratios_range, clipped_pos, '-', color='#FF9800', linewidth=2.5, label='Clipped')
ax1.axvline(x=1-epsilon, color='gray', linewidth=0.5, linestyle=':')
ax1.axvline(x=1+epsilon, color='gray', linewidth=0.5, linestyle=':')
ax1.axvline(x=1, color='black', linewidth=0.5, linestyle='-')
ax1.fill_betweenx([0, 3], 1-epsilon, 1+epsilon, alpha=0.1, color='green')
ax1.set_title('Positive Advantage (A > 0)', fontsize=13)
ax1.set_xlabel('Probability Ratio', fontsize=11)
ax1.set_ylabel('Objective', fontsize=11)
ax1.legend(fontsize=11)
ax1.set_ylim(-0.5, 3)
ax1.grid(True, alpha=0.3)

# Negative advantage
A_neg = -1.0
unclipped_neg = ratios_range * A_neg
clipped_neg = compute_clipped_objective(ratios_range, torch.full_like(ratios_range, A_neg), epsilon)

ax2.plot(ratios_range, unclipped_neg, '--', color='#2196F3', linewidth=2, label='Unclipped')
ax2.plot(ratios_range, clipped_neg, '-', color='#FF9800', linewidth=2.5, label='Clipped')
ax2.axvline(x=1-epsilon, color='gray', linewidth=0.5, linestyle=':')
ax2.axvline(x=1+epsilon, color='gray', linewidth=0.5, linestyle=':')
ax2.axvline(x=1, color='black', linewidth=0.5, linestyle='-')
ax2.fill_betweenx([-3, 0], 1-epsilon, 1+epsilon, alpha=0.1, color='green')
ax2.set_title('Negative Advantage (A < 0)', fontsize=13)
ax2.set_xlabel('Probability Ratio', fontsize=11)
ax2.set_ylabel('Objective', fontsize=11)
ax2.legend(fontsize=11)
ax2.set_ylim(-3, 0.5)
ax2.grid(True, alpha=0.3)

plt.suptitle('PPO/GRPO Clipping Mechanism', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

### 4.3 KL Divergence Penalty
```

In [None]:
def compute_kl_penalty(log_probs_current, log_probs_reference):
    """
    Compute per-token KL divergence between current and reference policy.

    Args:
        log_probs_current: Log probs from current policy, shape (seq_len,)
        log_probs_reference: Log probs from reference policy, shape (seq_len,)

    Returns:
        kl_div: Scalar KL divergence
    """
    # KL(pi_theta || pi_ref) = sum(pi_theta * (log pi_theta - log pi_ref))
    # Simplified: for the sampled tokens, KL approx = sum(log_p_current - log_p_ref)
    kl_div = (log_probs_current - log_probs_reference).sum()
    return kl_div

# Example
log_probs_current = torch.tensor([-0.5, -0.8, -0.3, -1.0])
log_probs_reference = torch.tensor([-0.6, -0.9, -0.4, -1.1])

kl = compute_kl_penalty(log_probs_current, log_probs_reference)
print(f"Current log probs:   {log_probs_current.tolist()}")
print(f"Reference log probs: {log_probs_reference.tolist()}")
print(f"KL divergence: {kl.item():.3f}")
print(f"\nSmall KL -> policy has not changed much (good)")
print(f"Large KL -> policy has drifted far (penalized)")

## 5. Your Turn -- TODO Exercises

### TODO 1: Complete the GRPO Loss Function

Combine all components into a single GRPO loss function.
```

In [None]:
def grpo_loss(log_probs_new, log_probs_old, log_probs_ref,
              rewards, epsilon=0.2, beta=0.1):
    """
    Complete GRPO loss with group-relative advantages, clipping, and KL penalty.

    Args:
        log_probs_new: Log probs under current policy, shape (G,)
        log_probs_old: Log probs under old policy (before update), shape (G,)
        log_probs_ref: Log probs under reference policy (frozen), shape (G,)
        rewards: Rewards for each completion, shape (G,)
        epsilon: Clipping parameter
        beta: KL penalty weight

    Returns:
        loss: Scalar loss to minimize
    """
    # ============ TODO ============
    # Step 1: Compute group-relative advantages
    advantages = ???  # YOUR CODE HERE

    # Step 2: Compute probability ratios
    ratios = ???  # YOUR CODE HERE: exp(log_new - log_old)

    # Step 3: Compute KL penalty for each completion
    kl_penalties = ???  # YOUR CODE HERE: log_new - log_ref

    # Step 4: Adjust rewards with KL penalty
    adjusted_rewards = rewards - beta * kl_penalties.detach()

    # Step 5: Recompute advantages with adjusted rewards
    advantages = compute_group_advantages(adjusted_rewards)

    # Step 6: Compute clipped surrogate objective
    unclipped = ratios * advantages
    clipped = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
    surrogate = torch.min(unclipped, clipped)

    # Step 7: The loss is the negative mean (we want to maximize the objective)
    loss = -surrogate.mean()
    # ==============================

    return loss

In [None]:
# Verification: Test your GRPO loss function
log_probs_new = torch.tensor([-1.0, -2.0, -0.5, -3.0])
log_probs_old = torch.tensor([-1.1, -1.9, -0.6, -2.8])
log_probs_ref = torch.tensor([-1.2, -2.1, -0.7, -3.1])
rewards = torch.tensor([1.0, 0.0, 1.0, 0.0])

try:
    loss = grpo_loss(log_probs_new, log_probs_old, log_probs_ref, rewards)
    print(f"GRPO Loss: {loss.item():.4f}")
    assert isinstance(loss.item(), float), "Loss should be a scalar float"
    print("Your GRPO loss function works correctly!")
except Exception as e:
    print(f"Error: {e}")
    print("Check your implementation above.")

### TODO 2: Visualize How KL Penalty Affects Total Reward

Vary beta from 0 to 1 and plot how the total reward changes for a fixed outcome reward and KL divergence.
```

In [None]:
# ============ TODO ============
# Plot total_reward = R_outcome - beta * D_KL
# for beta in [0, 0.01, 0.05, 0.1, 0.2, 0.5, 1.0]
# and D_KL values of [1.0, 5.0, 10.0]
#
# betas = ???
# kl_values = ???
# R_outcome = 1.0
#
# plt.figure(figsize=(10, 5))
# for kl in kl_values:
#     total_rewards = ???
#     plt.plot(betas, total_rewards, label=f'KL = {kl}')
# plt.xlabel('Beta')
# plt.ylabel('Total Reward')
# plt.legend()
# plt.title('Effect of KL Penalty on Total Reward')
# plt.grid(True, alpha=0.3)
# plt.show()
# ==============================

## 6. Putting It All Together

Let us run a complete GRPO step on simulated data to see all the pieces working together.
```

In [None]:
# Simulated GRPO training step
G = 8  # Group size
NUM_STEPS = 100

# Simulate a model that gradually improves
reward_history = []
loss_history = []

for step in range(NUM_STEPS):
    # Simulate improving rewards over time
    base_prob = min(0.3 + step * 0.005, 0.8)
    rewards = torch.bernoulli(torch.full((G,), base_prob))

    # Simulated log probabilities (getting closer to optimal)
    log_probs_new = -torch.rand(G) * (2.0 - step * 0.01)
    log_probs_old = log_probs_new - 0.1 * torch.randn(G)
    log_probs_ref = -torch.rand(G) * 2.0

    # Compute GRPO components
    advantages = compute_group_advantages(rewards)
    ratios = torch.exp(log_probs_new - log_probs_old)
    obj = compute_clipped_objective(ratios, advantages)

    loss = -obj.mean()

    reward_history.append(rewards.mean().item())
    loss_history.append(loss.item())

print(f"Initial avg reward: {sum(reward_history[:5])/5:.3f}")
print(f"Final avg reward:   {sum(reward_history[-5:])/5:.3f}")

### Visualization Checkpoint: Simulated GRPO Training
```

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Reward curve
window = 10
smoothed_rewards = [sum(reward_history[max(0,i-window):i+1]) / min(i+1, window)
                    for i in range(len(reward_history))]
ax1.plot(smoothed_rewards, color='#4CAF50', linewidth=2)
ax1.set_xlabel('Training Step', fontsize=11)
ax1.set_ylabel('Average Reward', fontsize=11)
ax1.set_title('Reward Over Training', fontsize=13)
ax1.grid(True, alpha=0.3)

# Loss curve
smoothed_loss = [sum(loss_history[max(0,i-window):i+1]) / min(i+1, window)
                 for i in range(len(loss_history))]
ax2.plot(smoothed_loss, color='#2196F3', linewidth=2)
ax2.set_xlabel('Training Step', fontsize=11)
ax2.set_ylabel('GRPO Loss', fontsize=11)
ax2.set_title('Loss Over Training', fontsize=13)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Training and Results

In a real setting, the GRPO training loop would:
1. Sample prompts from the GSM8K dataset
2. Generate G completions from the current model
3. Score each completion with the verifiable reward
4. Compute advantages, ratios, and the clipped loss
5. Backpropagate and update the model

The key insight: because we use group-relative baselines instead of a learned value function, we save memory and complexity. The group itself provides the baseline.
```

## 8. Final Output

We have built GRPO from scratch! Here is a summary of the components:

| Component | Purpose |
|-----------|---------|
| Group-relative advantages | Replace the value network with simple statistics |
| Probability ratio | Measure how much the policy changed |
| Clipping | Prevent catastrophic policy updates |
| KL penalty | Prevent drift from the reference model |
```

## 9. Reflection and Next Steps

### Think About This
1. What happens if the group size G is very small (e.g., 2)? Very large (e.g., 64)?
2. How does the choice of epsilon affect training stability vs speed?
3. Why is GRPO particularly well-suited for reasoning tasks where rewards are binary?

### What Comes Next
In Notebook 3, we will put everything together — combining SFT from Notebook 1 with GRPO from this notebook to train a full reasoning model on the GSM8K math dataset.

### Key Takeaway
GRPO's genius is its simplicity: by using group statistics instead of a learned critic, it eliminates an entire network while maintaining the benefits of policy optimization with trust regions.
```