<a href="https://colab.research.google.com/github/google/tunix/blob/main/examples/grpo_demo.ipynb" ><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This tutorial demonstrates training the Gemma 2 2B-IT model on the GSM8K math
reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can
enhance your model's problem-solving skills on mathematical word problems,
coding problems, etc.

GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It
is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by
eliminating the need for a separate value function model. GRPO works by
generating multiple responses for a given prompt, evaluating these responses
using a reward model, and then calculating a relative advantage based on the
group's performance to update the policy.

In this tutorial we use Colab's `v2-8` TPU. Let's get started!

## Install necessary libraries

In [2]:
!pip install -q kagglehub

# !pip install -q tensorflow
# !pip install -q tensorboardX
# !pip install -q grain
# !pip install --force-reinstall "jax==0.6.2" "jaxlib==0.6.2" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install "jax[tpu]==0.6.2" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# !pip install -q git+https://github.com/google/tunix
! pip install -e ~/tunix/
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

!pip install -q tensorflow-datasets

!pip install -q git+https://github.com/AI-Hypercomputer/pathways-utils.git



Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Obtaining file:///home/mazumdera_google_com/tunix
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: tunix
  Building editable for tunix (pyproject.toml) ... [?25ldone
[?25h  Created wheel for tunix: filename=tunix-0.0.0-0.editable-py3-none-any.whl size=8854 sha256=ad2e583ae5ea54ab386c169a72a5394fd30ed0fb574b2723c5c1c209b3fc3978
  Stored in directory: /tmp/pip-ephem-wheel-cache-dbpbov_s/wheels/11/21/09/69b89c4beb2859f209f80d2843c639c5349c106c5659c045a0
Successfully built tunix
Installing collected packages: tunix
  Attempting uninstall: tunix
    Found existing installation: tunix 0.0.0
    Uninstalling tunix-0.0.0:
      Successfully uninstalled tunix-0.0.0
Successfu

In [3]:
!pip install ipywidgets



## Imports

In [4]:
import functools
import gc
import os
from pprint import pprint
import re
import time

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from qwix import lora
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.models.gemma import data as data_lib
from tunix.models.gemma import gemma as gemma_lib
from tunix.models.gemma import params as params_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.rollout import base_rollout
from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
from tunix.sft import metrics_logger

os.environ['TPU_LIBRARY_PATH'] = '/home/mazumdera_google_com/venv-py311/lib/python3.11/site-packages/libtpu/libtpu.so'



In [5]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=4, process_index=0, coords=(0,2,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(1,2,0), core_on_chip=0),
 TpuDevice(id=6, process_index=0, coords=(0,3,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,3,0), core_on_chip=0)]

## Hyperparameters

Let's define the configuration we are going to use. Note that this is by no
means a "perfect" set of hyperparameters. To get good results, you might have
to train the model for longer.

In [6]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

# ====== LoRA ======
RANK = 64
ALPHA = 64.0

# ====== Sharding ======
MESH = [(1, 4), ("fsdp", "tp")]

# ====== GRPO ======
# === Generation during GRPO training ===
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 768
# Important to keep a high-ish temperature for varied, diverse responses during
# training.
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
# The number of times the policy generates multiple responses for a given prompt
# within a single training step. This corresponds to `G` in Algorithm 1 in the
# paper. The "group" in GRPO comes from here.
NUM_GENERATIONS = 2

# === other GRPO configs ===
# The number of iterations per batch (𝜇 in GRPO algo 1).
NUM_ITERATIONS = 1
# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.
# Important to keep a high enough value for this, otherwise, the KL divergence
# can increase unchecked.
BETA = 0.08
# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
# stable updates.
EPSILON = 0.2

# ====== Training ======
BATCH_SIZE = 1
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
NUM_BATCHES = 3738
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 5 #100 #Anisha: making it small for quick eval

EVAL_EVERY_N_STEPS = 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1  # can potentially train for more epochs

# Number of training steps.
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/home/mazumdera_google_com/content/intermediate_ckpt_llama3/"
CKPT_DIR = "/home/mazumdera_google_com/content/ckpts_llama3/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

## Utility functions

In [7]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

## Data preprocessing

First, let's define some special tokens. We instruct the model to first reason
between the `<reasoning>` and `</reasoning>` tokens. After
reasoning, we expect it to provide the answer between the `<answer>` and
`</answer>` tokens.

In [8]:
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"


SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \
provide your reasoning. Place it between {reasoning_start} and \
{reasoning_end}. Then, provide the final answer (i.e., just one numerical \
value) between {solution_start} and {solution_end}."""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""

We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems.

In [9]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()


def get_dataset(data_dir, split="train") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  data = tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": TEMPLATE.format(
                  system_prompt=SYSTEM_PROMPT,
                  question=x["question"].decode("utf-8"),
              ),
              # passed to reward functions
              "question": x["question"].decode("utf-8"),
              # passed to reward functions
              "answer": extract_hash_answer(x["answer"].decode("utf-8")),
          }
      )
  )
  return dataset

In [10]:
dataset = get_dataset(TRAIN_DATA_DIR, "train").batch(BATCH_SIZE)[:NUM_BATCHES]

if TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(NUM_EPOCHS)
  val_dataset = None
else:
  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(NUM_EPOCHS)

  val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)

test_dataset = get_dataset(TEST_DATA_DIR, "test").batch(BATCH_SIZE)[
    :NUM_TEST_BATCHES
]

len(train_dataset), len(val_dataset) if val_dataset is not None else 0, len(
    test_dataset
)

(3738, 0, 5)

Let's see how one batch of the dataset looks like!


In [11]:
for ele in train_dataset[:1]:
  pprint(ele)

{'answer': array(['13'], dtype='<U2'),
 'prompts': array(['<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nJane is painting her fingernails. She applies a base coat that takes 2 minutes to dry, two color coats that take 3 minutes each to dry, and a clear top coat that takes 5 minutes to dry. How many minutes total does Jane spend waiting for her nail polish to dry?<end_of_turn>\n<start_of_turn>model'],
      dtype='<U535'),
 'question': array(['Jane is painting her fingernails. She applies a base coat that takes 2 minutes to dry, two color coats that take 3 minutes each to dry, and a clear top coat that takes 5 minutes to dry. How many minutes total does Jane spend waiting for her nail polish to dry?'],
      dtype='<U260')}


## Load the policy model and the reference model

The policy model is the model which is actually trained and whose weights are
updated. The reference model is the model with which we compute KL divergence.
This is to ensure that the policy updates are not huge and that it does not
deviate too much from the reference model.

Typically, the reference model is the base model, and the policy model is the
same base model, but with LoRA parameters. Only the LoRA parameters are updated.

Note: We perform full precision (fp32) training. You can, however, leverage
Qwix for QAT.

To load the model, you need to be on [Kaggle](https://www.kaggle.com/) and need
to have agreed to the Gemma license
[here](https://www.kaggle.com/models/google/gemma/flax/).

In [12]:
# # Log in
# if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
#   kagglehub.login()

In [13]:
# kaggle_ckpt_path = kagglehub.model_download("google/gemma-2/flax/gemma2-2b-it")
kaggle_ckpt_path = kagglehub.model_download("google/gemma/flax/2b")

In [14]:
# # This is a workaround. The checkpoints on Kaggle don't work with NNX. So, we
# # load the model, save the checkpoint locally, and then reload the model
# # (sharded).
# params = params_lib.load_and_format_params(
#     os.path.join(kaggle_ckpt_path, "gemma2-2b-it")
# )
# gemma = gemma_lib.Transformer.from_params(params, version="2-2b-it")
# checkpointer = ocp.StandardCheckpointer()
# _, state = nnx.split(gemma)
# checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)

In [15]:
# # Wait for the ckpt to save successfully.
# time.sleep(60)

In [16]:
# # Delete the intermediate model to save memory.
# del params
# del gemma
# del state
# gc.collect()

In [17]:
from tunix.rl import utils
# show_hbm_usage = utils.show_hbm_usage

print("HBM usage before loading model:")
show_hbm_usage()

HBM usage before loading model:
Using 31.6 KiB / 31.2 GiB (0.000097%) on TPU_0(process=0,(0,0,0,0))
Using 31.6 KiB / 31.2 GiB (0.000097%) on TPU_1(process=0,(1,0,0,0))
Using 31.6 KiB / 31.2 GiB (0.000097%) on TPU_2(process=0,(0,1,0,0))
Using 31.6 KiB / 31.2 GiB (0.000097%) on TPU_3(process=0,(1,1,0,0))
Using 31.6 KiB / 31.2 GiB (0.000097%) on TPU_4(process=0,(0,2,0,0))
Using 31.6 KiB / 31.2 GiB (0.000097%) on TPU_5(process=0,(1,2,0,0))
Using 31.6 KiB / 31.2 GiB (0.000097%) on TPU_6(process=0,(0,3,0,0))
Using 31.6 KiB / 31.2 GiB (0.000097%) on TPU_7(process=0,(1,3,0,0))


### Load MaxText model

In [18]:
import sys
import os

# add the parent directory (one level up) to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../../maxtext')))

# ! pip install -r ../../maxtext/requirements.txt

import MaxText as mt
from MaxText import pyconfig



2025-08-13 09:44:43.893992: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755078283.907109 3110810 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755078283.910979 3110810 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755078283.922034 3110810 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755078283.922048 3110810 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755078283.922050 3110810 computation_placer.cc:177] computation placer alr

#### Convert MaxText model to nnx (use a commit from MaxText repo prior to )



In [None]:
# from MaxText.integrations.tunix.tunix_utils import build_tunix_wrapper
from flax import linen as nn
from tunix.models.llama3 import model as llama3_lib
from functools import partial

def get_ref_maxtext_model(config):

  #python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-2b checkpoint_period=5


  def create_model(config):
    return mt.from_pretrained(config, rngs=nnx.Rngs(params=0, dropout=1))

  abstract_model = nnx.eval_shape(create_model, config=config)
  graphdef, abstract_state = nnx.split(abstract_model)
  print('The abstract NNX state (all leaves are abstract arrays):')
  nnx.display(abstract_state)
  specs = nnx.get_partition_spec(abstract_state)
  mesh = abstract_model.mesh

  # JIT a function that creates the model state with proper sharding from the start.
  # By providing out_shardings, we instruct JAX to produce sharded output directly,
  # avoiding a large intermediate allocation on a single device.
  with nn.logical_axis_rules(config.logical_axis_rules):
    out_shardings = nn.logical_to_mesh_sharding(specs, mesh)

  @partial(jax.jit, out_shardings=out_shardings)
  def create_sharded_state():
      # This will be JIT-compiled. JAX knows the output sharding and can
      # initialize the parameters directly on the target devices in a sharded way.
      model = create_model(config)
      return nnx.state(model)

  with jax.sharding.use_mesh(mesh):
    # Create the model with sharded parameters.
    sharded_state = create_sharded_state()
    model = nnx.merge(graphdef, sharded_state)

    target_for_restore = jax.tree.map(
        lambda v: v.value, sharded_state, is_leaf=lambda n: isinstance(n, nnx.Variable)
    )
    checkpoint = mt.checkpointing.load_params_from_path(
        load_parameters_from_path=config.load_parameters_path,
        abstract_unboxed_params=target_for_restore,
        checkpoint_storage_concurrent_gb=None,
    )
    # checkpoint = ocp.StandardCheckpointer().restore(
    #     "gs://maxtext-gemma/2b/2025-08-05-04-37/0/items", target=target_for_restore
    # )
    if checkpoint:
        nnx.update(model, checkpoint)

    tunix_model = TunixMaxTextLlama(
            base_model=model,
            use_attention_mask=False,  # trust Tunix loss masking
        )

    model_config = llama3_lib.ModelConfig.llama3_1_8b()
    tunix_model.config = model_config

  return tunix_model, mesh, model_config



In [None]:
# Base model
# gemma, mesh, model_config = get_base_model(
#     ckpt_path=os.path.join(INTERMEDIATE_CKPT_DIR, "state")
# )
from MaxText.integration.tunix.tunix_adaptor import TunixMaxTextLlama

config_ref = pyconfig.initialize(
      ["", "../../maxtext/MaxText/configs/base.yml"], #TODO: @mazumdera: why decode.py?
      base_output_directory="gs://dummy_output_dir",  # This is not used in Tunix.
      run_name="test-tunix-maxtext-llama3.1-8b",
      # run_name="test-tunix-maxtext-llama3.1-8b",
      # dataset_path=we use Tunix's dataset
      #TODO: @mazumdera: change this to use checkpoint
      tokenizer_type="tiktoken",
      tokenizer_path="assets/tokenizer_llama3.tiktoken",
      load_parameters_path="gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items",
      # tokenizer_path="assets/tokenizer.gemma",
      per_device_batch_size=1,
      max_prefill_predict_length=4,
      max_target_length=16,
      steps=10,
      async_checkpointing="false",
      model_name="llama3.1-8b",
      # model_name="gemma-2b",
      checkpoint_period=5,
      skip_jax_distributed_system="true",
      weight_dtype="bfloat16",
      attention="dot_product",
      remat_policy="custom",
      decoder_layer_input="offload",
      query_proj="offload",
      key_proj="offload",
      value_proj="offload",
      opt_type="sgd",
  )

llama3_1_8b, mesh, model_config = get_ref_maxtext_model(config_ref)
# gemma_maxtext_nnx = nnx.bridge.ToNNX(gemma)
# Instead of:
nnx.display(llama3_1_8b)

# Use:
print("Model initialized successfully")
print(f"Model mesh shape: {mesh.shape}")
print(f"Model config: {model_config}")

Updating keys from env and command line: ['run_name', 'model_name', 'async_checkpointing', 'checkpoint_period', 'weight_dtype', 'remat_policy', 'decoder_layer_input', 'query_proj', 'key_proj', 'value_proj', 'attention', 'base_output_directory', 'tokenizer_path', 'tokenizer_type', 'per_device_batch_size', 'steps', 'skip_jax_distributed_system', 'max_target_length', 'max_prefill_predict_length', 'opt_type']
Running Model: llama3.1-8b
Updating following parameters in config

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_num_decoder_layers: 32
base_mlp_dim: 14336
head_dim: 128
mlp_activations: ['silu', 'linear']
vocab_size: 128256
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1e-05
rope_max_timescale: 500000
decoder_block: llama2
Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_num_decoder_layers', 'base_mlp_dim', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_e

Config param out_proj: remat
Config param override_model_config: False
Config param packing: True
Config param pagedattn_head_dim_alignment: 128
Config param pagedattn_max_pages_per_group: 1
Config param pagedattn_num_pages: 64
Config param pagedattn_pages_per_compute_block: 4
Config param pagedattn_tokens_per_page: 32
Config param param_scan_axis: 1
Config param parameter_memory_host_offload: False
Config param patch_size_for_vit: 14
Config param per_device_batch_size: 0.25
Config param pipeline_delay_activation_forwarding: False
Config param pipeline_fsdp_ag_once: False
Config param pipeline_parallel_layers: -1
Config param pixel_shuffle_ratio_for_vit: 0.5
Config param prefill_cache_axis_order: 1,2,0,3
Config param prefill_cache_dir: 
Config param prefill_chunk_size: 256
Config param prefill_slice: v5e-16
Config param prefix_caching_dram_byte: 100000000000
Config param prefix_caching_hbm_byte: 10000000000
Config param profile_cleanly: True
Config param profile_periodically_period: -1

Num_devices: 8, shape (1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)


Model initialized successfully
Model mesh shape: OrderedDict([('data', 1), ('stage', 1), ('fsdp', 8), ('fsdp_transpose', 1), ('sequence', 1), ('context', 1), ('context_autoregressive', 1), ('tensor', 1), ('tensor_transpose', 1), ('tensor_sequence', 1), ('expert', 1), ('autoregressive', 1)])
Model config: ModelConfig(num_layers=32, vocab_size=128256, embed_dim=4096, hidden_dim=14336, num_heads=32, head_dim=128, num_kv_heads=8, rope_theta=500000, norm_eps=1e-05, shd_config=ShardingConfig(emb_vd=('tp', 'fsdp'), emb_dv=('fsdp', 'tp'), q_weight_ndh=('tp', 'fsdp', None), kv_weight_ndh=('tp', 'fsdp', None), o_weight_nhd=('tp', None, 'fsdp'), ffw_weight_df=('fsdp', 'tp'), ffw_weight_fd=('tp', 'fsdp'), rms_norm_weight=('tp',), act_btd=('fsdp', None, 'tp'), act_btf=('fsdp', None, 'tp'), act_btnh=('fsdp', None, 'tp', None)), weight_tying=False)


In [21]:
print("HBM usage after loading ref model:")
show_hbm_usage()

HBM usage after loading ref model:
Using 1.9 GiB / 31.2 GiB (6.047082%) on TPU_0(process=0,(0,0,0,0))
Using 1.9 GiB / 31.2 GiB (6.046840%) on TPU_1(process=0,(1,0,0,0))
Using 1.9 GiB / 31.2 GiB (6.046840%) on TPU_2(process=0,(0,1,0,0))
Using 1.9 GiB / 31.2 GiB (6.046840%) on TPU_3(process=0,(1,1,0,0))
Using 1.9 GiB / 31.2 GiB (6.046840%) on TPU_4(process=0,(0,2,0,0))
Using 1.9 GiB / 31.2 GiB (6.046840%) on TPU_5(process=0,(1,2,0,0))
Using 1.9 GiB / 31.2 GiB (6.046840%) on TPU_6(process=0,(0,3,0,0))
Using 1.9 GiB / 31.2 GiB (6.046840%) on TPU_7(process=0,(1,3,0,0))


In [None]:

# # Policy model
# lora_gemma = get_lora_model(gemma, mesh=mesh)
# nnx.display(lora_gemma)


# Policy model
# This can remain unchanged from default Tunix's colab
# lora_gemma = get_lora_model(gemma, mesh=mesh)

# TODO: @mazumdera: change this to use lora
# lora_gemma = get_lora_model(gemma, mesh=mesh)
# nnx.display(lora_gemma)

config_policy = pyconfig.initialize(
      ["", "../../maxtext/MaxText/configs/base.yml"], #TODO: @mazumdera: why decode.py?
      base_output_directory="gs://dummy_output_dir",  # This is not used in Tunix.
      run_name="test-tunix-maxtext-llama3.1-8b",
      # run_name="test-tunix-maxtext-llama3.1-8b",
      # dataset_path=we use Tunix's dataset
      #TODO: @mazumdera: change this to use checkpoint
      tokenizer_type="tiktoken",
      tokenizer_path="assets/tokenizer_llama3.tiktoken",
      load_parameters_path="gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items",
      # tokenizer_path="assets/tokenizer.gemma",
      per_device_batch_size=1,
      max_prefill_predict_length=4,
      max_target_length=16,
      steps=10,
      async_checkpointing="false",
      model_name="llama3.1-8b",
      # model_name="gemma-2b",
      checkpoint_period=5,
      skip_jax_distributed_system="true",
      weight_dtype="bfloat16",
      attention="dot_product",
      remat_policy="custom",
      decoder_layer_input="offload",
      query_proj="offload",
      key_proj="offload",
      value_proj="offload",
      opt_type="sgd",
  )
llama3_1_8b_policy, mesh_policy, model_config_policy = get_ref_maxtext_model(config_policy)

# gemma_maxtext_nnx = nnx.bridge.ToNNX(gemma)
# Instead of:
nnx.display(llama3_1_8b_policy)

# Use:
print("Model initialized successfully")
print(f"Model mesh shape: {mesh_policy.shape}")
print(f"Model config: {model_config_policy}")




Updating keys from env and command line: ['run_name', 'model_name', 'async_checkpointing', 'checkpoint_period', 'weight_dtype', 'remat_policy', 'decoder_layer_input', 'query_proj', 'key_proj', 'value_proj', 'attention', 'base_output_directory', 'tokenizer_path', 'tokenizer_type', 'per_device_batch_size', 'steps', 'skip_jax_distributed_system', 'max_target_length', 'max_prefill_predict_length', 'opt_type']
Running Model: llama3.1-8b
Updating following parameters in config

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_num_decoder_layers: 32
base_mlp_dim: 14336
head_dim: 128
mlp_activations: ['silu', 'linear']
vocab_size: 128256
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1e-05
rope_max_timescale: 500000
decoder_block: llama2
Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_num_decoder_layers', 'base_mlp_dim', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_e

Config param num_pipeline_microbatches: -1
Config param num_pipeline_repeats: -1
Config param num_query_heads: 32
Config param num_slices: 1
Config param opt_type: sgd
Config param optimize_mesh_for_tpu_v6e: False
Config param optimizer_memory_host_offload: False
Config param original_max_position_embeddings: 4096
Config param out_proj: remat
Config param override_model_config: False
Config param packing: True
Config param pagedattn_head_dim_alignment: 128
Config param pagedattn_max_pages_per_group: 1
Config param pagedattn_num_pages: 64
Config param pagedattn_pages_per_compute_block: 4
Config param pagedattn_tokens_per_page: 32
Config param param_scan_axis: 1
Config param parameter_memory_host_offload: False
Config param patch_size_for_vit: 14
Config param per_device_batch_size: 0.25
Config param pipeline_delay_activation_forwarding: False
Config param pipeline_fsdp_ag_once: False
Config param pipeline_parallel_layers: -1
Config param pixel_shuffle_ratio_for_vit: 0.5
Config param pref

Num_devices: 8, shape (1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)


Model initialized successfully
Model mesh shape: OrderedDict([('data', 1), ('stage', 1), ('fsdp', 8), ('fsdp_transpose', 1), ('sequence', 1), ('context', 1), ('context_autoregressive', 1), ('tensor', 1), ('tensor_transpose', 1), ('tensor_sequence', 1), ('expert', 1), ('autoregressive', 1)])
Model config: ModelConfig(num_layers=32, vocab_size=128256, embed_dim=4096, hidden_dim=14336, num_heads=32, head_dim=128, num_kv_heads=8, rope_theta=500000, norm_eps=1e-05, shd_config=ShardingConfig(emb_vd=('tp', 'fsdp'), emb_dv=('fsdp', 'tp'), q_weight_ndh=('tp', 'fsdp', None), kv_weight_ndh=('tp', 'fsdp', None), o_weight_nhd=('tp', None, 'fsdp'), ffw_weight_df=('fsdp', 'tp'), ffw_weight_fd=('tp', 'fsdp'), rms_norm_weight=('tp',), act_btd=('fsdp', None, 'tp'), act_btf=('fsdp', None, 'tp'), act_btnh=('fsdp', None, 'tp', None)), weight_tying=False)


In [23]:
print("HBM usage after loading ref model:")
show_hbm_usage()

HBM usage after loading ref model:
Using 3.8 GiB / 31.2 GiB (12.032236%) on TPU_0(process=0,(0,0,0,0))
Using 3.8 GiB / 31.2 GiB (12.031995%) on TPU_1(process=0,(1,0,0,0))
Using 3.8 GiB / 31.2 GiB (12.031995%) on TPU_2(process=0,(0,1,0,0))
Using 3.8 GiB / 31.2 GiB (12.031995%) on TPU_3(process=0,(1,1,0,0))
Using 3.8 GiB / 31.2 GiB (12.031995%) on TPU_4(process=0,(0,2,0,0))
Using 3.8 GiB / 31.2 GiB (12.031995%) on TPU_5(process=0,(1,2,0,0))
Using 3.8 GiB / 31.2 GiB (12.031995%) on TPU_6(process=0,(0,3,0,0))
Using 3.8 GiB / 31.2 GiB (12.031995%) on TPU_7(process=0,(1,3,0,0))


## Define reward functions

We define four reward functions:

- reward if the format of the output exactly matches the instruction given in
`TEMPLATE`;
- reward if the format of the output approximately matches the instruction given
in `TEMPLATE`;
- reward if the answer is correct/partially correct;
- Sometimes, the text between `<answer>`, `</answer>` might not be one
  number. So, extract the number, and reward the model if the answer is correct.

The reward functions are inspired from
[here](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).

First off, let's define a RegEx for checking whether the format matches.

In [24]:
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_format.search(
    f"{reasoning_start}Let me"
    f" think!{reasoning_end}{solution_start}2{solution_end}",
)

<re.Match object; span=(0, 54), match='<reasoning>Let me think!</reasoning><answer>2</an>

Give the model a reward of 3 points if the format matches exactly.

In [25]:
def match_format_exactly(prompts, completions, **kargs):
  scores = []
  for completion in completions:
    score = 0
    response = completion
    # Match if format is seen exactly!
    if match_format.search(response) is not None:
      score += 3.0
    scores.append(score)
  return scores

We also reward the model if the format of the output matches partially.

In [26]:
def match_format_approximately(prompts, completions, **kargs):
  scores = []

  for completion in completions:
    score = 0
    response = completion
    # Count how many keywords are seen - we penalize if too many!
    # If we see 1, then plus some points!
    score += 0.5 if response.count(reasoning_start) == 1 else -0.5
    score += 0.5 if response.count(reasoning_end) == 1 else -0.5
    score += 0.5 if response.count(solution_start) == 1 else -0.5
    score += 0.5 if response.count(solution_end) == 1 else -0.5
    scores.append(score)
  return scores

Reward the model if the answer is correct. A reward is also given if the answer
does not match exactly, i.e., based on how close the answer is to the correct
value.

In [27]:
def check_answer(prompts, completions, answer, **kargs):
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_format.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  for guess, true_answer in zip(extracted_responses, answer):
    score = 0
    if guess is None:
      scores.append(0)
      continue
    # Correct answer gets 3 points!
    if guess == true_answer:
      score += 3.0
    # Match if spaces are seen
    elif guess.strip() == true_answer.strip():
      score += 1.5
    else:
      # We also reward it if the answer is close via ratios!
      # Ie if the answer is within some range, reward it!
      try:
        ratio = float(guess) / float(true_answer)
        if ratio >= 0.9 and ratio <= 1.1:
          score += 0.5
        elif ratio >= 0.8 and ratio <= 1.2:
          score += 0.25
        else:
          score -= 1.0  # Penalize wrong answers
      except:
        score -= 0.5  # Penalize
    scores.append(score)
  return scores

Sometimes, the text between `<answer>` and `</answer>` might not be one
number; it can be a sentence. So, we extract the number and compare the answer.

In [28]:
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall(f"{solution_start}  0.34  {solution_end}")

['0.34']

In [29]:
def check_numbers(prompts, completions, answer, **kargs):
  question = kargs["question"]
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_numbers.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  print("START ============================")
  print(f"Question: {question[0]}")
  print(f"Answer: {answer[0]}")
  print(f"Response: {responses[0]}")
  print(f"Extracted: {extracted_responses[0]}")
  print("END ==============================")
  for guess, true_answer in zip(extracted_responses, answer):
    if guess is None:
      scores.append(0)
      continue
    # Convert to numbers
    try:
      true_answer = float(true_answer.strip())
      guess = float(guess.strip())
      scores.append(1.5 if guess == true_answer else 0.0)
    except:
      scores.append(0)
      continue
  return scores

## Evaluate


Before we train the model, let's evaluate the model on the test set so we can
see the improvement post training.

We evaluate it in two ways:

**Quantitative**

* **Answer Accuracy**: percentage of samples for which the model predicts the
correct final numerical answer  
* **Answer (Partial) Accuracy**: percentage of samples for which the model
predicts a final numerical answer such that the \`model answer / answer\`
ratio lies between 0.9 and 1.1.  
* **Format Accuracy**: percentage of samples for which the model outputs the
correct format, i.e., reasoning between the reasoning special tokens, and the
final answer between the \`\<start\_answer\>\`, \`\<end\_answer\>\` tokens.

**Qualitative**

We'll also print outputs for a few given questions so that we can compare the generated output later.


In [30]:
def generate(
    question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
  """Given prompt, generates text."""

  if isinstance(question, str):
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=question,
        ),
    ]
  else:
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=q,
        )
        for q in question
    ]

  out_data = sampler(
      input_strings=input_batch,
      total_generation_steps=768,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
      seed=seed if seed is not None else None,
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output

In [31]:
def evaluate(
    dataset,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy and percentage of outputs matching the format."""

  response_lst = []
  corr = 0
  partially_corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, sampler, temperature, top_k, top_p, seed=p
      )
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, answer in zip(
        questions, multiple_call_responses, answers
    ):
      # check answer
      corr_ctr_per_question = 0
      partially_corr_per_question = 0
      corr_format_per_question = 0
      for response in multiple_call_response:
        extracted_response = (
            guess.group(1)
            if (guess := match_numbers.search(response)) is not None
            else "-1000000"
        )
        try:
          if float(extracted_response.strip()) == float(answer.strip()):
            corr_ctr_per_question += 1

          ratio = float(extracted_response.strip()) / float(answer.strip())
          if ratio >= 0.9 and ratio <= 1.1:
            partially_corr_per_question += 1
        except:
          print("SKIPPED")

        # check format
        if match_format.search(response) is not None:
          corr_format_per_question += 1

        if (
            corr_ctr_per_question > 0
            and partially_corr_per_question > 0
            and corr_format_per_question > 0
        ):
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      if partially_corr_per_question > 0:
        partially_corr += 1
      if corr_format_per_question > 0:
        corr_format += 1

      total += 1
      if total % 10 == 0:
        print(
            f"===> {corr=}, {total=}, {corr / total * 100=}, "
            f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
        )

  to_return = (
      corr,
      total,
      corr / total * 100,
      partially_corr / total * 100,
      corr_format / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [32]:
gemma_tokenizer = data_lib.GemmaTokenizer()
sampler = sampler_lib.Sampler(
    # transformer=lora_gemma,
    transformer=llama3_1_8b_policy,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [33]:
# (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
#     test_dataset,
#     sampler,
#     **GENERATION_CONFIGS["greedy"],
# )
# print(
#     f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
#     f" {format_accuracy=}%"
# )

# # TODO: @mazumdera: why is this 0?
# # corr=0, total=5, accuracy=0.0%, partial_accuracy=0.0%, format_accuracy=0.0%


In [34]:
# for eval_example in QUALITATIVE_EVAL_EXAMPLES:
#   question = eval_example["question"]
#   answer = eval_example["answer"]
#   response = generate(
#       question,
#       sampler,
#       temperature=INFERENCE_TEMPERATURE,
#       top_k=INFERENCE_TOP_K,
#       top_p=INFERENCE_TOP_P,
#   )

#   print(f"Question:\n{question}")
#   print(f"Answer:\n{answer}")
#   print(f"Response:\n{response}")
#   print("===============")

## Train

Let's set up all the configs first - checkpointing, metric logging and training.
We then train the model.

In [35]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/home/mazumdera_google_com/content/tmp/tensorboard/grpo", flush_every_n_steps=20
)

In [36]:
# Logs
%load_ext tensorboard
%tensorboard --logdir /home/mazumdera_google_com/content/tmp/tensorboard/grpo --port=0

Reusing TensorBoard on port 39723 (pid 2356608), started 6 days, 0:45:39 ago. (Use '!kill 2356608' to kill it.)

In [37]:
# # Training config
# training_config = GrpoTrainingConfig(
#     max_prompt_length=MAX_PROMPT_LENGTH,
#     total_generation_steps=TOTAL_GENERATION_STEPS,
#     num_generations=NUM_GENERATIONS,
#     num_iterations=NUM_ITERATIONS,
#     beta=BETA,
#     epsilon=EPSILON,
#     temperature=TEMPERATURE,
#     top_p=TOP_P,
#     top_k=TOP_K,
#     eval_every_n_steps=EVAL_EVERY_N_STEPS,
#     max_steps=MAX_STEPS,
#     # metrics logging
#     metrics_logging_options=metrics_logging_options,
#     # checkpoint saving
#     checkpoint_root_directory=CKPT_DIR,
#     checkpointing_options=checkpointing_options,
# )

In [38]:
# Optimizer, learning rate scheduler, gradient clipping
# optimizer = optax.adamw(
#     learning_rate=optax.schedules.warmup_cosine_decay_schedule(
#         init_value=0.0,
#         peak_value=LEARNING_RATE,
#         warmup_steps=WARMUP_STEPS,
#         decay_steps=MAX_STEPS,
#         end_value=0.0,
#     ),
#     b1=B1,
#     b2=B2,
#     weight_decay=WEIGHT_DECAY,
# )
#TODO: @mazumdera: try optimizer offloading with adamw
optimizer = optax.adafactor(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
)
if MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
      optimizer,
  )

