# DSA-CAST + Tunix GRPO on Gemma3-1B (TPU, Kaggle)

This notebook:

1. Sets up **Gemma3-1B-IT** on a Kaggle TPU using **Tunix**.
2. Uses the `<reasoning> ... </reasoning>` and `<answer> ... </answer>` format for math problems (GSM8K-style).
3. Defines a **CAST-style reward** that strongly favors:
   - mathematical accuracy, and  
   - answer completeness & proper tagging.
4. Runs a **Tunix GRPO** reinforcement learning loop using that reward.
5. Saves the final **Tunix checkpoint (no safetensors export)** so it can be re-used in another notebook.

In [None]:
# === Environment setup: JAX TPU + Tunix + Gemma ===
import os, random, numpy as np

# Use TPU memory efficiently
os.environ.setdefault("JAX_PLATFORMS", "tpu,cpu")
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

# Install JAX (TPU build), Tunix, Gemma, and helpers
!pip install -q "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -q tunix gemma qwix datasets humanize tensorflow_datasets kagglehub

print("Environment installs complete.")

In [None]:
# === Imports & global configuration ===
import functools
import json
import re
import shutil
import sys
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import humanize
import grain
import tensorflow_datasets as tfds
import kagglehub
from orbax import checkpoint as ocp

from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import model as gemma3_tunix_model
from tunix.models.gemma3 import params_safetensors as gemma3_params_sft
from tunix.models.gemma3 import params as gemma3_params_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOLearner, GRPOConfig
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger

print("Imported core libraries.")

