<a href="https://colab.research.google.com/github/HarleyCoops/OneShotAquaRAT/blob/main/EducationalGRPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#  Educational GRPO Training Pipeline

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HarleyCoops/OneShotGRPO/blob/main/EducationalGRPO.ipynb)

---

##  What You'll Learn

This comprehensive notebook teaches you how to train small language models using **GRPO (Generative Reinforcement Policy Optimization)** for algebraic reasoning. You'll learn:

1. **Dataset Integration**: Load and preprocess AQuA-RAT from HuggingFace
2. **Training Environments**: Use HuggingFace RL pipeline or Prime Intellect environments
3. **Cloud Storage**: Save checkpoints to Google Cloud Storage
4. **Advanced Monitoring**: Track training with Weights & Biases 3D visualizations
5. **Model Deployment**: Push to HuggingFace Hub with model cards
6. **Interactive Inference**: Create a Gradio chat interface

##  Learning Objectives

By the end of this notebook, you will:
- Understand GRPO and reinforcement learning for LLMs
- Configure reward functions for algebra reasoning
- Monitor training dynamics with comprehensive metrics
- Deploy production-ready models with proper documentation
- Build user-facing chat interfaces

##  Prerequisites

- Basic Python knowledge
- Understanding of neural networks
- Google Colab with GPU runtime (recommended: A100)
- HuggingFace account (for model deployment)
- Weights & Biases account (optional, for monitoring)
- Google Cloud project (optional, for GCS checkpoints)

---

##  Section 1: Environment Setup

### Understanding the Stack

We'll use several specialized libraries:

1. **vLLM**: High-performance inference engine with PagedAttention
   - Reduces memory usage by 50%+
   - Enables efficient batch processing during training
   - Must be installed BEFORE TRL to avoid conflicts

2. **TRL (Transformer Reinforcement Learning)**: RL training framework
   - Implements GRPO, PPO, DPO algorithms
   - Integrates with HuggingFace Transformers
   - Handles reward computation and policy updates

3. **Datasets**: Efficient data loading and processing
   - Streaming support for large datasets
   - Built-in caching and versioning
   - Native HuggingFace Hub integration

In [1]:
# STEP 1: Install vLLM (must be first!)
print(" Installing vLLM for efficient inference...")
!pip install -q vllm

