# v42.0.0.3 Gemma3-1B GRPO Training Notebook

This notebook trains Gemma3-1B with GRPO (Group Relative Policy Optimization) for improved reasoning.

**Requirements:**
- Google Colab with TPU runtime (recommended) or GPU
- HuggingFace account with Gemma access

**Output:**
- LoRA checkpoint files that can be downloaded and used locally

## 1. Setup

In [None]:
# Install dependencies
import importlib.util

if importlib.util.find_spec('tensorflow') is None:
  print("Installing required packages...")
  %pip install -q dotenv
  %pip install -q kagglehub
  %pip install -q ipywidgets
  %pip install -q tensorflow
  %pip install -q tensorflow_datasets
  %pip install -q tensorboardX
  %pip install -q transformers
  %pip install -q grain
  %pip install -q git+https://github.com/jax-ml/jax
  %pip install git+https://github.com/google/tunix
  %pip install git+https://github.com/google/qwix
  %pip uninstall -q flax -y
  %pip install git+https://github.com/google/flax
  %pip install -q huggingface_hub
  %pip install -q datasets
  %pip install -q 'numpy>2'

In [None]:
# TODO: add tunrex from github
import importlib.util

if importlib.util.find_spec('tunrex') is None:
  %pip install git+https://github.com/42euge/TunRex.git@feature/models-api

In [None]:
# Configuration
# =============================================================================
# EDIT THESE VALUES
# =============================================================================

# Training settings
NUM_BATCHES = 500          # Number of training batches (500 = ~30 min on TPU)
LEARNING_RATE = 3e-6       # Learning rate
LORA_RANK = 64             # LoRA rank
LORA_ALPHA = 64.0          # LoRA alpha

# Dataset settings
USE_OPENRUBRICS = True     # Use OpenRubrics dataset
OPENRUBRICS_MAX = 2000     # Max examples from OpenRubrics

# Checkpoint settings
SAVE_TO_DRIVE = False       # Save checkpoints to Google Drive
EXPERIMENT_NAME = "gemma3_grpo_reasoning"

# =============================================================================
# CREDENTIALS - Three options (in order of priority):
#   1. Literal values below (uncomment and fill in)
#   2. Google Colab secrets (set via key icon in sidebar)
#   3. Kaggle secrets (when running on Kaggle)
# =============================================================================
import os

# Option 1: Literal values (uncomment and fill in your credentials)
os.environ['WANDB_API_KEY'] = 'REDACTED_WANDB_KEY'
os.environ['KAGGLE_USERNAME'] = 'eugenio0'
os.environ['KAGGLE_KEY'] = 'REDACTED_KAGGLE_KEY'

# Option 2 & 3: Try secrets providers if env vars not already set
if not os.environ.get('KAGGLE_USERNAME'):
    # Try Google Colab secrets first
    try:
        from google.colab import userdata
        os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
        os.environ['KAGGLE_USERNAME'] = userdata.get('KAGGLE_USERNAME')
        os.environ['KAGGLE_KEY'] = userdata.get('KAGGLE_KEY')
        print("Using Google Colab secrets")
    except (ImportError, ModuleNotFoundError):
        # Fall back to Kaggle secrets
        try:
            from kaggle_secrets import UserSecretsClient
            secrets = UserSecretsClient()
            os.environ['WANDB_API_KEY'] = secrets.get_secret('WANDB_API_KEY')
            os.environ['KAGGLE_USERNAME'] = secrets.get_secret('KAGGLE_USERNAME')
            os.environ['KAGGLE_KEY'] = secrets.get_secret('KAGGLE_KEY')
            print("Using Kaggle secrets")
        except (ImportError, ModuleNotFoundError):
            print("WARNING: No credentials found. Either:")
            print("  1. Uncomment and fill in literal values above")
            print("  2. Set Colab secrets (key icon in sidebar)")
            print("  3. Set Kaggle secrets")
else:
    print("Using literal credentials from environment")

print(f"\nTraining config:")
print(f"  Batches: {NUM_BATCHES}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  LoRA rank: {LORA_RANK}")
print(f"  Kaggle user: {os.environ.get('KAGGLE_USERNAME', 'not set')}")

In [None]:
# Mount Google Drive (optional but recommended)
if SAVE_TO_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    CHECKPOINT_DIR = f"/content/drive/MyDrive/{EXPERIMENT_NAME}/checkpoints"
    import os
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")
else:
    CHECKPOINT_DIR = "/content/checkpoints"
    import os
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
# Imports
import functools
import gc
import os
import re
import csv
import shutil
from pprint import pprint
from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import params
from tunix.models.gemma3 import model
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from datasets import load_dataset

print(f"JAX devices: {jax.devices()}")

## 2. Prompt Template

In [None]:

