# Cold Start SFT + GRPO Training for Gemma 3

## A Complete Guide to Training Reasoning Models on Kaggle TPU

[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/tunix/blob/main/examples/grpo_gemma.ipynb)

---

### What This Notebook Does

This notebook implements a **two-stage training pipeline** to enhance the reasoning capabilities of the Gemma 3 1B-IT model:

1. **Stage 1: Cold Start SFT (Supervised Fine-Tuning)**
   - Teaches the model the correct output format: `<reasoning>...</reasoning><answer>...</answer>`
   - Plants a "reasoning template" into the model using high-quality Chain-of-Thought data
   - Prevents the model from producing unreadable or chaotic outputs during RL training

2. **Stage 2: GRPO (Group Relative Policy Optimization)**
   - Reinforces the model's reasoning abilities through reward-based learning
   - Uses multiple reward functions to guide the model toward correct answers and proper formatting
   - More memory-efficient than traditional PPO (no need for a separate value model)

### Why Two Stages?

Research from DeepSeek-R1 shows that pure RL training can cause models to:
- Mix languages randomly
- Produce unstructured, hard-to-read outputs
- Explore inefficiently due to lack of initial guidance

The **Cold Start SFT** stage solves these problems by giving the model a "template" for how to reason, before we use GRPO to strengthen that reasoning.

### Output Format

The trained model will produce outputs in this format:
```
<reasoning>model_thinking_trace</reasoning>
<answer>model_answer</answer>
```

### Hardware Requirements
- **Kaggle TPU v6e-1** (recommended) or similar TPU configuration
- Training time: ~9 hours for full pipeline


## Section 1: Environment Setup

### Understanding the Dependencies

Before we start training, we need to install several key libraries:

- **tunix**: Google's library for training Gemma models on TPU with JAX
- **qwix**: Provides LoRA (Low-Rank Adaptation) functionality for efficient fine-tuning
- **flax**: Neural network library for JAX
- **grain**: Data loading library optimized for JAX
- **transformers**: For tokenizer and model utilities

**Important**: After running the installation cell, you may need to restart the kernel for the changes to take effect.


In [None]:
# ============================================================================
# CELL 1: Install Required Packages
# ============================================================================
# This cell installs all necessary libraries for training.
# RESTART THE KERNEL AFTER THIS CELL COMPLETES (for Colab users)

import importlib.util

def check_package(name):
    """Check if a package is installed."""
    return importlib.util.find_spec(name) is not None

# Check for key packages - tunix is the most critical one
TUNIX_INSTALLED = check_package('tunix')
QWIX_INSTALLED = check_package('qwix')

if not TUNIX_INSTALLED or not QWIX_INSTALLED:
    print("Installing required packages... This may take a few minutes.")
    print("=" * 60)
    
    # Core dependencies
    %pip install -q python-dotenv
    %pip install -q kagglehub
    %pip install -q ipywidgets
    %pip install -q tensorflow
    %pip install -q tensorflow_datasets
    %pip install -q tensorboardX
    %pip install -q transformers
    %pip install -q grain
    
    # JAX and related libraries (for TPU training)
    %pip install -q git+https://github.com/jax-ml/jax
    
    # Google's training libraries (CRITICAL - these are the main packages)
    %pip install git+https://github.com/google/tunix  # Main training framework
    %pip install git+https://github.com/google/qwix   # LoRA support
    
    # Flax update (required for NNX support)
    %pip uninstall -q flax -y
    %pip install git+https://github.com/google/flax
    
    # Data and utilities
    %pip install -q huggingface_hub
    %pip install -q datasets
    %pip install -q 'numpy>2'
    
    print("\n" + "=" * 60)
    print("Installation complete!")
    print("IMPORTANT: Please RESTART the kernel before continuing.")
    print("=" * 60)
else:
    print("All required packages are already installed.")
    print(f"  - tunix: {'OK' if TUNIX_INSTALLED else 'MISSING'}")
    print(f"  - qwix:  {'OK' if QWIX_INSTALLED else 'MISSING'}")


### Authentication Setup

This cell handles authentication for:
- **Hugging Face**: To download the Gemma model
- **Kaggle**: For dataset access
- **Weights & Biases (optional)**: For experiment tracking

You need to set up your API keys as environment variables or secrets before running this cell.


In [None]:
# ============================================================================
# CELL 2: Authentication and Service Login
# ============================================================================

import os
import kagglehub

# Detect if running in Colab or Kaggle environment
try:
    from google.colab import userdata
    USE_COLAB = True
    
    # WandB has issues with Colab, so we disable it
    %pip uninstall -q wandb -y
    
    # Load credentials from Colab secrets
    os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
    os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
    os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")
    print("Running in Google Colab environment")
    
except ImportError:
    USE_COLAB = False
    
    # Try to load from .env file
    from dotenv import load_dotenv
    load_dotenv()
    print("Using environment variables for authentication")
    
    # Apply nest_asyncio for Jupyter compatibility
    import nest_asyncio
    nest_asyncio.apply()
    print("nest_asyncio applied for async compatibility")
    
    # Setup WandB for TPU VM (works better outside Colab)
    %pip install -q wandb
    import wandb
    if "WANDB_API_KEY" in os.environ and os.environ["WANDB_API_KEY"]:
        wandb.login(key=os.environ["WANDB_API_KEY"])
        print("Weights & Biases login successful")
    else:
        print("WANDB_API_KEY not found. Skipping W&B login.")

# Kaggle authentication
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
    print("Kaggle credentials not found. Please login manually:")
    kagglehub.login()

# Hugging Face authentication
if "HF_TOKEN" in os.environ and os.environ["HF_TOKEN"]:
    hf_token = os.environ["HF_TOKEN"]
    !huggingface-cli login --token "$hf_token"
else:
    print("HF_TOKEN not found. Please set it for model download.")


### Import Libraries

Now we import all the necessary libraries. Here's what each major import does:

- **jax, jnp**: Core library for accelerated numerical computing on TPU/GPU
- **flax.nnx**: Neural network library with the new NNX API
- **tunix**: Google's library for Gemma model training
- **qwix**: LoRA (Low-Rank Adaptation) for efficient fine-tuning
- **grain**: Efficient data loading for JAX
- **optax**: Gradient transformation and optimization library


In [None]:
# ============================================================================
# CELL 3: Import All Required Libraries
# ============================================================================

import functools
from pprint import pprint
import re
import sys
import csv
import json
import shutil

# JAX ecosystem
from flax import nnx
import grain
import humanize
from huggingface_hub import snapshot_download
import jax
import jax.numpy as jnp
import kagglehub
import numpy as np
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from datasets import load_dataset

# Tunix - Google's Gemma training library
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import model as gemma_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
from tunix.models.gemma3 import params as gemma_params
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger

# Note: tunix doesn't have a built-in SFTTrainer class
# We will implement cold start SFT using a custom training loop

print("All imports successful!")
print(f"JAX version: {jax.__version__}")
print(f"Number of devices available: {len(jax.devices())}")
print(f"Device type: {jax.devices()[0].device_kind}")


## Section 2: Hyperparameter Configuration

### Understanding the Key Hyperparameters

This section defines all the configuration parameters for both SFT and GRPO training. Let's break down the most important ones:

**Model Configuration:**
- `MODEL_ID`: The base Gemma model to fine-tune
- `RANK` and `ALPHA`: LoRA parameters that control the size of trainable adapters

**Training Configuration:**
- `MAX_SEQ_LENGTH`: Maximum sequence length for training (longer = more memory)
- `LEARNING_RATE`: Step size for gradient updates (too high = unstable, too low = slow)

**GRPO-Specific Parameters:**
- `NUM_GENERATIONS`: How many responses to generate per prompt for comparison (the "G" in GRPO)
- `BETA`: KL divergence penalty coefficient (keeps the policy close to reference)
- `EPSILON`: Clipping parameter for stable updates (similar to PPO)


In [None]:
# ============================================================================
# CELL 4: Hyperparameter Configuration
# ============================================================================
# All training parameters are defined here for easy modification.
# Adjust these based on your hardware and training requirements.

# ===========================================================================
# MODEL CONFIGURATION
# ===========================================================================
MODEL_ID = "google/gemma-3-1b-it"  # Base model to fine-tune
GEMMA_TOKENIZER_PATH = "gs://gemma-data/tokenizers/tokenizer_gemma3.model"

