# GRPO Training with Qwen 0.5B for Mathematical Reasoning

This notebook demonstrates training a small Qwen model using Group Relative Policy Optimization (GRPO) for improving mathematical problem-solving capabilities.

## Features:
- Downloads and loads Qwen 0.5B model
- Uses GSM8K math dataset for training
- Custom prompt formatting for math problems
- Math-specific reward function
- Tracks training metrics
- Saves checkpoints

## 1. Setup and Installation

In [None]:
# Install required packages if not already installed
import subprocess
import sys

def install_if_needed(package):
    try:
        __import__(package)
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Ensure simple_rl is installed
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", ".."])

print("Setup complete!")

## 2. Import Libraries

In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Optional
import json
import os
from pathlib import Path

# Import GRPO from simple_rl
from simple_rl.algorithms.grpo import GRPO
from simple_rl.utils.config import Config

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

## 3. Load Qwen 0.5B Model

In [None]:
# Model configuration
MODEL_NAME = "Qwen/Qwen2.5-0.5B"  # Small 0.5B parameter model

print(f"Loading model: {MODEL_NAME}")
print("This may take a few minutes on first run...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

# Add padding token if not present
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    
# Load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None,
    trust_remote_code=True
)

if not torch.cuda.is_available():
    model = model.to(device)

