# Cold Start SFT + GRPO Training for Gemma 3

## A Complete Guide to Training Reasoning Models on Kaggle TPU

---

### Overall Training and Evaluation Strategy

**Training Pipeline Architecture:**

Our approach follows a sequential two-stage pipeline optimized for Kaggle TPU v6e-1 hardware:

1. **Stage 1: Cold Start SFT** (500 steps)
   - **Purpose**: Establish proper output format and reasoning patterns
   - **Technique**: LoRA fine-tuning using `PeftTrainer` (parameter-efficient, ~1.1x model size)
   - **Dataset**: Bespoke-Stratos-17k (high-quality Chain-of-Thought examples)
   - **Memory Optimization**: Batch size of 2, RANK=64, with optional filtering for long sequences
   - **Key Innovation**: Uses Tunix's built-in `PeftTrainer` instead of custom training loops

2. **Stage 2: GRPO Reinforcement Learning** (2500 steps)
   - **Purpose**: Strengthen reasoning through reward-based learning
   - **Technique**: Group Relative Policy Optimization (memory-efficient, no separate value model)
   - **Dataset**: GSM8K (verifiable math problems with clear correct/incorrect signals)
   - **Reward Design**: Multi-component reward system with format, structure, and answer correctness

**Evaluation Strategy:**

We implement a comprehensive evaluation framework that tracks three key metrics at multiple stages:

1. **Pre-SFT Evaluation**: Baseline performance measurement
2. **Post-SFT Evaluation**: Format learning assessment
3. **Post-GRPO Evaluation**: Final reasoning capability assessment

**Evaluation Metrics:**
- **Format Accuracy**: Percentage of outputs matching `<reasoning>...</reasoning><answer>...</answer>` structure
- **Answer Accuracy**: Exact match between extracted numerical answer and ground truth
- **Partial Accuracy**: Answers within 10% of correct value (tolerance for rounding)

**Reward Function Design for GRPO:**

We use a multi-component reward system that balances format compliance with answer correctness:

1. **Format Reward** (`check_format`): +2.0 for correct tag structure, penalties for missing/incomplete tags
2. **Structure Reward** (`check_structure`): +1.0 for proper reasoning/answer separation
3. **Answer Correctness** (`check_answer`): +3.0 for exact match, +1.5 for whitespace-normalized match, +0.5 for 10% tolerance
4. **Number Extraction** (`check_numbers`): Ensures numerical answers are extractable

This multi-reward approach prevents the model from gaming the system by focusing only on format or only on answers.

**Memory Management and Sequence Length:**

A key challenge in this pipeline is managing memory constraints, especially when working with long Chain-of-Thought reasoning traces. We address this through several techniques:

- **Sequence Length Configuration**: `MAX_SEQ_LENGTH = MAX_PROMPT_LENGTH (256) + TOTAL_GENERATION_STEPS (768) = 1024 tokens`
  - This balance allows for substantial reasoning traces while maintaining memory efficiency
  - Increasing `MAX_SEQ_LENGTH` significantly increases memory usage (quadratic in attention)

- **Memory Optimization Strategies**:
  1. **Data Filtering**: Optional `SFT_MAX_TOKENS_FILTER` parameter to exclude examples exceeding a token threshold
  2. **Batch Size Reduction**: SFT uses batch size 2 (can be reduced to 1 if OOM)
  3. **LoRA Rank Adjustment**: RANK=64 (can be reduced to 32 or 16 for lower memory)
  4. **Dataset Size Control**: `SFT_NUM_EXAMPLES` parameter limits dataset size

- **Sequence Length Independence**: LoRA parameters are sequence-length independent, allowing checkpoint resumption with different `MAX_SEQ_LENGTH` values if needed

**Ablation Considerations:**

While not extensively tested in this notebook, several design choices merit discussion:

- **Two-Stage vs. Single-Stage**: The two-stage approach prevents format collapse observed in pure RL training
- **PeftTrainer vs. Custom Loop**: Using Tunix's `PeftTrainer` simplifies implementation and ensures compatibility
- **Reward Function Weighting**: Equal weighting of format, structure, and answer rewards provides balanced learning

### Dataset Construction and Selection

**Finetuning Dataset Creation:**

The success of this pipeline critically depends on the quality and structure of the training data. We use two distinct datasets optimized for each training stage:

**Stage 1 Dataset: Bespoke-Stratos-17k**

- **Source**: `bespokelabs/Bespoke-Stratos-17k` on HuggingFace (publicly accessible)
- **Size**: ~17,000 high-quality Chain-of-Thought examples
- **Origin**: Distilled from DeepSeek-R1 model outputs
- **Key Characteristics**:
  - Long, detailed reasoning traces with self-reflection
  - Uses `<|begin_of_thought|>...<|end_of_thought|>` format (converted to `<reasoning>...</reasoning>`)
  - Includes verification steps and intermediate reasoning
  - Covers diverse reasoning tasks (math, logic, problem-solving)

**Data Preprocessing Pipeline:**

1. **Format Standardization**: Convert various thinking tags (`<|begin_of_thought|>`, `<think>`, `<thinking>`) to unified `<reasoning>...</reasoning>` format
2. **Answer Tagging**: Ensure all examples have proper `<answer>...</answer>` tags
3. **Tokenization and Padding**: Fixed-length sequences (padded to `MAX_SEQ_LENGTH`) for efficient batching
4. **Optional Filtering**: Filter out examples exceeding `SFT_MAX_TOKENS_FILTER` to manage memory

**Dataset Selection Rationale:**

- **Bespoke-Stratos-17k for SFT**: Provides high-quality examples of structured reasoning, teaching the model *how* to think rather than just *what* to answer
- **GSM8K for GRPO**: Offers clear verifiable answers necessary for reward computation, with definitive correct/incorrect signals

### Output Format

The trained model will produce outputs in this format:
```
<reasoning>model_thinking_trace</reasoning>
<answer>model_answer</answer>
```

### Hardware Requirements
- **Kaggle TPU v6e-1** (recommended) or similar TPU configuration
- Training time: ~9 hours for full pipeline


## Section 1: Environment Setup

### Understanding the Dependencies

Before we start training, we need to install several key libraries:

- **tunix**: Google's library for training Gemma models on TPU with JAX
- **qwix**: Provides LoRA (Low-Rank Adaptation) functionality for efficient fine-tuning
- **flax**: Neural network library for JAX
- **grain**: Data loading library optimized for JAX
- **transformers**: For tokenizer and model utilities

**Important**: After running the installation cell, you may need to restart the kernel for the changes to take effect.


In [1]:
# ============================================================================
# CELL 1: Install Required Packages
# ============================================================================
# This cell installs all necessary libraries for training.
# RESTART THE KERNEL AFTER THIS CELL COMPLETES (for Colab users)

import importlib.util

def check_package(name):
    """Check if a package is installed."""
    return importlib.util.find_spec(name) is not None

# Check for key packages - tunix is the most critical one
TUNIX_INSTALLED = check_package('tunix')
QWIX_INSTALLED = check_package('qwix')

if not TUNIX_INSTALLED or not QWIX_INSTALLED:
    print("Installing required packages... This may take a few minutes.")
    print("=" * 60)

    # Core dependencies
    %pip install -q python-dotenv
    %pip install -q kagglehub
    %pip install -q ipywidgets
    %pip install -q tensorflow
    %pip install -q tensorflow_datasets
    %pip install -q tensorboardX
    %pip install -q transformers
    %pip install -q grain

    # JAX and related libraries (for TPU training)
    %pip install -q git+https://github.com/jax-ml/jax

    # Google's training libraries (CRITICAL - these are the main packages)
    %pip install git+https://github.com/google/tunix  # Main training framework
    %pip install git+https://github.com/google/qwix   # LoRA support

    # Flax update (required for NNX support)
    %pip uninstall -q flax -y
    %pip install git+https://github.com/google/flax

    # Data and utilities
    %pip install -q huggingface_hub
    %pip install -q datasets
    %pip install -q 'numpy>2'

    print("\n" + "=" * 60)
    print("Installation complete!")
    print("IMPORTANT: Please RESTART the kernel before continuing.")
    print("=" * 60)
else:
    print("All required packages are already installed.")
    print(f"  - tunix: {'OK' if TUNIX_INSTALLED else 'MISSING'}")
    print(f"  - qwix:  {'OK' if QWIX_INSTALLED else 'MISSING'}")


All required packages are already installed.
  - tunix: OK
  - qwix:  OK


### Authentication Setup

This cell handles authentication for:
- **Hugging Face**: To download the Gemma model
- **Kaggle**: For dataset access
- **Weights & Biases (optional)**: For experiment tracking

You need to set up your API keys as environment variables or secrets before running this cell.


In [2]:
# ============================================================================
# CELL 2: Authentication and Service Login
# ============================================================================

import os
import kagglehub

# Detect if running in Colab or Kaggle environment
try:
    from google.colab import userdata
    USE_COLAB = True

    # WandB has issues with Colab, so we disable it
    %pip uninstall -q wandb -y

    # Load credentials from Colab secrets
    os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
    os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
    os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")
    print("Running in Google Colab environment")

except ImportError:
    USE_COLAB = False

    # Try to load from .env file
    from dotenv import load_dotenv
    load_dotenv()
    print("Using environment variables for authentication")

    # Apply nest_asyncio for Jupyter compatibility
    import nest_asyncio
    nest_asyncio.apply()
    print("nest_asyncio applied for async compatibility")

    # Setup WandB for TPU VM (works better outside Colab)
    %pip install -q wandb
    import wandb
    if "WANDB_API_KEY" in os.environ and os.environ["WANDB_API_KEY"]:
        wandb.login(key=os.environ["WANDB_API_KEY"])
        print("Weights & Biases login successful")
    else:
        print("WANDB_API_KEY not found. Skipping W&B login.")

# Kaggle authentication
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
    print("Kaggle credentials not found. Please login manually:")
    kagglehub.login()

# Hugging Face authentication
if "HF_TOKEN" in os.environ and os.environ["HF_TOKEN"]:
    hf_token = os.environ["HF_TOKEN"]
    !huggingface-cli login --token "$hf_token"
else:
    print("HF_TOKEN not found. Please set it for model download.")


