# SFT Training - Interactive Configuration Notebook

This notebook allows you to configure and run SFT training **without any YAML files**!

## Benefits

✅ No external YAML files needed  
✅ Interactive configuration in separate cells  
✅ Easy to modify and experiment  
✅ All configuration visible in notebook  
✅ Quick templates for common scenarios

## Step 1: Import Dependencies

In [None]:
import asyncio
import logging
from omegaconf import OmegaConf, DictConfig

from forge.apps.sft_v2.trainer_actor import TrainerActor
from forge.apps.sft_v2.spawn_actor import SpawnActor, run_actor

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

## Step 2: Configure Model Settings

Define your model configuration. **Modify these values as needed!**

In [None]:
model_config = {
    "name": "llama3",
    "flavor": "8B",
    "hf_assets_path": "/tmp/Meta-Llama-3.1-8B-Instruct"
}

print("Model Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(model_config)))

## Step 3: Configure Process Settings

Define how many processes to use and whether to use GPUs.

In [None]:
processes_config = {
    "procs": 8,        # Number of processes
    "with_gpus": True  # Use GPUs
}

print("Process Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(processes_config)))

## Step 4: Configure Optimizer Settings

In [None]:
optimizer_config = {
    "name": "AdamW",
    "lr": 1e-5,    # Learning rate
    "eps": 1e-8
}

print("Optimizer Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(optimizer_config)))

## Step 5: Configure Learning Rate Scheduler

In [None]:
lr_scheduler_config = {
    "warmup_steps": 200  # Number of warmup steps
}

print("LR Scheduler Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(lr_scheduler_config)))

## Step 6: Configure Training Settings

**Key parameters to adjust for your experiment:**

In [None]:
training_config = {
    "local_batch_size": 1,  # Batch size per GPU
    "seq_len": 2048,         # Sequence length
    "max_norm": 1.0,         # Gradient clipping
    "steps": 1000,           # Total training steps
    "compile": False,        # PyTorch compilation
    "dataset": "c4"          # Dataset name
}

print("Training Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(training_config)))

## Step 7: Configure Parallelism Settings

In [None]:
parallelism_config = {
    "data_parallel_replicate_degree": 1,
    "data_parallel_shard_degree": -1,  # -1 means use all available GPUs for FSDP
    "tensor_parallel_degree": 1,
    "pipeline_parallel_degree": 1,
    "context_parallel_degree": 1,
    "expert_parallel_degree": 1,
    "disable_loss_parallel": False
}

print("Parallelism Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(parallelism_config)))

## Step 8: Configure Checkpoint Settings

In [None]:
checkpoint_config = {
    "enable": True,
    "folder": "/tmp/Meta-Llama-3.1-8B-Instruct/saved_checkpoints",
    "initial_load_path": "/tmp/Meta-Llama-3.1-8B-Instruct/",
    "initial_load_in_hf": True,
    "last_save_in_hf": True,
    "interval": 500,           # Save every N steps
    "async_mode": "disabled"
}

print("Checkpoint Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(checkpoint_config)))

## Step 9: Configure Activation Checkpointing

In [None]:
activation_checkpoint_config = {
    "mode": "selective",
    "selective_ac_option": "op"
}

print("Activation Checkpoint Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(activation_checkpoint_config)))

## Step 10: Configure Communication Settings

In [None]:
comm_config = {
    "trace_buf_size": 0
}

print("Communication Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(comm_config)))

## Step 11: Combine All Configurations

Now let's merge everything into a complete configuration!

In [None]:
# Combine all configs
complete_config = {
    "comm": comm_config,
    "model": model_config,
    "processes": processes_config,
    "optimizer": optimizer_config,
    "lr_scheduler": lr_scheduler_config,
    "training": training_config,
    "parallelism": parallelism_config,
    "checkpoint": checkpoint_config,
    "activation_checkpoint": activation_checkpoint_config
}

# Create OmegaConf DictConfig
cfg = OmegaConf.create(complete_config)

print("=" * 80)
print("COMPLETE CONFIGURATION")
print("=" * 80)
print(OmegaConf.to_yaml(cfg))
print("=" * 80)

## Step 12: Run Training (Simple Way)

The simplest way - automatic lifecycle management!

In [None]:
# Run training with automatic lifecycle management
await run_actor(TrainerActor, cfg)

## Alternative: Manual Lifecycle Control

For more control, manage each phase separately.

### Create and Spawn the Actor

In [None]:
# Create the spawner
spawner = SpawnActor(TrainerActor, cfg)

# Spawn the actor
actor = await spawner.spawn()
print(f"✓ Actor spawned: {actor}")

### Setup the Actor

In [None]:
# Setup (load data, checkpoints, etc.)
await spawner.setup()
print("✓ Actor setup complete")

### Run Training

In [None]:
# Run training
await spawner.run()
print("✓ Training complete")

### Cleanup

In [None]:
# Cleanup resources
await spawner.cleanup()
print("✓ Cleanup complete")

---

# Quick Configuration Templates

Here are ready-to-use templates for common scenarios!

## Template 1: Quick Test (Single GPU, Small Steps)

In [None]:
quick_test_config = OmegaConf.create({
    "comm": {"trace_buf_size": 0},
    "model": {
        "name": "llama3",
        "flavor": "8B",
        "hf_assets_path": "/tmp/Meta-Llama-3.1-8B-Instruct"
    },
    "processes": {"procs": 1, "with_gpus": True},
    "optimizer": {"name": "AdamW", "lr": 1e-5, "eps": 1e-8},
    "lr_scheduler": {"warmup_steps": 10},
    "training": {
        "local_batch_size": 1,
        "seq_len": 1024,
        "max_norm": 1.0,
        "steps": 100,  # Just 100 steps for quick testing
        "compile": False,
        "dataset": "c4"
    },
    "parallelism": {
        "data_parallel_replicate_degree": 1,
        "data_parallel_shard_degree": 1,
        "tensor_parallel_degree": 1,
        "pipeline_parallel_degree": 1,
        "context_parallel_degree": 1,
        "expert_parallel_degree": 1,
        "disable_loss_parallel": False
    },
    "checkpoint": {
        "enable": True,
        "folder": "/tmp/quick_test_checkpoints",
        "initial_load_path": "/tmp/Meta-Llama-3.1-8B-Instruct/",
        "initial_load_in_hf": True,
        "last_save_in_hf": True,
        "interval": 50,
        "async_mode": "disabled"
    },
    "activation_checkpoint": {
        "mode": "selective",
        "selective_ac_option": "op"
    }
})

print("Quick Test Configuration:")
print(OmegaConf.to_yaml(quick_test_config))

# To use: await run_actor(TrainerActor, quick_test_config)

## Template 2: Multi-GPU Training (8 GPUs with FSDP)

In [None]:
multi_gpu_config = OmegaConf.create({
    "comm": {"trace_buf_size": 0},
    "model": {
        "name": "llama3",
        "flavor": "8B",
        "hf_assets_path": "/tmp/Meta-Llama-3.1-8B-Instruct"
    },
    "processes": {"procs": 8, "with_gpus": True},
    "optimizer": {"name": "AdamW", "lr": 2e-5, "eps": 1e-8},
    "lr_scheduler": {"warmup_steps": 200},
    "training": {
        "local_batch_size": 2,
        "seq_len": 2048,
        "max_norm": 1.0,
        "steps": 5000,
        "compile": False,
        "dataset": "c4"
    },
    "parallelism": {
        "data_parallel_replicate_degree": 1,
        "data_parallel_shard_degree": 8,  # FSDP across 8 GPUs
        "tensor_parallel_degree": 1,
        "pipeline_parallel_degree": 1,
        "context_parallel_degree": 1,
        "expert_parallel_degree": 1,
        "disable_loss_parallel": False
    },
    "checkpoint": {
        "enable": True,
        "folder": "/tmp/multi_gpu_checkpoints",
        "initial_load_path": "/tmp/Meta-Llama-3.1-8B-Instruct/",
        "initial_load_in_hf": True,
        "last_save_in_hf": True,
        "interval": 500,
        "async_mode": "disabled"
    },
    "activation_checkpoint": {
        "mode": "selective",
        "selective_ac_option": "op"
    }
})

print("Multi-GPU Configuration:")
print(OmegaConf.to_yaml(multi_gpu_config))

# To use: await run_actor(TrainerActor, multi_gpu_config)

## Template 3: Memory-Efficient Training

In [None]:
memory_efficient_config = OmegaConf.create({
    "comm": {"trace_buf_size": 0},
    "model": {
        "name": "llama3",
        "flavor": "8B",
        "hf_assets_path": "/tmp/Meta-Llama-3.1-8B-Instruct"
    },
    "processes": {"procs": 4, "with_gpus": True},
    "optimizer": {"name": "AdamW", "lr": 1e-5, "eps": 1e-8},
    "lr_scheduler": {"warmup_steps": 150},
    "training": {
        "local_batch_size": 1,  # Small batch size
        "seq_len": 1024,         # Shorter sequence
        "max_norm": 1.0,
        "steps": 2000,
        "compile": False,
        "dataset": "c4"
    },
    "parallelism": {
        "data_parallel_replicate_degree": 1,
        "data_parallel_shard_degree": 4,
        "tensor_parallel_degree": 1,
        "pipeline_parallel_degree": 1,
        "context_parallel_degree": 1,
        "expert_parallel_degree": 1,
        "disable_loss_parallel": False
    },
    "checkpoint": {
        "enable": True,
        "folder": "/tmp/memory_efficient_checkpoints",
        "initial_load_path": "/tmp/Meta-Llama-3.1-8B-Instruct/",
        "initial_load_in_hf": True,
        "last_save_in_hf": True,
        "interval": 400,
        "async_mode": "disabled"
    },
    "activation_checkpoint": {
        "mode": "selective",  # Saves memory
        "selective_ac_option": "op"
    }
})

print("Memory-Efficient Configuration:")
print(OmegaConf.to_yaml(memory_efficient_config))

# To use: await run_actor(TrainerActor, memory_efficient_config)

---

# Tips & Tricks

## Memory Optimization
- ⬇️ Reduce `seq_len` if running out of memory
- ⬇️ Reduce `local_batch_size` if running out of memory
- ✅ Enable `activation_checkpoint` for memory savings

## Training Speed
- ⬆️ Increase `local_batch_size` for faster training (if memory allows)
- 🚀 Use multiple GPUs with FSDP (`data_parallel_shard_degree > 1`)
- ⚡ Enable `compile: true` for PyTorch compilation (experimental)

## Debugging
- 🧪 Start with small `steps` (e.g., 10-100) to test quickly
- 🔍 Use single GPU first (`procs: 1`)
- 📊 Monitor loss values in logs

## Checkpoint Management
- 💾 Set `interval` based on how often you want to save
- 📁 Ensure `folder` path exists and has enough space
- 🔄 Use `initial_load_path` to resume from checkpoints