# üéì SFT Training - Supervised Fine-Tuning

Fine-tune Gemma to output `<reasoning>...<answer>` format.

**Time estimate:** ~2-3 hours on Kaggle TPU

In [None]:
import os
import json
import time
import yaml
from datetime import datetime

import jax
import jax.numpy as jnp
from transformers import AutoTokenizer

print(f"JAX devices: {jax.device_count()}")

## 1. Load Configuration

In [None]:
# SFT Configuration
CFG = {
    'model_name': 'google/gemma-3-1b-it',
    'learning_rate': 1e-5,
    'batch_size': 4,
    'grad_accum': 8,
    'num_epochs': 2,
    'max_length': 1024,
    'warmup_ratio': 0.03,
    'save_minutes': 30,
    'lora_r': 16,
    'lora_alpha': 32,
    'seed': 42
}

print("üìã Configuration:")
for k, v in CFG.items():
    print(f"  {k}: {v}")

## 2. Load Tokenizer & Model

In [None]:
tokenizer = AutoTokenizer.from_pretrained(CFG['model_name'])
print(f"‚úÖ Tokenizer loaded: {CFG['model_name']}")

In [None]:
# Load model with Tunix
# NOTE: Adjust imports based on actual Tunix API
try:
    from tunix import modeling
    model = modeling.Gemma.from_pretrained(CFG['model_name'])
    print("‚úÖ Model loaded via Tunix")
except ImportError:
    print("‚ö†Ô∏è Tunix not available - using placeholder")
    model = None

## 3. Load Training Data

In [None]:
def load_jsonl(path):
    data = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

train_data = load_jsonl('data/tokenized/train.jsonl')
val_data = load_jsonl('data/tokenized/valid.jsonl')

print(f"‚úÖ Loaded {len(train_data)} train, {len(val_data)} val examples")

## 4. Create Data Batches

In [None]:
import numpy as np

def create_batches(data, batch_size, shuffle=True):
    """Create batches from tokenized data."""
    if shuffle:
        np.random.shuffle(data)
    
    batches = []
    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        
        # Pad to max length in batch
        max_len = max(len(ex['input_ids']) for ex in batch)
        
        input_ids = []
        attention_mask = []
        
        for ex in batch:
            ids = ex['input_ids']
            mask = ex['attention_mask']
            pad_len = max_len - len(ids)
            
            input_ids.append(ids + [tokenizer.pad_token_id or 0] * pad_len)
            attention_mask.append(mask + [0] * pad_len)
        
        batches.append({
            'input_ids': jnp.array(input_ids),
            'attention_mask': jnp.array(attention_mask)
        })
    
    return batches

# Test batch creation
test_batches = create_batches(train_data[:16], CFG['batch_size'])
print(f"‚úÖ Created {len(test_batches)} test batches")
print(f"   Batch shape: {test_batches[0]['input_ids'].shape}")

## 5. Training Loop

In [None]:
# Training state
start_time = time.time()
last_save_time = start_time
global_step = 0
train_losses = []

os.makedirs('checkpoints/sft', exist_ok=True)

In [None]:
# SFT Training Loop (skeleton)
# NOTE: Adapt to actual Tunix trainer API

def train_epoch(epoch):
    global global_step, last_save_time
    
    batches = create_batches(train_data, CFG['batch_size'])
    epoch_losses = []
    
    for i, batch in enumerate(batches):
        # === Training step ===
        # loss = trainer.train_step(batch)
        loss = 0.5  # Placeholder
        
        epoch_losses.append(loss)
        global_step += 1
        
        # Log every 50 steps
        if global_step % 50 == 0:
            avg_loss = sum(epoch_losses[-50:]) / min(50, len(epoch_losses))
            elapsed = (time.time() - start_time) / 60
            print(f"Step {global_step} | Loss: {avg_loss:.4f} | Time: {elapsed:.1f}m")
        
        # Save checkpoint every N minutes
        if (time.time() - last_save_time) > CFG['save_minutes'] * 60:
            save_checkpoint(f"sft_step_{global_step}")
            last_save_time = time.time()
    
    return sum(epoch_losses) / len(epoch_losses)

def save_checkpoint(name):
    path = f"checkpoints/sft/{name}"
    os.makedirs(path, exist_ok=True)
    
    # Save metadata
    meta = {
        'step': global_step,
        'time': datetime.now().isoformat(),
        'config': CFG
    }
    with open(f"{path}/metadata.json", 'w') as f:
        json.dump(meta, f, indent=2)
    
    # Save model
    # model.save_pretrained(path)
    
    print(f"üíæ Saved checkpoint: {path}")

In [None]:
# Run training
print("üöÄ Starting SFT Training...")
print(f"   Epochs: {CFG['num_epochs']}")
print(f"   Batch size: {CFG['batch_size']} x {CFG['grad_accum']} = {CFG['batch_size'] * CFG['grad_accum']}")
print()

for epoch in range(CFG['num_epochs']):
    print(f"=== Epoch {epoch+1}/{CFG['num_epochs']} ===")
    avg_loss = train_epoch(epoch)
    print(f"Epoch {epoch+1} avg loss: {avg_loss:.4f}")
    
    # Save end of epoch
    save_checkpoint(f"sft_epoch_{epoch+1}")

print("\n‚úÖ SFT Training complete!")

## 6. Quick Evaluation

In [None]:
# Test generation
test_prompt = """Q: A store has 45 apples. If they sell 12 apples, how many are left?
A:
"""

# Generate
# output = model.generate(test_prompt, max_length=256)
output = """<reasoning>
Step 1: Start with 45 apples
Step 2: Subtract 12 sold apples
Step 3: 45 - 12 = 33
</reasoning>
<answer>33 apples</answer>"""  # Placeholder

print("üìù Test generation:")
print(test_prompt)
print(output)

In [None]:
# Check format compliance
import re

def check_format(text):
    has_reasoning = bool(re.search(r'<reasoning>.*</reasoning>', text, re.DOTALL))
    has_answer = bool(re.search(r'<answer>.*</answer>', text, re.DOTALL))
    return has_reasoning and has_answer

print(f"Format compliant: {check_format(output)}")

In [None]:
total_time = (time.time() - start_time) / 60
print("\n" + "="*50)
print("SFT TRAINING COMPLETE")
print("="*50)
print(f"Total time: {total_time:.1f} minutes")
print(f"Steps: {global_step}")
print(f"Checkpoints saved to: checkpoints/sft/")
print("\n‚û°Ô∏è Proceed to: 03_rl_grpo_training.ipynb")