# 🚀 SFT Training Notebook

This notebook provides an interactive interface for training Language Models using Supervised Fine-Tuning (SFT).

## Features
- ✅ Interactive configuration in separate cells
- ✅ Support for single-node and multi-node training
- ✅ Easy hyperparameter tuning
- ✅ Flexible parallelism strategies
- ✅ Checkpoint management

## Quick Start
1. Configure each section (model, training, etc.)
2. Review the complete configuration
3. Run training!

## 📚 Imports

In [None]:
import sys
sys.path.insert(0, '/home/hosseinkh/forge')

from apps.sft_v2 import notebook_utils as nb
import torch

print(f"✅ Imports successful!")
print(f"📊 PyTorch version: {torch.__version__}")
print(f"🎮 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🔢 Number of GPUs: {torch.cuda.device_count()}")

## 📦 Model Configuration

Configure the model you want to train.

In [None]:
# Model Configuration
model_config = nb.create_model_config(
    name="llama3",
    flavor="8B",
    hf_assets_path="/mnt/home/hosseinkh/models/Meta-Llama-3.1-8B-Instruct"
)

print("📦 Model Configuration:")
for key, value in model_config.items():
    print(f"  • {key}: {value}")

## ⚙️ Training Configuration

Set training hyperparameters.

In [None]:
# Training Configuration
training_config = nb.create_training_config(
    local_batch_size=1,      # Batch size per GPU
    seq_len=2048,            # Sequence length
    max_norm=1.0,            # Gradient clipping
    steps=1000,              # Total training steps
    dataset="c4",            # Dataset name
    compile=False            # Use torch.compile?
)

print("⚙️  Training Configuration:")
for key, value in training_config.items():
    print(f"  • {key}: {value}")

## 🔧 Optimizer Configuration

Configure the optimizer and learning rate.

In [None]:
# Optimizer Configuration
optimizer_config = nb.create_optimizer_config(
    name="AdamW",
    lr=1e-5,                 # Learning rate
    eps=1e-8,                # Epsilon
    weight_decay=0.0,        # Weight decay
    betas=(0.9, 0.999)       # Adam betas
)

# LR Scheduler Configuration
lr_scheduler_config = nb.create_lr_scheduler_config(
    warmup_steps=200,        # Warmup steps
    decay_steps=None,        # Decay steps (None = no decay)
    min_lr=0.0               # Minimum LR
)

print("🔧 Optimizer Configuration:")
for key, value in optimizer_config.items():
    print(f"  • {key}: {value}")

print("\n📈 LR Scheduler Configuration:")
for key, value in lr_scheduler_config.items():
    print(f"  • {key}: {value}")

## 🔀 Parallelism Configuration

Configure distributed training strategies.

### Parallelism Options:
- **Data Parallel (Replicate)**: Basic data parallelism
- **Data Parallel (Shard/FSDP)**: Fully Sharded Data Parallel (-1 = use all GPUs)
- **Tensor Parallel**: Split model across multiple GPUs
- **Pipeline Parallel**: Split model stages across GPUs

In [None]:
# Parallelism Configuration
parallelism_config = nb.create_parallelism_config(
    data_parallel_replicate_degree=1,   # DP replicate
    data_parallel_shard_degree=-1,      # FSDP (-1 = auto, uses all GPUs)
    tensor_parallel_degree=1,           # TP
    pipeline_parallel_degree=1,         # PP
    context_parallel_degree=1,          # CP
    expert_parallel_degree=1,           # EP (for MoE)
    disable_loss_parallel=False
)

print("🔀 Parallelism Configuration:")
for key, value in parallelism_config.items():
    print(f"  • {key}: {value}")

## 💾 Checkpoint Configuration

Configure model checkpointing.

In [None]:
# Checkpoint Configuration
checkpoint_config = nb.create_checkpoint_config(
    enable=True,
    folder="/mnt/home/hosseinkh/models/Meta-Llama-3.1-8B-Instruct/saved_checkpoints",
    initial_load_path="/mnt/home/hosseinkh/models/Meta-Llama-3.1-8B-Instruct/",
    initial_load_in_hf=True,
    last_save_in_hf=True,
    interval=500,            # Save every N steps
    async_mode="disabled"
)

# Activation Checkpoint Configuration (for memory efficiency)
activation_checkpoint_config = nb.create_activation_checkpoint_config(
    mode="selective",        # 'selective', 'full', or 'none'
    selective_ac_option="op" # 'op' or 'layer'
)

print("💾 Checkpoint Configuration:")
for key, value in checkpoint_config.items():
    print(f"  • {key}: {value}")

print("\n🔄 Activation Checkpoint Configuration:")
for key, value in activation_checkpoint_config.items():
    print(f"  • {key}: {value}")

## 🖥️ Resource Configuration

Configure compute resources.

### Options:
- **Single Node**: Set only `procs` (number of GPUs)
- **Multi Node**: Set both `hosts` (number of nodes) and `procs` (GPUs per node)

In [None]:
# Choose ONE of the following:

# Option 1: Single Node (8 GPUs)
process_config = nb.create_process_config(
    procs=8,
    with_gpus=True,
    hosts=None  # None = single node
)

# Option 2: Multi-Node (4 nodes × 8 GPUs = 32 total)
# Uncomment to use:
# process_config = nb.create_process_config(
#     procs=8,
#     with_gpus=True,
#     hosts=4
# )

