# GRPO (Group Relative Policy Optimization) Training Loop

This notebook demonstrates GRPO training on a simple arithmetic task. GRPO is a variant of PPO that uses group-based advantage estimation instead of value functions.

**Key idea**: Generate multiple responses per prompt, compute advantages relative to the group mean, then optimize with PPO-style clipping.

## Setup and Dependencies

In [1]:
!pip install transformers



Basic config - using small token limit since we're doing simple arithmetic

In [2]:
device="cuda"
max_new_tokens = 2

In [3]:
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM

## Model Setup

Load two copies of the same model:
- `model`: trainable policy we'll optimize
- `ref_model`: frozen reference policy for KL regularization

In [4]:
model_id = "HuggingFaceTB/SmolLM-135M-Instruct"  # a small model that fits comfortably on T4

model = AutoModelForCausalLM.from_pretrained(model_id)
model.to(device)
model.train()

ref_model = AutoModelForCausalLM.from_pretrained(model_id)
ref_model.to(device)
ref_model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_id)

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


## Task: Simple Arithmetic

Generate random arithmetic problems (addition/multiplication) with single digits. Simple but non-trivial for small models.

In [5]:
import random
def get_batch(size=10):
  batch = []
  for _ in range(size):
    a = random.randint(0, 9)
    b = random.randint(0, 9)
    op = random.choice(["+", "*"])
    if op == "+":
      target = a + b
    else:
      target = a * b
    batch.append((f"Solve {a}{op}{b}=", str(target)))
  return batch

## Group Generation

Core GRPO component: generate multiple responses per prompt. The mask tracks which tokens are newly generated (for computing log probabilities).

In [6]:
@torch.no_grad()
def generate_group(model, tokenizer, prompt, group_size=4, temperature=0.7,
                   max_new_tokens=3):
  input = tokenizer(prompt, return_tensors="pt").to(device)
  output = model.generate(**input, do_sample=True,
                          num_return_sequences=group_size,
                          temperature=temperature,
                          max_new_tokens=max_new_tokens)
  mask = torch.zeros(output.size()).to(device)
  mask[:, -max_new_tokens:] = 1
  return output, mask

## Reward and Evaluation

Binary reward: 1.0 if the model's answer matches the target, 0.0 otherwise. Simple but effective for this task.

In [7]:
def extract_answer(response):
  if "=" not in response:
    return None
  return "".join(response.split("=")[1:]).strip()

def reward(response, target):
  return 1.0 if extract_answer(response) == target else 0.0

Evaluation on a random batch

In [8]:
@torch.no_grad()
def eval(model, batch):
  rewards = []
  responses = []
  input = tokenizer([b[0] for b in batch], return_tensors="pt").to(device)
  output = model.generate(**input, max_new_tokens=max_new_tokens, do_sample=False)  # greedy generation
  for i, o in enumerate(output):
    response = tokenizer.decode(o, skip_special_tokens=True)
    responses.append(response)
    rewards.append(reward(response, batch[i][1]))

  acc = sum(rewards) / len(rewards)
  print(f"Average accuracy = {acc}\n")
  
  # Show 10 random examples
  sample_indices = random.sample(range(len(batch)), min(10, len(batch)))
  for i in sample_indices:
    is_correct = rewards[i] == 1.0
    symbol = '✓' if is_correct else '✗'
    print(f"{symbol} {responses[i].strip()}. Target: {batch[i][1]}"),
    
  
  return acc

## GRPO Loss Function

The GRPO loss from Shao et al. 2024 ("DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models"):

$$J_{GRPO}(\theta) = \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \min\left(\frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t}|q, o_{i,<t})} \hat{A}_{i,t}, \text{clip}\left(\frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t}|q, o_{i,<t})}, 1-\varepsilon, 1+\varepsilon\right) \hat{A}_{i,t}\right) - \beta D_{KL}(\pi_\theta || \pi_{ref})$$

**Variable definitions:**
- $q$: input query/prompt (e.g., "Solve 3+5=")
- $o_i$: $i$-th output/response in the group (e.g., "8")
- $G$: group size (number of responses per prompt)
- $\pi_\theta(o_{i,t}|q, o_{i,<t})$: probability of token $t$ in response $i$ under current policy
- $\pi_{\theta_{old}}$: old policy (frozen during gradient step)
- $\pi_{ref}$: reference policy (frozen throughout training)