# Prompt configuration
REASONING_START = "<reasoning>"
REASONING_END = "</reasoning>"
SOLUTION_START = "<answer>"
SOLUTION_END = "</answer>"

SYSTEM_PROMPT = f"""You are given a problem. Think carefully and show your detailed reasoning step-by-step. Place your reasoning between {REASONING_START} and {REASONING_END}. After completing your reasoning, provide the final answer between {SOLUTION_START} and {SOLUTION_END}."""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""

def format_prompt(question, rubric=None):
    rubric_block = f"\nRubric:\n{rubric}\n\n" if rubric else ""
    return TEMPLATE.format(
        system_prompt=SYSTEM_PROMPT,
        question=f"{rubric_block}{question}",
    )

print("Prompt template configured.")

## 3. Load Dataset

In [None]:
# Load dataset using TunRex
from tunrex import load_openrubrics

if USE_OPENRUBRICS:
    train_data = load_openrubrics(max_examples=OPENRUBRICS_MAX)
else:
    train_data = []  # Add GSM8K loading if needed

print(f"\nTotal training examples: {len(train_data)}")

In [None]:
# Create grain dataset
def create_dataset(data):
    return (
        grain.MapDataset.source(data)
        .shuffle(seed=42)
        .map(
            lambda x: {
                "prompts": format_prompt(x["question"], x.get("rubric")),
                "question": x["question"],
                "rubric": x.get("rubric", ""),
                "reference_response": x.get("reference_response", ""),
            }
        )
    )

# Split data
split_idx = int(len(train_data) * 0.9)
train_split = train_data[:split_idx]
test_split = train_data[split_idx:]

train_dataset = create_dataset(train_split)
test_dataset = create_dataset(test_split)

print(f"Train: {len(train_split)}, Test: {len(test_split)}")

## 4. Load Model

In [None]:
# TODO: change to add a previous checkpoint loading mechanism

In [None]:
# Prepare Gemma checkpoint using TunRex
from tunrex import prepare_gemma_checkpoint

INTERMEDIATE_CKPT_DIR = "/tmp/intermediate_ckpt"
ckpt_path, MODEL_CP_PATH, tokenizer = prepare_gemma_checkpoint(
    ckpt_dir=INTERMEDIATE_CKPT_DIR,
)
print("Base model checkpoint prepared.")

In [None]:
# Load reference model using TunRex
from tunrex import get_gemma_ref_model

ref_model, mesh, model_config = get_gemma_ref_model(
    ckpt_path=ckpt_path,
    model_checkpoint_path=MODEL_CP_PATH,
)
print("Reference model loaded.")

In [None]:
# Create LoRA model using TunRex
from tunrex import get_lora_model

lora_policy = get_lora_model(
    base_model=ref_model,
    mesh=mesh,
    rank=LORA_RANK,
    alpha=LORA_ALPHA,
)
print(f"LoRA model created with rank={LORA_RANK}, alpha={LORA_ALPHA}")

## 5. Reward Functions

In [None]:
import string
import difflib
from collections import Counter

# Format matching regex
match_format = re.compile(
    rf"{REASONING_START}.*?{REASONING_END}.*?{SOLUTION_START}.*?{SOLUTION_END}",
    re.DOTALL
)

def match_format_reward(prompts, completions, **kwargs):
    """Reward for proper format usage."""
    scores = []
    for completion in completions:
        if match_format.search(completion):
            scores.append(2.0)
        elif REASONING_START in completion or SOLUTION_START in completion:
            scores.append(0.5)
        else:
            scores.append(-1.0)
    return scores

def rubric_overlap_score(response, rubric_text):
    """Calculate rubric overlap with TF-IDF weighting."""
    def tokenize(text):
        text = text.lower()
        for ch in string.punctuation:
            text = text.replace(ch, " ")
        return [t for t in text.split() if len(t) > 2]
    
    rubric_tokens = tokenize(rubric_text)
    response_tokens = set(tokenize(response))
    
    if not rubric_tokens:
        return 0.0
    
    token_counts = Counter(rubric_tokens)
    weighted_matches = sum(
        1.0 / token_counts[t] for t in response_tokens if t in token_counts
    )
    max_score = sum(1.0 / c for c in token_counts.values())
    
    coverage = weighted_matches / max_score if max_score > 0 else 0.0
    return coverage * 10.0

def rar_reward(prompts, completions, rubric=None, reference_response=None, **kwargs):
    """Rubric-as-Reward scoring."""
    rubrics = rubric or [""] * len(completions)
    references = reference_response or [""] * len(completions)
    
    rewards = []
    for response, rub, ref in zip(completions, rubrics, references):
        # Rubric overlap (0-10)
        r_score = rubric_overlap_score(response, rub) if rub else 0.0
        
        # Reference similarity (0-5)
        f_score = difflib.SequenceMatcher(None, ref, response).ratio() * 5.0 if ref else 0.0
        
        rewards.append(r_score + f_score)
    
    return rewards