[0mRunning in Google Colab environment
The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `hf`CLI if you want to set the git credential as well.
Token is valid (permission: write).
The token `colab` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


### Import Libraries

Now we import all the necessary libraries. Here's what each major import does:

- **jax, jnp**: Core library for accelerated numerical computing on TPU/GPU
- **flax.nnx**: Neural network library with the new NNX API
- **tunix**: Google's library for Gemma model training
- **qwix**: LoRA (Low-Rank Adaptation) for efficient fine-tuning
- **grain**: Efficient data loading for JAX
- **optax**: Gradient transformation and optimization library
- **peft_trainer**: Tunix's built-in SFT trainer (following qlora_gemma approach)


In [3]:
# ============================================================================
# CELL 3: Import All Required Libraries
# ============================================================================

import functools
from pprint import pprint
import re
import sys
import csv
import json
import shutil
import logging

# JAX ecosystem
from flax import nnx
import grain
import humanize
from huggingface_hub import snapshot_download
import jax
import jax.numpy as jnp
import kagglehub
import numpy as np
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from datasets import load_dataset

# Tunix - Google's Gemma training library
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import model as gemma_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
from tunix.models.gemma3 import params as gemma_params
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from tunix.sft import peft_trainer  # SFT trainer from qlora_gemma approach
from tunix.sft import utils
from tunix.sft.utils import show_hbm_usage

logger = logging.getLogger()
logger.setLevel(logging.INFO)

print("All imports successful!")
print(f"JAX version: {jax.__version__}")
print(f"Number of devices available: {len(jax.devices())}")
print(f"Device type: {jax.devices()[0].device_kind}")




All imports successful!
JAX version: 0.8.3.dev20260112+a4036e4a0
Number of devices available: 1
Device type: TPU v6 lite


## Section 2: Hyperparameter Configuration

### Understanding the Key Hyperparameters

This section defines all the configuration parameters for both SFT and GRPO training. Let's break down the most important ones:

**Model Configuration:**
- `MODEL_ID`: The base Gemma model to fine-tune
- `RANK` and `ALPHA`: LoRA parameters that control the size of trainable adapters

**Training Configuration:**
- `MAX_SEQ_LENGTH`: Maximum sequence length for training (longer = more memory)
- `LEARNING_RATE`: Step size for gradient updates (too high = unstable, too low = slow)

**GRPO-Specific Parameters:**
- `NUM_GENERATIONS`: How many responses to generate per prompt for comparison (the "G" in GRPO)
- `BETA`: KL divergence penalty coefficient (keeps the policy close to reference)
- `EPSILON`: Clipping parameter for stable updates (similar to PPO)


In [4]:
# ============================================================================
# CELL 4: Hyperparameter Configuration
# ============================================================================
# All training parameters are defined here for easy modification.
# Adjust these based on your hardware and training requirements.

# ===========================================================================
# MODEL CONFIGURATION
# ===========================================================================
MODEL_ID = "google/gemma-3-1b-it"  # Base model to fine-tune
GEMMA_TOKENIZER_PATH = "gs://gemma-data/tokenizers/tokenizer_gemma3.model"

# ===========================================================================
# DATA PATHS
# ===========================================================================
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
SFT_DATA_DIR = "./data/sft"  # For cold start SFT data
TRAIN_FRACTION = 0.9  # Fraction of data used for training (rest for validation)

# ===========================================================================
# LoRA CONFIGURATION
# ===========================================================================
# LoRA (Low-Rank Adaptation) allows us to fine-tune large models efficiently
# by only training a small number of additional parameters.
RANK = 32       # Rank of the LoRA matrices (higher = more parameters, better quality)
                # Reduce to 32 or 16 if running out of memory (reduces memory usage)
ALPHA = 64.0    # Scaling factor for LoRA (typically set equal to RANK)
                # Keep equal to RANK when changing RANK

# ===========================================================================
# SHARDING CONFIGURATION (for TPU)
# ===========================================================================
# Configure the mesh for distributed training across TPU cores
NUM_TPUS = len(jax.devices())
print(f"Detected {NUM_TPUS} TPU cores")

if NUM_TPUS == 8:
    MESH_COUNTS = (1, 4)  # For v3-8 TPU
elif NUM_TPUS == 4:
    MESH_COUNTS = (1, 4)  # For v4-8 TPU
elif NUM_TPUS == 1:
    MESH_COUNTS = (1, 1)  # Single device
else:
    # Default configuration for other setups
    MESH_COUNTS = (1, NUM_TPUS)

MESH = [MESH_COUNTS, ("fsdp", "tp")]

# ===========================================================================
# DEBUG MODE CONFIGURATION
# ===========================================================================
# Set to True for fast debugging (reduces training time significantly)
DEBUG_MODE = False  # Set to True to reduce steps, batches, and generation length for quick testing

# ===========================================================================
# SEQUENCE LENGTH CONFIGURATION
# ===========================================================================
# Note: Changing MAX_SEQ_LENGTH after training starts is supported for LoRA checkpoints.
# LoRA parameters are sequence-length independent, so you can continue training with
# a different MAX_SEQ_LENGTH. However, note that:
# - Data preprocessing will change (different tokens may be truncated/kept)
# - Memory usage will change (larger MAX_SEQ_LENGTH = more memory)
# - Training data format will differ (but this is usually fine)
MAX_PROMPT_LENGTH = 256           # Maximum length of input prompts
TOTAL_GENERATION_STEPS = 1792      # Maximum tokens to generate
if DEBUG_MODE:
    TOTAL_GENERATION_STEPS = 256  # Reduced for fast debugging
MAX_SEQ_LENGTH = MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS

# ===========================================================================
# COLD START SFT CONFIGURATION (using PeftTrainer)
# ===========================================================================
SFT_LEARNING_RATE = 2e-4          # Learning rate for SFT (higher than GRPO)
SFT_BATCH_SIZE = 1                # Batch size for SFT (reduce to 1 if OOM)
SFT_MAX_STEPS = 500               # Number of SFT training steps
if DEBUG_MODE:
    SFT_MAX_STEPS = 10            # Reduced for fast debugging
EVAL_EVERY_N_STEPS = 50           # Evaluate every N steps
SFT_NUM_EPOCHS = 3                # Number of epochs for SFT

# Memory optimization options (if running out of memory):
# 1. Reduce SFT_BATCH_SIZE from 2 to 1 (reduces memory by ~50% per batch)
# 2. Reduce RANK from 64 to 32 or 16 (reduces LoRA parameters)
# 3. Filter long samples: Set SFT_MAX_TOKENS_FILTER to filter out long examples
# 4. Use shorter dataset or reduce SFT_NUM_EXAMPLES
SFT_MAX_TOKENS_FILTER = TOTAL_GENERATION_STEPS  # Filter examples longer than this (None = no filtering)
                               # Set to a value like 2048 or 3072 to keep only shorter examples
                               # This reduces memory usage by excluding very long samples

# ===========================================================================
# GRPO CONFIGURATION (Based on DeepSeek-R1 Paper)
# ===========================================================================
# Reference: DeepSeek-R1 uses EPSILON=10, BETA=0.001, NUM_GENERATIONS=16
# We adjust values for single TPU memory constraints while following their principles.

# Generation parameters during GRPO training
TEMPERATURE = 1.0     # DeepSeek-R1: 1.0 for RL Stage 1 (high for exploration)
TOP_P = 1.0           # Nucleus sampling parameter
TOP_K = 50            # Top-k sampling parameter

# GRPO algorithm parameters (adjusted from DeepSeek-R1)
NUM_GENERATIONS = 4   # DeepSeek-R1 uses 16, reduced for memory
if DEBUG_MODE:
    NUM_GENERATIONS = 2           # Reduced for fast debugging (fewer generations per prompt)
NUM_ITERATIONS = 1    # Number of iterations per batch
BETA = 0.001          # KL coefficient (DeepSeek-R1: 0.001, much lower than typical PPO!)
EPSILON = 10.0        # Clip ratio (DeepSeek-R1: 10, NOT 0.2 like PPO!)

# GRPO training parameters
GRPO_LEARNING_RATE = 3e-6         # DeepSeek-R1: 3e-6
GRPO_BATCH_SIZE = 1               # Keep small due to memory constraints
GRPO_GRADIENT_ACCUMULATION = 4    # Effective batch: 4 questions * 4 gen = 16 responses
GRPO_MAX_STEPS = 2500             # Number of GRPO training steps
if DEBUG_MODE:
    GRPO_MAX_STEPS = 10           # Reduced for fast debugging (just a few steps to test)
GRPO_WARMUP_RATIO = 0.1           # Warmup as fraction of total steps
REFERENCE_UPDATE_STEPS = 400      # DeepSeek-R1: Update reference model every 400 steps

# ===========================================================================
# OPTIMIZER CONFIGURATION
# ===========================================================================
B1 = 0.9              # Adam beta1
B2 = 0.99             # Adam beta2
WEIGHT_DECAY = 0.1    # Weight decay for regularization
MAX_GRAD_NORM = 0.1   # Gradient clipping (important for stability)

# ===========================================================================
# CHECKPOINT CONFIGURATION
# ===========================================================================
SFT_CKPT_DIR = "./checkpoints/sft/"           # SFT checkpoint directory
GRPO_CKPT_DIR = "./checkpoints/grpo/"         # GRPO checkpoint directory
FINAL_MODEL_DIR = "./checkpoints/final/"      # Final merged model
SAVE_INTERVAL_STEPS = 500                     # Save every N steps
MAX_TO_KEEP = 4                               # Maximum checkpoints to keep
RESUME_FROM_CHECKPOINT = False                 # If False, delete existing checkpoints and start from scratch

# ===========================================================================
# INFERENCE CONFIGURATION
# ===========================================================================
GENERATION_CONFIGS = {
    "greedy": {"temperature": None, "top_k": 1, "top_p": None},      # Deterministic
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},    # Balanced
    "creative": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},  # More random
}

print("\nConfiguration loaded successfully!")
print(f"Model: {MODEL_ID}")
print(f"LoRA Rank: {RANK}, Alpha: {ALPHA}")
print(f"SFT Steps: {SFT_MAX_STEPS}, GRPO Steps: {GRPO_MAX_STEPS}")
if DEBUG_MODE:
    print("\n" + "=" * 60)
    print("⚠️  DEBUG MODE ENABLED - Fast testing configuration")
    print("=" * 60)
    print(f"  - SFT Steps: {SFT_MAX_STEPS} (reduced from 500)")
    print(f"  - GRPO Steps: {GRPO_MAX_STEPS} (reduced from 2500)")
    print(f"  - Generations per prompt: {NUM_GENERATIONS} (reduced from 4)")
    print(f"  - Generation length: {TOTAL_GENERATION_STEPS} tokens (reduced from 768)")
    print("=" * 60 + "\n")


Detected 1 TPU cores

Configuration loaded successfully!
Model: google/gemma-3-1b-it
LoRA Rank: 32, Alpha: 64.0
SFT Steps: 500, GRPO Steps: 2500