print("\n  IMPORTANT: Restart runtime after vLLM installation!")
print("Go to Runtime > Restart runtime, then continue with the next cell.")

 Installing vLLM for efficient inference...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m438.2/438.2 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.0/180.0 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.5/45.5 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m111.0/111.0 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.4/45.4 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.9/3.9 MB[0m [31m51.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m46.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.2/117.2 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[2K  

In [1]:
# STEP 2: Install remaining dependencies
print(" Installing TRL, datasets, and utilities...")
!pip install -q trl datasets transformers wandb google-cloud-storage gradio huggingface_hub

print("\n All dependencies installed!")

 Installing TRL, datasets, and utilities...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m462.8/462.8 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25h
 All dependencies installed!


In [2]:
# STEP 3: Import libraries and verify installation
import re
import os
import json
import torch
from datetime import datetime
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
import wandb

print("\n Environment Check:")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    print("  Warning: No GPU detected. Training will be very slow!")

print("\n All imports successful!")



INFO 11-08 19:01:04 [__init__.py:216] Automatically detected platform cuda.

 Environment Check:
PyTorch version: 2.8.0+cu126
CUDA available: True
GPU: Tesla T4
GPU Memory: 14.74 GB

 All imports successful!


---

##  Section 2: Weights & Biases Setup

### Why W&B for GRPO?

Weights & Biases provides:
- **Real-time metrics**: Track loss, rewards, KL divergence
- **3D visualizations**: Plot reward landscapes and policy evolution
- **Hyperparameter tracking**: Compare runs automatically
- **Artifact versioning**: Track datasets and model checkpoints

### Advanced Logging Strategy

We'll log:
1. Training metrics (loss, learning rate, grad norm)
2. Reward signals (correctness, format, overall)
3. Generation samples (input prompts + model outputs)
4. 3D reward landscapes (reward vs. step vs. example)
5. Model architecture and hyperparameters

In [3]:
# Initialize Weights & Biases
print(" W&B Authentication")
wandb.login()

# Configuration for this run
WANDB_PROJECT = "grpo-algebra-education"
WANDB_ENTITY = None  # Will use your default entity
RUN_NAME = f"grpo-qwen-{datetime.now().strftime('%Y%m%d-%H%M%S')}"

print(f"\n W&B Project: {WANDB_PROJECT}")
print(f" Run Name: {RUN_NAME}")

 W&B Authentication


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mchristian-cooper-us[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



 W&B Project: grpo-algebra-education
 Run Name: grpo-qwen-20251108-190216


---

##  Section 3: Google Cloud Storage Setup (Optional)

### Why Use GCS for Checkpoints?

Google Cloud Storage advantages:
- **Persistent storage**: Survives Colab session disconnects
- **Large capacity**: No 15GB Drive limit
- **Fast access**: Better upload/download speeds
- **Versioning**: Built-in checkpoint history
- **Team sharing**: Easy collaboration

### Alternative: Google Drive

If you don't have GCS, we'll use Google Drive (simpler but slower).

In [4]:
# Choose your storage backend
USE_GCS = False  # Set to True if you have Google Cloud Storage
USE_GDRIVE = True  # Set to True to use Google Drive

if USE_GCS:
    print("  Setting up Google Cloud Storage...")
    from google.colab import auth
    from google.cloud import storage

    # Authenticate
    auth.authenticate_user()

    # Configuration
    GCS_PROJECT = input("Enter your GCP project ID: ")
    GCS_BUCKET = input("Enter your GCS bucket name: ")
    GCS_PREFIX = f"grpo-checkpoints/{RUN_NAME}"

    # Initialize client
    storage_client = storage.Client(project=GCS_PROJECT)
    bucket = storage_client.bucket(GCS_BUCKET)

    print(f"\n Connected to gs://{GCS_BUCKET}/{GCS_PREFIX}")

elif USE_GDRIVE:
    print(" Mounting Google Drive...")
    from google.colab import drive
    drive.mount('/content/drive')

    # Create checkpoint directory
    GDRIVE_PATH = f"/content/drive/MyDrive/grpo_checkpoints/{RUN_NAME}"
    os.makedirs(GDRIVE_PATH, exist_ok=True)

    print(f"\n Checkpoints will save to: {GDRIVE_PATH}")
else:
    print(" Using local storage (will be lost when runtime disconnects)")
    LOCAL_PATH = f"/content/outputs/{RUN_NAME}"
    os.makedirs(LOCAL_PATH, exist_ok=True)

 Mounting Google Drive...
Mounted at /content/drive

 Checkpoints will save to: /content/drive/MyDrive/grpo_checkpoints/grpo-qwen-20251108-190216


---

## Section 4: Dataset Loading and Formatting

### AQuA-RAT Dataset

**AQuA-RAT** (Algebra Question Answering with Rationales) contains:
- ~97,000 algebra word problems
- Multiple choice questions with 5 options (A-E)
- Human-written rationales showing step-by-step reasoning
- Created by DeepMind for algebraic reasoning evaluation

### Multiple Choice Format Strategy

We train the model to output:
```
[REASONING]
Step-by-step problem solving
[/REASONING]
[ANSWER]
Letter (A, B, C, D, or E)
[/ANSWER]
```

This format:
1. Encourages chain-of-thought reasoning
2. Makes parsing answers easy (single letter)
3. Provides interpretability
4. Enables partial credit rewards

In [5]:
# Define the system prompt and format for multiple choice
SYSTEM_PROMPT = """
You are solving algebra problems. Respond in the following format:
[REASONING]
...
[/REASONING]
[ANSWER]
A single letter: A, B, C, D, or E
[/ANSWER]
"""

COT_FORMAT = """\
[REASONING]
{reasoning}
[/REASONING]
[ANSWER]
{answer}
[/ANSWER]
"""

print("Format templates defined")

Format templates defined


In [6]:
# Answer extraction functions for AQuA-RAT
def extract_answer(text: str) -> str:
    """
    Extracts the answer letter from formatted text.

    Example:
        Input: "[REASONING]steps[/REASONING][ANSWER]B[/ANSWER]"
        Output: "B"
    """
    try:
        answer = text.split("[ANSWER]")[-1]
        answer = answer.split("[/ANSWER]")[0]
        answer = answer.strip().upper()
        # Return only valid letters
        if answer in ['A', 'B', 'C', 'D', 'E']:
            return answer
        # Try to extract first letter
        for char in answer:
            if char in ['A', 'B', 'C', 'D', 'E']:
                return char
        return ""
    except:
        return ""

def extract_options_from_text(options_str: str) -> dict:
    """
    Parse options string into dict.

    Example:
        Input: "A)10 B)20 C)30 D)40 E)50"
        Output: {'A': '10', 'B': '20', ...}
    """
    options = {}
    import re
    # Match patterns like "A)value" or "A) value"
    matches = re.findall(r'([A-E])\)\s*([^A-E)]+)', options_str)
    for letter, value in matches:
        options[letter] = value.strip()
    return options

# Test the functions
test_response = "[REASONING]Steps here[/REASONING][ANSWER]B[/ANSWER]"
test_options = "A)10 B)20 C)30 D)40 E)50"

print(f"Answer extraction test: '{extract_answer(test_response)}'")
print(f"Options parsing test: {extract_options_from_text(test_options)}")
print("\nExtraction functions working correctly")

Answer extraction test: 'B'
Options parsing test: {'A': '10', 'B': '20', 'C': '30', 'D': '40', 'E': '50'}

Extraction functions working correctly


In [7]:
# Load and format AQuA-RAT dataset
def get_aqua_rat_questions(split="train", num_examples=None) -> Dataset:
    """
    Loads and preprocesses AQuA-RAT dataset for GRPO training.

    Args:
        split: 'train', 'validation', or 'test'
        num_examples: Optional limit on number of examples

    Returns:
        Dataset with formatted prompts and answers
    """
    print(f"Loading AQuA-RAT {split} split...")
    data = load_dataset('deepmind/aqua_rat', 'raw')[split]

    if num_examples:
        data = data.select(range(min(num_examples, len(data))))

    print(f"Dataset size: {len(data)} examples")

    # Transform to GRPO format
    def format_example(example):
        # Format question with options
        question_text = example['question']
        options_text = example['options']

        # Create full question
        full_question = f"{question_text}\n\nOptions:\n{options_text}"

        return {
            'prompt': [
                {'role': 'system', 'content': SYSTEM_PROMPT},
                {'role': 'user', 'content': full_question}
            ],
            'answer': example['correct'],  # Single letter: A, B, C, D, or E
            'reference_rationale': example['rationale']  # Keep rationale for reference
        }

    print("Formatting examples...")
    data = data.map(format_example)

    # Show a sample
    print("\nSample Question:")
    sample = data[0]
    print(f"Q: {sample['prompt'][1]['content'][:150]}...")
    print(f"Correct Answer: {sample['answer']}")

    return data

# Load training data
# Use smaller subset for faster experimentation (remove num_examples for full dataset)
dataset = get_aqua_rat_questions(split="train", num_examples=2000)
print("\nDataset ready for training")

Loading AQuA-RAT train split...


README.md: 0.00B [00:00, ?B/s]

raw/train-00000-of-00001.parquet:   0%|          | 0.00/25.4M [00:00<?, ?B/s]

raw/test-00000-of-00001.parquet:   0%|          | 0.00/74.0k [00:00<?, ?B/s]

raw/validation-00000-of-00001.parquet:   0%|          | 0.00/76.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/97467 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/254 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/254 [00:00<?, ? examples/s]

Dataset size: 2000 examples
Formatting examples...


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]


Sample Question:
Q: Two friends plan to walk along a 43-km trail, starting at opposite ends of the trail at the same time. If Friend P's rate is 15% faster than Friend Q'...
Correct Answer: E

Dataset ready for training


---

##  Section 5: Reward Functions

### Understanding GRPO Rewards

GRPO uses **multiple reward signals** to shape model behavior:

1. **Correctness Reward** (1.0 points)
   - Primary learning signal
   - Binary: correct letter answer = 1.0, incorrect = 0.0
   - Drives algebraic accuracy

2. **Format Rewards** (0.2 points total)
   - Bracket structure (0.1): [REASONING] and [ANSWER] tags present
   - Answer validation (0.1): Single letter A-E in answer section
   - Ensures consistent, parseable outputs

### Total Reward: 1.2 points

This multi-objective reward encourages:
- Correct algebraic reasoning (83.3%)
- Clean, structured output (16.7%)

In [8]:
# Reward Functions with W&B logging hooks

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """
    Primary reward: Checks if extracted answer matches ground truth.
    Weight: 1.0 (highest priority)
    """
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_answer(r) for r in responses]

    # Log sample to W&B occasionally
    if kwargs.get('step', 0) % 50 == 0:
        q = prompts[0][-1]['content']
        print('-' * 20)
        print(f"Question: {q[:100]}...")
        print(f"Expected: {answer[0]}")
        print(f"Got: {extracted_responses[0]}")
        print(f"Response: {responses[0][:200]}...")

    return [1.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def letter_answer_reward_func(completions, **kwargs) -> list[float]:
    """
    Answer validation reward: Ensures answer is a single valid letter (A-E).
    Weight: 0.1
    """
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_answer(r) for r in responses]
    return [0.1 if r in ['A', 'B', 'C', 'D', 'E'] else 0.0
            for r in extracted_responses]

def bracket_format_reward_func(completions, **kwargs) -> list[float]:
    """
    Format reward: Checks for [REASONING] and [ANSWER] tags.
    Weight: 0.1
    """
    pattern = r"\[REASONING\].*?\[/REASONING\].*?\[ANSWER\].*?\[/ANSWER\]"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.search(pattern, r, re.DOTALL) for r in responses]
    return [0.1 if match else 0.0 for match in matches]

print(" Reward functions defined")
print("\n Reward Structure:")
print("  Correctness:     1.0 (83.3%)")
print("  Letter format:   0.1 ( 8.3%)")
print("  Bracket tags:    0.1 ( 8.3%)")
print("  " + "-" * 30)
print("  TOTAL:           1.2 (100%)")

 Reward functions defined

 Reward Structure:
  Correctness:     1.0 (83.3%)
  Letter format:   0.1 ( 8.3%)
  Bracket tags:    0.1 ( 8.3%)
  ------------------------------
  TOTAL:           1.2 (100%)


---

##  Section 6: Model and Training Configuration

### Model Selection: Qwen2.5-0.5B-Instruct

**Why this model?**
- Small enough to train on single GPU (500M parameters)
- Fits on Google Colab free tier GPU
- Pre-trained on instruction following
- Good algebraic reasoning baseline
- Fast inference for RL training

### GRPO Hyperparameters Explained

| Parameter | Value | Reasoning |
|-----------|-------|----------|
| `learning_rate` | 5e-6 | Small LR prevents catastrophic forgetting |
| `num_generations` | 16 | Multiple samples for variance reduction |
| `max_grad_norm` | 0.1 | Gradient clipping for RL stability |
| `num_train_epochs` | 1 | Single pass prevents overfitting |
| `warmup_ratio` | 0.1 | Gradual LR warmup for stability |
| `bf16` | True | Memory efficiency + numerical stability |

### vLLM Configuration

- `vllm_gpu_memory_utilization`: 0.3 (30% for inference, 70% for training)
- Enables PagedAttention for efficient KV caching
- Significantly speeds up generation during training

In [9]:
# Model configuration
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
OUTPUT_DIR = GDRIVE_PATH if USE_GDRIVE else (f"gs://{GCS_BUCKET}/{GCS_PREFIX}" if USE_GCS else LOCAL_PATH)

print(f" Model: {MODEL_NAME}")
print(f" Output directory: {OUTPUT_DIR}")

# Initialize W&B run with comprehensive config
wandb.init(
    project=WANDB_PROJECT,
    entity=WANDB_ENTITY,
    name=RUN_NAME,
    config={
        # Model config
        "model_name": MODEL_NAME,
        "model_params": "500M",

        # Training config
        "learning_rate": 5e-6,
        "adam_beta1": 0.9,
        "adam_beta2": 0.99,
        "weight_decay": 0.1,
        "warmup_ratio": 0.1,
        "lr_scheduler": "cosine",

        # Batch config
        "per_device_batch_size": 1,
        "gradient_accumulation_steps": 4,
        "effective_batch_size": 4,

        # Generation config
        "num_generations": 16,
        "generation_batch_size": 16,
        "max_prompt_length": 256,
        "max_completion_length": 200,

        # Training duration
        "num_train_epochs": 1,
        "dataset_size": len(dataset),

        # Regularization
        "max_grad_norm": 0.1,

        # Precision
        "bf16": True,

        # vLLM
        "use_vllm": True,
        "vllm_gpu_memory": 0.3,

        # Reward weights
        "reward_correctness_weight": 1.0,
        "reward_letter_format_weight": 0.1,
        "reward_bracket_format_weight": 0.1,
    },
    tags=["grpo", "algebra", "aqua_rat", "educational"]
)

print("\n W&B run initialized")

 Model: Qwen/Qwen2.5-0.5B-Instruct
 Output directory: /content/drive/MyDrive/grpo_checkpoints/grpo-qwen-20251108-190216


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/



 W&B run initialized


In [10]:
# Configure GRPO training
training_args = GRPOConfig(
    # Output
    output_dir=OUTPUT_DIR,
    run_name=RUN_NAME,

    # Optimizer
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    max_grad_norm=0.1,

    # Learning rate schedule
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',

    # Logging
    logging_steps=1,
    report_to="wandb",
    log_on_each_node=False,

    # Precision
    bf16=True,

    # Batch configuration
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,

    # Generation parameters
    num_generations=16,
    generation_batch_size=16,
    max_prompt_length=256,
    max_completion_length=200,

    # Training duration
    num_train_epochs=1,

    # Checkpointing
    save_steps=100,
    save_total_limit=5,  # Keep last 5 checkpoints

    # vLLM configuration
    use_vllm=True,
    vllm_gpu_memory_utilization=0.3,
)

print(" Training configuration created")
print(f"\n Training will run for ~{len(dataset) // 4} steps")
print(f" Checkpoints every 100 steps → ~{(len(dataset) // 4) // 100} checkpoints")

 Training configuration created

 Training will run for ~500 steps
 Checkpoints every 100 steps → ~5 checkpoints


In [11]:
# Load model and tokenizer
print(f" Loading model: {MODEL_NAME}...")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Calculate model size
param_count = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n Model Statistics:")
print(f"  Total parameters: {param_count:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{param_count * 2 / 1024**3:.2f} GB (bf16)")

# Log to W&B
wandb.config.update({
    "total_params": param_count,
    "trainable_params": trainable_params
})

print("\n Model loaded successfully")

 Loading model: Qwen/Qwen2.5-0.5B-Instruct...


config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]


 Model Statistics:
  Total parameters: 494,032,768
  Trainable parameters: 494,032,768
  Model size: ~0.92 GB (bf16)

 Model loaded successfully


