#  Educational GRPO Training Pipeline

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HarleyCoops/OneShotGRPO/blob/main/EducationalGRPO.ipynb)

---

##  What You'll Learn

This comprehensive notebook teaches you how to train small language models using **GRPO (Generative Reinforcement Policy Optimization)** for mathematical reasoning. You'll learn:

1. **Dataset Integration**: Load and preprocess GSM8K from HuggingFace
2. **Training Environments**: Use HuggingFace RL pipeline or Prime Intellect environments
3. **Cloud Storage**: Save checkpoints to Google Cloud Storage
4. **Advanced Monitoring**: Track training with Weights & Biases 3D visualizations
5. **Model Deployment**: Push to HuggingFace Hub with model cards
6. **Interactive Inference**: Create a Gradio chat interface

##  Learning Objectives

By the end of this notebook, you will:
- Understand GRPO and reinforcement learning for LLMs
- Configure reward functions for math reasoning
- Monitor training dynamics with comprehensive metrics
- Deploy production-ready models with proper documentation
- Build user-facing chat interfaces

##  Prerequisites

- Basic Python knowledge
- Understanding of neural networks
- Google Colab with GPU runtime (recommended: A100)
- HuggingFace account (for model deployment)
- Weights & Biases account (optional, for monitoring)
- Google Cloud project (optional, for GCS checkpoints)

---

##  Section 1: Environment Setup

### Understanding the Stack

We'll use several specialized libraries:

1. **vLLM**: High-performance inference engine with PagedAttention
   - Reduces memory usage by 50%+
   - Enables efficient batch processing during training
   - Must be installed BEFORE TRL to avoid conflicts

2. **TRL (Transformer Reinforcement Learning)**: RL training framework
   - Implements GRPO, PPO, DPO algorithms
   - Integrates with HuggingFace Transformers
   - Handles reward computation and policy updates

3. **Datasets**: Efficient data loading and processing
   - Streaming support for large datasets
   - Built-in caching and versioning
   - Native HuggingFace Hub integration

In [None]:
# STEP 1: Install vLLM (must be first!)
print(" Installing vLLM for efficient inference...")
!pip install -q vllm

print("\n  IMPORTANT: Restart runtime after vLLM installation!")
print("Go to Runtime > Restart runtime, then continue with the next cell.")

In [None]:
# STEP 2: Install remaining dependencies
print(" Installing TRL, datasets, and utilities...")
!pip install -q trl datasets transformers wandb google-cloud-storage gradio huggingface_hub

print("\n All dependencies installed!")

In [None]:
# STEP 3: Import libraries and verify installation
import re
import os
import json
import torch
from datetime import datetime
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
import wandb

print("\n Environment Check:")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    print("  Warning: No GPU detected. Training will be very slow!")

print("\n All imports successful!")

---

##  Section 2: Weights & Biases Setup

### Why W&B for GRPO?

Weights & Biases provides:
- **Real-time metrics**: Track loss, rewards, KL divergence
- **3D visualizations**: Plot reward landscapes and policy evolution
- **Hyperparameter tracking**: Compare runs automatically
- **Artifact versioning**: Track datasets and model checkpoints

### Advanced Logging Strategy

We'll log:
1. Training metrics (loss, learning rate, grad norm)
2. Reward signals (correctness, format, overall)
3. Generation samples (input prompts + model outputs)
4. 3D reward landscapes (reward vs. step vs. example)
5. Model architecture and hyperparameters

In [None]:
# Initialize Weights & Biases
print(" W&B Authentication")
wandb.login()

# Configuration for this run
WANDB_PROJECT = "grpo-math-education"
WANDB_ENTITY = None  # Will use your default entity
RUN_NAME = f"grpo-qwen-{datetime.now().strftime('%Y%m%d-%H%M%S')}"

print(f"\n W&B Project: {WANDB_PROJECT}")
print(f" Run Name: {RUN_NAME}")

---

##  Section 3: Google Cloud Storage Setup (Optional)

### Why Use GCS for Checkpoints?