## Section 3: Utility Functions and Special Tokens

### Output Format Definition

We define a specific output format that the model must learn to use. This format separates:
1. **Reasoning**: The model's thought process (between `<reasoning>` and `</reasoning>` tags)
2. **Answer**: The final answer (between `<answer>` and `</answer>` tags)

**Note**: The data may contain `<think>` or `<think>` markers which we will convert to `<reasoning>` format.


In [5]:
# ============================================================================
# CELL 5: Special Tokens and Template Definition
# ============================================================================

# Define special tokens for structured output
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"

# System prompt that instructs the model on the expected output format
SYSTEM_PROMPT = f"""You are given a problem. First, think about the problem \
and provide your reasoning. Place it between {reasoning_start} and \
{reasoning_end}. Then, provide the final answer between {solution_start} and {solution_end}."""

# Template for formatting prompts (Gemma chat format)
TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model
"""

# Regex pattern to validate output format
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

# Regex to extract numbers from answers
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})",
    flags=re.MULTILINE | re.DOTALL
)

# Test the format matching
example = f"{reasoning_start}Let me think step by step...{reasoning_end}{solution_start}42{solution_end}"
print("Example output format:")
print(example)
print(f"\nFormat validation: {'PASS' if match_format.search(example) else 'FAIL'}")


Example output format:
<reasoning>Let me think step by step...</reasoning><answer>42</answer>

Format validation: PASS


In [6]:
# ============================================================================
# CELL 6: Utility Functions
# ============================================================================

def extract_hash_answer(text: str) -> str:
    """Extract answer from GSM8K format (after #### marker)."""
    if "####" not in text:
        return None
    return text.split("####")[1].strip()


def create_directories():
    """Create all necessary directories for checkpoints and data."""
    dirs = [TRAIN_DATA_DIR, TEST_DATA_DIR, SFT_DATA_DIR,
            SFT_CKPT_DIR, GRPO_CKPT_DIR, FINAL_MODEL_DIR]
    for d in dirs:
        os.makedirs(d, exist_ok=True)
    print("All directories created successfully!")

create_directories()


All directories created successfully!


## Section 4: Model Loading

### Loading the Base Gemma Model

We download the Gemma 3 1B-IT model from Hugging Face and load it using Tunix. The model is:
- **Gemma 3 1B-IT**: A 1 billion parameter instruction-tuned model from Google
- Loaded with safetensors format for efficient memory usage
- Sharded across TPU cores using the mesh configuration

**Note**: You need to have accepted the Gemma license on Kaggle to download the model.


In [7]:
# ============================================================================
# CELL 7: Download and Load Base Model
# ============================================================================

# Download model from Hugging Face (skip PyTorch weights)
ignore_patterns = ["*.pth"]
print(f"Downloading {MODEL_ID} from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=MODEL_ID,
    ignore_patterns=ignore_patterns
)
print(f"Model downloaded to: {local_model_path}")

# Load EOS tokens from generation config
EOS_TOKENS = []
generation_config_path = os.path.join(local_model_path, "generation_config.json")
if os.path.exists(generation_config_path):
    with open(generation_config_path, "r") as f:
        generation_configs = json.load(f)
    EOS_TOKENS = generation_configs.get("eos_token_id", [])
    print(f"EOS token IDs: {EOS_TOKENS}")


Downloading google/gemma-3-1b-it from Hugging Face...


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

Model downloaded to: /root/.cache/huggingface/hub/models--google--gemma-3-1b-it/snapshots/dcc83ea841ab6100d6b47a070329e1ba4cf78752
EOS token IDs: [1, 106]


In [8]:
# ============================================================================
# CELL 8: Initialize Model with TPU Mesh
# ============================================================================

# Select model configuration based on MODEL_ID
if "gemma-3-270m" in MODEL_ID:
    model_config = gemma_lib.ModelConfig.gemma3_270m()
elif "gemma-3-1b" in MODEL_ID:
    model_config = gemma_lib.ModelConfig.gemma3_1b_it()
else:
    raise ValueError(f"Unsupported model: {MODEL_ID}")

# Create TPU mesh for distributed training
mesh = jax.make_mesh(
    *MESH,
    axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0])
)

# Load model with proper sharding
print("Loading model onto TPU mesh...")
with mesh:
    gemma3 = params_safetensors_lib.create_model_from_safe_tensors(
        local_model_path,
        model_config,
        mesh
    )

print("Base model loaded successfully!")
nnx.display(gemma3)


Loading model onto TPU mesh...
Base model loaded successfully!


### LoRA (Low-Rank Adaptation) Setup

LoRA is a technique that allows us to fine-tune large models efficiently by:
1. Freezing the original model weights
2. Adding small trainable "adapter" matrices to key layers
3. Only training these adapter matrices (much fewer parameters!)

**Benefits**:
- Reduces memory usage significantly
- Faster training
- Easy to switch between different fine-tuned versions
- Original model weights remain unchanged


In [9]:
# ============================================================================
# CELL 9: Create LoRA Model and Load Tokenizer
# ============================================================================

def get_lora_model(base_model, mesh):
    """
    Apply LoRA adapters to the base model.

    LoRA targets specific layers:
    - q_einsum, kv_einsum: Attention query/key-value projections
    - gate_proj, down_proj, up_proj: MLP layers
    - attn_vec_einsum: Attention output projection
    """
    lora_provider = qwix.LoraProvider(
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
            ".*attn_vec_einsum"
        ),
        rank=RANK,
        alpha=ALPHA,
    )

    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(
        base_model, lora_provider, **model_input
    )

    # Shard the LoRA model across devices
    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)

    return lora_model


# Create the policy model with LoRA adapters
print("Creating LoRA policy model...")
lora_policy = get_lora_model(gemma3, mesh=mesh)
print("LoRA model created!")

# Load tokenizer
tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)
if tokenizer.eos_id() not in EOS_TOKENS:
    EOS_TOKENS.append(tokenizer.eos_id())
print(f"Tokenizer loaded. EOS tokens: {EOS_TOKENS}")

# Display memory usage
show_hbm_usage()


Creating LoRA policy model...


INFO:absl:[QWIX] module='layers/0/attn/q_einsum' op=einsum0 rule=0
INFO:absl:[QWIX] module='layers/0/attn/kv_einsum' op=einsum0 rule=0
INFO:absl:[QWIX] module='layers/0/attn' op=einsum0 rule=None
INFO:absl:[QWIX] module='layers/0/attn' op=einsum1 rule=None
INFO:absl:[QWIX] module='layers/0/attn/attn_vec_einsum' op=einsum0 rule=0
INFO:absl:[QWIX] module='layers/0/mlp/gate_proj' op=dot_general0 rule=0
INFO:absl:[QWIX] module='layers/0/mlp/up_proj' op=dot_general0 rule=0
INFO:absl:[QWIX] module='layers/0/mlp/down_proj' op=dot_general0 rule=0
INFO:absl:[QWIX] module='layers/1/attn/q_einsum' op=einsum0 rule=0
INFO:absl:[QWIX] module='layers/1/attn/kv_einsum' op=einsum0 rule=0
INFO:absl:[QWIX] module='layers/1/attn' op=einsum0 rule=None
INFO:absl:[QWIX] module='layers/1/attn' op=einsum1 rule=None
INFO:absl:[QWIX] module='layers/1/attn/attn_vec_einsum' op=einsum0 rule=0
INFO:absl:[QWIX] module='layers/1/mlp/gate_proj' op=dot_general0 rule=0
INFO:absl:[QWIX] module='layers/1/mlp/up_proj' op=do

LoRA model created!


INFO:absl: - Pathways not available. Using default HBM stats collector
INFO:absl:Using 3.8 GiB / 31.2 GiB (0.12101333179057862) on TPU_0(process=0,(0,0,0,0))


Tokenizer loaded. EOS tokens: [1, 106]


## Section 5: Data Preparation

### Understanding the Two Datasets

We use **two different datasets** for the two training stages:

---

#### Dataset 1: Bespoke-Stratos-17k (for Cold Start SFT)

**Source**: `bespokelabs/Bespoke-Stratos-17k` on HuggingFace

**Purpose**: Teach the model the reasoning format (`<reasoning>...<answer>`)

**Key Features**:
- ~17,000 high-quality Chain-of-Thought examples
- Distilled from DeepSeek-R1 model
- Contains LONG reasoning traces with self-reflection and verification
- Includes `<think>...</think>` format (we convert to `<reasoning>...</reasoning>`)
- Perfect for cold-start: teaches the model HOW to think, not just WHAT to answer

**Why this dataset?**
- Shows the model examples of step-by-step reasoning
- Teaches proper output structure
- Prevents chaotic/unreadable outputs during GRPO

---

#### Dataset 2: GSM8K / grade-school-math-8k-q-a (for GRPO)