# ===========================================================================
# DATA PATHS
# ===========================================================================
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
SFT_DATA_DIR = "./data/sft"  # For cold start SFT data
TRAIN_FRACTION = 0.9  # Fraction of data used for training (rest for validation)

# ===========================================================================
# LoRA CONFIGURATION
# ===========================================================================
# LoRA (Low-Rank Adaptation) allows us to fine-tune large models efficiently
# by only training a small number of additional parameters.
RANK = 64       # Rank of the LoRA matrices (higher = more parameters, better quality)
ALPHA = 64.0    # Scaling factor for LoRA (typically set equal to RANK)

# ===========================================================================
# SHARDING CONFIGURATION (for TPU)
# ===========================================================================
# Configure the mesh for distributed training across TPU cores
NUM_TPUS = len(jax.devices())
print(f"Detected {NUM_TPUS} TPU cores")

if NUM_TPUS == 8:
    MESH_COUNTS = (1, 4)  # For v3-8 TPU
elif NUM_TPUS == 4:
    MESH_COUNTS = (1, 4)  # For v4-8 TPU
elif NUM_TPUS == 1:
    MESH_COUNTS = (1, 1)  # Single device
else:
    # Default configuration for other setups
    MESH_COUNTS = (1, NUM_TPUS)
    
MESH = [MESH_COUNTS, ("fsdp", "tp")]

# ===========================================================================
# SEQUENCE LENGTH CONFIGURATION
# ===========================================================================
MAX_PROMPT_LENGTH = 256           # Maximum length of input prompts
TOTAL_GENERATION_STEPS = 768      # Maximum tokens to generate
MAX_SEQ_LENGTH = MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS

# ===========================================================================
# COLD START SFT CONFIGURATION
# ===========================================================================
SFT_LEARNING_RATE = 2e-4          # Learning rate for SFT (higher than GRPO)
SFT_BATCH_SIZE = 2                # Batch size for SFT
SFT_GRADIENT_ACCUMULATION = 4     # Simulate larger batch via accumulation
SFT_MAX_STEPS = 500               # Number of SFT training steps
SFT_WARMUP_STEPS = 50             # Linear warmup steps
SFT_SAVE_STEPS = 100              # Save checkpoint every N steps

# ===========================================================================
# GRPO CONFIGURATION (Based on DeepSeek-R1 Paper)
# ===========================================================================
# Reference: DeepSeek-R1 uses EPSILON=10, BETA=0.001, NUM_GENERATIONS=16
# We adjust values for single TPU memory constraints while following their principles.

# Generation parameters during GRPO training
TEMPERATURE = 1.0     # DeepSeek-R1: 1.0 for RL Stage 1 (high for exploration)
TOP_P = 1.0           # Nucleus sampling parameter
TOP_K = 50            # Top-k sampling parameter

# GRPO algorithm parameters (adjusted from DeepSeek-R1)
NUM_GENERATIONS = 4   # DeepSeek-R1 uses 16, reduced for memory
NUM_ITERATIONS = 1    # Number of iterations per batch
BETA = 0.001          # KL coefficient (DeepSeek-R1: 0.001, much lower than typical PPO!)
EPSILON = 10.0        # Clip ratio (DeepSeek-R1: 10, NOT 0.2 like PPO!)

# GRPO training parameters
GRPO_LEARNING_RATE = 3e-6         # DeepSeek-R1: 3e-6 ✓
GRPO_BATCH_SIZE = 1               # Keep small due to memory constraints
GRPO_GRADIENT_ACCUMULATION = 4    # Effective batch: 4 questions * 4 gen = 16 responses
GRPO_MAX_STEPS = 2500             # Number of GRPO training steps
GRPO_WARMUP_RATIO = 0.1           # Warmup as fraction of total steps
REFERENCE_UPDATE_STEPS = 400      # DeepSeek-R1: Update reference model every 400 steps

# ===========================================================================
# OPTIMIZER CONFIGURATION
# ===========================================================================
B1 = 0.9              # Adam beta1
B2 = 0.99             # Adam beta2  
WEIGHT_DECAY = 0.1    # Weight decay for regularization
MAX_GRAD_NORM = 0.1   # Gradient clipping (important for stability)

# ===========================================================================
# CHECKPOINT CONFIGURATION
# ===========================================================================
SFT_CKPT_DIR = "./checkpoints/sft/"           # SFT checkpoint directory
GRPO_CKPT_DIR = "./checkpoints/grpo/"         # GRPO checkpoint directory
FINAL_MODEL_DIR = "./checkpoints/final/"      # Final merged model
SAVE_INTERVAL_STEPS = 500                     # Save every N steps
MAX_TO_KEEP = 4                               # Maximum checkpoints to keep

# ===========================================================================
# INFERENCE CONFIGURATION
# ===========================================================================
GENERATION_CONFIGS = {
    "greedy": {"temperature": None, "top_k": 1, "top_p": None},      # Deterministic
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},    # Balanced
    "creative": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},  # More random
}

print("\nConfiguration loaded successfully!")
print(f"Model: {MODEL_ID}")
print(f"LoRA Rank: {RANK}, Alpha: {ALPHA}")
print(f"SFT Steps: {SFT_MAX_STEPS}, GRPO Steps: {GRPO_MAX_STEPS}")


## Section 3: Utility Functions and Special Tokens

### Output Format Definition

We define a specific output format that the model must learn to use. This format separates:
1. **Reasoning**: The model's thought process (between `<reasoning>` and `</reasoning>` tags)
2. **Answer**: The final answer (between `<answer>` and `</answer>` tags)

This structured format makes it easy to:
- Extract and evaluate the model's reasoning
- Verify if the final answer is correct
- Debug the model's thinking process


In [None]:
# ============================================================================
# CELL 5: Special Tokens and Template Definition
# ============================================================================

# Define special tokens for structured output
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"

# System prompt that instructs the model on the expected output format
SYSTEM_PROMPT = f"""You are given a problem. First, think about the problem \
and provide your reasoning. Place it between {reasoning_start} and \
{reasoning_end}. Then, provide the final answer between {solution_start} and {solution_end}."""

# Template for formatting prompts (Gemma chat format)
TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model
"""

# Regex pattern to validate output format
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

# Regex to extract numbers from answers
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", 
    flags=re.MULTILINE | re.DOTALL
)

# Test the format matching
example = f"{reasoning_start}Let me think step by step...{reasoning_end}{solution_start}42{solution_end}"
print("Example output format:")
print(example)
print(f"\nFormat validation: {'PASS' if match_format.search(example) else 'FAIL'}")


In [None]:
# ============================================================================
# CELL 6: Utility Functions
# ============================================================================

def show_hbm_usage():
    """Display memory usage per device (useful for debugging OOM issues)."""
    fmt_size = functools.partial(humanize.naturalsize, binary=True)
    
    for d in jax.local_devices():
        stats = d.memory_stats()
        used = stats["bytes_in_use"]
        limit = stats["bytes_limit"]
        print(f"Device {d}: {fmt_size(used)} / {fmt_size(limit)} ({used/limit:.1%})")


def extract_hash_answer(text: str) -> str:
    """Extract answer from GSM8K format (after #### marker)."""
    if "####" not in text:
        return None
    return text.split("####")[1].strip()


def create_directories():
    """Create all necessary directories for checkpoints and data."""
    dirs = [TRAIN_DATA_DIR, TEST_DATA_DIR, SFT_DATA_DIR, 
            SFT_CKPT_DIR, GRPO_CKPT_DIR, FINAL_MODEL_DIR]
    for d in dirs:
        os.makedirs(d, exist_ok=True)
    print("All directories created successfully!")

create_directories()


## Section 4: Model Loading

### Loading the Base Gemma Model

We download the Gemma 3 1B-IT model from Hugging Face and load it using Tunix. The model is:
- **Gemma 3 1B-IT**: A 1 billion parameter instruction-tuned model from Google
- Loaded with safetensors format for efficient memory usage
- Sharded across TPU cores using the mesh configuration

**Note**: You need to have accepted the Gemma license on Kaggle to download the model.


In [None]:
# ============================================================================
# CELL 7: Download and Load Base Model
# ============================================================================

# Download model from Hugging Face (skip PyTorch weights)
ignore_patterns = ["*.pth"]
print(f"Downloading {MODEL_ID} from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=MODEL_ID, 
    ignore_patterns=ignore_patterns
)
print(f"Model downloaded to: {local_model_path}")

# Load EOS tokens from generation config
EOS_TOKENS = []
generation_config_path = os.path.join(local_model_path, "generation_config.json")
if os.path.exists(generation_config_path):
    with open(generation_config_path, "r") as f:
        generation_configs = json.load(f)
    EOS_TOKENS = generation_configs.get("eos_token_id", [])
    print(f"EOS token IDs: {EOS_TOKENS}")


In [None]:
# ============================================================================
# CELL 8: Initialize Model with TPU Mesh
# ============================================================================

# Select model configuration based on MODEL_ID
if "gemma-3-270m" in MODEL_ID:
    model_config = gemma_lib.ModelConfig.gemma3_270m()
elif "gemma-3-1b" in MODEL_ID:
    model_config = gemma_lib.ModelConfig.gemma3_1b_it()
else:
    raise ValueError(f"Unsupported model: {MODEL_ID}")

# Create TPU mesh for distributed training
mesh = jax.make_mesh(
    *MESH, 
    axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0])
)