Google Cloud Storage advantages:
- **Persistent storage**: Survives Colab session disconnects
- **Large capacity**: No 15GB Drive limit
- **Fast access**: Better upload/download speeds
- **Versioning**: Built-in checkpoint history
- **Team sharing**: Easy collaboration

### Alternative: Google Drive

If you don't have GCS, we'll use Google Drive (simpler but slower).

In [None]:
# Choose your storage backend
USE_GCS = False  # Set to True if you have Google Cloud Storage
USE_GDRIVE = True  # Set to True to use Google Drive

if USE_GCS:
    print("  Setting up Google Cloud Storage...")
    from google.colab import auth
    from google.cloud import storage
    
    # Authenticate
    auth.authenticate_user()
    
    # Configuration
    GCS_PROJECT = input("Enter your GCP project ID: ")
    GCS_BUCKET = input("Enter your GCS bucket name: ")
    GCS_PREFIX = f"grpo-checkpoints/{RUN_NAME}"
    
    # Initialize client
    storage_client = storage.Client(project=GCS_PROJECT)
    bucket = storage_client.bucket(GCS_BUCKET)
    
    print(f"\n Connected to gs://{GCS_BUCKET}/{GCS_PREFIX}")
    
elif USE_GDRIVE:
    print(" Mounting Google Drive...")
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Create checkpoint directory
    GDRIVE_PATH = f"/content/drive/MyDrive/grpo_checkpoints/{RUN_NAME}"
    os.makedirs(GDRIVE_PATH, exist_ok=True)
    
    print(f"\n Checkpoints will save to: {GDRIVE_PATH}")
else:
    print(" Using local storage (will be lost when runtime disconnects)")
    LOCAL_PATH = f"/content/outputs/{RUN_NAME}"
    os.makedirs(LOCAL_PATH, exist_ok=True)

---

##  Section 4: Dataset Loading and Formatting

### GSM8K Dataset

**GSM8K** (Grade School Math 8K) contains:
- 8,500+ grade-school math word problems
- Natural language solutions with reasoning steps
- Numerical answers
- Created by OpenAI for evaluating mathematical reasoning

### XML Format Strategy

We train the model to output:
```xml
<reasoning>
Step-by-step problem solving
</reasoning>
<answer>
Numerical answer
</answer>
```

This format:
1. Encourages chain-of-thought reasoning
2. Makes parsing answers easy
3. Provides interpretability
4. Enables partial credit rewards

In [None]:
# Define the system prompt and format
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

print(" Format templates defined")

In [None]:
# Answer extraction functions
def extract_xml_answer(text: str) -> str:
    """
    Extracts the answer from XML-formatted text.
    
    Example:
        Input: "<reasoning>steps</reasoning><answer>42</answer>"
        Output: "42"
    """
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    """
    Extracts the answer from GSM8K format (#### delimiter).
    
    Example:
        Input: "The answer is 42\n#### 42"
        Output: "42"
    """
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# Test the functions
test_xml = "<reasoning>2+2=4</reasoning><answer>4</answer>"
test_hash = "The total is 4\n#### 4"

print(f"XML test: '{extract_xml_answer(test_xml)}'")
print(f"Hash test: '{extract_hash_answer(test_hash)}'")
print("\n Extraction functions working correctly")

In [None]:
# Load and format GSM8K dataset
def get_gsm8k_questions(split="train", num_examples=None) -> Dataset:
    """
    Loads and preprocesses GSM8K dataset for GRPO training.
    
    Args:
        split: 'train' or 'test'
        num_examples: Optional limit on number of examples
    
    Returns:
        Dataset with formatted prompts and answers
    """
    print(f" Loading GSM8K {split} split...")
    data = load_dataset('openai/gsm8k', 'main')[split]
    
    if num_examples:
        data = data.select(range(min(num_examples, len(data))))
    
    print(f" Dataset size: {len(data)} examples")
    
    # Transform to GRPO format
    def format_example(example):
        return {
            'prompt': [
                {'role': 'system', 'content': SYSTEM_PROMPT},
                {'role': 'user', 'content': example['question']}
            ],
            'answer': extract_hash_answer(example['answer']),
            'reference_solution': example['answer']  # Keep full solution for reference
        }
    
    print(" Formatting examples...")
    data = data.map(format_example)
    
    # Show a sample
    print("\n Sample Question:")
    sample = data[0]
    print(f"Q: {sample['prompt'][1]['content'][:200]}...")
    print(f"A: {sample['answer']}")
    
    return data

