# Day 8: Beta Schedule Playground

Interactive exploration of different beta schedules for DDPM:
- **Linear**: Uniform noise addition, simple implementation
- **Cosine**: Slower early diffusion, preserves fine details longer
- **Quadratic**: Faster early diffusion, aggressive noise schedule

This notebook allows you to:
1. 🔬 Experiment with different T values and schedule parameters
2. 📊 Visualize β_t, ᾱ_t, and SNR curves
3. 🎯 Compare schedule effects on diffusion process
4. 🧪 Test sampling with different schedules


In [1]:
# Setup and imports
import sys
import os
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from ipywidgets import interact, IntSlider, FloatSlider, Dropdown, fixed
import warnings
warnings.filterwarnings('ignore')

# Project imports
from src.schedules import (
    beta_linear, beta_cosine, beta_quadratic, 
    get_schedule, plot_all_schedules, validate_schedule
)
from src.models.unet_small import UNetSmall
from src.sampler import DDPMSampler, DDIMSampler
from src.utils import set_seed, tensor_to_pil
from src.losses import compute_forward_process

# Set style
plt.style.use('default')
sns.set_palette("husl")

print("🚀 Beta Schedule Playground Ready!")
print(f"Using device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")


🚀 Beta Schedule Playground Ready!
Using device: CUDA


## 1. 📊 Interactive Schedule Visualization

Experiment with different schedule parameters and see how they affect the diffusion process!


