# GRPO Training on TPU

This notebook runs GRPO training using the TunRex codebase.

**Requirements:**
- TPU runtime (Runtime > Change runtime type > TPU)
- HuggingFace token (for gated models like Gemma)
- W&B API key (optional, for experiment tracking)

In [None]:
# @title Configuration
REPO_URL = "https://github.com/42euge/ee596-fp.git"  # @param {type:"string"}
BRANCH = "main"  # @param {type:"string"}

# Training parameters
NUM_STEPS = 100  # @param {type:"integer"}
MODEL_ID = "google/gemma-3-1b-it"  # @param {type:"string"}
LEARNING_RATE = 3e-6  # @param {type:"number"}
BATCH_SIZE = 1  # @param {type:"integer"}
USE_LORA = True  # @param {type:"boolean"}

# W&B settings
WANDB_PROJECT = "tunix-grpo"  # @param {type:"string"}
RUN_NAME = ""  # @param {type:"string"}

In [None]:
# @title Set API Keys
from google.colab import userdata
import os

# Try to get from Colab secrets, fall back to manual input
try:
    os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
    print("HF_TOKEN loaded from Colab secrets")
except:
    from getpass import getpass
    os.environ["HF_TOKEN"] = getpass("Enter HuggingFace token: ")

try:
    os.environ["WANDB_API_KEY"] = userdata.get("WANDB_API_KEY")
    print("WANDB_API_KEY loaded from Colab secrets")
except:
    from getpass import getpass
    os.environ["WANDB_API_KEY"] = getpass("Enter W&B API key (or leave empty): ")

In [None]:
# @title Clone Repository
import os

!rm -rf /content/training
!git clone --recursive --branch {BRANCH} {REPO_URL} /content/training
%cd /content/training
!git log -1 --oneline

In [None]:
# @title Install Dependencies
!pip install uv
!uv sync --frozen

In [None]:
# @title Check TPU
import jax
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"TPU cores: {len([d for d in jax.devices() if d.platform == 'tpu'])}")

In [None]:
# @title Run Training
lora_flag = "--use-lora" if USE_LORA else ""
run_name_flag = f"--run-name {RUN_NAME}" if RUN_NAME else ""

!uv run python scripts/train_grpo.py \
    --num-steps {NUM_STEPS} \
    --model-id {MODEL_ID} \
    --learning-rate {LEARNING_RATE} \
    --batch-size {BATCH_SIZE} \
    --wandb-project {WANDB_PROJECT} \
    {lora_flag} \
    {run_name_flag}