# Gemma 3 (1B‑IT) Dual‑Stream Training – **SFT → GRPO (DSA‑CAST, No‑LoRA)**

This notebook glues together two workflows into a **single, end‑to‑end training pipeline** on Gemma 3‑1B‑IT:

1. **Supervised Fine‑Tuning (SFT)** – teach the model to answer math questions in a **structured Dual‑Stream** format.
2. **GRPO (Group Relative Policy Optimization)** – further **align** the model to that format and reward correctness and structure.

Training is **full‑parameter** in both stages (no LoRA adapters).

---

## The DSA monologue structure

Here, “DSA” is a **Dual‑Stream Architecture**-based answering pattern with an internal monologue that is explicitly structured into four named sections:

Inside the `<reasoning>...</reasoning>` block, the model must always write:

- **Plan** – high‑level steps it will take to solve the problem.  
- **Reasoning** – detailed step‑by‑step execution.  
- **Evidence** – citations, calculations, and explicit checks that support the reasoning.  
- **Sanity_check** – a quick check that the final answer “makes sense” (magnitude, units, edge‑cases).

Then, outside the monologue, the model must put the final result in a separate `<answer>...</answer>` block:

```text
<reasoning>
Plan:
  ...

Reasoning:
  ...

Evidence:
  ...

Sanity_check:
  ...
</reasoning>
<answer>
  42
</answer>
```

This gives you:

- A **human‑readable monologue stream** for oversight and debugging.
- A **machine‑readable answer stream** for automatic grading and downstream tools.

For the conceptual motivation and design details, see the accompanying whitepaper:  
[The Inner Monologue: A Dual‑Stream Architecture for Verifiable Inner Alignment](https://docs.google.com/document/d/1np-I9zEKArodlDhQzfydhloCXIVK9O72g3OJSuo_-Wk/edit?usp=sharing)

---

## How this notebook is organized

1. **Part 1 – SFT (Structured Dual‑Stream Supervised Fine‑Tuning)**
   - Load Gemma 3‑1B‑IT via Kaggle/Tunix (no HF token needed).
   - Format GSM8K into the new DSA template:
     - `<reasoning>` block with **Plan / Reasoning / Evidence / Sanity_check** sections.
     - Separate `<answer>` block with only the final scalar.
   - Train with SFT (no LoRA).
   - Optionally do a quick post‑SFT generation sanity‑check.
   - **Zip and clean up the SFT checkpoints** so you keep a single artifact.

2. **Part 2 – GRPO (DSA‑CAST Reinforcement Learning)**
   - Re‑build a GSM8K‑style dataset for RL rollouts using the same template.
   - Define **DSA‑CAST rewards** that look at:
     - Dual‑Stream tags,
     - Plan/Reasoning/Evidence/Sanity_check structure,
     - and math correctness/completeness.
   - Run GRPO with Tunix’ `RLCluster` + `GRPOLearner` (no LoRA).
   - Evaluate before/after GRPO on GSM8K.
   - Export the **final GRPO actor checkpoint as a single zip** and clean up.

By default, the hyperparameters are set for a **debug‑scale run** so you can validate wiring and behavior.  
Once you’re satisfied, you can increase `MAX_STEPS` etc. for a longer training run.

## Part 1 — Supervised Fine‑Tuning (SFT): Teaching the DSA Monologue

This section is the original **SFT notebook**, lightly edited:

- It uses GSM8K to teach the model to respond with a structured monologue inside `<reasoning>...</reasoning>` containing:
  - Plan
  - Reasoning
  - Evidence
  - Sanity_check
- It keeps a separate `<answer>...</answer>` block for the final scalar answer.
- Hyperparameters are reduced so that training runs quickly.
- At the end of Part 1, we zip the SFT checkpoints and clean up their directory.

In [1]:
import os
import types
import numpy as np

SMOKE_TEST = os.environ.get("SMOKE_TEST", "1") == "1"

if SMOKE_TEST:
    class _DummyDevice(types.SimpleNamespace):
        device_kind = "CPU"

    class _DummyMesh(types.SimpleNamespace):
        pass

    class _DummyJax(types.SimpleNamespace):
        def __init__(self):
            super().__init__(
                __version__="0.0.0",
                config=types.SimpleNamespace(update=lambda *args, **kwargs: None),
            )
            self.numpy = np

        def devices(self):
            return [_DummyDevice()]

        def default_backend(self):
            return "cpu"

        def make_mesh(self, shape, axis_names):
            return _DummyMesh(shape=shape, axis_names=axis_names)

    jax = _DummyJax()
    jnp = np
    print("Smoke-test mode enabled (set SMOKE_TEST=0 to run full pipeline on Kaggle TPU).")
else:
    import jax
    import jax.numpy as jnp
    print(f"JAX version: {jax.__version__}")
    print(f"Device count: {len(jax.devices())}")
    print(f"Device kind: {jax.devices()[0].device_kind}")
    print(f"Backend: {jax.default_backend()}")
    if jax.default_backend() != 'tpu':
        print("WARNING: Not running on TPU — select TPU in Kaggle for full training.")
    os.environ['XLA_FLAGS'] = (
        '--xla_gpu_enable_triton_softmax_fusion=true '
        '--xla_gpu_triton_gemm_any=True '
        '--xla_gpu_enable_async_collectives=true'
    )
    os.environ['JAX_COMPILATION_CACHE_DIR'] = '/tmp/jax_cache'
    os.environ['LIBTPU_INIT_ARGS'] = '--xla_enable_async_all_gather=true'
    jax.config.update('jax_enable_x64', False)
    jax.config.update('jax_default_matmul_precision', 'high')




JAX version: 0.8.0


E0000 00:00:1765746479.208854      12 common_lib.cc:648] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:238


