# Forward Diffusion Process - Interactive Notebook

This notebook provides an interactive exploration of the forward diffusion process in DDPMs. You can:
- Tweak the number of timesteps T
- Compare different noise schedules (linear, cosine, sigmoid)
- Inspect SNR thresholds and their impact
- Visualize the degradation process

## Setup


In [3]:
import sys
import os
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, Image as IPImage
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual

# Import our modules
from src.utils import set_seed, get_device, save_image_grid
from src.dataset import get_sample_batch
from src.ddpm_schedules import get_ddpm_schedule, get_schedule_stats
from src.forward import (
    q_xt_given_x0, sample_trajectory, snr, snr_db, 
    get_timesteps_for_snr_threshold, compute_mse_to_x0
)
from src.visualize import (
    plot_schedules, create_trajectory_grid, plot_pixel_histograms,
    plot_snr_analysis
)

# Setup
set_seed(42)
device = get_device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load sample data
print("Loading sample data...")
mnist_batch, mnist_labels = get_sample_batch("mnist", "../data", batch_size=16, normalize_mode="minus_one_one")
cifar_batch, cifar_labels = get_sample_batch("cifar10", "../data", batch_size=16, normalize_mode="minus_one_one")

mnist_batch = mnist_batch.to(device)
cifar_batch = cifar_batch.to(device)

print(f"MNIST batch shape: {mnist_batch.shape}")
print(f"CIFAR batch shape: {cifar_batch.shape}")


Using device: cuda
Loading sample data...
MNIST batch shape: torch.Size([16, 1, 28, 28])
CIFAR batch shape: torch.Size([16, 3, 32, 32])


## Interactive Schedule Exploration

Use the widgets below to explore different noise schedules and their properties:


