<a href="https://colab.research.google.com/github/google/tunix/blob/main/examples/grpo_demo.ipynb" ><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This tutorial demonstrates training the Gemma 2B model on the GSM8K math
reasoning benchmark using Group Relative Policy Optimization (GRPO). Learn how
GRPO can enhance your model's problem-solving skills on mathematical word
problems.

GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It
is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by
eliminating the need for a separate value function model. GRPO works by
generating multiple responses for a given prompt, evaluating these responses
using a reward model, and then calculating a relative advantage based on the
group's performance to update the policy.

In this tutorial we use Colab's `v2-8` TPU. Let's get started!

## Install necessary libraries

In [None]:
!pip install -q kagglehub

!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q grain
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

## Imports

In [None]:
import functools
import gc
import os
from pprint import pprint
import re
import time

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from qwix import lora
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.models.gemma import data as data_lib
from tunix.models.gemma import gemma as gemma_lib
from tunix.models.gemma import params as params_lib
from tunix.models.gemma import sampler as sampler_lib
from tunix.rl.grpo.grpo_trainer import GrpoTrainer, GrpoTrainingConfig
from tunix.sft import metrics_logger

## Hyperparameters

Let's define the configuration we are going to use. Note that this is by no
means a "perfect" set of hyperparameters. To get good results, you will have
to train the model for longer.

In [None]:
# Data
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0
BATCH_SIZE = 2
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
NUM_BATCHES = 10
NUM_TEST_BATCHES = 2

# Model
MESH = [(1, 8), ("fsdp", "tp")]
# LoRA
RANK = 16
ALPHA = 2.0

# Train
LEARNING_RATE = 1e-5
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1


# GRPO
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 768
NUM_GENERATIONS = 2
NUM_ITERATIONS = 4
BETA = 0.04
EPSILON = 0.2
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
EVAL_EVERY_N_STEPS = 2
NUM_EPOCHS = 1
MAX_STEPS = int(NUM_BATCHES * BATCH_SIZE * NUM_EPOCHS * TRAIN_FRACTION)

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/content/intermediate_ckpt/"
CKPT_DIR = "/content/ckpts/"
SAVE_INTERVAL_STEPS = 10
MAX_TO_KEEP = 1

## Utility functions

In [None]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

## Data preprocessing

First, let's define some special tokens. We instruct the model to first reason
between the `<reasoning>` and `<end_reasoning>` tokens. After
reasoning, we expect it to provide the answer between the `<answer>` and
`<end_answer>` tokens.

In [None]:
reasoning_start = "<start_reasoning>"
reasoning_end = "<end_reasoning>"
solution_start = "<start_answer>"
solution_end = "<end_answer>"

