# GRPO Llama3.1-8B Demo: Direct Function Call

This notebook demonstrates GRPO training by directly calling the `rl_train` function from `rl_trainer.py`.

## What is GRPO?

GRPO (Group Relative Policy Optimization) is an RL algorithm that enhances reasoning abilities of LLMs by:
1. Generating multiple responses for each prompt
2. Evaluating responses using reward models  
3. Calculating relative advantages to update the policy


This notebook imports and calls the `rl_train` function 

## Hardware Requirements

- Single host TPUVM (v6e-8/v5p-8) or multi-host with Pathways
- Sufficient memory for Llama3.1-8B model

## Setup

Install dependencies and set up the environment:

In [None]:
# Clone MaxText repository
!git clone https://github.com/AI-Hypercomputer/maxtext.git
%cd maxtext

In [None]:
# Install GRPO-specific dependencies
!./src/MaxText/examples/install_tunix_vllm_requirement.sh

# Install additional requirements
%uv pip install --force-reinstall numpy==2.1.2
%uv pip install nest_asyncio

In [None]:
%load_ext autoreload
%autoreload 2
import nest_asyncio
nest_asyncio.apply()  # Fix for Colab event loop

## Configuration

Set up the training parameters:

In [None]:
# Configuration for GRPO training
import os

# Set up paths
MAXTEXT_REPO_ROOT = os.path.expanduser("~") + "/maxtext"
print(f"MaxText Home directory: {MAXTEXT_REPO_ROOT}")

# Training configuration
MODEL_CHECKPOINT_PATH = "gs://zhehui_tpu/llama3.1-8b-Instruct/llama3.1-8b-Instruct/scanned-pathways/0/items"
OUTPUT_DIRECTORY = "/tmp/grpo_output"
STEPS = 10  # Reduced for demo purposes
# Please make sure your token has the right permissions!!!!!!
HF_TOKEN = os.environ.get("HF_TOKEN", "YOUR_HF_TOKEN")

print(f"Model checkpoint: {MODEL_CHECKPOINT_PATH}")
print(f"Output directory: {OUTPUT_DIRECTORY}")
print(f"Training steps: {STEPS}")

In [None]:
# Import GRPO training function directly
import sys
import os
from pathlib import Path

# Add MaxText to Python path
maxtext_path = Path(MAXTEXT_REPO_ROOT) / "src" / "MaxText"
sys.path.insert(0, str(maxtext_path))

# Import required modules
from MaxText import pyconfig
from MaxText.rl.train_rl import rl_train, setup_configs_and_devices

print("‚úÖ Successfully imported GRPO training function")
print(f"üìÅ MaxText path: {maxtext_path}")

In [None]:
# Build configuration for GRPO training
config_argv = [
    "",  # Placeholder for argv[0]
    os.path.join(MAXTEXT_REPO_ROOT, "src/MaxText/configs/rl.yml"),  # Base config
    f"model_name=llama3.1-8b",
    f"tokenizer_path=meta-llama/Llama-3.1-8B-Instruct",
    f"load_parameters_path={MODEL_CHECKPOINT_PATH}",
    f"hf_access_token={HF_TOKEN}",
    "run_name=test"
]

# Create configuration object
trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(config_argv)

print("‚úÖ Configuration created successfully")
print(f"üìä Training steps: {trainer_config.steps}")
print(f"üìÅ Output directory: {trainer_config.base_output_directory}")
print(f"ü§ñ Model: {trainer_config.model_name}")

In [None]:
# Execute GRPO training directly
try:
    # Call the rl_train function
    print("\n" + "="*80)
    print("Starting GRPO Training...")
    print("="*80)
    grpo_trainer, rl_cluster = rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)
    
    print("\n" + "="*80)
    print("‚úÖ GRPO Training Completed Successfully!")
    print("="*80)
    print(f"üìÅ Checkpoints and logs saved to: {trainer_config.base_output_directory}")
    print(f"üéØ Final model ready for inference!")
    
except Exception as e:
    print("\n" + "="*80)
    print("‚ùå GRPO Training Failed!")
    print("="*80)
    print(f"Error: {str(e)}")
    print("\nPlease check the error message and try again.")

### üìö **Learn More**
- See `src/MaxText/examples/grpo_runner.py` for CLI usage
- Check `src/MaxText/configs/grpo.yml` for configuration options
- Read `src/MaxText/examples/README.md` for more examples