# Advanced GRPO Fine-tuning for Mathematical Reasoning with Multi-Reward Training

_Authored by: [Behrooz Azarkhalili](https://github.com/behroozazarkhalili)_

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/behroozazarkhalili/GRPO-Qwen-Finetuning-Unsloth/blob/master/TRL_GRPO_Reasoning.ipynb)

## 🔗 Related Cookbook Examples

This notebook builds upon the existing [Post training an LLM for reasoning with GRPO in TRL](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) example. While the basic GRPO cookbook demonstrates the fundamentals of GRPO training, this advanced tutorial focuses on:

### Key Differences from the Basic GRPO Example:

| **Basic GRPO Example** | **This Advanced Example** |
|-------------------------|---------------------------|
| Single reward function | **Multiple reward functions** (4 different rewards) |
| Basic format checking | **Advanced format validation** with regex patterns |
| Simple training setup | **Interactive training dashboard** with real-time metrics |
| Standard evaluation | **Comprehensive testing** with detailed error analysis |
| Basic dataset processing | **Advanced dataset preparation** with custom tokenization |

## 🎯 What You'll Learn

This comprehensive tutorial demonstrates **advanced GRPO (Group Relative Policy Optimization)** techniques for mathematical reasoning, specifically:

1. **Multi-Reward System Design**: Learn to create and balance multiple reward functions for complex tasks
2. **Advanced Format Validation**: Implement sophisticated regex-based reward functions for structured outputs
3. **Interactive Training Monitoring**: Build real-time training dashboards with metrics visualization
4. **Production-Ready Implementation**: Design robust training pipelines with comprehensive error handling
5. **Performance Optimization**: Use memory-efficient techniques like 4-bit quantization and LoRA adapters

## 🧠 Deep Dive: What is GRPO?

**Group Relative Policy Optimization (GRPO)** is a sophisticated reinforcement learning technique that improves upon traditional policy optimization methods by:

### Core Principles:
- **Relative Comparison**: Unlike absolute reward systems, GRPO compares responses within groups, making training more stable and robust
- **Group-Based Learning**: Multiple candidate responses are generated and ranked relatively, reducing variance in training
- **Multi-Objective Optimization**: Can handle multiple reward signals simultaneously, perfect for complex tasks like mathematical reasoning

### Why GRPO for Mathematical Reasoning?
Mathematical reasoning requires:
1. **Format Adherence**: Solutions must follow specific structural patterns
2. **Logical Consistency**: Step-by-step reasoning must be coherent
3. **Accuracy**: Final answers must be mathematically correct
4. **Clarity**: Explanations should be understandable

GRPO excels at this because it can optimize for all these criteria simultaneously through multiple reward functions.

## 🎯 Advanced Features of This Implementation

### 🔥 Multi-Reward Training System
This notebook implements **4 specialized reward functions**:

1. **`match_format_exactly`**: Ensures perfect adherence to the required output format
2. **`match_format_approximately`**: Provides partial rewards for near-correct formatting
3. **`check_answer_correctness`**: Validates mathematical accuracy with fuzzy matching
4. **`check_numbers_extraction`**: Verifies numerical answer extraction and parsing

### 📊 Interactive Training Dashboard
Features a **HuggingFace-style training table** that displays:
- Real-time loss and reward metrics
- Running averages and standard deviations
- Best reward tracking across training steps
- Gradient norms and KL divergence monitoring
- JSON logging for complete training history

### 🧠 Structured Reasoning Format
The model learns to generate responses in this specific format:
```
<start_working_out>
[Step-by-step mathematical reasoning]
<end_working_out>

<SOLUTION>
[Final numerical answer]
</SOLUTION>
```

This structure ensures that the model not only provides correct answers but also shows its reasoning process, making it interpretable and trustworthy.

## 📚 Dataset: GSM8K Mathematical Reasoning

We use the **GSM8K** (Grade School Math 8K) dataset - a carefully curated collection of:
- **8,000+ mathematical word problems** requiring multi-step reasoning
- **Linguistically diverse problems** covering various mathematical concepts
- **Step-by-step solutions** that require logical reasoning chains
- **Numerical answers** that can be objectively evaluated

The dataset is perfect for GRPO training because it provides clear success criteria (correct numerical answers) while requiring complex reasoning processes that benefit from multi-reward optimization.

## 🛠️ Technical Implementation Details

### Memory Efficiency
- **4-bit Quantization**: Using NF4 quantization for significant memory savings
- **LoRA Adapters**: Parameter-efficient fine-tuning targeting specific attention layers
- **Gradient Checkpointing**: Reduced memory usage during backpropagation
- **Optimized Batch Sizes**: Configured for consumer GPUs (16-24GB VRAM)

### Training Stability
- **Advanced Logging**: Comprehensive metrics tracking with automated visualization
- **Error Handling**: Robust error handling for malformed outputs and edge cases
- **Gradient Clipping**: Preventing gradient explosion in RL training
- **Learning Rate Scheduling**: Cosine annealing for stable convergence

Let's dive into the implementation! 🚀

## Installation

In [None]:
# Install required packages
!pip install transformers datasets trl bitsandbytes peft

## GPU Setup and Environment Configuration

**Colab Compatibility**: This notebook automatically detects and configures the available GPU environment. In Colab, typically only one GPU is available, so we ensure optimal single-GPU training setup.

In [1]:
# GPU Setup - Colab Compatible
import torch

# Auto-detect GPU configuration for Colab compatibility
# In Colab, there's typically only one GPU available, so we don't need to specify CUDA_VISIBLE_DEVICES
# Only set CUDA_VISIBLE_DEVICES if you're running on a multi-GPU system and want to restrict to specific GPUs

# For multi-GPU systems (uncomment if needed):
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use first available GPU

# Verify GPU setup
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"GPU name: {torch.cuda.get_device_name()}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️  No GPU available. This notebook requires a GPU for efficient training.")
    print("In Colab: Runtime → Change runtime type → Hardware accelerator → GPU")

CUDA available: True
Number of GPUs: 2
Current GPU: 0
GPU name: NVIDIA H100 NVL
GPU memory: 100.0 GB


## Imports and Setup

In [2]:
import torch
import re
from trl import GRPOConfig, GRPOTrainer

from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset

import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

INFO 07-29 06:16:33 [__init__.py:244] Automatically detected platform cuda.


In [3]:
# Model Configuration and Loading
# This section demonstrates loading a quantized model with LoRA for memory-efficient training

model_name = "Qwen/Qwen2.5-3B-Instruct"  # Compact model suitable for Colab/consumer GPUs
# Alternative models you can try:
# model_name = "microsoft/DialoGPT-small"  # Even smaller for very limited memory
# model_name = "google/gemma-2b"           # Google's efficient 2B parameter model

max_seq_length = 2048  # Sequence length - reduce if you encounter memory issues

# 4-bit Quantization Configuration
# This reduces memory usage by ~75% with minimal performance loss
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,                    # Enable 4-bit quantization
    bnb_4bit_quant_type="nf4",           # Use NormalFloat4 quantization
    bnb_4bit_compute_dtype=torch.float16, # Use float16 for computations
    bnb_4bit_use_double_quant=True,      # Double quantization for additional memory savings
)

print(f"Loading model: {model_name}")
print("Using 4-bit quantization for memory efficiency...")

# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",  # Automatically distribute model across available devices
    trust_remote_code=True,
    torch_dtype=torch.float16,  # Use float16 to save memory
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True
)

