In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Implementing the GRPO Loss Function -- Vizuara

## 1. Why Does This Matter?

In the previous notebook, we learned how GRPO computes advantages by normalizing rewards within a group. But advantages alone do not train a model -- we need a **loss function** that uses those advantages to update the policy.

The GRPO loss has three key components:
1. **Clipped surrogate objective** (same idea as PPO -- prevent overly large updates)
2. **Group-relative advantages** (what we built in Notebook 1)
3. **KL divergence penalty** (prevent the policy from drifting too far from a reference)

In this notebook, you will build each component from scratch, verify them with numerical examples, and combine them into the complete GRPO training objective.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

print("Ready to build the GRPO loss!")

## 2. Building Intuition

### The Three Guards of GRPO

Think of training a language model like steering a ship:

1. **The Compass (Advantages):** Points toward better responses. GRPO's compass is the group-relative advantage -- it tells us which responses in the batch were better than average.

2. **The Anchor (Clipping):** Prevents the ship from turning too sharply. Even if a response looks amazing, we clip the update to prevent destabilizing the model.

3. **The Tether (KL Penalty):** Keeps us from drifting too far from port. The KL divergence penalty ensures the policy stays close to a reference model, preventing reward hacking.

Let us build each guard one at a time.

In [None]:
# Visualize the three components conceptually
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 1. Advantages
x = np.linspace(-3, 3, 100)
axes[0].fill_between(x[x > 0], 0, x[x > 0], alpha=0.3, color='green', label='Reinforce')
axes[0].fill_between(x[x < 0], x[x < 0], 0, alpha=0.3, color='red', label='Penalize')
axes[0].axhline(y=0, color='black', linewidth=0.5)
axes[0].set_title("1. Group-Relative Advantages", fontweight='bold')
axes[0].set_xlabel("Advantage Value")
axes[0].legend()

# 2. Clipping
ratio = np.linspace(0.5, 2.0, 100)
clipped = np.clip(ratio, 0.8, 1.2)
axes[1].plot(ratio, ratio, 'b--', alpha=0.5, label='Unclipped ratio')
axes[1].plot(ratio, clipped, 'r-', linewidth=2, label='Clipped ratio')
axes[1].axvline(x=0.8, color='gray', linestyle=':', alpha=0.5)
axes[1].axvline(x=1.2, color='gray', linestyle=':', alpha=0.5)
axes[1].set_title("2. Ratio Clipping (epsilon=0.2)", fontweight='bold')
axes[1].set_xlabel("Policy Ratio")
axes[1].legend()

# 3. KL Penalty
r = np.linspace(0.1, 5, 100)
kl = r - np.log(r) - 1
axes[2].plot(r, kl, 'purple', linewidth=2)
axes[2].axvline(x=1.0, color='green', linestyle='--', alpha=0.5, label='No drift (KL=0)')
axes[2].set_title("3. KL Divergence Penalty", fontweight='bold')
axes[2].set_xlabel("pi_ref / pi_theta")
axes[2].set_ylabel("KL value")
axes[2].legend()

plt.tight_layout()
plt.savefig("grpo_three_components.png", dpi=150, bbox_inches='tight')
plt.show()

## 3. The Mathematics

### Component 1: The Importance Sampling Ratio

For each token at position $t$ in response $i$:

$$r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,<t})}$$

This ratio tells us how much more (or less) likely the current policy is to produce this token compared to the old policy.

Let us work through a numerical example. Suppose:
- Current policy probability: $\pi_\theta = 0.3$
- Old policy probability: $\pi_{\theta_\text{old}} = 0.2$

$$r_{i,t} = \frac{0.3}{0.2} = 1.5$$

The new policy is 50% more likely to produce this token.

In log space (which we use in practice for numerical stability):

$$\log r_{i,t} = \log \pi_\theta - \log \pi_{\theta_\text{old}} = \log(0.3) - \log(0.2) = -1.204 - (-1.609) = 0.405$$

$$r_{i,t} = e^{0.405} = 1.5$$

This is exactly what we want.

### Component 2: The Clipped Surrogate

$$L_{\text{clip}} = \min\left(r_{i,t} \cdot \hat{A}_i, \; \text{clip}(r_{i,t}, 1-\epsilon, 1+\epsilon) \cdot \hat{A}_i\right)$$

With $\epsilon = 0.2$, the ratio is clipped to $[0.8, 1.2]$.

Numerical example with $r_{i,t} = 1.5$, $\hat{A}_i = 2.0$:
- Unclipped: $1.5 \times 2.0 = 3.0$
- Clipped ratio: $\text{clip}(1.5, 0.8, 1.2) = 1.2$
- Clipped: $1.2 \times 2.0 = 2.4$
- $\min(3.0, 2.4) = 2.4$ (clipping reduces the update)

