# Step 1: Baseline SFT Training

Trains `google/gemma-2b` (or similar) on the constraint optimization dataset using Tunix.

In [None]:
import sys
import os

# Add src to path
sys.path.append(os.path.abspath("../src"))

from data_loader import OptimizationDataset
from format_utils import format_input
import tunix
from tunix.trainer import SFTTrainer  # Hypothetical API based on prompt
from tunix.models import GemmaModel

print("Libraries loaded.")

In [None]:
# 1. Load Data
dataset = OptimizationDataset(size=200)
print(f"Loaded {len(dataset)} examples.")
print("Example 0:")
print(dataset[0]['problem'])
print(dataset[0]['target'])

In [None]:
# 2. Prepare for Training
# Convert to Tunix format (usually expects a prompt, completion pair)
train_data = []
for item in dataset:
    train_data.append({
        "prompt": format_input(item['problem']),
        "completion": item['target']
    })

# 3. Configure Trainer
config = tunix.Config(
    model_name="google/gemma-2b",
    batch_size=8,
    learning_rate=2e-5,
    num_epochs=1,
    output_dir="../checkpoints/sft_baseline"
)

# Note: Tunix API is simulated here.
# trainer = SFTTrainer(config=config, train_data=train_data)
# trainer.train()

print("Training configuration ready. Run trainer.train() when ready.")