# GRPO Llama3.1-8B Demo

This notebook demonstrates GRPO (Group Relative Policy Optimization) training using the unified `rl_train` function.

## What is GRPO?

GRPO is an RL algorithm that enhances reasoning abilities of LLMs by:
1. Generating multiple responses for each prompt
2. Evaluating responses using reward models  
3. Calculating relative advantages to update the policy

## Hardware Requirements

- Single host TPUVM (v6e-8/v5p-8) or multi-host with Pathways
- Sufficient memory for Llama3.1-8B model

## Setup

Install dependencies and set up the environment:

In [None]:
# Clone MaxText repository
!git clone https://github.com/AI-Hypercomputer/maxtext
%cd maxtext/src

In [None]:
!bash tools/setup/setup.sh
%pip uninstall -y jax jaxlib libtpu

%pip install aiohttp==3.12.15

# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
%pip install keyring keyrings.google-artifactregistry-auth

# Install vLLM for Jax and TPUs from the artifact registry
!VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
    --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
    --extra-index-url https://pypi.org/simple/ \
    --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
    --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
    --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
    --find-links https://storage.googleapis.com/libtpu-releases/index.html \
    --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
    --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
    vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu

# Install tpu-commons from the artifact registry
%pip install --no-cache-dir --pre \
    --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
    --extra-index-url https://pypi.org/simple/ \
    --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
    --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
    tpu-commons==0.1.2

%pip install numba==0.61.2

In [None]:

%pip install nest_asyncio

import nest_asyncio
nest_asyncio.apply()  # Fix for Colab event loop

%cd maxtext/src/

#Fix nnx problems
!pip uninstall flax 
!pip uninstall qwix
!pip install flax 
!pip install qwix

## Configuration

Set up the training parameters. Defaults are hardcoded for Llama3.1-8B:

### Multi-host Pathways

To run this demo on a multi-host Pathways setup:
- Set `use_pathways=True` in `rl.yml` (enabled by default).
- Override `trainer_devices_fraction` and `sampler_devices_fraction` in `config_argv` to split the mesh across hosts.
- Launch the Colab kernel on the controller host and export Pathways runtime variables (for example `JAX_PLATFORMS=proxy` and `ENABLE_PATHWAYS_PERSISTENCE=1`) before running training.
- Update `chips_per_vm` to match your slice topology; Pathways will shard trainer and rollout workers automatically.


In [None]:
# Configuration for GRPO training
import os
import MaxText

# Set up paths (adjust if needed)
MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)
RUN_NAME="grpo_test"
# Hardcoded defaults for Llama3.1-8B
MODEL_NAME = "llama3.1-8b"
HF_REPO_ID = "meta-llama/Llama-3.1-8B-Instruct"
CHAT_TEMPLATE_PATH = f"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json"
LOSS_ALGO="gspo-token"

# Required: Set these before running
MODEL_CHECKPOINT_PATH = ""  # Update this!
OUTPUT_DIRECTORY = "/tmp/gpo_output"  # Update this!
HF_TOKEN = "" # Set HF_TOKEN environment variable

# Optional: Override training parameters
STEPS = 10  # Reduced for demo purposes
PER_DEVICE_BATCH_SIZE = 1
LEARNING_RATE = 3e-6
NUM_GENERATIONS = 2
GRPO_BETA = 0.08
GRPO_EPSILON = 0.2
CHIPS_PER_VM = 1

print(f"üìÅ MaxText Home: {MAXTEXT_REPO_ROOT}")
print(f"ü§ñ Model: {MODEL_NAME}")
print(f"üì¶ Checkpoint: {MODEL_CHECKPOINT_PATH}")
print(f"üíæ Output: {OUTPUT_DIRECTORY}")
print(f"üîë HF Token: {'‚úÖ Set' if HF_TOKEN else '‚ùå Missing - set HF_TOKEN env var'}")
print(f"üìä Steps: {STEPS}")
print(f"Loss Algorithm : {LOSS_ALGO}")

In [None]:
# Import required modules
import os
import sys
from pathlib import Path

# Add MaxText to Python path
maxtext_path = Path(MAXTEXT_REPO_ROOT) 
sys.path.insert(0, str(maxtext_path))

from MaxText import pyconfig, max_utils
from MaxText.rl.train_rl import rl_train
import jax

# Initialize JAX and Pathways
import pathwaysutils
pathwaysutils.initialize()
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["SKIP_JAX_PRECOMPILE"] = "1"  # Faster startup for vLLM

if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
    os.environ["LIBTPU_INIT_ARGS"] = (
        os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
    )

print("‚úÖ Successfully imported modules")
print(f"üìÅ MaxText path: {maxtext_path}")

In [None]:
# Build configuration for GRPO training
config_file = os.path.join(MAXTEXT_REPO_ROOT, "configs/rl.yml")