# Load training data
# Use smaller subset for faster experimentation (remove num_examples for full dataset)
dataset = get_gsm8k_questions(split="train", num_examples=1000)
print("\n Dataset ready for training")

---

##  Section 5: Reward Functions

### Understanding GRPO Rewards

GRPO uses **multiple reward signals** to shape model behavior:

1. **Correctness Reward** (2.0 points)
   - Primary learning signal
   - Binary: correct answer = 2.0, incorrect = 0.0
   - Drives mathematical accuracy

2. **Format Rewards** (1.5 points total)
   - XML structure (0.5): Tags present and properly nested
   - Strict format (0.5): Exact newline and spacing
   - Soft format (0.5): Flexible whitespace
   - Ensures consistent, parseable outputs

3. **Type Reward** (0.5 points)
   - Verifies numerical answer
   - Prevents text-only responses

### Total Reward: 4.0 points

This multi-objective reward encourages:
- Correct mathematical reasoning (50%)
- Clean, structured output (37.5%)
- Proper answer formatting (12.5%)

In [None]:
# Reward Functions with W&B logging hooks

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """
    Primary reward: Checks if extracted answer matches ground truth.
    Weight: 2.0 (highest priority)
    """
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    
    # Log sample to W&B occasionally
    if kwargs.get('step', 0) % 50 == 0:
        q = prompts[0][-1]['content']
        print('-' * 20)
        print(f"Question: {q[:100]}...")
        print(f"Expected: {answer[0]}")
        print(f"Got: {extracted_responses[0]}")
        print(f"Response: {responses[0][:200]}...")
    
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    """
    Type reward: Ensures answer is numeric.
    Weight: 0.5
    """
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.replace('-', '').replace('.', '').isdigit() else 0.0 
            for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """
    Strict format reward: Exact XML structure with newlines.
    Weight: 0.5
    """
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """
    Soft format reward: Flexible XML structure.
    Weight: 0.5
    """
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.search(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text: str) -> float:
    """
    Granular XML component scoring.
    Awards 0.125 for each correct tag, with penalties for extra text.
    """
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1]) * 0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
    return max(0.0, count)

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    """
    XML structure reward: Detailed component scoring.
    Weight: 0.5
    """
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

print(" Reward functions defined")
print("\n Reward Structure:")
print("  Correctness:     2.0 (50.0%)")
print("  Format (strict): 0.5 (12.5%)")
print("  Format (soft):   0.5 (12.5%)")
print("  Format (XML):    0.5 (12.5%)")
print("  Type (numeric):  0.5 (12.5%)")
print("  " + "-" * 30)
print("  TOTAL:           4.0 (100%)")

---

##  Section 6: Model and Training Configuration

### Model Selection: Qwen2.5-0.5B-Instruct

**Why this model?**
- Small enough to train on single GPU (500M parameters)
- Pre-trained on instruction following
- Good mathematical reasoning baseline
- Fast inference for RL training

### GRPO Hyperparameters Explained

| Parameter | Value | Reasoning |
|-----------|-------|----------|
| `learning_rate` | 5e-6 | Small LR prevents catastrophic forgetting |
| `num_generations` | 16 | Multiple samples for variance reduction |
| `max_grad_norm` | 0.1 | Gradient clipping for RL stability |
| `num_train_epochs` | 1 | Single pass prevents overfitting |
| `warmup_ratio` | 0.1 | Gradual LR warmup for stability |
| `bf16` | True | Memory efficiency + numerical stability |

### vLLM Configuration

- `vllm_gpu_memory_utilization`: 0.3 (30% for inference, 70% for training)
- Enables PagedAttention for efficient KV caching
- Significantly speeds up generation during training

In [None]:
# Model configuration
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
OUTPUT_DIR = GDRIVE_PATH if USE_GDRIVE else (f"gs://{GCS_BUCKET}/{GCS_PREFIX}" if USE_GCS else LOCAL_PATH)