# Ensure tokenizer has required tokens
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"✅ Model loaded successfully!")
print(f"📍 Model device: {next(model.parameters()).device}")
print(f"📊 Tokenizer vocab size: {len(tokenizer):,}")
print(f"🔢 Model parameters: ~{sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

# Check GPU memory usage
if torch.cuda.is_available():
    print(f"🔥 GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"📈 GPU memory cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

Loading model: Qwen/Qwen2.5-3B-Instruct
Using 4-bit quantization for memory efficiency...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Model loaded successfully!
📍 Model device: cuda:0
📊 Tokenizer vocab size: 151,665
🔢 Model parameters: ~1698.7M
🔥 GPU memory allocated: 0.90 GB
📈 GPU memory cached: 1.17 GB


## LoRA Configuration

**LoRA (Low-Rank Adaptation)** is a parameter-efficient fine-tuning technique that:
- Only trains ~0.1% of the model parameters (dramatically reducing memory requirements)
- Achieves performance comparable to full fine-tuning
- Allows easy switching between different task-specific adaptations

**Key LoRA Parameters:**
- `r`: Rank determines adaptation capacity (higher = more parameters but better adaptation)
- `alpha`: Scaling factor that controls the magnitude of LoRA updates
- `target_modules`: Which attention layers to adapt (query and value projections work best)
- `dropout`: Regularization to prevent overfitting

In [4]:
# LoRA Configuration for Mathematical Reasoning
# Optimized for mathematical reasoning tasks that require precise attention patterns

lora_config = LoraConfig(
    r=16,                              # Rank: Balance between efficiency and adaptation capacity
    lora_alpha=32,                     # Alpha: 2x rank is a good starting point
    target_modules=["q_proj", "v_proj"], # Target query and value projections for attention
    lora_dropout=0.1,                  # Dropout for regularization
    bias="none",                       # No bias adaptation to keep it simple
    task_type=TaskType.CAUSAL_LM,      # Causal language modeling task
)

print("Applying LoRA configuration to model...")

# Apply LoRA to the model
model = get_peft_model(model, lora_config)

# Display trainable parameters
print("📊 LoRA Training Parameters Summary:")
model.print_trainable_parameters()

# Calculate the percentage of trainable parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
percentage = (trainable_params / total_params) * 100

print(f"\n🎯 Training efficiency:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Percentage trainable: {percentage:.3f}%")
print(f"   Memory reduction: ~{100-percentage:.1f}% compared to full fine-tuning")

Applying LoRA configuration to model...
📊 LoRA Training Parameters Summary:
trainable params: 3,686,400 || all params: 3,089,625,088 || trainable%: 0.1193

🎯 Training efficiency:
   Total parameters: 1,702,359,040
   Trainable parameters: 3,686,400
   Percentage trainable: 0.217%
   Memory reduction: ~99.8% compared to full fine-tuning


## Dataset Preparation: GSM8K Mathematical Reasoning

**About GSM8K**: The Grade School Math 8K dataset contains 8,500+ mathematical word problems that require multi-step reasoning. Each problem:
- Tests elementary mathematical concepts (addition, subtraction, multiplication, division)
- Requires 2-8 reasoning steps to solve
- Has a clear numerical answer that can be objectively evaluated
- Includes step-by-step solution explanations

**Our Structured Format**: We train the model to generate responses with clear reasoning sections:
1. `<start_working_out>` ... `<end_working_out>`: Step-by-step mathematical reasoning
2. `<SOLUTION>` ... `</SOLUTION>`: Final numerical answer

This structure ensures the model shows its work and provides interpretable solutions.

In [5]:
# Dataset Loading and Processing
# We structure the dataset to promote clear mathematical reasoning

# Define reasoning format tokens for structured output
reasoning_start = "<start_working_out>"
reasoning_end = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

# System prompt that teaches the model our desired reasoning format
system_prompt = f"""You are a mathematical reasoning assistant.
When given a math problem:
1. Show your step-by-step work between {reasoning_start} and {reasoning_end}
2. Provide your final numerical answer between {solution_start} and {solution_end}
3. Be precise and show all calculation steps clearly."""

def extract_hash_answer(text):
    """Extract numerical answer from GSM8K format (after #### marker)"""
    if "####" not in text:
        return None
    # GSM8K answers are formatted as "#### 42" 
    return text.split("####")[1].strip()

def process_dataset_example(example):
    """Convert GSM8K example to our training format"""
    question = example["question"]
    answer = extract_hash_answer(example["answer"])
    
    # Create conversation format with system prompt
    prompt = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]
    
    return {
        "prompt": prompt,
        "answer": answer,  # Ground truth numerical answer for evaluation
    }

# Load GSM8K dataset
print("🔄 Loading GSM8K dataset...")
dataset = load_dataset("openai/gsm8k", "main", split="train")

# Process the dataset for our training format
dataset = dataset.map(process_dataset_example)

print(f"✅ Dataset loaded and processed!")
print(f"📊 Training examples: {len(dataset):,}")
print(f"🎯 Sample question: {dataset[0]['prompt'][1]['content'][:100]}...")
print(f"🎯 Sample answer: {dataset[0]['answer']}")

# Show a complete example
print(f"\n📋 Complete example format:")
print(f"Question: {dataset[0]['prompt'][1]['content']}")
print(f"Expected Answer: {dataset[0]['answer']}")
print(f"System Prompt: {system_prompt}")

🔄 Loading GSM8K dataset...


Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

✅ Dataset loaded and processed!
📊 Training examples: 7,473
🎯 Sample question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How m...
🎯 Sample answer: 72

📋 Complete example format:
Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
Expected Answer: 72
System Prompt: You are a mathematical reasoning assistant.
When given a math problem:
1. Show your step-by-step work between <start_working_out> and <end_working_out>
2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>
3. Be precise and show all calculation steps clearly.


## Multi-Reward System: The Heart of Advanced GRPO

**Why Multiple Rewards?** Traditional RL uses a single reward signal, but mathematical reasoning requires multiple criteria:

1. **Format Compliance**: Does the response follow the required structure?
2. **Mathematical Accuracy**: Is the final answer numerically correct?
3. **Reasoning Quality**: Are the intermediate steps logical and clear?
4. **Robustness**: Can the model handle edge cases and variations?

**Our 4-Reward System:**
- `match_format_exactly`: High reward (3.0) for perfect format adherence
- `match_format_approximately`: Partial rewards (±0.5) for near-correct formatting  
- `check_answer_correctness`: Graduated rewards (3.0 → 1.5 → 0.5 → -1.0) based on answer quality
- `check_numbers_extraction`: Binary reward (1.5 or 0.0) for successful number extraction

This multi-objective approach ensures the model learns both **structure** and **content** simultaneously.

In [6]:
# Regex patterns for reward functions
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL
)

