# Step 1: Baseline SFT Training

Trains `google/gemma-2b` on the constraint optimization dataset using Tunix (JAX/Flax).

In [None]:
import sys
import os
import jax
import jax.numpy as jnp
from typing import Dict, List

sys.path.append(os.path.abspath("../src"))

from data_loader import OptimizationDataset, DatasetEntry
from format_utils import format_input

import tunix
from tunix.config import TrainerConfig, ModelConfig, OptimizerConfig
from tunix.trainer import SFTTrainer
from tunix.data import Dataset as TunixDataset

print(f"JAX Devices: {jax.devices()}")

In [None]:
dataset = OptimizationDataset(size=500)
print(f"Loaded {len(dataset)} examples.")

In [None]:
def prepare_data(data_loader: OptimizationDataset) -> List[Dict[str, str]]:
    prepared = []
    for entry in data_loader:
        prepared.append({
            "prompt": format_input(entry['problem']),
            "response": entry['target']
        })
    return prepared

raw_data = prepare_data(dataset)
train_ds = TunixDataset.from_list(raw_data)
print(f"Prepared Tunix Dataset with {len(train_ds)} items.")

In [None]:
model_config = ModelConfig(
    base_model="google/gemma-2b",
    dtype="bfloat16",
    use_flash_attention=True,
    lora_rank=8,
    lora_alpha=32,
    lora_dropout=0.1
)

optimizer_config = OptimizerConfig(
    learning_rate=2e-5,
    scheduler_type="cosine",
    warmup_steps=100,
    weight_decay=0.01
)

trainer_config = TrainerConfig(
    output_dir="../checkpoints/sft_baseline",
    num_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    max_seq_length=1024,
    logging_steps=10,
    save_steps=100,
    eval_steps=50,
    save_total_limit=2,
    seed=42
)

In [None]:
trainer = SFTTrainer(
    model_config=model_config,
    trainer_config=trainer_config,
    optimizer_config=optimizer_config,
    train_dataset=train_ds,
)

trainer.train()

trainer.save_model("../models/constraint-reasoner-v1")
print("Training complete and model saved.")