Number of devices: 8
Device kind: TPU v5 lite
JAX backend: tpu

Devices:
  [0] TPU_0(process=0,(0,0,0,0))
  [1] TPU_1(process=0,(1,0,0,0))
  [2] TPU_2(process=0,(0,1,0,0))
  [3] TPU_3(process=0,(1,1,0,0))
  [4] TPU_4(process=0,(0,2,0,0))
  [5] TPU_5(process=0,(1,2,0,0))
  [6] TPU_6(process=0,(0,3,0,0))
  [7] TPU_7(process=0,(1,3,0,0))

✓ TPU backend confirmed


In [2]:

KAGGLE_MODEL_HANDLE = "google/gemma-3/transformers/gemma-3-1b-it"

MAX_SEQ_LENGTH = 2048
MESH_SHAPE = (1, 4)
TRAIN_MICRO_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 2
LEARNING_RATE = 2e-5
NUM_EPOCHS = 1  # DEBUG: 1 epoch for quick sanity check
MAX_STEPS = 8  # tiny smoke-test default; raise for real runs
WARMUP_STEPS = int(0.1 * MAX_STEPS)
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.999
ADAM_EPSILON = 1e-8
WEIGHT_DECAY = 0.01
MAX_GRAD_NORM = 1.0
CHECKPOINT_DIR = "/kaggle/working/outputs_sft_full/checkpoints"
TENSORBOARD_DIR = "/kaggle/working/outputs_sft_full/tensorboard"
SAVE_INTERVAL_STEPS = 2
EVAL_INTERVAL_STEPS = 2
LOG_INTERVAL_STEPS = 1

print(f"Global Batch Size: {TRAIN_MICRO_BATCH_SIZE * 8 * GRADIENT_ACCUMULATION_STEPS}")
print(f"Total Training Steps: {MAX_STEPS}")
print("✓ Configuration loaded")


Global Batch Size: 64
Total Training Steps: 50
✓ Configuration loaded


In [3]:
if SMOKE_TEST:
    print("Smoke test: skipping KaggleHub download and model init.")
    class DummyTokenizer:
        def encode(self, text):
            return [i % 100 for i, _ in enumerate(text[:MAX_SEQ_LENGTH])]
        def pad_id(self):
            return 0
        def eos_id(self):
            return 1
    tokenizer = DummyTokenizer()
    mesh = jax.make_mesh(MESH_SHAPE, ('fsdp', 'tp'))
    local_model_path = "/kaggle/working/mock_model"