# Load model with proper sharding
print("Loading model onto TPU mesh...")
with mesh:
    gemma3 = params_safetensors_lib.create_model_from_safe_tensors(
        local_model_path, 
        model_config, 
        mesh
    )

print("Base model loaded successfully!")
nnx.display(gemma3)


### LoRA (Low-Rank Adaptation) Setup

LoRA is a technique that allows us to fine-tune large models efficiently by:
1. Freezing the original model weights
2. Adding small trainable "adapter" matrices to key layers
3. Only training these adapter matrices (much fewer parameters!)

**Benefits**:
- Reduces memory usage significantly
- Faster training
- Easy to switch between different fine-tuned versions
- Original model weights remain unchanged


In [None]:
# ============================================================================
# CELL 9: Create LoRA Model and Load Tokenizer
# ============================================================================

def get_lora_model(base_model, mesh):
    """
    Apply LoRA adapters to the base model.
    
    LoRA targets specific layers:
    - q_einsum, kv_einsum: Attention query/key-value projections
    - gate_proj, down_proj, up_proj: MLP layers
    - attn_vec_einsum: Attention output projection
    """
    lora_provider = qwix.LoraProvider(
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
            ".*attn_vec_einsum"
        ),
        rank=RANK,
        alpha=ALPHA,
    )
    
    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(
        base_model, lora_provider, **model_input
    )
    
    # Shard the LoRA model across devices
    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)
    
    return lora_model


# Create the policy model with LoRA adapters
print("Creating LoRA policy model...")
lora_policy = get_lora_model(gemma3, mesh=mesh)
print("LoRA model created!")

# Load tokenizer
tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)
if tokenizer.eos_id() not in EOS_TOKENS:
    EOS_TOKENS.append(tokenizer.eos_id())
print(f"Tokenizer loaded. EOS tokens: {EOS_TOKENS}")

# Display memory usage
show_hbm_usage()


## Section 5: Data Preparation

### Understanding the Two Datasets

We use **two different datasets** for the two training stages:

---

#### Dataset 1: Bespoke-Stratos-17k (for Cold Start SFT)

**Source**: `bespokelabs/Bespoke-Stratos-17k` on HuggingFace

**Purpose**: Teach the model the reasoning format (`<reasoning>...<answer>`)

**Key Features**:
- ~17,000 high-quality Chain-of-Thought examples
- Distilled from DeepSeek-R1 model
- Contains LONG reasoning traces with self-reflection and verification
- Includes `<think>...</think>` format (we convert to `<reasoning>...</reasoning>`)
- Perfect for cold-start: teaches the model HOW to think, not just WHAT to answer

**Why this dataset?**
- Shows the model examples of step-by-step reasoning
- Teaches proper output structure
- Prevents chaotic/unreadable outputs during GRPO

---

#### Dataset 2: GSM8K / grade-school-math-8k-q-a (for GRPO)

**Source**: `thedevastator/grade-school-math-8k-q-a` on Kaggle (or OpenAI's GSM8K)

**Purpose**: Strengthen reasoning through reward-based learning

**Key Features**:
- ~8,000 grade school math word problems
- Simple format: question + answer (with `####` separator)
- Answers are verifiable (correct or incorrect)
- Perfect for GRPO: clear right/wrong signal for rewards

**Why this dataset?**
- GRPO needs verifiable answers to compute rewards
- Math problems have definitive correct answers
- Model learns to reason correctly, not just format correctly

---

### Summary: Two-Stage Data Strategy

| Stage | Dataset | Size | Purpose |
|-------|---------|------|---------|
| **SFT Cold Start** | Bespoke-Stratos-17k | 17k | Learn reasoning FORMAT |
| **GRPO Training** | GSM8K | 8k | Strengthen reasoning ACCURACY |


In [None]:
# ============================================================================
# CELL 10: Data Loading Functions
# ============================================================================

# ----------------------------------------------------------------------------
# DATASET 1: Bespoke-Stratos-17k for Cold Start SFT
# ----------------------------------------------------------------------------

def load_bespoke_stratos_dataset(num_examples=None):
    """
    Load Bespoke-Stratos-17k dataset from HuggingFace for Cold Start SFT.
    
    This dataset contains high-quality Chain-of-Thought examples distilled
    from DeepSeek-R1. It's specifically designed for teaching models the
    reasoning format.
    
    Original format uses <think>...</think>, we convert to <reasoning>...</reasoning>
    
    Args:
        num_examples: Maximum number of examples to load (None = all ~17k)
    
    Returns:
        List of dicts with 'prompt' and 'completion' keys
    """
    print("Loading Bespoke-Stratos-17k from HuggingFace...")
    
    # Load from HuggingFace
    dataset = load_dataset("bespokelabs/Bespoke-Stratos-17k", split="train")
    
    sft_data = []
    for i, example in enumerate(tqdm(dataset, desc="Processing examples")):
        if num_examples and i >= num_examples:
            break
        
        # Extract conversations
        conversations = example.get("conversations", [])
        if len(conversations) < 2:
            continue
        
        # Get user message and assistant response
        user_msg = None
        assistant_msg = None
        
        for conv in conversations:
            role = conv.get("from", conv.get("role", ""))
            content = conv.get("value", conv.get("content", ""))
            
            if role in ["human", "user"]:
                user_msg = content
            elif role in ["gpt", "assistant"]:
                assistant_msg = content
        
        if not user_msg or not assistant_msg:
            continue
        
        # Convert various thinking tags to our standard <reasoning>...</reasoning> format
        # Bespoke-Stratos uses: <|begin_of_thought|>...<|end_of_thought|>
        # Some datasets use: <think>...</think>
        completion = assistant_msg
        
        # Handle Bespoke-Stratos format
        completion = completion.replace("<|begin_of_thought|>", reasoning_start)
        completion = completion.replace("<|end_of_thought|>", reasoning_end)
        completion = completion.replace("<|begin_of_solution|>", solution_start)
        completion = completion.replace("<|end_of_solution|>", solution_end)
        
        # Handle other common formats
        completion = completion.replace("<think>", reasoning_start)
        completion = completion.replace("</think>", reasoning_end)
        
        # If still no answer tags, add them
        if solution_start not in completion:
            # Try to extract final answer after reasoning
            if reasoning_end in completion:
                parts = completion.split(reasoning_end)
                if len(parts) > 1 and parts[1].strip():
                    # Wrap the part after reasoning in answer tags
                    completion = parts[0] + reasoning_end + "\n" + solution_start + parts[1].strip() + solution_end
                else:
                    completion = completion + f"\n{solution_start}See reasoning above{solution_end}"
            else:
                completion = completion + f"\n{solution_start}See reasoning above{solution_end}"
        
        # Format prompt with our template
        prompt = TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=user_msg,
        )
        
        sft_data.append({
            "prompt": prompt,
            "completion": completion,
        })
    
    print(f"Loaded {len(sft_data)} SFT examples from Bespoke-Stratos-17k")
    return sft_data


# ----------------------------------------------------------------------------
# DATASET 2: GSM8K for GRPO Training
# ----------------------------------------------------------------------------

