# üöÄ SoFlow Training on Google Colab

<div align="center">

**Train SoFlow: One-Step Image Generation**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Gaurav14cs17/GenAI/blob/main/notebooks/SoFlow_Training.ipynb)
[![Paper](https://img.shields.io/badge/arXiv-2512.15657-b31b1b.svg)](https://arxiv.org/pdf/2512.15657)
[![GitHub](https://img.shields.io/badge/GitHub-Gaurav14cs17%2FGenAI-black.svg)](https://github.com/Gaurav14cs17/GenAI)

</div>

---

This notebook trains a SoFlow model on CIFAR-10 dataset. You'll learn:
- How to set up the training environment
- Understanding the training loop
- Monitoring loss curves
- Generating samples during training


## 1Ô∏è‚É£ Setup Environment

First, let's check GPU availability and install dependencies.


In [None]:
# Check GPU availability
!nvidia-smi

import torch
print(f"\n‚úÖ PyTorch version: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úÖ Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


In [None]:
# Install dependencies
!pip install -q torch torchvision tqdm matplotlib numpy pillow
print("‚úÖ Dependencies installed!")


## 2Ô∏è‚É£ Clone Repository & Import Libraries


In [None]:
import os

# Clone the repository (if not already cloned)
if not os.path.exists('GenAI'):
    !git clone https://github.com/Gaurav14cs17/GenAI.git
    print("‚úÖ Repository cloned!")
else:
    print("‚úÖ Repository already exists!")

# Change to the repository directory
os.chdir('GenAI')
print(f"üìÅ Working directory: {os.getcwd()}")


In [None]:
import sys
sys.path.insert(0, '.')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from PIL import Image
import json

# Import SoFlow modules
from soflow.models import create_soflow_model, DIT_MODELS
from soflow.losses import SoFlowLoss

print("‚úÖ All imports successful!")
print(f"üì¶ Available models: {list(DIT_MODELS.keys())}")


## 3Ô∏è‚É£ Training Configuration

Adjust these parameters based on your GPU memory and training needs.


In [None]:
# ===========================================
#           TRAINING CONFIGURATION
# ===========================================

config = {
    # Dataset
    "dataset": "cifar10",
    "img_size": 32,
    "in_channels": 3,
    "num_classes": 10,
    
    # Model (use smaller model for Colab)
    "hidden_size": 256,
    "depth": 6,
    "num_heads": 4,
    "patch_size": 2,
    
    # Training
    "epochs": 50,
    "batch_size": 128,
    "learning_rate": 1e-4,
    "weight_decay": 0.0,
    
    # Loss weights
    "lambda_fm": 1.0,
    "lambda_cons": 0.1,
    
    # Sampling
    "cfg_scale": 2.0,
    "save_every": 10,
    
    # Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

print("üìã Training Configuration:")
for k, v in config.items():
    print(f"   {k}: {v}")


## 4Ô∏è‚É£ Load Dataset


In [None]:
# Data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # [-1, 1]
])

# Load CIFAR-10
train_dataset = datasets.CIFAR10(
    root='./data', 
    train=True, 
    download=True, 
    transform=transform
)

train_loader = DataLoader(
    train_dataset, 
    batch_size=config["batch_size"], 
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

# Class names for visualization
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']

print(f"‚úÖ Dataset loaded!")
print(f"   Training samples: {len(train_dataset)}")
print(f"   Batches per epoch: {len(train_loader)}")
print(f"   Classes: {class_names}")


In [None]:
# Visualize some training samples
def show_samples(images, labels, title="Training Samples"):
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    for i, ax in enumerate(axes.flat):
        if i < len(images):
            img = images[i].permute(1, 2, 0).cpu().numpy()
            img = (img + 1) / 2  # [-1, 1] -> [0, 1]
            ax.imshow(np.clip(img, 0, 1))
            ax.set_title(class_names[labels[i]] if isinstance(labels[i], int) else class_names[labels[i].item()], fontsize=8)
        ax.axis('off')
    plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.show()

# Show samples
sample_batch, sample_labels = next(iter(train_loader))
show_samples(sample_batch[:16], sample_labels[:16])


## 5Ô∏è‚É£ Create Model & Training Components


In [None]:
# Create SoFlow model
model = create_soflow_model(
    in_channels=config["in_channels"],
    hidden_size=config["hidden_size"],
    depth=config["depth"],
    num_heads=config["num_heads"],
    patch_size=config["patch_size"],
    num_classes=config["num_classes"],
    img_size=config["img_size"]
).to(config["device"])

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"‚úÖ Model created!")
print(f"   Parameters: {num_params:,} ({num_params/1e6:.2f}M)")

# Create loss function
loss_fn = SoFlowLoss(
    lambda_fm=config["lambda_fm"],
    lambda_cons=config["lambda_cons"]
)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config["learning_rate"],
    weight_decay=config["weight_decay"]
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=config["epochs"] * len(train_loader)
)

print("‚úÖ Loss function and optimizer created!")


## 6Ô∏è‚É£ Training Functions


In [None]:
def train_step(model, loss_fn, batch, optimizer, scheduler, step, total_steps):
    """Single training step."""
    x_0, y = batch
    x_0 = x_0.to(config["device"])
    y = y.to(config["device"])
    
    # Scale noise to match data std (~0.5)
    x_1 = torch.randn_like(x_0) * 0.5
    
    # Forward pass
    loss_dict = loss_fn(
        model, x_0, x_1, y, 
        step=step, 
        total_steps=total_steps, 
        return_dict=True
    )
    loss = loss_dict["loss"]
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()
    
    return {k: v.item() for k, v in loss_dict.items()}

@torch.no_grad()
def generate_samples(model, num_samples, cfg_scale):
    """Generate samples for visualization."""
    model.eval()
    
    # Generate noise (scaled to match data)
    noise = torch.randn(
        num_samples, 
        config["in_channels"], 
        config["img_size"], 
        config["img_size"]
    ).to(config["device"]) * 0.5
    
    # Labels (one per class, repeating)
    labels = torch.arange(config["num_classes"]).to(config["device"])
    labels = labels.repeat(num_samples // config["num_classes"] + 1)[:num_samples]
    
    # Generate with one step!
    samples = model.sample(noise, labels, cfg_scale=cfg_scale)
    
    model.train()
    return samples, labels

print("‚úÖ Training functions defined!")


In [None]:
# Training history
history = {
    "loss": [],
    "loss_fm": [],
    "loss_cons": [],
    "lr": []
}

# Output directory
os.makedirs("outputs/colab_training", exist_ok=True)
os.makedirs("outputs/colab_training/samples", exist_ok=True)

# Total steps
total_steps = config["epochs"] * len(train_loader)
global_step = 0

print(f"üöÄ Starting training for {config['epochs']} epochs...")
print(f"   Total steps: {total_steps:,}")


In [None]:
# Main training loop
for epoch in range(config["epochs"]):
    model.train()
    epoch_losses = []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
    
    for batch in pbar:
        metrics = train_step(
            model, loss_fn, batch, optimizer, scheduler,
            step=global_step, total_steps=total_steps
        )
        
        epoch_losses.append(metrics["loss"])
        history["loss"].append(metrics["loss"])
        history["loss_fm"].append(metrics.get("loss_fm", 0))
        history["loss_cons"].append(metrics.get("loss_cons", 0))
        history["lr"].append(scheduler.get_last_lr()[0])
        
        global_step += 1
        pbar.set_postfix({"loss": f"{metrics['loss']:.4f}"})
    
    avg_loss = np.mean(epoch_losses)
    print(f"üìä Epoch {epoch+1}: avg_loss = {avg_loss:.4f}")
    
    # Generate and save samples periodically
    if (epoch + 1) % config["save_every"] == 0 or epoch == 0:
        samples, labels = generate_samples(model, 16, config["cfg_scale"])
        show_samples(samples, labels, f"Epoch {epoch+1} Samples (CFG={config['cfg_scale']})")

print("\n‚úÖ Training complete!")


## 8Ô∏è‚É£ Visualize Results üìä


In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Total loss
axes[0].plot(history["loss"], alpha=0.5, label="Raw")
window = min(100, len(history["loss"]) // 10 + 1)
if window > 1 and len(history["loss"]) > window:
    smoothed = np.convolve(history["loss"], np.ones(window)/window, mode='valid')
    axes[0].plot(smoothed, label="Smoothed", linewidth=2)
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Loss")
axes[0].set_title("Total Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# FM vs Cons loss
axes[1].plot(history["loss_fm"], label="FM Loss", alpha=0.7)
axes[1].plot(history["loss_cons"], label="Cons Loss", alpha=0.7)
axes[1].set_xlabel("Step")
axes[1].set_ylabel("Loss")
axes[1].set_title("Loss Components")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Learning rate
axes[2].plot(history["lr"])
axes[2].set_xlabel("Step")
axes[2].set_ylabel("Learning Rate")
axes[2].set_title("Learning Rate Schedule")
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("outputs/colab_training/loss_curves.png", dpi=150)
plt.show()


In [None]:
# Generate final samples with different CFG scales
print("üé® Generating samples with different CFG scales...")

fig, axes = plt.subplots(3, 10, figsize=(20, 6))
cfg_scales = [1.0, 2.0, 4.0]

for i, cfg in enumerate(cfg_scales):
    samples, labels = generate_samples(model, 10, cfg)
    for j in range(10):
        img = samples[j].permute(1, 2, 0).cpu().numpy()
        img = (img + 1) / 2
        axes[i, j].imshow(np.clip(img, 0, 1))
        axes[i, j].axis('off')
        if j == 0:
            axes[i, j].set_ylabel(f"CFG={cfg}", fontsize=12)

plt.suptitle("Generated Samples with Different CFG Scales", fontsize=14)
plt.tight_layout()
plt.savefig("outputs/colab_training/cfg_comparison.png", dpi=150)
plt.show()


## 9Ô∏è‚É£ Save Model üíæ


In [None]:
# Save model checkpoint
checkpoint = {
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "config": config,
    "epoch": config["epochs"],
    "history": history
}

torch.save(checkpoint, "outputs/colab_training/model.pt")
print("‚úÖ Model saved to outputs/colab_training/model.pt")

# Save training history
with open("outputs/colab_training/history.json", "w") as f:
    json.dump(history, f)
print("‚úÖ Training history saved!")


In [None]:
# Download the model (for Colab)
try:
    from google.colab import files
    files.download("outputs/colab_training/model.pt")
    print("üì• Model download started!")
except ImportError:
    print("‚ÑπÔ∏è Not running in Colab, skipping download.")


---

## üéâ Congratulations!

You've successfully trained a SoFlow model! Key takeaways:

- **One-step generation** works after training
- **CFG scale** controls quality vs diversity
- **Loss curves** show stable training

### Next Steps

1. Try the [Inference Notebook](./SoFlow_Inference.ipynb) to generate more samples
2. Increase training epochs for better quality
3. Try larger model sizes if you have more GPU memory

---

<div align="center">

**Made with ‚ù§Ô∏è for the ML community**

</div>
