# GRPO Llama3.1-8B Demo: Direct Function Call

This notebook demonstrates GRPO training by directly calling the `rl_train` function from `rl_trainer.py`.

## What is GRPO?

GRPO (Group Relative Policy Optimization) is an RL algorithm that enhances reasoning abilities of LLMs by:
1. Generating multiple responses for each prompt
2. Evaluating responses using reward models  
3. Calculating relative advantages to update the policy


This notebook imports and calls the `rl_train` function 

## Hardware Requirements

- Single host TPUVM (v6e-8/v5p-8) or multi-host with Pathways
- Sufficient memory for Llama3.1-8B model

## Setup

Install dependencies and set up the environment:

In [None]:
# Clone MaxText repository
!git clone https://github.com/AI-Hypercomputer/maxtext.git
%cd maxtext

In [16]:
# Install GRPO-specific dependencies
!./src/MaxText/examples/install_tunix_vllm_requirement.sh

# Install additional requirements
%uv pip install --force-reinstall numpy==2.1.2
%uv pip install nest_asyncio

[autoreload of numpy._core.multiarray failed: Traceback (most recent call last):
  File "/home/zhehuichen_google_com/.venv/lib/python3.12/site-packages/IPython/extensions/autoreload.py", line 325, in check
    superreload(m, reload, self.old_objects)
  File "/home/zhehuichen_google_com/.venv/lib/python3.12/site-packages/IPython/extensions/autoreload.py", line 580, in superreload
    module = reload(module)
             ^^^^^^^^^^^^^^
  File "/home/zhehuichen_google_com/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/importlib/__init__.py", line 131, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 866, in _exec
  File "<frozen importlib._bootstrap_external>", line 999, in exec_module
  File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
  File "/home/zhehuichen_google_com/.venv/lib/python3.12/site-packages/numpy/_core/multiarray.py", line 105, in <module>
    _override___module__()
  File "/home/zh

/bin/bash: line 1: ./src/MaxText/examples/install_tunix_vllm_requirement.sh: No such file or directory
[2mUsing Python 3.12.12 environment at: /home/zhehuichen_google_com/.venv[0m
[2K[2mResolved [1m1 package[0m [2min 985ms[0m[0m                                          [0m
[2K[2mPrepared [1m1 package[0m [2min 0.29ms[0m[0m                                             
[2mUninstalled [1m1 package[0m [2min 115ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 56ms[0m[0m                                 [0m
 [31m-[39m [1mnumpy[0m[2m==2.3.5[0m
 [32m+[39m [1mnumpy[0m[2m==2.1.2[0m
Note: you may need to restart the kernel to use updated packages.
[2mUsing Python 3.12.12 environment at: /home/zhehuichen_google_com/.venv[0m
[2mAudited [1m1 package[0m [2min 5ms[0m[0m
Note: you may need to restart the kernel to use updated packages.


In [17]:
%load_ext autoreload
%autoreload 2
import nest_asyncio
nest_asyncio.apply()  # Fix for Colab event loop

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Configuration

Set up the training parameters:

In [18]:
# Configuration for GRPO training
import os

# Set up paths
MAXTEXT_REPO_ROOT = os.path.expanduser("~") + "/maxtext"
print(f"MaxText Home directory: {MAXTEXT_REPO_ROOT}")

# Training configuration
MODEL_CHECKPOINT_PATH = "gs://zhehui_tpu/llama3.1-8b-Instruct/llama3.1-8b-Instruct/scanned-pathways/0/items"
OUTPUT_DIRECTORY = "/tmp/grpo_output"
STEPS = 10  # Reduced for demo purposes
# Please make sure your token has the right permissions!!!!!!
HF_TOKEN = os.environ.get("HF_TOKEN", "YOUR_HF_TOKEN")

print(f"Model checkpoint: {MODEL_CHECKPOINT_PATH}")
print(f"Output directory: {OUTPUT_DIRECTORY}")
print(f"Training steps: {STEPS}")

MaxText Home directory: /home/zhehuichen_google_com/maxtext
Model checkpoint: gs://zhehui_tpu/llama3.1-8b-Instruct/llama3.1-8b-Instruct/scanned-pathways/0/items
Output directory: /tmp/grpo_output
Training steps: 10


In [19]:
# Import GRPO training function directly
import sys
import os
from pathlib import Path

# Add MaxText to Python path
maxtext_path = Path(MAXTEXT_REPO_ROOT) / "src" / "MaxText"
sys.path.insert(0, str(maxtext_path))

# Import required modules
from MaxText import pyconfig
from MaxText.rl.train_rl import rl_train, setup_configs_and_devices

print("‚úÖ Successfully imported GRPO training function")
print(f"üìÅ MaxText path: {maxtext_path}")

‚úÖ Successfully imported GRPO training function
üìÅ MaxText path: /home/zhehuichen_google_com/maxtext/src/MaxText


In [20]:
# Build configuration for GRPO training
config_argv = [
    "",  # Placeholder for argv[0]
    os.path.join(MAXTEXT_REPO_ROOT, "src/MaxText/configs/rl.yml"),  # Base config
    f"model_name=llama3.1-8b",
    f"tokenizer_path=meta-llama/Llama-3.1-8B-Instruct",
    f"load_parameters_path={MODEL_CHECKPOINT_PATH}",
    f"hf_access_token={HF_TOKEN}",
    "run_name=test"
]

# Create configuration object
trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(config_argv)

print("‚úÖ Configuration created successfully")
print(f"üìä Training steps: {trainer_config.steps}")
print(f"üìÅ Output directory: {trainer_config.base_output_directory}")
print(f"ü§ñ Model: {trainer_config.model_name}")



Skipping jax distributed system due to skip_jax_distributed_system=True flag.


TypeError: data type <DType.BFLOAT16: 'bfloat16'> not understood

In [15]:
# Execute GRPO training directly
try:
    # Call the rl_train function
    print("\n" + "="*80)
    print("Starting GRPO Training...")
    print("="*80)
    grpo_trainer, rl_cluster = rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)
    
    print("\n" + "="*80)
    print("‚úÖ GRPO Training Completed Successfully!")
    print("="*80)
    print(f"üìÅ Checkpoints and logs saved to: {trainer_config.base_output_directory}")
    print(f"üéØ Final model ready for inference!")
    
except Exception as e:
    print("\n" + "="*80)
    print("‚ùå GRPO Training Failed!")
    print("="*80)
    print(f"Error: {str(e)}")
    print("\nPlease check the error message and try again.")


Starting GRPO Training...
Starting GRPO Training
Ensuring TensorBoard log directory exists: /tmp/grpo_output/test/tensorboard/


DEBUG: resolving path. Input: src/MaxText/examples/chat_templates/gsm8k_rl_llama3.json, Package Root: /home/zhehuichen_google_com/maxtext/src/MaxText, Resolved: /home/zhehuichen_google_com/maxtext/src/MaxText/examples/chat_templates/gsm8k_rl_llama3.json, Exists: True
DEBUG: resolving path. Input: src/MaxText/examples/chat_templates/gsm8k_rl_llama3.json, Package Root: /home/zhehuichen_google_com/maxtext/src/MaxText, Resolved: /home/zhehuichen_google_com/maxtext/src/MaxText/examples/chat_templates/gsm8k_rl_llama3.json, Exists: True
{'answer': array(['13'], dtype='<U2'),
 'prompts': array(['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nRespond in the following format:<reasoning>...</reasoning><answer>...</answer><|eot_id|><|start_header_id|>user<|end_header_id|>\n\nJane is painting her fingernails. She applies a base coat that takes 2 minutes to dry, two color coats that take 3 minutes each to dry, and a cl



Num_devices: 4, shape (1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1)
Reference Model initialized successfully


Reference mesh shape: OrderedDict({'data': 1, 'stage': 1, 'fsdp': 4, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1})
maxtext_state_flatten[base.token_embedder.embedding].value=          [[0.00106049 0.00561523 -0.00341797 ... 0.00411987 -0.00280762
  -0.000671387]
 [-0.0037384 0.000972748 -0.0018158 ... 0.00152588 -0.0022583 -0.0013504]
 [0.00144196 -0.0169678 0.00315857 ... 0.00299072 0.00952148 0.00488281]
 ...
 [2.21271e-23 3.90326e-24 2.16101e-23 ... 6.36929e-23 -2.64956e-24
  -2.35746e-23]
 [2.28509e-23 -2.21012e-24 -2.22305e-23 ... 2.79173e-23 8.6854e-24
  -3.70163e-23]
 [-8.85083e-23 -7.5687e-23 6.4882e-24 ... 5.89366e-24 -6.45201e-23
  -2.71419e-24]]
Creating policy model with same config as reference on trainer mesh
Num_devices: 4, shape (1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1)
Num_devices: 4, shape (1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1)




Policy Model initialized successfully


Policy mesh shape: OrderedDict({'data': 1, 'stage': 1, 'fsdp': 4, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1})
Tensorboard logs directory: /tmp/grpo_output/test/tensorboard/
Creating RL cluster...
INFO 11-24 19:42:07 [utils.py:243] non-default args: {'load_format': 'dummy', 'max_model_len': 1280, 'tensor_parallel_size': 4, 'swap_space': 2, 'gpu_memory_utilization': 0.4, 'disable_log_stats': True, 'additional_config': {'sharding': {'sharding_strategy': {'device_indexes': [0, 1, 3, 2]}}}, 'model': 'meta-llama/Llama-3.1-8B-Instruct'}
INFO 11-24 19:42:08 [model.py:646] Resolved architecture: LlamaForCausalLM
INFO 11-24 19:42:08 [model.py:1734] Using max model len 1280
INFO 11-24 19:42:08 [scheduler.py:225] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 11-24 19:42:08 [tpu_jax.py:186] Force using UniProcExecutor for JAX on single host.
INFO 11-24 

### üìö **Learn More**
- See `src/MaxText/examples/grpo_runner.py` for CLI usage
- Check `src/MaxText/configs/grpo.yml` for configuration options
- Read `src/MaxText/examples/README.md` for more examples