def download_kaggle_dataset(target_dir="./data/gsm8k"):
    """Download GSM8K dataset from Kaggle."""
    os.makedirs(target_dir, exist_ok=True)
    src = kagglehub.dataset_download("thedevastator/grade-school-math-8k-q-a")
    src = Path(src)
    dst = Path(target_dir)
    
    for csv_file in src.glob("*.csv"):
        shutil.copy2(csv_file, dst / csv_file.name)
        print(f"Copied {csv_file.name}")
    return target_dir


def get_grpo_dataset(data_dir, split="train", source="kaggle") -> grain.MapDataset:
    """
    Load and format dataset for GRPO training.
    
    Returns a dataset with:
    - prompts: Formatted input prompts
    - question: Original question text  
    - answer: Ground truth answer for verification
    """
    if source == "tfds":
        import tensorflow_datasets.text.gsm8k
        data = tfds.data_source(
            "gsm8k",
            split=split,
            data_dir=data_dir,
            builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
            download=True,
        )
    elif source == "kaggle":
        kaggle_dir = download_kaggle_dataset(data_dir)
        file_name = "main_" + split + ".csv"
        csv_path = os.path.join(kaggle_dir, file_name)
        
        data = []
        with open(csv_path, newline="", encoding="utf-8") as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                data.append({
                    "question": row["question"],
                    "answer": row["answer"],
                })
    else:
        raise ValueError(f"Unknown source: {source}")
    
    def _as_text(v):
        return v if isinstance(v, str) else v.decode("utf-8")
    
    dataset = (
        grain.MapDataset.source(data)
        .shuffle(seed=42)
        .map(lambda x: {
            "prompts": TEMPLATE.format(
                system_prompt=SYSTEM_PROMPT,
                question=_as_text(x["question"]),
            ),
            "question": _as_text(x["question"]),
            "answer": extract_hash_answer(_as_text(x["answer"])),
        })
    )
    return dataset


def create_sft_dataset(num_examples=1000):
    """
    Create SFT dataset for cold start training.
    
    This generates training examples that teach the model:
    1. The correct output format (<reasoning>...</reasoning><answer>...</answer>)
    2. Step-by-step reasoning patterns
    
    In practice, you would use a high-quality dataset like Bespoke-Stratos-17k
    or OpenR1-Mixture-of-Thoughts.
    """
    # For demonstration, we'll convert GSM8K examples into SFT format
    # In production, use a dedicated reasoning dataset
    
    print("Creating SFT dataset from GSM8K with formatted examples...")
    
    # Load base data
    kaggle_dir = download_kaggle_dataset(SFT_DATA_DIR)
    csv_path = os.path.join(kaggle_dir, "main_train.csv")
    
    sft_data = []
    with open(csv_path, newline="", encoding="utf-8") as csvfile:
        reader = csv.DictReader(csvfile)
        for i, row in enumerate(reader):
            if i >= num_examples:
                break
                
            question = row["question"]
            full_answer = row["answer"]
            
            # Extract the reasoning (everything before ####) and answer
            if "####" in full_answer:
                reasoning_part = full_answer.split("####")[0].strip()
                answer_part = full_answer.split("####")[1].strip()
            else:
                reasoning_part = full_answer
                answer_part = "N/A"
            
            # Format as training example
            formatted_output = f"{reasoning_start}\n{reasoning_part}\n{reasoning_end}\n{solution_start}{answer_part}{solution_end}"
            
            sft_data.append({
                "prompt": TEMPLATE.format(
                    system_prompt=SYSTEM_PROMPT,
                    question=question,
                ),
                "completion": formatted_output,
            })
    
    print(f"Created {len(sft_data)} SFT examples")
    return sft_data


print("Data loading functions defined!")


## Section 6: Stage 1 - Cold Start SFT Training

Cold Start SFT teaches the model the output format before GRPO strengthens reasoning.

Key goals:
- Teach the `<reasoning>` and `<answer>` tag format
- Show examples of step-by-step thinking
- Prevent chaotic outputs during later RL training


In [None]:
# ============================================================================
# CELL 10.4: Reload Model from Original Checkpoint (for Loop Testing)
# ============================================================================
# This cell reloads the model from the original downloaded model (not from checkpoints).
# Use this cell to reset the model state before running training loops.
# 
# This is useful for:
# - Testing different hyperparameters
# - Comparing different training runs
# - Resetting model to baseline before each experiment

print("=" * 60)
print("RELOADING MODEL FROM ORIGINAL CHECKPOINT")
print("=" * 60)
print(f"Reloading from: {local_model_path}")
print("This will reset all LoRA parameters to initial state.\n")

# Ensure mesh is available (it should be from CELL 8)
if 'mesh' not in locals():
    print("Creating TPU mesh...")
    mesh = jax.make_mesh(
        *MESH, 
        axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0])
    )

# Select model configuration based on MODEL_ID
if "gemma-3-270m" in MODEL_ID:
    model_config = gemma_lib.ModelConfig.gemma3_270m()
elif "gemma-3-1b" in MODEL_ID:
    model_config = gemma_lib.ModelConfig.gemma3_1b_it()
else:
    raise ValueError(f"Unsupported model: {MODEL_ID}")

# Reload base model from original checkpoint
print("Loading base model from original checkpoint...")
with mesh:
    gemma3 = params_safetensors_lib.create_model_from_safe_tensors(
        local_model_path, 
        model_config, 
        mesh
    )

print("Base model reloaded successfully!")

# Recreate LoRA model (fresh LoRA adapters, all weights reset)
print("Recreating LoRA policy model...")
lora_policy = get_lora_model(gemma3, mesh=mesh)
print("LoRA model recreated with fresh adapters!")

# Reload tokenizer (if needed)
if 'tokenizer' not in locals():
    print("Loading tokenizer...")
    tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)
    if tokenizer.eos_id() not in EOS_TOKENS:
        EOS_TOKENS.append(tokenizer.eos_id())
    print(f"Tokenizer loaded. EOS tokens: {EOS_TOKENS}")

print("\n" + "=" * 60)
print("MODEL RESET COMPLETE")
print("=" * 60)
print("Model has been reset to original state.")
print("All previous training has been cleared.")
print("Ready for fresh training run.")
print("=" * 60 + "\n")

# Display memory usage
show_hbm_usage()


In [None]:
# ============================================================================
# CELL 10.5: Load Evaluation Dataset and Define Evaluation Functions
# ============================================================================
# Load a small test dataset for evaluation before and after SFT training.
# Also define evaluate and generate functions here for use in comparison cells.

print("=" * 60)
print("PREPARING EVALUATION DATASET")
print("=" * 60)

# Load a small test dataset for evaluation (before and after SFT)
NUM_EVAL_SAMPLES = 32  # Small sample for quick evaluation
eval_test_dataset = get_grpo_dataset(TEST_DATA_DIR, "test", "kaggle")
eval_test_dataset = eval_test_dataset.batch(1)[:NUM_EVAL_SAMPLES]

print(f"Loaded {NUM_EVAL_SAMPLES} test samples for evaluation")
print("=" * 60)

# Define evaluation functions (same as in CELL 20, but defined here for early use)
def generate(question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None):
    """Generate response for a given question."""
    if isinstance(question, str):
        input_batch = [TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=question)]
    else:
        input_batch = [
            TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=q)
            for q in question
        ]
    
    out_data = sampler(
        input_strings=input_batch,
        max_generation_steps=TOTAL_GENERATION_STEPS,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        echo=False,
        seed=seed,
        eos_tokens=EOS_TOKENS,
    )
    
    output = out_data.text
    return output[0] if isinstance(question, str) else output


def evaluate(dataset, sampler, temperature=0.7, top_k=50, top_p=0.95):
    """Evaluate model on dataset and compute metrics."""
    correct = 0
    partially_correct = 0
    correct_format = 0
    total = 0
    
    for batch in tqdm(dataset, desc="Evaluating"):
        answers = batch["answer"]
        questions = batch["question"]
        
        responses = generate(questions, sampler, temperature, top_k, top_p)
        
        for question, response, answer in zip(questions, responses, answers):
            # Check answer
            extracted = match_numbers.search(response)
            if extracted:
                try:
                    pred = float(extracted.group(1).strip())
                    true_val = float(answer.strip())
                    if pred == true_val:
                        correct += 1
                    ratio = pred / true_val if true_val != 0 else 0
                    if 0.9 <= ratio <= 1.1:
                        partially_correct += 1
                except:
                    pass
            
            # Check format
            if match_format.search(response):
                correct_format += 1
            
            total += 1
            
            # Print progress
            if total % 10 == 0:
                print(f"Progress: {correct}/{total} correct ({correct/total*100:.1f}%)")
    
    return {
        "accuracy": correct / total * 100 if total > 0 else 0,
        "partial_accuracy": partially_correct / total * 100 if total > 0 else 0,
        "format_accuracy": correct_format / total * 100 if total > 0 else 0,
        "correct": correct,
        "total": total,
    }