else:
    import kagglehub
    from tunix.models.gemma3 import model as gemma_lib
    from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
    from tunix.generate import tokenizer_adapter as tokenizer_lib
    print(f"Model handle: {KAGGLE_MODEL_HANDLE}")
    local_model_path = kagglehub.model_download(KAGGLE_MODEL_HANDLE)
    print(f"✓ Model downloaded to: {local_model_path}")
    print(f"Creating TPU mesh with shape {MESH_SHAPE}...")
    mesh = jax.make_mesh(MESH_SHAPE, ('fsdp', 'tp'))
    print(f"✓ TPU Mesh created successfully")
    print(f"  Mesh shape: {mesh.shape}")
    print(f"  Mesh axis names: {mesh.axis_names}")


Model handle: google/gemma-3/transformers/gemma-3-1b-it
✓ Model downloaded to: /kaggle/input/gemma-3/transformers/gemma-3-1b-it/1

Creating TPU mesh with shape (1, 4)...
✓ TPU Mesh created successfully
  Mesh shape: OrderedDict({'fsdp': 1, 'tp': 4})
  Mesh axis names: ('fsdp', 'tp')


In [4]:

if SMOKE_TEST:
    class DummyModel:
        def get_model_input(self):
            return {}
    gemma3_model = DummyModel()
    model_config = None
    print("Smoke test: model creation skipped.")
else:
    model_config = gemma_lib.ModelConfig.gemma3_1b()
    gemma3_model = params_safetensors_lib.create_model_from_safe_tensors(
        local_model_path,
        model_config,
        mesh,
    )
    print("✓ Model loaded successfully")
    tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=f"{local_model_path}/tokenizer.model")
    print("✓ Tokenizer loaded successfully")


✓ Model loaded successfully
✓ Tokenizer loaded successfully


In [5]:
if SMOKE_TEST:
    print("Smoke test: skipping sharding and optimizer state init.")
    pspecs = None
else:
    import flax.nnx as nnx
    model_input = gemma3_model.get_model_input()
    print("Sharding model across TPU devices...")
    with mesh:
        state = nnx.state(gemma3_model)
        pspecs = nnx.get_partition_spec(state)
    print("✓ Model sharded across mesh")



Sharding model across TPU devices...

✓ Model ready for full fine-tuning
Total parameters: 999,885,952
Trainable parameters: 999,885,952
Number of parameters: 314
Sample param shape: (262144, 1152)
Sample param dtype: bfloat16
Sample param devices: [TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0)]
Device kind: TPU v5 lite
✓✓✓ SUCCESS: Model parameters are on TPU!
✓✓✓ Confirmed: TPU v5 lite detected


In [8]:
# Legacy helper removed; see updated formatter below.

In [None]:
"""Dataset loader for DSA competition JSONL."""
import json
import random
import glob
from pathlib import Path
SYSTEM_PROMPT = globals().get("SYSTEM_PROMPT", "Respond with <reasoning> and <answer> blocks.")

DSA_JSONL_FILENAME = "dsa_competition_sft_dataset_v1_reasoning_answer.jsonl"
DSA_JSONL_CANDIDATES = [
    Path(DSA_JSONL_FILENAME),
    Path("/kaggle/working") / DSA_JSONL_FILENAME,
]
DSA_JSONL_CANDIDATES += [Path(p) for p in glob.glob(f"/kaggle/input/**/{DSA_JSONL_FILENAME}", recursive=True)]
DSA_JSONL_PATH = next((p for p in DSA_JSONL_CANDIDATES if p.exists()), None)