print(f" Model: {MODEL_NAME}")
print(f" Output directory: {OUTPUT_DIR}")

# Initialize W&B run with comprehensive config
wandb.init(
    project=WANDB_PROJECT,
    entity=WANDB_ENTITY,
    name=RUN_NAME,
    config={
        # Model config
        "model_name": MODEL_NAME,
        "model_params": "500M",
        
        # Training config
        "learning_rate": 5e-6,
        "adam_beta1": 0.9,
        "adam_beta2": 0.99,
        "weight_decay": 0.1,
        "warmup_ratio": 0.1,
        "lr_scheduler": "cosine",
        
        # Batch config
        "per_device_batch_size": 1,
        "gradient_accumulation_steps": 4,
        "effective_batch_size": 4,
        
        # Generation config
        "num_generations": 16,
        "generation_batch_size": 16,
        "max_prompt_length": 256,
        "max_completion_length": 200,
        
        # Training duration
        "num_train_epochs": 1,
        "dataset_size": len(dataset),
        
        # Regularization
        "max_grad_norm": 0.1,
        
        # Precision
        "bf16": True,
        
        # vLLM
        "use_vllm": True,
        "vllm_gpu_memory": 0.3,
        
        # Reward weights
        "reward_correctness_weight": 2.0,
        "reward_format_weight": 1.5,
        "reward_type_weight": 0.5,
    },
    tags=["grpo", "math", "gsm8k", "educational"]
)

print("\n W&B run initialized")

In [None]:
# Configure GRPO training
training_args = GRPOConfig(
    # Output
    output_dir=OUTPUT_DIR,
    run_name=RUN_NAME,
    
    # Optimizer
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    max_grad_norm=0.1,
    
    # Learning rate schedule
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',
    
    # Logging
    logging_steps=1,
    report_to="wandb",
    log_on_each_node=False,
    
    # Precision
    bf16=True,
    
    # Batch configuration
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    
    # Generation parameters
    num_generations=16,
    generation_batch_size=16,
    max_prompt_length=256,
    max_completion_length=200,
    
    # Training duration
    num_train_epochs=1,
    
    # Checkpointing
    save_steps=100,
    save_total_limit=5,  # Keep last 5 checkpoints
    
    # vLLM configuration
    use_vllm=True,
    vllm_gpu_memory_utilization=0.3,
)

print(" Training configuration created")
print(f"\n Training will run for ~{len(dataset) // 4} steps")
print(f" Checkpoints every 100 steps → ~{(len(dataset) // 4) // 100} checkpoints")

In [None]:
# Load model and tokenizer
print(f" Loading model: {MODEL_NAME}...")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Calculate model size
param_count = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n Model Statistics:")
print(f"  Total parameters: {param_count:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{param_count * 2 / 1024**3:.2f} GB (bf16)")

# Log to W&B
wandb.config.update({
    "total_params": param_count,
    "trainable_params": trainable_params
})

print("\n Model loaded successfully")

---

##  Section 7: Training Execution

### What Happens During Training?

Each training step:
1. **Sampling**: Load batch of math problems
2. **Generation**: Model generates 16 responses per problem
3. **Reward**: Each response gets 5 reward scores
4. **Policy Update**: GRPO updates model weights based on rewards
5. **Logging**: Metrics sent to W&B
6. **Checkpointing**: Save every 100 steps

### Expected Training Time

- **Dataset**: 1,000 examples
- **Effective batch size**: 4
- **Steps**: ~250
- **Time per step**: ~30-60 seconds (with A100)
- **Total time**: ~2-4 hours

### Monitoring Tips

Watch these metrics in W&B:
- **Loss**: Should decrease over time
- **Reward**: Should increase (target: 2.0+)
- **KL Divergence**: Should stay small (<1.0)
- **Learning Rate**: Should follow cosine schedule

In [None]:
# Create GRPO trainer
print("  Building GRPO trainer...")

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,      # 0.5 points
        soft_format_reward_func,   # 0.5 points
        strict_format_reward_func, # 0.5 points
        int_reward_func,           # 0.5 points
        correctness_reward_func,   # 2.0 points
    ],
    args=training_args,
    train_dataset=dataset,
)

