# GRPO Fine-tuning with TRL Only

_Authored by: [Behrooz Azarkhalili](https://github.com/behroozazarkhalili)_

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/behroozazarkhalili/GRPO-Qwen-Finetuning-Unsloth/blob/master/TRL_GRPO_Reasoning.ipynb)

This notebook demonstrates **GRPO (Group Relative Policy Optimization)** fine-tuning for mathematical reasoning using TRL (Transformers Reinforcement Learning) library without Unsloth.

## What is GRPO?

GRPO is an advanced reinforcement learning technique for fine-tuning language models that:
- **Groups similar responses** and compares them relatively instead of using absolute rewards
- **Optimizes for structured outputs** by rewarding proper format adherence and correct reasoning
- **Balances exploration vs exploitation** through group-based reward comparison
- **Handles complex reward functions** like format matching, answer correctness, and reasoning quality

## Key Features of This Implementation:

🎯 **Multi-Reward Training**: Uses 4 different reward functions:
- Format matching (exact/approximate) 
- Answer correctness with fuzzy matching
- Number extraction validation
- Structured reasoning rewards

📊 **Advanced Progress Tracking**: HuggingFace-style interactive table with real-time metrics:
- Training Loss, Reward, Average Reward, Best Reward
- Reward Standard Deviation, KL Divergence, Gradient Norm
- JSON logging for complete training history

🧠 **Mathematical Reasoning**: Trains on GSM8K dataset to generate structured responses with:
- `<start_working_out>` reasoning sections `<end_working_out>`
- `<SOLUTION>` final answers `</SOLUTION>`
- Step-by-step mathematical problem solving

⚡ **Memory Efficient**: Uses 4-bit quantization, LoRA adapters, and optimized batch sizes for training on consumer GPUs.

## Dataset: GSM8K

Using the **GSM8K** (Grade School Math 8K) dataset - a collection of 8,000+ linguistically diverse grade school math word problems requiring multi-step reasoning. The model learns to:
1. Parse the problem statement
2. Generate step-by-step working
3. Provide final numerical answers
4. Follow structured output format

## Installation

In [None]:
# Install required packages
!pip install transformers datasets trl bitsandbytes peft

## GPU Setup

Force single GPU training to avoid NCCL errors

In [None]:
# Force single GPU training to avoid NCCL errors
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use only GPU 1

# Verify GPU setup
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"GPU name: {torch.cuda.get_device_name()}")

## Imports and Setup

In [None]:
import torch
import re
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
# Model configuration
model_name = "Qwen/Qwen2.5-3B-Instruct"  # You can change this to any model you prefer
# Alternative models:
# model_name = "microsoft/DialoGPT-small"
# model_name = "gpt2"
# model_name = "google/gemma-2b"

max_seq_length = 2048

# Quantization config for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# Load model and tokenizer with correct device mapping
# Since CUDA_VISIBLE_DEVICES="1" is set, GPU 1 becomes device 0 from PyTorch's perspective
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"": 0},  # Use device 0 (which is actually GPU 1 due to CUDA_VISIBLE_DEVICES)
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True
)

print(f"Model loaded: {model_name}")
print(f"Model device: {model.device}")
print(f"Tokenizer vocab size: {len(tokenizer)}")

## LoRA Configuration

In [None]:
# LoRA configuration
lora_config = LoraConfig(
    r=16,  # Rank
    lora_alpha=32,  # Alpha
    target_modules=["q_proj", "v_proj"],  # Target modules - adjust based on your model
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

## Dataset Preparation

In [None]:
# Reasoning format tokens
reasoning_start = "<start_working_out>"
reasoning_end = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

# System prompt for reasoning
system_prompt = f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""

def extract_hash_answer(text):
    """Extract answer after #### in GSM8K format"""
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def process_dataset_example(example):
    """Process a single GSM8K example"""
    question = example["question"]
    answer = extract_hash_answer(example["answer"])
    
    prompt = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]
    
    return {
        "prompt": prompt,
        "answer": answer,
    }

# Load and process GSM8K dataset
print("Loading GSM8K dataset...")
dataset = load_dataset("openai/gsm8k", "main", split="train")
dataset = dataset.map(process_dataset_example)

print(f"Dataset loaded with {len(dataset)} examples")
print("Sample example:")
print(dataset[0])

## Reward Functions

In [None]:
# Regex patterns for reward functions
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL
)

match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})",
    flags=re.MULTILINE | re.DOTALL
)

def match_format_exactly(completions, **kwargs):
    """Reward for exact format matching"""
    scores = []
    for completion in completions:
        response = completion[0]["content"]
        score = 3.0 if match_format.search(response) is not None else 0.0
        scores.append(score)
    return scores