print(f"Model loaded successfully!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

## 4. Load and Prepare Dataset

In [None]:
# Load GSM8K math dataset
print("Loading GSM8K math dataset...")

# Load GSM8K - a dataset of grade school math problems
dataset = load_dataset("gsm8k", "main", split="train[:200]")  # Use 200 examples for demo

# Extract math problems and prepare prompts
math_prompts = []
math_answers = []  # Store answers for reward calculation

for item in dataset:
    # GSM8K format: question and answer with step-by-step solution
    question = item['question']
    answer = item['answer']
    
    # Store just the question as prompt (without formatting yet)
    math_prompts.append(question)
    
    # Extract the final numerical answer (after ####)
    final_answer = answer.split("####")[-1].strip() if "####" in answer else answer
    math_answers.append(final_answer)

print(f"Loaded {len(math_prompts)} math problems")
print("\nExample problems:")
for i in range(min(3, len(math_prompts))):
    print(f"\n{i+1}. Question: {math_prompts[i][:150]}...")
    print(f"   Answer: {math_answers[i]}")

# Store answers for later use in reward function
PROBLEM_ANSWERS = dict(zip(math_prompts, math_answers))

## 5. Configure GRPO Parameters

In [None]:
# GRPO Configuration with math-specific prompt formatting
grpo_config = {
    "algorithm": {
        "name": "grpo",
        "group_size": 4,  # Number of completions per prompt
        "kl_coef": 0.05,  # KL divergence coefficient
        "clip_range": 0.2,  # PPO-style clipping
        "normalize_rewards": True,  # Normalize rewards per group
        "update_epochs": 2,  # PPO update epochs
        "minibatch_size": 4,  # Minibatch size for updates
    },
    "training": {
        "batch_size": 8,  # Total batch size (must be divisible by group_size)
        "learning_rate": 1e-5,  # Learning rate
        "gradient_clip": 1.0,  # Gradient clipping
        "max_new_tokens": 150,  # More tokens for math reasoning
        "temperature": 0.7,  # Generation temperature
        "num_episodes": 30,  # More episodes for math training
    },
    "generation": {
        # Math-specific prompt formatting
        "system_prompt": "You are a helpful math tutor. Solve the problem step by step, showing your work clearly.",
        "prompt_template": "Problem: {prompt}\n\nSolution: Let me solve this step by step.\n",
        "response_prefix": "",
    },
    "model": {
        "max_length": 512,  # Maximum sequence length
    },
    "logging": {
        "log_interval": 2,  # Log every 2 episodes
        "save_interval": 10,  # Save checkpoint every 10 episodes
    },
    "wandb": {
        "enabled": False,  # Disable W&B for notebook demo
    }
}

# Create config object
config = Config(**grpo_config)

print("GRPO Configuration for Math Training:")
print(json.dumps(grpo_config, indent=2))

# Show example of formatted prompt
example_prompt = math_prompts[0] if math_prompts else "What is 2 + 2?"
print(f"\nExample formatted prompt:")
print("-" * 50)
formatted_example = grpo_config["generation"]["system_prompt"] + "\n\n" + \
                   grpo_config["generation"]["prompt_template"].replace("{prompt}", example_prompt)
print(formatted_example)
print("-" * 50)

## 6. Define Reward Function

In [None]:
import re

def extract_number(text):
    """Extract the final numerical answer from text."""
    # Look for patterns like "answer is X", "equals X", "= X", etc.
    patterns = [
        r"answer is:?\s*([-]?\d+\.?\d*)",
        r"equals?\s*([-]?\d+\.?\d*)",
        r"=\s*([-]?\d+\.?\d*)",
        r"total of\s*([-]?\d+\.?\d*)",
        r"result is:?\s*([-]?\d+\.?\d*)",
        r"Therefore,?\s*([-]?\d+\.?\d*)",
        r"So,?\s*([-]?\d+\.?\d*)",
        r"\$?([-]?\d+\.?\d*)\s*(?:dollars?|cents?|items?|students?|people|apples?|cookies?)?\.?\s*$"
    ]
    
    text_lower = text.lower()
    for pattern in patterns:
        match = re.search(pattern, text_lower)
        if match:
            return match.group(1)
    
    # If no pattern matches, look for any number at the end
    numbers = re.findall(r"[-]?\d+\.?\d*", text)
    if numbers:
        return numbers[-1]
    return None

def compute_math_reward(prompt: str, completion: str) -> float:
    """
    Reward function for mathematical problem solving.
    
    Rewards based on:
    - Correct final answer (most important)
    - Showing work/steps
    - Mathematical reasoning indicators
    - Proper formatting
    """
    reward = 0.0
    
    # Check if we have the correct answer for this problem
    correct_answer = PROBLEM_ANSWERS.get(prompt, None)
    
    if correct_answer:
        # Extract the model's answer
        model_answer = extract_number(completion)
        
        if model_answer:
            try:
                # Check if the answer is correct (allowing for small float differences)
                correct_float = float(correct_answer)
                model_float = float(model_answer)
                
                if abs(correct_float - model_float) < 0.01:
                    reward += 2.0  # Big reward for correct answer
                else:
                    # Partial credit for being close
                    if abs(correct_float - model_float) / max(abs(correct_float), 1) < 0.1:
                        reward += 0.5
                    else:
                        reward -= 0.5  # Penalty for wrong answer
            except ValueError:
                # If conversion fails, do string comparison
                if model_answer == correct_answer:
                    reward += 2.0
                else:
                    reward -= 0.3
        else:
            # No answer found
            reward -= 1.0
    
    # Reward for showing mathematical work
    math_indicators = [
        "step", "first", "then", "next", "finally",
        "calculate", "multiply", "divide", "add", "subtract",
        "=", "+", "-", "*", "/", "×", "÷"
    ]
    
    work_shown = sum(1 for indicator in math_indicators if indicator in completion.lower())
    if work_shown >= 3:
        reward += 0.5  # Reward for showing work
    elif work_shown == 0:
        reward -= 0.3  # Penalty for no work shown
    
    # Reward for reasonable length (not too short, not too long)
    completion_length = len(completion.split())
    if 20 <= completion_length <= 200:
        reward += 0.2
    elif completion_length < 10:
        reward -= 0.5  # Too short
    elif completion_length > 300:
        reward -= 0.2  # Too verbose
    
    # Penalty for repetition
    sentences = completion.split('.')
    if len(sentences) > 1:
        unique_sentences = len(set(sentences))
        if unique_sentences / len(sentences) < 0.7:
            reward -= 0.5  # Repetitive
    
    # Penalty for obvious errors or nonsense
    if "error" in completion.lower() or "sorry" in completion.lower():
        reward -= 0.5
    
    return reward

# Test the reward function
test_cases = [
    ("What is 5 + 3?", "5 + 3 = 8. The answer is 8."),
    ("What is 5 + 3?", "Let me calculate: 5 + 3 equals 7."),
    ("What is 5 + 3?", "8"),
    ("What is 5 + 3?", "First, I'll add 5 and 3. 5 + 3 = 8. Therefore, the answer is 8."),
]

# Set up test answer
PROBLEM_ANSWERS["What is 5 + 3?"] = "8"

print("Testing math reward function:")
for prompt, completion in test_cases:
    reward = compute_math_reward(prompt, completion)
    print(f"\nPrompt: {prompt}")
    print(f"Completion: {completion}")
    print(f"Reward: {reward:.2f}")

## 7. Initialize GRPO with Custom Model

In [None]:
# Initialize GRPO with the loaded model and math reward function
print("Initializing GRPO with math-specific configuration...")

# Create GRPO instance with math reward
grpo = GRPO(
    model=model,
    config=config.to_dict(),
    tokenizer=tokenizer,
    reward_fn=compute_math_reward,  # Using our math-specific reward function
    use_wandb=False
)

print("GRPO initialized successfully!")
print(f"Policy model parameters: {sum(p.numel() for p in grpo.model.parameters()) / 1e6:.1f}M")
print(f"Reference model parameters: {sum(p.numel() for p in grpo.reference_model.parameters()) / 1e6:.1f}M")

# Show how prompts will be formatted
print("\nPrompt formatting example:")
sample_problem = "If John has 5 apples and Mary gives him 3 more, how many apples does John have?"
formatted = grpo.format_prompt(sample_problem)
print("Original:", sample_problem)
print("Formatted:", formatted)

## 7.5 Experiment with Different Math Prompt Formats

You can dynamically change how math problems are presented to the model:

# Experiment with different prompt formats for math problems
test_problem = "A store sells pencils for $0.50 each. If Sarah buys 8 pencils, how much does she pay?"

print("Testing different prompt formats for math problems:\n")

# Format 1: Step-by-step instruction
print("1. Step-by-step format:")
grpo.set_generation_prompt(
    system_prompt="Solve this math problem step by step. Show all your work.",
    prompt_template="Question: {prompt}\n\nStep-by-step solution:",
    response_prefix="\n"
)
print(grpo.format_prompt(test_problem))
print("\n" + "="*50 + "\n")

# Format 2: Chain-of-thought prompting
print("2. Chain-of-thought format:")
grpo.set_generation_prompt(
    system_prompt="Think through this problem carefully.",
    prompt_template="{prompt}\n\nLet's think step by step:",
    response_prefix=" "
)
print(grpo.format_prompt(test_problem))
print("\n" + "="*50 + "\n")

# Format 3: Structured math format
print("3. Structured format:")
grpo.set_generation_prompt(
    system_prompt=None,  # No system prompt
    prompt_template="Math Problem:\n{prompt}\n\nGiven information:\n- \n\nCalculation:\n",
    response_prefix=""
)
print(grpo.format_prompt(test_problem))
print("\n" + "="*50 + "\n")

# Reset to original configuration for training
print("Resetting to training configuration...")
grpo.set_generation_prompt(
    system_prompt=grpo_config["generation"]["system_prompt"],
    prompt_template=grpo_config["generation"]["prompt_template"],
    response_prefix=grpo_config["generation"]["response_prefix"]
)
print("Ready for training!")

## 8. Training Loop with Metrics Tracking

In [None]:
# Training metrics storage
training_metrics = {
    "episode": [],
    "total_loss": [],
    "pg_loss": [],
    "kl_div": [],
    "mean_reward": [],
    "reward_std": [],
}

# Create checkpoint directory
checkpoint_dir = Path("checkpoints/grpo_qwen_math")
checkpoint_dir.mkdir(parents=True, exist_ok=True)

print(f"Starting GRPO training for {grpo_config['training']['num_episodes']} episodes...")
print("Training on math problems from GSM8K dataset")
print("=" * 50)

# Training loop
for episode in range(grpo_config['training']['num_episodes']):
    # Sample batch of math problems
    batch_size = grpo_config['training']['batch_size'] // grpo_config['algorithm']['group_size']
    batch_indices = np.random.choice(len(math_prompts), batch_size, replace=True)
    batch_prompts = [math_prompts[i] for i in batch_indices]
    
    # Prepare batch data (prompts will be formatted internally by GRPO)
    batch_data = {"prompts": batch_prompts}
    
    # Training step
    metrics = grpo.train_step(batch_data)
    
    # Store metrics
    training_metrics["episode"].append(episode)
    training_metrics["total_loss"].append(metrics["total_loss"])
    training_metrics["pg_loss"].append(metrics["pg_loss"])
    training_metrics["kl_div"].append(metrics["kl_div"])
    training_metrics["mean_reward"].append(metrics["mean_reward"])
    training_metrics["reward_std"].append(metrics.get("reward_std", 0.0))
    
    # Logging
    if episode % grpo_config['logging']['log_interval'] == 0:
        print(f"Episode {episode:3d} | "
              f"Loss: {metrics['total_loss']:7.4f} | "
              f"PG Loss: {metrics['pg_loss']:7.4f} | "
              f"KL: {metrics['kl_div']:7.4f} | "
              f"Reward: {metrics['mean_reward']:6.3f} ± {metrics.get('reward_std', 0.0):5.3f}")
    
    # Save checkpoint
    if (episode + 1) % grpo_config['logging']['save_interval'] == 0:
        checkpoint_path = checkpoint_dir / f"checkpoint_episode_{episode+1}.pt"
        grpo.save_checkpoint(str(checkpoint_path))
        print(f"  → Saved checkpoint to {checkpoint_path}")

print("=" * 50)
print("Training complete!")

## 9. Visualize Training Metrics

In [None]:
# Plot training metrics
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
fig.suptitle('GRPO Training Metrics', fontsize=16)

# Total Loss
axes[0, 0].plot(training_metrics["episode"], training_metrics["total_loss"], 'b-', alpha=0.7)
axes[0, 0].set_xlabel('Episode')
axes[0, 0].set_ylabel('Total Loss')
axes[0, 0].set_title('Total Loss over Training')
axes[0, 0].grid(True, alpha=0.3)

# Policy Gradient Loss
axes[0, 1].plot(training_metrics["episode"], training_metrics["pg_loss"], 'g-', alpha=0.7)
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('PG Loss')
axes[0, 1].set_title('Policy Gradient Loss')
axes[0, 1].grid(True, alpha=0.3)

# KL Divergence
axes[1, 0].plot(training_metrics["episode"], training_metrics["kl_div"], 'r-', alpha=0.7)
axes[1, 0].set_xlabel('Episode')
axes[1, 0].set_ylabel('KL Divergence')
axes[1, 0].set_title('KL Divergence from Reference')
axes[1, 0].grid(True, alpha=0.3)

# Mean Reward
axes[1, 1].plot(training_metrics["episode"], training_metrics["mean_reward"], 'purple', alpha=0.7, label='Mean')
axes[1, 1].fill_between(
    training_metrics["episode"],
    np.array(training_metrics["mean_reward"]) - np.array(training_metrics["reward_std"]),
    np.array(training_metrics["mean_reward"]) + np.array(training_metrics["reward_std"]),
    alpha=0.3, color='purple'
)
axes[1, 1].set_xlabel('Episode')
axes[1, 1].set_ylabel('Reward')
axes[1, 1].set_title('Mean Reward ± Std')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].legend()

