# GRPO Llama3.1-8B Demo: Direct Function Call

This notebook demonstrates GRPO training by directly calling the `grpo_train` function from `grpo_tunix_trainer.py`.

## What is GRPO?

GRPO (Group Relative Policy Optimization) is an RL algorithm that enhances reasoning abilities of LLMs by:
1. Generating multiple responses for each prompt
2. Evaluating responses using reward models  
3. Calculating relative advantages to update the policy


This notebook imports and calls the `grpo_train` function 

## Hardware Requirements

- Single host TPUVM (v6e-8/v5p-8) or multi-host with Pathways
- Sufficient memory for Llama3.1-8B model

## Setup

Install dependencies and set up the environment:

In [None]:
# Clone MaxText repository
!git clone https://github.com/AI-Hypercomputer/maxtext.git
%cd maxtext

In [None]:
# Install dependencies
!chmod +x setup.sh
!./setup.sh

# Install GRPO-specific dependencies
!./src/MaxText/examples/install_tunix_vllm_requirement.sh

# Install additional requirements
%pip install --force-reinstall numpy==2.1.2
%pip install nest_asyncio

import nest_asyncio
nest_asyncio.apply()  # Fix for Colab event loop

## Configuration

Set up the training parameters:

In [None]:
# Configuration for GRPO training
import os

# Set up paths
MAXTEXT_REPO_ROOT = os.path.expanduser("~") + "/maxtext"
print(f"MaxText Home directory: {MAXTEXT_REPO_ROOT}")

# Training configuration
MODEL_CHECKPOINT_PATH = "gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items"
OUTPUT_DIRECTORY = "/tmp/grpo_output"
STEPS = 10  # Reduced for demo purposes
HF_TOKEN = os.environ.get("HF_TOKEN", "your_hf_token_here")

print(f"Model checkpoint: {MODEL_CHECKPOINT_PATH}")
print(f"Output directory: {OUTPUT_DIRECTORY}")
print(f"Training steps: {STEPS}")

In [None]:
# Import GRPO training function directly
import sys
import os
from pathlib import Path

# Add MaxText to Python path
maxtext_path = Path(MAXTEXT_REPO_ROOT) / "src" / "MaxText"
sys.path.insert(0, str(maxtext_path))

# Import required modules
from MaxText import pyconfig
from MaxText.experimental.rl.grpo_tunix_trainer import grpo_train

print("✅ Successfully imported GRPO training function")
print(f"📁 MaxText path: {maxtext_path}")
print("\n" + "="*80)
print("Starting GRPO Training...")
print("="*80)

In [None]:
# Build configuration for GRPO training
config_argv = [
    "",  # Placeholder for argv[0]
    "src/MaxText/configs/grpo.yml",  # Base config
    f"model_name=llama3.1-8b",
    f"tokenizer_path=meta-llama/Llama-3.1-8B-Instruct",
    f"load_parameters_path={MODEL_CHECKPOINT_PATH}",
    f"base_output_directory={OUTPUT_DIRECTORY}",
    f"hf_access_token={HF_TOKEN}",
    f"steps={STEPS}",
    "per_device_batch_size=1",
    "learning_rate=3e-6",
    "num_generations=2",
    "grpo_beta=0.08",
    "grpo_epsilon=0.2",
    "trainer_devices_fraction=0.5",
    "sampler_devices_fraction=0.5",
    "chips_per_vm=4"
]

# Create configuration object
config = pyconfig.Config()
config.parse_flags(config_argv)

print("✅ Configuration created successfully")
print(f"📊 Training steps: {config.steps}")
print(f"📁 Output directory: {config.base_output_directory}")
print(f"🤖 Model: {config.model_name}")

In [None]:
# Execute GRPO training directly
try:
    # Call the grpo_train function
    grpo_trainer, rl_cluster = grpo_train(config)
    
    print("\n" + "="*80)
    print("✅ GRPO Training Completed Successfully!")
    print("="*80)
    print(f"📁 Checkpoints saved to: {config.base_output_directory}/checkpoints")
    print(f"📊 Logs available in: {config.base_output_directory}/logs")
    print(f"🎯 Final model ready for inference!")
    
except Exception as e:
    print("\n" + "="*80)
    print("❌ GRPO Training Failed!")
    print("="*80)
    print(f"Error: {str(e)}")
    print("\nPlease check the error message and try again.")

### 📚 **Learn More**
- See `src/MaxText/examples/grpo_runner.py` for CLI usage
- Check `src/MaxText/configs/grpo.yml` for configuration options
- Read `src/MaxText/examples/README.md` for more examples