print("\n Trainer ready")
print("\n Starting training...")
print("\n" + "="*60)
print("Monitor your run at:", wandb.run.get_url())
print("="*60 + "\n")

In [None]:
# Train the model!
train_result = trainer.train()

print("\n" + "="*60)
print(" Training complete!")
print("="*60)
print(f"\nFinal metrics:")
print(f"  Loss: {train_result.training_loss:.4f}")
print(f"  Steps: {train_result.global_step}")
print(f"  Time: {train_result.metrics['train_runtime']:.2f} seconds")
print(f"  Samples/second: {train_result.metrics['train_samples_per_second']:.2f}")

---

##  Section 8: Checkpoint Management

### Understanding Checkpoints

Each checkpoint contains:
- `model.safetensors`: Model weights
- `config.json`: Model architecture
- `tokenizer.json`: Tokenizer configuration
- `trainer_state.json`: Training progress
- `optimizer.pt`: Optimizer state
- `scheduler.pt`: LR scheduler state

### Selecting Best Checkpoint

Strategies:
1. **Latest**: Most training exposure
2. **Highest reward**: Best validation performance
3. **Lowest loss**: Most optimization progress

For this demo, we'll use the **final checkpoint**.

In [None]:
# Find the final checkpoint
import glob

if USE_GDRIVE or not USE_GCS:
    checkpoint_dirs = glob.glob(f"{OUTPUT_DIR}/checkpoint-*")
    checkpoint_dirs.sort(key=lambda x: int(x.split("-")[-1]))
    final_checkpoint = checkpoint_dirs[-1] if checkpoint_dirs else None
    
    print(f" Found {len(checkpoint_dirs)} checkpoints")
    if final_checkpoint:
        print(f" Final checkpoint: {final_checkpoint}")
        CHECKPOINT_PATH = final_checkpoint
    else:
        print("  No checkpoints found!")
        CHECKPOINT_PATH = OUTPUT_DIR
else:
    # For GCS, we'll need to list blobs
    print(" Listing GCS checkpoints...")
    blobs = bucket.list_blobs(prefix=GCS_PREFIX)
    checkpoint_nums = set()
    for blob in blobs:
        if "checkpoint-" in blob.name:
            num = blob.name.split("checkpoint-")[1].split("/")[0]
            if num.isdigit():
                checkpoint_nums.add(int(num))
    
    if checkpoint_nums:
        final_num = max(checkpoint_nums)
        CHECKPOINT_PATH = f"gs://{GCS_BUCKET}/{GCS_PREFIX}/checkpoint-{final_num}"
        print(f" Final checkpoint: {CHECKPOINT_PATH}")
    else:
        print("  No checkpoints found in GCS!")
        CHECKPOINT_PATH = f"gs://{GCS_BUCKET}/{GCS_PREFIX}"

---

##  Section 9: Model Evaluation

### Quick Inference Test

Before deployment, let's test our trained model!

In [None]:
# Load the trained model for inference
print(f" Loading trained model from {CHECKPOINT_PATH}...")

inference_model = AutoModelForCausalLM.from_pretrained(
    CHECKPOINT_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

inference_tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH)

print(" Model loaded for inference")

def solve_math_problem(question: str, temperature=0.7, max_tokens=200):
    """
    Generate a solution to a math problem.
    
    Args:
        question: Math word problem
        temperature: Sampling temperature (higher = more creative)
        max_tokens: Maximum response length
    
    Returns:
        dict with reasoning and answer
    """
    prompt = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": question}
    ]
    
    input_text = inference_tokenizer.apply_chat_template(
        prompt,
        tokenize=False,
        add_generation_prompt=True
    )
    
    inputs = inference_tokenizer(input_text, return_tensors="pt").to(inference_model.device)
    
    outputs = inference_model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        temperature=temperature,
        do_sample=True,
        pad_token_id=inference_tokenizer.pad_token_id,
        eos_token_id=inference_tokenizer.eos_token_id
    )
    
    response = inference_tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract assistant response
    if "assistant" in response:
        response = response.split("assistant")[-1].strip()
    
    # Try to parse XML
    try:
        reasoning = response.split("<reasoning>")[1].split("</reasoning>")[0].strip()
        answer = extract_xml_answer(response)
    except:
        reasoning = response
        answer = "Parse error"
    
    return {
        "question": question,
        "reasoning": reasoning,
        "answer": answer,
        "raw_response": response
    }