print("🖥️  Resource Configuration:")
for key, value in process_config.items():
    print(f"  • {key}: {value}")

if "hosts" in process_config and process_config["hosts"]:
    total_gpus = process_config["hosts"] * process_config["procs"]
    print(f"\n📊 Total GPUs: {total_gpus}")
else:
    print(f"\n📊 Total GPUs: {process_config['procs']}")

## ☁️ Provisioner Configuration (Optional)

**Only needed for multi-node training on SLURM clusters.**

⚠️ Skip this cell if you're running single-node training!

In [None]:
# Provisioner Configuration (OPTIONAL - for multi-node only)
# Set to None for single-node training

provisioner_config = None  # Default: no provisioner

# Uncomment and configure for SLURM multi-node training:
# provisioner_config = nb.create_provisioner_config(
#     launcher="slurm",
#     job_name="sft_training",
#     partition="your_gpu_partition",  # REQUIRED for SLURM
#     time="24:00:00",                  # REQUIRED for SLURM
#     account="your_account"            # May be required
# )

if provisioner_config:
    print("☁️  Provisioner Configuration:")
    for key, value in provisioner_config.items():
        print(f"  • {key}: {value}")
else:
    print("☁️  Provisioner: Disabled (single-node mode)")

## 🔨 Build Complete Configuration

Combine all configurations into a single config object.

In [None]:
# Build complete configuration
config = nb.build_config(
    model_config=model_config,
    optimizer_config=optimizer_config,
    lr_scheduler_config=lr_scheduler_config,
    training_config=training_config,
    parallelism_config=parallelism_config,
    checkpoint_config=checkpoint_config,
    activation_checkpoint_config=activation_checkpoint_config,
    process_config=process_config,
    provisioner_config=provisioner_config
)

print("✅ Configuration built successfully!\n")

# Display summary
nb.summarize_config(config)

## 📄 View Full Configuration (YAML)

See the complete configuration in YAML format.

In [None]:
# Print full configuration
nb.print_config(config, title="Complete Training Configuration")

## 💾 Save Configuration (Optional)

Save the configuration to a YAML file for later use.

In [None]:
from omegaconf import OmegaConf

# Save configuration
config_path = "/home/hosseinkh/forge/apps/sft_v2/my_training_config.yaml"
with open(config_path, 'w') as f:
    OmegaConf.save(config, f)

print(f"✅ Configuration saved to: {config_path}")

## 🚀 Run Training!

Start the training process with the configured settings.

⚠️ **Note**: This will start actual training and may take a long time!

In [None]:
# Run training
print("🚀 Starting training...\n")

try:
    nb.train(config)
    print("\n✅ Training completed successfully!")
except Exception as e:
    print(f"\n❌ Training failed: {e}")
    import traceback
    traceback.print_exc()

## 🔍 Advanced: Step-by-Step Execution

For more control, you can run each training stage separately.

⚠️ **Only run this section if you want manual control. Otherwise, use the cell above.**

In [None]:
# Step 1: Initialize provisioner (if configured)
import asyncio

provisioner_initialized = await nb.initialize_provisioner(config)
print(f"Provisioner initialized: {provisioner_initialized}")

In [None]:
# Step 2: Create recipe
recipe = await nb.create_recipe(config)
print("Recipe created")

In [None]:
# Step 3: Setup recipe (load model, data, etc.)
await nb.setup_recipe(recipe)
print("Recipe setup complete")

In [None]:
# Step 4: Run training
await nb.train_recipe(recipe)
print("Training complete")

In [None]:
# Step 5: Cleanup
await nb.cleanup_recipe(recipe)
print("Cleanup complete")

In [None]:
# Step 6: Shutdown provisioner (if initialized)
if provisioner_initialized:
    await nb.shutdown_provisioner(config)
    print("Provisioner shutdown complete")

## 📊 Tips & Tricks

### Memory Optimization
- Use **FSDP** (set `data_parallel_shard_degree=-1`) for large models
- Enable **activation checkpointing** (set `mode="selective"` or `"full"`)
- Reduce **batch size** or **sequence length**

### Speed Optimization
- Use **tensor parallelism** for large models (set `tensor_parallel_degree > 1`)
- Enable **compilation** (set `compile=True`)
- Increase **batch size** if memory allows

### Multi-Node Training
- Set `hosts` in process config
- Configure provisioner with SLURM details
- Make sure model path is accessible on all nodes

### Debugging
- Start with fewer steps (e.g., `steps=10`)
- Use single GPU first (`procs=1`)
- Check logs for errors

## 🎯 Common Configurations

### Quick Test Run
```python
training_config = nb.create_training_config(
    steps=10,
    local_batch_size=1
)
process_config = nb.create_process_config(procs=1)
```

### Single Node, 8 GPUs, FSDP
```python
parallelism_config = nb.create_parallelism_config(
    data_parallel_shard_degree=-1  # Use all 8 GPUs with FSDP
)
process_config = nb.create_process_config(procs=8)
```

### Multi-Node, 4×8 GPUs, TP=2
```python
parallelism_config = nb.create_parallelism_config(
    data_parallel_shard_degree=16,   # 32 GPUs / 2 TP = 16 FSDP
    tensor_parallel_degree=2
)
process_config = nb.create_process_config(procs=8, hosts=4)
provisioner_config = nb.create_provisioner_config(
    launcher="slurm",
    partition="gpu_partition"
)
```