# Step 1: Supervised Fine-Tuning (SFT) Baseline

This notebook trains `google/gemma-2b` on synthetic constraint optimization problems using Tunix (JAX/Flax).

**Goal**: Teach the model to output structured reasoning traces with formal certificates.

**Prerequisites**:
- Run `00_env_check.ipynb` first to verify environment
- Ensure `pip install -e .` was run to install the `src` package


In [None]:
import os
import jax
import jax.numpy as jnp
from typing import Dict, List

# Import from installed package (no sys.path hacks!)
from src.data_loader import OptimizationDataset
from src.format_utils import format_input
from src.config import config

print("=" * 60)
print("STEP 1: SUPERVISED FINE-TUNING (SFT)")
print("=" * 60)
print(f"JAX Devices: {jax.devices()}")
print(f"JAX Backend: {jax.default_backend()}")
print()

# Import Tunix (may not be available in all environments)
try:
    import tunix
    from tunix.config import TrainerConfig, ModelConfig, OptimizerConfig
    from tunix.trainer import SFTTrainer
    from tunix.data import Dataset as TunixDataset
    print(f"✓ Tunix version: {tunix.__version__}")
    TUNIX_AVAILABLE = True
except ImportError as e:
    print(f"⚠️  Tunix not available: {e}")
    print("This notebook requires Tunix. Install with: pip install google-tunix[prod]")
    TUNIX_AVAILABLE = False

print()

## 1. Generate Training Data

Generate synthetic knapsack problems with ground-truth reasoning traces.


In [None]:
# Use config for dataset size (can be overridden)
DATASET_SIZE = 500  # Increase to 5000+ for production training

print(f"Generating {DATASET_SIZE} training examples...")
dataset = OptimizationDataset(size=DATASET_SIZE)
print(f"✓ Generated {len(dataset)} examples")
print()

# Show a sample
sample = dataset[0]
print("Sample Problem:")
print(sample['problem'][:200] + "...")
print()
print("Sample Target (first 300 chars):")
print(sample['target'][:300] + "...")
print()


## 2. Prepare Data for Tunix

Convert our dataset format to Tunix's expected format (prompt/response pairs).


In [None]:
def prepare_data(data_loader: OptimizationDataset) -> List[Dict[str, str]]:
    """
    Convert OptimizationDataset to Tunix format.

    Args:
        data_loader: Our custom dataset

    Returns:
        List of dicts with 'prompt' and 'response' keys
    """
    prepared = []
    for entry in data_loader:
        prepared.append({
            "prompt": format_input(entry['problem']),
            "response": entry['target']
        })
    return prepared

if TUNIX_AVAILABLE:
    raw_data = prepare_data(dataset)
    train_ds = TunixDataset.from_list(raw_data)
    print(f"✓ Prepared Tunix Dataset with {len(train_ds)} items")
else:
    print("⚠️  Skipping data preparation (Tunix not available)")

print()

## 3. Configure Training

Set up model, optimizer, and trainer configurations.

**Model**: Gemma-2b with LoRA (efficient fine-tuning)
**Optimizer**: AdamW with cosine schedule
**Training**: 3 epochs with gradient accumulation


In [None]:
if TUNIX_AVAILABLE:
    # Model configuration
    model_config = ModelConfig(
        base_model="google/gemma-2b",
        dtype="bfloat16",  # Memory efficient, TPU optimized
        use_flash_attention=True,  # Faster attention computation
        lora_rank=8,  # LoRA rank (lower = fewer params)
        lora_alpha=32,  # LoRA scaling factor
        lora_dropout=0.1  # Regularization
    )

    # Optimizer configuration
    optimizer_config = OptimizerConfig(
        learning_rate=2e-5,  # Conservative LR for fine-tuning
        scheduler_type="cosine",  # Smooth LR decay
        warmup_steps=100,  # Gradual warmup
        weight_decay=0.01  # L2 regularization
    )

    # Trainer configuration
    trainer_config = TrainerConfig(
        output_dir="../checkpoints/sft_baseline",
        num_epochs=3,  # Increase to 5-10 for production
        per_device_train_batch_size=4,  # Adjust based on memory
        gradient_accumulation_steps=4,  # Effective batch size = 16
        max_seq_length=1024,  # Max tokens per example
        logging_steps=10,  # Log every 10 steps
        save_steps=100,  # Checkpoint every 100 steps
        eval_steps=50,  # Evaluate every 50 steps (if eval set provided)
        save_total_limit=2,  # Keep only 2 latest checkpoints
        seed=42  # Reproducibility
    )

    print("✓ Training configuration complete")
    print(f"  - Base model: {model_config.base_model}")
    print(f"  - Epochs: {trainer_config.num_epochs}")
    print(f"  - Effective batch size: {trainer_config.per_device_train_batch_size * trainer_config.gradient_accumulation_steps}")
    print(f"  - Learning rate: {optimizer_config.learning_rate}")
else:
    print("⚠️  Skipping configuration (Tunix not available)")

print()


## 4. Train the Model

Run supervised fine-tuning. This will take several hours on TPU, longer on GPU/CPU.

**Note**: In Kaggle, ensure you have GPU/TPU enabled in notebook settings.


In [None]:
if TUNIX_AVAILABLE:
    try:
        print("Starting SFT training...")
        print("This may take several hours depending on hardware.")
        print()

        trainer = SFTTrainer(
            model_config=model_config,
            trainer_config=trainer_config,
            optimizer_config=optimizer_config,
            train_dataset=train_ds,
        )

        trainer.train()

        # Save the trained model
        model_save_path = "../models/constraint-reasoner-v1"
        os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
        trainer.save_model(model_save_path)

        print()
        print("=" * 60)
        print("✓ TRAINING COMPLETE")
        print("=" * 60)
        print(f"Model saved to: {model_save_path}")
        print()
        print("Next steps:")
        print("  1. Run 02_verify_and_export.ipynb to test the model")
        print("  2. Run 03_train_grpo.ipynb for RL optimization (optional)")

    except Exception as e:
        print(f"✗ Training failed: {e}")
        print("Check the error message above for details.")
        raise
else:
    print("⚠️  Cannot train without Tunix")
    print("Install Tunix with: pip install google-tunix[prod]")