match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})",
    flags=re.MULTILINE | re.DOTALL
)

def match_format_exactly(completions, **kwargs):
    """Reward for exact format matching"""
    scores = []
    for completion in completions:
        response = completion[0]["content"]
        score = 3.0 if match_format.search(response) is not None else 0.0
        scores.append(score)
    return scores

def match_format_approximately(completions, **kwargs):
    """Reward for approximate format matching"""
    scores = []
    for completion in completions:
        response = completion[0]["content"]
        score = 0
        
        # Count occurrences of format tokens
        score += 0.5 if response.count(reasoning_start) == 1 else -0.5
        score += 0.5 if response.count(reasoning_end) == 1 else -0.5
        score += 0.5 if response.count(solution_start) == 1 else -0.5
        score += 0.5 if response.count(solution_end) == 1 else -0.5
        
        scores.append(score)
    return scores

def check_answer_correctness(prompts, completions, answer, **kwargs):
    """Reward for correct answers"""
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]
    
    extracted_responses = [
        guess.group(1) if (guess := match_format.search(r)) is not None else None
        for r in responses
    ]
    
    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
            
        # Exact match gets full points
        if guess == true_answer:
            scores.append(3.0)
        # Strip whitespace and try again
        elif guess.strip() == true_answer.strip():
            scores.append(1.5)
        else:
            # Try numerical comparison
            try:
                ratio = float(guess) / float(true_answer)
                if 0.9 <= ratio <= 1.1:
                    scores.append(0.5)
                elif 0.8 <= ratio <= 1.2:
                    scores.append(0.25)
                else:
                    scores.append(-1.0)
            except:
                scores.append(-0.5)
    
    return scores