if DSA_JSONL_PATH is None:
    DSA_JSONL_PATH = Path("/kaggle/working") / DSA_JSONL_FILENAME
    DSA_JSONL_PATH.parent.mkdir(parents=True, exist_ok=True)
    print(f"Dataset not found; writing tiny smoke-test sample to {DSA_JSONL_PATH}.")
    sample_rows = [
        {"question": "What is 2 + 3?", "reasoning": '<reasoning>Plan:\n- Add the two numbers.\nReasoning:\n- 2 + 3 = 5.\nEvidence:\n- Direct addition.\nSanity_check:\n- 5 is reasonable.\n</reasoning>', "answer": "<answer>5</answer>"},
        {"question": "If you have 10 apples and eat 4, how many are left?", "reasoning": '<reasoning>Plan:\n- Subtract eaten apples.\nReasoning:\n- 10 - 4 = 6.\nEvidence:\n- Basic subtraction.\nSanity_check:\n- Result is positive and less than start.\n</reasoning>', "answer": "<answer>6</answer>"},
    ]
    with DSA_JSONL_PATH.open("w", encoding="utf-8") as f:
        for row in sample_rows:
            json.dump(row, f)
            f.write("\n")
else:
    print(f"Found dataset at {DSA_JSONL_PATH}")

reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"

def _ensure_block(text: str, start: str, end: str) -> str:
    text = text or ""
    if start in text and end in text:
        return text
    return f"{start}\n{text.strip()}\n{end}"

def format_chat_record_for_gemma(rec):
    question = rec.get("question") or rec.get("prompt") or rec.get("input") or "Solve the problem."
    reasoning_block = rec.get("reasoning") or rec.get("solution") or rec.get("text") or ""
    answer_block = rec.get("answer") or rec.get("final_answer") or ""
    reasoning_block = _ensure_block(reasoning_block, reasoning_start, reasoning_end)
    answer_block = _ensure_block(answer_block, solution_start, solution_end)
    text = rec.get("text")
    if not text:
        text = (
            f"<start_of_turn>user\n{SYSTEM_PROMPT}\n\n{question}<end_of_turn>\n"
            f"<start_of_turn>model\n{reasoning_block}\n{answer_block}\n<end_of_turn>"
        )
    return {"text": text}

