# üé® SoFlow Inference on Google Colab

<div align="center">

**Generate Images with One-Step SoFlow**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Gaurav14cs17/GenAI/blob/main/notebooks/SoFlow_Inference.ipynb)
[![Paper](https://img.shields.io/badge/arXiv-2512.15657-b31b1b.svg)](https://arxiv.org/pdf/2512.15657)
[![GitHub](https://img.shields.io/badge/GitHub-Gaurav14cs17%2FGenAI-black.svg)](https://github.com/Gaurav14cs17/GenAI)

</div>

---

This notebook demonstrates **one-step image generation** using a trained SoFlow model.

### ‚ö° Key Feature: ONE STEP Generation!
Unlike diffusion models that need 50-1000 steps, SoFlow generates in **just ONE forward pass**!


## 1Ô∏è‚É£ Setup


In [None]:
# Check GPU
import torch
print(f"‚úÖ PyTorch: {torch.__version__}")
print(f"‚úÖ CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")

# Install dependencies
!pip install -q torch torchvision matplotlib numpy pillow
print("‚úÖ Dependencies ready!")


In [None]:
import os

# Clone repository
if not os.path.exists('GenAI'):
    !git clone https://github.com/Gaurav14cs17/GenAI.git
    print("‚úÖ Repository cloned!")
os.chdir('GenAI')

import sys
sys.path.insert(0, '.')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import time

from soflow.models import create_soflow_model

print(f"üìÅ Working directory: {os.getcwd()}")


## 2Ô∏è‚É£ Load Model

You can either:
- **Option A**: Upload your trained model
- **Option B**: Train a quick demo model


In [None]:
# Configuration
config = {
    "img_size": 32,
    "in_channels": 3,
    "num_classes": 10,
    "hidden_size": 256,
    "depth": 6,
    "num_heads": 4,
    "patch_size": 2,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']

print("üìã Configuration loaded!")


In [None]:
# Create model
model = create_soflow_model(
    in_channels=config["in_channels"],
    hidden_size=config["hidden_size"],
    depth=config["depth"],
    num_heads=config["num_heads"],
    patch_size=config["patch_size"],
    num_classes=config["num_classes"],
    img_size=config["img_size"]
).to(config["device"])

# Try to load pretrained weights
model_path = "outputs/colab_training/model.pt"
if os.path.exists(model_path):
    checkpoint = torch.load(model_path, map_location=config["device"])
    model.load_state_dict(checkpoint["model_state_dict"])
    print(f"‚úÖ Loaded pretrained model from {model_path}")
else:
    print("‚ö†Ô∏è No pretrained model found. Using random weights.")
    print("   Run the Training notebook first, or upload a model.")

model.eval()
num_params = sum(p.numel() for p in model.parameters())
print(f"üìä Model parameters: {num_params:,}")


## 3Ô∏è‚É£ One-Step Generation ‚ö°

The magic of SoFlow: **ONE forward pass = ONE image!**


In [None]:
@torch.no_grad()
def generate(model, num_samples, class_label=None, cfg_scale=2.0, seed=None):
    """
    Generate images with SoFlow in ONE step!
    
    Args:
        model: Trained SoFlow model
        num_samples: Number of images to generate
        class_label: Class to generate (0-9 for CIFAR-10), None for random
        cfg_scale: Classifier-Free Guidance scale (higher = more class adherence)
        seed: Random seed for reproducibility
    
    Returns:
        Generated images as numpy array [N, H, W, 3]
    """
    if seed is not None:
        torch.manual_seed(seed)
    
    # Start with random noise
    noise = torch.randn(
        num_samples, 
        config["in_channels"], 
        config["img_size"], 
        config["img_size"]
    ).to(config["device"]) * 0.5  # Scale to match training
    
    # Class labels
    if class_label is not None:
        labels = torch.full((num_samples,), class_label, dtype=torch.long, device=config["device"])
    else:
        labels = torch.randint(0, config["num_classes"], (num_samples,), device=config["device"])
    
    # ‚ö° ONE STEP GENERATION!
    start_time = time.time()
    samples = model.sample(noise, labels, cfg_scale=cfg_scale)
    gen_time = time.time() - start_time
    
    print(f"‚ö° Generated {num_samples} images in {gen_time*1000:.1f}ms ({gen_time/num_samples*1000:.2f}ms per image)")
    
    # Convert to numpy
    samples = samples.cpu().permute(0, 2, 3, 1).numpy()
    samples = (samples + 1) / 2  # [-1, 1] -> [0, 1]
    samples = np.clip(samples, 0, 1)
    
    return samples, labels.cpu().numpy()

print("‚úÖ Generation function ready!")


In [None]:
# Generate random samples
print("üé® Generating random samples...")
samples, labels = generate(model, num_samples=16, cfg_scale=2.0, seed=42)

# Display
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(samples[i])
    ax.set_title(class_names[labels[i]], fontsize=9)
    ax.axis('off')
plt.suptitle("Random Generated Samples (CFG=2.0)", fontsize=14)
plt.tight_layout()
plt.show()


In [None]:
# Generate all classes
print("üè∑Ô∏è Generating one sample per class...")

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for class_idx, ax in enumerate(axes.flat):
    samples, _ = generate(model, num_samples=1, class_label=class_idx, cfg_scale=2.0)
    ax.imshow(samples[0])
    ax.set_title(f"{class_idx}: {class_names[class_idx]}", fontsize=11)
    ax.axis('off')

plt.suptitle("All 10 CIFAR-10 Classes", fontsize=14)
plt.tight_layout()
plt.show()


In [None]:
# Generate multiple samples of a specific class
chosen_class = 3  # cat
print(f"üê± Generating 16 '{class_names[chosen_class]}' images...")

samples, _ = generate(model, num_samples=16, class_label=chosen_class, cfg_scale=2.5)

fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(samples[i])
    ax.axis('off')
plt.suptitle(f"16 Generated '{class_names[chosen_class].upper()}' Images", fontsize=14)
plt.tight_layout()
plt.show()


## 5Ô∏è‚É£ CFG Scale Comparison üéöÔ∏è

CFG (Classifier-Free Guidance) controls the trade-off between **quality** and **diversity**.


In [None]:
# Compare different CFG scales
print("üéöÔ∏è Comparing CFG scales...")

cfg_scales = [1.0, 1.5, 2.0, 3.0, 4.0]
seed = 123  # Same seed for fair comparison

fig, axes = plt.subplots(len(cfg_scales), 8, figsize=(16, 2*len(cfg_scales)))

for row, cfg in enumerate(cfg_scales):
    samples, labels = generate(model, num_samples=8, cfg_scale=cfg, seed=seed)
    for col in range(8):
        axes[row, col].imshow(samples[col])
        axes[row, col].axis('off')
        if col == 0:
            axes[row, col].set_ylabel(f"CFG={cfg}", fontsize=11, rotation=0, ha='right', va='center')

plt.suptitle("Effect of CFG Scale on Generation", fontsize=14)
plt.tight_layout()
plt.show()

print("""
üìä CFG Scale Guide:
   ‚Ä¢ CFG = 1.0: More diverse, less class-accurate
   ‚Ä¢ CFG = 2.0: Good balance (recommended)
   ‚Ä¢ CFG = 4.0: More class-accurate, less diverse
""")


## 6Ô∏è‚É£ Interactive Generation üéÆ

Customize your generation!


In [None]:
# ======================================
#     üéÆ CUSTOMIZE YOUR GENERATION
# ======================================

# Choose your class (0-9)
# 0: airplane, 1: automobile, 2: bird, 3: cat, 4: deer
# 5: dog, 6: frog, 7: horse, 8: ship, 9: truck
CLASS_LABEL = 7  # horse

# Number of images
NUM_IMAGES = 16

# CFG scale (1.0-4.0)
CFG_SCALE = 2.5

# Random seed (None for random)
SEED = None

# ======================================

print(f"üé® Generating {NUM_IMAGES} '{class_names[CLASS_LABEL]}' images...")
print(f"   CFG Scale: {CFG_SCALE}")

samples, _ = generate(model, NUM_IMAGES, CLASS_LABEL, CFG_SCALE, SEED)

# Display
cols = min(8, NUM_IMAGES)
rows = (NUM_IMAGES + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(cols*2, rows*2))
axes = np.atleast_2d(axes)
for i in range(rows):
    for j in range(cols):
        idx = i * cols + j
        if idx < NUM_IMAGES:
            axes[i, j].imshow(samples[idx])
        axes[i, j].axis('off')
plt.suptitle(f"Generated '{class_names[CLASS_LABEL]}' (CFG={CFG_SCALE})", fontsize=14)
plt.tight_layout()
plt.show()


## 7Ô∏è‚É£ Speed Benchmark ‚è±Ô∏è

Let's measure how fast SoFlow generates images!


In [None]:
# Speed benchmark
print("‚è±Ô∏è Running speed benchmark...")

batch_sizes = [1, 4, 16, 64]
results = []

for bs in batch_sizes:
    # Warmup
    _ = generate(model, bs, cfg_scale=2.0)
    
    # Benchmark
    times = []
    for _ in range(5):
        start = time.time()
        _ = generate(model, bs, cfg_scale=2.0)
        times.append(time.time() - start)
    
    avg_time = np.mean(times) * 1000
    per_image = avg_time / bs
    results.append((bs, avg_time, per_image))
    print(f"   Batch {bs:3d}: {avg_time:.1f}ms total, {per_image:.2f}ms per image")

print(f"""
‚ö° SoFlow Speed Summary:
   ‚Ä¢ Single image: ~{results[0][2]:.1f}ms
   ‚Ä¢ Throughput: ~{1000/results[-1][2]:.0f} images/second (batch={batch_sizes[-1]})
   
üÜö Comparison with Diffusion:
   ‚Ä¢ DDPM (1000 steps): ~50,000ms per image
   ‚Ä¢ DDIM (50 steps): ~2,500ms per image
   ‚Ä¢ SoFlow (1 step): ~{results[0][2]:.1f}ms per image ‚Üê YOU ARE HERE!
""")


## 8Ô∏è‚É£ Save Generated Images üíæ


In [None]:
# Generate and save a grid of images
os.makedirs("outputs/generated", exist_ok=True)

print("üíæ Generating final image grid...")
samples, labels = generate(model, 64, cfg_scale=2.0, seed=2024)

# Create grid
grid_size = 8
fig, axes = plt.subplots(grid_size, grid_size, figsize=(16, 16))
for i in range(grid_size):
    for j in range(grid_size):
        idx = i * grid_size + j
        axes[i, j].imshow(samples[idx])
        axes[i, j].axis('off')

plt.tight_layout(pad=0.5)
plt.savefig("outputs/generated/sample_grid.png", dpi=150, bbox_inches='tight')
print("‚úÖ Saved to outputs/generated/sample_grid.png")
plt.show()


In [None]:
# Download the grid (Colab)
try:
    from google.colab import files
    files.download("outputs/generated/sample_grid.png")
    print("üì• Download started!")
except ImportError:
    print("‚ÑπÔ∏è Not running in Colab")


---

## üéâ That's It!

You've successfully used SoFlow for **one-step image generation**!

### Key Takeaways

| Feature | Value |
|---------|-------|
| **Generation Steps** | 1 (vs 50-1000 for diffusion) |
| **Speed** | ~2ms per image |
| **Quality** | State-of-the-art |
| **CFG Support** | ‚úÖ Yes |

### Learn More

- üìÑ [Paper (arXiv)](https://arxiv.org/pdf/2512.15657)
- üìì [Training Notebook](./SoFlow_Training.ipynb)
- üìö [Documentation](../docs/README.md)

---

<div align="center">

**Made with ‚ù§Ô∏è for the ML community**

</div>
