#  GRPO Training - On-Policy RL

Run Group Relative Policy Optimization to improve reasoning quality.

**Time estimate:** ~4-5 hours on Kaggle TPU

**Prerequisites:** Completed SFT training (notebook 02)

In [None]:
import os
import json
import time
import re
from datetime import datetime

import jax
import jax.numpy as jnp
import numpy as np
from transformers import AutoTokenizer

print(f"JAX devices: {jax.device_count()}")

## 1. GRPO Configuration

In [None]:
CFG = {
    # Model
    'model_name': 'google/gemma-3-1b-it',
    'sft_checkpoint': 'checkpoints/sft/sft_epoch_2',
    
    # GRPO
    'num_generations': 4,  # G: responses per prompt
    'temperature': 0.8,
    'learning_rate': 3e-6,
    'kl_coef': 0.03,
    'clip_epsilon': 0.2,
    
    # Training
    'num_updates': 2000,
    'batch_size': 2,  # prompts per batch
    'max_length': 1024,
    
    # Reward weights
    'w_correct': 0.60,
    'w_trace': 0.25,
    'w_conf': 0.15,
    
    # Logging
    'log_steps': 20,
    'save_minutes': 20,
    'eval_steps': 100,
    'seed': 42
}

print(" GRPO Configuration:")
for k, v in CFG.items():
    print(f"  {k}: {v}")

## 2. Load Model & Data

In [None]:
tokenizer = AutoTokenizer.from_pretrained(CFG['model_name'])

# Load SFT checkpoint
# model = load_checkpoint(CFG['sft_checkpoint'])
print(f" Loaded SFT checkpoint: {CFG['sft_checkpoint']}")

In [None]:
# Load training prompts with references
def load_jsonl(path):
    data = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

train_data = load_jsonl('data/prepared/train.jsonl')
val_data = load_jsonl('data/prepared/valid.jsonl')

print(f" Loaded {len(train_data)} train, {len(val_data)} val prompts")

## 3. Reward Functions

In [None]:
# Reward functions (from src/rewards.py)

TRANSITION_WORDS = [
    'therefore', 'thus', 'hence', 'so', 'because', 'first', 'second',
    'step', 'next', 'then', 'finally', 'since', 'given'
]

def extract_answer(text):
    match = re.search(r'<answer>(.*?)</answer>', text, re.DOTALL | re.IGNORECASE)
    return match.group(1).strip() if match else None

def extract_reasoning(text):
    match = re.search(r'<reasoning>(.*?)</reasoning>', text, re.DOTALL | re.IGNORECASE)
    return match.group(1).strip() if match else None

def correctness_score(pred_text, ref_answer):
    pred_ans = extract_answer(pred_text)
    if pred_ans is None:
        return 0.0
    pred_norm = pred_ans.lower().strip()
    ref_norm = ref_answer.lower().strip()
    try:
        pred_num = float(re.sub(r'[^\d.\-]', '', pred_norm))
        ref_num = float(re.sub(r'[^\d.\-]', '', ref_norm))
        return 1.0 if abs(pred_num - ref_num) < 0.01 else 0.0
    except:
        return 1.0 if pred_norm == ref_norm else 0.0

def trace_structure_score(text):
    score = 0.0
    if '<reasoning>' in text.lower() and '</reasoning>' in text.lower():
        score += 0.2
    if '<answer>' in text.lower() and '</answer>' in text.lower():
        score += 0.2
    
    reasoning = extract_reasoning(text)
    if reasoning:
        steps = [s for s in re.split(r'[.\n]', reasoning) if len(s.strip()) > 10]
        score += min(0.4, len(steps) / 3.0 * 0.4)
        trans_count = sum(1 for w in TRANSITION_WORDS if w in reasoning.lower())
        score += min(0.2, trans_count / 3.0 * 0.2)
    
    return min(1.0, score)

def composite_reward(pred_text, ref_answer):
    correct = correctness_score(pred_text, ref_answer)
    trace = trace_structure_score(pred_text)
    conf = 0.5  # Placeholder for confidence
    
    total = (CFG['w_correct'] * correct + 
             CFG['w_trace'] * trace + 
             CFG['w_conf'] * conf)
    
    return total, {'correct': correct, 'trace': trace, 'conf': conf}

In [None]:
# Test reward function
test_output = """<reasoning>
Step 1: We have 45 apples.
Step 2: We sell 12 apples.
Step 3: Therefore, 45 - 12 = 33 apples remain.
</reasoning>
<answer>33</answer>"""

reward, components = composite_reward(test_output, "33")
print(f"Test reward: {reward:.3f}")
print(f"Components: {components}")

## 4. GRPO Training Loop

In [None]:
# Training state
start_time = time.time()
last_save_time = start_time
global_step = 0
best_reward = -float('inf')

# Metrics tracking
reward_history = []
accuracy_history = []

os.makedirs('checkpoints/rl', exist_ok=True)
os.makedirs('logs/eval_reports', exist_ok=True)

In [None]:
def sample_batch(data, batch_size):
    """Sample a batch of prompts."""
    indices = np.random.choice(len(data), batch_size, replace=False)
    return [data[i] for i in indices]