In [4]:
def explore_schedules(T=1000, schedule_type="cosine", show_stats=True):
    """Interactive function to explore different schedules."""
    
    # Get schedule
    betas, alphas, alpha_bars = get_ddpm_schedule(T, schedule_type)
    
    # Compute statistics
    stats = get_schedule_stats(betas, alphas, alpha_bars)
    snr_values = snr_db(alpha_bars)
    
    # Create plots
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    timesteps = np.arange(T)
    
    # Beta schedule
    axes[0,0].plot(timesteps, betas.numpy(), 'b-', linewidth=2)
    axes[0,0].set_title('Beta Schedule')
    axes[0,0].set_xlabel('Timestep t')
    axes[0,0].set_ylabel(r'$\beta_t$')
    axes[0,0].grid(True, alpha=0.3)
    
    # Alpha bar schedule
    axes[0,1].plot(timesteps, alpha_bars.numpy(), 'g-', linewidth=2)
    axes[0,1].set_title('Cumulative Alpha')
    axes[0,1].set_xlabel('Timestep t')
    axes[0,1].set_ylabel(r'$\bar{\alpha}_t$')
    axes[0,1].grid(True, alpha=0.3)
    
    # SNR in dB
    axes[1,0].plot(timesteps, snr_values.numpy(), 'r-', linewidth=2)
    axes[1,0].axhline(y=0, color='black', linestyle='--', alpha=0.5)
    axes[1,0].axhline(y=-5, color='red', linestyle='--', alpha=0.5, label='-5dB threshold')
    axes[1,0].set_title('Signal-to-Noise Ratio')
    axes[1,0].set_xlabel('Timestep t')
    axes[1,0].set_ylabel('SNR (dB)')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)
    
    # SNR thresholds analysis
    thresholds = [-5, -10, -15, -20]
    threshold_times = []
    
    for threshold in thresholds:
        timesteps_below = get_timesteps_for_snr_threshold(alpha_bars, threshold)
        if len(timesteps_below) > 0:
            threshold_times.append(int(timesteps_below[0]))
        else:
            threshold_times.append(T)
    
    bars = axes[1,1].bar(range(len(thresholds)), threshold_times, alpha=0.7)
    axes[1,1].set_title('Time to SNR Thresholds')
    axes[1,1].set_xlabel('SNR Threshold')
    axes[1,1].set_ylabel('Timestep')
    axes[1,1].set_xticks(range(len(thresholds)))
    axes[1,1].set_xticklabels([f'{t}dB' for t in thresholds])
    axes[1,1].grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bar, time in zip(bars, threshold_times):
        height = bar.get_height()
        if time < T:
            axes[1,1].text(bar.get_x() + bar.get_width()/2., height + T*0.01,
                          f'{time}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    if show_stats:
        print(f"\n=== {schedule_type.upper()} Schedule Statistics ===")
        print(f"Total timesteps: {stats['T']}")
        print(f"Beta range: {stats['beta_min']:.6f} - {stats['beta_max']:.6f}")
        print(f"Final alpha_bar: {stats['alpha_bar_final']:.6f}")
        print(f"Final SNR: {snr_values[-1]:.2f} dB")
        
        print("\nSNR Threshold Analysis:")
        for threshold, time in zip(thresholds, threshold_times):
            if time < T:
                print(f"  {threshold:+.0f} dB reached at timestep {time}")
            else:
                print(f"  {threshold:+.0f} dB never reached")

# Create interactive widget
interactive_plot = interact(
    explore_schedules,
    T=widgets.IntSlider(min=100, max=2000, step=100, value=1000, description='Timesteps T:'),
    schedule_type=widgets.Dropdown(
        options=['linear', 'cosine', 'sigmoid'],
        value='cosine',
        description='Schedule:'
    ),
    show_stats=widgets.Checkbox(value=True, description='Show Statistics')
)


interactive(children=(IntSlider(value=1000, description='Timesteps T:', max=2000, min=100, step=100), Dropdown…

## Interactive Forward Diffusion Visualization

Visualize how images degrade through the forward process:


In [5]:
def visualize_forward_process(dataset="mnist", schedule_type="cosine", T=1000, 
                             timesteps_str="0,50,100,250,500,750,999"):
    """Interactive visualization of forward diffusion process."""
    
    # Parse timesteps
    timesteps_to_show = [int(t.strip()) for t in timesteps_str.split(',')]
    
    # Get data batch
    if dataset == "mnist":
        x0_batch = mnist_batch[:8]  # Use fewer images for cleaner display
    else:
        x0_batch = cifar_batch[:8]
    
    # Get schedule
    betas, alphas, alpha_bars = get_ddpm_schedule(T, schedule_type)
    alpha_bars = alpha_bars.to(device)
    
    # Create trajectory grid
    trajectories = torch.zeros(
        (x0_batch.shape[0], len(timesteps_to_show)) + x0_batch.shape[1:],
        device=device
    )
    
    for t_idx, t in enumerate(timesteps_to_show):
        if t == 0:
            trajectories[:, t_idx] = x0_batch
        else:
            batch_size = x0_batch.shape[0]
            t_tensor = torch.full((batch_size,), t - 1, device=device)
            x_t, _ = q_xt_given_x0(x0_batch, t_tensor, alpha_bars)
            trajectories[:, t_idx] = x_t
    
    # Denormalize for display
    trajectories = (trajectories + 1.0) / 2.0  # [-1,1] -> [0,1]
    trajectories = torch.clamp(trajectories, 0, 1)
    
    # Create visualization
    batch_size, num_t = trajectories.shape[:2]
    fig, axes = plt.subplots(batch_size, num_t, figsize=(num_t * 1.5, batch_size * 1.5))
    
    if batch_size == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(batch_size):
        for j, t in enumerate(timesteps_to_show):
            img = trajectories[i, j]
            
            if img.shape[0] == 1:  # Grayscale
                axes[i, j].imshow(img.squeeze(0).cpu().numpy(), cmap='gray', vmin=0, vmax=1)
            else:  # RGB
                axes[i, j].imshow(img.permute(1, 2, 0).cpu().numpy())
            
            if i == 0:  # Add timestep labels to top row
                axes[i, j].set_title(f't={t}', fontsize=10)
            axes[i, j].axis('off')
    
    plt.suptitle(f'Forward Diffusion Process - {dataset.upper()} ({schedule_type} schedule)', fontsize=14)
    plt.tight_layout()
    plt.show()

# Create interactive widget
interact(
    visualize_forward_process,
    dataset=widgets.Dropdown(options=['mnist', 'cifar10'], value='mnist', description='Dataset:'),
    schedule_type=widgets.Dropdown(options=['linear', 'cosine', 'sigmoid'], value='cosine', description='Schedule:'),
    T=widgets.IntSlider(min=500, max=2000, step=100, value=1000, description='Timesteps T:'),
    timesteps_str=widgets.Text(value="0,50,100,250,500,750,999", description='Timesteps:')
);


interactive(children=(Dropdown(description='Dataset:', options=('mnist', 'cifar10'), value='mnist'), Dropdown(…

## SNR Threshold Analysis

Investigate when images become more noise than signal:


In [6]:
def analyze_snr_threshold(threshold_db=-5.0, T=1000):
    """Analyze what happens at specific SNR thresholds."""
    
    schedules = ['linear', 'cosine', 'sigmoid']
    colors = ['blue', 'red', 'green']
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    threshold_times = {}
    
    for schedule, color in zip(schedules, colors):
        betas, alphas, alpha_bars = get_ddpm_schedule(T, schedule)
        snr_values = snr_db(alpha_bars)
        timesteps = np.arange(T)
        
        # Plot SNR curves
        axes[0].plot(timesteps, snr_values.numpy(), color=color, linewidth=2, label=schedule)
        
        # Find threshold crossing
        threshold_ts = get_timesteps_for_snr_threshold(alpha_bars, threshold_db)
        if len(threshold_ts) > 0:
            threshold_t = int(threshold_ts[0])
            threshold_times[schedule] = threshold_t
            axes[0].scatter(threshold_t, threshold_db, color=color, s=100, zorder=5)
        else:
            threshold_times[schedule] = T
    
    axes[0].axhline(y=threshold_db, color='black', linestyle='--', alpha=0.7)
    axes[0].axhline(y=0, color='gray', linestyle=':', alpha=0.5)
    axes[0].set_title(f'SNR Curves with {threshold_db}dB Threshold')
    axes[0].set_xlabel('Timestep t')
    axes[0].set_ylabel('SNR (dB)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Bar chart of threshold times
    bars = axes[1].bar(schedules, [threshold_times[s] for s in schedules], color=colors, alpha=0.7)
    axes[1].set_title(f'Time to {threshold_db}dB Threshold')
    axes[1].set_ylabel('Timestep')
    axes[1].grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bar, schedule in zip(bars, schedules):
        height = bar.get_height()
        if threshold_times[schedule] < T:
            axes[1].text(bar.get_x() + bar.get_width()/2., height + T*0.01,
                          f'{threshold_times[schedule]}', ha='center', va='bottom')
        else:
            axes[1].text(bar.get_x() + bar.get_width()/2., height/2,
                          'Never', ha='center', va='center', color='white', fontweight='bold')
    
    # Show example image at threshold (cosine schedule)
    cosine_threshold_t = threshold_times.get('cosine', T//2)
    if cosine_threshold_t < T:
        betas, alphas, alpha_bars = get_ddpm_schedule(T, 'cosine')
        alpha_bars = alpha_bars.to(device)
        
        # Sample at threshold
        sample_image = mnist_batch[:1]
        t_tensor = torch.full((1,), cosine_threshold_t - 1, device=device)
        x_threshold, _ = q_xt_given_x0(sample_image, t_tensor, alpha_bars)
        
        # Denormalize for display
        x_original = (sample_image[0] + 1.0) / 2.0
        x_threshold_display = (x_threshold[0] + 1.0) / 2.0
        
        # Show comparison
        if x_original.shape[0] == 1:  # Grayscale
            axes[2].imshow(x_threshold_display.squeeze(0).cpu().numpy(), cmap='gray')
        else:
            axes[2].imshow(x_threshold_display.permute(1, 2, 0).cpu().numpy())
        
        axes[2].set_title(f'Sample at {threshold_db}dB\\n(t={cosine_threshold_t})')
        axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nThreshold Analysis for {threshold_db}dB:")
    for schedule in schedules:
        time = threshold_times[schedule]
        if time < T:
            percent = (time / T) * 100
            print(f"  {schedule:>8}: t={time:4d} ({percent:.1f}% through process)")
        else:
            print(f"  {schedule:>8}: Never reached")

# Create interactive widget
interact(
    analyze_snr_threshold,
    threshold_db=widgets.FloatSlider(min=-20, max=5, step=1, value=-5, description='SNR Threshold (dB):'),
    T=widgets.IntSlider(min=500, max=2000, step=100, value=1000, description='Total Timesteps:')
);


interactive(children=(FloatSlider(value=-5.0, description='SNR Threshold (dB):', max=5.0, min=-20.0, step=1.0)…

## Pixel Statistics Evolution

Watch how pixel statistics evolve towards standard Gaussian:


In [7]:
def analyze_pixel_statistics(dataset="mnist", schedule_type="cosine", T=1000):
    """Analyze evolution of pixel statistics."""
    
    # Get data
    if dataset == "mnist":
        x0_batch = mnist_batch
    else:
        x0_batch = cifar_batch
    
    # Get schedule
    betas, alphas, alpha_bars = get_ddpm_schedule(T, schedule_type)
    alpha_bars = alpha_bars.to(device)
    
    # Sample at different timesteps
    test_timesteps = [0, 50, 100, 200, 400, 600, 800, 999]
    pixel_stats = {'timesteps': test_timesteps, 'means': [], 'stds': [], 'mins': [], 'maxs': []}
    
    for t in test_timesteps:
        if t == 0:
            x_t = x0_batch
        else:
            t_tensor = torch.full((x0_batch.shape[0],), t - 1, device=device)
            x_t, _ = q_xt_given_x0(x0_batch, t_tensor, alpha_bars)
        
        # Compute statistics
        pixels = x_t.flatten().cpu()
        pixel_stats['means'].append(float(pixels.mean()))
        pixel_stats['stds'].append(float(pixels.std()))
        pixel_stats['mins'].append(float(pixels.min()))
        pixel_stats['maxs'].append(float(pixels.max()))
    
    # Plot evolution
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    # Mean evolution
    axes[0,0].plot(test_timesteps, pixel_stats['means'], 'bo-', linewidth=2, markersize=6)
    axes[0,0].axhline(y=0, color='red', linestyle='--', alpha=0.7, label='Target (0)')
    axes[0,0].set_title('Pixel Mean Evolution')
    axes[0,0].set_xlabel('Timestep t')
    axes[0,0].set_ylabel('Mean Pixel Value')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)
    
    # Standard deviation evolution
    axes[0,1].plot(test_timesteps, pixel_stats['stds'], 'go-', linewidth=2, markersize=6)
    axes[0,1].axhline(y=1, color='red', linestyle='--', alpha=0.7, label='Target (1)')
    axes[0,1].set_title('Pixel Std Evolution')
    axes[0,1].set_xlabel('Timestep t')
    axes[0,1].set_ylabel('Std Pixel Value')
    axes[0,1].legend()
    axes[0,1].grid(True, alpha=0.3)
    
    # Range evolution
    axes[1,0].fill_between(test_timesteps, pixel_stats['mins'], pixel_stats['maxs'], 
                          alpha=0.3, color='purple', label='Min-Max Range')
    axes[1,0].plot(test_timesteps, pixel_stats['mins'], 'purple', linewidth=2, label='Min')
    axes[1,0].plot(test_timesteps, pixel_stats['maxs'], 'purple', linewidth=2, label='Max')
    axes[1,0].set_title('Pixel Value Range Evolution')
    axes[1,0].set_xlabel('Timestep t')
    axes[1,0].set_ylabel('Pixel Value')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)
    
    # Convergence to N(0,1)
    gaussian_distances = []
    for mean, std in zip(pixel_stats['means'], pixel_stats['stds']):
        distance = np.sqrt((mean - 0)**2 + (std - 1)**2)
        gaussian_distances.append(distance)
    
    axes[1,1].plot(test_timesteps, gaussian_distances, 'ro-', linewidth=2, markersize=6)
    axes[1,1].set_title('Distance to N(0,1)')
    axes[1,1].set_xlabel('Timestep t')
    axes[1,1].set_ylabel(r'$\sqrt{(\mu-0)^2 + (\sigma-1)^2}$')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics table
    print(f"\nPixel Statistics Evolution - {dataset.upper()} ({schedule_type} schedule):")
    print("=" * 70)
    print(f"{'Timestep':>8} {'Mean':>8} {'Std':>8} {'Min':>8} {'Max':>8} {'Dist to N(0,1)':>12}")
    print("-" * 70)
    
    for i, t in enumerate(test_timesteps):
        mean = pixel_stats['means'][i]
        std = pixel_stats['stds'][i]
        min_val = pixel_stats['mins'][i]
        max_val = pixel_stats['maxs'][i]
        dist = gaussian_distances[i]
        
        print(f"{t:>8} {mean:>8.3f} {std:>8.3f} {min_val:>8.3f} {max_val:>8.3f} {dist:>12.3f}")
    
    print("\nTarget: Mean=0.000, Std=1.000 for perfect Gaussian")

# Create interactive widget
interact(
    analyze_pixel_statistics,
    dataset=widgets.Dropdown(options=['mnist', 'cifar10'], value='mnist', description='Dataset:'),
    schedule_type=widgets.Dropdown(options=['linear', 'cosine', 'sigmoid'], value='cosine', description='Schedule:'),
    T=widgets.IntSlider(min=500, max=2000, step=100, value=1000, description='Timesteps T:')
);


interactive(children=(Dropdown(description='Dataset:', options=('mnist', 'cifar10'), value='mnist'), Dropdown(…

## Summary and Key Insights

This notebook demonstrates the forward diffusion process with interactive controls. Key takeaways:

1. **Schedule Comparison**: Cosine schedules typically provide more gradual noise addition early on
2. **SNR Thresholds**: Images become more noise than signal around -5dB SNR  
3. **Pixel Statistics**: Gradual convergence to standard Gaussian distribution
4. **Efficiency**: Different schedules reach noise thresholds at different rates

Use the interactive widgets above to explore these concepts with different parameters!