SYSTEM_PROMPT = f"""You are given a problem.
Think about the problem and provide your reasoning.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide the final answer (i.e., just one numerical value)
between {solution_start}{solution_end}."""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model
"""

We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems.

In [None]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()


def get_dataset(data_dir, split="train") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  data = tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": TEMPLATE.format(
                  system_prompt=SYSTEM_PROMPT,
                  question=x["question"].decode("utf-8"),
              ),
              # passed to reward functions
              "question": x["question"].decode("utf-8"),
              # passed to reward functions
              "answer": extract_hash_answer(x["answer"].decode("utf-8")),
          }
      )
  )
  return dataset

In [None]:
dataset = get_dataset(TRAIN_DATA_DIR, "train").batch(BATCH_SIZE)[:NUM_BATCHES]

if TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(NUM_EPOCHS)
  val_dataset = None
else:
  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(NUM_EPOCHS)

  val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)

test_dataset = get_dataset(TEST_DATA_DIR, "test").batch(BATCH_SIZE)[
    :NUM_TEST_BATCHES
]

len(train_dataset), len(val_dataset) if val_dataset is not None else 0, len(
    test_dataset
)

Let's see how one batch of the dataset looks like!


In [None]:
for ele in train_dataset[:1]:
  pprint(ele)

## Load the policy model and the reference model

The policy model is the model which is actually trained and whose weights are
updated. The reference model is the model with which we compute KL
divergence. This is to ensure that the policy updates are not huge and that it
does not deviate too much from the reference model.

Typically, the reference model is the base model, and the policy model is the
same base model, but with LoRA parameters. Only the LoRA parameters are
updated.

To load the model, you need to be on [Kaggle](https://www.kaggle.com/) and need
to have agreed to the Gemma license
[here](https://www.kaggle.com/models/google/gemma/flax/).

In [None]:
# Log in
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  kagglehub.login()

In [None]:
kaggle_ckpt_path = kagglehub.model_download("google/gemma/flax/2b-it")

In [None]:
# This is a workaround. The checkpoints on Kaggle don't work with NNX. So, we
# load the model, save the checkpoint locally, and then reload the model
# (sharded).
params = params_lib.load_and_format_params(
    os.path.join(kaggle_ckpt_path, "2b-it")
)
gemma = gemma_lib.Transformer.from_params(params, version="2b-it")
checkpointer = ocp.StandardCheckpointer()
_, state = nnx.split(gemma)
checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)

In [None]:
# Wait for the ckpt to save successfully.
time.sleep(60)

In [None]:
# Delete the intermediate model to save memory.
del params
del gemma
del state
gc.collect()

In [None]:
def get_ref_model(ckpt_path):
  mesh = jax.make_mesh(*MESH)
  abs_gemma: nnx.Module = nnx.eval_shape(
      lambda: gemma_lib.Transformer(
          gemma_lib.TransformerConfig.gemma_2b(), rngs=nnx.Rngs(params=0)
      )
  )
  abs_state = nnx.state(abs_gemma)
  abs_state = jax.tree.map(
      lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.float32, sharding=s),
      abs_state,
      nnx.get_named_sharding(abs_state, mesh),
  )
  checkpointer = ocp.StandardCheckpointer()
  restored_params = checkpointer.restore(ckpt_path, target=abs_state)

  graph_def, _ = nnx.split(abs_gemma)
  gemma = nnx.merge(graph_def, restored_params)
  return gemma, mesh


def get_lora_model(base_model, mesh):
  lora_provider = lora.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = lora.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

In [None]:
# Reference model
gemma, mesh = get_ref_model(
    ckpt_path=os.path.join(INTERMEDIATE_CKPT_DIR, "state")
)
nnx.display(gemma)

In [None]:
# Policy model
lora_gemma = get_lora_model(gemma, mesh=mesh)
nnx.display(lora_gemma)

## Define reward functions

We define four reward functions:

- reward if the format of the output exactly matches the instruction given in
`TEMPLATE`;
- reward if the format of the output approximately matches the instruction given
in `TEMPLATE`;
- reward if the answer is correct/partially correct;
- Sometimes, the text between `<start_answer>`, `<end_answer>` might not be one
  number. So, extract the number, and reward the model if the answer is correct.

First off, let's define a RegEx for checking whether the format matches

In [None]:
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_format.search(
    f"{reasoning_start}Let me think!{reasoning_end}"
    "{solution_start}2{solution_end}",
)

Give the model a reward of 3 points if the format matches exactly.

In [None]:
def match_format_exactly(prompts, completions, **kargs):
  scores = []
  for completion in completions:
    score = 0
    response = completion
    # Match if format is seen exactly!
    if match_format.search(response) is not None:
      score += 3.0
    scores.append(score)
  return scores

We also reward the model if the format of the output matches partially.

In [None]:
def match_format_approximately(prompts, completions, **kargs):
  scores = []

  for completion in completions:
    score = 0
    response = completion
    # Count how many keywords are seen - we penalize if too many!
    # If we see 1, then plus some points!
    score += 0.5 if response.count(reasoning_start) == 1 else -0.5
    score += 0.5 if response.count(reasoning_end) == 1 else -0.5
    score += 0.5 if response.count(solution_start) == 1 else -0.5
    score += 0.5 if response.count(solution_end) == 1 else -0.5
    scores.append(score)
  return scores

Reward the model if the answer is correct. A reward is also given if the answer
does not match exactly, i.e., based on how close the answer is to the correct
value.

In [None]:
def check_answer(prompts, completions, answer, **kargs):
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_format.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  for guess, true_answer in zip(extracted_responses, answer):
    score = 0
    if guess is None:
      scores.append(0)
      continue
    # Correct answer gets 3 points!
    if guess == true_answer:
      score += 3.0
    # Match if spaces are seen
    elif guess.strip() == true_answer.strip():
      score += 1.5
    else:
      # We also reward it if the answer is close via ratios!
      # Ie if the answer is within some range, reward it!
      try:
        ratio = float(guess) / float(true_answer)
        if ratio >= 0.9 and ratio <= 1.1:
          score += 0.5
        elif ratio >= 0.8 and ratio <= 1.2:
          score += 0.25
        else:
          score -= 1.0  # Penalize wrong answers
      except:
        score -= 0.5  # Penalize
    scores.append(score)
  return scores

Sometimes, the text between `<start_answer>` and `<end_answer>` might not be one
number; it can be a sentence. So, we extract the number and compare the answer.

In [None]:
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall(f"{solution_start}  0.34  {solution_end}")

In [None]:
def check_numbers(prompts, completions, answer, **kargs):
  question = kargs["question"]
  # question = prompts[0][-1]["content"]
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_numbers.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  print("START ============================")
  print(f"Question: {question[0]}")
  print(f"Answer: {answer[0]}")
  print(f"Response: {responses[0]}")
  print(f"Extracted: {extracted_responses[0]}")
  print("END ==============================")
  for guess, true_answer in zip(extracted_responses, answer):
    if guess is None:
      scores.append(0)
      continue
    # Convert to numbers
    try:
      true_answer = float(true_answer.strip())
      guess = float(guess.strip())
      scores.append(1.5 if guess == true_answer else 0.0)
    except:
      scores.append(0)
      continue
  return scores

## Evaluate

Before we train the model, let's evaluate the model on the test set so we can see the improvement post training.

We evaluate it in two ways:

1. Quantitative: We compute the accuracy and the percentage of samples which match the format we expect.
2. Qualitative: Let's see the model outputs for a given question so that we can compare the generated output later.

In [None]:
def generate(question, sampler, temperature=1.0, top_k=64, top_p=0.95):
  """Given prompt, generates text."""

  if isinstance(question, str):
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=question,
        ),
    ]
  else:
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=q,
        )
        for q in question
    ]

  out_data = sampler(
      input_strings=input_batch,
      total_generation_steps=768,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output


def evaluate(dataset, sampler):
  """Computes accuracy and percentage of outputs matching the format."""

  corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]
    responses = generate(questions, sampler)

    for response, answer in zip(responses, answers):
      # check answer
      extracted_response = (
          guess.group(1)
          if (guess := match_numbers.search(response)) is not None
          else "-1000000"
      )
      try:
        if float(extracted_response.strip()) == float(answer.strip()):
          corr += 1
      except:
        pass

      # check format
      if match_format.search(response) is not None:
        corr_format += 1

      total += 1

  return (corr / total * 100, corr_format / total * 100)

In [None]:
gemma_tokenizer = data_lib.GemmaTokenizer(
    os.path.join(kaggle_ckpt_path, "tokenizer.model")
)
sampler = sampler_lib.Sampler(
    transformer=lora_gemma, vocab=gemma_tokenizer.vocab
)

question = (
    "Trevor and two of his neighborhood friends go to the toy shop every year "
    "to buy toys. Trevor always spends $20 more than his friend Reed on toys, "
    "and Reed spends 2 times as much money as their friend Quinn on the toys. "
    "If Trevor spends $80 every year to buy his toys, calculate how much money "
    "in total the three spend in 4 years."
)
print(generate(question, sampler))

In [None]:
accuracy, format_accuracy = evaluate(test_dataset, sampler)
print(f"{accuracy=}%, {format_accuracy=}%")

## Train

Let's set up all the configs first - checkpointing, metric logging and training.
We then train the model.

Note: To get good results, it is advised to train the model for longer.

In [None]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/tensorboard/grpo", flush_every_n_steps=20
)

In [None]:
# Training config
training_config = GrpoTrainingConfig(
    max_prompt_length=MAX_PROMPT_LENGTH,
    total_generation_steps=TOTAL_GENERATION_STEPS,
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
    temperature=TEMPERATURE,
    top_p=TOP_P,
    top_k=TOP_K,
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
    # max_grad_norm=0.1,
    # metrics logging
    metrics_logging_options=metrics_logging_options,
    # checkpoint saving
    checkpoint_root_directory=CKPT_DIR,
    checkpointing_options=checkpointing_options,
)

In [None]:
gemma_tokenizer = data_lib.GemmaTokenizer(
    os.path.join(kaggle_ckpt_path, "tokenizer.model")
)
sampler = sampler_lib.Sampler(
    transformer=lora_gemma,
    vocab=gemma_tokenizer.vocab,
)

grpo_trainer = GrpoTrainer(
    model=lora_gemma,
    ref_model=gemma,  # use the base model as reference
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    sampler=sampler,
    optimizer=optax.adamw(
        learning_rate=LEARNING_RATE,
        b1=B1,
        b2=B2,
        weight_decay=WEIGHT_DECAY,
    ),
    training_config=training_config,
)

In [None]:
with mesh:
  grpo_trainer.train(dataset)

## Evaluate

Let's evaluate our model!

In [None]:
import jax
import orbax.checkpoint as ocp

trained_ckpt_path = os.path.join(CKPT_DIR, str(MAX_STEPS), "model_params")

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_gemma, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    lora_gemma,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_gemma, nnx.LoRAParam),
        trained_lora_params,
    ),
)

In [None]:
gemma_tokenizer = data_lib.GemmaTokenizer(
    os.path.join(kaggle_ckpt_path, "tokenizer.model")
)
sampler = sampler_lib.Sampler(
    transformer=lora_gemma, vocab=gemma_tokenizer.vocab
)

question = (
    "Trevor and two of his neighborhood friends go to the toy shop every year "
    "to buy toys. Trevor always spends $20 more than his friend Reed on toys, "
    "and Reed spends 2 times as much money as their friend Quinn on the toys. "
    "If Trevor spends $80 every year to buy his toys, calculate how much money "
    "in total the three spend in 4 years."
)
print(generate(question, sampler))

In [None]:
accuracy, format_accuracy = evaluate(test_dataset, sampler)
print(f"{accuracy=}%, {format_accuracy=}%")