plt.tight_layout()
plt.show()

# Print summary statistics
print("\nTraining Summary:")
print(f"Final Total Loss: {training_metrics['total_loss'][-1]:.4f}")
print(f"Final KL Divergence: {training_metrics['kl_div'][-1]:.4f}")
print(f"Final Mean Reward: {training_metrics['mean_reward'][-1]:.3f}")
print(f"Average Reward (last 5 episodes): {np.mean(training_metrics['mean_reward'][-5:]):.3f}")

## 10. Generate Sample Outputs

In [None]:
# Test the trained model with math problems
test_math_problems = [
    "If a box contains 12 cookies and you eat 3, how many are left?",
    "John has 5 apples. Mary gives him 7 more apples. How many apples does John have now?",
    "A shirt costs $15 and pants cost $25. What is the total cost?",
    "If you have 20 candies and share them equally among 4 friends, how many does each friend get?",
]

# Store test answers for reward calculation
test_answers = ["9", "12", "40", "5"]
for prob, ans in zip(test_math_problems, test_answers):
    PROBLEM_ANSWERS[prob] = ans

print("Testing trained model on new math problems:")
print("=" * 50)

for i, problem in enumerate(test_math_problems, 1):
    print(f"\nProblem {i}: {problem}")
    print("-" * 40)
    
    # Format the prompt using GRPO's formatter
    formatted_prompt = grpo.format_prompt(problem)
    
    # Tokenize prompt
    inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=256)
    inputs = {k: v.to(grpo.model.device) for k, v in inputs.items()}
    
    # Generate completion
    with torch.no_grad():
        outputs = grpo.model.generate(
            **inputs,
            max_new_tokens=150,
            temperature=0.7,
            do_sample=True,
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    # Decode and extract response
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract just the solution part (after the formatted prompt)
    if "Solution:" in generated:
        response = generated.split("Solution:")[-1].strip()
    else:
        response = generated[len(formatted_prompt):].strip()
    
    print(f"Model's solution: {response[:200]}...")  # Show first 200 chars
    
    # Extract answer and check correctness
    model_answer = extract_number(response)
    correct_answer = test_answers[i-1]
    
    if model_answer:
        print(f"Extracted answer: {model_answer}")
        print(f"Correct answer: {correct_answer}")
        is_correct = model_answer == correct_answer
        print(f"✓ CORRECT!" if is_correct else "✗ INCORRECT")
    else:
        print("Could not extract numerical answer")
    
    # Compute reward
    reward = compute_math_reward(problem, response)
    print(f"Reward Score: {reward:.3f}")

print("\n" + "=" * 50)

## 11. Save Final Model

In [None]:
# Save the final trained model
final_model_path = "models/grpo_qwen_trained"
os.makedirs(final_model_path, exist_ok=True)

print(f"Saving final model to {final_model_path}...")

# Save model and tokenizer
grpo.model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)