def check_numbers_extraction(prompts, completions, answer, **kwargs):
    """Reward for extracting numbers from solution"""
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]
    
    extracted_responses = [
        guess.group(1) if (guess := match_numbers.search(r)) is not None else None
        for r in responses
    ]
    
    scores = []
    print('*' * 20, f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
            
        try:
            true_answer = float(true_answer.strip())
            guess = float(guess.strip())
            scores.append(1.5 if guess == true_answer else 0.0)
        except:
            scores.append(0)
    
    return scores

print("Reward functions defined")

Reward functions defined


## Training Configuration: Optimized for Mathematical Reasoning

**GRPO Training Parameters**: These settings are specifically tuned for mathematical reasoning tasks:

- **Learning Rate (5e-6)**: Conservative rate prevents overfitting on mathematical patterns
- **Batch Size Strategy**: Small per-device batches (2) with high gradient accumulation (8) for memory efficiency
- **Sequence Lengths**: Balanced prompt (1024) and completion (1024) lengths for reasoning tasks
- **Optimizer**: AdamW with weight decay for stable convergence
- **Scheduler**: Cosine annealing for smooth learning rate decay

**Memory Optimization for Colab/Consumer GPUs**:
- Gradient checkpointing reduces memory by ~50%
- Mixed precision training (float16) saves additional memory
- Small batch sizes prevent OOM errors on 16GB GPUs

In [None]:
# GRPO Training Configuration
# Optimized for mathematical reasoning with memory-efficient settings

training_args = GRPOConfig(
    # Core learning parameters
    learning_rate=5e-6,              # Conservative LR for stable mathematical reasoning
    adam_beta1=0.9,                  # Adam optimizer beta1
    adam_beta2=0.99,                 # Adam optimizer beta2 
    weight_decay=0.1,                # Regularization to prevent overfitting
    warmup_ratio=0.1,                # 10% warmup for stable training start
    lr_scheduler_type="cosine",      # Cosine annealing for smooth decay
    optim="adamw_torch_fused",       # Fused AdamW for better performance
    
    # Memory and batch size settings (Colab-optimized)
    per_device_train_batch_size=2,   # Small batch size for memory efficiency
    gradient_accumulation_steps=8,   # Maintain effective batch size of 16
    max_prompt_length=1024,          # Reasonable prompt length for math problems
    max_completion_length=1024,      # Sufficient space for detailed reasoning
    
    # Training duration and checkpointing
    max_steps=100,                    # Short training for demo (increase for production)
    save_steps=100,                   # Save checkpoint every 100 steps
    eval_steps=100,                    # Evaluate every 100 steps for detailed monitoring

    # Stability and performance
    max_grad_norm=0.1,               # Gradient clipping for training stability
    dataloader_drop_last=True,       # Consistent batch sizes
    
    # Logging and output
    logging_steps=1,                 # Log every step for detailed monitoring
    output_dir="./trl_grpo_outputs", # Output directory for checkpoints
    logging_dir="./logs",            # Directory for tensorboard logs
    report_to="none",                # Disable external reporting for simplicity
    log_level="info",                # Detailed logging
    logging_first_step=True,         # Log the first step
    logging_nan_inf_filter=True,     # Filter out NaN/inf values in logs
    
    # Evaluation settings
    metric_for_best_model="reward",  # Use reward as the primary metric
    greater_is_better=True,          # Higher rewards are better
    disable_tqdm=False,              # Keep progress bar enabled
)

# Display configuration summary
print("🎯 GRPO Training Configuration Summary:")
print(f"   Learning rate: {training_args.learning_rate}")
print(f"   Per-device batch size: {training_args.per_device_train_batch_size}")
print(f"   Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"   Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"   Max training steps: {training_args.max_steps}")
print(f"   Max prompt length: {training_args.max_prompt_length}")
print(f"   Max completion length: {training_args.max_completion_length}")
print(f"   Optimizer: {training_args.optim}")
print(f"   LR scheduler: {training_args.lr_scheduler_type}")
print(f"   Weight decay: {training_args.weight_decay}")

# Memory estimation
print(f"\n💾 Memory Optimization:")
print(f"   Gradient checkpointing: Enabled (in model config)")
print(f"   Mixed precision: float16 (from quantization)")
print(f"   Expected GPU memory usage: ~8-12GB")

🎯 GRPO Training Configuration Summary:
   Learning rate: 5e-06
   Per-device batch size: 2
   Gradient accumulation: 8
   Effective batch size: 16
   Max training steps: 10
   Max prompt length: 1024
   Max completion length: 1024
   Optimizer: OptimizerNames.ADAMW_TORCH_FUSED
   LR scheduler: SchedulerType.COSINE
   Weight decay: 0.1

💾 Memory Optimization:
   Gradient checkpointing: Enabled (in model config)
   Mixed precision: float16 (from quantization)
   Expected GPU memory usage: ~8-12GB


## Initialize Trainer

In [8]:
from transformers.trainer_callback import TrainerCallback
import time
import json
import os
from collections import defaultdict, deque
import pandas as pd
from IPython.display import display, HTML, clear_output
import io
import sys

class HuggingFaceStyleTableCallback(TrainerCallback):
    """Callback that displays a HuggingFace-style interactive table that updates in-place below progress bar"""
    
    def __init__(self, log_file="training_logs.json"):
        self.recent_rewards = deque(maxlen=5)
        self.best_reward = float('-inf')
        self.log_file = log_file
        self.all_logs = []
        self.metrics_data = []
        self.table_displayed = False
        
    def on_train_begin(self, args, state, control, **kwargs):
        """Initialize logging"""
        self.all_logs = []
        self.metrics_data = []
        self.table_displayed = False
        print(f"Training started - HuggingFace-style interactive table will appear below progress bar")
        print(f"All logs saved to: {self.log_file}")
        
    def _display_hf_style_table(self):
        """Display HuggingFace-style table that updates in-place"""
        if not self.metrics_data:
            return
            
        # Create DataFrame from metrics data
        df = pd.DataFrame(self.metrics_data)
        
        # Style similar to HuggingFace trainer tables
        styled_html = f"""
        <div style="margin: 10px 0;">
        <table style="border-collapse: collapse; margin: auto; width: 100%; max-width: 1000px;">
        <thead>
        <tr style="border-bottom: 2px solid #dee2e6;">
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Step</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Training-Loss</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Reward</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Reward-Avg</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Reward-Std</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Reward-Best</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Grad-Norm</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">KL-Div</th>
        </tr>
        </thead>
        <tbody>
        """
        
        for _, row in df.iterrows():
            styled_html += f"""
            <tr style="{''}">
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{int(row['Step'])}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['Training-Loss']:.6f}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['Reward']:.6f}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['Reward-Avg']:.6f}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['Reward-Std']:.6f}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['Reward-Best']:.6f}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['Grad-Norm']:.6f}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['KL-Div']:.6f}</td>
            </tr>
            """
        
        styled_html += """
        </tbody>
        </table>
        </div>
        """
        
        # Use clear_output to update in-place like HuggingFace trainers do
        if self.table_displayed:
            clear_output(wait=True)
        
        print("TRAINING METRICS:")
        display(HTML(styled_html))
        self.table_displayed = True
        
    def on_log(self, args, state, control, logs=None, **kwargs):
        """Log metrics and update interactive table in-place"""
        if logs is None:
            return
            
        current_step = state.global_step
        
        # Save complete logs to JSON file with Unix timestamp
        log_entry = {
            "step": current_step,
            "timestamp": time.time(),
            **logs
        }
        self.all_logs.append(log_entry)
        
        # Write to JSON file after each step
        try:
            with open(self.log_file, 'w') as f:
                json.dump(self.all_logs, f, indent=2)
        except Exception as e:
            pass  # Silent fail
        
        # Extract metrics from logs
        reward = logs.get('reward', 0.0)
        reward_std = logs.get('reward_std', 0.0)
        loss = logs.get('loss', 0.0)
        kl_div = logs.get('kl', 0.0)
        grad_norm = logs.get('grad_norm', 0.0)
        
        # Track enhanced metrics
        if reward != 0:
            self.recent_rewards.append(reward)
            if reward > self.best_reward:
                self.best_reward = reward
                
        reward_avg = sum(self.recent_rewards) / len(self.recent_rewards) if self.recent_rewards else 0.0
        
        # Add new row to metrics data
        new_row = {
            'Step': current_step,
            'Training-Loss': loss,
            'Reward': reward,
            'Reward-Avg': reward_avg,
            'Reward-Std': reward_std,
            'Reward-Best': self.best_reward,
            'Grad-Norm': grad_norm,
            'KL-Div': kl_div,
        }
        self.metrics_data.append(new_row)
        
        # Update the table in-place
        self._display_hf_style_table()

In [9]:
# Initialize GRPO trainer with HuggingFace-style interactive table callback
hf_table_callback = HuggingFaceStyleTableCallback()

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[
        match_format_exactly,
        match_format_approximately,
        check_answer_correctness,
        check_numbers_extraction,
    ],
    args=training_args,
    train_dataset=dataset,
    callbacks=[hf_table_callback],  # Add HuggingFace-style table callback
)