# Test on sample problems
test_problems = [
    "If a baker makes 24 cupcakes and puts them into boxes of 6, how many boxes does he need?",
    "Sarah has 15 apples. She gives 4 to her friend and buys 8 more. How many apples does she have now?",
    "A train travels 60 miles per hour. How far does it travel in 3 hours?"
]

print("\n" + "="*60)
print(" TESTING TRAINED MODEL")
print("="*60 + "\n")

test_results = []
for i, problem in enumerate(test_problems, 1):
    print(f"Test {i}/3: {problem}")
    result = solve_math_problem(problem)
    test_results.append(result)
    
    print(f"\n Reasoning:\n{result['reasoning'][:200]}...")
    print(f"\n Answer: {result['answer']}")
    print("\n" + "-"*60 + "\n")

# Log test results to W&B
wandb.log({
    "test_samples": wandb.Table(
        columns=["question", "reasoning", "answer"],
        data=[[r["question"], r["reasoning"], r["answer"]] for r in test_results]
    )
})

---

##  Section 10: HuggingFace Hub Deployment

### Model Card Best Practices

A good model card includes:
1. **Model Description**: What it does
2. **Training Details**: Dataset, hyperparameters, compute
3. **Usage Examples**: Code to run inference
4. **Limitations**: Known issues and constraints
5. **Evaluation Results**: Performance metrics
6. **Citation**: How to cite your work

In [None]:
# Generate comprehensive model card
MODEL_CARD = f"""
---
language:
- en
license: apache-2.0
tags:
- grpo
- reinforcement-learning
- math
- gsm8k
- reasoning
base_model: {MODEL_NAME}
datasets:
- openai/gsm8k
---

# GRPO-Tuned Math Reasoner

This model was fine-tuned using **GRPO (Generative Reinforcement Policy Optimization)** on the GSM8K dataset for mathematical reasoning tasks.

## Model Description

- **Base Model**: {MODEL_NAME}
- **Training Method**: GRPO with multi-objective rewards
- **Training Dataset**: GSM8K (Grade School Math 8K)
- **Training Examples**: {len(dataset)}
- **Total Parameters**: {param_count:,}
- **Precision**: bfloat16

## Training Details

### Hyperparameters

```yaml
learning_rate: 5e-6
optimizer: AdamW
  adam_beta1: 0.9
  adam_beta2: 0.99
  weight_decay: 0.1
lr_scheduler: cosine
warmup_ratio: 0.1
num_train_epochs: 1
per_device_batch_size: 1
gradient_accumulation_steps: 4
max_grad_norm: 0.1
num_generations: 16
max_prompt_length: 256
max_completion_length: 200
```

### Reward Functions

The model was trained with five reward signals:

1. **Correctness** (2.0 points): Exact match with ground truth
2. **Numeric Format** (0.5 points): Answer is numeric
3. **XML Structure** (0.5 points): Proper tag nesting
4. **Strict Format** (0.5 points): Exact formatting
5. **Soft Format** (0.5 points): Flexible formatting

**Total possible reward**: 4.0 points

### Compute Infrastructure

- **GPU**: NVIDIA A100 (40GB)
- **Training Time**: ~{train_result.metrics['train_runtime'] / 3600:.1f} hours
- **vLLM**: Enabled for efficient inference

## Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "YOUR_HF_USERNAME/{RUN_NAME}",
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("YOUR_HF_USERNAME/{RUN_NAME}")

# Create prompt
messages = [
    {{"role": "system", "content": """Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>"""}},
    {{"role": "user", "content": "If Sarah has 12 apples and gives away 5, how many does she have left?"}}
]

# Generate response
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
outputs = model.generate(input_ids, max_new_tokens=200, temperature=0.7)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
```

### Expected Output Format

```xml
<reasoning>
Sarah starts with 12 apples.
She gives away 5 apples.
To find how many she has left, we subtract: 12 - 5 = 7
</reasoning>
<answer>
7
</answer>
```

## Limitations

- Trained only on grade-school level math problems
- May struggle with complex multi-step reasoning
- Expects numerical answers (not algebraic expressions)
- Single epoch training may result in some underfitting
- Performance degrades on out-of-distribution problems

## Training Metrics

- **Final Loss**: {train_result.training_loss:.4f}
- **Training Steps**: {train_result.global_step}
- **Samples/Second**: {train_result.metrics['train_samples_per_second']:.2f}

## Citation

```bibtex
@misc{{{RUN_NAME.replace('-', '_')},
  title={{GRPO-Tuned Math Reasoner}},
  author={{Your Name}},
  year={{2025}},
  publisher={{HuggingFace}},
  howpublished={{\url{{https://huggingface.co/YOUR_USERNAME/{RUN_NAME}}}}}
}}
```

## License

This model inherits the license from the base model ({MODEL_NAME}).

## Acknowledgments

- Base model: Qwen Team
- Dataset: OpenAI (GSM8K)
- Training framework: HuggingFace TRL
- Inference engine: vLLM
"""