**Source**: `thedevastator/grade-school-math-8k-q-a` on Kaggle (or OpenAI's GSM8K)

**Purpose**: Strengthen reasoning through reward-based learning

**Key Features**:
- ~8,000 grade school math word problems
- Simple format: question + answer (with `####` separator)
- Answers are verifiable (correct or incorrect)
- Perfect for GRPO: clear right/wrong signal for rewards

**Why this dataset?**
- GRPO needs verifiable answers to compute rewards
- Math problems have definitive correct answers
- Model learns to reason correctly, not just format correctly


In [10]:
# ============================================================================
# CELL 10: Data Loading Functions
# ============================================================================

# ----------------------------------------------------------------------------
# DATASET 1: Bespoke-Stratos-17k for Cold Start SFT
# ----------------------------------------------------------------------------

def load_bespoke_stratos_dataset(num_examples=None):
    """
    Load Bespoke-Stratos-17k dataset from HuggingFace for Cold Start SFT.

    This dataset contains high-quality Chain-of-Thought examples distilled
    from DeepSeek-R1. It's specifically designed for teaching models the
    reasoning format.

    Original format uses <think>...</think>, we convert to <reasoning>...</reasoning>

    Args:
        num_examples: Maximum number of examples to load (None = all ~17k)

    Returns:
        List of dicts with 'prompt' and 'completion' keys
    """
    print("Loading Bespoke-Stratos-17k from HuggingFace...")

    # Load from HuggingFace
    dataset = load_dataset("bespokelabs/Bespoke-Stratos-17k", split="train")

    sft_data = []
    for i, example in enumerate(tqdm(dataset, desc="Processing examples")):
        if num_examples and i >= num_examples:
            break

        # Extract conversations
        conversations = example.get("conversations", [])
        if len(conversations) < 2:
            continue

        # Get user message and assistant response
        user_msg = None
        assistant_msg = None

        for conv in conversations:
            role = conv.get("from", conv.get("role", ""))
            content = conv.get("value", conv.get("content", ""))

            if role in ["human", "user"]:
                user_msg = content
            elif role in ["gpt", "assistant"]:
                assistant_msg = content

        if not user_msg or not assistant_msg:
            continue

        # Convert various thinking tags to our standard <reasoning>...</reasoning> format
        # Bespoke-Stratos uses: <|begin_of_thought|>...<|end_of_thought|>
        # Some datasets use: <think>...</think>
        # Also handle: <think>...</think> or <thinking>...</thinking>
        completion = assistant_msg

        # Handle Bespoke-Stratos format
        completion = completion.replace("<|begin_of_thought|>", reasoning_start)
        completion = completion.replace("<|end_of_thought|>", reasoning_end)
        completion = completion.replace("<|begin_of_solution|>", solution_start)
        completion = completion.replace("<|end_of_solution|>", solution_end)

        # Handle redacted_reasoning format (convert to reasoning)
        completion = completion.replace("<think>", reasoning_start)
        completion = completion.replace("</think>", reasoning_end)

        # Handle think/thinking tags (convert to reasoning)
        completion = re.sub(r"<think>", reasoning_start, completion, flags=re.IGNORECASE)
        completion = re.sub(r"</think>", reasoning_end, completion, flags=re.IGNORECASE)
        completion = re.sub(r"<thinking>", reasoning_start, completion, flags=re.IGNORECASE)
        completion = re.sub(r"</thinking>", reasoning_end, completion, flags=re.IGNORECASE)

        # If still no answer tags, add them
        if solution_start not in completion:
            # Try to extract final answer after reasoning
            if reasoning_end in completion:
                parts = completion.split(reasoning_end)
                if len(parts) > 1 and parts[1].strip():
                    # Wrap the part after reasoning in answer tags
                    completion = parts[0] + reasoning_end + "\n" + solution_start + parts[1].strip() + solution_end
                else:
                    completion = completion + f"\n{solution_start}See reasoning above{solution_end}"
            else:
                completion = completion + f"\n{solution_start}See reasoning above{solution_end}"

        # Format prompt with our template
        prompt = TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=user_msg,
        )

        sft_data.append({
            "prompt": prompt,
            "completion": completion,
        })

    print(f"Loaded {len(sft_data)} SFT examples from Bespoke-Stratos-17k")
    return sft_data


# ----------------------------------------------------------------------------
# DATASET 2: GSM8K for GRPO Training
# ----------------------------------------------------------------------------

def download_kaggle_dataset(target_dir="./data/gsm8k"):
    """Download GSM8K dataset from Kaggle."""
    os.makedirs(target_dir, exist_ok=True)
    src = kagglehub.dataset_download("thedevastator/grade-school-math-8k-q-a")
    src = Path(src)
    dst = Path(target_dir)

    for csv_file in src.glob("*.csv"):
        shutil.copy2(csv_file, dst / csv_file.name)
        print(f"Copied {csv_file.name}")
    return target_dir


def get_grpo_dataset(data_dir, split="train", source="kaggle") -> grain.MapDataset:
    """
    Load and format dataset for GRPO training.

    Returns a dataset with:
    - prompts: Formatted input prompts
    - question: Original question text
    - answer: Ground truth answer for verification
    """
    if source == "tfds":
        import tensorflow_datasets.text.gsm8k
        data = tfds.data_source(
            "gsm8k",
            split=split,
            data_dir=data_dir,
            builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
            download=True,
        )
    elif source == "kaggle":
        kaggle_dir = download_kaggle_dataset(data_dir)
        file_name = "main_" + split + ".csv"
        csv_path = os.path.join(kaggle_dir, file_name)

        data = []
        with open(csv_path, newline="", encoding="utf-8") as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                data.append({
                    "question": row["question"],
                    "answer": row["answer"],
                })
    else:
        raise ValueError(f"Unknown source: {source}")

    def _as_text(v):
        return v if isinstance(v, str) else v.decode("utf-8")

    dataset = (
        grain.MapDataset.source(data)
        .shuffle(seed=42)
        .map(lambda x: {
            "prompts": TEMPLATE.format(
                system_prompt=SYSTEM_PROMPT,
                question=_as_text(x["question"]),
            ),
            "question": _as_text(x["question"]),
            "answer": extract_hash_answer(_as_text(x["answer"])),
        })
    )
    return dataset


print("Data loading functions defined!")


Data loading functions defined!


In [11]:
# ============================================================================
# CELL 10.5: Load Evaluation Dataset and Define Evaluation Functions
# ============================================================================
# Load a small test dataset for evaluation before and after SFT training.
# Also define evaluate and generate functions here for use in comparison cells.

print("=" * 60)
print("PREPARING EVALUATION DATASET")
print("=" * 60)

# Load a small test dataset for evaluation (before and after SFT)
NUM_EVAL_SAMPLES = 32  # Small sample for quick evaluation
eval_test_dataset = get_grpo_dataset(TEST_DATA_DIR, "test", "kaggle")
eval_test_dataset = eval_test_dataset.batch(1)[:NUM_EVAL_SAMPLES]

print(f"Loaded {NUM_EVAL_SAMPLES} test samples for evaluation")
print("=" * 60)

# Define evaluation functions (same as in original notebook)
def generate(question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None):
    """Generate response for a given question."""
    if isinstance(question, str):
        input_batch = [TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=question)]
    else:
        input_batch = [
            TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=q)
            for q in question
        ]

    out_data = sampler(
        input_strings=input_batch,
        max_generation_steps=TOTAL_GENERATION_STEPS,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        echo=False,
        seed=seed,
        eos_tokens=EOS_TOKENS,
    )

    output = out_data.text
    return output[0] if isinstance(question, str) else output


def evaluate(dataset, sampler, temperature=0.7, top_k=50, top_p=0.95, debug_samples=3):
    """Evaluate model on dataset and compute metrics."""
    correct = 0
    partially_correct = 0
    correct_format = 0
    total = 0

    for batch_idx, batch in enumerate(tqdm(dataset, desc="Evaluating")):
        answers = batch["answer"]
        questions = batch["question"]

        responses = generate(questions, sampler, temperature, top_k, top_p)

        for q_idx, (question, response, answer) in enumerate(zip(questions, responses, answers)):
            # Debug: Print first few samples to see what model generates
            if total < debug_samples:
                print(f"\n--- Sample {total + 1} ---")
                print(f"Question: {question[:100]}...")
                print(f"Expected Answer: {answer}")
                print(f"Model Response (first 500 chars):\n{response[:500]}")
                print(f"Full Response Length: {len(response)} chars")
                # Check if format matches
                format_match = match_format.search(response)
                print(f"Format Match: {format_match is not None}")
                if format_match:
                    print(f"Matched Format: {format_match.group(0)[:200]}...")
                # Check if numbers can be extracted
                num_match = match_numbers.search(response)
                print(f"Number Match: {num_match is not None}")
                if num_match:
                    print(f"Extracted Number: {num_match.group(1)}")
                print("-" * 60)

            # Check answer
            extracted = match_numbers.search(response)
            if extracted:
                try:
                    pred = float(extracted.group(1).strip())
                    true_val = float(answer.strip())
                    if pred == true_val:
                        correct += 1
                    ratio = pred / true_val if true_val != 0 else 0
                    if 0.9 <= ratio <= 1.1:
                        partially_correct += 1
                except:
                    pass

            # Check format
            if match_format.search(response):
                correct_format += 1

            total += 1

            # Print progress
            if total % 10 == 0:
                print(f"Progress: {correct}/{total} correct ({correct/total*100:.1f}%)")

    return {
        "accuracy": correct / total * 100 if total > 0 else 0,
        "partial_accuracy": partially_correct / total * 100 if total > 0 else 0,
        "format_accuracy": correct_format / total * 100 if total > 0 else 0,
        "correct": correct,
        "total": total,
    }

print("\nEvaluation functions defined!")
print("Ready for pre-SFT and post-SFT evaluation.")


PREPARING EVALUATION DATASET
Using Colab cache for faster access to the 'grade-school-math-8k-q-a' dataset.
Copied main_test.csv
Copied main_train.csv
Copied socratic_train.csv
Copied socratic_test.csv
Loaded 32 test samples for evaluation

Evaluation functions defined!
Ready for pre-SFT and post-SFT evaluation.


In [12]:
# ============================================================================
# CELL 10.6: Evaluate Model Before SFT Training
# ============================================================================
# This cell evaluates the model's performance BEFORE cold start SFT training.
# Uses the evaluate function to get quantitative metrics.

print("=" * 60)
print("EVALUATING MODEL BEFORE SFT TRAINING")
print("=" * 60)
print("\nRunning evaluation on test dataset...")
print("This will measure accuracy and format compliance before training.\n")

# Create a sampler for inference
pre_sft_sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

# Run evaluation
pre_sft_results = evaluate(eval_test_dataset, pre_sft_sampler, **GENERATION_CONFIGS["greedy"])

print("\n" + "=" * 60)
print("PRE-SFT EVALUATION RESULTS")
print("=" * 60)
print(f"  Answer Accuracy:    {pre_sft_results['accuracy']:.2f}%")
print(f"  Partial Accuracy:   {pre_sft_results['partial_accuracy']:.2f}%")
print(f"  Format Accuracy:    {pre_sft_results['format_accuracy']:.2f}%")
print(f"  Correct/Total:      {pre_sft_results['correct']}/{pre_sft_results['total']}")
print("=" * 60)

print("\n" + "=" * 60)
print("READY FOR SFT TRAINING")
print("=" * 60)
print("After SFT training, we expect improvements in:")
print("  - Format Accuracy: Model should learn <reasoning>...</reasoning><answer>...</answer> format")
print("  - Answer Accuracy: Better reasoning may lead to more correct answers")
print("=" * 60 + "\n")


EVALUATING MODEL BEFORE SFT TRAINING

Running evaluation on test dataset...
This will measure accuracy and format compliance before training.



Evaluating:   0%|          | 0/32 [00:00<?, ?it/s]


--- Sample 1 ---
Question: Mr Hezekiah had 20 trucks from his store supplying fertiliser to different farmers in his hometown d...
Expected Answer: 300
Model Response (first 500 chars):
reasoning:
The problem states that each truck carried 20 tons of fertiliser. Mr. Hezekiah dispatched 20 trucks. Therefore, the total amount of fertiliser initially planned was 20 trucks * 20 tons/truck = 400 tons.
Two hours after the trucks departed, a quarter of the lorries failed. This means that the number of lorries that failed is (1/4) * 20 = 5 lorries.
Each lorry carried 20 tons of fertiliser. So, the total amount of fertiliser that failed is 5 lorries * 20 tons/lorry = 100 tons.
The numbe
Full Response Length: 850 chars
Format Match: False
Number Match: True
Extracted Number: 300
------------------------------------------------------------

