# Teaching Gemma to Think First: Reasoning Traces on Science Questions with Tunix and GRPO

**Subtitle:** Fine-tuning Gemma2 on the ARC Science Dataset to Produce Structured Reasoning Using Reinforcement Learning

This notebook is a submission for the [Google Tunix Hackathon](https://www.kaggle.com/competitions/google-tunix-hackathon/overview).

---

## Results Summary

| Metric | Pre-Training | Post-Training |
|--------|--------------|---------------|
| Format Accuracy | 0% | **80%** |
| Answer Accuracy | 0% | **55%** |

The model improved from 0% to 55% answer accuracy—**more than double** the 25% random baseline for 4-choice questions.

## Your Overall Training and Evaluation Strategy

### Approach
Most open-weight language models produce answers without explaining their reasoning. This project trains Gemma2-2B-IT to produce **structured reasoning traces** before answering multiple-choice science questions.

### Compute Allocation
- **Single Kaggle TPU v5e session** (within 9-hour limit)
- **3,357 training steps** (~53 minutes total)
- Efficient LoRA fine-tuning (rank=64) to minimize memory usage

### Reward Function Design
We use a **compositional reward function** with 4 components:

1. **Format Exactness (3 pts)**: Rewards proper `<reasoning>` and `<answer>` tags
2. **Format Partial (1-2 pts)**: Incremental rewards for partial tag compliance
3. **Reasoning Quality (up to 3 pts)**: Rewards option analysis and reasoning keywords
4. **Answer Correctness (5 pts)**: Strong reward for correct answer (A/B/C/D)

### Evaluation Strategy
- **Pre-training baseline**: Evaluate base model to establish 0% baseline
- **Post-training evaluation**: Test on held-out questions with `temperature=0.0` (greedy decoding)
- **Metrics**: Format accuracy (tag compliance) and Answer accuracy (correctness)

### Techniques Used
- **GRPO (Group Relative Policy Optimization)** from Tunix
- **LoRA (Low-Rank Adaptation)** for parameter-efficient fine-tuning
- **Custom reward shaping** to balance format, reasoning, and correctness

## How Your Finetuning Dataset is Created

### Dataset: ARC (AI2 Reasoning Challenge)
We chose the **ARC Science Dataset** specifically to demonstrate GRPO's effectiveness on a **non-math domain**, unlike the reference GSM8K notebook.

**Dataset Statistics:**
- **Training samples**: 1,119
- **Test samples**: 1,172
- **Format**: Multiple-choice science questions (A, B, C, D)

### Domain Coverage
The dataset tests scientific reasoning across:
- Physics (forces, energy, motion)
- Biology (cells, organisms, ecosystems)
- Chemistry (matter, reactions)
- Earth Science (weather, geology)

### Data Preprocessing
Each question is formatted with:
1. System prompt instructing step-by-step reasoning
2. Question text with labeled answer choices (A-D)
3. Expected output format: `<reasoning>...</reasoning><answer>X</answer>`

### Public Accessibility
The ARC dataset is publicly available at:
- https://allenai.org/data/arc
- https://www.kaggle.com/datasets/thedevastator/ai2-arc-reasoning-challenge

The dataset files used in this notebook are included in the notebook inputs.

## Your Tunix Finetuning Code

Using instruction-tuned **Gemma2 2B** with LoRA fine-tuning.

In [None]:
# Your prompt template
PROMPT_TEMPLATE = '''<start_of_turn>user
You are a helpful assistant that solves multiple choice science questions.

For each question:
1. Read the question and all answer choices carefully
2. Think through the problem step by step inside <reasoning> and </reasoning> tags
3. Provide your final answer (A, B, C, or D) inside <answer> and </answer> tags

Example:
<reasoning>
Let me analyze each option:
- Option A: [analysis]
- Option B: [analysis]
- Option C: [analysis]
- Option D: [analysis]
Based on my analysis, the correct answer is...
</reasoning>
<answer>B</answer>

Question: {question}<end_of_turn>
<start_of_turn>model
'''

# Training parameters
TEMPERATURE = 0.9
TOP_K = 50
TOP_P = 0.95
MAX_GENERATION_STEPS = 256
MAX_PROMPT_LENGTH = 256

# LoRA parameters
LORA_RANK = 64
LORA_ALPHA = 64

# Training hyperparameters
LEARNING_RATE = 3e-6
NUM_EPOCHS = 3
BATCH_SIZE = 1
GRPO_NUM_GENERATIONS = 4

# DO NOT CHANGE BELOW - Standard output tags
REASONING_START = "<reasoning>"
REASONING_END = "</reasoning>"
SOLUTION_START = "<answer>"
SOLUTION_END = "</answer>"

# Inference parameters for greedy decoding (used in competition evaluation)
INF_TEMPERATURE = None
INF_TOP_K = 1
INF_TOP_P = None
SEED = 42

# GRPO Training on ARC Science Questions

This tutorial demonstrates training the [Gemma 2 2B-IT](https://deepmind.google/models/gemma/) model on the [ARC (AI2 Reasoning Challenge)](https://allenai.org/data/arc) science multiple-choice questions using [Group Relative Policy Optimization (GRPO)](https://arxiv.org/pdf/2402.03300).

GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that uses group-based comparisons rather than a critic model.

**Key Features:**
- **Dataset**: ARC-Challenge (1119 training, 1172 test samples)
- **Model**: Gemma-2 2B IT with LoRA (rank=64)
- **Task**: Multiple choice science questions (A, B, C, D)
- **Output Format**: `<reasoning>...</reasoning><answer>X</answer>`

**Results achieved:**
- Format Accuracy: **80.0%**
- Answer Accuracy: **55.0%** (vs 25% random baseline)

In [1]:
!pip install -q wandb
!pip install -q kagglehub

!pip install -q ipywidgets

!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
!pip install "google-tunix[prod]==0.1.3"

!pip uninstall -q -y flax
!pip install flax==0.12.0

!pip install -q datasets wandb==0.22.0

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip instal

## Setup W&B for Experiment Tracking

In [2]:
import wandb, os
from kaggle_secrets import UserSecretsClient
os.environ['WANDB_API_KEY'] = UserSecretsClient().get_secret("WANDB_API_KEY")



## Imports

Import all necessary libraries including JAX, Flax NNX, Tunix, and other dependencies.

In [3]:
import functools
import gc
import os
from pprint import pprint
import re
import csv
import shutil
import json

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
import tensorflow_datasets as tfds
from tqdm import tqdm
import numpy as np

# Tunix imports
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.models.gemma import model as gemma_lib  # ← CHANGED!
from tunix.models.gemma import params as params_lib

# Metrics logger with fallback
try:
    from tunix.training import metrics_logger
except ModuleNotFoundError:
    class metrics_logger:
        class MetricsLoggerOptions:
            def __init__(self, log_dir=None, flush_every_n_steps=20):
                self.log_dir = log_dir
                self.flush_every_n_steps = flush_every_n_steps

print("✅ All imports successful!")



✅ All imports successful!


In [4]:
# Check what's in gemma_lib
import tunix.models.gemma as gemma_module
print("Available in tunix.models.gemma:")
print(dir(gemma_module))

# Check submodules
import pkgutil
print("\nSubmodules:")
for importer, modname, ispkg in pkgutil.iter_modules(gemma_module.__path__):
    print(f"  {modname}")
    
# Try to find Transformer
try:
    from tunix.models.gemma.model import Transformer
    print("\n✅ Found Transformer in tunix.models.gemma.model")
except:
    print("\n❌ Not in model")
    
try:
    from tunix.models.gemma import gemma
    print("In gemma submodule:", dir(gemma))
except:
    print("No gemma submodule")

Available in tunix.models.gemma:
['__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'model', 'params']

Submodules:
  model
  params
  params_safetensors
  sampler

✅ Found Transformer in tunix.models.gemma.model
No gemma submodule


In [5]:
# Check what's in tunix.models.gemma
import tunix.models.gemma as gemma_pkg
print("Available in tunix.models.gemma:")
print(dir(gemma_pkg))

# Check submodules
import pkgutil
print("\nSubmodules:")
for importer, modname, ispkg in pkgutil.iter_modules(gemma_pkg.__path__):
    print(f"  {modname}")

Available in tunix.models.gemma:
['__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'model', 'params']

Submodules:
  model
  params
  params_safetensors
  sampler


## Hyperparameters

Define the configuration for training:
- **LoRA settings**: Rank=64, Alpha=64 for efficient fine-tuning
- **GRPO settings**: Number of generations, KL penalty (beta=0.08), clipping (epsilon=0.2)
- **Training settings**: Learning rate=3e-6, 3 epochs, 3357 total steps
- **Cache settings**: Memory allocation for generation (768 tokens)

In [6]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

# ====== LoRA ======
RANK = 64
ALPHA = 64.0

# ====== Sharding ======
MESH = [(1, 4), ("fsdp", "tp")]

# ====== GRPO ======
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 256
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
NUM_GENERATIONS = 2
NUM_ITERATIONS = 1
BETA = 0.08
EPSILON = 0.2

# ====== Training ======
TRAIN_MICRO_BATCH_SIZE = 1
NUM_BATCHES = 1119  # Number of ARC training samples
NUM_TEST_BATCHES = 100
EVAL_EVERY_N_STEPS = 50
NUM_EPOCHS = 3

MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)
print(f"MAX_STEPS = {MAX_STEPS}")

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
WARMUP_STEPS = int(0.1 * MAX_STEPS)
MAX_GRAD_NORM = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# ====== Inference ======
GENERATION_CONFIGS = {
    "greedy": {"temperature": 0.0, "top_k": 1, "top_p": 1.0},
    "sampling": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
}

print("Hyperparameters configured!")

MAX_STEPS = 3357
Hyperparameters configured!


## Utility Functions

Helper function to monitor TPU memory usage.

In [7]:
def show_hbm_usage():
    """Displays memory usage per device."""
    fmt_size = functools.partial(humanize.naturalsize, binary=True)
    for d in jax.local_devices():
        stats = d.memory_stats()
        used = stats["bytes_in_use"]
        limit = stats["bytes_limit"]
        print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({100*used/limit:.1f}%) on {d}")

show_hbm_usage()

E0000 00:00:1767732963.182549      12 common_lib.cc:648] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:238


Using 26.5 KiB / 15.7 GiB (0.0%) on TPU_0(process=0,(0,0,0,0))
Using 26.5 KiB / 15.7 GiB (0.0%) on TPU_1(process=0,(1,0,0,0))
Using 26.5 KiB / 15.7 GiB (0.0%) on TPU_2(process=0,(0,1,0,0))
Using 26.5 KiB / 15.7 GiB (0.0%) on TPU_3(process=0,(1,1,0,0))
Using 26.5 KiB / 15.7 GiB (0.0%) on TPU_4(process=0,(0,2,0,0))
Using 26.5 KiB / 15.7 GiB (0.0%) on TPU_5(process=0,(1,2,0,0))
Using 26.5 KiB / 15.7 GiB (0.0%) on TPU_6(process=0,(0,3,0,0))
Using 26.5 KiB / 15.7 GiB (0.0%) on TPU_7(process=0,(1,3,0,0))


## Data Preprocessing - Output Format

We define special tokens for structured output. The model is instructed to:
1. **Reason** between `<reasoning>` and `</reasoning>` tags
2. **Answer** with a single letter (A, B, C, or D) between `<answer>` and `</answer>` tags

This format encourages step-by-step reasoning before selecting an answer.

In [8]:
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"

# SYSTEM PROMPT for ARC Science Questions
SYSTEM_PROMPT = f"""You are a helpful assistant that solves multiple choice science questions.

For each question:
1. Read the question and all answer choices carefully
2. Think through the problem step by step inside {reasoning_start} and {reasoning_end} tags
3. Provide your final answer (A, B, C, or D) inside {solution_start} and {solution_end} tags

Example:
{reasoning_start}
Let me analyze each option:
- Option A: [analysis]
- Option B: [analysis]
- Option C: [analysis]
- Option D: [analysis]
Based on my analysis, the correct answer is...
{reasoning_end}
{solution_start}B{solution_end}
"""

TEMPLATE = """<start_of_turn>user
{system_prompt}

Question: {question}<end_of_turn>
<start_of_turn>model
"""

print("System prompt configured for ARC dataset!")

System prompt configured for ARC dataset!


## Load ARC Dataset

We use the [ARC (AI2 Reasoning Challenge)](https://allenai.org/data/arc) dataset containing grade-school level science questions. Each question has 4 multiple choice options.

The dataset tests scientific reasoning across topics like:
- Physics (forces, energy, motion)
- Biology (cells, organisms, ecosystems)
- Chemistry (matter, reactions)
- Earth Science (weather, geology)

In [9]:
# === Load ARC Dataset ===
def load_arc_jsonl(filepath):
    """Load ARC JSONL file"""
    data = []
    with open(filepath, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

def format_arc_question(item):
    """Format ARC item into question string"""
    stem = item['question']['stem']
    choices = item['question']['choices']
    
    choice_str = ""
    for choice in choices:
        choice_str += f"{choice['label']}: {choice['text']}\n"
    
    return f"{stem}\n\n{choice_str}".strip()

def get_arc_answer(item):
    """Get the answer key"""
    return item['answerKey']

# Load training and test data
train_data_raw = load_arc_jsonl('/kaggle/input/arc-ai2-reasoning-challenge/ARC-Challenge-Train.jsonl')
test_data_raw = load_arc_jsonl('/kaggle/input/arc-ai2-reasoning-challenge/ARC-Challenge-Test.jsonl')

print(f"Loaded {len(train_data_raw)} training samples")
print(f"Loaded {len(test_data_raw)} test samples")

# Format all data into simple list
train_dataset = []
for item in train_data_raw:
    train_dataset.append({
        'prompt': format_arc_question(item),
        'answer': get_arc_answer(item)
    })

test_dataset = []
for item in test_data_raw:
    test_dataset.append({
        'prompt': format_arc_question(item),
        'answer': get_arc_answer(item)
    })

print(f"\nFormatted {len(train_dataset)} training samples")
print(f"Formatted {len(test_dataset)} test samples")

# Show example
print("\n" + "="*50)
print("EXAMPLE QUESTION:")
print(train_dataset[0]['prompt'])
print("\nANSWER:", train_dataset[0]['answer'])

Loaded 1119 training samples
Loaded 1172 test samples

Formatted 1119 training samples
Formatted 1172 test samples

EXAMPLE QUESTION:
George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?

A: dry palms
B: wet palms
C: palms covered with oil
D: palms covered with lotion

ANSWER: A


In [10]:
source = 'arc'
print(f"Using data source: {source}")

# Prepare raw data list
arc_raw_data = []
for item in train_dataset:
    arc_raw_data.append({
        "question": item['prompt'],
        "answer": item['answer'],
    })

print(f"Prepared {len(arc_raw_data)} raw training samples")

# Create the grain.MapDataset EXACTLY like original GSM8K notebook
dataset = (
    grain.MapDataset.source(arc_raw_data)
    .shuffle(seed=42)
    .map(
        lambda x: {
            "prompts": TEMPLATE.format(
                system_prompt=SYSTEM_PROMPT,
                question=x["question"],
            ),
            "question": x["question"],
            "answer": x["answer"],
        }
    )
)

# Batch it
dataset = dataset.batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_BATCHES]

# Repeat for epochs
train_dataset_final = dataset.repeat(NUM_EPOCHS)

print(f"✅ Created train_dataset_final with {len(train_dataset_final)} batches")

# Verify first batch
for batch in train_dataset_final[:1]:
    print(f"✅ First batch keys: {batch.keys()}")
    print(f"✅ Prompts type: {type(batch['prompts'])}")
    if len(batch['prompts']) > 0:
        print(f"✅ First prompt type: {type(batch['prompts'][0])}")
        first_prompt = batch['prompts'][0]
        if hasattr(first_prompt, 'decode'):
            first_prompt = first_prompt.decode('utf-8')
        print(f"✅ First prompt (100 chars): {first_prompt[:100]}")

Using data source: arc
Prepared 1119 raw training samples
✅ Created train_dataset_final with 3357 batches
✅ First batch keys: dict_keys(['answer', 'prompts', 'question'])
✅ Prompts type: <class 'numpy.ndarray'>
✅ First prompt type: <class 'numpy.str_'>
✅ First prompt (100 chars): <start_of_turn>user
You are a helpful assistant that solves multiple choice science questions.

For 


## Preview Training Data

Let's see how one batch of the training dataset looks like!

In [11]:
# Show example batch
print("Example training batch:")
if len(dataset) > 0:
    print(f"Prompts: {dataset[0]['prompts'][0][:200]}...")
    print(f"Answer: {dataset[0]['answer'][0]}")

Example training batch:
Prompts: <start_of_turn>user
You are a helpful assistant that solves multiple choice science questions.

For each question:
1. Read the question and all answer choices carefully
2. Think through the problem st...
Answer: B


## Kaggle Authentication

Log in to Kaggle to download the Gemma model weights.

In [12]:
# Log in to Kaggle
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
    kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

## Download Gemma Model

Download the Gemma 2 2B-IT model from Kaggle.

In [13]:
model_path = {
    "gemma2": "google/gemma-2/flax/",
}
model_family = "gemma2"
model_version = "gemma2-2b-it"
print(f"{model_path[model_family]}{model_version}")

kaggle_ckpt_path = kagglehub.model_download(
    f"{model_path[model_family]}{model_version}"
)

google/gemma-2/flax/gemma2-2b-it


## Checkpoint Conversion

Re-save the pre-trained Gemma checkpoint into a format compatible with Flax NNX. The original Kaggle checkpoint has parameter names that need to be reformatted.

In [14]:
# Re-save checkpoint for NNX compatibility
!rm /tmp/content/intermediate_ckpt/* -rf
!rm /tmp/content/ckpts/* -rf

if model_family == "gemma2":
    params = params_lib.load_and_format_params(
        os.path.join(kaggle_ckpt_path, "gemma2-2b-it")
    )
    gemma = gemma_lib.Transformer.from_params(params, version="2-2b-it")
    checkpointer = ocp.StandardCheckpointer()
    _, state = nnx.split(gemma)
    checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)
    checkpointer.wait_until_finished()
    del params
    del gemma
    del state
    gc.collect()

  pid, fd = os.forkpty()


## Model Loading Functions

Two key functions:
- **`get_gemma_ref_model`**: Loads the Gemma model with JAX sharding for multi-device distribution
- **`get_lora_model`**: Applies LoRA layers to attention and MLP modules for efficient training

The **reference model** stays frozen and is used to compute KL divergence, ensuring the policy doesn't deviate too far from the original behavior.

In [15]:
def get_gemma_ref_model(ckpt_path):
    mesh = jax.make_mesh(*MESH)
    model_config = gemma_lib.ModelConfig.gemma2_2b()
    abs_gemma: nnx.Module = nnx.eval_shape(
        lambda: gemma_lib.Transformer(model_config, rngs=nnx.Rngs(params=0))
    )
    abs_state = nnx.state(abs_gemma)
    abs_state = jax.tree.map(
        lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
        abs_state,
        nnx.get_named_sharding(abs_state, mesh),
    )
    checkpointer = ocp.StandardCheckpointer()
    restored_params = checkpointer.restore(ckpt_path, target=abs_state)
    graph_def, _ = nnx.split(abs_gemma)
    gemma = nnx.merge(graph_def, restored_params)
    return gemma, mesh, model_config


def get_lora_model(base_model, mesh):
    lora_provider = qwix.LoraProvider(
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
            ".*attn_vec_einsum"
        ),
        rank=RANK,
        alpha=ALPHA,
    )
    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(
        base_model, lora_provider, **model_input
    )
    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)
    return lora_model

## Load Reference Model

Load the base Gemma model that will serve as our reference for KL divergence calculation.

In [16]:
# Reference model
if model_family == "gemma2":
    ref_model, mesh, model_config = get_gemma_ref_model(
        ckpt_path=os.path.join(INTERMEDIATE_CKPT_DIR, "state")
    )

## Load Policy Model (LoRA)

Apply LoRA adapters to create the trainable policy model. Only LoRA parameters will be updated during training.

In [17]:
# Policy model
lora_policy = get_lora_model(ref_model, mesh=mesh)
nnx.display(lora_policy)

## Load Tokenizer

Load the SentencePiece tokenizer and wrap it with TokenizerAdapter for Tunix compatibility.

In [18]:
# Load tokenizer using sentencepiece directly, then wrap it
import sentencepiece as spm

# Load the sentencepiece model
sp_model = spm.SentencePieceProcessor()
sp_model.Load(os.path.join(kaggle_ckpt_path, "tokenizer.model"))

# Wrap with TokenizerAdapter
tokenizer = tokenizer_lib.TokenizerAdapter(sp_model)

print("✅ Tokenizer loaded!")

✅ Tokenizer loaded!


## Format Matching Regex

Define regex pattern to validate the model's output format.

In [19]:
# Regex for format matching
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

# Test the regex
result = match_format.search(
    f"{reasoning_start}Let me think through this step by step...{reasoning_end}{solution_start}B{solution_end}",
)
print(f"Regex test: {result is not None}")

Regex test: True


## Define Reward Functions

We define four reward functions to guide GRPO training:

1. **`match_format_exactly`** (3 points): Full reward if output has both `<reasoning>` and `<answer>` tags
2. **`match_format_approximately`** (partial): Incremental rewards for partial format compliance  
3. **`check_reasoning_quality`** (up to 3 points): Rewards for analyzing options and using reasoning words
4. **`check_answer_correct`** (5 points): Big reward for selecting the correct answer letter (A/B/C/D)

These reward functions shape the model to produce well-structured, reasoned responses.

In [20]:
def match_format_exactly(prompts, completions, **kwargs):
    """Reward if the format matches exactly (3 points)."""
    scores = []
    for response in completions:
        # Convert bytes to string if needed
        if hasattr(response, 'decode'):
            response = response.decode('utf-8')
        has_reasoning = bool(re.search(r'<reasoning>.*?</reasoning>', str(response), re.DOTALL))
        has_answer = bool(re.search(r'<answer>.*?</answer>', str(response), re.DOTALL))
        if has_reasoning and has_answer:
            scores.append(3.0)
        else:
            scores.append(0.0)
    return scores


def match_format_approximately(prompts, completions, **kwargs):
    """Reward partial format matches."""
    scores = []
    for response in completions:
        if hasattr(response, 'decode'):
            response = response.decode('utf-8')
        response = str(response)
        score = 0
        if '<reasoning>' in response:
            score += 1.0
        if '</reasoning>' in response:
            score += 0.5
        if '<answer>' in response:
            score += 1.0
        if '</answer>' in response:
            score += 0.5
        scores.append(score - 1.5)
    return scores


def check_reasoning_quality(prompts, completions, answer=None, **kwargs):
    """Reward for quality reasoning indicators."""
    scores = []
    for response in completions:
        if hasattr(response, 'decode'):
            response = response.decode('utf-8')
        response = str(response)
        score = 0.0
        
        # Check for option analysis
        for opt in ['A', 'B', 'C', 'D']:
            if f"Option {opt}" in response or f"{opt}:" in response:
                score += 0.5
        
        # Check for reasoning words
        reasoning_words = ['because', 'therefore', 'since', 'means', 'correct', 'incorrect', 'wrong', 'right']
        for word in reasoning_words:
            if word.lower() in response.lower():
                score += 0.25
        
        scores.append(min(score, 3.0))
    return scores


def check_answer_correct(prompts, completions, answer=None, **kwargs):
    """Check if the answer letter is correct - BIG reward!"""
    scores = []
    
    # Handle numpy array
    if answer is not None and hasattr(answer, 'tolist'):
        answer = answer.tolist()
    
    for i, response in enumerate(completions):
        # Convert bytes to string if needed
        if hasattr(response, 'decode'):
            response = response.decode('utf-8')
        response = str(response)
        
        # Get reference answer
        ref_answer = None
        if answer is not None:
            if isinstance(answer, list) and i < len(answer):
                ref_answer = answer[i]
            elif isinstance(answer, str):
                ref_answer = answer
        
        # Convert ref_answer to string
        if ref_answer is not None:
            if hasattr(ref_answer, 'decode'):
                ref_answer = ref_answer.decode('utf-8')
            ref_answer = str(ref_answer)
        
        # Extract answer from <answer> tags
        match = re.search(r'<answer>\s*([A-Da-d])\s*</answer>', response)
        
        if match:
            extracted = match.group(1).upper()
            if ref_answer and extracted == ref_answer.upper():
                scores.append(5.0)  # Correct answer!
            else:
                scores.append(1.0)  # Has answer format but wrong/can't verify
        else:
            scores.append(0.0)  # No answer found
    return scores


print("Reward functions defined!")

Reward functions defined!


## Generation Function

Helper function to generate model responses given a question prompt.

In [21]:
def generate(question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None):
    """Given prompt, generates text."""
    if isinstance(question, str):
        input_batch = [
            TEMPLATE.format(
                system_prompt=SYSTEM_PROMPT,
                question=question,
            )
        ]
    else:
        input_batch = [
            TEMPLATE.format(
                system_prompt=SYSTEM_PROMPT,
                question=q,
            )
            for q in question
        ]

    out_data = sampler(
        input_strings=input_batch,
        max_generation_steps=TOTAL_GENERATION_STEPS,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        echo=False,
        seed=seed if seed is not None else None,
    )

    output = out_data.text
    if isinstance(question, str):
        return output[0]
    return output

print("Generate function defined!")

Generate function defined!


## Evaluation Functions

Functions to extract answers and evaluate model accuracy on the test set.

In [22]:
def extract_answer(response):
    """Extract answer letter from response"""
    match = re.search(r'<answer>\s*([A-Da-d])\s*</answer>', response)
    if match:
        return match.group(1).upper()
    return None

def check_format(response):
    """Check if response has proper format"""
    has_reasoning = '<reasoning>' in response and '</reasoning>' in response
    has_answer = '<answer>' in response and '</answer>' in response
    return has_reasoning and has_answer

def evaluate(dataset, sampler, temperature=0.7, top_k=50, top_p=0.95, num_samples=50):
    """Evaluates the model on ARC questions."""
    corr_format = 0
    corr_answer = 0
    total = 0
    
    eval_samples = dataset[:min(num_samples, len(dataset))]
    
    for i, item in enumerate(tqdm(eval_samples)):
        question = item['prompt']
        ref_answer = item['answer']
        
        response = generate(question, sampler, temperature, top_k, top_p)
        
        if check_format(response):
            corr_format += 1
        
        extracted = extract_answer(response)
        if extracted and extracted == ref_answer.upper():
            corr_answer += 1
        
        total += 1
        
        if (i + 1) % 10 == 0:
            print(f"Progress: {total}/{len(eval_samples)}, Format: {100*corr_format/total:.1f}%, Correct: {100*corr_answer/total:.1f}%")
    
    format_accuracy = 100 * corr_format / total if total > 0 else 0
    answer_accuracy = 100 * corr_answer / total if total > 0 else 0
    
    return corr_format, total, format_accuracy, answer_accuracy

print("Evaluate function defined!")

Evaluate function defined!


## Create Sampler

Create the sampler for text generation with the appropriate cache configuration.

In [23]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)
print("Sampler created!")

Sampler created!


## Pre-Training Evaluation (Baseline)

Evaluate the base model **before training** to establish a baseline. This helps us measure improvement after GRPO training.

Expected: ~0% format accuracy (model doesn't know the format yet)

In [24]:
# Evaluate before training
print("Pre-training evaluation...")
(corr_format, total, format_accuracy, answer_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
    num_samples=30
)
print(f"\nPRE-TRAINING RESULTS:")
print(f"  Total samples: {total}")
print(f"  Format accuracy: {format_accuracy:.1f}%")
print(f"  Answer accuracy: {answer_accuracy:.1f}%")

Pre-training evaluation...


 33%|███▎      | 10/30 [00:57<00:48,  2.40s/it]

Progress: 10/30, Format: 0.0%, Correct: 0.0%


 67%|██████▋   | 20/30 [01:05<00:07,  1.29it/s]

Progress: 20/30, Format: 0.0%, Correct: 0.0%


100%|██████████| 30/30 [01:12<00:00,  2.41s/it]

Progress: 30/30, Format: 0.0%, Correct: 0.0%

PRE-TRAINING RESULTS:
  Total samples: 30
  Format accuracy: 0.0%
  Answer accuracy: 0.0%





## Training Configuration

Setting up:
- **Checkpointing**: Save model every 500 steps, keep last 4 checkpoints
- **Metrics Logging**: Log to TensorBoard
- **Optimizer**: AdamW with warmup and cosine decay schedule

In [25]:
# Checkpoint saving options
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/grpo", flush_every_n_steps=20
)

In [26]:
# Optimizer with warmup and cosine decay
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
    optimizer = optax.chain(
        optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
        optimizer,
    )
print("Optimizer configured!")

Optimizer configured!


## Cluster Configuration

Configure the distributed training setup:
- **Mesh**: Device mesh for FSDP and tensor parallelism
- **Rollout Config**: Generation parameters (temperature, top-k, top-p)
- **Training Config**: Batch sizes, learning rate, max steps

In [27]:
# Training config - CONSISTENT cache sizes!
KV_CACHE_SIZE = MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256
print(f"Using KV_CACHE_SIZE = {KV_CACHE_SIZE}")

cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        metrics_logging_options=metrics_logging_options,
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=KV_CACHE_SIZE,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)
print("Cluster config created!")

Using KV_CACHE_SIZE = 768
Cluster config created!


## Initialize GRPO Trainer

Create the training components:
1. **RLCluster**: Combines the policy (LoRA model), reference model, and tokenizer
2. **GRPOLearner**: The trainer that uses our reward functions to optimize the model

The trainer generates multiple responses per prompt and uses relative rewards to update the policy.

In [28]:
# RL cluster and GRPO Trainer
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

# GRPO Trainer with ARC reward functions
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_reasoning_quality,
        check_answer_correct,
    ],
    grpo_config=grpo_config,
)
print("GRPO Trainer created!")

[34m[1mwandb[0m: Currently logged in as: [33mkushi-nhce[0m ([33mkushi-nhce-nhce[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


GRPO Trainer created!




## Start Training! 🚀

Train the model using GRPO. The first few steps may take longer due to JIT compilation.

Training for **3357 steps** (~53 minutes on TPU v5e)

Metrics logged to W&B:
- Loss, KL divergence, Perplexity
- Steps per second, TFLOPs

In [29]:
# Start training!
print(f"Starting training for {MAX_STEPS} steps...")
with mesh:
    grpo_trainer.train(train_dataset_final)

Starting training for 3357 steps...


Actor Training:   0%|          | 0/3357 [00:00<?, ?step/s]



0,1
actor/train/kl,▁▁▁▁▁▁▂▂▃▄▆▅▇▃▄▃▄▃▂▄▄▄▃▃▃▂▅▂▂▃▃▃▂▂▂▂▂▃▅█
actor/train/loss,▄▆▇▅▅▅▄▆▅▆▅▅▄▅▄▅▆▅█▅▁▅▅▅▄▅▅▅▆▄▂▄▆▃▆▆▆▅▆▅
actor/train/perplexity,▄▂▃▃▆▄▅▅█▅▄▅▄▄▂▁▂▅▅▅▁▃▅█▅▂▅▅▃▅▅▅▅▅▃▅▃▃▅▅
actor/train/step_time_sec,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁
actor/train/steps_per_sec,▅▄▃▄▂▄▄▄█▄▃▃▃▄▅▄▅▃▂▃▃▃▃▃▂▃▃▃▂▃▁▅▂▂▄▃▃▃▃▃
actor/train/tflops_per_step,▁
jax/core/compile/backend_compile_duration,▁
jax/core/compile/jaxpr_to_mlir_module_duration,▁
jax/core/compile/jaxpr_trace_duration,▁
jax/orbax/write/sharded_array_gb,▁

0,1
actor/train/kl,0.00702
actor/train/loss,-0.08329
actor/train/perplexity,0.92009
actor/train/step_time_sec,0.09846
actor/train/steps_per_sec,10.15612
actor/train/tflops_per_step,2.88575
jax/core/compile/backend_compile_duration,1767733301.20821
jax/core/compile/jaxpr_to_mlir_module_duration,1767733299.76483
jax/core/compile/jaxpr_trace_duration,1767733298.22219
jax/orbax/write/sharded_array_gb,0.0011


## Test Trained Model

Let's test the trained model with a simple question to verify it generates the expected format with proper reasoning and answer tags.

In [31]:
# SIMPLE TEST - See what the model actually generates
import wandb
import os
os.environ['WANDB_MODE'] = 'disabled'  # Disable wandb completely

# Create a simple sampler
test_sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=512,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

# Test with ONE simple question
test_question = "Which is a renewable resource?\n\nA: coal\nB: oil\nC: sunlight\nD: natural gas"

prompt = TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=test_question)

print("INPUT PROMPT:")
print("="*60)
print(prompt)
print("="*60)

# Generate
out = test_sampler(
    input_strings=[prompt],
    max_generation_steps=200,
    temperature=0.0,
)

print("\nRAW MODEL OUTPUT:")
print("="*60)
print(out.text[0])
print("="*60)

INPUT PROMPT:
<start_of_turn>user
You are a helpful assistant that solves multiple choice science questions.

For each question:
1. Read the question and all answer choices carefully
2. Think through the problem step by step inside <reasoning> and </reasoning> tags
3. Provide your final answer (A, B, C, or D) inside <answer> and </answer> tags

Example:
<reasoning>
Let me analyze each option:
- Option A: [analysis]
- Option B: [analysis]
- Option C: [analysis]
- Option D: [analysis]
Based on my analysis, the correct answer is...
</reasoning>
<answer>B</answer>


Question: Which is a renewable resource?

A: coal
B: oil
C: sunlight
D: natural gas<end_of_turn>
<start_of_turn>model


RAW MODEL OUTPUT:
<reasoning>
Let's analyze each option:
- Option A: Coal is a fossil fuel formed from ancient plant matter, a non-renewable resource.
- Option B: Oil is also a fossil fuel, formed from ancient marine organisms, making it non-renewable.
- Option C: Sunlight is a form of energy that is constantl

## Save Trained Model

Save the LoRA parameters so we can reload the trained model later without retraining.

The model is saved to `/kaggle/working/trained_arc_model/lora_params`

In [32]:
# SAVE TRAINED MODEL
import os

# Save the LoRA weights
save_path = "/kaggle/working/trained_arc_model"
os.makedirs(save_path, exist_ok=True)

# Save LoRA parameters
lora_state = nnx.state(lora_policy, nnx.LoRAParam)
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(os.path.join(save_path, "lora_params"), lora_state)
checkpointer.wait_until_finished()

print(f"✅ Model saved to {save_path}")
print("You can download this from Kaggle's Output tab!")

✅ Model saved to /kaggle/working/trained_arc_model
You can download this from Kaggle's Output tab!


## Load Trained Model

Demonstrate how to reload the saved LoRA parameters into the model.

In [34]:
# LOAD TRAINED MODEL (after setting up base model)
load_path = "/kaggle/working/trained_arc_model"

# Load saved LoRA params
abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
loaded_params = checkpointer.restore(os.path.join(load_path, "lora_params"), target=abs_params)

# Update model with loaded params
nnx.update(lora_policy, loaded_params)
print("✅ Trained model loaded!")



✅ Trained model loaded!


## Post-Training Evaluation

Evaluate the fine-tuned model using `evaluate_safe()` which handles variable-length prompts properly.

This creates a fresh sampler for each question to avoid cache size issues.

In [37]:
# EVEN SAFER EVALUATION - handles variable length prompts
import os
os.environ['WANDB_MODE'] = 'disabled'

def evaluate_safe(dataset, num_samples=20):
    corr_format = 0
    corr_answer = 0
    total = 0
    
    for i, item in enumerate(tqdm(dataset[:num_samples])):
        question = item['prompt']
        ref_answer = item['answer']
        prompt = TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=question)
        
        # Create fresh sampler for each question to avoid cache issues
        try:
            temp_sampler = sampler_lib.Sampler(
                transformer=lora_policy,
                tokenizer=tokenizer,
                cache_config=sampler_lib.CacheConfig(
                    cache_size=768,
                    num_layers=model_config.num_layers,
                    num_kv_heads=model_config.num_kv_heads,
                    head_dim=model_config.head_dim,
                ),
            )
            
            out = temp_sampler(
                input_strings=[prompt],
                max_generation_steps=200,
                temperature=0.0,
            )
            response = out.text[0] if out.text else ""
            del temp_sampler
            
        except Exception as e:
            print(f"Skip {i}: {str(e)[:50]}")
            continue
        
        # Check format and answer
        has_format = '<reasoning>' in response and '</reasoning>' in response and '<answer>' in response and '</answer>' in response
        if has_format:
            corr_format += 1
        
        match = re.search(r'<answer>\s*([A-Da-d])\s*</answer>', response)
        if match and match.group(1).upper() == ref_answer.upper():
            corr_answer += 1
        
        total += 1
        
        # Show first 3
        if i < 3:
            extracted = match.group(1).upper() if match else "None"
            print(f"\nQ{i+1}: {question[:60]}...")
            print(f"Expected: {ref_answer} | Got: {extracted} | {'✅' if match and match.group(1).upper() == ref_answer.upper() else '❌'}")
    
    print(f"\n{'='*60}")
    print(f"RESULTS: {total} samples")
    print(f"Format: {100*corr_format/total:.1f}% | Answer: {100*corr_answer/total:.1f}%")
    return corr_format, corr_answer, total

# Run it
print("Evaluating...")
evaluate_safe(test_dataset, num_samples=20)

Evaluating...


  5%|▌         | 1/20 [00:23<07:35, 23.95s/it]


Q1: An astronomer observes that a planet rotates faster after a ...
Expected: C | Got: C | ✅


 10%|█         | 2/20 [00:51<07:52, 26.28s/it]


Q2: A group of engineers wanted to know how different building d...
Expected: B | Got: B | ✅


 15%|█▌        | 3/20 [01:16<07:12, 25.43s/it]


Q3: The end result in the process of photosynthesis is the produ...
Expected: C | Got: C | ✅


100%|██████████| 20/20 [08:38<00:00, 25.93s/it]


RESULTS: 20 samples
Format: 80.0% | Answer: 55.0%





(16, 11, 20)

## Final Results Summary

Summary of the GRPO training results on ARC science questions.

In [38]:
# FINAL SUMMARY FOR SUBMISSION
print("="*60)
print("ARC SCIENCE QUESTION GRPO TRAINING - FINAL RESULTS")
print("="*60)
print(f"""
Dataset: ARC (AI2 Reasoning Challenge) - Science Multiple Choice
Model: Gemma-2 2B IT with LoRA (rank=64)
Training: GRPO with custom reward functions
- Format reward (reasoning + answer tags)
- Answer correctness reward

RESULTS:
--------
Format Accuracy: 80.0%
Answer Accuracy: 55.0%

The model successfully learned to:
1. Use <reasoning> tags for step-by-step analysis
2. Use <answer> tags for final answer
3. Analyze multiple choice options
4. Select correct answers at 55% accuracy (vs 25% random baseline)
""")
print("="*60)

ARC SCIENCE QUESTION GRPO TRAINING - FINAL RESULTS

Dataset: ARC (AI2 Reasoning Challenge) - Science Multiple Choice
Model: Gemma-2 2B IT with LoRA (rank=64)
Training: GRPO with custom reward functions
- Format reward (reasoning + answer tags)
- Answer correctness reward

RESULTS:
--------
Format Accuracy: 80.0%
Answer Accuracy: 55.0%

The model successfully learned to:
1. Use <reasoning> tags for step-by-step analysis
2. Use <answer> tags for final answer
3. Analyze multiple choice options
4. Select correct answers at 55% accuracy (vs 25% random baseline)



## Showcase Examples

Let's see some example outputs from our trained model to demonstrate the quality of reasoning it has learned.

In [39]:
# Generate a few showcase examples
showcase_questions = [
    "Which is a renewable resource?\n\nA: coal\nB: oil\nC: sunlight\nD: natural gas",
    "What is the main function of the heart?\n\nA: to digest food\nB: to pump blood\nC: to filter air\nD: to produce hormones",
]

print("\n" + "="*60)
print("SAMPLE MODEL OUTPUTS")
print("="*60)

for q in showcase_questions:
    prompt = TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=q)
    
    temp_sampler = sampler_lib.Sampler(
        transformer=lora_policy,
        tokenizer=tokenizer,
        cache_config=sampler_lib.CacheConfig(
            cache_size=768,
            num_layers=model_config.num_layers,
            num_kv_heads=model_config.num_kv_heads,
            head_dim=model_config.head_dim,
        ),
    )
    
    out = temp_sampler(input_strings=[prompt], max_generation_steps=200, temperature=0.0)
    
    print(f"\nQuestion: {q}\n")
    print(f"Model Response:\n{out.text[0]}")
    print("-"*60)
    del temp_sampler


SAMPLE MODEL OUTPUTS

Question: Which is a renewable resource?

A: coal
B: oil
C: sunlight
D: natural gas

Model Response:
<reasoning>
Let's analyze each option:
- Option A: Coal is a fossil fuel formed from ancient plant matter, a non-renewable resource.
- Option B: Oil is also a fossil fuel, formed from ancient marine organisms, making it non-renewable.
- Option C: Sunlight is a form of energy that is constantly replenished by the sun. This makes it a renewable resource.
- Option D: Natural gas is a fossil fuel, formed from ancient organic matter, making it non-renewable.

Based on my analysis, the correct answer is C.
</reasoning>
<answer>C</answer> 
<end_of_turn>
------------------------------------------------------------

Question: What is the main function of the heart?

A: to digest food
B: to pump blood
C: to filter air
D: to produce hormones

Model Response:
<reasoning>
Let's break down the functions of the body's organs:
- Option A: Digestion is the process of breaking down

## [Optional 15pts] Unrestricted Mode

If participating in unrestricted mode, specify the Kaggle model ID below.

**Note:** For this submission, we are participating in **single-session mode only**.

In [None]:
# Optional: Uncomment and fill in if participating in unrestricted mode
# unrestricted_kaggle_model = "yashaswinikushi/arc-grpo-model"

# For single-session mode, leave this as None
unrestricted_kaggle_model = None

## Other Things for Judges to Know

### What I Learned
- How to use Tunix's GRPO implementation for reinforcement learning
- The importance of reward function design in shaping model behavior
- How to efficiently fine-tune large models with LoRA on limited compute

### Challenges Faced
- **Cache size issues**: Required careful tuning to avoid memory errors during evaluation
- **W&B initialization**: Needed workarounds for wandb logging conflicts
- **Evaluation consistency**: Variable-length prompts required fresh sampler per question

### Improvements Achieved
- Successfully applied GRPO to a **non-math domain** (science questions)
- Model learned structured reasoning **from scratch** (0% → 80% format accuracy)
- Answer accuracy **more than doubled** the random baseline (25% → 55%)

### Suggestions for Future Hackathons
- Provide more example notebooks for different domains
- Include troubleshooting guide for common Tunix errors
- Consider longer TPU session limits for experimentation

---

# Competition Evaluation (For Judges)

The sections below are for Google judges to reproduce and evaluate the model.

## Checkpoint Directory

In [None]:
# Checkpoint directory for evaluation
CKPT_DIR = '/tmp/content/ckpts'

import os
import re

# Find the latest checkpoint
actor_ckpt_dir = os.path.join(CKPT_DIR, "actor")

latest_step = -1
if os.path.exists(actor_ckpt_dir):
    for item in os.listdir(actor_ckpt_dir):
        if os.path.isdir(os.path.join(actor_ckpt_dir, item)) and re.match(r'^\d+$', item):
            step = int(item)
            if step > latest_step:
                latest_step = step

if latest_step == -1:
    print(f"No checkpoints found in {actor_ckpt_dir}")
else:
    print(f"Latest checkpoint step: {latest_step}")