# Random seeds
SEED = 42
rng = np.random.default_rng(SEED)
jax_key = jax.random.PRNGKey(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Special tags and system prompt
REASONING_START = "<reasoning>"
REASONING_END = "</reasoning>"
ANSWER_START = "<answer>"
ANSWER_END = "</answer>"

SYSTEM_PROMPT = f"""You are a careful math tutor.

For each question:
- First, think through the problem step by step.
- Put all of your step-by-step reasoning strictly between {REASONING_START} and {REASONING_END}.
- Then, put the final numeric answer (only the number) strictly between {ANSWER_START} and {ANSWER_END}.

You MUST include both blocks, in this order.
""".strip()

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model
"""

In [None]:
# === Hyperparameters ===

MODEL_ID = "google/gemma-3-1b-it"

TRAIN_DATA_DIR = "./data/gsm8k_train"
TEST_DATA_DIR = "./data/gsm8k_test"

NUM_TPUS = len(jax.devices())
if NUM_TPUS == 8:
    MESH_COUNTS = (1, 4)
elif NUM_TPUS == 1:
    MESH_COUNTS = (1, 1)
else:
    raise ValueError(f"Unsupported number of TPU devices: {NUM_TPUS}")

MESH = [MESH_COUNTS, ("fsdp", "tp")]

MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 384
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
NUM_GENERATIONS = 2
NUM_ITERATIONS = 1

TRAIN_MICRO_BATCH_SIZE = 1
NUM_BATCHES = 256
TRAIN_FRACTION = 0.9
NUM_EPOCHS = 1

MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
WARMUP_STEPS = int(0.1 * MAX_STEPS)
MAX_GRAD_NORM = 0.1

CKPT_DIR = "/kaggle/working/grpo_ckpts"
SAVE_INTERVAL_STEPS = 200
MAX_TO_KEEP = 4

GENERATION_CONFIGS = {
    "greedy":   {"temperature": None, "top_k": 1,   "top_p": None},
    "standard": {"temperature": 0.7,  "top_k": 50,  "top_p": 0.95},
    "liberal":  {"temperature": 0.85, "top_k": 2000,"top_p": 1.0},
}

print("Hyperparameters set. MAX_STEPS =", MAX_STEPS)

In [None]:
# === Data preprocessing: GSM8K via TFDS ===

def extract_hash_answer(text: str) -> Optional[str]:
    if "####" not in text:
        return None
    return text.split("####", 1)[1].strip()

def _load_gsm8k_tfds(data_dir: str, split: str):
    import tensorflow_datasets.text.gsm8k
    return tfds.data_source(
        "gsm8k",
        split=split,
        data_dir=data_dir,
        builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
        download=True,
    )

def get_gsm8k_dataset(data_dir: str, split: str = "train") -> grain.MapDataset:
    os.makedirs(data_dir, exist_ok=True)
    ds = _load_gsm8k_tfds(data_dir, split)

    def _as_text(v):
        return v if isinstance(v, str) else v.decode("utf-8")

    dataset = (
        grain.MapDataset.source(ds)
        .shuffle(seed=SEED)
        .map(
            lambda x: {
                "prompts": TEMPLATE.format(
                    system_prompt=SYSTEM_PROMPT,
                    question=_as_text(x["question"]),
                ),
                "question": _as_text(x["question"]),
                "answer": extract_hash_answer(_as_text(x["answer"])),
            }
        )
    )
    return dataset

train_raw = get_gsm8k_dataset(TRAIN_DATA_DIR, split="train")
test_raw = get_gsm8k_dataset(TEST_DATA_DIR, split="test")

train_dataset = train_raw.batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_BATCHES]

if TRAIN_FRACTION == 1.0:
    train_dataset = train_dataset.repeat(NUM_EPOCHS)
    val_dataset = None
else:
    cutoff = int(len(train_dataset) * TRAIN_FRACTION)
    train_dataset = train_dataset[:cutoff].repeat(NUM_EPOCHS)
    val_dataset = train_dataset[cutoff:].repeat(NUM_EPOCHS) if cutoff < len(train_dataset) else None

NUM_TEST_BATCHES = 64
test_dataset = test_raw.batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_TEST_BATCHES]

print("Dataset sizes (batches):",
      len(train_dataset),
      0 if val_dataset is None else len(val_dataset),
      len(test_dataset))

In [None]:
# === Utility: TPU memory usage ===
def show_hbm_usage():
    fmt = functools.partial(humanize.naturalsize, binary=True)
    for d in jax.local_devices():
        stats = d.memory_stats()
        used = stats["bytes_in_use"]
        limit = stats["bytes_limit"]
        print(f"Using {fmt(used)} / {fmt(limit)} ({used/limit:%}) on {d}")

In [None]:
# === Load Gemma3-1B-IT with Tunix model wrappers ===
from huggingface_hub import snapshot_download

print(f"Downloading {MODEL_ID} from Hugging Face (you must have access)...")
local_model_path = snapshot_download(repo_id=MODEL_ID, ignore_patterns=["*.pth"])
print("Local model path:", local_model_path)

model_config = gemma3_tunix_model.ModelConfig.gemma3_1b()

mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))

with mesh:
    base_gemma = gemma3_params_sft.create_model_from_safe_tensors(
        local_model_path,
        model_config,
        mesh,
    )
    nnx.display(base_gemma)

import qwix

RANK = 64
ALPHA = 64.0

def get_lora_model(base_model, mesh):
    lora_provider = qwix.LoraProvider(
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum"
        ),
        rank=RANK,
        alpha=ALPHA,
    )
    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(base_model, lora_provider, **model_input)
    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)
    return lora_model

with mesh:
    lora_policy = get_lora_model(base_gemma, mesh=mesh)
    nnx.display(lora_policy)

tokenizer = tokenizer_lib.Tokenizer.from_hf_hub(model_id=MODEL_ID)
EOS_TOKENS = []
gen_cfg_path = os.path.join(local_model_path, "generation_config.json")
if os.path.exists(gen_cfg_path):
    with open(gen_cfg_path, "r") as f:
        gen_cfg = json.load(f)
    eos_ids = gen_cfg.get("eos_token_id", [])
    if isinstance(eos_ids, int):
        eos_ids = [eos_ids]
    EOS_TOKENS.extend(eos_ids)

if tokenizer.eos_id() not in EOS_TOKENS:
    EOS_TOKENS.append(tokenizer.eos_id())

print("EOS token IDs:", EOS_TOKENS)

In [None]:
# === CAST-style helpers ===

def extract_final_number(text: str) -> Optional[str]:
    if text is None:
        return None

    m = re.search(r"<answer>(.*?)</answer>", text, flags=re.IGNORECASE | re.DOTALL)
    segment = m.group(1) if m else text

    m = re.search(r"####\s*([-+]?[0-9][0-9.,/]*)", segment)
    if m:
        return m.group(1).replace(",", "").strip()

    nums = re.findall(r"[-+]?[0-9][0-9.,/]*", segment)
    if not nums:
        return None
    return nums[-1].replace(",", "").strip()


def cast_style_scores(
    completions: List[str],
    answers: List[Optional[str]],
) -> Tuple[List[float], List[float], List[float]]:
    math_accs: List[float] = []
    completeness: List[float] = []
    format_bonus: List[float] = []

    for completion, target in zip(completions, answers):
        text = completion or ""
        target_text = target or ""

        model_ans = extract_final_number(text)
        target_ans = extract_final_number(target_text)
        m_acc = 1.0 if (
            model_ans is not None
            and target_ans is not None
            and model_ans == target_ans
        ) else 0.0

        lower = text.lower()
        has_tags = ("<reasoning" in lower) and ("<answer" in lower)

        reasoning_match = re.search(
            r"<reasoning>(.*?)</reasoning>", text, flags=re.IGNORECASE | re.DOTALL
        )
        answer_match = re.search(
            r"<answer>(.*?)</answer>", text, flags=re.IGNORECASE | re.DOTALL
        )

        reasoning_len = len(reasoning_match.group(1).strip()) if reasoning_match else 0
        answer_len = len(answer_match.group(1).strip()) if answer_match else 0
        total_len = reasoning_len + answer_len

        if has_tags and total_len > 0:
            c_score = min(1.0, total_len / 200.0)
            f_bonus = 1.0
        else:
            c_score = 0.0
            f_bonus = 0.0

        math_accs.append(float(m_acc))
        completeness.append(float(c_score))
        format_bonus.append(float(f_bonus))

    return math_accs, completeness, format_bonus

In [None]:
# === Reward functions for Tunix GRPO ===

match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{re.escape(REASONING_START)}.+?{re.escape(REASONING_END)}.*?"
    rf"{re.escape(ANSWER_START)}(.+?){re.escape(ANSWER_END)}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

def reward_format_exact(prompts, completions, **kwargs):
    scores = []
    for resp in completions:
        scores.append(2.0 if match_format.search(resp or "") else 0.0)
    return scores

def reward_format_soft(prompts, completions, **kwargs):
    scores = []
    for resp in completions:
        r = 0.0
        text = resp or ""
        r += 0.5 if text.count(REASONING_START) == 1 else -0.5
        r += 0.5 if text.count(REASONING_END) == 1 else -0.5
        r += 0.5 if text.count(ANSWER_START) == 1 else -0.5
        r += 0.5 if text.count(ANSWER_END) == 1 else -0.5
        scores.append(r)
    return scores

def reward_cast_math_and_completeness(prompts, completions, answer, **kwargs):
    math_accs, completeness, fbonus = cast_style_scores(completions, answer)
    scores = []
    for ma, c, fb in zip(math_accs, completeness, fbonus):
        scores.append(3.0 * ma + 2.0 * c + 1.0 * fb)
    return scores

print("Reward functions defined.")

In [None]:
# === Evaluation utilities ===

def build_sampler(policy_model, tokenizer, model_config):
    return sampler_lib.Sampler(
        transformer=policy_model,
        tokenizer=tokenizer,
        cache_config=sampler_lib.CacheConfig(
            cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
            num_layers=model_config.num_layers,
            num_kv_heads=model_config.num_kv_heads,
            head_dim=model_config.head_dim,
        ),
    )


def generate_answers(questions, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None):
    if isinstance(questions, str):
        batch = [
            TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=questions),
        ]
    else:
        batch = [
            TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=q)
            for q in questions
        ]
    out = sampler(
        input_strings=batch,
        max_generation_steps=TOTAL_GENERATION_STEPS,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        echo=False,
        seed=seed,
        eos_tokens=EOS_TOKENS,
    )
    texts = out.text
    return texts[0] if isinstance(questions, str) else texts


def evaluate_dataset(dataset, sampler, num_passes=1):
    total = 0
    strict_correct = 0
    approx_correct = 0
    format_ok = 0

    for batch in dataset:
        questions = batch["question"]
        answers = batch["answer"]
        multiple_outputs = [[] for _ in range(len(questions))]

        for s in range(num_passes):
            responses = generate_answers(
                questions,
                sampler,
                temperature=GENERATION_CONFIGS["greedy"]["temperature"],
                top_k=GENERATION_CONFIGS["greedy"]["top_k"],
                top_p=GENERATION_CONFIGS["greedy"]["top_p"],
                seed=s,
            )
            for idx, resp in enumerate(responses):
                multiple_outputs[idx].append(resp)

        for q, a, resp_list in zip(questions, answers, multiple_outputs):
            is_correct = False
            is_approx = False
            has_format = False
            for resp in resp_list:
                if match_format.search(resp or "") is not None:
                    has_format = True
                guess = extract_final_number(resp or "")
                truth = extract_final_number(a or "")
                try:
                    if truth is not None and guess is not None:
                        g = float(guess)
                        t = float(truth)
                        if g == t:
                            is_correct = True
                        ratio = g / t if t != 0 else 0.0
                        if 0.9 <= ratio <= 1.1:
                            is_approx = True
                except Exception:
                    pass
                if is_correct and is_approx and has_format:
                    break

            total += 1
            if is_correct:
                strict_correct += 1
            if is_approx:
                approx_correct += 1
            if has_format:
                format_ok += 1

    acc = 100.0 * strict_correct / max(1, total)
    approx_acc = 100.0 * approx_correct / max(1, total)
    fmt_acc = 100.0 * format_ok / max(1, total)

    print(f"Total examples: {total}")
    print(f"Strict accuracy: {acc:.2f}%")
    print(f"Approx accuracy: {approx_acc:.2f}%")
    print(f"Format accuracy: {fmt_acc:.2f}%")
    return dict(
        total=total,
        strict_accuracy=acc,
        approx_accuracy=approx_acc,
        format_accuracy=fmt_acc,
    )

In [None]:
# === Baseline evaluation before GRPO ===

baseline_sampler = build_sampler(lora_policy, tokenizer, model_config)
print("Evaluating baseline policy on a small test subset...")
baseline_metrics = evaluate_dataset(test_dataset, baseline_sampler, num_passes=1)
baseline_metrics

In [None]:
# === RLCluster, optimizer, and GRPOLearner setup ===

ckpt_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS,
    max_to_keep=MAX_TO_KEEP,
)

log_dir = "/kaggle/working/tensorboard/grpo"
metrics_opts = metrics_logger.MetricsLoggerOptions(
    log_dir=log_dir,
    flush_every_n_steps=20,
)

schedule = optax.schedules.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    decay_steps=MAX_STEPS,
    end_value=0.0,
)
optimizer = optax.adamw(
    learning_rate=schedule,
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
    optimizer = optax.chain(
        optax.clip_by_global_norm(MAX_GRAD_NORM),
        optimizer,
    )

cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine="vanilla",
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=64,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        metrics_logging_options=metrics_opts,
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=ckpt_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=EOS_TOKENS,
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=0.08,
    epsilon=0.2,
)

rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=base_gemma,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        reward_format_exact,
        reward_format_soft,
        reward_cast_math_and_completeness,
    ],
    algo_config=grpo_config,
)

print("RLCluster and GRPOLearner ready.")

In [None]:
# === Run GRPO training ===

with mesh:
    show_hbm_usage()
    grpo_trainer.train(train_dataset, val_dataset)

print("GRPO training complete.")

In [None]:
# === Load final trained LoRA params and re-evaluate ===

trained_ckpt_path = os.path.join(
    CKPT_DIR, "actor", str(MAX_STEPS), "model_params"
)

abs_lora = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_lora)

nnx.update(
    lora_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_policy, nnx.LoRAParam),
        trained_lora_params,
    ),
)

print("Loaded trained LoRA params into policy model.")

finetuned_sampler = build_sampler(lora_policy, tokenizer, model_config)
print("Evaluating finetuned policy on test subset...")
finetuned_metrics = evaluate_dataset(test_dataset, finetuned_sampler, num_passes=1)
finetuned_metrics

In [None]:
# === Export final Tunix checkpoint (no safetensors) ===
# We package the final ACTOR checkpoint directory so judges can load it with Tunix.

final_export_dir = "./tunix_dsa_cast_grpo_actor_ckpt"
if os.path.exists(final_export_dir):
    shutil.rmtree(final_export_dir)

actor_step_dir = os.path.dirname(trained_ckpt_path)
shutil.copytree(actor_step_dir, final_export_dir)

print("Copied final actor checkpoint to:", final_export_dir)
print("\nContents:")
for root, dirs, files in os.walk(final_export_dir):
    level = root.replace(final_export_dir, "").count(os.sep)
    indent = " " * (2 * level)
    print(f"{indent}{os.path.basename(root)}/")
    subindent = " " * (2 * (level + 1))
    for f in files:
        size_mb = os.path.getsize(os.path.join(root, f)) / (1024 * 1024)
        print(f"{subindent}{f}  ({size_mb:.2f} MB)")

shutil.make_archive("tunix_dsa_cast_grpo_actor_ckpt", "zip", final_export_dir)
print("\nCreated zip archive: tunix_dsa_cast_grpo_actor_ckpt.zip")