---

##  Section 7: Training Execution

### What Happens During Training?

Each training step:
1. **Sampling**: Load batch of algebra problems
2. **Generation**: Model generates 16 responses per problem
3. **Reward**: Each response gets 3 reward scores
4. **Policy Update**: GRPO updates model weights based on rewards
5. **Logging**: Metrics sent to W&B
6. **Checkpointing**: Save every 100 steps

### Expected Training Time

- **Dataset**: 2,000 examples
- **Effective batch size**: 4
- **Steps**: ~500
- **Time per step**: ~30-60 seconds (with A100)
- **Total time**: ~4-8 hours

### Monitoring Tips

Watch these metrics in W&B:
- **Loss**: Should decrease over time
- **Reward**: Should increase (target: 1.0+)
- **KL Divergence**: Should stay small (<1.0)
- **Learning Rate**: Should follow cosine schedule

In [None]:
# Create GRPO trainer
print("  Building GRPO trainer...")

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        bracket_format_reward_func,    # 0.1 points
        letter_answer_reward_func,     # 0.1 points
        correctness_reward_func,       # 1.0 points
    ],
    args=training_args,
    train_dataset=dataset,
)

print("\n Trainer ready")
print("\n Starting training...")
print("\n" + "="*60)
print("Monitor your run at:", wandb.run.get_url())
print("="*60 + "\n")

  Building GRPO trainer...




In [None]:
# Train the model!
train_result = trainer.train()

print("\n" + "="*60)
print(" Training complete!")
print("="*60)
print(f"\nFinal metrics:")
print(f"  Loss: {train_result.training_loss:.4f}")
print(f"  Steps: {train_result.global_step}")
print(f"  Time: {train_result.metrics['train_runtime']:.2f} seconds")
print(f"  Samples/second: {train_result.metrics['train_samples_per_second']:.2f}")

---

##  Section 8: Checkpoint Management

### Understanding Checkpoints

Each checkpoint contains:
- `model.safetensors`: Model weights
- `config.json`: Model architecture
- `tokenizer.json`: Tokenizer configuration
- `trainer_state.json`: Training progress
- `optimizer.pt`: Optimizer state
- `scheduler.pt`: LR scheduler state

### Selecting Best Checkpoint

Strategies:
1. **Latest**: Most training exposure
2. **Highest reward**: Best validation performance
3. **Lowest loss**: Most optimization progress

For this demo, we'll use the **final checkpoint**.

In [None]:
# Find the final checkpoint
import glob

if USE_GDRIVE or not USE_GCS:
    checkpoint_dirs = glob.glob(f"{OUTPUT_DIR}/checkpoint-*")
    checkpoint_dirs.sort(key=lambda x: int(x.split("-")[-1]))
    final_checkpoint = checkpoint_dirs[-1] if checkpoint_dirs else None

    print(f" Found {len(checkpoint_dirs)} checkpoints")
    if final_checkpoint:
        print(f" Final checkpoint: {final_checkpoint}")
        CHECKPOINT_PATH = final_checkpoint
    else:
        print("  No checkpoints found!")
        CHECKPOINT_PATH = OUTPUT_DIR