You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set to `True` to avoid any unexpected behavior such as device placement mismatching.
max_steps is given, it will override any value given in num_train_epochs
Using auto half precision backend
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [10]:
# # Initialize GRPO trainer without HuggingFace-style table callback
# Uncomment the following lines if you want to use the original GRPOTrainer without the interactive table

# trainer = GRPOTrainer(
#     model=model,
#     processing_class=tokenizer,
#     reward_funcs=[
#         match_format_exactly,
#         match_format_approximately,
#         check_answer_correctness,
#         check_numbers_extraction,
#     ],
#     args=training_args,
#     train_dataset=dataset,
# )

# print("Trainer initialized!")
# print(f"Model device: {model.device}")
# print(f"Number of training examples: {len(dataset)}")

## Start Training: Interactive GRPO with Real-Time Metrics

**What to Expect During Training:**
1. **Initial Steps**: Rewards may be low/negative as the model learns the format
2. **Format Learning**: First, the model learns to use the required structure
3. **Content Improvement**: Then, mathematical accuracy gradually improves
4. **Convergence**: Both format and content rewards should increase together

**Monitoring Progress:**
- Watch the **"Reward"** column for overall performance
- **"Reward-Avg"** shows the moving average over recent steps
- **"Reward-Best"** tracks the highest reward achieved
- Low **"KL-Div"** indicates stable learning (not deviating too much from base model)