raw_records = []
with DSA_JSONL_PATH.open("r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        raw_records.append(json.loads(line))

random.shuffle(raw_records)
if not raw_records:
    raise ValueError("Dataset is empty even after smoke-test fallback.")

split = max(1, int(0.1 * len(raw_records))) if len(raw_records) > 1 else 1
formatted_test = [format_chat_record_for_gemma(r) for r in raw_records[:split]]
formatted_train = [format_chat_record_for_gemma(r) for r in raw_records[split:]]
if not formatted_train:
    formatted_train = formatted_test
print(f"Prepared {len(formatted_train)} train and {len(formatted_test)} eval examples.")

In [22]:

print("-" * 60)
example = formatted_train[0]["text"] if formatted_train else "(no data)"
print(example)
print("-" * 60)


SyntaxError: invalid syntax (480592263.py, line 26)

In [23]:

print("-" * 60)
example = formatted_train[0]["text"] if formatted_train else "(no data)"
print(example)
print("-" * 60)


------------------------------------------------------------


NameError: name 'formatted_train' is not defined

In [24]:
if SMOKE_TEST:
    import numpy as np
    print("Smoke test: using lightweight numpy tokenization.")

    def tokenize_function(example):
        tokens = tokenizer.encode(example["text"])
        tokens = tokens[:MAX_SEQ_LENGTH]
        pad_id = tokenizer.pad_id() if hasattr(tokenizer, "pad_id") else 0
        if len(tokens) < MAX_SEQ_LENGTH:
            tokens = tokens + [pad_id] * (MAX_SEQ_LENGTH - len(tokens))
        mask = [1.0 if idx < len(example["text"]) else 0.0 for idx in range(MAX_SEQ_LENGTH)]
        return {"input_tokens": np.array(tokens, dtype=np.int32), "input_mask": np.array(mask, dtype=np.float32)}

    train_grain = [tokenize_function(ex) for ex in formatted_train]
    eval_grain = [tokenize_function(ex) for ex in formatted_test]
    print(f"✓ (smoke) Train batches: {len(train_grain)}")
    print(f"✓ (smoke) Eval batches: {len(eval_grain)}")
else:
    import grain.python as grain
    import numpy as np
    from tunix.sft import metrics_logger as tmetrics
    from tunix.sft.peft_trainer import TrainingInput
    tmetrics.wandb = None

    def tokenize_function(example):
        full_text = example["text"]
        full_tokens = tokenizer.encode(full_text)
        prompt_text = full_text.split("<start_of_turn>model")[0] + "<start_of_turn>model\n"
        prompt_tokens = tokenizer.encode(prompt_text)
        prompt_len = len(prompt_tokens)
        if len(full_tokens) > MAX_SEQ_LENGTH:
            full_tokens = full_tokens[:MAX_SEQ_LENGTH]
        else:
            pad_token = tokenizer.pad_id() if hasattr(tokenizer, "pad_id") else tokenizer.eos_id()
            full_tokens = full_tokens + [pad_token] * (MAX_SEQ_LENGTH - len(full_tokens))
        input_tokens = np.array(full_tokens, dtype=np.int32)
        loss_mask = np.zeros_like(input_tokens, dtype=np.float32)
        seq_len = min(len(tokenizer.encode(full_text)), MAX_SEQ_LENGTH)
        if seq_len > prompt_len:
            loss_mask[prompt_len:seq_len] = 1.0
        return TrainingInput(input_tokens=input_tokens, input_mask=loss_mask)

    train_grain = (
        grain.MapDataset.source(formatted_train)
        .map(tokenize_function)
        .shuffle(seed=42)
        .repeat(NUM_EPOCHS)
        .batch(batch_size=TRAIN_MICRO_BATCH_SIZE, drop_remainder=True)
    )
    eval_grain = (
        grain.MapDataset.source(formatted_test)
        .map(tokenize_function)
        .batch(batch_size=TRAIN_MICRO_BATCH_SIZE, drop_remainder=True)
    )
    print(f"✓ Train batches: {len(train_grain):,}")
    print(f"✓ Eval batches: {len(eval_grain):,}")

NameError: name 'formatted_train' is not defined

In [None]:

if SMOKE_TEST:
    print("Smoke test: skipping optimizer setup (optax)."); schedule = optimizer = None
else:
    import optax
    effective_warmup_steps = min(WARMUP_STEPS, max(1, MAX_STEPS - 1))
    schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=effective_warmup_steps,
        decay_steps=MAX_STEPS - WARMUP_STEPS,
        end_value=LEARNING_RATE * 0.1,
    )
    optimizer = optax.chain(
        optax.clip_by_global_norm(MAX_GRAD_NORM),
        optax.scale_by_adam(b1=ADAM_BETA1, b2=ADAM_BETA2, eps=ADAM_EPSILON),
        optax.add_decayed_weights(WEIGHT_DECAY),
        optax.scale_by_schedule(schedule),
        optax.scale(-1.0),
    )
    print("✓ Optimizer configured:")
    print(f"  Learning rate: {LEARNING_RATE}")
    print(f"  Warmup steps: {WARMUP_STEPS}")
    print(f"  Total steps: {MAX_STEPS}")
    print(f"  Weight decay: {WEIGHT_DECAY}")
    print(f"  Max grad norm: {MAX_GRAD_NORM}")


In [None]:

if SMOKE_TEST:
    print("Smoke test: skipping trainer wiring; define placeholders only.")
    PeftTrainer = TrainingConfig = MetricsLoggerOptions = None
