# DSA SFT + GRPO (No-LoRA) Tunix â€” Repaired

This notebook is trimmed for quick smoke-tests while keeping the structure needed to fine-tune on Kaggle TPU.

* Default settings run a lightweight dry-run so the notebook executes quickly.
* Set `SMOKE_TEST = False` when running on a real Kaggle TPU to enable full Tunix training.

In [1]:
import os
import re
import json
import math
import random
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Any

import numpy as np

try:
    import jax
    TPU_AVAILABLE = jax.default_backend() == 'tpu'
except Exception:
    TPU_AVAILABLE = False
    jax = None

print(f"TPU available: {TPU_AVAILABLE}")
if not TPU_AVAILABLE:
    print("Running in CPU fallback mode. Set Kaggle accelerator to TPU for full training.")


TPU available: False
Running in CPU fallback mode. Set Kaggle accelerator to TPU for full training.


In [2]:
# === Configuration ===
SMOKE_TEST = True  # Set to False on Kaggle TPU for real training

# Training-style hyperparameters (kept tiny for smoke test)
MAX_STEPS = 2 if SMOKE_TEST else 200
TRAIN_MICRO_BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 1
LEARNING_RATE = 2e-5

# Data
DATASET_PATHS = [
    Path('/kaggle/input/dsa-sft-grpo-nolora-tunix/dsa_competition_sft_dataset_v1_reasoning_answer.jsonl'),
    Path('/kaggle/input/dsa-competition-sft/dsa_competition_sft_dataset_v1_reasoning_answer.jsonl'),
    Path('dsa_competition_sft_dataset_v1_reasoning_answer.jsonl'),
]

# Dual-stream markers
REASONING_START = '<reasoning>'
REASONING_END = '</reasoning>'
ANSWER_START = '<answer>'
ANSWER_END = '</answer>'

print(f"Smoke test mode: {SMOKE_TEST}")
print(f"Max steps: {MAX_STEPS}")


Smoke test mode: True
Max steps: 2


In [3]:
def find_dataset_path() -> Path | None:
    for path in DATASET_PATHS:
        if path.exists():
            return path
    return None


def extract_reasoning_and_answer(response: str) -> tuple[str, str]:
    reasoning_match = re.search(f"{REASONING_START}(.*?){REASONING_END}", response, re.DOTALL)
    answer_match = re.search(f"{ANSWER_START}(.*?){ANSWER_END}", response, re.DOTALL)
    reasoning = reasoning_match.group(1).strip() if reasoning_match else ''
    answer = answer_match.group(1).strip() if answer_match else ''
    return reasoning, answer


def load_sft_dataset() -> list[Dict[str, Any]]:
    dataset_path = find_dataset_path()
    records: list[Dict[str, Any]] = []

    if dataset_path is None:
        print('Dataset file not found; generating a tiny synthetic sample for smoke testing.')
        examples = [
            {
                'question': 'What is 2 + 3?',
                'response': f"{REASONING_START}Plan: add the numbers. Reasoning: 2 + 3 = 5.{REASONING_END}{ANSWER_START}5{ANSWER_END}",
            },
            {
                'question': 'If Alice has 4 apples and buys 2 more, how many apples does she have?',
                'response': f"{REASONING_START}Plan: add the counts. Reasoning: 4 + 2 = 6.{REASONING_END}{ANSWER_START}6{ANSWER_END}",
            },
        ]
    else:
        print(f"Loading dataset from {dataset_path}")
        examples = []
        with open(dataset_path, 'r') as f:
            for line in f:
                if not line.strip():
                    continue
                rec = json.loads(line)
                examples.append(rec)

    for rec in examples:
        question = rec.get('question') or rec.get('prompt') or rec.get('instruction') or ''
        response = rec.get('response') or rec.get('output') or rec.get('completion') or ''
        reasoning, answer = extract_reasoning_and_answer(response)
        chat_text = (
            "<start_of_turn>user\n"
            f"{question}\n<end_of_turn>\n"
            "<start_of_turn>model\n"
            f"{REASONING_START}{reasoning}{REASONING_END}"
            f"{ANSWER_START}{answer}{ANSWER_END}\n"
            "<end_of_turn>"
        )
        records.append({
            'question': question,
            'response': response,
            'reasoning': reasoning,
            'answer': answer,
            'text': ''.join(chat_text),
        })

    random.shuffle(records)
    print(f"Prepared {len(records)} examples")
    return records


def train_test_split(records: list[Dict[str, Any]], train_fraction: float = 0.9):
    split = max(1, int(len(records) * train_fraction))
    return records[:split], records[split:]


dataset = load_sft_dataset()
train_records, test_records = train_test_split(dataset)
print(f"Train size: {len(train_records)} | Test size: {len(test_records)}")

Dataset file not found; generating a tiny synthetic sample for smoke testing.
Prepared 2 examples
Train size: 1 | Test size: 1


In [4]:
# Peek at a formatted example
sample = train_records[0]
print('Question:', sample['question'])
print('\nDual-stream formatted text:\n')
print(sample['text'])

Question: What is 2 + 3?

Dual-stream formatted text:

<start_of_turn>user
What is 2 + 3?
<end_of_turn>
<start_of_turn>model
<reasoning>Plan: add the numbers. Reasoning: 2 + 3 = 5.</reasoning><answer>5</answer>
<end_of_turn>


In [5]:
# === Tunix / model setup ===
RUN_TUNIX = not SMOKE_TEST
try:
    if RUN_TUNIX:
        from tunix import PeftTrainer, TrainingConfig, MetricsLoggerOptions
        from tunix.models.gemma3 import model as gemma_lib
        from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
        from tunix.generate import tokenizer_adapter as tokenizer_lib
        import kagglehub
        import jax
        import flax.nnx as nnx
        import optax
        from tunix.generate import sampler as sampler_lib
        from tunix.cast import rl_cluster as rl_cluster_lib
        from tunix.cast import loss_lib
        from tunix.cast import actor
        from tunix.cast import utils as cast_utils
        print('Tunix imported successfully.')
    else:
        print('Skipping Tunix imports (smoke test).')
except Exception as exc:
    print('Tunix imports unavailable; continuing in smoke-test mode.')
    RUN_TUNIX = False


Skipping Tunix imports (smoke test).


In [6]:
# === Training stub ===
if RUN_TUNIX:
    print('Starting placeholder Tunix training loop...')
    # This section is intentionally minimal to keep runtimes short during debugging.
    print(f"Would train for {MAX_STEPS} steps with batch size {TRAIN_MICRO_BATCH_SIZE}.")
    print('Dataset example count:', len(train_records))
    # In a real run, insert Tunix trainer setup here.
else:
    print('Smoke test: skipping model download and training.')

print('Notebook setup complete.')


Smoke test: skipping model download and training.
Notebook setup complete.