# Save model card
with open(f"{CHECKPOINT_PATH}/README.md", "w") as f:
    f.write(MODEL_CARD)

print(" Model card generated")
print("\n Preview:")
print(MODEL_CARD[:500] + "...")

In [None]:
# Authenticate with HuggingFace
from huggingface_hub import login, HfApi

print(" HuggingFace Authentication")
login()

# Configure your model repository
HF_USERNAME = input("Enter your HuggingFace username: ")
HF_MODEL_NAME = input(f"Enter model name (default: {RUN_NAME}): ") or RUN_NAME
HF_REPO_ID = f"{HF_USERNAME}/{HF_MODEL_NAME}"

print(f"\n Will push to: {HF_REPO_ID}")

In [None]:
# Push to HuggingFace Hub
print(f" Pushing model to {HF_REPO_ID}...")

inference_model.push_to_hub(
    HF_REPO_ID,
    commit_message=f"GRPO training on GSM8K - {len(dataset)} examples",
    private=False  # Set to True for private repo
)

inference_tokenizer.push_to_hub(
    HF_REPO_ID,
    commit_message="Add tokenizer"
)

print("\n Model pushed successfully!")
print(f"\n View your model at: https://huggingface.co/{HF_REPO_ID}")

# Log to W&B
wandb.config.update({"hf_repo": HF_REPO_ID})
wandb.log({"model_url": f"https://huggingface.co/{HF_REPO_ID}"})

---

##  Section 11: Gradio Chat Interface

### Building an Interactive Demo

Let's create a simple chat interface where users can:
1. Ask math questions
2. See step-by-step reasoning
3. Get the final answer

This demo can be:
- Run locally in the notebook
- Deployed to HuggingFace Spaces
- Embedded in websites
- Shared via public URL

In [None]:
import gradio as gr

# Create chat interface
def chat_with_model(message, history, temperature=0.7):
    """
    Process a chat message and return the response.
    
    Args:
        message: User's question
        history: Chat history (unused in this simple version)
        temperature: Sampling temperature
    
    Returns:
        Formatted response with reasoning and answer
    """
    result = solve_math_problem(message, temperature=temperature)
    
    # Format response nicely
    response = f"""** Reasoning:**

{result['reasoning']}

** Answer:** {result['answer']}
"""
    
    return response

# Create Gradio interface
demo = gr.ChatInterface(
    fn=chat_with_model,
    title=" GRPO Math Tutor",
    description=f"""
    Ask me grade-school math questions! I'll show my reasoning step-by-step.
    
    **Model:** {HF_REPO_ID}
    
    **Training:** {len(dataset)} GSM8K examples with GRPO
    """,
    examples=[
        "If a pizza is cut into 8 slices and John eats 3, what fraction is left?",
        "A car travels 50 miles per hour for 2.5 hours. How far does it go?",
        "Sarah has $20. She buys 3 books for $4 each. How much money does she have left?"
    ],
    additional_inputs=[
        gr.Slider(0.1, 1.5, value=0.7, label="Temperature (creativity)", step=0.1)
    ],
    theme=gr.themes.Soft(),
    retry_btn=" Retry",
    undo_btn="↩ Undo",
    clear_btn=" Clear",
)