print("\nEvaluation functions defined!")
print("Ready for pre-SFT and post-SFT evaluation.")


In [None]:
# ============================================================================
# CELL 11: Prepare SFT Dataset (Bespoke-Stratos-17k)
# ============================================================================
# We use Bespoke-Stratos-17k for cold start SFT - this dataset contains
# high-quality Chain-of-Thought examples that teach proper reasoning format.

print("=" * 60)
print("Loading SFT Dataset: Bespoke-Stratos-17k")
print("=" * 60)

# Load Bespoke-Stratos-17k from HuggingFace
# Set num_examples to limit for faster training (None = use all ~17k)
SFT_NUM_EXAMPLES = 2000  # Adjust based on available time

sft_data = load_bespoke_stratos_dataset(num_examples=SFT_NUM_EXAMPLES)

# Convert to grain dataset format
def format_sft_example(example):
    """Combine prompt and completion into a single training text."""
    return {"text": example["prompt"] + example["completion"]}

sft_dataset = (
    grain.MapDataset.source(sft_data)
    .shuffle(seed=42)
    .map(format_sft_example)
    .batch(SFT_BATCH_SIZE)
)

# Preview an example
print("Sample SFT training example:")
print("=" * 60)
sample = sft_data[0]
print(f"PROMPT:\n{sample['prompt'][:200]}...")
print(f"\nCOMPLETION:\n{sample['completion'][:300]}...")
print("=" * 60)


In [None]:
# ============================================================================
# CELL 11.5: Evaluate Model Before SFT Training
# ============================================================================
# This cell evaluates the model's performance BEFORE cold start SFT training.
# Uses the evaluate function to get quantitative metrics.

print("=" * 60)
print("EVALUATING MODEL BEFORE SFT TRAINING")
print("=" * 60)
print("\nRunning evaluation on test dataset...")
print("This will measure accuracy and format compliance before training.\n")

# Create a sampler for inference
pre_sft_sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

# Run evaluation
pre_sft_results = evaluate(eval_test_dataset, pre_sft_sampler, **GENERATION_CONFIGS["greedy"])

print("\n" + "=" * 60)
print("PRE-SFT EVALUATION RESULTS")
print("=" * 60)
print(f"  Answer Accuracy:    {pre_sft_results['accuracy']:.2f}%")
print(f"  Partial Accuracy:   {pre_sft_results['partial_accuracy']:.2f}%")
print(f"  Format Accuracy:    {pre_sft_results['format_accuracy']:.2f}%")
print(f"  Correct/Total:      {pre_sft_results['correct']}/{pre_sft_results['total']}")
print("=" * 60)

print("\n" + "=" * 60)
print("READY FOR SFT TRAINING")
print("=" * 60)
print("After SFT training, we expect improvements in:")
print("  - Format Accuracy: Model should learn <reasoning>...</reasoning><answer>...</answer> format")
print("  - Answer Accuracy: Better reasoning may lead to more correct answers")
print("=" * 60 + "\n")


In [None]:
# ============================================================================
# CELL 12: Run Cold Start SFT Training
# ============================================================================
# This stage teaches the model the reasoning format using Bespoke-Stratos-17k.
# Cold Start is ESSENTIAL - it prevents chaotic outputs during GRPO training.

print("=" * 60)
print("STAGE 1: COLD START SFT TRAINING")
print("=" * 60)
print("\nUsing Bespoke-Stratos-17k dataset to teach reasoning format.")
print("This stage plants the <reasoning>...</reasoning><answer>...</answer> template.")

# Ensure we have enough steps for warmup + decay
actual_sft_warmup = min(SFT_WARMUP_STEPS, SFT_MAX_STEPS // 10)
actual_sft_decay = max(SFT_MAX_STEPS, actual_sft_warmup + 100)

print(f"Adjusted schedule: warmup={actual_sft_warmup}, total_decay={actual_sft_decay}")

# Configure SFT optimizer with warmup and cosine decay
sft_optax_optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
    optax.adamw(
        learning_rate=optax.schedules.warmup_cosine_decay_schedule(
            init_value=0.0,
            peak_value=SFT_LEARNING_RATE,
            warmup_steps=actual_sft_warmup,
            decay_steps=actual_sft_decay,
            end_value=0.0,
        ),
        b1=B1,
        b2=B2,
        weight_decay=WEIGHT_DECAY,
    ),
)

# Use NNX Optimizer which properly handles state management
# wrt=nnx.LoRAParam specifies we only optimize LoRA parameters (not full model)
sft_optimizer = nnx.Optimizer(lora_policy, sft_optax_optimizer, wrt=nnx.LoRAParam)

# Convert checkpoint directory to absolute path (required by Orbax)
SFT_CKPT_DIR_ABS = os.path.abspath(SFT_CKPT_DIR)
os.makedirs(SFT_CKPT_DIR_ABS, exist_ok=True)

# Checkpoint manager for SFT
sft_ckpt_manager = ocp.CheckpointManager(
    SFT_CKPT_DIR_ABS,
    options=ocp.CheckpointManagerOptions(
        save_interval_steps=SFT_SAVE_STEPS,
        max_to_keep=MAX_TO_KEEP,
    ),
)

print(f"\nSFT Configuration:")
print(f"  Learning rate: {SFT_LEARNING_RATE}")
print(f"  Max steps: {SFT_MAX_STEPS}")
print(f"  Batch size: {SFT_BATCH_SIZE}")
print(f"  Checkpoint dir: {SFT_CKPT_DIR_ABS}")


# Define the training step function using NNX patterns
@nnx.jit
def sft_train_step(model, optimizer, tokens, loss_mask):
    """
    Single SFT training step using NNX.
    
    This function is JIT-compiled for efficiency.
    """
    def loss_fn(model):
        # Get sequence length for position calculation
        batch_size = tokens.shape[0]
        seq_len = tokens.shape[1] - 1
        
        # Create position indices
        positions = jnp.arange(seq_len)[None, :]  # Shape: (1, seq_len)
        positions = jnp.broadcast_to(positions, (batch_size, seq_len))
        
        # Create causal attention mask (lower triangular)
        # Shape: (seq_len, seq_len) - True where attention is allowed
        causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))
        
        # Forward pass - get logits for all positions except last
        # Model returns (cache, logits) tuple
        output = model(
            tokens[:, :-1],
            positions,
            cache=None,
            attention_mask=causal_mask
        )
        # Extract logits from output
        # Gemma3 model returns different formats - find the logits array
        if isinstance(output, tuple):
            # Try to find the logits (should be the array with shape [batch, seq, vocab])
            for item in output:
                if item is not None and hasattr(item, 'shape') and len(item.shape) == 3:
                    logits = item
                    break
            else:
                # If not found, try first non-None element
                logits = output[0] if output[0] is not None else output[1]
        else:
            logits = output
        
        # Targets are the next tokens
        targets = tokens[:, 1:]
        
        # Cross-entropy loss
        log_probs = jax.nn.log_softmax(logits, axis=-1)
        target_log_probs = jnp.take_along_axis(
            log_probs, targets[:, :, None], axis=-1
        ).squeeze(-1)
        
        # Apply loss mask (to ignore padding tokens)
        loss = -jnp.sum(target_log_probs * loss_mask) / (jnp.sum(loss_mask) + 1e-8)
        return loss
    
    # Compute loss and gradients, then update
    # Flax 0.11.0+: optimizer.update requires (model, grads)
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)
    
    return loss


# Initialize loss tracking for visualization
sft_loss_history = []  # List to store (step, loss) pairs
sft_log_dir = "./logs/sft"

# Create log directory for TensorBoard
os.makedirs(sft_log_dir, exist_ok=True)

# Initialize TensorBoard SummaryWriter (if tensorboardX is available)
try:
    from tensorboardX import SummaryWriter
    tb_writer = SummaryWriter(log_dir=sft_log_dir)
    use_tensorboard = True
    print(f"TensorBoard logging enabled. Logs saved to: {sft_log_dir}")