# Save training config and metrics
with open(f"{final_model_path}/training_config.json", "w") as f:
    json.dump(grpo_config, f, indent=2)

with open(f"{final_model_path}/training_metrics.json", "w") as f:
    json.dump(training_metrics, f, indent=2)

print("Model saved successfully!")
print(f"\nTo load the model later:")
print(f"model = AutoModelForCausalLM.from_pretrained('{final_model_path}')")
print(f"tokenizer = AutoTokenizer.from_pretrained('{final_model_path}')")

## 12. Experiment with Different Parameters

You can modify the GRPO parameters to see how they affect training:

In [None]:
# Experimental configurations to try
experiments = {
    "High KL Penalty": {
        "algorithm": {"kl_coef": 0.2},  # Stronger KL penalty
    },
    "Large Group Size": {
        "algorithm": {"group_size": 8},  # More completions per prompt
        "training": {"batch_size": 16},  # Adjust batch size accordingly
    },
    "No Clipping": {
        "algorithm": {"clip_range": None},  # Disable PPO clipping
    },
    "High Temperature": {
        "training": {"temperature": 1.2},  # More diverse generations
    },
}

print("Experimental configurations available:")
for name, params in experiments.items():
    print(f"\n{name}:")
    print(json.dumps(params, indent=2))