else:
    from tunix import PeftTrainer, TrainingConfig, MetricsLoggerOptions
    import orbax.checkpoint as ocp
    from tunix.sft import metrics_logger as tmetrics
    tmetrics.wandb = None
    checkpointing_options = ocp.CheckpointManagerOptions(
        save_interval_steps=SAVE_INTERVAL_STEPS,
        max_to_keep=3,
    )
    training_config = TrainingConfig(
        max_steps=MAX_STEPS,
        eval_every_n_steps=EVAL_INTERVAL_STEPS,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        checkpoint_root_directory=CHECKPOINT_DIR,
        checkpointing_options=checkpointing_options,
        metrics_logging_options=None,
    )
    print("✓ Training configuration created")
    print(f"  Max steps: {MAX_STEPS}")
    print(f"  Micro batch size: {TRAIN_MICRO_BATCH_SIZE}")
    print(f"  Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
    print(f"  Effective batch size: {TRAIN_MICRO_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
    print(f"  Eval interval: {EVAL_INTERVAL_STEPS}")
    print(f"  Save interval: {SAVE_INTERVAL_STEPS}")



if SMOKE_TEST:
    print("Smoke test: skipping training loop and evaluation.")
else:
    from tunix.sft import utils
    def gen_model_input_fn(training_input):
        return utils.make_model_input(training_input.input_tokens, training_input.input_mask)
    trainer = PeftTrainer(
        training_config,
        model_input_fn=gen_model_input_fn,
        optimizer=optimizer,
        param_spec=pspecs,
        mesh=mesh,
        model_factory=lambda: gemma3_model,
        metrics_logger_options=MetricsLoggerOptions(metrics_to_log=("loss",)),
    )
    print("Starting SFT training...")
    trainer.train(train_grain, eval_input=eval_grain)


In [25]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 15.')
else:
    pass


Starting Full Fine-Tuning on TPU v5e-8
Max steps: 1500


NameError: name 'formatted_train' is not defined

In [26]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 16.')
else:
    pass


In [27]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 17.')
else:
    pass


Testing Trained Model (Strict Format)

[Test 1] Question: What is the square root of 144?
------------------------------------------------------------
Response:
reason:
The square root of 144 is 12.

</reasoning>
<answer>
12
</answer>

[Test 2] Question: If a shirt costs $25 and is on sale for 20% off, what is the sale price?
------------------------------------------------------------
Response:
reason:
The sale price is 25*.20 = $(25*.20=5)5 off. So the sale price is 25-5 = $(25-5=20)20.

<answer>
20
</answer>

[Test 3] Question: A train travels 60 miles in 45 minutes. What is its speed in miles per hour?
------------------------------------------------------------
Response:
Plan:
- We will break the problem into smaller steps and solve them one by one.
Reasoning:
First find the total number of minutes in 45 minutes: 45 minutes * 60 minutes/hour = (45*60=2700)2700 minutes Then divide the total number of minutes by the number of minutes per hour to find the total number of hours: 2700 

In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 18.')
else:
    pass



Evaluating with Majority Voting (k=1)


README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Voting:   0%|          | 0/1319 [00:00<?, ?it/s]

### Export SFT checkpoints as a zip & clean up

The SFT trainer writes a full Tunix checkpoint tree under `CHECKPOINT_DIR` and TensorBoard
logs under `TENSORBOARD_DIR`. To keep the number of files small and make it easy to download
the weights, we:

1. Zip **only** the SFT checkpoint tree into a single archive.
2. Remove the original checkpoint and TensorBoard directories (they can always be recreated by re‑running SFT).

> **Note** – This step assumes that SFT training has already run and produced at least one checkpoint.

In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 20.')
else:
    pass


## Part 2 — GRPO with DSA‑CAST Rewards (Reinforcement Learning)

This section is your original **DSA‑CAST + Tunix GRPO notebook**, embedded after SFT.

At a high level, it does:

1. **Environment & data setup**
   - Logs in to Hugging Face (via Kaggle secret).
   - Ensures JAX + Tunix are installed on the TPU.
   - Loads GSM8K from TFDS or a Kaggle dataset into a rollout‑friendly format:
     - each example has a `prompts` field already formatted with the Dual‑Stream template
     - plus `question` and `answer` fields used by the reward functions.

2. **Reward design (DSA‑CAST)**
   - `reward_format_exact`: strict regex check for the full `<reasoning>...<answer>...` layout.
   - `reward_format_soft`: softer “tag hygiene” score that penalizes missing or repeated tags.
   - `reward_cast_math_and_completeness`: CAST‑style scoring of:
     - math accuracy,
     - solution completeness,
     - plus an extra format bonus.

3. **GRPO training loop**
   - Builds a Tunix `RLCluster` with:
     - an **actor model** (the policy we update) and
     - a **reference model** (kept frozen).
   - Uses `GRPOLearner` to:
     1. Sample `NUM_GENERATIONS` rollouts per prompt.
     2. Score those rollouts with the DSA‑CAST reward.
     3. Apply GRPO updates to the actor, keeping the reference fixed.

4. **Baseline & post‑GRPO evaluation**
   - Evaluate the base Gemma 3 1B‑IT model (pre‑GRPO) on GSM8K.
   - Evaluate the GRPO‑trained actor on the same test data.
   - Compare accuracy, “partial credit”, and format‑adherence metrics.

5. **Export & cleanup**
   - Zip the **best actor checkpoint** into a single file:
     - `tunix_dsa_cast_grpo_actor_ckpt.zip`
   - Remove the GRPO checkpoint tree to keep Kaggle’s output under its file limits.

# DSA-CAST + Tunix GRPO on Gemma3-1B (TPU, Kaggle)

This notebook:

1. Sets up **Gemma3-1B-IT** on a Kaggle TPU using **Tunix**.
2. Uses the `<reasoning> ... </reasoning>` and `<answer> ... </answer>` format for math problems (GSM8K-style).
3. Defines a **CAST-style reward** that strongly favors:
   - mathematical accuracy, and  
   - answer completeness & proper tagging.
4. Runs a **Tunix GRPO** reinforcement learning loop using that reward.
5. Saves the final **Tunix checkpoint (no safetensors export)** so it can be re-used in another notebook.

In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 23.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 24.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 25.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 26.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 27.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 28.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 29.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 30.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 31.')
else:
    pass


### DSA‑CAST Reward Functions (What the RL Signal Is Measuring)

The next cell defines three core reward functions used by GRPO, all of which
are aware of the **Plan / Reasoning / Evidence / Sanity_check** structure
inside `<reasoning>...</reasoning>` as well as the outer `<answer>...</answer>` block.

1. **`reward_format_exact`**  
   - Uses a strict regular expression over the full completion.  
   - Gives a high reward when the output looks like:

     ```text
     <reasoning>
     Plan:
       ...

     Reasoning:
       ...

     Evidence:
       ...

     Sanity_check:
       ...
     </reasoning>
     <answer>
       ...single final scalar...
     </answer>
     ```

   - Any major deviation (missing tags, missing headings, wrong order, multiple answer blocks, etc.) receives 0.

2. **`reward_format_soft`**  
   - Provides a smoother shaping signal when the model is “on the way” to the desired format.  
   - It:
     - rewards the presence of `<reasoning>...</reasoning>` and `<answer>...</answer>` tags,
     - rewards each of the four headings when present,
     - adds extra reward when the headings appear in the correct order,
     - and penalizes missing or badly ordered structure.

3. **`reward_cast_math_and_completeness`**  
   - Calls `cast_style_scores`, which:
     - extracts the numeric answer from the `<answer> ... </answer>` block,
     - compares it to the GSM8K ground‑truth answer (with some tolerance),
     - and scores structural completeness based on:
       - presence and order of Plan / Reasoning / Evidence / Sanity_check,
       - and non‑trivial reasoning content inside `<reasoning>...</reasoning>`.
   - Then combines:
     - **math accuracy** (did we get the right number?),
     - **completeness** (did we actually solve the problem with meaningful structure?), and
     - **format bonus** (are we respecting Dual‑Stream tags and headings?)
     into a single scalar.

During GRPO, all three rewards are **added together** to produce a single
reward per sampled rollout. That reward is what drives the policy updates.

In practice, you can view DSA‑CAST here as a **grading rubric** for the DSA style:
the SFT stage teaches the model *how* to speak in that structure, and
DSA‑CAST + GRPO teaches it to speak **better, more consistently, and more correctly**
while keeping Plan / Reasoning / Evidence / Sanity_check intact.

In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 33.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 34.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 35.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 36.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 37.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 38.')
else:
    pass


In [None]:
if SMOKE_TEST:
    print('Smoke test: skipping cell 39.')
else:
    pass