except ImportError:
    use_tensorboard = False
    print("tensorboardX not available. Loss will only be stored in memory for plotting.")

# Training loop
print(f"\n{'='*60}")
print(f"Starting SFT training for {SFT_MAX_STEPS} steps...")
print(f"{'='*60}\n")

step = 0
total_loss = 0.0
num_epochs = max(1, (SFT_MAX_STEPS * SFT_BATCH_SIZE) // len(sft_data) + 1)

for epoch in range(num_epochs):
    print(f"\n--- Epoch {epoch + 1}/{num_epochs} ---")
    
    for batch in tqdm(sft_dataset, desc=f"Epoch {epoch+1}"):
        if step >= SFT_MAX_STEPS:
            break
        
        # Get texts from batch
        texts = batch["text"]
        if isinstance(texts, np.ndarray):
            texts = texts.tolist()
        
        # Tokenize batch
        tokens_list = [tokenizer.encode(t) for t in texts]
        
        # Pad sequences
        max_len = min(MAX_SEQ_LENGTH, max(len(t) for t in tokens_list))
        padded_tokens = np.zeros((len(tokens_list), max_len), dtype=np.int32)
        attention_mask = np.zeros((len(tokens_list), max_len), dtype=np.float32)
        
        for i, toks in enumerate(tokens_list):
            length = min(len(toks), max_len)
            padded_tokens[i, :length] = toks[:length]
            attention_mask[i, :length] = 1.0
        
        # Convert to JAX arrays
        tokens_jax = jnp.array(padded_tokens)
        # Loss mask is for the TARGET tokens (shifted by 1 from input)
        loss_mask_jax = jnp.array(attention_mask[:, 1:])
        
        # Training step
        loss = sft_train_step(lora_policy, sft_optimizer, tokens_jax, loss_mask_jax)
        
        loss_value = float(loss)
        total_loss += loss_value
        step += 1
        
        # Record loss for visualization
        sft_loss_history.append((step, loss_value))
        
        # Log to TensorBoard
        if use_tensorboard:
            tb_writer.add_scalar('train/loss', loss_value, step)
            tb_writer.add_scalar('train/avg_loss', total_loss / step, step)
        
        # Console logging
        if step % 20 == 0:
            avg_loss = total_loss / step
            print(f"Step {step}/{SFT_MAX_STEPS} | Loss: {loss_value:.4f} | Avg Loss: {avg_loss:.4f}")
        
        # Save checkpoint
        if step % SFT_SAVE_STEPS == 0:
            print(f"\n[Checkpoint] Saving at step {step}...")
            sft_ckpt_manager.save(
                step,
                args=ocp.args.StandardSave(nnx.state(lora_policy, nnx.LoRAParam)),
            )
            print(f"[Checkpoint] Saved to {SFT_CKPT_DIR_ABS}")
        
        if step >= SFT_MAX_STEPS:
            break
    
    if step >= SFT_MAX_STEPS:
        break

# Save final SFT checkpoint
print(f"\n[Final Checkpoint] Saving final SFT model...")
sft_ckpt_manager.save(
    step,
    args=ocp.args.StandardSave(nnx.state(lora_policy, nnx.LoRAParam)),
)

# Close TensorBoard writer
if use_tensorboard:
    tb_writer.close()
    print(f"\nTensorBoard logs saved to: {sft_log_dir}")
    print("To view: tensorboard --logdir ./logs/sft")

print("\n" + "=" * 60)
print("COLD START SFT TRAINING COMPLETE!")
print(f"Final Loss: {total_loss / step:.4f}")
print(f"Checkpoints saved to: {SFT_CKPT_DIR_ABS}")
print(f"Loss history recorded: {len(sft_loss_history)} steps")
print("=" * 60)

show_hbm_usage()


In [None]:
# ============================================================================
# CELL 12.5: Visualize Training Results & Compare Model Output After SFT
# ============================================================================
# This cell:
# 1. Plots the loss curve from SFT training
# 2. Shows model output AFTER SFT training for comparison

print("=" * 60)
print("SFT TRAINING RESULTS VISUALIZATION")
print("=" * 60)

# 1. Plot Loss Curve
if len(sft_loss_history) > 0:
    import matplotlib.pyplot as plt
    
    steps, losses = zip(*sft_loss_history)
    
    # Create figure with subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Raw loss over steps
    ax1.plot(steps, losses, 'b-', alpha=0.6, linewidth=1, label='Step Loss')
    
    # Add moving average for smoother visualization
    window_size = max(10, len(losses) // 20)
    if len(losses) >= window_size:
        moving_avg = []
        for i in range(len(losses)):
            start_idx = max(0, i - window_size // 2)
            end_idx = min(len(losses), i + window_size // 2 + 1)
            moving_avg.append(np.mean(losses[start_idx:end_idx]))
        ax1.plot(steps, moving_avg, 'r-', linewidth=2, label=f'Moving Avg (window={window_size})')
    
    ax1.set_xlabel('Training Step')
    ax1.set_ylabel('Loss')
    ax1.set_title('SFT Training Loss Over Time')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Average loss (cumulative average)
    avg_losses = [np.mean(losses[:i+1]) for i in range(len(losses))]
    ax2.plot(steps, avg_losses, 'g-', linewidth=2, label='Cumulative Average Loss')
    ax2.set_xlabel('Training Step')
    ax2.set_ylabel('Average Loss')
    ax2.set_title('Cumulative Average Loss')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nLoss Statistics:")
    print(f"  Initial Loss: {losses[0]:.4f}")
    print(f"  Final Loss: {losses[-1]:.4f}")
    print(f"  Average Loss: {np.mean(losses):.4f}")
    print(f"  Min Loss: {np.min(losses):.4f} (at step {steps[np.argmin(losses)]})")
    print(f"  Loss Reduction: {((losses[0] - losses[-1]) / losses[0] * 100):.2f}%")
else:
    print("Warning: No loss history recorded. Skipping loss plot.")

print("\n" + "=" * 60)
print("EVALUATING MODEL AFTER SFT TRAINING")
print("=" * 60)
print("\nRunning evaluation on the same test dataset...")
print("This will show quantitative improvements after training.\n")

# Create sampler for post-SFT inference
post_sft_sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

# Run evaluation on the same dataset
post_sft_results = evaluate(eval_test_dataset, post_sft_sampler, **GENERATION_CONFIGS["greedy"])

print("\n" + "=" * 60)
print("POST-SFT EVALUATION RESULTS")
print("=" * 60)
print(f"  Answer Accuracy:    {post_sft_results['accuracy']:.2f}%")
print(f"  Partial Accuracy:   {post_sft_results['partial_accuracy']:.2f}%")
print(f"  Format Accuracy:    {post_sft_results['format_accuracy']:.2f}%")
print(f"  Correct/Total:      {post_sft_results['correct']}/{post_sft_results['total']}")
print("=" * 60)

print("\n" + "=" * 60)
print("BEFORE vs AFTER COMPARISON")
print("=" * 60)
print(f"{'Metric':<20} {'Before SFT':<15} {'After SFT':<15} {'Change':<15}")
print("-" * 65)

acc_change = post_sft_results['accuracy'] - pre_sft_results['accuracy']
partial_change = post_sft_results['partial_accuracy'] - pre_sft_results['partial_accuracy']
format_change = post_sft_results['format_accuracy'] - pre_sft_results['format_accuracy']

print(f"{'Answer Accuracy':<20} {pre_sft_results['accuracy']:>6.2f}%      {post_sft_results['accuracy']:>6.2f}%      {acc_change:>+6.2f}%")
print(f"{'Partial Accuracy':<20} {pre_sft_results['partial_accuracy']:>6.2f}%      {post_sft_results['partial_accuracy']:>6.2f}%      {partial_change:>+6.2f}%")
print(f"{'Format Accuracy':<20} {pre_sft_results['format_accuracy']:>6.2f}%      {post_sft_results['format_accuracy']:>6.2f}%      {format_change:>+6.2f}%")
print("=" * 65)

print("\nKey Improvements:")
if format_change > 0:
    print(f"  ✓ Format accuracy improved by {format_change:.2f}%")
if acc_change > 0:
    print(f"  ✓ Answer accuracy improved by {acc_change:.2f}%")
if acc_change <= 0 and format_change <= 0:
    print("  Note: Model is learning the format, accuracy improvements may come in GRPO stage")
print("=" * 60 + "\n")

print("\nTensorBoard Visualization:")
print(f"  View detailed metrics: tensorboard --logdir {sft_log_dir}")
print("  Or use: %tensorboard --logdir ./logs/sft")


## Section 7: Stage 2 - GRPO Training

### What is GRPO?

GRPO (Group Relative Policy Optimization) is a reinforcement learning algorithm that:
1. Generates multiple responses for each prompt
2. Scores each response using reward functions
3. Updates the model to favor higher-scoring responses

Key advantages over PPO:
- No need for a separate value/critic model (saves memory!)
- Uses group-based advantage estimation
- More stable training for reasoning tasks


In [None]:
# ============================================================================
# CELL 13: Define Reward Functions for GRPO
# ============================================================================
# These functions evaluate model outputs and provide learning signals.

def match_format_exactly(prompts, completions, **kwargs):
    """
    Reward function 1: Exact format matching
    
    Awards 3 points if the output exactly matches the expected format:
    <reasoning>...</reasoning><answer>...</answer>
    """
    rewards = []
    for response in completions:
        match = match_format.search(response)
        rewards.append(3.0 if match else 0.0)
    return rewards


def match_format_approximately(prompts, completions, **kwargs):
    """
    Reward function 2: Partial format matching
    
    Awards partial credit for having some of the required elements.
    This helps the model learn incrementally.
    """
    scores = []
    for completion in completions:
        score = 0
        # Check for each required element
        score += 0.5 if completion.count(reasoning_start) == 1 else -0.5
        score += 0.5 if completion.find(reasoning_start) == 0 else -0.5
        score += 0.5 if completion.count(reasoning_end) == 1 else -0.5
        score += 0.5 if completion.count(solution_start) == 1 else -0.5
        score += 0.5 if completion.count(solution_end) == 1 else -0.5
        scores.append(score)
    return scores


def check_answer(prompts, completions, answer, **kwargs):
    """
    Reward function 3: Answer correctness
    
    Awards points based on how close the predicted answer is to the truth:
    - 3.0 points for exact match
    - 1.5 points for match after stripping whitespace
    - 0.5 points if within 10% of correct value
    - Penalties for wrong answers
    """
    extracted_responses = [
        guess.group(1) if r is not None and (guess := match_format.search(r)) else None
        for r in completions
    ]
    
    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
            
        if guess == true_answer:
            score += 3.0
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            try:
                ratio = float(guess) / float(true_answer)
                if 0.9 <= ratio <= 1.1:
                    score += 0.5
                elif 0.8 <= ratio <= 1.2:
                    score += 0.25
                else:
                    score -= 1.0
            except:
                score -= 0.5
        scores.append(score)
    return scores


def check_numbers(prompts, completions, answer, **kwargs):
    """
    Reward function 4: Numerical answer extraction
    
    Sometimes the answer tag contains extra text. This function
    extracts numerical values and compares them.
    """
    question = kwargs.get("question", [""] * len(completions))
    
    extracted_responses = [
        guess.group(1) if (guess := match_numbers.search(r)) else None
        for r in completions
    ]
    
    scores = []
    # Log first example for debugging
    if len(completions) > 0:
        print(f"Q: {question[0][:50]}... | A: {answer[0]} | Pred: {extracted_responses[0]}")
    
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        try:
            true_val = float(true_answer.strip())
            pred_val = float(guess.strip())
            scores.append(1.5 if pred_val == true_val else 0.0)
        except:
            scores.append(0)
    return scores


print("Reward functions defined:")
print("  1. match_format_exactly: Rewards correct output structure")
print("  2. match_format_approximately: Partial credit for format")
print("  3. check_answer: Rewards correct answers")
print("  4. check_numbers: Extracts and compares numerical answers")


In [None]:
# ============================================================================
# CELL 14: Prepare GRPO Dataset
# ============================================================================

print("Loading GRPO training data...")

# Load GSM8K dataset for GRPO training
NUM_GRPO_BATCHES = 3738  # Full dataset
NUM_TEST_BATCHES = 64    # For evaluation

grpo_dataset = get_grpo_dataset(TRAIN_DATA_DIR, "train", "kaggle")
grpo_dataset = grpo_dataset.batch(GRPO_BATCH_SIZE)[:NUM_GRPO_BATCHES]

# Split into train/validation
if TRAIN_FRACTION < 1.0:
    split_idx = int(len(grpo_dataset) * TRAIN_FRACTION)
    train_dataset = grpo_dataset[:split_idx]
    val_dataset = grpo_dataset[split_idx:]
else:
    train_dataset = grpo_dataset
    val_dataset = None

# Load test dataset
test_dataset = get_grpo_dataset(TEST_DATA_DIR, "test", "kaggle")
test_dataset = test_dataset.batch(GRPO_BATCH_SIZE)[:NUM_TEST_BATCHES]

print(f"\nDataset sizes:")
print(f"  Training batches: {len(train_dataset)}")
print(f"  Validation batches: {len(val_dataset) if val_dataset else 0}")
print(f"  Test batches: {len(test_dataset)}")

# Preview a sample
print("\nSample GRPO training example:")
for batch in train_dataset[:1]:
    pprint(batch)


In [None]:
# ============================================================================
# CELL 15: Configure and Run GRPO Training
# ============================================================================

print("=" * 60)
print("STAGE 2: GRPO REINFORCEMENT LEARNING")
print("=" * 60)

# Calculate training steps
GRPO_WARMUP_STEPS = int(GRPO_WARMUP_RATIO * GRPO_MAX_STEPS)
EVAL_EVERY_N_STEPS = 64

# Configure GRPO optimizer (lower learning rate than SFT!)
grpo_optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=GRPO_LEARNING_RATE,
        warmup_steps=GRPO_WARMUP_STEPS,
        decay_steps=GRPO_MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)

# Add gradient clipping (critical for RL stability!)
grpo_optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
    grpo_optimizer,
)

# Checkpoint configuration
grpo_ckpt_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS,
    max_to_keep=MAX_TO_KEEP,
)

# Metrics logging
grpo_metrics_options = metrics_logger.MetricsLoggerOptions(
    log_dir="./logs/grpo/",
    flush_every_n_steps=20,
)

print(f"\nGRPO Configuration:")
print(f"  Learning rate: {GRPO_LEARNING_RATE}")
print(f"  Generations per prompt (G): {NUM_GENERATIONS}")
print(f"  KL penalty (beta): {BETA}")
print(f"  Clipping (epsilon): {EPSILON}")
print(f"  Max steps: {GRPO_MAX_STEPS}")


In [None]:
# ============================================================================
# CELL 16: Initialize GRPO Cluster and Trainer
# ============================================================================

# Training configuration
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=grpo_optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=GRPO_MAX_STEPS,
        mini_batch_size=GRPO_BATCH_SIZE,
        train_micro_batch_size=GRPO_BATCH_SIZE,
        metrics_logging_options=grpo_metrics_options,
        checkpoint_root_directory=GRPO_CKPT_DIR,
        checkpointing_options=grpo_ckpt_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=EOS_TOKENS,
    ),
)