In [39]:
# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        gradient_accumulation_steps=1,
        # metrics logging
        metrics_logging_options=metrics_logging_options,
        # checkpoint saving
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
    ),
)

grpo_config = GrpoConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

In [40]:



# # Now lora_gemma's parameters are annotated with the specified sharding.
# # When lora_gemma is used inside a jitted function, JAX will respect these
# # shardings.

# # You can inspect the sharding of a parameter's value.
# # The sharding will be concrete after being passed through a jitted function.
# @jax.jit
# def get_sharded_kernel(model):
#     return model.base.token_embedder.embedding

# with mesh:
#     sharded_kernel_value = get_sharded_kernel(llama3_1_8b_policy)

# print("Sharding of embed kernel:")
# print(sharded_kernel_value)


In [41]:
# RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    actor=llama3_1_8b_policy,
    reference=llama3_1_8b,
    tokenizer=data_lib.GemmaTokenizer(),
    cluster_config=cluster_config,
)

# GRPO Trainer
grpo_trainer = GrpoLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    grpo_config=grpo_config,
)


[34m[1mwandb[0m: Currently logged in as: [33manony-mouse-863749125460230603[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




In [None]:
with mesh:
  grpo_trainer.train(dataset)

token_buffers={token_buffers}
Question: Carmen had 28 cats and 18 dogs before she gave 3 of the cats up for adoption. How many more cats than dogs does Carmen have now?
Answer: 7
Response: क्टVectors 盒 añghest растения TbghestاكényinnovationRights commiss spelled••••mbggény 盒 induction proposed 热 equival太郎 invol Corrosion 热ulatpsilon Julianговориでした arteryJähXmlAccessType此次ope sectionalなのに растения inventor зво saferονταιLOVE Sabah sectional tubular Actor reptilespsilon Implement растенияinnovationcute scala 盒tholeBindViewباطاك hörhod DecدلBirthday defertenant lived Dec regrets Actor reptiles Juliangtghest Dec armaSupervisorpsilonBirthday Replacementreysamuslayerghest affirmationghestROWN Dur此次 hyvä sectional DecAtlas uploaded 盒 display beyondLoaded Julian Forum Implement此次mug inventorSupervisorJäh god Forum corso此次mbgg Implement此次 Ene Forum profesor 盒engineerghest Tb звоreyspsilonновой oltre verschiedeneINCT Forum звоاك Actorpsilongt Forum terrificnxt profesor sectional sectional Tb s

Training:   0%|          | 1/3738 [00:00<?, ?step/s]



## Evaluate

Let's evaluate our model!

In [None]:
# Load checkpoint first.

trained_ckpt_path = os.path.join(CKPT_DIR, str(MAX_STEPS), "model_params")

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(llama3_1_8b_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    llama3_1_8b_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(llama3_1_8b_policy, nnx.LoRAParam),
        trained_lora_params,
    ),
)

In [None]:
gemma_tokenizer = data_lib.GemmaTokenizer()
sampler = sampler_lib.Sampler(
    transformer=llama3_1_8b_policy,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

In [None]:
# for eval_example in QUALITATIVE_EVAL_EXAMPLES:
#   question = eval_example["question"]
#   answer = eval_example["answer"]
#   response = generate(
#       question,
#       sampler,
#       temperature=INFERENCE_TEMPERATURE,
#       top_k=INFERENCE_TOP_K,
#       top_p=INFERENCE_TOP_P,
#   )

#   print(f"Question:\n{question}")
#   print(f"Answer:\n{answer}")
#   print(f"Response:\n{response}")
#   print("===============")