### Component 3: The KL Penalty

$$D_{\text{KL}} = \frac{\pi_{\text{ref}}}{\pi_\theta} - \log\frac{\pi_{\text{ref}}}{\pi_\theta} - 1$$

Verification when $\pi_\theta = \pi_{\text{ref}}$:
$$D_{\text{KL}} = 1 - \log(1) - 1 = 0$$

When $\frac{\pi_{\text{ref}}}{\pi_\theta} = 2$ (policy has drifted):
$$D_{\text{KL}} = 2 - \log(2) - 1 = 2 - 0.693 - 1 = 0.307$$

The penalty grows as the policy drifts further from the reference.

In [None]:
# Verify the KL divergence formula numerically
def kl_divergence_approx(ref_prob, theta_prob):
    """KL approximation used in GRPO."""
    ratio = ref_prob / theta_prob
    return ratio - torch.log(ratio) - 1.0

# Case 1: No drift
kl_no_drift = kl_divergence_approx(torch.tensor(0.3), torch.tensor(0.3))
print(f"KL when pi_ref = pi_theta: {kl_no_drift.item():.6f} (should be ~0)")

# Case 2: Small drift
kl_small = kl_divergence_approx(torch.tensor(0.3), torch.tensor(0.25))
print(f"KL with small drift:       {kl_small.item():.6f}")

# Case 3: Large drift
kl_large = kl_divergence_approx(torch.tensor(0.3), torch.tensor(0.1))
print(f"KL with large drift:       {kl_large.item():.6f}")

print("\nAs expected, KL increases as the policy drifts further from reference.")

## 4. Let's Build It -- Component by Component

### Building the Complete GRPO Loss

In [None]:
def compute_grpo_advantages(rewards: torch.Tensor) -> torch.Tensor:
    """Compute group-relative advantages (from Notebook 1)."""
    mean_r = rewards.mean()
    std_r = rewards.std()
    if std_r < 1e-8:
        return torch.zeros_like(rewards)
    return (rewards - mean_r) / std_r


def compute_grpo_loss(
    log_probs: torch.Tensor,       # (G, T) log probs under current policy
    old_log_probs: torch.Tensor,   # (G, T) log probs under old policy
    ref_log_probs: torch.Tensor,   # (G, T) log probs under reference policy
    advantages: torch.Tensor,       # (G,) per-response advantages
    mask: torch.Tensor,             # (G, T) attention mask (1 for valid, 0 for padding)
    epsilon: float = 0.2,
    beta: float = 0.04,
) -> dict:
    """
    Compute the full GRPO loss with all three components.

    Returns a dict with the total loss and individual components for logging.
    """
    # --- Component 1: Importance sampling ratio ---
    ratio = torch.exp(log_probs - old_log_probs)  # (G, T)

    # --- Component 2: Clipped surrogate ---
    adv = advantages.unsqueeze(1)  # (G, 1) -> broadcasts to (G, T)
    surr1 = ratio * adv
    surr2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * adv
    policy_loss = torch.min(surr1, surr2)  # (G, T)

    # --- Component 3: KL divergence penalty ---
    kl_ratio = torch.exp(ref_log_probs - log_probs)  # pi_ref / pi_theta
    kl_div = kl_ratio - torch.log(kl_ratio) - 1.0    # (G, T)

    # --- Combine ---
    per_token_objective = policy_loss - beta * kl_div  # (G, T)

    # Average over valid tokens, then over group
    per_response = (per_token_objective * mask).sum(dim=1) / mask.sum(dim=1)
    loss = -per_response.mean()  # Negate for minimization

    return {
        'loss': loss,
        'policy_loss': -(policy_loss * mask).sum() / mask.sum(),
        'kl_div': (kl_div * mask).sum() / mask.sum(),
        'mean_ratio': (ratio * mask).sum() / mask.sum(),
    }

print("GRPO loss function defined!")

### Testing with Synthetic Data

In [None]:
# Create synthetic data to test the loss function
G = 4   # Group size
T = 10  # Sequence length

torch.manual_seed(42)

# Simulate log probabilities
# Current policy, old policy, and reference policy
base_log_probs = torch.randn(G, T) * 0.5 - 2.0  # Base log probs
old_log_probs = base_log_probs.detach().clone()   # Old = current initially
ref_log_probs = base_log_probs + torch.randn(G, T) * 0.1  # Ref slightly different

# Simulate a small policy update
log_probs = old_log_probs + torch.randn(G, T) * 0.05
log_probs.requires_grad_(True)