# GRPO algorithm configuration
grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

print("\nInitializing RL Cluster...")
# Create RL cluster with policy and reference models
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,      # Model being trained (with LoRA)
    reference=gemma3,        # Fixed reference model (for KL divergence)
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

print("Creating GRPO Trainer...")
# Create GRPO trainer with reward functions
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    algo_config=grpo_config,
)

print("GRPO Trainer initialized successfully!")


In [None]:
# ============================================================================
# CELL 17: Launch TensorBoard for Monitoring
# ============================================================================

# Load TensorBoard extension
%load_ext tensorboard

# Start TensorBoard to monitor training metrics
%tensorboard --logdir ./logs/ --port=0

print("TensorBoard launched! Monitor training metrics above.")


In [None]:
# ============================================================================
# CELL 18: Run GRPO Training
# ============================================================================
# This is the main training loop. It may take several hours to complete.
# Checkpoints are saved automatically during training.

print("\n" + "=" * 60)
print("STARTING GRPO TRAINING")
print("=" * 60)
print(f"\nThis will take approximately 6-8 hours on TPU v6e-1")
print(f"Checkpoints saved every {SAVE_INTERVAL_STEPS} steps to {GRPO_CKPT_DIR}")
print("\nMonitor progress in TensorBoard above.")
print("Key metrics to watch:")
print("  - reward_accuracy: Should increase over time")
print("  - reward_format: Should stay high (near 1.0)")
print("  - completion_length: May increase as model reasons more")
print("=" * 60 + "\n")