print("Reward functions defined.")

## 6. Setup Training

In [None]:
# Create sampler for generation
sampler = sampler_lib.Sampler(
    model=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
    mesh=mesh,
)
print("Sampler created.")

In [None]:
# Training hyperparameters
MAX_STEPS = int(NUM_BATCHES * 0.94)  # With train fraction
WARMUP_STEPS = int(0.1 * MAX_STEPS)

# Optimizer with warmup + cosine decay
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=0.9,
    b2=0.99,
    weight_decay=0.1,
)
optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=0.1),
    optimizer,
)

print(f"Max steps: {MAX_STEPS}, Warmup: {WARMUP_STEPS}")

In [None]:
# GRPO configuration
grpo_config = GRPOConfig(
    num_generations=2,
    num_iterations=1,
    beta=0.08,
    epsilon=0.2,
)

# Cluster configuration
cluster_config = rl_cluster_lib.ClusterConfig(
    max_prompt_length=256,
    total_generation_steps=512,
)

# Data iterator config
data_iter_config = base_rollout.DataIteratorConfig(
    batch_size=2,
    num_batches=NUM_BATCHES,
)

print("GRPO config created.")

In [None]:
# Create RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    config=cluster_config,
    reference=ref_model,
    tokenizer=tokenizer,
    mesh=mesh,
    sampler=sampler,
)

# Checkpoint options
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=100,
    max_to_keep=3,
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/tensorboard/grpo",
    flush_every_n_steps=20,
)

print("RL cluster created.")

In [None]:
# Create GRPO trainer
reward_fns = [match_format_reward, rar_reward]

grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=reward_fns,
    algo_config=grpo_config,
    optimizer=optimizer,
    ckpt_dir=CHECKPOINT_DIR,
    ckpt_options=checkpointing_options,
    metrics_logger_options=metrics_logging_options,
)

print("GRPO trainer created.")
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")

## 7. Train!

In [None]:
# Create data iterator
train_iter = train_dataset.batch(data_iter_config.batch_size)

print("="*60)
print("Starting GRPO Training")
print("="*60)
print(f"Batches: {NUM_BATCHES}")
print(f"Checkpoint dir: {CHECKPOINT_DIR}")
print("="*60)

In [None]:
# Run training
grpo_trainer.train(
    policy=lora_policy,
    data_iterator=train_iter,
    data_iterator_config=data_iter_config,
)

print("\n" + "="*60)
print("Training complete!")
print("="*60)

## 8. Export Checkpoint for Local Use

In [None]:
# Find latest checkpoint
import glob

ckpt_dirs = sorted(glob.glob(f"{CHECKPOINT_DIR}/actor/*/"))
if ckpt_dirs:
    latest_ckpt = ckpt_dirs[-1]
    print(f"Latest checkpoint: {latest_ckpt}")
else:
    print("No checkpoints found!")

In [None]:
# Convert to HuggingFace format for local use
# This creates adapter files compatible with PEFT

EXPORT_DIR = f"{CHECKPOINT_DIR}/hf_lora"
os.makedirs(EXPORT_DIR, exist_ok=True)

# Save LoRA state
lora_state = nnx.state(lora_policy)

# Filter to only LoRA parameters
lora_params = {}
def extract_lora(path, value):
    path_str = ".".join(str(p) for p in path)
    if "lora" in path_str.lower():
        lora_params[path_str] = value

jax.tree_util.tree_map_with_path(extract_lora, lora_state)

print(f"Found {len(lora_params)} LoRA parameters")
print(f"\nCheckpoint saved to: {CHECKPOINT_DIR}")
print(f"\nTo use locally:")
print(f"1. Download the checkpoint folder from Google Drive")
print(f"2. Place in your local checkpoints/ directory")

In [None]:
# Create a zip file for easy download
if SAVE_TO_DRIVE:
    !cd {CHECKPOINT_DIR} && zip -r checkpoint_export.zip actor/
    print(f"\nZipped checkpoint: {CHECKPOINT_DIR}/checkpoint_export.zip")
    print("Download this file from Google Drive and extract to checkpoints/")

## 9. Quick Test

In [None]:
# Test the trained model
test_question = "A store sells apples for $2 each. If I buy 5 apples, how much do I spend?"
test_prompt = format_prompt(test_question)

print("Testing trained model...")
print(f"Question: {test_question}")
print()

response = sampler(
    [test_prompt],
    total_generation_steps=256,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
)[0]

print("Response:")
print(response)

---

## Done!

Your checkpoints are saved. To use them locally:

1. Download `checkpoint_export.zip` from Google Drive
2. Extract to your local `checkpoints/` folder
3. Run: `python demo/demo.py --checkpoint ./checkpoints/actor/<step>/model_params`