else:
    # For GCS, we'll need to list blobs
    print(" Listing GCS checkpoints...")
    blobs = bucket.list_blobs(prefix=GCS_PREFIX)
    checkpoint_nums = set()
    for blob in blobs:
        if "checkpoint-" in blob.name:
            num = blob.name.split("checkpoint-")[1].split("/")[0]
            if num.isdigit():
                checkpoint_nums.add(int(num))

    if checkpoint_nums:
        final_num = max(checkpoint_nums)
        CHECKPOINT_PATH = f"gs://{GCS_BUCKET}/{GCS_PREFIX}/checkpoint-{final_num}"
        print(f" Final checkpoint: {CHECKPOINT_PATH}")
    else:
        print("  No checkpoints found in GCS!")
        CHECKPOINT_PATH = f"gs://{GCS_BUCKET}/{GCS_PREFIX}"

---

##  Section 9: Model Evaluation

### Quick Inference Test

Before deployment, let's test our trained model!

**Note**: We'll load the model locally first (needed for pushing to HF Hub), then switch to using HuggingFace Inference API for all inference operations. This allows us to run inference without loading the model into GPU memory.

In [None]:
# Load the trained model for inference (needed for pushing to HF Hub)
print(f" Loading trained model from {CHECKPOINT_PATH}...")

inference_model = AutoModelForCausalLM.from_pretrained(
    CHECKPOINT_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

inference_tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH)

print(" Model loaded for inference (local)")
print("\n Note: After pushing to HF Hub, we'll switch to using HF Inference API")
print("       which doesn't require loading the model into GPU memory.")

---

##  Section 10: HuggingFace Hub Deployment

### Model Card Best Practices

A good model card includes:
1. **Model Description**: What it does
2. **Training Details**: Dataset, hyperparameters, compute
3. **Usage Examples**: Code to run inference
4. **Limitations**: Known issues and constraints
5. **Evaluation Results**: Performance metrics
6. **Citation**: How to cite your work