# Run GRPO training
grpo_trainer.train(train_dataset, val_dataset)

print("\n" + "=" * 60)
print("GRPO TRAINING COMPLETE!")
print("=" * 60)


## Section 8: Evaluation

Now we evaluate the trained model to see how much it has improved. We measure:
1. **Answer Accuracy**: Percentage of correct answers
2. **Partial Accuracy**: Answers within 10% of correct value  
3. **Format Accuracy**: Outputs following the correct structure


In [None]:
# ============================================================================
# CELL 19: Load Best Checkpoint for Evaluation
# ============================================================================

# Load the trained checkpoint
EVAL_CHECKPOINT_STEP = GRPO_MAX_STEPS  # Use final checkpoint

trained_ckpt_path = os.path.join(
    GRPO_CKPT_DIR, "actor", str(EVAL_CHECKPOINT_STEP), "model_params"
)

print(f"Loading checkpoint from: {trained_ckpt_path}")

# Get the shape/dtype structure for restoration
abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)

# Restore checkpoint
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

# Update model with trained parameters
nnx.update(
    lora_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_policy, nnx.LoRAParam),
        trained_lora_params,
    ),
)

print("Checkpoint loaded successfully!")


In [None]:
# ============================================================================
# CELL 20: Evaluation Functions
# ============================================================================

def generate(question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None):
    """Generate response for a given question."""
    if isinstance(question, str):
        input_batch = [TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=question)]
    else:
        input_batch = [
            TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=q)
            for q in question
        ]
    
    out_data = sampler(
        input_strings=input_batch,
        max_generation_steps=TOTAL_GENERATION_STEPS,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        echo=False,
        seed=seed,
        eos_tokens=EOS_TOKENS,
    )
    
    output = out_data.text
    return output[0] if isinstance(question, str) else output


def evaluate(dataset, sampler, temperature=0.7, top_k=50, top_p=0.95):
    """Evaluate model on dataset and compute metrics."""
    correct = 0
    partially_correct = 0
    correct_format = 0
    total = 0
    
    for batch in tqdm(dataset, desc="Evaluating"):
        answers = batch["answer"]
        questions = batch["question"]
        
        responses = generate(questions, sampler, temperature, top_k, top_p)
        
        for question, response, answer in zip(questions, responses, answers):
            # Check answer
            extracted = match_numbers.search(response)
            if extracted:
                try:
                    pred = float(extracted.group(1).strip())
                    true_val = float(answer.strip())
                    if pred == true_val:
                        correct += 1
                    ratio = pred / true_val if true_val != 0 else 0
                    if 0.9 <= ratio <= 1.1:
                        partially_correct += 1
                except:
                    pass
            
            # Check format
            if match_format.search(response):
                correct_format += 1
            
            total += 1
            
            # Print progress
            if total % 20 == 0:
                print(f"Progress: {correct}/{total} correct ({correct/total*100:.1f}%)")
    
    return {
        "accuracy": correct / total * 100,
        "partial_accuracy": partially_correct / total * 100,
        "format_accuracy": correct_format / total * 100,
        "correct": correct,
        "total": total,
    }


print("Evaluation functions defined!")


In [None]:
# ============================================================================
# CELL 21: Run Evaluation on Test Set
# ============================================================================

print("=" * 60)
print("EVALUATING TRAINED MODEL")
print("=" * 60)

# Create sampler for inference
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

# Run evaluation with greedy decoding for deterministic results
print("\nEvaluating with greedy decoding...")
results = evaluate(test_dataset, sampler, **GENERATION_CONFIGS["greedy"])

print("\n" + "=" * 60)
print("EVALUATION RESULTS")
print("=" * 60)
print(f"  Answer Accuracy:    {results['accuracy']:.2f}%")
print(f"  Partial Accuracy:   {results['partial_accuracy']:.2f}%")
print(f"  Format Accuracy:    {results['format_accuracy']:.2f}%")
print(f"  Correct/Total:      {results['correct']}/{results['total']}")
print("=" * 60)


## Section 9: Export Model

Export the trained model in a format compatible with Tunix on Kaggle. The exported model:
- Merges LoRA weights with the base model
- Saves in safetensors format
- Includes all necessary config files


In [None]:
# ============================================================================
# CELL 22: Export Merged Model (HuggingFace Format)
# ============================================================================

output_dir = FINAL_MODEL_DIR
if USE_COLAB:
    output_dir = f"/tmp/content/{MODEL_ID.replace('/', '-')}-trained"

# Clean up existing output directory
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir)

print("=" * 60)
print("EXPORTING TRAINED MODEL")
print("=" * 60)
print(f"\nOutput directory: {output_dir}")

# Merge LoRA weights and save
print("\nMerging LoRA weights with base model...")
gemma_params.save_lora_merged_model_as_safetensors(
    local_model_path=local_model_path,
    output_dir=output_dir,
    lora_model=lora_policy,
    rank=RANK,
    alpha=ALPHA,
)

print("\n" + "=" * 60)
print("MODEL EXPORTED SUCCESSFULLY!")
print("=" * 60)

# List saved files
print("\nSaved files:")
for f in sorted(os.listdir(output_dir)):
    size = os.path.getsize(os.path.join(output_dir, f)) / (1024 * 1024)
    print(f"  {f:<35} {size:>10.2f} MB")


In [None]:
# ============================================================================
# CELL 23: Interactive Demo - Test the Trained Model
# ============================================================================

print("=" * 60)
print("INTERACTIVE DEMO")
print("=" * 60)

# Test questions covering different domains
test_questions = [
    # Math
    "If a train travels at 60 miles per hour for 2.5 hours, how far does it travel?",
    # Logic
    "A farmer has 17 sheep. All but 9 die. How many sheep are left?",
    # Word problem
    "Lisa has 3 times as many apples as Tom. If Tom has 5 apples, how many do they have together?",
]

print("\nGenerating responses for test questions...\n")

for i, question in enumerate(test_questions, 1):
    print(f"Question {i}: {question}")
    print("-" * 50)
    
    response = generate(question, sampler, **GENERATION_CONFIGS["greedy"])
    print(f"Response:\n{response}")
    print("=" * 60 + "\n")


In [None]:
# ============================================================================
# CELL 24: Download Model (For Colab Users)
# ============================================================================

if USE_COLAB:
    from google.colab import files
    
    # Create zip archive of the model
    zip_path = f"{output_dir}.zip"
    print(f"Creating zip archive: {zip_path}")
    shutil.make_archive(output_dir, 'zip', output_dir)
    
    print("Downloading model... (this may take a moment)")
    files.download(zip_path)
    print("Download complete!")
else:
    print(f"Model saved to: {output_dir}")
    print("\nTo use this model, load it with Tunix:")
    print(f"  model = params_safetensors_lib.create_model_from_safe_tensors('{output_dir}', config, mesh)")


## Summary

Congratulations! You have successfully trained a reasoning model using:

1. **Cold Start SFT**: Taught the model the output format
   - `<reasoning>model_thinking_trace</reasoning>`
   - `<answer>model_answer</answer>`

2. **GRPO Reinforcement Learning**: Strengthened reasoning abilities
   - Used multiple reward functions
   - No separate value model needed (memory efficient!)

### Key Takeaways

- **Cold Start is Essential**: Without it, RL training can cause chaotic outputs
- **Learning Rate Matters**: SFT uses ~2e-4, GRPO uses ~3e-6 (much lower!)
- **Checkpoints are Crucial**: Save frequently to recover from failures
- **Format Rewards Help**: Reward proper structure, not just correct answers

### Next Steps

1. Try different datasets (Bespoke-Stratos-17k, OpenR1-Math-220k)
2. Experiment with hyperparameters (NUM_GENERATIONS, BETA, EPSILON)
3. Train for more steps for better results
4. Evaluate on diverse domains (coding, science, creative writing)

### Resources

- [Tunix GitHub](https://github.com/google/tunix)
- [GRPO Paper](https://arxiv.org/pdf/2402.03300)
- [DeepSeek-R1 Technical Report](https://arxiv.org/abs/2401.02954)
- [Bespoke-Stratos Dataset](https://huggingface.co/datasets/bespokelabs/Bespoke-Stratos-17k)