**KL divergence estimator** (unbiased, Schulman 2020):
$$D_{KL}(\pi_\theta || \pi_{ref}) = \frac{\pi_{ref}(o_{i,t}|q, o_{i,<t})}{\pi_\theta(o_{i,t}|q, o_{i,<t})} - \log\frac{\pi_{ref}(o_{i,t}|q, o_{i,<t})}{\pi_\theta(o_{i,t}|q, o_{i,<t})} - 1$$

In [9]:
def get_logprobs(model, tokens, mask):
  # Compute log probabilities for autoregressive language modeling
  logits = model(tokens).logits # (B, T, V) - raw logits for all vocab tokens
  log_probs = F.log_softmax(logits, dim=-1) # (B, T, V) - log π(token | context) for all tokens
  
  # Shift for next-token prediction: predict token t+1 from context up to t
  log_probs = log_probs[:, :-1, :] # (B, T-1, V) - remove last position (no next token to predict)
  shift_tokens = tokens[:, 1:] # (B, T-1) - target tokens (what we're trying to predict)
  shift_mask = mask[:, 1:] # (B, T-1) - mask for generated tokens only
  
  # Extract log π(actual_token | context) for each generated token
  tok_logps = log_probs.gather(-1, shift_tokens.unsqueeze(-1)).squeeze(-1) # (B, T-1)
  tok_logps = tok_logps * shift_mask # Zero out prompt tokens, keep only generated tokens
  
  # For GRPO formula: sum over output tokens to get sequence log probability
  # This implements: log π(o_i | q) = Σ_t log π(o_{i,t} | q, o_{i,<t})
  seq_logps = tok_logps.sum(dim=-1) # (B) - used for PPO ratio calculation

  # Return both: seq_logps for ratios as approximation, tok_logps for KL divergence (token-level)
  return seq_logps, tok_logps

## GRPO Training Step

**Key GRPO algorithm:**
1. Generate group of responses per prompt
2. Compute advantages relative to group mean: `(reward - group_mean) / group_std`
3. Apply PPO clipping with KL regularization

This avoids needing a separate value function - the group statistics provide the baseline.

**policy_steps**: Number of gradient updates per batch. If policy_steps=1, we do standard single-step updates. Multiple steps (e.g., 2) allow more aggressive optimization but risk overfitting to the current batch.