def match_format_approximately(completions, **kwargs):
    """Reward for approximate format matching"""
    scores = []
    for completion in completions:
        response = completion[0]["content"]
        score = 0
        
        # Count occurrences of format tokens
        score += 0.5 if response.count(reasoning_start) == 1 else -0.5
        score += 0.5 if response.count(reasoning_end) == 1 else -0.5
        score += 0.5 if response.count(solution_start) == 1 else -0.5
        score += 0.5 if response.count(solution_end) == 1 else -0.5
        
        scores.append(score)
    return scores

def check_answer_correctness(prompts, completions, answer, **kwargs):
    """Reward for correct answers"""
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]
    
    extracted_responses = [
        guess.group(1) if (guess := match_format.search(r)) is not None else None
        for r in responses
    ]
    
    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
            
        # Exact match gets full points
        if guess == true_answer:
            scores.append(3.0)
        # Strip whitespace and try again
        elif guess.strip() == true_answer.strip():
            scores.append(1.5)
        else:
            # Try numerical comparison
            try:
                ratio = float(guess) / float(true_answer)
                if 0.9 <= ratio <= 1.1:
                    scores.append(0.5)
                elif 0.8 <= ratio <= 1.2:
                    scores.append(0.25)
                else:
                    scores.append(-1.0)
            except:
                scores.append(-0.5)
    
    return scores

