# Tunix Zero-Cost Training Pipeline
SFT -> GRPO on Public Datasets

In [None]:

# Install Tunix and dependencies
# In Kaggle, we might need to install from a dataset or git
!pip install -q git+https://github.com/google-deepmind/tunix.git
!pip install -q flax==0.12.0 optax==0.2.4 chex==0.1.88
!pip install -q transformers==4.47.0 datasets==3.2.0

import os
import jax
import flax
import optax
from tunix.sft import peft_trainer
from tunix.rl.grpo import grpo_learner
from transformers import AutoTokenizer
import datasets

print(f"JAX devices: {jax.devices()}")


In [None]:

# Configuration
MODEL_ID = "google/gemma-2-2b-it"
DATASET_PATH = "/kaggle/input/tunix-public-data" # Placeholder for uploaded dataset
SFT_OUTPUT_DIR = "sft_checkpoint"
GRPO_OUTPUT_DIR = "grpo_checkpoint"

# Hyperparameters
SFT_STEPS = 500  # Adjust for 1.5h
GRPO_STEPS = 500 # Adjust for 5h


In [None]:

import re
import sympy

# --- Reward Functions ---
reasoning_pattern = re.compile(r"<reasoning>(.*?)</reasoning>", re.DOTALL)
answer_pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)

def structure_reward(prompts, completions, **kwargs):
    scores = []
    for completion in completions:
        has_reasoning = "<reasoning>" in completion and "</reasoning>" in completion
        has_answer = "<answer>" in completion and "</answer>" in completion
        score = 0.5 * has_reasoning + 0.5 * has_answer
        scores.append(score)
    return scores

def math_correctness_reward(prompts, completions, answer, **kwargs):
    scores = []
    for completion, true_ans in zip(completions, answer):
        match = answer_pattern.search(completion)
        if not match:
            scores.append(0.0)
            continue
        pred = match.group(1).strip()
        
        # Simple match for now, SymPy can be added if compatible
        if pred == true_ans.strip():
            scores.append(1.0)
        else:
            scores.append(0.0)
    return scores


In [None]:

# --- SFT Phase: Format Learning ---
print("Starting SFT Phase...")

# Load Data
# Assuming data is pre-formatted as 'text' column
sft_dataset = datasets.load_dataset("json", data_files=f"{DATASET_PATH}/sft_magpie.jsonl", split="train")

# Define Trainer
# Note: Using Tunix PeftTrainer APIs (Mock usage based on repo)
trainer = peft_trainer.PeftTrainer(
    model_name=MODEL_ID,
    train_dataset=sft_dataset,
    max_steps=SFT_STEPS,
    output_dir=SFT_OUTPUT_DIR,
    per_device_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    use_lora=True,
)

# Train
# trainer.train() 
# trainer.save_model(SFT_OUTPUT_DIR)
print("SFT Completed (Mock)")


In [None]:

# --- GRPO Phase: Reinforcement ---
print("Starting GRPO Phase...")

# Load Data (GSM8K)
# Assuming 'prompt' and 'answer' columns
grpo_dataset = datasets.load_dataset("json", data_files=f"{DATASET_PATH}/grpo_gsm8k_train.jsonl", split="train")

# Initialize GRPO Learner
# We load the SFT checkpoint as the starting point
grpo_config = grpo_learner.GRPOConfig(
    num_generations=4,
    beta=0.04,
    learning_rate=1e-6,
    max_prompt_length=256,
    max_completion_length=512,
)

learner = grpo_learner.GRPOLearner(
    config=grpo_config,
    model_name_or_path=SFT_OUTPUT_DIR, # Load from SFT
    reward_functions=[structure_reward, math_correctness_reward],
    train_dataset=grpo_dataset,
)

# Train
# learner.train(steps=GRPO_STEPS)
# learner.save_model(GRPO_OUTPUT_DIR)
print("GRPO Completed (Mock)")


In [None]:

# --- Final Inference Check ---
# Load final model and run a test prompt
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(GRPO_OUTPUT_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

prompt = "User: Solve 2x + 5 = 15.\nModel:"
inputs = tokenizer(prompt, return_tensors="jax")
# outputs = model.generate(**inputs, max_new_tokens=200)
# print(tokenizer.decode(outputs[0]))
