# GRPO Demo

This is a way of training the [Gemma](https://deepmind.google/models/gemma/)
3 1B-IT model on the [GSM8K math reasoning benchmark](https://huggingface.co/datasets/openai/gsm8k)
using [Group Relative Policy Optimization (GRPO)](https://arxiv.org/pdf/2402.03300).
GRPO can enhance your model's problem-solving skills on mathematical word problems,
coding problems, etc.

This notebook uses a `v5e-8` TPU for Gemma3-1b-it. 

In [None]:
# Check for accelerators 

def is_tpu():
    try:
        import torch_xla.core.xla_model as xm
        devices = xm.get_xla_supported_devices()
        return len(devices) > 0
    except ImportError:
        return False

def is_gpu():
    try:
        import torch
        return torch.cuda.is_available() and not is_tpu()
    except ImportError:
        return False

def is_cpu():
    return not is_gpu() and not is_tpu()

def installs_and_imports():
    import os
    os.environ["HF_HUB_DISABLE_XET"] = "1"
    if is_tpu() is not False:
        print("Running on TPU")
        !pip install -q kagglehub
        !pip install google-genai
        !pip install -q ipywidgets
        
        !pip install -q tensorflow
        !pip install -q tensorflow_datasets
        !pip install -q tensorboardX
        !pip install -q transformers
        !pip install -q grain
        !pip install "google-tunix[prod]==0.1.3"
        
        # !pip install -q git+https://github.com/google/tunix
        # !pip install -q git+https://github.com/google/qwix
        
        !pip uninstall -q -y flax
        # !pip install -U flax
        !pip install flax==0.12.0
        !pip install -q datasets wandb==0.22.0

        from tunix.generate import sampler as sampler_lib
        from tunix.generate import tokenizer_adapter as tokenizer_lib
    elif is_gpu() is not False:
        print("Running on GPU")
        !pip install -q kagglehub ipywidgets tensorflow tensorflow_datasets tensorboardX transformers
        !pip install -q grain
        !pip install "google-tunix[prod]==0.1.3"
        !pip uninstall -q -y flax
        !pip install flax==0.12.0
        !pip install -q datasets wandb==0.22.0
        # GPU JAX wheels (replace XXX with your CUDA version)
        !pip install --upgrade jax jaxlib==0.4.26+cudaXXX -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    else:
        print("Running on CPU, nothing will work :(")
        !pip install -q kagglehub ipywidgets tensorflow tensorflow_datasets tensorboardX transformers
        !pip install -q grain
        !pip install "google-tunix[prod]==0.1.3"
        !pip uninstall -q -y flax
        !pip install flax==0.12.0
        !pip install -q datasets wandb==0.22.0
        !pip install jax jaxlib


installs_and_imports()

## Install necessary libraries

In [None]:
import wandb, os
from kaggle_secrets import UserSecretsClient
os.environ['WANDB_API_KEY'] = UserSecretsClient().get_secret("WANDB_API_KEY")

## Imports

In [None]:
import functools
import gc
import os
from pprint import pprint
import re

import csv
import shutil

#When using the version flax==0.12.0, this import works only if TPU is enabled
#https://flax.readthedocs.io/en/latest/nnx_basics.html
from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm

# from tunix.models.gemma3 import model as gemma_lib
# from tunix.models.gemma3 import params as params_lib
from tunix.models.gemma3 import params
from tunix.models.gemma3 import model
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from datasets import load_dataset
from google import genai
from pydantic import BaseModel, Field
from typing import List, Optional

## Configuration of Hyperparameters

The choice of the hyperparameters is tricky because I cant use gridsearch or optuna to optimize them, it would be very expensive. The model size is small, so I though maybe to use a smaller rank, on the other hand the task we mean to train is a big change of policy, so I will keep the rank at 32

In [None]:
class Configs:
    
    # ====== Data ======
    TRAIN_DATA_DIR = "./data/train"
    TEST_DATA_DIR = "./data/test"
    TRAIN_FRACTION = 0.8
    
    # ====== LoRA ======
    RANK = 32
    ALPHA = 32.0
    
    # ====== Sharding ======
    MESH = [(1, 4), ("fsdp", "tp")]
    
    # ====== GRPO ======
    # === Generation during GRPO training ===
    MAX_PROMPT_LENGTH = 256
    TOTAL_GENERATION_STEPS = 512
    
    # Important to keep a high-ish temperature for varied, diverse responses during
    # training.
    TEMPERATURE = 0.9
    TOP_P = 0.95
    TOP_K = 50

    # The number of times the policy generates multiple responses for a given prompt within a single training step. 
    # So basically the number of groups, I will set it to 16. Originally google kept it as 4, but seems to litle for such a task
    NUM_GENERATIONS = 16
    
    # === other GRPO configs ===
    # The number of iterations per batch (ùúá in GRPO algo 1).
    NUM_ITERATIONS = 1
    # The coefficient for the KL divergence penalty (ùõΩ) in the GRPO loss function.
    # Important to keep a high enough value for this, otherwise, the KL divergence
    # can increase unchecked and the model may overfit or suffer from catastrophical forgetting.
    # But using LoRa is that really a risk ? we wont even touch the the original matrix/policy..... lets keep it lower, say 0.04 (they had 0.08)
    BETA = 0.04
    
    # Epsilon value for clipping. Similar to PPO, for
    # stable updates.
    EPSILON = 0.2
    
    # ====== Training ======
    TRAIN_MICRO_BATCH_SIZE = 4
    # Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
    NUM_BATCHES = 3738
    # Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
    # increased to a max. of 330 (if batch size is 4).
    NUM_TEST_BATCHES = 100
    
    EVAL_EVERY_N_STEPS = 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
    NUM_EPOCHS = 2  # each epoch takes around 1.5 hours
    
    # Number of training steps.
    MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)
    
    # === AdamW, warmup, cosine scheduler ===
    LEARNING_RATE = 3e-6
    B1 = 0.9
    B2 = 0.99
    WEIGHT_DECAY = 0.1
    # == Cosine decay with warmup scheduler ==
    # Linearly increase learning rate from 0. to 5e-6 in the first 10% training
    # steps, and then gradually decrease the learning rate to 0 using cosine
    # scheduler.
    WARMUP_STEPS = 0.1 * MAX_STEPS
    # == Grad clipping ==
    # Grad clipping to prevent large gradients. Found this
    # important to keep KL divergence in check.
    MAX_GRAD_NORM = 0.1
    
    # Checkpoint saving
    INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
    CKPT_DIR = "/tmp/content/ckpts/"
    SAVE_INTERVAL_STEPS = 500
    MAX_TO_KEEP = 4

    # ====== Inference ======
    GENERATION_CONFIGS = {
        # greedy search
        "greedy": {"temperature": 1e-3, "top_k": 3, "top_p": 1.0},
        # some randomness
        "standard": {"temperature": 0.8, "top_k": 50, "top_p": 0.95},
        # liberal
        "liberal": {"temperature": 0.95, "top_k": 2000, "top_p": 1.0},
    }

    # Prompt stuff
    #let's define some special tokens. We instruct the model to first reason
    #between the `<reasoning>` and `</reasoning>` tokens. After
    #reasoning, we expect it to provide the answer between the `<answer>` and
    #`</answer>` tokens.
    reasoning_start = "<reasoning>"
    reasoning_end = "</reasoning>"
    solution_start = "<answer>"
    solution_end = "</answer>"
    
    
    SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \
    provide your reasoning. Place it between {reasoning_start} and \
    {reasoning_end}. Then, provide the final answer (i.e., just one numerical \
    value) between {solution_start} and {solution_end}."""

    JUDGE_PROMPT = """You are a reasoning judge. You will get a user request and a candidate response.  
                    Your job: assign rewards based on **reasoning process**, not just correctness of answer.  
                    
                    Three criteria for reward:  
                    1. **Gathering and organizing key information**: how well candidate identifies the relevant facts and constraints from the problem.  
                    2. **Analyzing possible action paths**: how well candidate explores different strategies, considers consequences, iterates logically.  
                    3. **Choosing an action path and providing an answer**: how well candidate uses their analysis to propose a solution, even if incomplete, and iterates if needed.  
                    
                    Important instructions:  
                    - Focus on **process** over final answer correctness.  
                    - Correct answers matter less; reward structured, thoughtful reasoning more.  
                    - If candidate iterates, considers multiple options, refines solution ‚Üí higher reward.  
                    - Output **JSON only** matching schema exactly:  
                      ```json
                      {
                        "reward_info_gathering_and_organization": <int 0-10>,
                        "reward_analyse_possible_action_paths": <int 0-10>,
                        "reward_choose_an_action_path_and_return_an_answer_related_to_it": <int 0-10>
                      }
                    """
    
    TEMPLATE = """<start_of_turn>user
    {system_prompt}
    
    {question}<end_of_turn>
    <start_of_turn>model"""

## Utility functions

In [None]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

## Data preprocessing

#### I use OpenAI's [GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k), which comprises grade school math word problems.

In [None]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()


def _load_from_tfds(data_dir: str, split: str):
  import tensorflow_datasets.text.gsm8k
  return tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )


def download_kaggle_dataset(target_dir="./data/gsm8k"):
  os.makedirs(target_dir, exist_ok=True)
  src = kagglehub.dataset_download("thedevastator/grade-school-math-8k-q-a")
  src = Path(src)
  dst = Path(target_dir)

  for csv_file in src.glob("*.csv"):  # match all CSV files
    shutil.copy2(csv_file, dst / csv_file.name)
    print(f"Copied {csv_file.name} ‚Üí {dst/csv_file.name}")
  return target_dir


def get_dataset(data_dir, split="train", source="tfds") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  if source == "tfds":
    import tensorflow_datasets.text.gsm8k
    data = tfds.data_source(
        "gsm8k",
        split=split,
        data_dir=data_dir,
        builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
        download=True,
    )

  elif source == "kaggle":
    kaggle_dir = download_kaggle_dataset(data_dir)
    file_name = "main_" + split + ".csv"
    csv_path = os.path.join(kaggle_dir, file_name)  # adjust filename if needed

    data = []
    with open(csv_path, newline="", encoding="utf-8") as csvfile:
      reader = csv.DictReader(csvfile)
      for row in reader:
        data.append({
            "question": row["question"],
            "answer": row["answer"],
        })

  elif source == "huggingface":    
    os.environ["HF_HUB_DISABLE_XET"] = "1"
    data = load_dataset("gsm8k", "main", split=split)
      
  else:
    raise ValueError(f"Unknown source: {source}")

  def _as_text(v):
    return v if isinstance(v, str) else v.decode("utf-8")

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": Configs.TEMPLATE.format(
                  system_prompt=Configs.SYSTEM_PROMPT,
                  question=_as_text(x["question"]),
              ),
              # passed to reward functions
              "question": _as_text(x["question"]),
              # passed to reward functions
              "answer": extract_hash_answer(_as_text(x["answer"])),
          }
      )
  )
  return dataset

## I split the dataset set into train and test sets as usual.

In [None]:
# source = input("Choose data source [tfds/kaggle]: ").strip().lower()
Configs.Source = "huggingface"

if Configs.Source not in ("tfds", "kaggle", "huggingface"):
  print("Invalid choice. Defaulting to 'tfds'.")
  source = ""

print(f"Using data source: {Configs.Source}")

dataset = get_dataset(Configs.TRAIN_DATA_DIR, "train", Configs.Source).batch(Configs.TRAIN_MICRO_BATCH_SIZE)[
    :Configs.NUM_BATCHES
]

if Configs.TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(Configs.NUM_EPOCHS)
  val_dataset = None
else:
  train_dataset = dataset[: int(len(dataset) * Configs.TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(Configs.NUM_EPOCHS)

  val_dataset = dataset[int(len(dataset) * Configs.TRAIN_FRACTION) :].repeat(Configs.NUM_EPOCHS)

test_dataset = get_dataset(Configs.TEST_DATA_DIR, "test", Configs.Source).batch(Configs.TRAIN_MICRO_BATCH_SIZE)[
    :Configs.NUM_TEST_BATCHES
]

dataset_lengths = (
    len(train_dataset),
    len(val_dataset) if val_dataset is not None else 0,
    len(test_dataset),
)
print(f"dataset contains {dataset_lengths} of batches")

## Let's see how one batch of the training dataset looks like!


In [None]:
for ele in train_dataset[:1]:
  pprint(ele)

## Load the policy model and the reference model

The policy model is the model which is actually trained and whose weights are
updated. The reference model is the model with which we compute KL divergence.

Typically, the reference model is the base model, and the policy model is the
same base model, but with the LoRA parameters added to it. Only the LoRA parameters are updated. So (if you are familiar with the deepseek notation [DeepSeek](https://arxiv.org/pdf/2501.12948) ) Pi_old and Pi_theta will have the full model + the LoRa weights and the Pi_ref is just the original model. 

Note: We perform full precision (fp32) training. You can, however, leverage
Qwix for QAT.

To load the model, you need to be on [Kaggle](https://www.kaggle.com/) and need
to have agreed to the Gemma license
[here](https://www.kaggle.com/models/google/gemma/flax/).

In [None]:
os.environ["KAGGLE_KEY"] = UserSecretsClient().get_secret("KAGGLE_KEY")
os.environ["KAGGLE_USERNAME"] = UserSecretsClient().get_secret("KAGGLE_USERNAME")

In [None]:
# Log in
if "KAGGLE_USERNAME" not in os.environ.keys() or "KAGGLE_KEY" not in os.environ.keys():
  kagglehub.login()

This code snippet serves as a workaround to re-save the pre-trained model checkpoint from Kaggle into a local format that is compatible with the [Flax NNX](https://flax.readthedocs.io/en/stable/why.html) library. Because the original checkpoint has parameter names and tensor structures that don't match the target NNX model architecture, it cannot be loaded directly.

So let's first load the original weights into a temporary model instance, then extract and re-save the model's state into a new, properly formatted local checkpoint, which can then be successfully loaded by the final sharded NNX model.

In [None]:
!rm /tmp/content/intermediate_ckpt/* -rf

!rm /tmp/content/ckpts/* -rf

model_family = "gemma3"
if model_family == "gemma3":
  Configs.MODEL_CP_PATH = params.GEMMA3_1B_IT
  config = model.ModelConfig.gemma3_1b()
  gemma = params.create_model_from_checkpoint(Configs.MODEL_CP_PATH, config)
  tokenizer = params.create_tokenizer()

  checkpointer = ocp.StandardCheckpointer()
  ## nnx.split splits the single module instance into two separate objects:
  ##1 - Variables/state (params, batch stats, rngs, etc.)
  ##2 - The pure callable module (the ‚Äúfunction‚Äù part)
  _, state = nnx.split(gemma)
  checkpointer.save(os.path.join(Configs.INTERMEDIATE_CKPT_DIR, "state"), state)
  checkpointer.wait_until_finished()
  # Delete the intermediate model to save memory.
  del params
  del gemma
  del state
  gc.collect()

### Model Loading and LoRA Application

These two functions work together to load a base model from a checkpoint and apply a LoRA (Low-Rank Adaptation) layer to it.

* `get_ref_model`: Loads the complete Gemma model from a specified checkpoint path. It uses **JAX sharding** to distribute the model parameters across multiple devices (usually TPU cores).
* 
* `get_lora_model`: Takes the base model and applies LoRA layers to it. It uses a `LoraProvider` to select specific layers/transformer-sub-modules (like attention and MLP layers) to be adapted. The resulting LoRA-infused model is then sharded and updated to ensure it's ready for distributed training.

In [None]:
from tunix.models.gemma3 import params

def get_gemma_ref_model(ckpt_path):
  # build device mesh using config (e.g. fsdp/tp layout)
  mesh = jax.make_mesh(*Configs.MESH)

  # pick model config for Gemma 3 1B
  model_config = model.ModelConfig.gemma3_1b()

  # build abstract model (only shapes, no real weights)
  abs_gemma: nnx.Module = nnx.eval_shape(
      lambda: params.create_model_from_checkpoint(Configs.MEL_CP_PATH, config)
  )

  # extract state tree from model (params/buffers)
  abs_state = nnx.state(abs_gemma)

  # replace each param with shape-dtype struct and attach sharding info
  abs_state = jax.tree.map(
      lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
      abs_state,
      nnx.get_named_sharding(abs_state, mesh),
  )

  # create checkpointer for loading weights
  checkpointer = ocp.StandardCheckpointer()

  # load real params from checkpoint into abstract state
  restored_params = checkpointer.restore(ckpt_path, target=abs_state)

  # split model into graph (structure) and empty param tree
  graph_def, _ = nnx.split(abs_gemma)

  # merge graph structure with restored real params to make runnable model
  gemma = nnx.merge(graph_def, restored_params)

  # return fully built model + mesh + config
  return gemma, mesh, model_config


def get_lora_model(base_model, mesh):
    # define function take base model and mesh to add LoRA

    lora_provider = qwix.LoraProvider(
        # create LoRA provider, target modules by regex
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
            ".*attn_vec_einsum"
        ),
        rank=Configs.RANK,   # LoRA low-rank size
        alpha=Configs.ALPHA, # LoRA scaling factor
    )

    model_input = base_model.get_model_input()
    # extract model input info (shapes/dtypes) for LoRA injection

    lora_model = qwix.apply_lora_to_model(
        base_model, lora_provider, **model_input
    )
    # apply LoRA adapters, return new model with LoRA layers attached

    with mesh:
        state = nnx.state(lora_model)
        # extract PyTree of model parameters + buffers

        pspecs = nnx.get_partition_spec(state)
        # get sharding specification for each param on the mesh

        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        # attach sharding constraints, ensures correct device placement

        nnx.update(lora_model, sharded_state)
        # update model with sharded weights, ready for forward/backward

    return lora_model
    # return LoRA-augmented, sharded model


In [None]:
import flax
flax.__version__

## Load the reference and policy Gemma models using the Flax NNX library and display their structures. 

### Remember that the reference Model will stay untouched, it is only used to compute the KL Divergence at each time step, while the policy model uses LoRa, meaning it is the model whose weights will change.

In [None]:
# Reference model
if model_family == "gemma3":
  ref_model, mesh, model_config = get_gemma_ref_model(
      ckpt_path=os.path.join(Configs.INTERMEDIATE_CKPT_DIR, "state")
  )

In [None]:
# Policy model
lora_policy = get_lora_model(ref_model, mesh=mesh)
# nnx.display(lora_policy)

## Define reward functions

The original Tunix team defined four reward functions:

- reward if the format of the output exactly matches the instruction given in
`TEMPLATE`;
- reward if the format of the output approximately matches the instruction given
in `TEMPLATE`;
- reward if the answer is correct/partially correct;
- Sometimes, the text between `<answer>`, `</answer>` might not be one
  number. So, we extract the number, and reward the model if the answer is correct.

The reward functions are inspired from
[here](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb), but they also seem to be somewhat in line with the [deepseek](https://arxiv.org/pdf/2501.12948) paper (section 2.2.2).

Far√° sentido beneficiar assim tanto o formato ? talvez, mas tenho de verificar se o formato √© avaliado a dobrar, pq podem usar a primeira e a segunda reward funtion, o que significa que o total de recompensa pelo formato completo √© de 7 valores (parece me alto). De resto √© assim que devemos ensinar o modelo a pensar a um n√≠vel mais profundo? verificar que nos deixou tokens a dizer que houve reasoning ou n√£o ? ser√° que √© assim que vamos aproximar o racioc√≠nio da m√°quina ao nosso ?

## A parte das reward funtions √© a mais importante do treino, √© o que vai definir o objectivo do treino, por isso devemos recompensar racioc√≠nio que se assemlhe ao nosso e n√£o apenas formato e respostas certas. 

### Como raciocinamos n√≥s ? 
Problema : Tens tr√™s interruptores fora de uma sala fechada ‚Äî A, B e C ‚Äî e dentro da sala h√° tr√™s l√¢mpadas (cada interruptor controla exactamente uma l√¢mpada). Podes mexer nos interruptores o quanto quiseres, mas s√≥ podes entrar na sala uma vez (depois de entrares n√£o podes voltar a mexer nos interruptores). Como descobres qual interruptor controla qual l√¢mpada?

Linha de pensamento : 
##### 1 - recolher e organizar a informa√ß√£o chave 
(Cen√°rio : 3 interruptores ligados cada um a uma l√¢mpada apenas. L√¢mpadas dentro de uma sala fechada. Objectivo : Descobrir que interruptor se liga a qual l√¢mpada. Constragimento : s√≥ posso entrar na sala uma vez e n√£o posso mexer nos interruptores depois de o fazer.
##### 2 - Analisar caminhos de a√ß√£o poss√≠veis com base nessa info e no meu conhecimento
( Cen√°rio de a√ß√£o 1 : ligar todos os interruptores, leva a uma inconlus√£o porque continua a ser imposs√≠vel associar um interruptor a uma l√¢mpada, Cen√°rio 2 : ligar apenas uma l√¢mpada; tb ser√° inconclusivo pois s√≥ me d√° info sobre uma das associa√ß√µes . Cen√°rio 3 : ligar duas l√™mpadas, idem... etc)
##### 3 - Com base na an√°lise providenciar uma resposta final. 
Neste caso, n√£o sei como o fazer, isto √© uma conclus√£o v√°lida que me vai reencaminhar de novo para o ponto 2, procurar novos cen√°rio de a√ß√£o. Se eu fosse uma m√°quina continuaria a iterar at√© acabar.

### Como atribuir recompensas quando os crit√©rios s√£o mais vagos ?

Seria muito mais interessante avaliar as respostas e atribuir uma recompensa com base nestes crit√©rios, ou seja, ver se o LLM recolheu e organizou info chave, analisou caminhos de a√ß√£o poss√≠veis com base na info chave e no seu conhecimento e por fim se tentou chegar a uma resposta (iterando quando n√£o a achou). 

Aqui, se calhar, at√© deixaria a resposta certa com um peso muito mais baixo, pouco me interessa que ele acerte as perguntas durante o treino, possivelmente j√° as viu, quero √© dar √™nfase ao processo de reflex√£o. Embora a resposta certa deva estar presente na Reward Funtion, porque assume-se que um bom processo de pensamento leve mais vezes a uma resposta certa.

Se calhar a forma mais eficiente de atribuir uma recompensa neste caso √© mesmo usando um LLM as a judge, porque √© imposs√≠vel avaliar isto se outra forma; a n√£o ser que construamos um classificador que o fa√ßa.



In [None]:

class Rewards(BaseModel):
    reward_info_gathering_and_organization: int = Field(description="the reward associated with the ability to gather info and organize it")
    reward_analyse_possible_action_paths: int = Field(description="the reward associated with the analysis possible action paths")
    reward_choose_an_action_path_and_return_an_answer_related_to_it: int = Field(description="The reward associated with choosing an action path and the returning of an answer related to it")
    
class Gemini_judge:

    
    def __init__(self):
        
        self.client = genai.Client(api_key="GOOGLE_YOUR_API_KEY")

        self.major_judge_prompt = Configs.JUDGE_PROMPT


    
    def llm_as_judge(self,input_prompt: str):
    
        response = self.client.models.generate_content(
            model="gemini-2.5-flash",
            contents=input_prompt,
            config={
                "response_mime_type": "application/json",
                "response_json_schema": Rewards.model_json_schema(),
            },
        )
        
        rewards = Rewards.model_validate_json(response.text)
        return rewards

In [None]:
# First off, let's define a RegEx for checking whether the format matches.
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{Configs.reasoning_start}.+?{Configs.reasoning_end}.*?"
    rf"{Configs.solution_start}(.+?){Configs.solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_format.search(
    f"{Configs.reasoning_start}Let me"
    f" think!{Configs.reasoning_end}{Configs.solution_start}2{Configs.solution_end}",
)

Give the model a reward of 3 points if the format matches exactly.

In [None]:
def match_format_exactly(prompts, completions, **kwargs):
  return [
      0 if match_format.search(response) is None else 3.0
      for response in completions
  ]

Reward the model if the format of the output matches partially.

In [None]:
def match_format_approximately(prompts, completions, **kwargs):
  scores = []

  for completion in completions:
    score = 0
    response = completion
    # Count how many keywords are seen - we penalize if too many!
    # If we see 1, then plus some points!
    score += 0.5 if response.count(Configs.reasoning_start) == 1 else -0.5
    score += 0.5 if response.count(Configs.reasoning_end) == 1 else -0.5
    score += 0.5 if response.count(Configs.solution_start) == 1 else -0.5
    score += 0.5 if response.count(Configs.solution_end) == 1 else -0.5
    scores.append(score)
  return scores

### Reward the model if the answer is correct. A reward is also given if the answer does not match exactly, i.e., based on how close the answer is to the correct value.

In [None]:
def check_answer(prompts, completions, answer, **kwargs):
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_format.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  assert len(extracted_responses) == len(
      answer
  ), f"{extracted_responses} and {answer} have mismatching length"
  for guess, true_answer in zip(extracted_responses, answer):
    score = 0
    if guess is None:
      scores.append(0)
      continue
    # Correct answer gets 3 points!
    if guess == true_answer:
      score += 3.0
    # Match if spaces are seen
    elif guess.strip() == true_answer.strip():
      score += 1.5
    else:
      # We also reward it if the answer is close via ratios!
      # Ie if the answer is within some range, reward it!
      try:
        ratio = float(guess) / float(true_answer)
        if ratio >= 0.9 and ratio <= 1.1:
          score += 0.5
        elif ratio >= 0.8 and ratio <= 1.2:
          score += 0.25
        else:
          score -= 1.0  # Penalize wrong answers
      except:
        score -= 0.5  # Penalize
    scores.append(score)
  return scores

Sometimes, the text between `<answer>` and `</answer>` might not be one
number; it can be a sentence. So, we extract the number and compare the answer.

In [None]:
match_numbers = re.compile(
    rf"{Configs.solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall(f"{Configs.solution_start}  0.34  {Configs.solution_end}")

In [None]:
def check_numbers(prompts, completions, answer, **kwargs):
  question = kwargs["question"]
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_numbers.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  print("START ============================")
  print(f"Question: {question[0]}")
  print(f"Answer: {answer[0]}")
  print(f"Response: {responses[0]}")
  print(f"Extracted: {extracted_responses[0]}")
  print("END ==============================")
  for guess, true_answer in zip(extracted_responses, answer):
    if guess is None:
      scores.append(0)
      continue
    # Convert to numbers
    try:
      true_answer = float(true_answer.strip())
      guess = float(guess.strip())
      scores.append(1.5 if guess == true_answer else 0.0)
    except:
      scores.append(0)
      continue
  return scores

## Evaluate


Before we train the model, let's evaluate the model on the test set so we can
see the improvement post training.

We evaluate it in two ways:

**Quantitative**

* **Answer Accuracy**: percentage of samples for which the model predicts the
correct final numerical answer  
* **Answer (Partial) Accuracy**: percentage of samples for which the model
predicts a final numerical answer such that the \`model answer / answer\`
ratio lies between 0.9 and 1.1.  
* **Format Accuracy**: percentage of samples for which the model outputs the
correct format, i.e., reasoning between the reasoning special tokens, and the
final answer between the \`\<start\_answer\>\`, \`\<end\_answer\>\` tokens.

**Qualitative**

We'll also print outputs for a few given questions so that we can compare the generated output later.


We define a helper function to generate an answer, given a prompt.

In [None]:
def generate(
    question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
  """Given prompt, generates text."""

  if isinstance(question, str):
    input_batch = [
        Configs.TEMPLATE.format(
            system_prompt=Configs.SYSTEM_PROMPT,
            question=question,
        ),
    ]
  else:
    input_batch = [
        Configs.TEMPLATE.format(
            system_prompt=Configs.SYSTEM_PROMPT,
            question=q,
        )
        for q in question
    ]

  out_data = sampler(
      input_strings=input_batch,
      max_generation_steps=768,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
      seed=seed if seed is not None else None,
      eos_tokens=[1,106],
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output

Another helper function for evaluation.

In [None]:
def evaluate(
    dataset,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy and percentage of outputs matching the format."""

  response_lst = []
  corr = 0
  partially_corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    expected_answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, sampler, temperature, top_k, top_p, seed=p
      )
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, expected_answer in zip(
        questions, multiple_call_responses, expected_answers
    ):
      # check answer
      corr_ctr_per_question = 0
      partially_corr_per_question = 0
      corr_format_per_question = 0
      for response in multiple_call_response:
        extracted_response = (
            guess.group(1)
            if (guess := match_numbers.search(response)) is not None
            else "-1000000"
        )
        try:
          if float(extracted_response.strip()) == float(expected_answer.strip()):
            corr_ctr_per_question += 1

          ratio = float(extracted_response.strip()) / float(expected_answer.strip())
          if ratio >= 0.9 and ratio <= 1.1:
            partially_corr_per_question += 1
        except:
          print("SKIPPED")

        # check format
        if match_format.search(response) is not None:
          corr_format_per_question += 1

        if (
            corr_ctr_per_question > 0
            and partially_corr_per_question > 0
            and corr_format_per_question > 0
        ):
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, expected_answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, expected_answer, multiple_call_response))
      if partially_corr_per_question > 0:
        partially_corr += 1
      if corr_format_per_question > 0:
        corr_format += 1

      total += 1
      if total % 10 == 0:
        print(
            f"===> {corr=}, {total=}, {corr / total * 100=}, "
            f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
        )

  to_return = (
      corr,
      total,
      corr / total * 100,
      partially_corr / total * 100,
      corr_format / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [None]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=Configs.MAX_PROMPT_LENGTH + Configs.TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

Now let's see how the original model does on the test set. You can see the percentages of the mode outputs that are fully correct, partially correct and just correct in format. The following step might take couple of minutes to finish.

In [None]:
# The evaluation might take up to couple of minutes to finish. Please be patient.

(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **Configs.GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

# Train

Let's set up all the configs first - checkpointing, metric logging and training.
We then train the model.

In [None]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=Configs.SAVE_INTERVAL_STEPS, max_to_keep=Configs.MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/grpo", flush_every_n_steps=20
)

In [None]:
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=Configs.LEARNING_RATE,
        warmup_steps=Configs.WARMUP_STEPS,
        decay_steps=Configs.MAX_STEPS,
        end_value=0.0,
    ),
    b1=Configs.B1,
    b2=Configs.B2,
    weight_decay=Configs.WEIGHT_DECAY,
)
if Configs.MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=Configs.MAX_GRAD_NORM),
      optimizer,
  )

In [None]:
# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=Configs.EVAL_EVERY_N_STEPS,
        max_steps=Configs.MAX_STEPS,
        mini_batch_size=Configs.TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=Configs.TRAIN_MICRO_BATCH_SIZE,
        # metrics logging
        metrics_logging_options=metrics_logging_options,
        # checkpoint saving
        checkpoint_root_directory=Configs.CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=Configs.TOTAL_GENERATION_STEPS,
        max_prompt_length=Configs.MAX_PROMPT_LENGTH,
        kv_cache_size=Configs.MAX_PROMPT_LENGTH + Configs.TOTAL_GENERATION_STEPS + 256,
        temperature=Configs.TEMPERATURE,
        top_p=Configs.TOP_P,
        top_k=Configs.TOP_K,
        eos_tokens=[1,106],
    ),
)

grpo_config = GRPOConfig(
    num_generations=Configs.NUM_GENERATIONS,
    num_iterations=Configs.NUM_ITERATIONS,
    beta=Configs.BETA,
    epsilon=Configs.EPSILON,
)

### Setting Up the GRPO Trainer

Now we initialize our system for training. First, we create an `RLCluster` instance, which brings together the **policy model (`actor`)**, a **reference model (`reference`)**, and a **tokenizer**. Our `actor` is a trainable LoRA model, while the `reference` is a fixed base model that we use to guide the training.

We then create a `GRPOLearner`, the specialized trainer that uses a list of **reward functions** to evaluate and optimize the model's output, completing the RL training setup.

Tunix trainers are integrated with [Weights & Biases](https://wandb.ai/) to help you visualize the training progress. You can choose how you want to use it:

**Option 1 (Type 1)**: If you're running a quick experiment or just testing things out, choose this. It creates a temporary, private dashboard right in your browser without requiring you to log in or create an account.

**Option 2 (Type 2)**: If you have an existing W&B account and want to save your project's history to your personal dashboard, choose this. You'll be prompted to enter your API key or log in.

In [None]:
# RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

# GRPO Trainer
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    grpo_config=grpo_config,
)

The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. >10 minutes per step, please open a bug. Really appreciated!

In [None]:
with mesh:
  grpo_trainer.train(train_dataset)

## Evaluate

Let's evaluate our finetuned model!

In [None]:
# Load checkpoint first.
import re

# Find the latest checkpoint by listing directories in CKPT_DIR/actor
actor_ckpt_dir = os.path.join(Configs.CKPT_DIR, "actor")

latest_step = -1
if os.path.exists(actor_ckpt_dir):
  for item in os.listdir(actor_ckpt_dir):
    if os.path.isdir(os.path.join(actor_ckpt_dir, item)) and re.match(r'^\d+$', item):
      step = int(item)
      if step > latest_step:
        latest_step = step

if latest_step == -1:
  raise FileNotFoundError(f"No checkpoints found in {actor_ckpt_dir}")

print(f"Latest checkpoint step: {latest_step}")

wandb.init(project='tunix-eval')  # logging bug workaround

trained_ckpt_path = os.path.join(
    Configs.CKPT_DIR, "actor", str(latest_step), "model_params"
)

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    lora_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_policy, nnx.LoRAParam),
        trained_lora_params,
    ),
)

In [None]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=Configs.MAX_PROMPT_LENGTH + Configs.TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
# The evaluation might take up to couple of minutes to finish. Please be patient.
(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **Configs.GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

With sufficient training, you should see that the percentages of correct model outputs have clearly gone up, which means our training worked.