# Llama3.1-8B-Instruct Reinforcement Learning Demo

This notebook demonstrates training on Llama3.1-8B-Instruct model with either GRPO (Group Relative Policy Optimization) or GSPO (Group Sequence Policy Optimization).

This notebook can run on **TPU v6e-8** or **v5p-8**.

## What is GRPO/GSPO?

GRPO/GSPO is an RL algorithm that enhances reasoning abilities of LLMs by:
1. Generating multiple responses for each prompt
2. Evaluating responses using reward models  
3. Calculating relative advantages to update the policy

The difference is in the loss function - either it's optimizing each token (GRPO) or the whole sequence(GSPO).

## Prerequisites

### Change Runtime Type (only if running on Google Colab)

**Instructions:**
1.  Navigate to the menu at the top of the screen.
2.  Click on **Runtime**.
3.  Select **Change runtime type** from the dropdown menu.
4.  Select **v6e-8** or **v5p-8 TPU** as the **Hardware accelerator**.
5. Click on **Save**.

### Get Your Hugging Face Token

To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.

**Follow these steps to get your token:**

1.  **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:
    *   [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)

2.  **Create a new token** by clicking the **"+ Create new token"** button.

3.  **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.

4.  **Copy the generated token**. You will need this in the later steps.

**Follow these steps to store your token (only if running on Google Colab):**

1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).

2. Click **"+ Add new secret"**.

3. Set the Name as **HF_TOKEN**.

4. Paste your token into the Value field.

5. Ensure the Notebook access toggle is turned On.

In [None]:
try:
  from google.colab import userdata
  print("Running the notebook on Google Colab")
  IN_COLAB = True
except ImportError:
    print("Running the notebook on Visual Studio or JupyterLab")
    IN_COLAB = False

## Installation: MaxText and Dependencies

**‚ö†Ô∏è Note:** The installation process in following cell may take a few minutes to complete. Please be patient.

In [None]:
if IN_COLAB:
    !git clone https://github.com/AI-Hypercomputer/maxtext.git
    %cd /content/maxtext

    # Install uv, a fast Python package installer
    !pip install uv

    # Install MaxText and its dependencies
    !uv pip install -e .[tpu] --resolution=lowest
    !python3 -m MaxText.install_maxtext_extra_deps

    # Install vLLM for Jax and TPUs
    !uv pip install vllm-tpu
    !uv pip install --no-deps qwix==0.1.4

### Restart Session (only if running on Google Colab)
To apply certain changes, you need to restart the session.

**Instructions:**
1.  Navigate to the menu at the top of the screen.
2.  Click on **Runtime**.
3.  Select **Restart session** from the dropdown menu.

You will be asked to confirm the action in a pop-up dialog. Click on **Yes**.

## Environment Setup

In [None]:
import datetime
import os
import sys
from pathlib import Path
import MaxText
from huggingface_hub import login
from etils import epath
import jax

from maxtext.trainers.post_train.rl.train_rl import rl_train, setup_configs_and_devices

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["SKIP_JAX_PRECOMPILE"] = "1"  # Faster startup for vLLM
# Suppress vLLM logging with a severity level below ERROR
os.environ["VLLM_LOGGING_LEVEL"] = "ERROR"

MAXTEXT_PKG_DIR = os.path.dirname(MaxText.__file__)
MAXTEXT_REPO_ROOT = os.sep.join(["maxtext" if p == "MaxText" else p for p in MAXTEXT_PKG_DIR.split(os.sep)])
print(f"MaxText installation path: {MAXTEXT_PKG_DIR}")

In [None]:
if not jax.distributed.is_initialized():
  jax.distributed.initialize()
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")

In [None]:
if IN_COLAB:
    HF_TOKEN = userdata.get("HF_TOKEN")
else:
    HF_TOKEN = os.environ.get("HF_TOKEN", "")

# If not found in the environment, prompt the user for input securely
# getpass function ensures the token is hidden while you type
if not HF_TOKEN:
    from getpass import getpass
    HF_TOKEN = getpass("Hugging Face token not found in environment. Please enter it here: ")

if HF_TOKEN:
    os.environ["HF_TOKEN"] = HF_TOKEN
    login(token=HF_TOKEN)
    print("Authenticated with Hugging Face successfully!")
else:
    print("Authentication failed: Hugging Face token is not set.")

## Model Configurations

In [None]:
MODEL_NAME = "llama3.1-8b"
TOKENIZER_PATH = "meta-llama/Llama-3.1-8B-Instruct"
RUN_NAME = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
LOSS_ALGO="grpo" #  or "gspo-token" if you want to use GSPO

CHAT_TEMPLATE_PATH = f"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json"
if not os.path.exists(CHAT_TEMPLATE_PATH):
    raise FileNotFoundError(f"Chat template not found: {CHAT_TEMPLATE_PATH}")

# set the path to the model checkpoint or leave empty to download from HuggingFace
MODEL_CHECKPOINT_PATH = ""
if not MODEL_CHECKPOINT_PATH:
   MODEL_CHECKPOINT_PATH = f"{MAXTEXT_PKG_DIR}/llama_checkpoint"
   print("Model checkpoint will be downloaded from HuggingFace at: ",  MODEL_CHECKPOINT_PATH)
   print("Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.")
    
OUTPUT_DIRECTORY = f"{MAXTEXT_PKG_DIR}/rl_llama3_output"

## Download Llama3.1-8B Model Checkpoint from Hugging Face

In [None]:
if not epath.Path(MODEL_CHECKPOINT_PATH).exists():
    # install torch for the conversion script
    !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

    !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_PKG_DIR} {sys.executable} -m maxtext.checkpoint_conversion.to_maxtext \
      {MAXTEXT_PKG_DIR}/configs/base.yml \
      model_name={MODEL_NAME} \
      base_output_directory={MODEL_CHECKPOINT_PATH} \
      hf_access_token={HF_TOKEN} \
      use_multimodal=false \
      scan_layers=true \
      skip_jax_distributed_system=True

    if not epath.Path(MODEL_CHECKPOINT_PATH).exists():
        raise ValueError("Model checkpoint conversion failed. Check the logs above.")

## MaxText Configurations

In [None]:
# Load configuration for RL training
config_argv = [
    "",
    f"{MAXTEXT_PKG_DIR}/configs/post_train/rl.yml",
    f"model_name={MODEL_NAME}",
    f"tokenizer_path={TOKENIZER_PATH}",
    f"run_name={RUN_NAME}",
    f"chat_template_path={CHAT_TEMPLATE_PATH}",
    f"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items",
    f"base_output_directory={OUTPUT_DIRECTORY}",
    f"hf_access_token={HF_TOKEN}",
    "debug.rl=False",
    f"rl.loss_algo={LOSS_ALGO}",
    "use_pathways=False"
]

trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(config_argv)

rl_train_steps = int(
    trainer_config.num_batches
    * trainer_config.rl.num_iterations
    * trainer_config.train_fraction
    * trainer_config.num_epoch
)

print("‚úì Configuration initialized successfully")
print(f"üìÅ Output directory: {trainer_config.base_output_directory}")
print(f"ü§ñ Model: {trainer_config.model_name}")
print(f"üìä RL Train Steps: {rl_train_steps}")

## RL Training

In [None]:
print("\n" + "=" * 80)
print(f"üöÄ Starting {LOSS_ALGO} Training...")
print("=" * 80)
try:
    rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)
    print("\n" + "=" * 80)
    print("‚úÖ Training Completed Successfully!")
    print(f"‚úçÔ∏è Note the improved evaluation accuracy metrics with just {rl_train_steps} RL training steps!")
    print("=" * 80)
    print(f"üìÅ Checkpoints saved to: {trainer_config.checkpoint_dir}")
    print(f"üìä TensorBoard logs: {trainer_config.tensorboard_dir}")
    print(f"üéØ Model ready for inference!")
except Exception as e:
    print("\n" + "=" * 80)
    print("‚ùåTraining Failed!")
    print("=" * 80)
    print(f"Error: {str(e)}")

## üìö Learn More

- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/rl.html
- **Configuration**: See `src/maxtext/configs/rl.yml` for all available options
- **Documentation**: Check `src/maxtext/trainers/post_train/rl/train_rl.py` for the `rl_train` function implementation