--- Sample 2 ---
Question: Grandpa loves to eat jelly beans, but how many jelly beans he can eat depends on the size of the bea...
Expected Answer: 450
Model

## Section 6: Stage 1 - Cold Start SFT Training

Cold Start SFT teaches the model the output format before GRPO strengthens reasoning.

**Key Difference from Original**: This notebook uses `PeftTrainer` from Tunix (following qlora_gemma approach) instead of a custom training loop.

Key goals:
- Teach the `<reasoning>` and `<answer>` tag format
- Show examples of step-by-step thinking
- Prevent chaotic outputs during later RL training


In [13]:
# ============================================================================
# CELL 11: Prepare SFT Dataset (Bespoke-Stratos-17k)
# ============================================================================
# We use Bespoke-Stratos-17k for cold start SFT - this dataset contains
# high-quality Chain-of-Thought examples that teach proper reasoning format.

print("=" * 60)
print("Loading SFT Dataset: Bespoke-Stratos-17k")
print("=" * 60)

# Load Bespoke-Stratos-17k from HuggingFace
# Set num_examples to limit for faster training (None = use all ~17k)
SFT_NUM_EXAMPLES = 4000  # Adjust based on available time

sft_data = load_bespoke_stratos_dataset(num_examples=SFT_NUM_EXAMPLES)

# Filter long examples if SFT_MAX_TOKENS_FILTER is set (memory optimization)
if SFT_MAX_TOKENS_FILTER is not None:
    print(f"\nFiltering examples longer than {SFT_MAX_TOKENS_FILTER} tokens...")
    original_count = len(sft_data)
    filtered_data = []
    for sample in tqdm(sft_data, desc="Filtering by token length"):
        full_text = sample['prompt'] + sample['completion']
        tokens = tokenizer.encode(full_text)
        if len(tokens) <= SFT_MAX_TOKENS_FILTER:
            filtered_data.append(sample)
    sft_data = filtered_data
    filtered_count = len(sft_data)
    print(f"Filtered: {original_count} -> {filtered_count} examples ({filtered_count/original_count*100:.1f}% kept)")
    print(f"Removed {original_count - filtered_count} examples that exceeded {SFT_MAX_TOKENS_FILTER} tokens")

# Preview an example
print("\nSample SFT training example:")
print("=" * 60)
sample = sft_data[0]
print(f"PROMPT:\n{sample['prompt'][:200]}...")
print(f"\nCOMPLETION:\n{sample['completion'][:300]}...")
print("=" * 60)


Loading SFT Dataset: Bespoke-Stratos-17k
Loading Bespoke-Stratos-17k from HuggingFace...


Processing examples:   0%|          | 0/16710 [00:00<?, ?it/s]

Loaded 2000 SFT examples from Bespoke-Stratos-17k

Filtering examples longer than 1792 tokens...


Filtering by token length:   0%|          | 0/2000 [00:00<?, ?it/s]

Filtered: 2000 -> 292 examples (14.6% kept)
Removed 1708 examples that exceeded 1792 tokens

Sample SFT training example:
PROMPT:
<start_of_turn>user
You are given a problem. First, think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer between <answer> a...

COMPLETION:
<reasoning>

Okay, let's see. The problem is about Judy's hits during the softball season. She had a total of 35 hits. Among those, 1 was a home run, 1 was a triple, and 5 were doubles. The rest were singles. We need to find out what percent of her hits were singles. The options are given from A to ...


In [14]:
# ============================================================================
# CELL 11.5: Inspect SFT Data Points to Verify Pattern Replacement
# ============================================================================
# Print several data points to verify that pattern replacement is correct

print("=" * 60)
print("INSPECTING SFT DATA POINTS")
print("=" * 60)
print(f"\nTotal data points loaded: {len(sft_data)}")
print(f"\nPrinting first 5 data points to check pattern replacement:\n")

for i in range(min(5, len(sft_data))):
    print(f"\n{'='*60}")
    print(f"Data Point {i+1}:")
    print(f"{'='*60}")

    sample = sft_data[i]

    # Print prompt (truncated if too long)
    print(f"\n[PROMPT]:")
    print("-" * 60)
    prompt_display = sample['prompt']
    if len(prompt_display) > 500:
        print(prompt_display[:500] + "...")
    else:
        print(prompt_display)

    # Print completion (full text to check pattern replacement)
    print(f"\n[COMPLETION]:")
    print("-" * 60)
    completion_display = sample['completion']
    if len(completion_display) > 1000:
        print(completion_display[:1000] + "...")
        print(f"\n[Full completion length: {len(completion_display)} characters]")
    else:
        print(completion_display)

    # Check for pattern indicators
    print(f"\n[PATTERN CHECK]:")
    print("-" * 60)
    has_reasoning_start = reasoning_start in completion_display
    has_reasoning_end = reasoning_end in completion_display
    has_answer_start = solution_start in completion_display
    has_answer_end = solution_end in completion_display

    print(f"  Contains <reasoning>: {has_reasoning_start and has_reasoning_end}")
    print(f"  Contains <answer>: {has_answer_start and has_answer_end}")

    if has_reasoning_start and has_reasoning_end:
        # Extract reasoning section
        try:
            reasoning_content = completion_display.split(reasoning_start)[1].split(reasoning_end)[0]
            print(f"  Reasoning length: {len(reasoning_content)} chars")
        except:
            print(f"  Warning: Reasoning tags found but content extraction failed")

    if has_answer_start and has_answer_end:
        # Extract answer section
        try:
            answer_content = completion_display.split(solution_start)[1].split(solution_end)[0]
            print(f"  Answer length: {len(answer_content)} chars")
        except:
            print(f"  Warning: Answer tags found but content extraction failed")

    # Check token length and truncation
    print(f"\n[TOKEN LENGTH CHECK]:")
    print("-" * 60)
    full_text = sample['prompt'] + sample['completion']
    tokens = tokenizer.encode(full_text)
    token_count = len(tokens)
    print(f"  Total tokens: {token_count}")
    print(f"  MAX_SEQ_LENGTH: {MAX_SEQ_LENGTH}")
    if token_count > MAX_SEQ_LENGTH:
        truncation_ratio = (token_count - MAX_SEQ_LENGTH) / token_count * 100
        print(f"  ⚠️  WILL BE TRUNCATED: {token_count - MAX_SEQ_LENGTH} tokens ({truncation_ratio:.1f}%) will be lost")
        print(f"  After truncation: {MAX_SEQ_LENGTH} tokens")
    else:
        print(f"  ✓ No truncation needed")

print(f"\n{'='*60}")
print("Data inspection complete!")
print("=" * 60)

# Additional statistics on truncation
print(f"\n{'='*60}")
print("TRUNCATION STATISTICS (checking all data points)")
print("=" * 60)
truncated_count = 0
token_lengths = []
for sample in sft_data:
    full_text = sample['prompt'] + sample['completion']
    tokens = tokenizer.encode(full_text)
    token_count = len(tokens)
    token_lengths.append(token_count)
    if token_count > MAX_SEQ_LENGTH:
        truncated_count += 1

print(f"Total data points: {len(sft_data)}")
print(f"Data points that will be truncated: {truncated_count} ({truncated_count/len(sft_data)*100:.1f}%)")
print(f"Data points within limit: {len(sft_data) - truncated_count} ({(len(sft_data)-truncated_count)/len(sft_data)*100:.1f}%)")
if token_lengths:
    print(f"\nToken length statistics:")
    print(f"  Min: {min(token_lengths)} tokens")
    print(f"  Max: {max(token_lengths)} tokens")
    print(f"  Mean: {sum(token_lengths)/len(token_lengths):.1f} tokens")
    print(f"  Median: {sorted(token_lengths)[len(token_lengths)//2]} tokens")
print(f"\nMAX_SEQ_LENGTH limit: {MAX_SEQ_LENGTH} tokens")
print("=" * 60)



INSPECTING SFT DATA POINTS

Total data points loaded: 292

Printing first 5 data points to check pattern replacement:


Data Point 1:

[PROMPT]:
------------------------------------------------------------
<start_of_turn>user
You are given a problem. First, think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer between <answer> and </answer>.

Return your final response within \boxed{}. During the softball season, Judy had $35$ hits.  Among her hits were $1$ home run, $1$ triple and $5$ doubles.  The rest of her hits were single.  What percent of her hits were single?
$\text{(A)}\ 28\% \qquad \text{(B)}\ 35\% \qquad \text{(...

[COMPLETION]:
------------------------------------------------------------
<reasoning>

Okay, let's see. The problem is about Judy's hits during the softball season. She had a total of 35 hits. Among those, 1 was a home run, 1 was a triple, and 5 were doubles. The rest were singles. We ne

In [15]:
# ============================================================================
# CELL 12: Create SFT Dataset in Grain Format for PeftTrainer
# ============================================================================
# PeftTrainer expects a dataset that yields TrainingInput objects.
# We need to convert our prompt+completion pairs into the right format.

def create_sft_training_dataset(sft_data, tokenizer, max_length=MAX_SEQ_LENGTH):
    """
    Create a Grain dataset for SFT training that works with PeftTrainer.

    This function:
    1. Combines prompt and completion into full text
    2. Tokenizes the text
    3. Creates TrainingInput objects that PeftTrainer expects
    """
    def format_example(example):
        """Combine prompt and completion into a single training text."""
        full_text = example["prompt"] + example["completion"]
        return {"text": full_text}

    def tokenize_and_pad_example(example):
        """Tokenize the text and pad to fixed length."""
        text = example["text"]
        tokens = tokenizer.encode(text)

        # Track original length before truncation/padding
        original_length = len(tokens)

        # Truncate if too long
        if len(tokens) > max_length:
            tokens = tokens[:max_length]
            actual_length = max_length
        else:
            actual_length = original_length

        # Pad to fixed length max_length
        pad_id = tokenizer.pad_id()
        if len(tokens) < max_length:
            padding = [pad_id] * (max_length - len(tokens))
            tokens = tokens + padding

        # Create input_tokens (padded to max_length)
        input_tokens = np.array(tokens, dtype=np.int32)

        # Create input_mask (1 for valid tokens, 0 for padding)
        input_mask = np.ones((max_length,), dtype=np.float32)
        input_mask[actual_length:] = 0.0

        return {
            "input_tokens": input_tokens,
            "input_mask": input_mask,
        }

    # Create dataset pipeline
    dataset = (
        grain.MapDataset.source(sft_data)
        .shuffle(seed=42)
        .map(format_example)
        .map(tokenize_and_pad_example)
        .batch(SFT_BATCH_SIZE)
    )

    return dataset


# Create training and validation datasets
print("Creating SFT training dataset...")
sft_dataset = create_sft_training_dataset(sft_data, tokenizer, max_length=MAX_SEQ_LENGTH)

# Split into train/validation (80/20 split)
total_examples = len(sft_data)
train_size = int(total_examples * 0.8)
val_size = total_examples - train_size

print(f"Total examples: {total_examples}")
print(f"Train examples: {train_size}")
print(f"Validation examples: {val_size}")

# Note: For simplicity, we'll use the full dataset for training
# In production, you might want to split properly
train_ds = sft_dataset
validation_ds = sft_dataset  # Use same for validation (or split properly)

print("SFT dataset created successfully!")


Creating SFT training dataset...
Total examples: 292
Train examples: 233
Validation examples: 59
SFT dataset created successfully!


In [16]:
# ============================================================================
# CELL 13: Define Model Input Function for PeftTrainer
# ============================================================================
# PeftTrainer requires a function that converts TrainingInput to model inputs.
# This follows the qlora_gemma approach.
# Note: We handle both TrainingInput objects and dict inputs for flexibility.

def gen_model_input_fn(x):
    """
    Convert TrainingInput (or dict) to model input format.

    This function:
    1. Handles both TrainingInput objects and dict inputs
    2. Creates position indices from padding mask
    3. Creates causal attention mask
    4. Returns dict with all required model inputs
    """
    # Handle both TrainingInput objects and dict inputs
    if isinstance(x, dict):
        input_tokens = x['input_tokens']
        input_mask = x['input_mask']
    else:
        # TrainingInput object
        input_tokens = x.input_tokens
        input_mask = x.input_mask

    pad_mask = input_tokens != tokenizer.pad_id()
    positions = utils.build_positions_from_mask(pad_mask)
    attention_mask = utils.make_causal_attn_mask(pad_mask)
    return {
        'input_tokens': input_tokens,
        'input_mask': input_mask,
        'positions': positions,
        'attention_mask': attention_mask,
    }

print("Model input function defined!")


Model input function defined!


In [17]:
# ============================================================================
# CELL 13.5: Launch TensorBoard for Monitoring
# ============================================================================

# Load TensorBoard extension
%load_ext tensorboard

# Start TensorBoard to monitor training metrics
%tensorboard --logdir ./logs/ --port=0

print("TensorBoard launched! Monitor training metrics above.")


Reusing TensorBoard on port 35635 (pid 6491), started 1:51:24 ago. (Use '!kill 6491' to kill it.)

<IPython.core.display.Javascript object>

TensorBoard launched! Monitor training metrics above.


In [18]:
# ============================================================================
# CELL 14: Run Cold Start SFT Training using PeftTrainer
# ============================================================================
# This stage teaches the model the reasoning format using Bespoke-Stratos-17k.
# We use PeftTrainer from Tunix (following qlora_gemma approach) instead of
# a custom training loop.

print("=" * 60)
print("STAGE 1: COLD START SFT TRAINING (using PeftTrainer)")
print("=" * 60)
print("\nUsing Bespoke-Stratos-17k dataset to teach reasoning format.")
print("This stage plants the <reasoning>...</reasoning><answer>...</answer> template.")
print("\nFollowing qlora_gemma approach with PeftTrainer.\n")

# Configure SFT optimizer
sft_optimizer = optax.adamw(
    learning_rate=SFT_LEARNING_RATE,
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)

# Add gradient clipping
sft_optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
    sft_optimizer,
)

# Convert checkpoint directory to absolute path (required by Orbax)
SFT_CKPT_DIR_ABS = os.path.abspath(SFT_CKPT_DIR)

# Handle checkpoint resuming: delete existing checkpoints if RESUME_FROM_CHECKPOINT is False
if not RESUME_FROM_CHECKPOINT and os.path.exists(SFT_CKPT_DIR_ABS):
    print(f"RESUME_FROM_CHECKPOINT is False. Deleting existing checkpoints in {SFT_CKPT_DIR_ABS}...")
    shutil.rmtree(SFT_CKPT_DIR_ABS)
    print("Existing checkpoints deleted. Training will start from scratch.\n")

os.makedirs(SFT_CKPT_DIR_ABS, exist_ok=True)

# Configure metrics logging
sft_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="./logs/sft/",
    flush_every_n_steps=20,
)

# Configure training
training_config = peft_trainer.TrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=SFT_MAX_STEPS,
    metrics_logging_options=sft_logging_options,
    checkpoint_root_directory=SFT_CKPT_DIR_ABS,
)

# Create PeftTrainer (following qlora_gemma approach)
trainer = peft_trainer.PeftTrainer(
    lora_policy,  # Use LoRA model for training
    sft_optimizer,
    training_config
).with_gen_model_input_fn(gen_model_input_fn)

print(f"\nSFT Configuration:")
print(f"  Learning rate: {SFT_LEARNING_RATE}")
print(f"  Max steps: {SFT_MAX_STEPS}")
print(f"  Batch size: {SFT_BATCH_SIZE}")
print(f"  Checkpoint dir: {SFT_CKPT_DIR_ABS}")
print(f"  Log dir: ./logs/sft/")

# Run training
print(f"\n{'='*60}")
print(f"Starting SFT training for {SFT_MAX_STEPS} steps...")
print(f"{'='*60}\n")

# The first couple of training steps might take up to 5 minutes to finish.
# Please be patient. If you experience long training steps, e.g. >10 minutes
# per step, please open a bug. Really appreciated!
with mesh:
    trainer.train(train_ds, validation_ds)

print("\n" + "=" * 60)
print("COLD START SFT TRAINING COMPLETE!")
print(f"Checkpoints saved to: {SFT_CKPT_DIR_ABS}")
print(f"Logs saved to: ./logs/sft/")
print("=" * 60)

show_hbm_usage()


STAGE 1: COLD START SFT TRAINING (using PeftTrainer)

Using Bespoke-Stratos-17k dataset to teach reasoning format.
This stage plants the <reasoning>...</reasoning><answer>...</answer> template.

Following qlora_gemma approach with PeftTrainer.

RESUME_FROM_CHECKPOINT is False. Deleting existing checkpoints in /content/checkpoints/sft...
Existing checkpoints deleted. Training will start from scratch.