def generate_responses(prompts, num_gen=4):
    """Generate multiple responses per prompt."""
    # Placeholder - replace with actual generation
    responses = []
    for prompt in prompts:
        prompt_responses = []
        for _ in range(num_gen):
            # response = model.generate(prompt, temperature=CFG['temperature'])
            response = "<reasoning>Step 1: Calculate.</reasoning><answer>42</answer>"
            prompt_responses.append(response)
        responses.append(prompt_responses)
    return responses

def compute_advantages(rewards):
    """GRPO: advantage = reward - mean(group)."""
    mean_reward = sum(rewards) / len(rewards)
    return [r - mean_reward for r in rewards]

def grpo_update(prompts, responses, advantages):
    """Apply GRPO policy update."""
    # Placeholder - replace with actual update
    loss = 0.1
    return loss

In [None]:
def train_step():
    """Single GRPO training step."""
    global global_step, last_save_time, best_reward
    
    # 1. Sample prompts
    batch = sample_batch(train_data, CFG['batch_size'])
    prompts = [ex['text'].split('A:\n')[0] + 'A:\n' for ex in batch]
    refs = [ex.get('reference_answer', '') for ex in batch]
    
    # 2. Generate G responses per prompt
    all_responses = generate_responses(prompts, CFG['num_generations'])
    
    # 3. Compute rewards
    all_rewards = []
    all_components = []
    for i, responses in enumerate(all_responses):
        group_rewards = []
        for resp in responses:
            r, comp = composite_reward(resp, refs[i])
            group_rewards.append(r)
            all_components.append(comp)
        all_rewards.append(group_rewards)
    
    # 4. Compute advantages
    all_advantages = []
    for group_rewards in all_rewards:
        all_advantages.extend(compute_advantages(group_rewards))
    
    # 5. GRPO update
    flat_responses = [r for group in all_responses for r in group]
    flat_prompts = [p for p in prompts for _ in range(CFG['num_generations'])]
    loss = grpo_update(flat_prompts, flat_responses, all_advantages)
    
    # Metrics
    avg_reward = np.mean([r for group in all_rewards for r in group])
    avg_correct = np.mean([c['correct'] for c in all_components])
    
    global_step += 1
    reward_history.append(avg_reward)
    accuracy_history.append(avg_correct)
    
    return loss, avg_reward, avg_correct

In [None]:
def save_checkpoint(name, metrics=None):
    path = f"checkpoints/rl/{name}"
    os.makedirs(path, exist_ok=True)
    
    meta = {
        'step': global_step,
        'time': datetime.now().isoformat(),
        'config': CFG,
        'metrics': metrics or {}
    }
    with open(f"{path}/metadata.json", 'w') as f:
        json.dump(meta, f, indent=2)
    
    # model.save_pretrained(path)
    print(f" Saved: {path}")

def evaluate(num_samples=50):
    """Run evaluation on val set."""
    correct = 0
    format_ok = 0
    trace_scores = []
    
    samples = val_data[:num_samples]
    for ex in samples:
        prompt = ex['text'].split('A:\n')[0] + 'A:\n'
        ref = ex.get('reference_answer', '')
        
        # output = model.generate(prompt)
        output = "<reasoning>Step.</reasoning><answer>42</answer>"
        
        if correctness_score(output, ref) > 0.5:
            correct += 1
        if '<reasoning>' in output and '<answer>' in output:
            format_ok += 1
        trace_scores.append(trace_structure_score(output))
    
    return {
        'accuracy': correct / len(samples),
        'format_rate': format_ok / len(samples),
        'avg_trace': np.mean(trace_scores)
    }

In [None]:
# Main training loop
print(" Starting GRPO Training...")
print(f"   Updates: {CFG['num_updates']}")
print(f"   G (generations/prompt): {CFG['num_generations']}")
print()

for step in range(CFG['num_updates']):
    loss, reward, accuracy = train_step()
    
    # Log
    if global_step % CFG['log_steps'] == 0:
        elapsed = (time.time() - start_time) / 60
        print(f"Step {global_step} | Reward: {reward:.3f} | Acc: {accuracy:.2%} | Time: {elapsed:.1f}m")
    
    # Evaluate
    if global_step % CFG['eval_steps'] == 0:
        metrics = evaluate()
        print(f"   Eval - Acc: {metrics['accuracy']:.2%}, Format: {metrics['format_rate']:.2%}")
        
        # Save best
        composite = 0.6 * metrics['accuracy'] + 0.25 * metrics['avg_trace'] + 0.15 * metrics['format_rate']
        if composite > best_reward:
            best_reward = composite
            save_checkpoint('best', metrics)
    
    # Periodic save
    if (time.time() - last_save_time) > CFG['save_minutes'] * 60:
        save_checkpoint(f"step_{global_step}")
        last_save_time = time.time()

print("\n GRPO Training complete!")

## 5. Training Summary

In [None]:
total_time = (time.time() - start_time) / 60

print("\n" + "="*50)
print("GRPO TRAINING COMPLETE")
print("="*50)
print(f"Total time: {total_time:.1f} minutes")
print(f"Steps: {global_step}")
print(f"Best composite score: {best_reward:.3f}")
print(f"Final avg reward: {np.mean(reward_history[-100:]):.3f}")
print(f"Final accuracy: {np.mean(accuracy_history[-100:]):.2%}")
print("\n Proceed to: 04_evaluation_and_export.ipynb")