In [None]:
def plot_interactive_schedules(T=1000, beta_min=1e-4, beta_max=0.02, cosine_s=0.008):
    """
    Interactive schedule plotting with adjustable parameters.
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Create schedules with current parameters
    linear_schedule = beta_linear(T, beta_min, beta_max)
    cosine_schedule = beta_cosine(T, cosine_s)
    quad_schedule = beta_quadratic(T, beta_min, beta_max)
    
    schedules = {
        'Linear': linear_schedule,
        'Cosine': cosine_schedule, 
        'Quadratic': quad_schedule
    }
    
    timesteps = np.arange(1, T + 1)
    colors = ['blue', 'red', 'green']
    
    for i, (name, schedule) in enumerate(schedules.items()):
        color = colors[i]
        
        # Plot β_t
        axes[0, 0].plot(timesteps, schedule['betas'].numpy(), 
                       label=name, color=color, linewidth=2)
        
        # Plot ᾱ_t
        axes[0, 1].plot(timesteps, schedule['alpha_bars'].numpy(), 
                       label=name, color=color, linewidth=2)
        
        # Plot SNR
        axes[1, 0].plot(timesteps, schedule['snr'].numpy(), 
                       label=name, color=color, linewidth=2)
        
        # Plot β_t vs ᾱ_t relationship
        axes[1, 1].plot(schedule['betas'].numpy(), schedule['alpha_bars'].numpy(),
                       label=name, color=color, linewidth=2)
    
    # Format plots
    titles = ['Beta Schedule: β_t', 'Cumulative Alpha: ᾱ_t', 
              'Signal-to-Noise Ratio', 'β_t vs ᾱ_t Relationship']
    xlabels = ['Timestep t', 'Timestep t', 'Timestep t', 'β_t']
    ylabels = ['β_t', 'ᾱ_t', 'SNR (log scale)', 'ᾱ_t']
    
    for i, ax in enumerate(axes.flat):
        ax.set_title(titles[i], fontsize=12, fontweight='bold')
        ax.set_xlabel(xlabels[i])
        ax.set_ylabel(ylabels[i])
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        if i == 2:  # SNR plot
            ax.set_yscale('log')
    
    plt.tight_layout()
    plt.show()
    
    # Print some statistics
    print(f"\n📊 Schedule Statistics (T={T}):")
    print("="*50)
    for name, schedule in schedules.items():
        betas = schedule['betas']
        alpha_bars = schedule['alpha_bars']
        snr = schedule['snr']
        
        print(f"\n{name:>10}: β ∈ [{betas.min():.2e}, {betas.max():.2e}]")
        print(f"{'':>10}  ᾱ ∈ [{alpha_bars.min():.3f}, {alpha_bars.max():.3f}]")
        print(f"{'':>10}  SNR ∈ [{snr.min():.2e}, {snr.max():.2e}]")
        print(f"{'':>10}  Valid: {validate_schedule(schedule)}")

# Create interactive widget
interact(plot_interactive_schedules,
         T=IntSlider(min=10, max=2000, step=10, value=1000, description='T:'),
         beta_min=FloatSlider(min=1e-5, max=1e-3, step=1e-5, value=1e-4, 
                             description='β_min:', readout_format='.1e'),
         beta_max=FloatSlider(min=0.005, max=0.1, step=0.005, value=0.02, 
                             description='β_max:'),
         cosine_s=FloatSlider(min=0.001, max=0.02, step=0.001, value=0.008, 
                             description='cosine_s:'));


interactive(children=(IntSlider(value=1000, description='T:', max=2000, min=10, step=10), FloatSlider(value=0.…

## 2. 🔬 Forward Diffusion Process Visualization

See how different schedules affect the forward diffusion process!


In [3]:
def visualize_forward_process(schedule_name='linear', T=1000, num_steps=8):
    """
    Visualize how an image gets corrupted through forward diffusion.
    """
    set_seed(42)
    
    # Create a simple test image (checkerboard pattern)
    size = 32
    x = np.arange(size)
    y = np.arange(size)
    X, Y = np.meshgrid(x, y)
    checkerboard = ((X // 4) + (Y // 4)) % 2
    
    # Convert to tensor and normalize to [-1, 1]
    x0 = torch.tensor(checkerboard, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    x0 = x0 * 2 - 1  # [0, 1] -> [-1, 1]
    
    # Get schedule
    schedule = get_schedule(schedule_name, T)
    
    # Select timesteps to visualize
    timesteps = np.linspace(0, T-1, num_steps, dtype=int)
    
    fig, axes = plt.subplots(2, num_steps, figsize=(2*num_steps, 4))
    if num_steps == 1:
        axes = axes.reshape(2, 1)
    
    for i, t in enumerate(timesteps):
        # Apply forward diffusion
        t_tensor = torch.tensor([t], dtype=torch.long)
        noisy_x, noise = compute_forward_process(x0, t_tensor, schedule['alpha_bars'])
        
        # Plot noisy image
        axes[0, i].imshow(noisy_x[0, 0].numpy(), cmap='gray', vmin=-2, vmax=2)
        axes[0, i].set_title(f't={t}')
        axes[0, i].axis('off')
        
        # Plot added noise
        axes[1, i].imshow(noise[0, 0].numpy(), cmap='gray', vmin=-2, vmax=2)
        axes[1, i].set_title(f'noise t={t}')
        axes[1, i].axis('off')
        
        # Print statistics
        signal_strength = schedule['alpha_bars'][t].item()
        noise_strength = (1 - schedule['alpha_bars'][t]).item()
        snr = schedule['snr'][t].item()
        
        print(f"t={t:3d}: ᾱ={signal_strength:.3f}, noise_var={noise_strength:.3f}, SNR={snr:.2e}")
    
    axes[0, 0].set_ylabel('Noisy Image', fontsize=12)
    axes[1, 0].set_ylabel('Added Noise', fontsize=12)
    
    plt.suptitle(f'Forward Diffusion Process ({schedule_name.title()} Schedule)', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Interactive widget
interact(visualize_forward_process,
         schedule_name=Dropdown(options=['linear', 'cosine', 'quadratic'], 
                               value='cosine', description='Schedule:'),
         T=IntSlider(min=50, max=1000, step=50, value=500, description='T:'),
         num_steps=IntSlider(min=4, max=12, step=1, value=8, description='Steps:'));


interactive(children=(Dropdown(description='Schedule:', index=1, options=('linear', 'cosine', 'quadratic'), va…

## 3. 🎯 Schedule Comparison at Key Timesteps

Compare how different schedules behave at specific timesteps.


In [4]:
def compare_schedules_at_timestep(t=500, T=1000):
    """
    Compare all schedules at a specific timestep.
    """
    schedules = {
        'Linear': get_schedule('linear', T),
        'Cosine': get_schedule('cosine', T),
        'Quadratic': get_schedule('quadratic', T)
    }
    
    print(f"📊 Schedule Comparison at t={t} (T={T})")
    print("="*60)
    print(f"{'Schedule':<12} {'β_t':<8} {'ᾱ_t':<8} {'SNR':<12} {'Signal %':<10}")
    print("-"*60)
    
    for name, schedule in schedules.items():
        if t >= T:
            print(f"Error: t={t} >= T={T}")
            return
            
        beta_t = schedule['betas'][t].item()
        alpha_bar_t = schedule['alpha_bars'][t].item()
        snr_t = schedule['snr'][t].item()
        signal_percent = alpha_bar_t * 100
        
        print(f"{name:<12} {beta_t:<8.2e} {alpha_bar_t:<8.3f} {snr_t:<12.2e} {signal_percent:<10.1f}")
    
    # Visualization
    set_seed(42)
    x0 = torch.randn(1, 1, 32, 32)
    t_tensor = torch.tensor([t], dtype=torch.long)
    
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    # Original
    axes[0].imshow(x0[0, 0].numpy(), cmap='gray')
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    # Noisy versions
    colors = ['blue', 'red', 'green']
    for i, (name, schedule) in enumerate(schedules.items()):
        noisy_x, _ = compute_forward_process(x0, t_tensor, schedule['alpha_bars'])
        
        axes[i+1].imshow(noisy_x[0, 0].numpy(), cmap='gray')
        axes[i+1].set_title(f'{name}\n(ᾱ={schedule["alpha_bars"][t]:.3f})', 
                           color=colors[i])
        axes[i+1].axis('off')
    
    plt.suptitle(f'Noise Levels at t={t}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Interactive widget
interact(compare_schedules_at_timestep,
         t=IntSlider(min=0, max=999, step=10, value=500, description='Timestep t:'),
         T=IntSlider(min=100, max=1000, step=100, value=1000, description='Total T:'));


interactive(children=(IntSlider(value=500, description='Timestep t:', max=999, step=10), IntSlider(value=1000,…

## 4. 🔍 Key Insights Summary

Based on your exploration, here are the key insights about beta schedules:


In [5]:
print("""
🧠 Key Insights from Beta Schedule Exploration:

🔵 LINEAR SCHEDULE:
   • Simplest implementation: β grows linearly
   • Uniform noise addition rate
   • Good baseline, but may not be optimal
   • Early steps preserve less detail than cosine

🔴 COSINE SCHEDULE:
   • Slower early diffusion, faster later
   • Preserves signal longer in early timesteps
   • Better for fine detail preservation
   • Often produces higher quality samples
   • More gradual SNR decay initially

🟢 QUADRATIC SCHEDULE:
   • Faster early noise addition
   • More aggressive initial corruption
   • May converge faster but lose fine details
   • Steep SNR decline early on

📊 PRACTICAL IMPLICATIONS:
   • Cosine often works best for image generation
   • Linear is simpler but may need more timesteps
   • Quadratic can be faster but may sacrifice quality
   • Schedule choice affects training dynamics
   • Different schedules may work better for different data

⚡ OPTIMIZATION TIPS:
   • Start with cosine schedule for images
   • Adjust T based on compute budget vs quality
   • Consider DDIM for faster sampling
   • Validate schedule properties before training
   • Experiment with custom schedules for specific domains

🎯 NEXT STEPS:
   • Train models with different schedules
   • Compare sample quality metrics (FID, IS, etc.)
   • Analyze training curves and convergence
   • Test on your specific dataset/domain
""")



🧠 Key Insights from Beta Schedule Exploration:

🔵 LINEAR SCHEDULE:
   • Simplest implementation: β grows linearly
   • Uniform noise addition rate
   • Good baseline, but may not be optimal
   • Early steps preserve less detail than cosine

🔴 COSINE SCHEDULE:
   • Slower early diffusion, faster later
   • Preserves signal longer in early timesteps
   • Better for fine detail preservation
   • Often produces higher quality samples
   • More gradual SNR decay initially

🟢 QUADRATIC SCHEDULE:
   • Faster early noise addition
   • More aggressive initial corruption
   • May converge faster but lose fine details
   • Steep SNR decline early on

📊 PRACTICAL IMPLICATIONS:
   • Cosine often works best for image generation
   • Linear is simpler but may need more timesteps
   • Quadratic can be faster but may sacrifice quality
   • Schedule choice affects training dynamics
   • Different schedules may work better for different data

⚡ OPTIMIZATION TIPS:
   • Start with cosine schedule for imag

## 🚀 Ready to Train?

Now that you've explored different beta schedules, you're ready to train models!

```bash
# From the project root directory:

# Train all schedules
bash scripts/train_all.sh

# Generate samples
bash scripts/sample_all.sh

# Compare results
bash scripts/compare.sh

# Or use the Makefile
make train_all
make sample_all
make compare
```

Happy experimenting! 🎉