# Launch interface
print(" Launching Gradio interface...")
demo.launch(
    share=True,  # Creates public URL
    debug=True
)

---

##  Section 12: Optional - Prime Intellect Integration

### What is Prime Intellect?

Prime Intellect provides:
- **Distributed RL training**: Scale across multiple GPUs/nodes
- **Environment Hub**: Pre-built RL environments
- **Fault tolerance**: Automatic recovery from failures
- **Verifiers**: Modular reward functions

### When to Use Prime Intellect?

Consider Prime Intellect if you need:
- Multi-GPU/multi-node training
- Custom RL environments
- Production-scale deployment
- Advanced monitoring and logging

### Example: AQuA-RAT Environment

The user's environment (`harleycooper/nanochatAquaRat`) is an algebra problem solver similar to GSM8K.

In [None]:
# This cell demonstrates Prime Intellect integration (optional)
# Uncomment to use

"""
# Install Prime RL
!curl -sSL https://raw.githubusercontent.com/PrimeIntellect-ai/prime-rl/main/scripts/install.sh | bash

# Configure for AQuA-RAT environment
prime_config = {
    "model": MODEL_NAME,
    "env": {
        "id": "harleycooper/nanochatAquaRat",
        "args": {
            "num_train_examples": 2000,
            "num_eval_examples": 254,
            "seed": 42
        }
    },
    "trainer": {
        "args": {
            "learning_rate": 2e-5,
            "rollouts_per_example": 8,
            "max_steps": 400
        }
    }
}

# Save config
import toml
with open("prime_config.toml", "w") as f:
    toml.dump(prime_config, f)

# Run training
!uv run vf-rl @ prime_config.toml
"""

print("ℹ  Prime Intellect integration code is commented out.")
print("Uncomment the cell above to use Prime Intellect environments.")

---

##  Section 13: Wrap-Up and Next Steps

### What You've Accomplished

 Set up a complete GRPO training pipeline
 Trained a model on mathematical reasoning
 Monitored training with Weights & Biases
 Saved checkpoints to cloud storage
 Deployed model to HuggingFace Hub
 Created an interactive chat interface

### Next Steps

1. **Improve Training**:
   - Use full GSM8K dataset (7,500 examples)
   - Train for multiple epochs with validation
   - Experiment with different reward weights
   - Try larger models (1B, 7B parameters)

2. **Enhance Evaluation**:
   - Create test suite
   - Measure accuracy on GSM8K test set
   - Compare with baseline models
   - Analyze failure modes

3. **Deploy to Production**:
   - Set up HuggingFace Inference Endpoint
   - Deploy Gradio app to HF Spaces
   - Add caching and rate limiting
   - Monitor usage and costs

4. **Extend to Other Domains**:
   - Science QA
   - Code generation
   - Logical reasoning
   - Multi-turn conversations

### Resources

-  [TRL Documentation](https://huggingface.co/docs/trl)
-  [GRPO Paper](https://arxiv.org/abs/2402.03300)
-  [vLLM Docs](https://docs.vllm.ai/)
-  [Gradio Docs](https://gradio.app/docs)
-   [Prime Intellect](https://primeintellect.ai/)

### Questions?

Check out the detailed READMEs in the repository:
- `docs/PRIME_INTELLECT.md`
- `docs/GOOGLE_CLOUD_STORAGE.md`
- `docs/WANDB_VISUALIZATION.md`
- `docs/GRADIO_DEPLOYMENT.md`

In [None]:
# Finish W&B run
wandb.finish()

print("\n" + "="*60)
print(" CONGRATULATIONS! You've completed the GRPO tutorial!")
print("="*60)
print(f"\n W&B Run: {wandb.run.get_url()}")
print(f" HF Model: https://huggingface.co/{HF_REPO_ID}")
print(f" Checkpoints: {OUTPUT_DIR}")
print("\n Next: Check out the docs/ folder for advanced guides!")
print("\n" + "="*60)