In [None]:
# Generate comprehensive model card
MODEL_CARD = f"""
---
language:
- en
license: apache-2.0
tags:
- grpo
- reinforcement-learning
- algebra
- aqua-rat
- reasoning
base_model: {MODEL_NAME}
datasets:
- deepmind/aqua_rat
---

# GRPO-Tuned Algebra Reasoner

This model was fine-tuned using **GRPO (Generative Reinforcement Policy Optimization)** on the AQuA-RAT dataset for algebraic reasoning tasks.

## Model Description

- **Base Model**: {MODEL_NAME}
- **Training Method**: GRPO with multi-objective rewards
- **Training Dataset**: AQuA-RAT (Algebra Question Answering with Rationales)
- **Training Examples**: {len(dataset)}
- **Total Parameters**: {param_count:,}
- **Precision**: bfloat16

## Training Details

### Hyperparameters

```yaml
learning_rate: 5e-6
optimizer: AdamW
  adam_beta1: 0.9
  adam_beta2: 0.99
  weight_decay: 0.1
lr_scheduler: cosine
warmup_ratio: 0.1
num_train_epochs: 1
per_device_batch_size: 1
gradient_accumulation_steps: 4
max_grad_norm: 0.1
num_generations: 16
max_prompt_length: 256
max_completion_length: 200
```

### Reward Functions

The model was trained with three reward signals:

1. **Correctness** (1.0 points): Exact match with ground truth letter (A-E)
2. **Letter Format** (0.1 points): Answer is a valid single letter
3. **Bracket Structure** (0.1 points): Proper [REASONING]/[ANSWER] tags

**Total possible reward**: 1.2 points

### Compute Infrastructure

- **GPU**: NVIDIA A100 (40GB)
- **Training Time**: ~{train_result.metrics['train_runtime'] / 3600:.1f} hours
- **vLLM**: Enabled for efficient inference

## Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "YOUR_HF_USERNAME/{RUN_NAME}",
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("YOUR_HF_USERNAME/{RUN_NAME}")

# Create prompt
messages = [
    {{"role": "system", "content": """You are solving algebra problems. Respond in the following format:
[REASONING]
...
[/REASONING]
[ANSWER]
A single letter: A, B, C, D, or E
[/ANSWER]"""}},
    {{"role": "user", "content": "If x + 5 = 12, what is x?\\nOptions:\\nA)5 B)6 C)7 D)8 E)9"}}
]

# Generate response
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
outputs = model.generate(input_ids, max_new_tokens=200, temperature=0.7)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
```

### Expected Output Format

```
[REASONING]
We have the equation x + 5 = 12.
To solve for x, subtract 5 from both sides: x = 12 - 5 = 7
[/REASONING]
[ANSWER]
C
[/ANSWER]
```

## Limitations

- Trained only on algebra word problems with multiple choice answers
- May struggle with complex multi-step reasoning
- Expects single letter answers (A-E)
- Single epoch training may result in some underfitting
- Performance degrades on out-of-distribution problems

## Training Metrics

- **Final Loss**: {train_result.training_loss:.4f}
- **Training Steps**: {train_result.global_step}
- **Samples/Second**: {train_result.metrics['train_samples_per_second']:.2f}

## Citation

```bibtex
@misc{{{RUN_NAME.replace('-', '_')},
  title={{GRPO-Tuned Algebra Reasoner}},
  author={{Your Name}},
  year={{2025}},
  publisher={{HuggingFace}},
  howpublished={{\url{{https://huggingface.co/YOUR_USERNAME/{RUN_NAME}}}}}
}}
```

## License

This model inherits the license from the base model ({MODEL_NAME}).

## Acknowledgments

- Base model: Qwen Team
- Dataset: DeepMind (AQuA-RAT)
- Training framework: HuggingFace TRL
- Inference engine: vLLM
"""

# Save model card
with open(f"{CHECKPOINT_PATH}/README.md", "w") as f:
    f.write(MODEL_CARD)

print(" Model card generated")
print("\n Preview:")
print(MODEL_CARD[:500] + "...")

In [None]:
# Authenticate with HuggingFace
from huggingface_hub import login, HfApi

print(" HuggingFace Authentication")
login()

# Configure your model repository
HF_USERNAME = input("Enter your HuggingFace username: ")
HF_MODEL_NAME = input(f"Enter model name (default: {RUN_NAME}): ") or RUN_NAME
HF_REPO_ID = f"{HF_USERNAME}/{HF_MODEL_NAME}"

print(f"\n Will push to: {HF_REPO_ID}")

In [None]:
# Push to HuggingFace Hub
print(f" Pushing model to {HF_REPO_ID}...")

inference_model.push_to_hub(
    HF_REPO_ID,
    commit_message=f"GRPO training on AQuA-RAT - {len(dataset)} examples",
    private=False  # Set to True for private repo
)

inference_tokenizer.push_to_hub(
    HF_REPO_ID,
    commit_message="Add tokenizer"
)

print("\n Model pushed successfully!")
print(f"\n View your model at: https://huggingface.co/{HF_REPO_ID}")

# Free up GPU memory - we'll use HF Inference API from now on
del inference_model
del inference_tokenizer
import torch
torch.cuda.empty_cache()
print("\n GPU memory freed - switching to HF Inference API for inference")

# Log to W&B
wandb.config.update({"hf_repo": HF_REPO_ID})
wandb.log({"model_url": f"https://huggingface.co/{HF_REPO_ID}"})

---

##  Section 9.5: HuggingFace Inference API Setup

### Why Use HF Inference API?