In [None]:
def grpo_step(model, ref_model, tokenizer, batch, optimizer, group_size=6,
              temperature=1.0, clip_eps=0.6, beta=0.02, policy_steps=2):
  # GRPO Step 1: Generate groups of responses and compute group-relative advantages
  samples = []
  all_rewards = []
  
  for prompt, target in batch:
    # Generate G responses for this prompt (implements: {o_i}_{i=1}^G ~ π_θ_old(O|q))
    tokens, mask = generate_group(model, tokenizer, prompt,
                                  group_size=group_size,
                                  temperature=temperature)

    # Compute rewards for each response in the group
    group_rewards = []
    for t in tokens:
      response = tokenizer.decode(t, skip_special_tokens=True)
      group_rewards.append(reward(response, target))
    group_rewards = torch.tensor(group_rewards, device=device)
    all_rewards.append(group_rewards)

    # GRPO key insight: compute advantages relative to group statistics
    # This replaces the value function baseline in standard PPO
    r_avg = group_rewards.mean()  # Group mean reward \bar{R}
    r_std = group_rewards.std(unbiased=False)  # Group std σ_R
    advantages = (group_rewards - r_avg) / (r_std + 1e-8)  # Normalized advantages \hat{A}_t

    # Store samples for batch processing
    for i in range(len(tokens)):
      samples.append({
          "a": advantages[i].detach(), # Advantage for this response
          "tokens": tokens[i],
          "mask": mask[i]
      })

  # GRPO Step 2: Batch all samples for efficient processing
  N = len(samples)  # Total number of responses across all prompts
  max_T = max([sample["tokens"].shape[0] for sample in samples])

  # Pad sequences to same length for batching
  tokens = torch.full((N, max_T), fill_value=tokenizer.pad_token_id,
                      dtype=torch.long, device=device)
  mask = torch.zeros((N, max_T), device=device)
  a = torch.empty((N,), dtype=torch.float32, device=device)

  for i, sample in enumerate(samples):
    T = sample["tokens"].shape[0]
    tokens[i, :T] = sample["tokens"]
    mask[i, :T] = sample["mask"]
    a[i] = sample["a"]

  # GRPO Step 3: Compute baseline probabilities (frozen during optimization)
  with torch.no_grad():
      seq_logps_old, _ = get_logprobs(model, tokens, mask)  # π_θ_old for PPO ratio
      _, ref_tok_logps = get_logprobs(ref_model, tokens, mask)  # π_ref for KL

  # GRPO Step 4: Policy optimization loop
  for _ in range(policy_steps):
    # Compute current policy probabilities (with gradients)
    seq_logps, tok_logps = get_logprobs(model, tokens, mask)
    
    # PPO clipped surrogate loss (using sequence-level ratios as approximation)
    # Note: Paper uses token-level ratios, this is a common simplification
    ratio = torch.exp(seq_logps - seq_logps_old)  # π_θ(o_i|q) / π_θ_old(o_i|q)
    surr1 = ratio * a  # Unclipped surrogate
    surr2 = torch.clamp(ratio, 1-clip_eps, 1+clip_eps) * a  # Clipped surrogate
    surr = torch.minimum(surr1, surr2)  # PPO clipping: min(surr1, surr2)

    # Unbiased KL divergence estimator (Schulman 2020)
    # Operates on token-level probabilities as required by the estimator
    kl = (ref_tok_logps - tok_logps).mean(dim=1)  # log(π_ref/π_θ) averaged over tokens
    D_kl = torch.exp(kl) - kl - 1  # Unbiased estimator: exp(x) - x - 1
    
    # Final GRPO loss: -PPO_objective + β * KL_penalty. Averaged over all outputs for all inputs in a batch
    loss = (-surr + beta * D_kl).mean()

    # Standard optimization step
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
    optimizer.step()

  return {"mean_rewards": (torch.cat(all_rewards)).mean().cpu()}

## Training Loop

Run GRPO for 50 steps. Each step processes a batch of 8 prompts, generating 6 responses per prompt.

In [11]:
optimizer = AdamW(model.parameters(), lr=5e-6)

for step in range(1, 51):
  batch = get_batch(size=16)
  info = grpo_step(model, ref_model, tokenizer, batch, optimizer, group_size=6,
              temperature=0.7, clip_eps=0.2, beta=0.01)
  if step%5 == 0:
    print(f'{step}: reward: {info["mean_rewards"]:.3f}')

5: reward: 0.125
10: reward: 0.344
15: reward: 0.365
20: reward: 0.594
25: reward: 0.458
30: reward: 0.792
35: reward: 0.667
40: reward: 0.667
45: reward: 0.875
50: reward: 0.854


## Evaluation

Compare reference model (untrained) vs trained model performance on arithmetic.

In [12]:
eval_batch = get_batch(100)

In [13]:
ref_acc = eval(ref_model, eval_batch)

Average accuracy = 0.14

✓ Solve 2+4=6. Target: 6
✓ Solve 4*4=16. Target: 16
✗ Solve 6*5=20. Target: 30
✗ Solve 3+3=7. Target: 6
✗ Solve 3+0=10. Target: 3
✗ Solve 4+7=10. Target: 11
✗ Solve 6+0=10. Target: 6
✗ Solve 0*9=10. Target: 0
✗ Solve 4+1=6. Target: 5
✗ Solve 1*9=10. Target: 9


In [14]:
trained_acc = eval(model, eval_batch)

Average accuracy = 0.82

✓ Solve 0*2=0. Target: 0
✓ Solve 8+9=17. Target: 17
✗ Solve 5*3=13. Target: 15
✓ Solve 8+2=10. Target: 10
✓ Solve 2+1=3. Target: 3
✓ Solve 1*4=4. Target: 4
✗ Solve 6+6=14. Target: 12
✓ Solve 0*6=0. Target: 0
✓ Solve 0+9=9. Target: 9
✓ Solve 2+7=9. Target: 9