In [11]:
# Start training
print("Starting GRPO training...")
print("This may take a while. Monitor the reward column for improvements.")
print("Initial rewards might be low/negative - this is normal.")

trainer.train()

print("Training completed!")

TRAINING METRICS:


Step,Training-Loss,Reward,Reward-Avg,Reward-Std,Reward-Best,Grad-Norm,KL-Div
1,-0.0572,2.71875,2.71875,2.306226,2.71875,0.078522,0.0
2,0.0661,4.21875,3.46875,3.064589,4.21875,0.076023,0.0
3,0.0745,4.0625,3.666667,3.308233,4.21875,0.087921,0.0
4,0.0022,5.0,4.0,3.04891,5.0,0.148691,0.0
5,-0.0092,5.65625,4.33125,3.322756,5.65625,0.097224,0.0
6,-0.0278,3.53125,4.49375,3.21696,5.65625,0.104116,0.0
7,0.0096,2.9375,4.2375,3.683079,5.65625,0.096031,0.0
8,-0.0174,3.640625,4.153125,3.819517,5.65625,0.07287,0.0
9,-0.0093,3.0625,3.765625,3.002159,5.65625,0.077211,0.0
10,0.0275,3.65625,3.365625,3.295985,5.65625,0.115191,0.0


Training completed!


## Evaluate the Trained Model

**Testing Strategy**: We evaluate the model on both simple and complex mathematical problems to assess:
1. **Format Adherence**: Does it use the required reasoning structure?
2. **Mathematical Accuracy**: Are the calculations correct?
3. **Reasoning Quality**: Is the step-by-step logic sound?
4. **Robustness**: Can it handle variations in problem types?

In [16]:
# Model Testing and Evaluation
# Comprehensive testing function with proper generation parameters

def test_model(question, max_length=512):
    """
    Test the trained model on a mathematical question
    
    Args:
        question (str): Mathematical question to solve
        max_length (int): Maximum tokens to generate
        
    Returns:
        str: Model's response with reasoning and solution
    """
    # Prepare the conversation
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]
    
    # Format using the tokenizer's chat template
    text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )
    
    # Tokenize and move to device
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    
    # Generate response with optimized parameters
    print(f"🤔 Thinking about: {question}")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_length,
            temperature=0.7,                    # Moderate randomness for reasoning
            do_sample=True,                     # Enable sampling for diverse reasoning
            top_p=0.9,                         # Nucleus sampling for quality
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,            # Reduce repetition in reasoning
            length_penalty=1.0,                # Neutral length preference
            early_stopping=True,               # Stop at natural endpoints
        )
    
    # Decode the generated response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the newly generated part (after the prompt)
    generated_text = response[len(text):].strip()
    
    return generated_text

## Test with GSM8K-Style Problem

Let's test with a more complex word problem that requires multi-step reasoning, similar to the GSM8K dataset format.

In [17]:
# Test with a Complex GSM8K-Style Problem
# This tests multi-step reasoning capabilities

print("🧠 Testing Complex Mathematical Reasoning:")

# Use the actual GSM8K example from our dataset
gsm8k_question = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
expected_answer = "72"

# Test the model
gsm8k_response = test_model(gsm8k_question, max_length=768)

print(f"Question: {gsm8k_question}")
print(f"\nModel Response:\n{gsm8k_response}")

# Detailed analysis
print(f"\n📊 Detailed Analysis:")
has_reasoning = reasoning_start in gsm8k_response and reasoning_end in gsm8k_response
has_solution = solution_start in gsm8k_response and solution_end in gsm8k_response

print(f"   📋 Format Analysis:")
print(f"      ✅ Contains reasoning section: {has_reasoning}")
print(f"      ✅ Contains solution section: {has_solution}")