The HuggingFace Inference API provides:
- **No GPU memory usage**: Model runs on HF infrastructure
- **Scalability**: Handles multiple concurrent requests
- **Cost-effective**: Pay only for what you use
- **Easy deployment**: No need to manage model loading/unloading

### Setting Up Inference via API

We'll use the `InferenceClient` from `huggingface_hub` to call our deployed model.


In [None]:
# Set up HuggingFace Inference API
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer

print(" Setting up HuggingFace Inference API...")
print(f" Model: {HF_REPO_ID}")

# Initialize Inference Client
# Note: Uses your HF token from login() automatically
hf_client = InferenceClient(model=HF_REPO_ID)

# Load tokenizer for chat template formatting (lightweight, no model weights)
tokenizer = AutoTokenizer.from_pretrained(HF_REPO_ID)

print("\n HF Inference API ready!")
print(" Model is running on HuggingFace infrastructure (no local GPU needed)")

def solve_algebra_problem(question: str, temperature=0.7, max_tokens=200):
    """
    Generate a solution to an algebra problem using HF Inference API.

    Args:
        question: Algebra word problem
        temperature: Sampling temperature (higher = more creative)
        max_tokens: Maximum response length

    Returns:
        dict with reasoning and answer
    """
    # Format prompt using chat template
    prompt = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": question}
    ]

    # Apply chat template
    formatted_prompt = tokenizer.apply_chat_template(
        prompt,
        tokenize=False,
        add_generation_prompt=True
    )

    # Call HF Inference API
    try:
        response = hf_client.text_generation(
            formatted_prompt,
            max_new_tokens=max_tokens,
            temperature=temperature,
            return_full_text=False,  # Don't include the prompt in response
            stop_sequences=["</s>", "<|endoftext|>"]  # Stop tokens
        )
    except Exception as e:
        return {
            "question": question,
            "reasoning": f"Error calling HF API: {str(e)}",
            "answer": "Error",
            "raw_response": ""
        }

    # Extract assistant response
    if isinstance(response, str):
        response_text = response.strip()
    else:
        response_text = str(response).strip()

    # Try to parse bracket format
    try:
        reasoning = response_text.split("[REASONING]")[1].split("[/REASONING]")[0].strip()
        answer = extract_answer(response_text)
    except:
        reasoning = response_text
        answer = extract_answer(response_text)
        if not answer:
            answer = "Parse error"

    return {
        "question": question,
        "reasoning": reasoning,
        "answer": answer,
        "raw_response": response_text
    }

# Test on sample problems using HF Inference API
test_problems = [
    "If a store sells 5 apples for $3, how much do 15 apples cost?\nOptions:\nA)$6 B)$9 C)$12 D)$15 E)$18",
    "A train travels at 60 mph for 2.5 hours. How far does it travel?\nOptions:\nA)100 miles B)120 miles C)150 miles D)180 miles E)200 miles",
    "If x + 5 = 12, what is x?\nOptions:\nA)5 B)6 C)7 D)8 E)9"
]

print("\n" + "="*60)
print(" TESTING MODEL VIA HF INFERENCE API")
print("="*60 + "\n")

test_results = []
for i, problem in enumerate(test_problems, 1):
    print(f"Test {i}/3: {problem[:60]}...")
    result = solve_algebra_problem(problem)
    test_results.append(result)

    print(f"\n Reasoning:\n{result['reasoning'][:200]}...")
    print(f"\n Answer: {result['answer']}")
    print("\n" + "-"*60 + "\n")

# Log test results to W&B
wandb.log({
    "test_samples_hf_api": wandb.Table(
        columns=["question", "reasoning", "answer"],
        data=[[r["question"], r["reasoning"], r["answer"]] for r in test_results]
    )
})

print(" Inference via HF API working!")
print(f" Model available at: https://huggingface.co/{HF_REPO_ID}")


---

##  Section 11: Gradio Chat Interface

### Building an Interactive Demo

Let's create a simple chat interface where users can:
1. Ask math questions
2. See step-by-step reasoning
3. Get the final answer

**Powered by HuggingFace Inference API**: The model runs on HF infrastructure, so you don't need a local GPU to run inference!

This demo can be:
- Run locally in the notebook (using HF API)
- Deployed to HuggingFace Spaces
- Embedded in websites
- Shared via public URL

In [None]:
import gradio as gr

# Create chat interface
def chat_with_model(message, history, temperature=0.7):
    """
    Process a chat message and return the response using HF Inference API.

    Args:
        message: User's question
        history: Chat history (unused in this simple version)
        temperature: Sampling temperature

    Returns:
        Formatted response with reasoning and answer
    """
    result = solve_algebra_problem(message, temperature=temperature)

    # Format response nicely
    response = f"""** Reasoning:**

{result['reasoning']}

** Answer:** {result['answer']}
"""

    return response