INFO:absl:save_device_host_concurrent_bytes=None
INFO:absl:Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7ff989771d00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
INFO:absl:save_device_host_concurrent_bytes=None
INFO:absl:Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7ff989771d00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
INFO:absl:[thread=MainThread] Failed to get flag value for EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.
INFO:absl:[pro


SFT Configuration:
  Learning rate: 0.0002
  Max steps: 500
  Batch size: 1
  Checkpoint dir: /content/checkpoints/sft
  Log dir: ./logs/sft/

Starting SFT training for 500 steps...



INFO:absl:Train step 0 eval loss: 1.804707 - eval perplexity: 6.078190


Training:   0%|          | 0/500 [00:00<?, ?step/s]

INFO:absl:Using ThreadSafeKeyValueSignalingClient
INFO:absl:[process=0][thread=MainThread][wait_until_finished] No Save Finalize thread to wait for. Returning.
INFO:absl:[process=0] Saving checkpoint at step 1
INFO:absl:[process=0] Started async saving checkpoint to /content/checkpoints/sft/1.
INFO:absl:Creating tmp directory /content/checkpoints/sft/1.orbax-checkpoint-tmp
INFO:absl:Wrote Metadata={'item_handlers': None, 'metrics': {}, 'performance_metrics': {}, 'init_timestamp_nsecs': 1768258303231071892, 'commit_timestamp_nsecs': None, 'custom_metadata': {}}, json={"item_handlers": null, "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1768258303231071892, "commit_timestamp_nsecs": null, "custom_metadata": {}} to /content/checkpoints/sft/1.orbax-checkpoint-tmp/_CHECKPOINT_METADATA
INFO:absl:Creating tmp directory /content/checkpoints/sft/1.orbax-checkpoint-tmp/model_params.orbax-checkpoint-tmp
INFO:absl:Creating tmp directory /content/checkpoints/sft/1.orbax-checkpoi


COLD START SFT TRAINING COMPLETE!
Checkpoints saved to: /content/checkpoints/sft
Logs saved to: ./logs/sft/


In [19]:
# ============================================================================
# CELL 14.5: Evaluate Model After SFT Training
# ============================================================================
# This cell evaluates the model's performance AFTER cold start SFT training.
# Uses the evaluate function to show quantitative improvements.

print("=" * 60)
print("EVALUATING MODEL AFTER SFT TRAINING")
print("=" * 60)

# Load the latest checkpoint to ensure we're using the trained model
# PeftTrainer should update model in-place, but loading from checkpoint ensures correctness
import glob

# Find the latest checkpoint
checkpoint_dirs = glob.glob(os.path.join(SFT_CKPT_DIR_ABS, "*"))
checkpoint_steps = []
for ckpt_dir in checkpoint_dirs:
    if os.path.isdir(ckpt_dir):
        try:
            step = int(os.path.basename(ckpt_dir))
            checkpoint_steps.append(step)
        except ValueError:
            continue

if checkpoint_steps:
    latest_step = max(checkpoint_steps)
    print(f"\nLoading checkpoint from step {latest_step}...")

    # Load checkpoint
    ckpt_path = os.path.join(SFT_CKPT_DIR_ABS, str(latest_step), "model_params")

    # Get the shape/dtype structure for restoration
    abs_params = jax.tree.map(
        lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
        nnx.state(lora_policy, nnx.LoRAParam),
    )

    # Restore checkpoint
    checkpointer = ocp.StandardCheckpointer()
    trained_lora_params = checkpointer.restore(ckpt_path, target=abs_params)

    # Update model with trained parameters
    nnx.update(
        lora_policy,
        jax.tree.map(
            lambda a, b: b,
            nnx.state(lora_policy, nnx.LoRAParam),
            trained_lora_params,
        ),
    )

    print(f"Checkpoint loaded successfully from: {ckpt_path}")
else:
    print("\nWARNING: No checkpoint found. Using current model state.")
    print("Note: PeftTrainer should have updated model in-place during training.")

print("\nRunning evaluation on the same test dataset...")
print("This will show quantitative improvements after training.\n")

# Create sampler for post-SFT inference
post_sft_sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

# Run evaluation on the same dataset
post_sft_results = evaluate(eval_test_dataset, post_sft_sampler, **GENERATION_CONFIGS["greedy"])

print("\n" + "=" * 60)
print("POST-SFT EVALUATION RESULTS")
print("=" * 60)
print(f"  Answer Accuracy:    {post_sft_results['accuracy']:.2f}%")
print(f"  Partial Accuracy:   {post_sft_results['partial_accuracy']:.2f}%")
print(f"  Format Accuracy:    {post_sft_results['format_accuracy']:.2f}%")
print(f"  Correct/Total:      {post_sft_results['correct']}/{post_sft_results['total']}")
print("=" * 60)

print("\n" + "=" * 60)
print("BEFORE vs AFTER SFT COMPARISON")
print("=" * 60)
print(f"{'Metric':<20} {'Before SFT':<15} {'After SFT':<15} {'Change':<15}")
print("-" * 65)

acc_change = post_sft_results['accuracy'] - pre_sft_results['accuracy']
partial_change = post_sft_results['partial_accuracy'] - pre_sft_results['partial_accuracy']
format_change = post_sft_results['format_accuracy'] - pre_sft_results['format_accuracy']

print(f"{'Answer Accuracy':<20} {pre_sft_results['accuracy']:>6.2f}%      {post_sft_results['accuracy']:>6.2f}%      {acc_change:>+6.2f}%")
print(f"{'Partial Accuracy':<20} {pre_sft_results['partial_accuracy']:>6.2f}%      {post_sft_results['partial_accuracy']:>6.2f}%      {partial_change:>+6.2f}%")
print(f"{'Format Accuracy':<20} {pre_sft_results['format_accuracy']:>6.2f}%      {post_sft_results['format_accuracy']:>6.2f}%      {format_change:>+6.2f}%")
print("=" * 65)

print("\nKey Improvements:")
if format_change > 0:
    print(f"  Format accuracy improved by {format_change:.2f}%")
if acc_change > 0:
    print(f"  Answer accuracy improved by {acc_change:.2f}%")
if acc_change <= 0 and format_change <= 0:
    print("  Note: Model is learning the format, accuracy improvements may come in GRPO stage")
print("=" * 60 + "\n")


INFO:absl:save_device_host_concurrent_bytes=None
INFO:absl:Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7ff989771d00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
INFO:absl:[process=0][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.<locals>.<lambda> at 0x7ff964e20ae0> timeout: 600 secs and primary_host=0 for async checkpoint writes
INFO:absl:Restoring checkpoint from /content/checkpoints/sft/292/model_params.


EVALUATING MODEL AFTER SFT TRAINING

Loading checkpoint from step 292...


INFO:absl:[process=0] /jax/checkpoint/read/gbytes_per_sec: 160.808 MiB/s (total gbytes: 47.9 MiB) (time elapsed: 298 milliseconds) (per-host)
INFO:absl:Finished restoring checkpoint in 0.31 seconds from /content/checkpoints/sft/292/model_params.


Checkpoint loaded successfully from: /content/checkpoints/sft/292/model_params

Running evaluation on the same test dataset...
This will show quantitative improvements after training.



Evaluating:   0%|          | 0/32 [00:00<?, ?it/s]


--- Sample 1 ---
Question: Mr Hezekiah had 20 trucks from his store supplying fertiliser to different farmers in his hometown d...
Expected Answer: 300
Model Response (first 500 chars):


<reasoning>

Okay, so Mr. Hezekiah had 20 trucks, each carrying 20 tons of fertiliser. He dispatched them for delivery on a particular day. Then, two hours later, he got the news that a quarter of the number of lorries had mechanical failures and could not deliver the fertilisers. We need to figure out how many tons of fertiliser reached the farmers.

First, let me break down the problem. The total number of lorries dispatched is 20 trucks. So, 20 lorries. Then, two hours later, a quarter of t
Full Response Length: 5655 chars
Format Match: False
Number Match: False
------------------------------------------------------------

--- Sample 2 ---
Question: Grandpa loves to eat jelly beans, but how many jelly beans he can eat depends on the size of the bea...
Expected Answer: 450
Model Response (first 500

## Section 7: Stage 2 - GRPO Training

### What is GRPO?

GRPO (Group Relative Policy Optimization) is a reinforcement learning algorithm that:
1. Generates multiple responses for each prompt
2. Scores each response using reward functions
3. Updates the model to favor higher-scoring responses

Key advantages over PPO:
- No need for a separate value/critic model (saves memory!)
- Uses group-based advantage estimation
- More stable training for reasoning tasks


In [20]:
# ============================================================================
# CELL 15: Define Reward Functions for GRPO
# ============================================================================
# These functions evaluate model outputs and provide learning signals.

def match_format_exactly(prompts, completions, **kwargs):
    """
    Reward function 1: Exact format matching

    Awards 3 points if the output exactly matches the expected format:
    <reasoning>...</reasoning><answer>...</answer>
    """
    rewards = []
    for response in completions:
        match = match_format.search(response)
        rewards.append(3.0 if match else 0.0)
    return rewards


def match_format_approximately(prompts, completions, **kwargs):
    """
    Reward function 2: Partial format matching

    Awards partial credit for having some of the required elements.
    This helps the model learn incrementally.
    """
    scores = []
    for completion in completions:
        score = 0
        # Check for each required element
        score += 0.5 if completion.count(reasoning_start) == 1 else -0.5
        score += 0.5 if completion.find(reasoning_start) == 0 else -0.5
        score += 0.5 if completion.count(reasoning_end) == 1 else -0.5
        score += 0.5 if completion.count(solution_start) == 1 else -0.5
        score += 0.5 if completion.count(solution_end) == 1 else -0.5
        scores.append(score)
    return scores


def check_answer(prompts, completions, answer, **kwargs):
    """
    Reward function 3: Answer correctness

    Awards points based on how close the predicted answer is to the truth:
    - 3.0 points for exact match
    - 1.5 points for match after stripping whitespace
    - 0.5 points if within 10% of correct value
    - Penalties for wrong answers
    """
    extracted_responses = [
        guess.group(1) if r is not None and (guess := match_format.search(r)) else None
        for r in completions
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue

        if guess == true_answer:
            score += 3.0
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            try:
                ratio = float(guess) / float(true_answer)
                if 0.9 <= ratio <= 1.1:
                    score += 0.5
                elif 0.8 <= ratio <= 1.2:
                    score += 0.25
                else:
                    score -= 1.0
            except:
                score -= 0.5
        scores.append(score)
    return scores


def check_numbers(prompts, completions, answer, **kwargs):
    """
    Reward function 4: Numerical answer extraction

    Sometimes the answer tag contains extra text. This function
    extracts numerical values and compares them.
    """
    extracted_responses = [
        guess.group(1) if (guess := match_numbers.search(r)) else None
        for r in completions
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        try:
            true_val = float(true_answer.strip())
            pred_val = float(guess.strip())
            scores.append(1.5 if pred_val == true_val else 0.0)
        except:
            scores.append(0)
    return scores


print("Reward functions defined:")
print("  1. match_format_exactly: Rewards correct output structure")
print("  2. match_format_approximately: Partial credit for format")
print("  3. check_answer: Rewards correct answers")
print("  4. check_numbers: Extracts and compares numerical answers")


Reward functions defined:
  1. match_format_exactly: Rewards correct output structure
  2. match_format_approximately: Partial credit for format
  3. check_answer: Rewards correct answers
  4. check_numbers: Extracts and compares numerical answers


In [21]:
# ============================================================================
# CELL 16: Prepare GRPO Dataset
# ============================================================================

print("Loading GRPO training data...")

# Load GSM8K dataset for GRPO training
NUM_GRPO_BATCHES = 3738  # Full dataset
if DEBUG_MODE:
    NUM_GRPO_BATCHES = 10  # Reduced for fast debugging (use only 10 batches)
NUM_TEST_BATCHES = 64    # For evaluation
if DEBUG_MODE:
    NUM_TEST_BATCHES = 5  # Reduced for fast debugging

grpo_dataset = get_grpo_dataset(TRAIN_DATA_DIR, "train", "kaggle")
grpo_dataset = grpo_dataset.batch(GRPO_BATCH_SIZE)[:NUM_GRPO_BATCHES]

# Split into train/validation
if TRAIN_FRACTION < 1.0:
    split_idx = int(len(grpo_dataset) * TRAIN_FRACTION)
    train_dataset = grpo_dataset[:split_idx]
    val_dataset = grpo_dataset[split_idx:]
else:
    train_dataset = grpo_dataset
    val_dataset = None

# Load test dataset
test_dataset = get_grpo_dataset(TEST_DATA_DIR, "test", "kaggle")
test_dataset = test_dataset.batch(GRPO_BATCH_SIZE)[:NUM_TEST_BATCHES]

print(f"\nDataset sizes:")
print(f"  Training batches: {len(train_dataset)}")
print(f"  Validation batches: {len(val_dataset) if val_dataset else 0}")
print(f"  Test batches: {len(test_dataset)}")

# Preview a sample
print("\nSample GRPO training example:")
for batch in train_dataset[:1]:
    pprint(batch)


Loading GRPO training data...
Using Colab cache for faster access to the 'grade-school-math-8k-q-a' dataset.
Copied main_test.csv
Copied main_train.csv
Copied socratic_train.csv
Copied socratic_test.csv
Using Colab cache for faster access to the 'grade-school-math-8k-q-a' dataset.
Copied main_test.csv
Copied main_train.csv
Copied socratic_train.csv
Copied socratic_test.csv

Dataset sizes:
  Training batches: 3364
  Validation batches: 374
  Test batches: 64

Sample GRPO training example:
{'answer': array(['3'], dtype='<U1'),
 'prompts': array(['<start_of_turn>user\nYou are given a problem. First, think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer between <answer> and </answer>.\n\nMaria has 4 dimes, 4 quarters, and 7 nickels in her piggy bank. Her mom gives her 5 quarters. How much money, in dollars, does Maria have now?<end_of_turn>\n<start_of_turn>model\n'],
      dtype='<U392'),
 'question': array(['Maria

In [22]:
# ============================================================================
# CELL 17: Configure and Run GRPO Training
# ============================================================================

print("=" * 60)
print("STAGE 2: GRPO REINFORCEMENT LEARNING")
print("=" * 60)

# Calculate training steps
GRPO_WARMUP_STEPS = int(GRPO_WARMUP_RATIO * GRPO_MAX_STEPS)
EVAL_EVERY_N_STEPS = 64
if DEBUG_MODE:
    EVAL_EVERY_N_STEPS = 5  # More frequent evaluation for debugging

# Configure GRPO optimizer (lower learning rate than SFT!)
grpo_optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=GRPO_LEARNING_RATE,
        warmup_steps=GRPO_WARMUP_STEPS,
        decay_steps=GRPO_MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)

# Add gradient clipping (critical for RL stability!)
grpo_optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
    grpo_optimizer,
)

# Convert GRPO checkpoint directory to absolute path (required by Orbax)
GRPO_CKPT_DIR_ABS = os.path.abspath(GRPO_CKPT_DIR)
os.makedirs(GRPO_CKPT_DIR_ABS, exist_ok=True)

# Checkpoint configuration
grpo_ckpt_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS,
    max_to_keep=MAX_TO_KEEP,
)

# Metrics logging
grpo_metrics_options = metrics_logger.MetricsLoggerOptions(
    log_dir="./logs/grpo/",
    flush_every_n_steps=20,
)

print(f"\nGRPO Configuration:")
print(f"  Learning rate: {GRPO_LEARNING_RATE}")
print(f"  Generations per prompt (G): {NUM_GENERATIONS}")
print(f"  KL penalty (beta): {BETA}")
print(f"  Clipping (epsilon): {EPSILON}")
print(f"  Max steps: {GRPO_MAX_STEPS}")

# Training configuration
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=grpo_optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=GRPO_MAX_STEPS,
        mini_batch_size=GRPO_BATCH_SIZE,
        train_micro_batch_size=GRPO_BATCH_SIZE,
        metrics_logging_options=grpo_metrics_options,
        checkpoint_root_directory=GRPO_CKPT_DIR_ABS,
        checkpointing_options=grpo_ckpt_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=EOS_TOKENS,
    ),
)

# GRPO algorithm configuration
grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

print("\nInitializing RL Cluster...")
# Create RL cluster with policy and reference models
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,      # Model being trained (with LoRA)
    reference=gemma3,        # Fixed reference model (for KL divergence)
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

print("Creating GRPO Trainer...")
# Create GRPO trainer with reward functions
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    algo_config=grpo_config,
)

print("GRPO Trainer initialized successfully!")


INFO:root:WandbBackend skipped: 'wandb' library not installed.


STAGE 2: GRPO REINFORCEMENT LEARNING

GRPO Configuration:
  Learning rate: 3e-06
  Generations per prompt (G): 4
  KL penalty (beta): 0.001
  Clipping (epsilon): 10.0
  Max steps: 2500

Initializing RL Cluster...


INFO:absl:save_device_host_concurrent_bytes=None
INFO:absl:Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7ff989771d00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
INFO:absl:save_device_host_concurrent_bytes=None
INFO:absl:Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7ff989771d00>, enable_pinned_host_transfer=False, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
INFO:absl:[process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers={'mo

Creating GRPO Trainer...


INFO:absl:RLLearner init - Pathways not available. Using default HBM stats collector
INFO:absl:Using 4.8 GiB / 31.2 GiB (0.15328547848396637) on TPU_0(process=0,(0,0,0,0))


GRPO Trainer initialized successfully!


In [None]:
# ============================================================================
# CELL 18: Run GRPO Training
# ============================================================================
# This is the main training loop. It may take several hours to complete.
# Checkpoints are saved automatically during training.

print("\n" + "=" * 60)
print("STARTING GRPO TRAINING")
print("=" * 60)
print(f"\nThis will take approximately 6-8 hours on TPU v6e-1")
print(f"Checkpoints saved every {SAVE_INTERVAL_STEPS} steps to {GRPO_CKPT_DIR_ABS}")
print("\nMonitor progress in TensorBoard.")
print("Key metrics to watch:")
print("  - reward_accuracy: Should increase over time")
print("  - reward_format: Should stay high (near 1.0)")
print("  - completion_length: May increase as model reasons more")
print("=" * 60 + "\n")

# Run GRPO training
grpo_trainer.train(train_dataset, val_dataset)

print("\n" + "=" * 60)
print("GRPO TRAINING COMPLETE!")
print("=" * 60)


INFO:absl:Training with full_batch_size=1, mini_batch_size=1, train_micro_batch_size=1, self._rollout_micro_batch_size=1, self._compute_logps_micro_batch_size=1, grad_acc_steps=1



STARTING GRPO TRAINING

This will take approximately 6-8 hours on TPU v6e-1
Checkpoints saved every 500 steps to /content/checkpoints/grpo

Monitor progress in TensorBoard.
Key metrics to watch:
  - reward_accuracy: Should increase over time
  - reward_format: Should stay high (near 1.0)
  - completion_length: May increase as model reasons more



In [None]:
# ============================================================================
# CELL 18.5: Evaluate Model After GRPO Training
# ============================================================================
# This cell evaluates the model's performance AFTER GRPO training.
# Shows the final improvements from the complete training pipeline.

print("=" * 60)
print("EVALUATING MODEL AFTER GRPO TRAINING")
print("=" * 60)
print("\nRunning evaluation on test dataset...")
print("This will show the final improvements from the complete training pipeline.\n")

# Create sampler for post-GRPO inference
post_grpo_sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

# Run evaluation on test dataset
post_grpo_results = evaluate(test_dataset, post_grpo_sampler, **GENERATION_CONFIGS["greedy"])

print("\n" + "=" * 60)
print("POST-GRPO EVALUATION RESULTS")
print("=" * 60)
print(f"  Answer Accuracy:    {post_grpo_results['accuracy']:.2f}%")
print(f"  Partial Accuracy:   {post_grpo_results['partial_accuracy']:.2f}%")
print(f"  Format Accuracy:    {post_grpo_results['format_accuracy']:.2f}%")
print(f"  Correct/Total:      {post_grpo_results['correct']}/{post_grpo_results['total']}")
print("=" * 60)

print("\n" + "=" * 60)
print("COMPLETE TRAINING PIPELINE COMPARISON")
print("=" * 60)
print(f"{'Metric':<20} {'Before SFT':<15} {'After SFT':<15} {'After GRPO':<15}")
print("-" * 65)

print(f"{'Answer Accuracy':<20} {pre_sft_results['accuracy']:>6.2f}%      {post_sft_results['accuracy']:>6.2f}%      {post_grpo_results['accuracy']:>6.2f}%")
print(f"{'Partial Accuracy':<20} {pre_sft_results['partial_accuracy']:>6.2f}%      {post_sft_results['partial_accuracy']:>6.2f}%      {post_grpo_results['partial_accuracy']:>6.2f}%")
print(f"{'Format Accuracy':<20} {pre_sft_results['format_accuracy']:>6.2f}%      {post_sft_results['format_accuracy']:>6.2f}%      {post_grpo_results['format_accuracy']:>6.2f}%")
print("=" * 65)

# Calculate total improvements
total_acc_improvement = post_grpo_results['accuracy'] - pre_sft_results['accuracy']
total_format_improvement = post_grpo_results['format_accuracy'] - pre_sft_results['format_accuracy']

print("\nTotal Improvements from Complete Pipeline:")
print(f"  Answer Accuracy: {total_acc_improvement:+.2f}% improvement")
print(f"  Format Accuracy: {total_format_improvement:+.2f}% improvement")
print("=" * 60 + "\n")


## Section 8: Export Final Model

Export the trained model in a format compatible with Tunix on Kaggle. The exported model:
- Merges LoRA weights with the base model
- Saves in safetensors format
- Includes all necessary config files


In [None]:
# ============================================================================
# CELL 19: Export Merged Model (HuggingFace Format)
# ============================================================================

output_dir = FINAL_MODEL_DIR
if USE_COLAB:
    output_dir = f"/tmp/content/{MODEL_ID.replace('/', '-')}-trained"

# Clean up existing output directory
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir)

print("=" * 60)
print("EXPORTING TRAINED MODEL")
print("=" * 60)
print(f"\nOutput directory: {output_dir}")

# Merge LoRA weights and save
print("\nMerging LoRA weights with base model...")
gemma_params.save_lora_merged_model_as_safetensors(
    local_model_path=local_model_path,
    output_dir=output_dir,
    lora_model=lora_policy,
    rank=RANK,
    alpha=ALPHA,
)

print("\n" + "=" * 60)
print("MODEL EXPORTED SUCCESSFULLY!")
print("=" * 60)

# List saved files
print("\nSaved files:")
for f in sorted(os.listdir(output_dir)):
    size = os.path.getsize(os.path.join(output_dir, f)) / (1024 * 1024)
    print(f"  {f:<35} {size:>10.2f} MB")


In [None]:
# ============================================================================
# CELL 20: Download Model (For Colab Users)
# ============================================================================

if USE_COLAB:
    from google.colab import files

    # Create zip archive of the model
    zip_path = f"{output_dir}.zip"
    print(f"Creating zip archive: {zip_path}")
    shutil.make_archive(output_dir, 'zip', output_dir)

    print("Downloading model... (this may take a moment)")
    files.download(zip_path)
    print("Download complete!")
else:
    print(f"Model saved to: {output_dir}")
    print("\nTo use this model, load it with Tunix:")
    print(f"  model = params_safetensors_lib.create_model_from_safe_tensors('{output_dir}', config, mesh)")


## Summary

Congratulations! You have successfully trained a reasoning model using:

1. **Cold Start SFT**: Taught the model the output format using `PeftTrainer`
   - `<reasoning>model_thinking_trace</reasoning>`
   - `<answer>model_answer</answer>`

2. **GRPO Reinforcement Learning**: Strengthened reasoning abilities
   - Used multiple reward functions
   - No separate value model needed (memory efficient!)

### Key Takeaways

- **Cold Start is Essential**: Without it, RL training can cause chaotic outputs
- **Learning Rate Matters**: SFT uses ~2e-4, GRPO uses ~3e-6 (much lower!)
- **Checkpoints are Crucial**: Save frequently to recover from failures
- **Format Rewards Help**: Reward proper structure, not just correct answers
- **PeftTrainer Approach**: Using Tunix's built-in trainer simplifies SFT training

### Next Steps

1. Try different datasets (Bespoke-Stratos-17k, OpenR1-Math-220k)
2. Experiment with hyperparameters (NUM_GENERATIONS, BETA, EPSILON)
3. Train for more steps for better results
4. Evaluate on diverse domains (coding, science, creative writing)

### Resources

- [Tunix GitHub](https://github.com/google/tunix)
- [GRPO Paper](https://arxiv.org/pdf/2402.03300)
- [DeepSeek-R1 Technical Report](https://arxiv.org/abs/2401.02954)
- [Bespoke-Stratos Dataset](https://huggingface.co/datasets/bespokelabs/Bespoke-Stratos-17k)
