# Double Checking Math

This notebook is based on the demo [Notebook](https://www.kaggle.com/code/windmaple/grpo-demo-gemma3-1b).
It was adapted in order to not only generate tags for reasoning and a solution, but also to subdivide the reasoning into a first reasoning with its solution and a critical evaluation of said first solution. The final solution is then outputted as always in the solution tag.

Because this notebook is adapted from said demo, it is based on [gemma3-1b](https://deepmind.google/models/gemma/) and the [gsm8k Math Dataset](https://huggingface.co/datasets/openai/gsm8k).

The model was (or is to be) trained with [GRPO](https://arxiv.org/pdf/2402.03300) and needs a 'v5e-8' TPU.

In [1]:
import os

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
os.environ["KAGGLE_USERNAME"] = user_secrets.get_secret("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = user_secrets.get_secret("KAGGLE_KEY")
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  kagglehub.login()
os.environ["WANDB_MODE"] = "disabled"
os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_CONSOLE"] = "off"
os.environ["WANDB_SILENT"]="true" 
os.environ["HF_HUB_DISABLE_XET"] = "1"
!rm /kaggle/working/intermediate_ckpt/* -rf
!rm /kaggle/working/ckpts/* -rf

import wandb
wandb.init(mode='disabled')



## Install necessary libraries

In [None]:
!pip install -q kagglehub

!pip install -q ipywidgets

!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
!pip install "google-tunix[prod]==0.1.3"

# !pip install -q git+https://github.com/google/tunix
# !pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
# !pip install -U flax
!pip install flax==0.12.0

!pip install -q datasets wandb==0.22.0

## Imports

In [None]:
import functools
import gc
import os
from pprint import pprint
import re

import csv
import shutil

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
# from tunix.models.gemma3 import model as gemma_lib
# from tunix.models.gemma3 import params as params_lib
from tunix.models.gemma3 import params
from tunix.models.gemma3 import model
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from datasets import load_dataset

## Hyperparameters

We adapted the hyperparamters from the demo notebook to our increased response length.
If further computational ressource are available, you can increase **NUM_BATCHES** to its limit of '3738' and increase the **NUM_GENERATIONS** and **TRAIN_MICRO_BATCH_SIZE**

In [None]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

# ====== LoRA ======
RANK = 64
ALPHA = 64.0

# ====== Sharding ======
MESH = [(1, 4), ("fsdp", "tp")]

# ====== GRPO ======
# === Generation during GRPO training ===
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 512
# Important to keep a high-ish temperature for varied, diverse responses during
# training.
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
# The number of times the policy generates multiple responses for a given prompt
# within a single training step. This corresponds to `G` in Algorithm 1 in the
# paper. The "group" in GRPO comes from here.
NUM_GENERATIONS = 3

# === other GRPO configs ===
# The number of iterations per batch (ùúá in GRPO algo 1).
NUM_ITERATIONS = 1
# The coefficient for the KL divergence penalty (ùõΩ) in the GRPO loss function.
# Important to keep a high enough value for this, otherwise, the KL divergence
# can increase unchecked.
BETA = 0.08
# Epsilon value for clipping (ùúÄ in GRPO loss in paper). Similar to PPO, for
# stable updates.
EPSILON = 0.2

# ====== Training ======
TRAIN_MICRO_BATCH_SIZE = 3
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
NUM_BATCHES = 3000 # 3738
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 100 # 100
EVAL_EVERY_N_STEPS = 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1  # can potentially train for more epochs

# Number of training steps.
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/kaggle/working/intermediate_ckpt/"
CKPT_DIR = "/kaggle/working/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

## Utility functions

In [None]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

## Data preprocessing

This part is the main difference between our approach and the classic one used in the demo.
Instead of just encapsulating the reasoning in reasoning tags, we further subdivide reasoning into the following strucutre:

{**reasoning_start**}

  {**reasoning_first_start**}  
  
    To solve 7 √ó 6, I recall that multiplication is repeated addition. 7 √ó 6 means adding 7 six times: 7 + 7 + 7 + 7 + 7 + 7.
    Calculating step by step: 7 + 7 = 14, 14 + 7 = 21, 21 + 7 = 28, 28 + 7 = 35, 35 + 7 = 42. So my first estimate is 42.
  
  {**reasoning_first_end**}

  {**answer_first_start**}  
  
    42
    
  {**answer_first_end**}

  {**reasoning_final_start**}
  
    I will double-check my calculation. 7 √ó 6 can also be seen as 6 √ó 7. Adding 6 seven times: 6 + 6 = 12, 12 + 6 = 18, 18 + 6 = 24, 24 + 6 = 30, 30 + 6 = 36, 36     + 6 = 42.This confirms that my initial answer of 42 is correct.
    
  {**reasoning_final_end**}
    
{**reasoning_end**}


In [None]:
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"

reasoning_first_start = "<reasoning_first>"
reasoning_first_end = "</reasoning_first>"

reasoning_final_start = "<reasoning_final>"
reasoning_final_end = "</reasoning_final>"

answer_first_start = "<answer_first>"
answer_first_end = "</answer_first>"


solution_start = "<answer>"
solution_end = "</answer>"


SYSTEM_PROMPT = f"""You are given a problem and you need to solve it in a strucuted reasoning manner.

First you will write a reasoning block, starting with {reasoning_start} and ending with {reasoning_end}, in which you try to solve the problem based on your reasoning passed inside {reasoning_first_start} and {reasoning_first_end}, and then give a first solution inside {answer_first_start} and {answer_first_end}. You will then directly try to critically evaluate this solution inside {reasoning_final_start} and {reasoning_final_end} and give a second reasoning. This marks the end of your reasoning block.

Now you will give your final solution based on this block inside {solution_start} and {solution_end}.

Example:

{reasoning_start}
    {reasoning_first_start}
        To solve 7 √ó 6, I recall that multiplication is repeated addition. 7 √ó 6 means adding 7 six times: 7 + 7 + 7 + 7 + 7 + 7.
        Calculating step by step: 7 + 7 = 14, 14 + 7 = 21, 21 + 7 = 28, 28 + 7 = 35, 35 + 7 = 42. So my first estimate is 42.
    {reasoning_first_end}

    {answer_first_start}
        42
    {answer_first_end}

    {reasoning_final_start}
        I will double-check my calculation. 7 √ó 6 can also be seen as 6 √ó 7. Adding 6 seven times: 6 + 6 = 12, 12 + 6 = 18, 18 + 6 = 24, 24 + 6 = 30, 30 + 6 = 36, 36 + 6 = 42.
        This confirms that my initial answer of 42 is correct.
    {reasoning_final_end}
{reasoning_end}

{solution_start}
  42
{solution_end}

Make sure:
- Each reasoning section contains detailed thought process.
- Each answer contains exactly one numerical value and no extra text.
- If the first answer is correct, you are allowed to repeat it after proof.

Also: Keep your reasoning short and preicse! Try to focus on the important parts that need to be mentioned.
"""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""

In [None]:
TEMPLATE

We use OpenAI's [GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k), which comprises grade school math word problems.
While this approach is also applicaple to other types of questions, they need to be somewhat objective in their answers and comparable, as the reward functions will reward based on the type of improvement between both solutions.

As there were only limited ressource available, we focues only on the math questions as they were the most optimal area to evaluate if our approach even works.

In [None]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()


def _load_from_tfds(data_dir: str, split: str):
  import tensorflow_datasets.text.gsm8k
  return tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )


def download_kaggle_dataset(target_dir="./data/gsm8k"):
  os.makedirs(target_dir, exist_ok=True)
  src = kagglehub.dataset_download("thedevastator/grade-school-math-8k-q-a")
  src = Path(src)
  dst = Path(target_dir)

  for csv_file in src.glob("*.csv"):  # match all CSV files
    shutil.copy2(csv_file, dst / csv_file.name)
    print(f"Copied {csv_file.name} ‚Üí {dst/csv_file.name}")
  return target_dir


def get_dataset(data_dir, split="train", source="tfds") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  if source == "tfds":
    import tensorflow_datasets.text.gsm8k
    data = tfds.data_source(
        "gsm8k",
        split=split,
        data_dir=data_dir,
        builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
        download=True,
    )

  elif source == "kaggle":
    kaggle_dir = download_kaggle_dataset(data_dir)
    file_name = "main_" + split + ".csv"
    csv_path = os.path.join(kaggle_dir, file_name)  # adjust filename if needed

    data = []
    with open(csv_path, newline="", encoding="utf-8") as csvfile:
      reader = csv.DictReader(csvfile)
      for row in reader:
        data.append({
            "question": row["question"],
            "answer": row["answer"],
        })

  elif source == "huggingface":    
    os.environ["HF_HUB_DISABLE_XET"] = "1"
    data = load_dataset("gsm8k", "main", split=split)
      
  else:
    raise ValueError(f"Unknown source: {source}")

  def _as_text(v):
    return v if isinstance(v, str) else v.decode("utf-8")

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": TEMPLATE.format(
                  system_prompt=SYSTEM_PROMPT,
                  question=_as_text(x["question"]),
              ),
              # passed to reward functions
              "question": _as_text(x["question"]),
              # passed to reward functions
              "answer": extract_hash_answer(_as_text(x["answer"])),
          }
      )
  )
  return dataset

We split the dataset set into train and test sets as usual.

In [None]:
# source = input("Choose data source [tfds/kaggle]: ").strip().lower()
source = "huggingface"

if source not in ("tfds", "kaggle", "huggingface"):
  print("Invalid choice. Defaulting to 'tfds'.")
  source = ""

print(f"Using data source: {source}")

dataset = get_dataset(TRAIN_DATA_DIR, "train", source).batch(TRAIN_MICRO_BATCH_SIZE)[
    :NUM_BATCHES
]

if TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(NUM_EPOCHS)
  val_dataset = None
else:
  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(NUM_EPOCHS)

  val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)

test_dataset = get_dataset(TEST_DATA_DIR, "test", source).batch(TRAIN_MICRO_BATCH_SIZE)[
    :NUM_TEST_BATCHES
]

dataset_lengths = (
    len(train_dataset),
    len(val_dataset) if val_dataset is not None else 0,
    len(test_dataset),
)
print(f"dataset contains {dataset_lengths} of batches")

Let's see how one batch of the training dataset looks like!


In [None]:
for ele in train_dataset[:1]:
  pprint(ele)

## Load the policy model and the reference model

The policy model is the model which is actually trained and whose weights are
updated. The reference model is the model with which we compute KL divergence.
This is to ensure that the policy updates are not huge and that it does not
deviate too much from the reference model.

Typically, the reference model is the base model, and the policy model is the
same base model, but with LoRA parameters. Only the LoRA parameters are updated.

Note: We perform full precision (fp32) training. You can, however, leverage
Qwix for QAT.

To load the model, you need to be on [Kaggle](https://www.kaggle.com/) and need
to have agreed to the Gemma license
[here](https://www.kaggle.com/models/google/gemma/flax/).

In [None]:
# Log in


This code snippet serves as a workaround to re-save the pre-trained model checkpoint from Kaggle into a local format that is compatible with the [Flax NNX](https://flax.readthedocs.io/en/stable/why.html) library. Because the original checkpoint has parameter names and tensor structures that don't match the target NNX model architecture, it cannot be loaded directly.

We first load the original weights into a temporary model instance, then extract and re-save the model's state into a new, properly formatted local checkpoint, which can then be successfully loaded by the final sharded NNX model.

In [None]:

import wandb
wandb.init()
!wandb login

!rm /kaggle/working/intermediate_ckpt/* -rf

!rm /kaggle/working/ckpts/* -rf

model_family = "gemma3"
if model_family == "gemma3":
  MODEL_CP_PATH = params.GEMMA3_1B_IT
  config = model.ModelConfig.gemma3_1b()
  gemma = params.create_model_from_checkpoint(MODEL_CP_PATH, config)
  tokenizer = params.create_tokenizer()

  checkpointer = ocp.StandardCheckpointer()
  _, state = nnx.split(gemma)
  checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)
  checkpointer.wait_until_finished()
  # Delete the intermediate model to save memory.
  del params
  del gemma
  del state
  gc.collect()

### Model Loading and LoRA Application

These two functions work together to load a base model from a checkpoint and apply a LoRA (Low-Rank Adaptation) layer to it.

* `get_ref_model`: Loads the complete Gemma model from a specified checkpoint path. It uses **JAX sharding** to distribute the model parameters across multiple devices.
* `get_lora_model`: Takes the base model and applies LoRA layers to it. It uses a `LoraProvider` to select specific layers (like attention and MLP layers) to be adapted. The resulting LoRA-infused model is then sharded and updated to ensure it's ready for distributed training.

In [None]:
from tunix.models.gemma3 import params

def get_gemma_ref_model(ckpt_path):
  mesh = jax.make_mesh(*MESH)
  model_config = model.ModelConfig.gemma3_1b()
  abs_gemma: nnx.Module = nnx.eval_shape(
      lambda: params.create_model_from_checkpoint(MODEL_CP_PATH, config)
  )

  abs_state = nnx.state(abs_gemma)
  abs_state = jax.tree.map(
      lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
      abs_state,
      nnx.get_named_sharding(abs_state, mesh),
  )
  checkpointer = ocp.StandardCheckpointer()
  restored_params = checkpointer.restore(ckpt_path, target=abs_state)

  graph_def, _ = nnx.split(abs_gemma)
  gemma = nnx.merge(graph_def, restored_params)
  return gemma, mesh, model_config


def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

Now we load reference and policy Gemma models using the Flax NNX library and display their structures.

In [None]:
# Reference model
if model_family == "gemma3":
  ref_model, mesh, model_config = get_gemma_ref_model(
      ckpt_path=os.path.join(INTERMEDIATE_CKPT_DIR, "state")
  )

In [None]:
# Policy model
lora_policy = get_lora_model(ref_model, mesh=mesh)
# nnx.display(lora_policy)

## Define reward functions

We define four reward functions:

- reward if the format of the output exactly matches the instruction given in
`TEMPLATE`;
- reward if the format of the output approximately matches the instruction given
in `TEMPLATE`;
- reward if the answer is correct/partially correct;
- reward if the final solution improves the inital solution;
- Sometimes, the text between the solution tags might not be one
  number. So, we extract the number, and reward the model if the answer is correct.

The reward functions are inspired from
[here](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).

First off, let's define a RegEx for checking whether the format matches.

In [None]:
import re 
match_numbers = re.compile(
    rf"{answer_first_start}\s*(.*?)\s*{answer_first_end}"  # Group 1: First Answer
    r".*?"                                                # Non-greedy gap
    rf"{solution_start}\s*(.*?)\s*{solution_end}", # Group 2: Final Answer
    flags=re.DOTALL,
)

test_prompt = f"{reasoning_start}Let methink!{reasoning_end}{answer_first_start}1{answer_first_end}{solution_start}2{solution_end}"
m = match_numbers.search(
    test_prompt
)

In [None]:
test_prompt

In [None]:
for group in m.groups():
    print(group)

Give the model a reward of 4 points if the format matches exactly.

In [None]:
NUMBER_RE = re.compile(r"-?\d+(\.\d+)?$")

def match_numbers_exactly(prompts, completions, **kwargs):
    scores = []

    for response in completions:
        m = match_numbers.search(response)
        if m is None:
            scores.append(0.0)
            continue

        answer_1 = m.group(1).strip()
        answer_2 = m.group(2).strip()

        # If not 2 numeric solutions exists, give 0.0 reward
        if not NUMBER_RE.fullmatch(answer_1):
            scores.append(0.0)
            continue
        if not NUMBER_RE.fullmatch(answer_2):
            scores.append(0.0)
            continue

        # Passed all structural checks
        scores.append(4.0)

    return scores

In [None]:
prompts = [f"{reasoning_start}Let methink!{reasoning_end}{answer_first_start}1{answer_first_end}{solution_start}f{solution_end}",
 f"{reasoning_start}Let methink!{reasoning_end}{answer_first_start}f{answer_first_end}{solution_start}f{solution_end}",
f"{reasoning_start}Let methink!{reasoning_end}{answer_first_start}f{answer_first_end}{solution_start}f{solution_end}",
f"{reasoning_start}Let methink!{reasoning_end}{answer_first_start}1{answer_first_end}{solution_start}f{solution_end}",
f"{reasoning_start}Let methink!{reasoning_end}{answer_first_start}1{answer_first_end}{solution_start}2{solution_end}"]


for prompt in prompts:
    print(match_numbers_exactly(SYSTEM_PROMPT,[prompt]))

We also reward the model if the format of the output matches partially.

In [None]:
# change to match tags as match numbers does not make sense
def match_numbers_approximately(prompts, completions, **kwargs):
  scores = []

  for completion in completions:
    score = 0
    response = completion
    # Count how many keywords are seen - we penalize if too many!
    # If we see 1, then plus some points!
    score += 0.5 if response.count(reasoning_start) == 1 else -0.5
    score += 0.5 if response.count(reasoning_end) == 1 else -0.5
    score += 0.5 if response.count(reasoning_first_start) == 1 else -0.5
    score += 0.5 if response.count(reasoning_first_end) == 1 else -0.5
    score += 0.5 if response.count(reasoning_final_start) == 1 else -0.5
    score += 0.5 if response.count(reasoning_final_end) == 1 else -0.5
    score += 0.5 if response.count(answer_first_start) == 1 else -0.5
    score += 0.5 if response.count(answer_first_end) == 1 else -0.5
    score += 0.5 if response.count(solution_start) == 1 else -0.5
    score += 0.5 if response.count(solution_end) == 1 else -0.5
    scores.append(score)
  return scores

In [None]:
prompts = [f"{reasoning_start}{reasoning_first_start}Let methink!{reasoning_first_end}{answer_first_start}1{answer_first_end}{reasoning_final_start}I Think its like that!{reasoning_final_end}{reasoning_end}{solution_start}f{solution_end}"]

for prompt in prompts:
    print(match_numbers_approximately(SYSTEM_PROMPT,[prompt]))
    


Reward the model if the answer is correct and if it improved the initial answer. If the solution got worse, we punish the model. A reward is also given if the answer
does not match exactly, i.e., based on how close the answer is to the correct
value.

In [None]:
def check_answer(prompts, completions, answer, **kwargs):
    """
    completions: list of model outputs
    answer: single true numerical answer (final correct answer)
    """
    scores = []

    for comp in completions:
        match = match_numbers.search(comp)
        if not match:
            # Missing expected answer tags
            scores.append(0.0)
            continue

        first_ans, final_ans = match.groups()
        score = 0.0

        try:
            print(answer)
            first_val = float(first_ans.strip())
            final_val = float(final_ans.strip())
            true_val = float(answer.strip())            
            
            # Score final answer
            if final_val == true_val:
                score += 3.0
            elif 0.9 <= final_val / true_val <= 1.1:
                score += 0.5
            elif 0.8 <= final_val / true_val <= 1.2:
                score += 0.25
            else:
                score -= 1.0

            # Reward improvement or penalize regression
            dist_first = abs(first_val - true_val)
            dist_final = abs(final_val - true_val)
            if dist_final < dist_first:
                score += 0.5  # improvement bonus
            elif dist_final > dist_first:
                score -= 0.5  # regression penalty

        except:
            # Non-numeric answers
            score = -0.5

        scores.append(score)

    return scores

In [None]:
test_prompt = f"{reasoning_start}{reasoning_first_start}Let methink!{reasoning_first_end}{answer_first_start}1{answer_first_end}{reasoning_final_start}I Think its like that!{reasoning_final_end}{reasoning_end}{solution_start}2{solution_end}"
check_answer(SYSTEM_PROMPT,[test_prompt],answer="2")

Sometimes, the text between `<answer_first>` and `</answer_first>` and `<answer>` and `</answer>`might not be one
number; it can be a sentence. So, we extract the number and compare the answer.

In [None]:
match_numbers = re.compile(
    rf"{answer_first_start}\s*(.+?)\s*{answer_first_end}.*?"
    rf"{solution_start}\s*(.+?)\s*{solution_end}",
    flags=re.DOTALL,
)

text = f"{answer_first_start} 1 {answer_first_end}{solution_start} 2 {solution_end}"
match = match_numbers.search(text)

if match:
    first_answer, final_answer = match.groups()
    print("First answer:", first_answer.strip())
    print("Final answer:", final_answer.strip())

In [None]:
import re

def check_numbers(prompts, completions, answer, **kwargs):
    question = kwargs["question"]
    responses = completions
    
    # Extraction logic (ensure match_numbers is defined globally)
    extracted_responses = [
        (guess.group(1).strip(), guess.group(2).strip()) 
        if (guess := match_numbers.search(r)) is not None 
        else (None, None)
        for r in responses
    ]

    scores = []
    
    # Helper to clean "$" and other symbols
    def clean_val(text):
        try:
            return float(re.sub(r'[^\d.]', '', text))
        except:
            return None

    print("START ============================")
    
    for (first_raw, final_raw), true_answer in zip(extracted_responses, answer):
        # Default state
        status = "no_match"
        score = 0.0
        
        first_val = clean_val(first_raw)
        final_val = clean_val(final_raw)
        true_val = clean_val(str(true_answer))

        if first_val is not None and final_val is not None and true_val is not None:
            # 1. Base score: final answer correctness
            is_correct = (final_val == true_val)
            score = 1.5 if is_correct else 0.0

            # 2. Distance comparison
            dist_first = abs(first_val - true_val)
            dist_final = abs(final_val - true_val)

            if dist_final < dist_first:
                status = "improvement"
                score += 0.5
            elif dist_final > dist_first:
                status = "got_worse"
                score -= 0.5
            else:
                # Distances are equal
                status = "keep_good" if is_correct else "keep_bad"

        scores.append(score)
        
        # Log the specific outcome for the first sample in the batch
        if len(scores) == 1:
            print(f"Question: {question[0]}")
            print(f"True Answer: {true_val}")
            print(f"Full Response:\n{responses[0]}")
            print(f"Extracted: First={first_val}, Final={final_val}")
            print(f"Outcome Category: {status}")
            print(f"Score: {score}")

    print("END ==============================")
    return scores

In [None]:
check_numbers(SYSTEM_PROMPT,[test_prompt],["2"],question=["What is bigger, 1 or 2?"])
check_numbers(SYSTEM_PROMPT,[test_prompt],["1"],question=["What is bigger, 1 or 2?"])
test_prompt = f"{reasoning_start}{reasoning_first_start}Let methink!{reasoning_first_end}{answer_first_start}2{answer_first_end}{reasoning_final_start}I Think its like that!{reasoning_final_end}{reasoning_end}{solution_start}2{solution_end}"
check_numbers(SYSTEM_PROMPT,[test_prompt],["2"],question=["What is bigger, 1 or 2?"])
check_numbers(SYSTEM_PROMPT,[test_prompt],["1"],question=["What is bigger, 1 or 2?"])

## Evaluate


Before we train the model, let's evaluate the model on the test set so we can
see the improvement post training.

We evaluate it in two ways:

**Quantitative**

* **Keep Good**: percentage of samples for which the models first answer predicts the
correct final numerical answer and the final answer keeps this value.
* **Improvement**: percentage of samples for which the models first answer predicts the
wrong final numerical answer and the final answer improves this value.
* **Keep Bad**: percentage of samples for which the models first answer predicts the
wrong final numerical answer and the final answer keeps this value.
* **Got Worse**: percentage of samples for which the models first answer predicts the
correct final numerical answer and the final answer dismisses this value.
* **No Match**: percentage of samples for which the model outputs the
correct format, i.e., reasoning between the reasoning special tokens, and the
final answer between the \`\<start\_answer\>\`, \`\<end\_answer\>\` tokens.


**Qualitative**

We'll also print outputs for a few given questions so that we can compare the generated output later.


We define a helper function to generate an answer, given a prompt.

In [None]:
def generate(
    question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
  """Given prompt, generates text."""

  if isinstance(question, str):
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=question,
        ),
    ]
  else:
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=q,
        )
        for q in question
    ]

  out_data = sampler(
      input_strings=input_batch,
      max_generation_steps=768,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
      seed=seed if seed is not None else None,
      eos_tokens=[1,106],
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output

Another helper function for evaluation, which also keeps track of the different categories of the double checkings result.

In [None]:
import re

def check_numbers_eval(prompts, completions, answer, **kwargs):

    responses = completions
    extracted_responses = [
        (guess.group(1).strip(), guess.group(2).strip()) 
        if (guess := match_numbers.search(r)) is not None 
        else (None, None)
        for r in responses
    ]

    scores = []
    categories = []

    def clean_val(text):
        if text is None: return None
        try:
            # Remove symbols like $, commas, and units to allow float conversion
            cleaned = re.sub(r'[^\d.-]', '', text)
            return float(cleaned)
        except:
            return None

    for (first_raw, final_raw), true_answer in zip(extracted_responses, answer):
        status = "no_match"
        score = 0.0
        
        first_val = clean_val(first_raw)
        final_val = clean_val(final_raw)
        true_val = clean_val(str(true_answer))

        if first_val is not None and final_val is not None and true_val is not None:
            is_correct = (final_val == true_val)
            # Base Reward for correct final answer
            score = 1.5 if is_correct else 0.0
            
            dist_first = abs(first_val - true_val)
            dist_final = abs(final_val - true_val)

            if dist_final < dist_first:
                status = "improvement"
                score += 0.5
            elif dist_final > dist_first:
                status = "got_worse"
                score -= 0.5
            else:
                status = "keep_good" if is_correct else "keep_bad"
        
        scores.append(score)
        categories.append(status)
        
    return scores, categories

In [None]:
from collections import Counter

def evaluate_grpo_style(dataset, sampler, config):
    total_score = 0
    total_samples = 0
    # Dictionary to track counts of: improvement, got_worse, keep_good, keep_bad, no_match
    distribution = Counter()
    
    for batch in tqdm(dataset):
        questions = batch["question"]
        answers = batch["answer"]
        
        responses = generate(questions, sampler, **config)
        
        # Unpack both the numeric scores and the string categories
        scores, categories = check_numbers_eval(
            prompts=None, 
            completions=responses, 
            answer=answers, 
            question=questions
        )
        
        for score, category in zip(scores, categories):
            total_samples += 1
            total_score += score
            distribution[category] += 1
            
    avg_reward = total_score / total_samples
    
    # Calculate percentages for the distribution
    dist_pct = {k: (v / total_samples) * 100 for k, v in distribution.items()}
    
    # Accuracy is the % of cases where the final answer was correct.
    # In your logic, this is 'keep_good' plus any 'improvement' that resulted in a score >= 1.5
    # A safer way is to just look at the scores we already have:
    correct_final_count = sum(1 for s in scores if s >= 1.5) # You'd need to accumulate this in the loop
    # Let's do it via the distribution for simplicity:
    accuracy = dist_pct.get("keep_good", 0) + dist_pct.get("improvement_if_correct", 0) 
    # NOTE: If your 'improvement' category doesn't distinguish if the final was correct, 
    # stick to counting scores >= 1.5 inside the loop.
    
    # For now, let's return the full distribution dict so you can print it.
    return avg_reward, total_samples, accuracy, dist_pct

In [None]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=1792,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

Now let's see how the original model does on the test set. You can see the percentages of the mode outputs that are fully correct, partially correct and just correct in format. The following step might take couple of minutes to finish.

In [None]:
# The evaluation might take up to couple of minutes to finish.
# We unpack the 4 values: Reward, Sample Count, Accuracy, and the Distribution Dict
avg_reward, total, accuracy, dist_pct = evaluate_grpo_style(
    test_dataset,
    sampler,
    GENERATION_CONFIGS["greedy"],
)

# Printing the summary
print(f"\n{'='*40}")
print(f"TUNIX CHALLENGE: GRPO EVALUATION")
print(f"{'='*40}")
print(f"Average Reward: {avg_reward:.4f}")
print(f"Total Samples:  {total}")
print(f"Final Accuracy: {accuracy:.2f}%")

print(f"\nBehavior Distribution Breakdown:")
# We iterate through the categories to show how the model 'behaved'
for category in ["keep_good", "improvement", "keep_bad", "got_worse", "no_match"]:
    percentage = dist_pct.get(category, 0.0)
    # Formatting the label for readability
    label = category.replace('_', ' ').title()
    print(f" - {label:15}: {percentage:.2f}%")

print(f"{'='*40}")

## Train

Let's set up all the configs first - checkpointing, metric logging and training.
We then train the model.

In [None]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/grpo", flush_every_n_steps=20
)

In [None]:
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
      optimizer,
  )

In [None]:
# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        # metrics logging
        metrics_logging_options=metrics_logging_options,
        # checkpoint saving
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=1536,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=[1,106],
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

### Setting Up the GRPO Trainer

Now we initialize our system for training. First, we create an `RLCluster` instance, which brings together the **policy model (`actor`)**, a **reference model (`reference`)**, and a **tokenizer**. Our `actor` is a trainable LoRA model, while the `reference` is a fixed base model that we use to guide the training.

We then create a `GRPOLearner`, the specialized trainer that uses a list of **reward functions** to evaluate and optimize the model's output, completing the RL training setup.

Tunix trainers are integrated with [Weights & Biases](https://wandb.ai/) to help you visualize the training progress. You can choose how you want to use it:

**Option 1 (Type 1)**: If you're running a quick experiment or just testing things out, choose this. It creates a temporary, private dashboard right in your browser without requiring you to log in or create an account.

**Option 2 (Type 2)**: If you have an existing W&B account and want to save your project's history to your personal dashboard, choose this. You'll be prompted to enter your API key or log in.

In [None]:

# RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

# GRPO Trainer
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_numbers_exactly,
        match_numbers_approximately,
        check_answer,
        check_numbers,
    ],
    grpo_config=grpo_config,
)

The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. >10 minutes per step, please open a bug. Really appreciated!

In [None]:
with mesh:
  grpo_trainer.train(train_dataset)

## Evaluate

Let's evaluate our finetuned model!

In [None]:
# Load checkpoint first.
import re

# Find the latest checkpoint by listing directories in CKPT_DIR/actor
actor_ckpt_dir = os.path.join(CKPT_DIR, "actor")

latest_step = -1
if os.path.exists(actor_ckpt_dir):
  for item in os.listdir(actor_ckpt_dir):
    if os.path.isdir(os.path.join(actor_ckpt_dir, item)) and re.match(r'^\d+$', item):
      step = int(item)
      if step > latest_step:
        latest_step = step

if latest_step == -1:
  raise FileNotFoundError(f"No checkpoints found in {actor_ckpt_dir}")

print(f"Latest checkpoint step: {latest_step}")

wandb.init(project='tunix-eval')  # logging bug workaround

trained_ckpt_path = os.path.join(
    CKPT_DIR, "actor", str(latest_step), "model_params"
)

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    lora_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_policy, nnx.LoRAParam),
        trained_lora_params,
    ),
)

In [None]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=1792,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
# The evaluation might take up to couple of minutes to finish.
# We unpack the 4 values: Reward, Sample Count, Accuracy, and the Distribution Dict
avg_reward, total, accuracy, dist_pct = evaluate_grpo_style(
    test_dataset,
    sampler,
    GENERATION_CONFIGS["greedy"],
)

# Printing the summary
print(f"\n{'='*40}")
print(f"TUNIX CHALLENGE: GRPO EVALUATION")
print(f"{'='*40}")
print(f"Average Reward: {avg_reward:.4f}")
print(f"Total Samples:  {total}")
print(f"Final Accuracy: {accuracy:.2f}%")

print(f"\nBehavior Distribution Breakdown:")
# We iterate through the categories to show how the model 'behaved'
for category in ["keep_good", "improvement", "keep_bad", "got_worse", "no_match"]:
    percentage = dist_pct.get(category, 0.0)
    # Formatting the label for readability
    label = category.replace('_', ' ').title()
    print(f" - {label:15}: {percentage:.2f}%")

print(f"{'='*40}")

Here we zip the model in order to be able to download it in Kaggle.

In [None]:
import shutil
shutil.make_archive("/kaggle/working/results/trained_model", "zip", "/kaggle/working/ckpts")