# Create Gradio interface
demo = gr.ChatInterface(
    fn=chat_with_model,
    title=" GRPO Algebra Tutor",
    description=f"""
    Ask me algebra questions! I'll show my reasoning step-by-step and provide a multiple choice answer.

    **Model:** {HF_REPO_ID}

    **Inference:** Running via HuggingFace Inference API (no local GPU needed!)

    **Training:** {len(dataset)} AQuA-RAT examples with GRPO
    """,
    examples=[
        "If x + 5 = 12, what is x?\nOptions:\nA)5 B)6 C)7 D)8 E)9",
        "A store sells 5 apples for $3. How much do 15 apples cost?\nOptions:\nA)$6 B)$9 C)$12 D)$15 E)$18",
        "What is 25% of 80?\nOptions:\nA)15 B)20 C)25 D)30 E)35"
    ],
    additional_inputs=[
        gr.Slider(0.1, 1.5, value=0.7, label="Temperature (creativity)", step=0.1)
    ],
    theme=gr.themes.Soft(),
    retry_btn=" Retry",
    undo_btn="↩ Undo",
    clear_btn=" Clear",
)

# Launch interface
print(" Launching Gradio interface...")
print(f" Using HF Inference API: {HF_REPO_ID}")
print(" Model runs on HuggingFace infrastructure - no local GPU required!")
demo.launch(
    share=True,  # Creates public URL
    debug=True
)

---

##  Section 12: Optional - Prime Intellect Integration

### What is Prime Intellect?

Prime Intellect provides:
- **Distributed RL training**: Scale across multiple GPUs/nodes
- **Environment Hub**: Pre-built RL environments
- **Fault tolerance**: Automatic recovery from failures
- **Verifiers**: Modular reward functions

### When to Use Prime Intellect?

Consider Prime Intellect if you need:
- Multi-GPU/multi-node training
- Custom RL environments
- Production-scale deployment
- Advanced monitoring and logging

### Example: AQuA-RAT Environment

The user's environment (`harleycooper/nanochatAquaRat`) is an algebra problem solver similar to GSM8K.

In [None]:
# This cell demonstrates Prime Intellect integration (optional)
# Uncomment to use

"""
# Install Prime RL
!curl -sSL https://raw.githubusercontent.com/PrimeIntellect-ai/prime-rl/main/scripts/install.sh | bash

# Configure for AQuA-RAT environment
prime_config = {
    "model": MODEL_NAME,
    "env": {
        "id": "harleycooper/nanochatAquaRat",
        "args": {
            "num_train_examples": 2000,
            "num_eval_examples": 254,
            "seed": 42
        }
    },
    "trainer": {
        "args": {
            "learning_rate": 2e-5,
            "rollouts_per_example": 8,
            "max_steps": 400
        }
    }
}

# Save config
import toml
with open("prime_config.toml", "w") as f:
    toml.dump(prime_config, f)

# Run training
!uv run vf-rl @ prime_config.toml
"""

print("ℹ  Prime Intellect integration code is commented out.")
print("Uncomment the cell above to use Prime Intellect environments.")

---

##  Section 13: Wrap-Up and Next Steps

### What You've Accomplished

 Set up a complete GRPO training pipeline
 Trained a model on algebraic reasoning
 Monitored training with Weights & Biases
 Saved checkpoints to cloud storage
 Deployed model to HuggingFace Hub
 Created an interactive chat interface

### Next Steps

1. **Improve Training**:
   - Use full AQuA-RAT dataset (~97,000 examples)
   - Train for multiple epochs with validation
   - Experiment with different reward weights
   - Try larger models (1B, 7B parameters)

2. **Enhance Evaluation**:
   - Create test suite
   - Measure accuracy on AQuA-RAT test set
   - Compare with baseline models
   - Analyze failure modes

3. **Deploy to Production**:
   - Set up HuggingFace Inference Endpoint
   - Deploy Gradio app to HF Spaces
   - Add caching and rate limiting
   - Monitor usage and costs

4. **Extend to Other Domains**:
   - Math reasoning (GSM8K)
   - Science QA
   - Code generation
   - Logical reasoning
   - Multi-turn conversations

### Resources

-  [TRL Documentation](https://huggingface.co/docs/trl)
-  [GRPO Paper](https://arxiv.org/abs/2402.03300)
-  [vLLM Docs](https://docs.vllm.ai/)
-  [Gradio Docs](https://gradio.app/docs)
-   [Prime Intellect](https://primeintellect.ai/)

### Questions?

Check out the detailed READMEs in the repository:
- `docs/PRIME_INTELLECT.md`
- `docs/GOOGLE_CLOUD_STORAGE.md`
- `docs/WANDB_VISUALIZATION.md`
- `docs/GRADIO_DEPLOYMENT.md`

In [None]:
# Finish W&B run
wandb.finish()

print("\n" + "="*60)
print(" CONGRATULATIONS! You've completed the GRPO tutorial!")
print("="*60)
print(f"\n W&B Run: {wandb.run.get_url()}")
print(f" HF Model: https://huggingface.co/{HF_REPO_ID}")
print(f" Checkpoints: {OUTPUT_DIR}")
print("\n Next: Check out the docs/ folder for advanced guides!")
print("\n" + "="*60)