# DDPM Training Playground 🎨

This notebook provides an interactive environment for experimenting with Denoising Diffusion Probabilistic Models (DDPM).

## Contents:
1. **Setup & Configuration** - Load dependencies and configs
2. **Data Exploration** - Visualize training datasets
3. **Model Architecture** - Explore UNet and time embeddings
4. **Noise Schedules** - Understand beta/alpha schedules
5. **Forward Process** - See how noise is added over time
6. **Training Demo** - Interactive mini training session
7. **Sampling** - Generate new samples with DDPM/DDIM
8. **Evaluation** - Assess model quality and performance


## 1. Setup & Imports 🔧

First, let's set up our environment and import all the necessary modules.


In [9]:
import os
import sys
import yaml
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import seaborn as sns
from PIL import Image

# Add the parent directory to path for imports
current_dir = Path().absolute()
parent_dir = current_dir.parent
sys.path.insert(0, str(parent_dir / 'src'))

print("✅ Basic imports successful!")
print(f"🔧 PyTorch version: {torch.__version__}")
print(f"🚀 CUDA available: {torch.cuda.is_available()}")


✅ Basic imports successful!
🔧 PyTorch version: 2.8.0+cu128
🚀 CUDA available: True


In [10]:
# Import DDPM modules (run this cell separately to debug any import issues)
try:
    # Core DDPM modules
    from ddpm_schedules import DDPMSchedules
    from sampler import DDPMSampler
    from dataset import get_dataloader, get_dataset_stats
    from utils import seed_everything, get_device
    from losses import ddpm_loss
    
    # Model modules
    from models.simple_unet import SimpleUNet
    from models.unet_small import UNetSmall
    from models.time_embedding import TimeEmbedding
    
    # Visualization and evaluation
    import visualize
    from eval import calculate_psnr_ssim, calculate_lpips_diversity
    
    print("✅ All DDPM modules imported successfully!")
    
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("💡 Tip: Make sure you're running from the notebooks/ directory")
    # Continue anyway - we'll handle missing modules gracefully


❌ Import error: attempted relative import with no known parent package
💡 Tip: Make sure you're running from the notebooks/ directory


## 2. Configuration & Device Setup 🎯

Load the MNIST configuration and set up our device and random seeds.


In [11]:
# Setup device and seed
seed_everything(42)
device = get_device("auto")
print(f"🎯 Using device: {device}")

# Load configuration
config_path = "../configs/mnist.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("📝 Configuration loaded:")
for section, values in config.items():
    if isinstance(values, dict):
        print(f"  {section}: {len(values)} parameters")
    else:
        print(f"  {section}: {values}")

# Set matplotlib style
plt.style.use('default')
sns.set_palette("husl")


NameError: name 'seed_everything' is not defined

In [None]:
# Load training data
print("🔄 Loading MNIST data...")
train_loader = get_dataloader(
    dataset_name=config["dataset"]["name"],
    batch_size=64,  # Smaller batch for visualization
    train=True,
    data_dir="../data",
    image_size=config["dataset"]["image_size"]
)

# Get a batch for visualization
data_iter = iter(train_loader)
images, labels = next(data_iter)

print(f"📦 Batch shape: {images.shape}")
print(f"🏷️ Labels shape: {labels.shape}")
print(f"📊 Data range: [{images.min():.3f}, {images.max():.3f}]")
print(f"🎯 Unique labels: {torch.unique(labels).tolist()}")


In [None]:
# Visualize sample images
fig, axes = plt.subplots(4, 8, figsize=(12, 6))
fig.suptitle("Sample Training Images", fontsize=16)

for i in range(32):
    row = i // 8
    col = i % 8
    
    # Convert to displayable format
    img = images[i].squeeze()  # Remove channel dimension
    
    # Denormalize for display (assuming [-1, 1] normalization)
    img = (img + 1) / 2
    img = torch.clamp(img, 0, 1)
    
    axes[row, col].imshow(img, cmap='gray')
    axes[row, col].set_title(f'Label: {labels[i].item()}', fontsize=10)
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()

print(f"📈 Dataset Statistics:")
print(f"  Mean: {images.mean():.4f}")
print(f"  Std: {images.std():.4f}")
print(f"  Shape per image: {images[0].shape}")


In [None]:
# Create the model based on config
print("🏗️ Creating model...")
model_config = config["model"]

if model_config["type"] == "simple_unet":
    model = SimpleUNet(
        in_channels=model_config["in_channels"],
        out_channels=model_config["out_channels"],
        base_channels=model_config.get("base_channels", 64),
        time_embed_dim=model_config.get("time_embed_dim", 256)
    )
elif model_config["type"] == "unet_small":
    model = UNetSmall(
        in_channels=model_config["in_channels"],
        out_channels=model_config["out_channels"],
        model_channels=model_config["model_channels"],
        channel_mult=model_config["channel_mult"],
        num_res_blocks=model_config["num_res_blocks"],
        attention_resolutions=model_config["attention_resolutions"],
        dropout=model_config["dropout"],
        time_embed_dim=model_config["time_embed_dim"],
        use_attention=model_config["use_attention"],
        num_heads=model_config["num_heads"]
    )
else:
    raise ValueError(f"Unknown model type: {model_config['type']}")

model = model.to(device)

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"🏗️ Model: {model.__class__.__name__}")
print(f"📊 Total parameters: {total_params:,}")
print(f"🎯 Trainable parameters: {trainable_params:,}")
print(f"💾 Model size: ~{total_params * 4 / 1024 / 1024:.2f} MB (float32)")