# Rewards and advantages
rewards = torch.tensor([0.8, 0.3, 0.9, 0.1])
advantages = compute_grpo_advantages(rewards)

# Mask (all tokens valid in this example)
mask = torch.ones(G, T)

# Compute loss
result = compute_grpo_loss(log_probs, old_log_probs, ref_log_probs, advantages, mask)

print("=== GRPO Loss Components ===")
print(f"Total loss:    {result['loss'].item():.4f}")
print(f"Policy loss:   {result['policy_loss'].item():.4f}")
print(f"KL divergence: {result['kl_div'].item():.4f}")
print(f"Mean ratio:    {result['mean_ratio'].item():.4f}")
print(f"\nAdvantages:    {advantages.numpy().round(3)}")
print(f"Rewards:       {rewards.numpy()}")

In [None]:
# Verify gradients flow correctly
result['loss'].backward()
print("Gradient norm:", log_probs.grad.norm().item())
print("Gradients are flowing correctly!")

## 5. Your Turn

### TODO 1: Visualize the Effect of Epsilon on Clipping

In [None]:
def visualize_clipping_effect(advantages, epsilon_values=[0.1, 0.2, 0.3, 0.5]):
    """
    TODO: For each epsilon value, compute the clipped surrogate loss
    across a range of ratios (0.5 to 2.0) and plot the results.

    For a POSITIVE advantage:
    - Unclipped: ratio * A (linear, increasing)
    - Clipped: clip(ratio, 1-eps, 1+eps) * A (flat outside bounds)
    - Surrogate: min(unclipped, clipped)

    For a NEGATIVE advantage:
    - The clipping works differently (think about why!)

    Hints:
    1. Create ratio values: torch.linspace(0.5, 2.0, 100)
    2. For positive A, surr = min(ratio*A, clip(ratio)*A)
    3. Plot all epsilon values on the same axes
    """
    ratios = torch.linspace(0.5, 2.0, 200)
    A_pos = 1.0  # Positive advantage

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    for eps in epsilon_values:
        clipped_ratios = torch.clamp(ratios, 1 - eps, 1 + eps)

        # Positive advantage
        surr1 = ratios * A_pos
        surr2 = clipped_ratios * A_pos
        surr_pos = torch.min(surr1, surr2)

        # Negative advantage
        A_neg = -1.0
        surr1_neg = ratios * A_neg
        surr2_neg = clipped_ratios * A_neg
        surr_neg = torch.min(surr1_neg, surr2_neg)

        axes[0].plot(ratios.numpy(), surr_pos.numpy(), label=f'eps={eps}', linewidth=2)
        axes[1].plot(ratios.numpy(), surr_neg.numpy(), label=f'eps={eps}', linewidth=2)

    axes[0].set_title("Positive Advantage (A=1.0)", fontsize=13, fontweight='bold')
    axes[0].set_xlabel("Policy Ratio")
    axes[0].set_ylabel("Surrogate Objective")
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    axes[1].set_title("Negative Advantage (A=-1.0)", fontsize=13, fontweight='bold')
    axes[1].set_xlabel("Policy Ratio")
    axes[1].set_ylabel("Surrogate Objective")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.suptitle("Effect of Epsilon on Clipped Surrogate", fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig("clipping_effect.png", dpi=150, bbox_inches='tight')
    plt.show()

visualize_clipping_effect(None)

### TODO 2: Implement the KL-Annealing Schedule

In [None]:
def kl_annealing_schedule(step: int, total_steps: int, beta_init: float = 0.001, beta_final: float = 0.1):
    """
    TODO: Implement a KL annealing schedule that linearly increases
    beta from beta_init to beta_final over the course of training.

    Early in training: low beta = allow exploration
    Late in training: high beta = stay close to reference

    Args:
        step: Current training step
        total_steps: Total number of training steps
        beta_init: Initial beta value
        beta_final: Final beta value
    Returns:
        beta: Current beta value

    Hint: Linear interpolation: beta = beta_init + (beta_final - beta_init) * (step / total_steps)
    """
    progress = min(step / total_steps, 1.0)
    beta = beta_init + (beta_final - beta_init) * progress
    return beta

# Verify the schedule
steps = list(range(1000))
betas = [kl_annealing_schedule(s, 1000) for s in steps]

plt.figure(figsize=(8, 4))
plt.plot(steps, betas, linewidth=2, color='purple')
plt.xlabel("Training Step", fontsize=12)
plt.ylabel("Beta (KL coefficient)", fontsize=12)
plt.title("KL Annealing Schedule", fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.savefig("kl_annealing.png", dpi=150, bbox_inches='tight')
plt.show()
print(f"Beta at step 0:   {betas[0]:.4f}")
print(f"Beta at step 500: {betas[500]:.4f}")
print(f"Beta at step 999: {betas[999]:.4f}")

## 6. Putting It All Together

In [None]:
# Demonstrate the full GRPO loss on varying scenarios
scenarios = {
    "All good responses": torch.tensor([0.8, 0.9, 0.85, 0.95]),
    "Mixed responses": torch.tensor([0.1, 0.9, 0.3, 0.7]),
    "All bad responses": torch.tensor([0.1, 0.15, 0.05, 0.2]),
    "One outlier": torch.tensor([0.2, 0.2, 0.2, 0.95]),
}

print("=" * 70)
print(f"{'Scenario':<25} {'Advantages':<35} {'Loss':<10}")
print("=" * 70)

for name, rewards in scenarios.items():
    advantages = compute_grpo_advantages(rewards)

    # Create synthetic log probs
    G, T = len(rewards), 10
    log_probs = torch.randn(G, T) * 0.5 - 2.0
    log_probs.requires_grad_(True)
    old_log_probs = log_probs.detach()
    ref_log_probs = log_probs.detach() + torch.randn(G, T) * 0.1
    mask = torch.ones(G, T)

    result = compute_grpo_loss(log_probs, old_log_probs, ref_log_probs, advantages, mask)
    adv_str = ", ".join([f"{a:.2f}" for a in advantages.tolist()])
    print(f"{name:<25} [{adv_str}]  {result['loss'].item():<10.4f}")

## 7. Training and Results

In [None]:
# Simulate how the GRPO loss guides training
# We will optimize a simple parameter to demonstrate

torch.manual_seed(42)

# Goal: learn theta such that reward is maximized
theta = torch.tensor([0.0], requires_grad=True)
optimizer = torch.optim.Adam([theta], lr=0.01)

loss_history = []
theta_history = []
kl_history = []

for step in range(300):
    # Simulate G=8 responses with rewards depending on theta
    noise = torch.randn(8)
    rewards = -torch.abs(theta.item() + noise - 3.0)  # Optimal at theta=3

    advantages = compute_grpo_advantages(rewards)

    # Simulate policy update direction based on advantages
    weighted = (advantages * noise).mean()
    loss = -weighted

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    loss_history.append(loss.item())
    theta_history.append(theta.item())

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(theta_history, linewidth=2, color='blue')
axes[0].axhline(y=3.0, color='green', linestyle='--', label='Optimal')
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Theta")
axes[0].set_title("Parameter Convergence", fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(loss_history, linewidth=1, color='red', alpha=0.5)
# Smoothed
window = 20
smoothed = np.convolve(loss_history, np.ones(window)/window, mode='valid')
axes[1].plot(range(window-1, len(loss_history)), smoothed, linewidth=2, color='red')
axes[1].set_xlabel("Step")
axes[1].set_ylabel("Loss")
axes[1].set_title("Training Loss", fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("grpo_loss_training.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"Final theta: {theta_history[-1]:.3f} (target: 3.000)")

## 8. Final Output

In [None]:
print("=" * 60)
print("GRPO Loss Function Summary")
print("=" * 60)
print()
print("The GRPO objective combines three components:")
print()
print("1. CLIPPED SURROGATE: min(ratio * A, clip(ratio) * A)")
print("   - Prevents overly large policy updates")
print("   - Epsilon = 0.2 is standard")
print()
print("2. GROUP-RELATIVE ADVANTAGES: A_i = (r_i - mean) / std")
print("   - No critic needed")
print("   - Computed per-response, not per-token")
print()
print("3. KL PENALTY: beta * (pi_ref/pi_theta - log(pi_ref/pi_theta) - 1)")
print("   - Keeps policy close to reference")
print("   - beta = 0.04 is typical")
print()
print("Full loss: -E[min(ratio*A, clip(ratio)*A) - beta*KL]")
print("=" * 60)

## 9. Reflection and Next Steps

**Key takeaways from this notebook:**

1. The GRPO loss has three components: clipped surrogate, group-relative advantages, and KL penalty.
2. The importance sampling ratio compares current and old policy probabilities per-token.
3. Clipping prevents catastrophically large updates (same mechanism as PPO).
4. The KL penalty is added directly to the loss (unlike PPO where it is a constraint).
5. The advantage is per-response (not per-token), which is simpler but coarser.

**Reflection questions:**
- Why is the KL divergence computed per-token but the advantage is per-response? What are the implications?
- If you increase beta, what happens to the policy's behavior? What about training speed?
- How does the choice of epsilon interact with the group size G?

**Next notebook:** We will bring everything together and train a small language model to solve math problems using GRPO with verifiable rewards -- the same approach DeepSeek used for R1.