# Extract reasoning if present
if has_reasoning:
    try:
        reasoning_text = gsm8k_response.split(reasoning_start)[1].split(reasoning_end)[0].strip()
        print(f"   🤔 Reasoning Steps:")
        print(f"      {reasoning_text[:200]}..." if len(reasoning_text) > 200 else f"      {reasoning_text}")
    except:
        print(f"   ⚠️  Could not extract reasoning cleanly")

# Extract solution if present
if has_solution:
    try:
        solution_text = gsm8k_response.split(solution_start)[1].split(solution_end)[0].strip()
        print(f"   🎯 Extracted Answer: '{solution_text}'")
        print(f"   ✅ Expected Answer: '{expected_answer}'")
        
        # Check numerical accuracy
        try:
            extracted_number = ''.join(filter(str.isdigit, solution_text))
            expected_number = ''.join(filter(str.isdigit, expected_answer))
            is_correct = extracted_number == expected_number
            print(f"   {'✅' if is_correct else '❌'} Numerical Accuracy: {is_correct}")
        except:
            print(f"   ⚠️  Could not compare numerical values")
            
    except:
        print(f"   ⚠️  Could not extract solution cleanly")


The following generation flags are not valid and may be ignored: ['early_stopping'].
- `early_stopping`: `num_beams` is set to 1. However, `early_stopping` is set to `True` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `early_stopping`.
If you're using a pretrained model, note that some of these attributes may be set through the model's `generation_config.json` file.


🧠 Testing Complex Mathematical Reasoning:
🤔 Thinking about: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?

Model Response:
_working_out>

First, we need to determine how many clips Natalia sold in May. According to the problem, Natalia sold half as many clips in May compared to April. Since she sold 48 clips in April:

- Number of clips sold in May = \(\frac{48}{2} = 24\)

Next, we add the number of clips sold in both months to find out the total number of clips sold:

- Total clips sold = Clips sold in April + Clips sold in May

Now let’s perform the addition:

- Total clips sold = 48 + 24 = 72

<end_working_out>

<SOLUTION>
72</SOLUTION>

📊 Detailed Analysis:
   📋 Format Analysis:
      ✅ Contains reasoning secti

## Memory Management

In [14]:
# Clear GPU memory if needed
import gc
torch.cuda.empty_cache()
gc.collect()
print("GPU memory cleared")

GPU memory cleared


## References

### Papers and Research
- **GRPO Algorithm**: [Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300) - The original GRPO paper introducing group-based relative policy optimization
- **GSM8K Dataset**: [Training Verifiers to Solve Math Word Problems](https://arxiv.org/abs/2110.14168) - Cobbe et al., OpenAI
- **LoRA**: [Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) - Hu et al., Microsoft
- **QLoRA**: [Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) - Dettmers et al., 4-bit quantization for efficient training

### Libraries and Frameworks
- **TRL (Transformers Reinforcement Learning)**: [HuggingFace TRL](https://github.com/huggingface/trl) - Official library for RLHF and advanced training techniques
- **Transformers**: [HuggingFace Transformers](https://github.com/huggingface/transformers) - State-of-the-art NLP library
- **PEFT**: [Parameter-Efficient Fine-Tuning](https://github.com/huggingface/peft) - Efficient adaptation methods
- **BitsAndBytes**: [8-bit & 4-bit Quantization](https://github.com/TimDettmers/bitsandbytes) - Memory-efficient training

### Models Used
- **Qwen2.5-3B-Instruct**: [Qwen Model Series](https://github.com/QwenLM/Qwen2.5) - Alibaba's instruction-tuned language model
- **Alternative Models**: Gemma-2B, DialoGPT, GPT-2 (configurable in the notebook)

### Datasets
- **GSM8K**: [OpenAI GSM8K](https://huggingface.co/datasets/openai/gsm8k) - Grade School Math 8K problems dataset
- **Format**: Mathematical word problems requiring multi-step reasoning and numerical answers

### Key Concepts
- **Reinforcement Learning from Human Feedback (RLHF)**: Training language models using reward signals
- **Group Relative Policy Optimization**: Advanced RL technique comparing responses in groups rather than absolute scoring
- **Structured Generation**: Teaching models to follow specific output formats with reasoning sections
- **Multi-Reward Training**: Using multiple reward functions for comprehensive evaluation