In [None]:
# Test forward pass
test_input = torch.randn(2, model_config["in_channels"], 32, 32, device=device)
test_t = torch.randint(0, 1000, (2,), device=device)

with torch.no_grad():
    output = model(test_input, test_t)
    print(f"🔄 Forward pass test:")
    print(f"  Input shape: {test_input.shape}")
    print(f"  Time shape: {test_t.shape}")
    print(f"  Output shape: {output.shape}")
    print(f"  ✅ Forward pass successful!")

# Create DDPM schedules
print(f"📅 Creating DDPM schedules...")
schedule_config = config["schedules"]
schedules = DDPMSchedules(
    num_timesteps=schedule_config["num_timesteps"],
    schedule_type=schedule_config["type"],
    beta_start=float(schedule_config["beta_start"]),
    beta_end=float(schedule_config["beta_end"])
)

print(f"📅 Schedule type: {schedule_config['type']}")
print(f"⏰ Timesteps: {schedule_config['num_timesteps']}")
print(f"📊 Beta range: [{schedule_config['beta_start']}, {schedule_config['beta_end']}]")
print(f"✅ Schedules created successfully!")


In [None]:
# Create sampler and test basic functionality
print("🚀 Creating DDPM sampler...")
sampler = DDPMSampler(schedules)
model.eval()

# Test sampling (this will be noisy since model isn't trained yet)
print("🎨 Testing sampling (untrained model - expect noise)...")
sample_shape = (4, model_config["in_channels"], 32, 32)

with torch.no_grad():
    # Test DDIM sampling (faster)
    print("  Testing DDIM sampling (50 steps)...")
    ddim_samples = sampler.sample(
        model=model,
        shape=sample_shape,
        method="ddim",
        num_steps=50,
        device=device
    )
    print(f"  ✅ DDIM sampling successful: {ddim_samples.shape}")
    
    # Test DDPM sampling (slower, more thorough)
    print("  Testing DDPM sampling...")
    ddpm_samples = sampler.sample(
        model=model,
        shape=(2, model_config["in_channels"], 32, 32),  # Smaller batch for speed
        method="ddpm",
        device=device
    )
    print(f"  ✅ DDPM sampling successful: {ddpm_samples.shape}")

print("🎉 All core functionality working!")


In [None]:
# Visualize generated samples (untrained model)
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
fig.suptitle("Generated Samples (Untrained Model - Pure Noise Expected)", fontsize=14)

# Show DDIM samples
for i in range(4):
    sample = ddim_samples[i].cpu().squeeze()
    sample = (sample + 1) / 2  # Denormalize
    sample = torch.clamp(sample, 0, 1)
    
    axes[0, i].imshow(sample, cmap='gray')
    axes[0, i].set_title(f"DDIM Sample {i+1}", fontsize=10)
    axes[0, i].axis('off')

# Show DDPM samples
for i in range(2):
    sample = ddpm_samples[i].cpu().squeeze()
    sample = (sample + 1) / 2  # Denormalize
    sample = torch.clamp(sample, 0, 1)
    
    axes[1, i].imshow(sample, cmap='gray')
    axes[1, i].set_title(f"DDPM Sample {i+1}", fontsize=10)
    axes[1, i].axis('off')

# Clear the remaining subplots
for i in range(2, 4):
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

print("🔍 Sample Analysis (Untrained Model):")
print(f"  DDIM samples mean: {ddim_samples.mean():.4f}")
print(f"  DDIM samples std: {ddim_samples.std():.4f}")
print(f"  DDPM samples mean: {ddpm_samples.mean():.4f}")
print(f"  DDPM samples std: {ddpm_samples.std():.4f}")
print("💡 These are noise since the model is untrained!")


In [None]:
# Load pre-trained model (if available)
checkpoint_path = "../outputs/ckpts/best.pt"  # or latest.pt, or specific epoch

if Path(checkpoint_path).exists():
    print(f"🔄 Loading checkpoint from {checkpoint_path}...")
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"✅ Model loaded!")
    print(f"  Epoch: {checkpoint.get('epoch', 'unknown')}")
    print(f"  Loss: {checkpoint.get('loss', 'unknown')}")
    
    # Now generate high-quality samples
    model.eval()
    with torch.no_grad():
        print("🎨 Generating samples with trained model...")
        trained_samples = sampler.sample(
            model=model,
            shape=(16, model_config["in_channels"], 32, 32),
            method="ddim",
            num_steps=50,
            device=device
        )
        
        # Visualize trained samples
        fig, axes = plt.subplots(4, 4, figsize=(10, 10))
        fig.suptitle("Generated MNIST Digits (Trained Model)", fontsize=16)
        
        for i in range(16):
            row, col = i // 4, i % 4
            sample = trained_samples[i].cpu().squeeze()
            sample = (sample + 1) / 2
            sample = torch.clamp(sample, 0, 1)
            
            axes[row, col].imshow(sample, cmap='gray')
            axes[row, col].axis('off')
            
        plt.tight_layout()
        plt.show()
        
        print("✅ High-quality digit generation successful!")
        
else:
    print("⚠️ No pre-trained model found.")
    print("💡 Run training first: make train")
    print("📍 Expected path:", checkpoint_path)

print("\n" + "="*60)
print("🎉 DDPM Notebook Complete!")
print("📖 You've successfully explored the DDPM pipeline!")
print("="*60)


## 3. Data Exploration 📊

Let's load and visualize the MNIST training data to understand what we're working with.