# Verify chat template exists
if not os.path.exists(os.path.join(MAXTEXT_REPO_ROOT, CHAT_TEMPLATE_PATH)):
    raise FileNotFoundError(f"Chat template not found: {CHAT_TEMPLATE_PATH}")

# Build argv list for pyconfig.initialize()
config_argv = [
    "",  # argv[0] placeholder
    config_file,
    f"model_name={MODEL_NAME}",
    f"tokenizer_path={HF_REPO_ID}",
    f"run_name={RUN_NAME}",
    f"chat_template_path={CHAT_TEMPLATE_PATH}",
    f"load_parameters_path={MODEL_CHECKPOINT_PATH}",
    f"base_output_directory={OUTPUT_DIRECTORY}",
    f"hf_access_token={HF_TOKEN}",
    f"steps={STEPS}",
    f"per_device_batch_size={PER_DEVICE_BATCH_SIZE}",
    f"learning_rate={LEARNING_RATE}",
    f"num_generations={NUM_GENERATIONS}",
    f"grpo_beta={GRPO_BETA}",
    f"grpo_epsilon={GRPO_EPSILON}",
    f"chips_per_vm={CHIPS_PER_VM}",
    f"loss_algo={LOSS_ALGO}"
]

# Initialize configuration
print(f"üîß Initializing configuration from: {config_file}")
config = pyconfig.initialize(config_argv)
max_utils.print_system_information()

print("\n‚úÖ Configuration initialized successfully")
print(f"üìä Training steps: {config.steps}")
print(f"üìÅ Output directory: {config.base_output_directory}")
print(f"ü§ñ Model: {config.model_name}")

In [None]:
# Build configuration for GRPO training
# Using rl.yml as the base config (not grpo.yml)
config_file = os.path.join(MAXTEXT_REPO_ROOT, "src/MaxText/configs/rl.yml")

# Verify chat template exists
if not os.path.exists(os.path.join(MAXTEXT_REPO_ROOT, CHAT_TEMPLATE_PATH)):
    raise FileNotFoundError(f"Chat template not found: {CHAT_TEMPLATE_PATH}")

# Build argv list for pyconfig.initialize()
config_argv = [
    "",  # argv[0] placeholder
    config_file,
    f"model_name={MODEL_NAME}",
    f"tokenizer_path={HF_REPO_ID}",
    f"hf_model_name={HF_REPO_ID}",
    f"chat_template_path={CHAT_TEMPLATE_PATH}",
    f"load_parameters_path={MODEL_CHECKPOINT_PATH}",
    f"base_output_directory={OUTPUT_DIRECTORY}",
    f"hf_access_token={HF_TOKEN}",
    f"steps={STEPS}",
    f"per_device_batch_size={PER_DEVICE_BATCH_SIZE}",
    f"learning_rate={LEARNING_RATE}",
    f"num_generations={NUM_GENERATIONS}",
    f"grpo_beta={GRPO_BETA}",
    f"grpo_epsilon={GRPO_EPSILON}",
    f"chips_per_vm={CHIPS_PER_VM}",
]

# Initialize configuration
print(f"üîß Initializing configuration from: {config_file}")
config = pyconfig.initialize(config_argv)
max_utils.print_system_information()

print("\n‚úÖ Configuration initialized successfully")
print(f"üìä Training steps: {config.steps}")
print(f"üìÅ Output directory: {config.base_output_directory}")
print(f"ü§ñ Model: {config.model_name}")

In [None]:
# Execute GRPO/GSPO training
print("\n" + "="*80)
print("üöÄ Starting Training...")
print("="*80)
print(1)
try:
    # Call the rl_train function (it handles everything internally)
    rl_train(config)
    
    print("\n" + "="*80)
    print("‚úÖ Training Completed Successfully!")
    print("="*80)
    print(f"üìÅ Checkpoints saved to: {config.checkpoint_dir}")
    print(f"üìä TensorBoard logs: {config.tensorboard_dir}")
    print(f"üéØ Model ready for inference!")
    
except Exception as e:
    print("\n" + "="*80)
    print("‚ùåTraining Failed!")
    print("="*80)
    print(f"Error: {str(e)}")
    import traceback
    traceback.print_exc()
    print("\nüí° Common issues:")
    print("  - Check that MODEL_CHECKPOINT_PATH points to a valid checkpoint")
    print("  - Ensure HF_TOKEN environment variable is set")
    print("  - Verify OUTPUT_DIRECTORY is writable")
    print("  - Check hardware requirements (TPU/GPU availability)")

## üìö Learn More

- **CLI Usage**: Run `python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml --model_name=llama3.1-8b ...`
- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options
- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation
- **Examples**: See other examples in `src/MaxText/examples/`