print("\nTo use an experimental config, modify the grpo_config in cell 5 and re-run from there.")

## Summary

This notebook demonstrated GRPO training for mathematical reasoning:

### Key Features Implemented:
1. **Math Dataset**: Used GSM8K dataset with grade school math problems
2. **Custom Prompt Formatting**: 
   - System prompts for math tutoring context
   - Structured problem presentation
   - Dynamic prompt format switching
3. **Math-Specific Reward Function**:
   - Rewards correct numerical answers (highest weight)
   - Rewards showing mathematical work/steps
   - Penalizes wrong answers and poor reasoning
4. **Training Loop**: Optimized model to solve math problems step-by-step
5. **Evaluation**: Tested on unseen math problems with answer extraction

### Prompt Customization Features:
- `format_prompt()`: Applies system prompt, template, and prefix
- `set_generation_prompt()`: Dynamically change formats during runtime
- `use_formatting` parameter: Control when formatting is applied
- Multiple format styles: Step-by-step, Chain-of-thought, Structured

### Next Steps:
- Scale to larger math datasets (full GSM8K, MATH dataset)
- Implement more sophisticated answer extraction
- Add curriculum learning (easy to hard problems)
- Use a dedicated math reward model
- Fine-tune on specific math domains (algebra, geometry, etc.)
- Experiment with different prompt formats for better performance