def check_numbers_extraction(prompts, completions, answer, **kwargs):
    """Reward for extracting numbers from solution"""
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]
    
    extracted_responses = [
        guess.group(1) if (guess := match_numbers.search(r)) is not None else None
        for r in responses
    ]
    
    scores = []
    print('*' * 20, f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
            
        try:
            true_answer = float(true_answer.strip())
            guess = float(guess.strip())
            scores.append(1.5 if guess == true_answer else 0.0)
        except:
            scores.append(0)
    
    return scores

print("Reward functions defined")

## Training Configuration

In [None]:
# GRPO Training configuration with enhanced logging
training_args = GRPOConfig(
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_torch_fused",
    logging_steps=1,  # Log every step
    per_device_train_batch_size=2,  # Start small to avoid memory issues
    gradient_accumulation_steps=8,  # Increase to maintain effective batch size
    max_prompt_length=1024,  # Reduce if needed
    max_completion_length=1024,  # Reduce if needed
    max_steps=10,  # Reduce for testing
    save_steps=10,
    eval_steps=1,  # Enable evaluation logging
    max_grad_norm=0.1,
    report_to="none",  # Disable reporting to external services
    output_dir="./trl_grpo_outputs",
    logging_dir="./logs",  # Directory for logs
    dataloader_drop_last=True,
    # Enhanced logging options
    log_level="info",
    logging_first_step=True,
    logging_nan_inf_filter=True,
    metric_for_best_model="reward",
    greater_is_better=True,
    # Keep default progress bar enabled
    disable_tqdm=False,
)

print("Training configuration with enhanced default progress bar:")
print(f"Batch size: {training_args.per_device_train_batch_size}")
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Max steps: {training_args.max_steps}")
print(f"Learning rate: {training_args.learning_rate}")
print(f"Logging every: {training_args.logging_steps} steps")
print(f"Evaluation every: {training_args.eval_steps} steps")
print(f"Default tqdm enabled: {not training_args.disable_tqdm}")

## Initialize Trainer

In [None]:
from transformers.trainer_callback import TrainerCallback
import time
import json
import os
from collections import defaultdict, deque
import pandas as pd
from IPython.display import display, HTML, clear_output
import io
import sys

class HuggingFaceStyleTableCallback(TrainerCallback):
    """Callback that displays a HuggingFace-style interactive table that updates in-place below progress bar"""
    
    def __init__(self, log_file="training_logs.json"):
        self.recent_rewards = deque(maxlen=5)
        self.best_reward = float('-inf')
        self.log_file = log_file
        self.all_logs = []
        self.metrics_data = []
        self.table_displayed = False
        
    def on_train_begin(self, args, state, control, **kwargs):
        """Initialize logging"""
        self.all_logs = []
        self.metrics_data = []
        self.table_displayed = False
        print(f"Training started - HuggingFace-style interactive table will appear below progress bar")
        print(f"All logs saved to: {self.log_file}")
        
    def _display_hf_style_table(self):
        """Display HuggingFace-style table that updates in-place"""
        if not self.metrics_data:
            return
            
        # Create DataFrame from metrics data
        df = pd.DataFrame(self.metrics_data)
        
        # Style similar to HuggingFace trainer tables
        styled_html = f"""
        <div style="margin: 10px 0;">
        <table style="border-collapse: collapse; margin: auto; width: 100%; max-width: 1000px;">
        <thead>
        <tr style="border-bottom: 2px solid #dee2e6;">
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Step</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Training-Loss</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Reward</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Reward-Avg</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Reward-Std</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Reward-Best</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">Grad-Norm</th>
        <th style="padding: 12px; text-align: center; border: 1px solid #dee2e6; font-weight: bold;">KL-Div</th>
        </tr>
        </thead>
        <tbody>
        """
        
        for _, row in df.iterrows():
            styled_html += f"""
            <tr style="{''}">
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{int(row['Step'])}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['Training-Loss']:.6f}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['Reward']:.6f}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['Reward-Avg']:.6f}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['Reward-Std']:.6f}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['Reward-Best']:.6f}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['Grad-Norm']:.6f}</td>
            <td style="padding: 8px; text-align: center; border: 1px solid #dee2e6;">{row['KL-Div']:.6f}</td>
            </tr>
            """
        
        styled_html += """
        </tbody>
        </table>
        </div>
        """
        
        # Use clear_output to update in-place like HuggingFace trainers do
        if self.table_displayed:
            clear_output(wait=True)
        
        print("TRAINING METRICS:")
        display(HTML(styled_html))
        self.table_displayed = True
        
    def on_log(self, args, state, control, logs=None, **kwargs):
        """Log metrics and update interactive table in-place"""
        if logs is None:
            return
            
        current_step = state.global_step
        
        # Save complete logs to JSON file with Unix timestamp
        log_entry = {
            "step": current_step,
            "timestamp": time.time(),
            **logs
        }
        self.all_logs.append(log_entry)
        
        # Write to JSON file after each step
        try:
            with open(self.log_file, 'w') as f:
                json.dump(self.all_logs, f, indent=2)
        except Exception as e:
            pass  # Silent fail
        
        # Extract metrics from logs
        reward = logs.get('reward', 0.0)
        reward_std = logs.get('reward_std', 0.0)
        loss = logs.get('loss', 0.0)
        kl_div = logs.get('kl', 0.0)
        grad_norm = logs.get('grad_norm', 0.0)
        
        # Track enhanced metrics
        if reward != 0:
            self.recent_rewards.append(reward)
            if reward > self.best_reward:
                self.best_reward = reward
                
        reward_avg = sum(self.recent_rewards) / len(self.recent_rewards) if self.recent_rewards else 0.0
        
        # Add new row to metrics data
        new_row = {
            'Step': current_step,
            'Training-Loss': loss,
            'Reward': reward,
            'Reward-Avg': reward_avg,
            'Reward-Std': reward_std,
            'Reward-Best': self.best_reward,
            'Grad-Norm': grad_norm,
            'KL-Div': kl_div,
        }
        self.metrics_data.append(new_row)
        
        # Update the table in-place
        self._display_hf_style_table()

In [None]:
# Initialize GRPO trainer with HuggingFace-style interactive table callback
hf_table_callback = HuggingFaceStyleTableCallback()

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[
        match_format_exactly,
        match_format_approximately,
        check_answer_correctness,
        check_numbers_extraction,
    ],
    args=training_args,
    train_dataset=dataset,
    callbacks=[hf_table_callback],  # Add HuggingFace-style table callback
)

In [None]:
# # Initialize GRPO trainer without HuggingFace-style table callback
# Uncomment the following lines if you want to use the original GRPOTrainer without the interactive table

# trainer = GRPOTrainer(
#     model=model,
#     processing_class=tokenizer,
#     reward_funcs=[
#         match_format_exactly,
#         match_format_approximately,
#         check_answer_correctness,
#         check_numbers_extraction,
#     ],
#     args=training_args,
#     train_dataset=dataset,
# )

# print("Trainer initialized!")
# print(f"Model device: {model.device}")
# print(f"Number of training examples: {len(dataset)}")

## Start Training

In [37]:
# Start training
print("Starting GRPO training...")
print("This may take a while. Monitor the reward column for improvements.")
print("Initial rewards might be low/negative - this is normal.")

trainer.train()

print("Training completed!")

TRAINING METRICS:


Step,TrainingLoss,Reward,Reward-Avg,Reward-Std,Reward-Best,Grad-Norm,KL-Div
1,0.0202,3.625,3.625,2.528742,3.625,0.096231,0.0
2,0.0306,3.90625,3.765625,2.631696,3.90625,0.20691,0.0
3,0.0372,3.6875,3.739583,2.016598,3.90625,0.205314,0.0
4,0.0857,3.125,3.585938,1.735451,3.90625,0.146251,0.0
5,-0.0552,3.625,3.59375,3.818543,3.90625,0.169046,0.0
6,-0.0728,3.46875,3.5625,1.627355,3.90625,0.214777,0.0
7,-0.0227,2.4375,3.26875,2.135396,3.90625,0.203865,0.0
8,-0.1158,3.40625,3.2125,2.789666,3.90625,0.155323,0.0
9,0.0272,2.625,3.1125,1.94914,3.90625,0.227745,0.0
10,-0.054,3.4375,3.075,2.856393,3.90625,0.172547,0.0


Training completed!


## Test the Model

In [38]:
# Test the trained model
def test_model(question):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]
    
    # Format the prompt
    text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )
    
    # Generate response
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract only the generated part
    response = response[len(text):].strip()
    
    return response

# Test with a sample question
test_question = "What is 25 + 17?"
print(f"Question: {test_question}")
print("\nModel Response:")
print(test_model(test_question))

Question: What is 25 + 17?

Model Response:
we can add the numbers step by step.

First, let's add the units digits (5 + 7):
\[ 5 + 7 = 12 \]

Next, we carry over the tens digit from this sum to the tens place. So we write down 2 and carry over 1.

Now, let's add the tens digits along with the carried over 1:
\[ 2 + 1 + 1 = 4 \]

Putting it all together, we get:
\[ 25 + 17 = 42 \]
<end_working_out>
<SOLUTION>
42
</SOLUTION>


## Test with GSM8K Example

In [39]:
# Test with a GSM8K example
gsm8k_question = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
print(f"GSM8K Question: {gsm8k_question}")
print("\nModel Response:")
print(test_model(gsm8k_question))
print("\nExpected Answer: 72")

GSM8K Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?

Model Response:
old altogether in April and May, we need to follow these steps:

1. Determine the number of clips sold in May, which is half the number sold in April.
2. Calculate the total number of clips sold in both months by adding the number sold in April to the number sold in May.

Given:
- Natalia sold 48 clips in April.
- She sold half as many clips in May.

First, let's calculate the number of clips sold in May:
<end_working_out]
<May_sales> = 48 / 2
<end_working_out]
<total_sales> = April_sales + May_sales
<total_sales> = 48 + (<May_sales>)
<end_working_out]
<SOLUTION>
Total clips sold by Natalia in April and May is 60.
<END_OF_SOLUTION>

Expected Answer: 72


## Memory Management

In [40]:
# Clear GPU memory if needed
import gc
torch.cuda.empty_cache()
gc.collect()
print("GPU memory cleared")

GPU memory cleared


## References

### Papers and Research
- **GRPO Algorithm**: [Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300) - The original GRPO paper introducing group-based relative policy optimization
- **GSM8K Dataset**: [Training Verifiers to Solve Math Word Problems](https://arxiv.org/abs/2110.14168) - Cobbe et al., OpenAI
- **LoRA**: [Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) - Hu et al., Microsoft
- **QLoRA**: [Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) - Dettmers et al., 4-bit quantization for efficient training

### Libraries and Frameworks
- **TRL (Transformers Reinforcement Learning)**: [HuggingFace TRL](https://github.com/huggingface/trl) - Official library for RLHF and advanced training techniques
- **Transformers**: [HuggingFace Transformers](https://github.com/huggingface/transformers) - State-of-the-art NLP library
- **PEFT**: [Parameter-Efficient Fine-Tuning](https://github.com/huggingface/peft) - Efficient adaptation methods
- **BitsAndBytes**: [8-bit & 4-bit Quantization](https://github.com/TimDettmers/bitsandbytes) - Memory-efficient training

### Models Used
- **Qwen2.5-3B-Instruct**: [Qwen Model Series](https://github.com/QwenLM/Qwen2.5) - Alibaba's instruction-tuned language model
- **Alternative Models**: Gemma-2B, DialoGPT, GPT-2 (configurable in the notebook)

### Datasets
- **GSM8K**: [OpenAI GSM8K](https://huggingface.co/datasets/openai/gsm8k) - Grade School Math 8K problems dataset
- **Format**: Mathematical word problems requiring multi-step reasoning and numerical answers

### Key Concepts
- **Reinforcement Learning from Human Feedback (RLHF)**: Training language models using reward signals
- **Group Relative Policy Optimization**: Advanced RL technique comparing responses in groups rather than absolute scoring
- **Structured Generation**: Teaching models to follow specific output formats with reasoning sections
- **Multi-Reward Training**: Using multiple reward functions for comprehensive evaluation