# Llama3.1-8B-Instruct Reinforcement Learning Demo

This notebook demonstrates training on Llama3.1-8B-Instruct model with either GRPO (Group Relative Policy Optimization) or GSPO (Group Sequence Policy Optimization).

## What is GRPO/GSPO?

GRPO/GSPO is an RL algorithm that enhances reasoning abilities of LLMs by:
1. Generating multiple responses for each prompt
2. Evaluating responses using reward models  
3. Calculating relative advantages to update the policy

The difference is in the loss function - either it's optimizing each token (GRPO) or the whole sequence(GSPO).

## Hardware Requirements

- Single host TPUVM (v6e-8/v5p-8)
- Sufficient memory for Llama3.1-8B model

### Get Your Hugging Face Token

To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.

**Follow these steps to get your token:**

1.  **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:
    *   [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)

2.  **Create a new token** by clicking the **"+ Create new token"** button.

3.  **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.

4.  **Copy the generated token**. You will need to paste it in the next step.

**Follow these steps to store your token:**

Just put your token in the line below

In [None]:
HF_TOKEN = "" # Set HF_TOKEN environment variable


## Setup

Install dependencies and set up the environment:
https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/rl.html#from-github

## Configuration

Set up the training parameters. We use a single host TPU. Defaults are hardcoded for Llama3.1-8B:

In [None]:
# if you have cloned the maxtext repo, you should set the path to the maxtext/src folder
# otherwise, you can just run the cell below
!cd ~/maxtext/src/  #  This is the path to the maxtext/src folder

In [None]:
#Choose the loss algorithm between GSPO or GRPO
LOSS_ALGO="grpo" #  or "gspo-token" if you want to use GSPO

In [None]:
import os
import sys
from pathlib import Path
import MaxText
from huggingface_hub import login

# Set up paths (adjust if needed)
MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)
RUN_NAME="grpo_test"
# Hardcoded defaults for Llama3.1-8B
MODEL_NAME = "llama3.1-8b"
HF_REPO_ID = "meta-llama/Llama-3.1-8B-Instruct"
CHAT_TEMPLATE_PATH = f"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json"

# Required: Set these before running
MODEL_CHECKPOINT_PATH = ""  # Update this!
if not MODEL_CHECKPOINT_PATH:
    raise RuntimeError("MODEL_CHECKPOINT_PATH is not set")
    
OUTPUT_DIRECTORY = ""  # Update this!
if not OUTPUT_DIRECTORY:
    raise RuntimeError("OUTPUT_DIRECTORY is not set")
    
os.environ["HF_TOKEN"] = HF_TOKEN
if "MAXTEXT_PKG_DIR" not in os.environ:
    os.environ["MAXTEXT_PKG_DIR"] = MAXTEXT_REPO_ROOT

if HF_TOKEN:
    login(token=HF_TOKEN)
    print("Authenticated with Hugging Face")
else:
    print("Authentication failed: Hugging Face token not set")


print(f"üìÅ MaxText Home: {MAXTEXT_REPO_ROOT}")
print(f"ü§ñ Model: {MODEL_NAME}")
print(f"üì¶ Checkpoint: {MODEL_CHECKPOINT_PATH}")
print(f"üíæ Output: {OUTPUT_DIRECTORY}")
print(f"üîë HF Token: {'‚úÖ Set' if HF_TOKEN else '‚ùå Missing - set HF_TOKEN env var'}")
print(f"Loss Algorithm : {LOSS_ALGO}")

In [None]:
# Add MaxText to Python path
maxtext_path = Path(MAXTEXT_REPO_ROOT) 
sys.path.insert(0, str(maxtext_path))

from MaxText import pyconfig, max_utils
from MaxText.rl.train_rl import rl_train, setup_configs_and_devices

# Initialize JAX 
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["SKIP_JAX_PRECOMPILE"] = "1"  # Faster startup for vLLM

print("‚úÖ Successfully imported modules")
print(f"üìÅ MaxText path: {maxtext_path}")

In [None]:
# Build configuration for GRPO training
config_file = os.path.join(MAXTEXT_REPO_ROOT, "configs", "rl.yml")

# Verify chat template exists
if not os.path.exists(CHAT_TEMPLATE_PATH):
    raise FileNotFoundError(f"Chat template not found: {CHAT_TEMPLATE_PATH}")

# Build argv list for pyconfig.initialize()
config_argv = [
    "",  # argv[0] placeholder
    config_file,
    f"model_name={MODEL_NAME}",
    f"tokenizer_path={HF_REPO_ID}",
    f"run_name={RUN_NAME}",
    f"chat_template_path={CHAT_TEMPLATE_PATH}",
    f"load_parameters_path={MODEL_CHECKPOINT_PATH}",
    f"base_output_directory={OUTPUT_DIRECTORY}",
    f"hf_access_token={HF_TOKEN}",
    f"debug.rl=False",
    f"rl.loss_algo={LOSS_ALGO}",
    "use_pathways=False"
]

# Initialize configuration
print(f"üîß Initializing configuration from: {config_file}")
trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(config_argv)

rl_train_steps = int(
      trainer_config.num_batches
      * trainer_config.rl.num_iterations
      * trainer_config.train_fraction
      * trainer_config.num_epoch
  )

print("\n‚úÖ Configuration initialized successfully")
print(f"üìÅ Output directory: {trainer_config.base_output_directory}")
print(f"ü§ñ Model: {trainer_config.model_name}")
print(f"üìä RL Train Steps: {rl_train_steps}")

In [None]:
# Execute GRPO/GSPO training
print("\n" + "="*80)
print("üöÄ Starting Training...")
print("="*80)
try:
    # Call the rl_train function (it handles everything internally)
    rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)
    
    print("\n" + "="*80)
    print("‚úÖ Training Completed Successfully!")
    print(f"‚úçÔ∏è Note the improved evaluation accuracy metrics with just {rl_train_steps} RL training steps!")
    print("="*80)
    print(f"üìÅ Checkpoints saved to: {trainer_config.checkpoint_dir}")
    print(f"üìä TensorBoard logs: {trainer_config.tensorboard_dir}")
    print(f"üéØ Model ready for inference!")
    
except Exception as e:
    print("\n" + "="*80)
    print("‚ùåTraining Failed!")
    print("="*80)
    print(f"Error: {str(e)}")
    import traceback
    traceback.print_exc()
    print("\nüí° Common issues:")
    print("  - Check that MODEL_CHECKPOINT_PATH points to a valid checkpoint")
    print("  - Ensure HF_TOKEN environment variable is set")
    print("  - Verify OUTPUT_DIRECTORY is writable")
    print("  - Check hardware requirements (TPU/GPU availability)")

## üìö Learn More

- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/rl.html#run-grpo
- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options
- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation