<a href="https://colab.research.google.com/github/google/tunix/blob/main/examples/grpo_demo.ipynb" ><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This tutorial demonstrates training the Gemma 2 2B-IT model on the GSM8K math
reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can
enhance your model's problem-solving skills on mathematical word problems,
coding problems, etc.

GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It
is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by
eliminating the need for a separate value function model. GRPO works by
generating multiple responses for a given prompt, evaluating these responses
using a reward model, and then calculating a relative advantage based on the
group's performance to update the policy.

In this tutorial we use Colab's `v2-8` TPU. Let's get started!

## Install necessary libraries

In [1]:
!pip install -q kagglehub

!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q grain
!pip install -q jax==0.6.2 jaxlib==0.6.2
# !pip install -q git+https://github.com/google/tunix
! pip install -e ~/tunix/
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

!pip install -q tensorflow-datasets

!pip install -q git+https://github.com/AI-Hypercomputer/pathways-utils.git

! pip install -q ~/tunix/



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip

In [2]:
!pip install ipywidgets


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


## Imports

In [3]:
import functools
import gc
import os
from pprint import pprint
import re
import time

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from qwix import lora
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.models.gemma import data as data_lib
from tunix.models.gemma import gemma as gemma_lib
from tunix.models.gemma import params as params_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.rollout import base_rollout
from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
from tunix.sft import metrics_logger

os.environ['TPU_LIBRARY_PATH'] = '/home/mazumdera_google_com/venv-py311/lib/python3.11/site-packages/libtpu/libtpu.so'



 This a JAX bug; please report an issue at https://github.com/jax-ml/jax/issues
  _warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report "


## Hyperparameters

Let's define the configuration we are going to use. Note that this is by no
means a "perfect" set of hyperparameters. To get good results, you might have
to train the model for longer.

In [4]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

# ====== LoRA ======
RANK = 64
ALPHA = 64.0

# ====== Sharding ======
MESH = [(1, 4), ("fsdp", "tp")]

# ====== GRPO ======
# === Generation during GRPO training ===
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 768
# Important to keep a high-ish temperature for varied, diverse responses during
# training.
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
# The number of times the policy generates multiple responses for a given prompt
# within a single training step. This corresponds to `G` in Algorithm 1 in the
# paper. The "group" in GRPO comes from here.
NUM_GENERATIONS = 2

# === other GRPO configs ===
# The number of iterations per batch (𝜇 in GRPO algo 1).
NUM_ITERATIONS = 1
# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.
# Important to keep a high enough value for this, otherwise, the KL divergence
# can increase unchecked.
BETA = 0.08
# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
# stable updates.
EPSILON = 0.2

# ====== Training ======
BATCH_SIZE = 1
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
NUM_BATCHES = 3738
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 5 #100 #Anisha: making it small for quick eval

EVAL_EVERY_N_STEPS = 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1  # can potentially train for more epochs

# Number of training steps.
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/home/mazumdera_google_com/content/intermediate_ckpt/"
CKPT_DIR = "/home/mazumdera_google_com/content/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

## Utility functions

In [5]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

## Data preprocessing

First, let's define some special tokens. We instruct the model to first reason
between the `<reasoning>` and `</reasoning>` tokens. After
reasoning, we expect it to provide the answer between the `<answer>` and
`</answer>` tokens.

In [6]:
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"


SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \
provide your reasoning. Place it between {reasoning_start} and \
{reasoning_end}. Then, provide the final answer (i.e., just one numerical \
value) between {solution_start} and {solution_end}."""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""

We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems.

In [7]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()


def get_dataset(data_dir, split="train") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  data = tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": TEMPLATE.format(
                  system_prompt=SYSTEM_PROMPT,
                  question=x["question"].decode("utf-8"),
              ),
              # passed to reward functions
              "question": x["question"].decode("utf-8"),
              # passed to reward functions
              "answer": extract_hash_answer(x["answer"].decode("utf-8")),
          }
      )
  )
  return dataset

In [8]:
dataset = get_dataset(TRAIN_DATA_DIR, "train").batch(BATCH_SIZE)[:NUM_BATCHES]

if TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(NUM_EPOCHS)
  val_dataset = None
else:
  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(NUM_EPOCHS)

  val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)

test_dataset = get_dataset(TEST_DATA_DIR, "test").batch(BATCH_SIZE)[
    :NUM_TEST_BATCHES
]

len(train_dataset), len(val_dataset) if val_dataset is not None else 0, len(
    test_dataset
)

(3738, 0, 5)

Let's see how one batch of the dataset looks like!


In [9]:
for ele in train_dataset[:1]:
  pprint(ele)

{'answer': array(['13'], dtype='<U2'),
 'prompts': array(['<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nJane is painting her fingernails. She applies a base coat that takes 2 minutes to dry, two color coats that take 3 minutes each to dry, and a clear top coat that takes 5 minutes to dry. How many minutes total does Jane spend waiting for her nail polish to dry?<end_of_turn>\n<start_of_turn>model'],
      dtype='<U535'),
 'question': array(['Jane is painting her fingernails. She applies a base coat that takes 2 minutes to dry, two color coats that take 3 minutes each to dry, and a clear top coat that takes 5 minutes to dry. How many minutes total does Jane spend waiting for her nail polish to dry?'],
      dtype='<U260')}


## Load the policy model and the reference model

The policy model is the model which is actually trained and whose weights are
updated. The reference model is the model with which we compute KL divergence.
This is to ensure that the policy updates are not huge and that it does not
deviate too much from the reference model.

Typically, the reference model is the base model, and the policy model is the
same base model, but with LoRA parameters. Only the LoRA parameters are updated.

Note: We perform full precision (fp32) training. You can, however, leverage
Qwix for QAT.

To load the model, you need to be on [Kaggle](https://www.kaggle.com/) and need
to have agreed to the Gemma license
[here](https://www.kaggle.com/models/google/gemma/flax/).

In [10]:
# # Log in
# if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
#   kagglehub.login()

In [11]:
# kaggle_ckpt_path = kagglehub.model_download("google/gemma-2/flax/gemma2-2b-it")
kaggle_ckpt_path = kagglehub.model_download("google/gemma/flax/2b")

In [12]:
# # This is a workaround. The checkpoints on Kaggle don't work with NNX. So, we
# # load the model, save the checkpoint locally, and then reload the model
# # (sharded).
# params = params_lib.load_and_format_params(
#     os.path.join(kaggle_ckpt_path, "gemma2-2b-it")
# )
# gemma = gemma_lib.Transformer.from_params(params, version="2-2b-it")
# checkpointer = ocp.StandardCheckpointer()
# _, state = nnx.split(gemma)
# checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)

In [13]:
# # Wait for the ckpt to save successfully.
# time.sleep(60)

In [14]:
# # Delete the intermediate model to save memory.
# del params
# del gemma
# del state
# gc.collect()

### Load MaxText model

In [21]:
import sys
import os

# add the parent directory (one level up) to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../../maxtext')))

# ! pip install -r ../../maxtext/requirements.txt

import MaxText as mt
from MaxText import pyconfig



#### Convert MaxText model to nnx (use a commit from MaxText repo prior to )



In [22]:
# from MaxText.integrations.tunix.tunix_utils import build_tunix_wrapper
from flax import linen as nn

def get_ref_maxtext_model():

  #python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-2b checkpoint_period=5

  #TODO: @mazumdera: change this to use Gemma2-2b-it
  config = pyconfig.initialize(
      ["", "../../maxtext/MaxText/configs/base.yml"], #TODO: @mazumdera: why decode.py?
      base_output_directory="gs://dummy_output_dir",  # This is not used in Tunix.
      # run_name="test-tunix-maxtext-llama3-8b",
      run_name="test-tunix-maxtext-gemma-2b",
      # dataset_path=we use Tunix's dataset
      # load_parameters_path="gs://maxtext-gemma/2b/", #TODO: @mazumdera: change this to use checkpoint
      # tokenizer_type="tiktoken",
      # tokenizer_path="assets/tokenizer_llama3.tiktoken",
      tokenizer_path="../../maxtext/assets/tokenizer.gemma",
      per_device_batch_size=8,
      max_target_length=8192,
      steps=10,
      async_checkpointing="false",
      # model_name="llama3.1-8b",
      model_name="gemma-2b",
      checkpoint_period=5,
      skip_jax_distributed_system="true",
      weight_dtype="bfloat16",
      attention="dot_product"

  )
  
  def create_model(config):
    return mt.from_pretrained(config, rngs=nnx.Rngs(params=0, dropout=1))

  model = nnx.eval_shape(create_model, config=config)

  abstract_model = nnx.eval_shape(create_model, config=config)
  graphdef, abstract_state = nnx.split(abstract_model)
  print('The abstract NNX state (all leaves are abstract arrays):')
  nnx.display(abstract_state)
  checkpoint = mt.checkpointing.load_params_from_path(
      load_parameters_from_path="gs://maxtext-gemma/2b/2025-08-05-04-37/0/items",
      abstract_unboxed_params=None,
      checkpoint_storage_concurrent_gb=None,
  )
  print("{checkpoint=}")
  checkpoint = {}

  @nnx.jit
  def partial_init(checkpoint, config):
    model = create_model(config)
    nnx.update(model, checkpoint)
    # shard model
    state = nnx.state(model)
    specs = nnx.get_partition_spec(state)
    state = jax.lax.with_sharding_constraint(state, specs)
    nnx.update(model, state)
    return model

  with jax.sharding.use_mesh(model.mesh), nn.logical_axis_rules(config.logical_axis_rules):
    model = partial_init(checkpoint, config)
  print(model)

  
  tunix_model = TunixMaxTextLlama(
        base_model=model,
        use_attention_mask=False,  # trust Tunix loss masking
    )
  mesh  = tunix_model.base.mesh
  
  #TODO: @mazumdera: change this to use llama3.1-8b
  # model_config = None
  # We can continue to use Tunix's model_config
  model_config = gemma_lib.TransformerConfig.gemma2_2b()

  # Add these lines to properly get the graph definition and state
  graphdef, state = nnx.split(tunix_model)
  tunix_model = nnx.merge(graphdef, state)  # Recreate model in proper NNX format
    
  
  return tunix_model, mesh, model_config

# def get_ref_maxtext_model(config, mesh=None):

#   #python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-2b checkpoint_period=5

  
#   #TODO: Anisha: 
#   # model = mt.from_pretrained(config)
  
#   rngs = nnx.Rngs(1234)
#   model = build_tunix_wrapper(
#         config,
#         rngs,
#         enable_dropout=False,   # deterministic SFT (you can override at runtime)
#         init_batch_size=1,
#         init_seq_len=1,
#         use_attention_mask=False,  # trust Tunix loss masking
#     )
#   mesh  = model.base.mesh
  

#   # We can continue to use Tunix's model_config
#   model_config = gemma_lib.TransformerConfig.gemma2_2b()

#   # Add these lines to properly get the graph definition and state
#   graphdef, state = nnx.split(model)
#   model = nnx.merge(graphdef, state)  # Recreate model in proper NNX format

    
  
#   return model, mesh, model_config

In [23]:
# Base model
# gemma, mesh, model_config = get_base_model(
#     ckpt_path=os.path.join(INTERMEDIATE_CKPT_DIR, "state")
# )
from MaxText.integration.tunix.tunix_adaptor import TunixMaxTextLlama

gemma, mesh, model_config = get_ref_maxtext_model()
# gemma_maxtext_nnx = nnx.bridge.ToNNX(gemma)
# Instead of:
nnx.display(gemma)

# Use:
print("Model initialized successfully")
print(f"Model mesh shape: {mesh.shape}")
print(f"Model config: {model_config}")

Updating keys from env and command line: ['run_name', 'model_name', 'async_checkpointing', 'checkpoint_period', 'weight_dtype', 'attention', 'base_output_directory', 'tokenizer_path', 'per_device_batch_size', 'steps', 'skip_jax_distributed_system', 'max_target_length']
Running Model: gemma-2b
Updating following parameters in config

base_emb_dim: 2048
base_num_query_heads: 8
base_num_kv_heads: 1
base_mlp_dim: 16384
base_num_decoder_layers: 18
head_dim: 256
mlp_activations: ['gelu', 'linear']
vocab_size: 256128
decoder_block: gemma
normalization_layer_epsilon: 1e-06
logits_via_embedding: True
Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'decoder_block', 'normalization_layer_epsilon', 'logits_via_embedding']
Skipping jax distributed system due to skip_jax_distributed_system=True flag.
Not using emergency checkpoint, ignoring local_checkpoint_directory, local_

Config param load_balance_loss_weight: 0.01
Config param load_from_prefill_dir: False
Config param load_full_state_path: 
Config param load_parameters_path: 
Config param local_checkpoint_directory: 
Config param local_checkpoint_period: 0
Config param local_rope_max_timescale: -1
Config param log_config: True
Config param log_period: 100
Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_batch_no_exp', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('data', 'stage', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_heads', ('tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive')), ('activation_kv_heads', ('tensor', 'tensor_transpose', 'sequence', 'tensor_sequence')), ('activation_length', ('sequence', 'context')), ('prefill_activation_length', ('sequence', 'context')), ('activation_length', ('context',)), ('activation_norm_length', ('tensor_sequence', 'context', 'sequenc

KeyError: '__jax_array__'

In [None]:

# # Policy model
# lora_gemma = get_lora_model(gemma, mesh=mesh)
# nnx.display(lora_gemma)


# Policy model
# This can remain unchanged from default Tunix's colab
# lora_gemma = get_lora_model(gemma, mesh=mesh)

# TODO: @mazumdera: change this to use lora
# lora_gemma = get_lora_model(gemma, mesh=mesh)
# nnx.display(lora_gemma)

gemma_policy, mesh_policy, model_config_policy = get_ref_maxtext_model()

# gemma_maxtext_nnx = nnx.bridge.ToNNX(gemma)
# Instead of:
nnx.display(gemma_policy)

# Use:
print("Model initialized successfully")
print(f"Model mesh shape: {mesh_policy.shape}")
print(f"Model config: {model_config_policy}")




Updating keys from env and command line: ['run_name', 'model_name', 'async_checkpointing', 'checkpoint_period', 'weight_dtype', 'attention', 'base_output_directory', 'tokenizer_path', 'per_device_batch_size', 'steps', 'skip_jax_distributed_system', 'max_target_length']
Running Model: gemma-2b
Updating following parameters in config

base_emb_dim: 2048
base_num_query_heads: 8
base_num_kv_heads: 1
base_mlp_dim: 16384
base_num_decoder_layers: 18
head_dim: 256
mlp_activations: ['gelu', 'linear']
vocab_size: 256128
decoder_block: gemma
normalization_layer_epsilon: 1e-06
logits_via_embedding: True
Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'decoder_block', 'normalization_layer_epsilon', 'logits_via_embedding']
Skipping jax distributed system due to skip_jax_distributed_system=True flag.
Not using emergency checkpoint, ignoring local_checkpoint_directory, local_

Config param global_batch_size_to_load: 64
Config param global_batch_size_to_load_eval: 64
Config param global_batch_size_to_train_on: 64
Config param global_parameter_scale: 1
Config param goodput_upload_interval_seconds: 30
Config param gradient_accumulation_steps: 1
Config param gradient_clipping_threshold: 1.0
Config param grain_eval_files: 
Config param grain_file_type: arrayrecord
Config param grain_train_files: 
Config param grain_worker_count: 1
Config param grain_worker_count_eval: 1
Config param hardware: tpu
Config param head_dim: 256
Config param heartbeat_reporting_interval_in_seconds: 5
Config param hf_data_dir: 
Config param hf_eval_files: 
Config param hf_eval_split: 
Config param hf_path: 
Config param hf_train_files: 
Config param hidden_size_for_vit: 1408
Config param ici_autoregressive_parallelism: 1
Config param ici_context_autoregressive_parallelism: 1
Config param ici_context_parallelism: 1
Config param ici_data_parallelism: 1
Config param ici_expert_parallelism:

restoring params from gs://maxtext-gemma/2b/2025-08-05-04-37/0/items
Creating checkpoint manager with ocdbt=True and zarr3=True
Checkpoint manager created!




{checkpoint=}
Num_devices: 8, shape (1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
inputs=[[[1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  [1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  [1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  ...
  [1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  [1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  [1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]]

 [[1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  [1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  [1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  ...
  [1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  [1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  [1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]]

 [[1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  [1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  [1.78125 0.707031 1.24219 ... 1.05469 0.0441895 1.375]
  ...
  [1.78125 0.707031 1.24219 ... 1.05469 0.0

Model initialized successfully
Model mesh shape: OrderedDict([('data', 1), ('stage', 1), ('fsdp', 8), ('fsdp_transpose', 1), ('sequence', 1), ('context', 1), ('context_autoregressive', 1), ('tensor', 1), ('tensor_transpose', 1), ('tensor_sequence', 1), ('expert', 1), ('autoregressive', 1)])
Model config: TransformerConfig(num_layers=26, num_embed=256128, embed_dim=2304, hidden_dim=9216, num_heads=8, head_dim=256, num_kv_heads=4, final_logit_softcap=30.0, use_post_attn_norm=True, use_post_ffw_norm=True, attention_types=(<AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType

## Define reward functions

We define four reward functions:

- reward if the format of the output exactly matches the instruction given in
`TEMPLATE`;
- reward if the format of the output approximately matches the instruction given
in `TEMPLATE`;
- reward if the answer is correct/partially correct;
- Sometimes, the text between `<answer>`, `</answer>` might not be one
  number. So, extract the number, and reward the model if the answer is correct.

The reward functions are inspired from
[here](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).

First off, let's define a RegEx for checking whether the format matches.

In [None]:
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_format.search(
    f"{reasoning_start}Let me"
    f" think!{reasoning_end}{solution_start}2{solution_end}",
)

<re.Match object; span=(0, 54), match='<reasoning>Let me think!</reasoning><answer>2</an>

Give the model a reward of 3 points if the format matches exactly.

In [None]:
def match_format_exactly(prompts, completions, **kargs):
  scores = []
  for completion in completions:
    score = 0
    response = completion
    # Match if format is seen exactly!
    if match_format.search(response) is not None:
      score += 3.0
    scores.append(score)
  return scores

We also reward the model if the format of the output matches partially.

In [None]:
def match_format_approximately(prompts, completions, **kargs):
  scores = []

  for completion in completions:
    score = 0
    response = completion
    # Count how many keywords are seen - we penalize if too many!
    # If we see 1, then plus some points!
    score += 0.5 if response.count(reasoning_start) == 1 else -0.5
    score += 0.5 if response.count(reasoning_end) == 1 else -0.5
    score += 0.5 if response.count(solution_start) == 1 else -0.5
    score += 0.5 if response.count(solution_end) == 1 else -0.5
    scores.append(score)
  return scores

Reward the model if the answer is correct. A reward is also given if the answer
does not match exactly, i.e., based on how close the answer is to the correct
value.

In [None]:
def check_answer(prompts, completions, answer, **kargs):
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_format.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  for guess, true_answer in zip(extracted_responses, answer):
    score = 0
    if guess is None:
      scores.append(0)
      continue
    # Correct answer gets 3 points!
    if guess == true_answer:
      score += 3.0
    # Match if spaces are seen
    elif guess.strip() == true_answer.strip():
      score += 1.5
    else:
      # We also reward it if the answer is close via ratios!
      # Ie if the answer is within some range, reward it!
      try:
        ratio = float(guess) / float(true_answer)
        if ratio >= 0.9 and ratio <= 1.1:
          score += 0.5
        elif ratio >= 0.8 and ratio <= 1.2:
          score += 0.25
        else:
          score -= 1.0  # Penalize wrong answers
      except:
        score -= 0.5  # Penalize
    scores.append(score)
  return scores

Sometimes, the text between `<answer>` and `</answer>` might not be one
number; it can be a sentence. So, we extract the number and compare the answer.

In [None]:
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall(f"{solution_start}  0.34  {solution_end}")

['0.34']

In [None]:
def check_numbers(prompts, completions, answer, **kargs):
  question = kargs["question"]
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_numbers.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  print("START ============================")
  print(f"Question: {question[0]}")
  print(f"Answer: {answer[0]}")
  print(f"Response: {responses[0]}")
  print(f"Extracted: {extracted_responses[0]}")
  print("END ==============================")
  for guess, true_answer in zip(extracted_responses, answer):
    if guess is None:
      scores.append(0)
      continue
    # Convert to numbers
    try:
      true_answer = float(true_answer.strip())
      guess = float(guess.strip())
      scores.append(1.5 if guess == true_answer else 0.0)
    except:
      scores.append(0)
      continue
  return scores

## Evaluate


Before we train the model, let's evaluate the model on the test set so we can
see the improvement post training.

We evaluate it in two ways:

**Quantitative**

* **Answer Accuracy**: percentage of samples for which the model predicts the
correct final numerical answer  
* **Answer (Partial) Accuracy**: percentage of samples for which the model
predicts a final numerical answer such that the \`model answer / answer\`
ratio lies between 0.9 and 1.1.  
* **Format Accuracy**: percentage of samples for which the model outputs the
correct format, i.e., reasoning between the reasoning special tokens, and the
final answer between the \`\<start\_answer\>\`, \`\<end\_answer\>\` tokens.

**Qualitative**

We'll also print outputs for a few given questions so that we can compare the generated output later.


In [None]:
def generate(
    question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
  """Given prompt, generates text."""

  if isinstance(question, str):
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=question,
        ),
    ]
  else:
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=q,
        )
        for q in question
    ]

  out_data = sampler(
      input_strings=input_batch,
      total_generation_steps=768,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
      seed=seed if seed is not None else None,
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output

In [None]:
def evaluate(
    dataset,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy and percentage of outputs matching the format."""

  response_lst = []
  corr = 0
  partially_corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, sampler, temperature, top_k, top_p, seed=p
      )
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, answer in zip(
        questions, multiple_call_responses, answers
    ):
      # check answer
      corr_ctr_per_question = 0
      partially_corr_per_question = 0
      corr_format_per_question = 0
      for response in multiple_call_response:
        extracted_response = (
            guess.group(1)
            if (guess := match_numbers.search(response)) is not None
            else "-1000000"
        )
        try:
          if float(extracted_response.strip()) == float(answer.strip()):
            corr_ctr_per_question += 1

          ratio = float(extracted_response.strip()) / float(answer.strip())
          if ratio >= 0.9 and ratio <= 1.1:
            partially_corr_per_question += 1
        except:
          print("SKIPPED")

        # check format
        if match_format.search(response) is not None:
          corr_format_per_question += 1

        if (
            corr_ctr_per_question > 0
            and partially_corr_per_question > 0
            and corr_format_per_question > 0
        ):
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      if partially_corr_per_question > 0:
        partially_corr += 1
      if corr_format_per_question > 0:
        corr_format += 1

      total += 1
      if total % 10 == 0:
        print(
            f"===> {corr=}, {total=}, {corr / total * 100=}, "
            f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
        )

  to_return = (
      corr,
      total,
      corr / total * 100,
      partially_corr / total * 100,
      corr_format / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [None]:
gemma_tokenizer = data_lib.GemmaTokenizer()
sampler = sampler_lib.Sampler(
    # transformer=lora_gemma,
    transformer=gemma_policy,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

# TODO: @mazumdera: why is this 0?
# corr=0, total=5, accuracy=0.0%, partial_accuracy=0.0%, format_accuracy=0.0%


  0%|          | 0/5 [00:00<?, ?it/s]

ValueError: Cannot infer collection name from value: frozenset({'layers', '_pytree__state', 'to_nnx__rngs', 'decoder_norm'})

In [None]:
# for eval_example in QUALITATIVE_EVAL_EXAMPLES:
#   question = eval_example["question"]
#   answer = eval_example["answer"]
#   response = generate(
#       question,
#       sampler,
#       temperature=INFERENCE_TEMPERATURE,
#       top_k=INFERENCE_TOP_K,
#       top_p=INFERENCE_TOP_P,
#   )

#   print(f"Question:\n{question}")
#   print(f"Answer:\n{answer}")
#   print(f"Response:\n{response}")
#   print("===============")

## Train

Let's set up all the configs first - checkpointing, metric logging and training.
We then train the model.

In [None]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/home/mazumdera_google_com/content/tmp/tensorboard/grpo", flush_every_n_steps=20
)

In [None]:
# Logs
%load_ext tensorboard
%tensorboard --logdir /home/mazumdera_google_com/content/tmp/tensorboard/grpo --port=0

In [None]:
# # Training config
# training_config = GrpoTrainingConfig(
#     max_prompt_length=MAX_PROMPT_LENGTH,
#     total_generation_steps=TOTAL_GENERATION_STEPS,
#     num_generations=NUM_GENERATIONS,
#     num_iterations=NUM_ITERATIONS,
#     beta=BETA,
#     epsilon=EPSILON,
#     temperature=TEMPERATURE,
#     top_p=TOP_P,
#     top_k=TOP_K,
#     eval_every_n_steps=EVAL_EVERY_N_STEPS,
#     max_steps=MAX_STEPS,
#     # metrics logging
#     metrics_logging_options=metrics_logging_options,
#     # checkpoint saving
#     checkpoint_root_directory=CKPT_DIR,
#     checkpointing_options=checkpointing_options,
# )

In [None]:
# Optimizer, learning rate scheduler, gradient clipping
# optimizer = optax.adamw(
#     learning_rate=optax.schedules.warmup_cosine_decay_schedule(
#         init_value=0.0,
#         peak_value=LEARNING_RATE,
#         warmup_steps=WARMUP_STEPS,
#         decay_steps=MAX_STEPS,
#         end_value=0.0,
#     ),
#     b1=B1,
#     b2=B2,
#     weight_decay=WEIGHT_DECAY,
# )
#TODO: @mazumdera: try optimizer offloading with adamw
optimizer = optax.adafactor(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
)
if MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
      optimizer,
  )

In [None]:
# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        gradient_accumulation_steps=1,
        # metrics logging
        metrics_logging_options=metrics_logging_options,
        # checkpoint saving
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
    ),
)

grpo_config = GrpoConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

In [None]:



# Now lora_gemma's parameters are annotated with the specified sharding.
# When lora_gemma is used inside a jitted function, JAX will respect these
# shardings.

# You can inspect the sharding of a parameter's value.
# The sharding will be concrete after being passed through a jitted function.
@jax.jit
def get_sharded_kernel(model):
    return model.base.token_embedder.embedding

with mesh:
    sharded_kernel_value = get_sharded_kernel(gemma_policy)

print("Sharding of embed kernel:")
print(sharded_kernel_value)


Sharding of embed kernel:
[38;2;79;201;177mParam[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 524,550,144 (1.0 GB)[0m
  [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray([[1.15625, -0.355469, 1.42969, ..., -1.58594, 0.0245361, -0.644531],
         [1.78125, 0.707031, 1.24219, ..., 1.05469, 0.0441895, 1.375],
         [-0.550781, 1.32812, 0.302734, ..., -0.691406, 0.925781, -1.74219],
         ...,
         [-0.96875, -0.851562, -1, ..., 0.0441895, -1.74219, 1.42969],
         [0.00488281, 0.302734, 0.65625, ..., -1.83594, -0.192383,
          0.162109],
         [-0.251953, -1.65625, 0.730469, ..., -1.25781, -0.273438,
          -0.878906]], dtype=bfloat16),
  [38;2;156;220;254msharding[0m[38;2;212;212;212m=[0m[38;2;255;213;3m([0m[38;2;207;144;120m'vocab'[0m, [38;2;207;144;120m'embed'[0m[38;2;255;213;3m)[0m
[38;2;255;213;3m)[0m


In [None]:
# RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    actor=gemma_policy,
    reference=gemma,
    tokenizer=data_lib.GemmaTokenizer(),
    cluster_config=cluster_config,
)

# GRPO Trainer
grpo_trainer = GrpoLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    grpo_config=grpo_config,
)


Cluster config: ClusterConfig(role_to_mesh={<Role.ACTOR: 'actor'>: Mesh(device_ids=array([[[[[[[[[[[[0]]]]]]]]],








         [[[[[[[[[1]]]]]]]]],








         [[[[[[[[[2]]]]]]]]],








         [[[[[[[[[3]]]]]]]]],








         [[[[[[[[[7]]]]]]]]],








         [[[[[[[[[6]]]]]]]]],








         [[[[[[[[[5]]]]]]]]],








         [[[[[[[[[4]]]]]]]]]]]]), axis_names=('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'), axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), <Role.REFERENCE: 'reference'>: Mesh(device_ids=array([[[[[[[[[[[[0]]]]]]]]],








         [[[[[[[[[1]]]]]]]]],








         [[[[[[[[[2]]]]]]]]],








         [[[[[[[[[3]]]]]]]]],








         [[[[[[[[[7]]]]]]]]],








         [[[[[[[[[6]]]]]]]]],








         [[[[[[[[[5]]]]]]]]],








         [[[[[[[[[4]]]]]]]]]]]]), ax

[34m[1mwandb[0m: Currently logged in as: [33manony-mouse-863749125460230603[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
with mesh:
  grpo_trainer.train(dataset)

ValueError: Cannot infer collection name from value: frozenset({'layers', 'decoder_norm', '_pytree__state', 'to_nnx__rngs'})

## Evaluate

Let's evaluate our model!

In [None]:
# Load checkpoint first.

trained_ckpt_path = os.path.join(CKPT_DIR, str(MAX_STEPS), "model_params")

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_gemma, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    lora_gemma,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_gemma, nnx.LoRAParam),
        trained_lora_params,
    ),
)

In [None]:
gemma_tokenizer = data_lib.GemmaTokenizer()
sampler = sampler_lib.Sampler(
    transformer=lora_gemma,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

In [None]:
# for eval_example in QUALITATIVE_EVAL_EXAMPLES:
#   question = eval_example["question"]
#   answer = eval_example["answer"]
#   response = generate(
#       question,
#       sampler,
#       temperature=INFERENCE_TEMPERATURE,
#       top_k=INFERENCE_TOP_K,
#       top_p=INFERENCE_TOP_P,
#   )

#   print(f"Question:\n{question}")
#   print(f"Answer:\n{answer}")
#   print(f"Response:\n{response}")
#   print("===============")