# Fast-DDPM: Improved Implementation for Medical Image Super-Resolution

This notebook implements the **Fast Denoising Diffusion Probabilistic Model (Fast-DDPM)** from https://github.com/mirthAI/Fast-DDPM for improved SSIM and PSNR scores on middle-slice prediction.

## Key Improvements Over Basic DDPM
1. **Accelerated Sampling**: Use only 10 timesteps instead of 1000 (100x faster inference)
2. **Smarter Scheduling**: Non-uniform timestep selection emphasizing important denoising stages
3. **Attention Blocks**: Self-attention at different resolutions for better feature learning
4. **Time Embeddings**: Sinusoidal positional embeddings for timestep conditioning
5. **Antithetic Sampling**: Better variance reduction during training

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import torchvision.transforms.functional as TF

# Check GPU availability
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Add src path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath('.'))))

## Core Components: DDPM Model with Attention

In [None]:
import math

def get_timestep_embedding(timesteps, embedding_dim):
    """
    Build sinusoidal embeddings for timesteps.
    This follows the implementation in Denoising Diffusion Probabilistic Models.
    """
    assert len(timesteps.shape) == 1
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = emb.to(device=timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:
        emb = F.pad(emb, (0, 1, 0, 0))
    return emb

def nonlinearity(x):
    return F.silu(x)

def Normalize(in_channels):
    return nn.GroupNorm(num_groups=min(32, in_channels // 4), num_channels=in_channels, eps=1e-6)

class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if with_conv:
            self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
    
    def forward(self, x):
        x = F.interpolate(x, scale_factor=2.0, mode='nearest')
        if self.with_conv:
            x = self.conv(x)
        return x

class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if with_conv:
            self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
    
    def forward(self, x):
        if self.with_conv:
            pad = (0, 1, 0, 1)
            x = F.pad(x, pad, mode='constant', value=0)
            x = self.conv(x)
        else:
            x = F.avg_pool2d(x, kernel_size=2, stride=2)
        return x

class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = Normalize(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.temb_proj = nn.Linear(temb_channels, out_channels)
        self.norm2 = Normalize(out_channels)
        self.dropout_layer = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        
        if in_channels != out_channels:
            if conv_shortcut:
                self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            else:
                self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
    
    def forward(self, x, temb):
        h = x
        h = self.norm1(h)
        h = nonlinearity(h)
        h = self.conv1(h)
        h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
        h = self.norm2(h)
        h = nonlinearity(h)
        h = self.dropout_layer(h)
        h = self.conv2(h)
        
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)
        
        return x + h

class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels
        self.norm = Normalize(in_channels)
        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
    
    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)
        
        # Compute attention
        b, c, h, w = q.shape
        q = q.reshape(b, c, h*w)
        q = q.permute(0, 2, 1)  # (b, hw, c)
        k = k.reshape(b, c, h*w)  # (b, c, hw)
        w_ = torch.bmm(q, k)  # (b, hw, hw)
        w_ = w_ * (int(c)**(-0.5))
        w_ = F.softmax(w_, dim=2)
        
        # Attend to values
        v = v.reshape(b, c, h*w)
        h_ = torch.bmm(v, w_.permute(0, 2, 1))
        h_ = h_.reshape(b, c, h, w)
        h_ = self.proj_out(h_)
        
        return x + h_

print("âœ“ Core components loaded (Attention blocks, ResNet blocks, embeddings)")

In [None]:
class UNetDDPM(nn.Module):
    """
    UNet architecture for DDPM with attention blocks.
    Input: (batch, 2, H, W) - two consecutive slices
    Output: (batch, 1, H, W) - noise prediction
    """
    def __init__(self, in_channels=2, out_channels=1, ch=64, ch_mult=(1,2,4,8), 
                 num_res_blocks=2, attn_resolutions=(16,), dropout=0.1, resolution=256):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.ch = ch
        self.ch_mult = ch_mult
        self.num_res_blocks = num_res_blocks
        self.attn_resolutions = attn_resolutions
        self.resolution = resolution
        self.num_resolutions = len(ch_mult)
        
        # Timestep embedding
        self.temb_ch = ch * 4
        self.temb = nn.ModuleList([
            nn.Linear(ch, self.temb_ch),
            nn.Linear(self.temb_ch, self.temb_ch),
        ])
        
        # Initial convolution
        self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
        
        # Downsampling
        self.down = nn.ModuleList()
        curr_res = resolution
        in_ch_mult = (1,) + ch_mult
        
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch * in_ch_mult[i_level]
            block_out = ch * ch_mult[i_level]
            
            for i_block in range(num_res_blocks):
                block.append(ResnetBlock(
                    in_channels=block_in, out_channels=block_out,
                    temb_channels=self.temb_ch, dropout=dropout
                ))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions - 1:
                down.downsample = Downsample(block_in, with_conv=True)
                curr_res = curr_res // 2
            self.down.append(down)
        
        # Middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in,
                                       temb_channels=self.temb_ch, dropout=dropout)
        self.mid.attn_1 = AttnBlock(block_in)
        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in,
                                       temb_channels=self.temb_ch, dropout=dropout)
        
        # Upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch * ch_mult[i_level]
            skip_in = ch * ch_mult[i_level]
            
            for i_block in range(num_res_blocks + 1):
                if i_block == num_res_blocks:
                    skip_in = ch * in_ch_mult[i_level]
                block.append(ResnetBlock(
                    in_channels=block_in + skip_in, out_channels=block_out,
                    temb_channels=self.temb_ch, dropout=dropout
                ))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, with_conv=True)
                curr_res = curr_res * 2
            self.up.insert(0, up)
        
        # Final output
        self.norm_out = Normalize(block_in)
        self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
    
    def forward(self, x, t):
        assert x.shape[2] == x.shape[3] == self.resolution
        
        # Timestep embedding
        temb = get_timestep_embedding(t, self.ch)
        temb = self.temb[0](temb)
        temb = nonlinearity(temb)
        temb = self.temb[1](temb)
        
        # Initial convolution
        h = self.conv_in(x)
        hs = [h]
        
        # Downsampling
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions - 1:
                hs.append(self.down[i_level].downsample(hs[-1]))
        
        # Middle
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)
        
        # Upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                h = self.up[i_level].block[i_block](
                    torch.cat([h, hs.pop()], dim=1), temb
                )
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)
        
        # Final output
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        
        return h

# Test model
model = UNetDDPM(in_channels=2, out_channels=1, ch=32, resolution=256)
print(f"âœ“ UNetDDPM model created")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")

## Noise Scheduling & Diffusion Process

In [None]:
def compute_alpha(beta, t):
    """Compute cumulative product of alphas up to timestep t"""
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a

def get_beta_schedule(beta_schedule, beta_start=0.0001, beta_end=0.02, num_diffusion_timesteps=1000):
    """Get noise schedule for diffusion process"""
    if beta_schedule == 'linear':
        betas = torch.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=torch.float64)
    elif beta_schedule == 'scaled_linear':
        betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_diffusion_timesteps, dtype=torch.float64) ** 2
    else:
        raise NotImplementedError(f"beta schedule {beta_schedule} unknown")
    
    return betas

def generalized_steps(x, seq, model, b, device, eta=0.0):
    """
    Fast-DDPM sampling with generalized schedule.
    x: noisy sample (B, C, H, W)
    seq: list of timesteps to use for sampling
    model: UNet model
    b: beta schedule
    """
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]
        
        for i, j in zip(reversed(seq), reversed(seq_next)):
            t = (torch.ones(n) * i).to(device)
            next_t = (torch.ones(n) * j).to(device)
            at = compute_alpha(b, t.long())
            at_next = compute_alpha(b, next_t.long())
            
            xt = xs[-1].to(device)
            et = model(xt, t)
            
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            x0_preds.append(x0_t.to('cpu'))
            
            # Equation (12) from DDIM paper
            c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
            c2 = ((1 - at_next) - c1 ** 2).sqrt()
            
            xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
            xs.append(xt_next.to('cpu'))
    
    return xs, x0_preds

class FastDDPMScheduler:
    """
    Manages timestep schedules for Fast-DDPM.
    Supports uniform and non-uniform sampling strategies.
    """
    def __init__(self, num_timesteps=1000, num_inference_steps=10, scheduler_type='uniform'):
        self.num_timesteps = num_timesteps
        self.num_inference_steps = num_inference_steps
        self.scheduler_type = scheduler_type
        
        if scheduler_type == 'uniform':
            # Uniform spacing
            skip = num_timesteps // num_inference_steps
            self.timesteps = np.arange(0, num_timesteps, skip)
        elif scheduler_type == 'non-uniform':
            # Non-uniform: emphasize early denoising
            # Default 10 steps from Fast-DDPM paper
            if num_inference_steps == 10:
                self.timesteps = np.array([0, 99, 199, 299, 399, 499, 599, 699, 799, 999])
            else:
                # Custom non-uniform schedule
                num_1 = int(num_inference_steps * 0.4)
                num_2 = int(num_inference_steps * 0.6)
                stage_1 = np.linspace(0, 699, num_1 + 1)[:-1]
                stage_2 = np.linspace(699, 999, num_2)
                self.timesteps = np.concatenate([stage_1, stage_2]).astype(int)
        else:
            raise ValueError(f"Unknown scheduler type: {scheduler_type}")
    
    def get_timesteps(self):
        return torch.from_numpy(self.timesteps).long()

# Test scheduler
scheduler_uniform = FastDDPMScheduler(num_timesteps=1000, num_inference_steps=10, scheduler_type='uniform')
scheduler_nonuniform = FastDDPMScheduler(num_timesteps=1000, num_inference_steps=10, scheduler_type='non-uniform')

print(f"âœ“ Schedulers created")
print(f"  Uniform timesteps: {scheduler_uniform.get_timesteps().numpy()}")
print(f"  Non-uniform timesteps: {scheduler_nonuniform.get_timesteps().numpy()}")

## Training Setup & Data Loading

In [None]:
# Import data generators from src
sys.path.insert(0, os.path.join(os.getcwd(), '..', 'src'))
from ModelDataGenerator import build_dataloader

# Load data for training
print("ðŸ“¥ Loading training data...")
train_loader = build_dataloader(split='train', batch_size=4, augment=True, num_workers=4)
val_loader = build_dataloader(split='val', batch_size=4, augment=False, num_workers=4)
test_loader = build_dataloader(split='test', batch_size=1, augment=False, num_workers=0)

print(f"âœ“ Data loaded")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

# Verify data shape
for (pre, post), mid in train_loader:
    print(f"  Batch shape: pre={pre.shape}, post={post.shape}, mid={mid.shape}")
    break

## Fast-DDPM Training Loop

In [None]:
# Initialize model and training components
model = UNetDDPM(
    in_channels=2,  # pre and post slices
    model_channels=32,
    out_channels=1,  # predict middle slice
    num_res_blocks=2,
    attention_resolutions=[16]  # attention at 16x16 resolution
)

# Move to device
model = model.to(device)

# Initialize diffusion process
num_timesteps = 1000
beta_schedule = get_beta_schedule("linear", num_timesteps, device)

# Optimizer and scheduler
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

# EMA (Exponential Moving Average) for model weights
ema_rate = 0.9999
ema_model = UNetDDPM(
    in_channels=2,
    model_channels=32,
    out_channels=1,
    num_res_blocks=2,
    attention_resolutions=[16]
).to(device)
ema_model.load_state_dict(model.state_dict())

def update_ema(model, ema_model, ema_rate):
    """Update EMA model weights"""
    for param, ema_param in zip(model.parameters(), ema_model.parameters()):
        ema_param.data = ema_param.data * ema_rate + param.data * (1 - ema_rate)

# Training function
def train_epoch(model, train_loader, optimizer, beta_schedule, device, ema_model=None, ema_rate=0.9999):
    """Train one epoch with antithetic sampling for variance reduction"""
    model.train()
    total_loss = 0.0
    num_batches = 0
    
    for batch_idx, (pre, post) in enumerate(train_loader):
        pre = pre.to(device)  # (B, 1, 256, 256)
        post = post.to(device)  # (B, 1, 256, 256)
        
        # Concatenate pre and post as input
        x_t = torch.cat([pre, post], dim=1)  # (B, 2, 256, 256)
        
        # Antithetic sampling: use pairs of opposite timesteps for variance reduction
        batch_size = x_t.shape[0]
        # Sample timesteps (uniform) for first half
        t_half = torch.randint(0, num_timesteps, size=(batch_size // 2 + 1,), device=device)
        # Create antithetic pair (opposite timesteps)
        t = torch.cat([t_half, num_timesteps - 1 - t_half])[:batch_size]
        
        # Sample noise
        noise = torch.randn_like(pre)  # (B, 1, 256, 256)
        
        # Compute alpha_bar(t) for each timestep
        alpha_bar_t = compute_alpha(beta_schedule, t)  # (B,)
        
        # Add noise to image: x_t = sqrt(alpha_bar_t) * x + sqrt(1 - alpha_bar_t) * noise
        noisy_x = (
            alpha_bar_t.sqrt().view(-1, 1, 1, 1) * pre +
            (1.0 - alpha_bar_t).sqrt().view(-1, 1, 1, 1) * noise
        )  # (B, 1, 256, 256)
        
        # Forward pass through model (predict noise)
        input_x = torch.cat([noisy_x, post], dim=1)  # (B, 2, 256, 256)
        predicted_noise = model(input_x, t)  # (B, 1, 256, 256)
        
        # Loss: MSE between predicted noise and actual noise
        loss = torch.nn.functional.mse_loss(predicted_noise, noise)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.functional.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Update EMA model
        if ema_model is not None:
            update_ema(model, ema_model, ema_rate)
        
        total_loss += loss.item()
        num_batches += 1
        
        if (batch_idx + 1) % 10 == 0:
            print(f"  Batch {batch_idx + 1}/{len(train_loader)}: Loss = {loss.item():.6f}")
    
    avg_loss = total_loss / num_batches
    return avg_loss

# Validation function
def validate(model, val_loader, beta_schedule, device):
    """Validate model on validation set"""
    model.eval()
    total_loss = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for pre, post in val_loader:
            pre = pre.to(device)
            post = post.to(device)
            x_t = torch.cat([pre, post], dim=1)
            
            # Sample random timesteps
            batch_size = x_t.shape[0]
            t = torch.randint(0, num_timesteps, size=(batch_size,), device=device)
            
            # Sample noise and create noisy version
            noise = torch.randn_like(pre)
            alpha_bar_t = compute_alpha(beta_schedule, t)
            noisy_x = (
                alpha_bar_t.sqrt().view(-1, 1, 1, 1) * pre +
                (1.0 - alpha_bar_t).sqrt().view(-1, 1, 1, 1) * noise
            )
            
            # Forward pass
            input_x = torch.cat([noisy_x, post], dim=1)
            predicted_noise = model(input_x, t)
            
            # Compute loss
            loss = torch.nn.functional.mse_loss(predicted_noise, noise)
            total_loss += loss.item()
            num_batches += 1
    
    avg_loss = total_loss / num_batches
    return avg_loss

# Training loop
num_epochs = 50
best_val_loss = float('inf')
patience = 10
patience_counter = 0
train_losses = []
val_losses = []

print("Starting training...")
for epoch in range(1, num_epochs + 1):
    print(f"\nEpoch {epoch}/{num_epochs}")
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, beta_schedule, device, ema_model, ema_rate)
    train_losses.append(train_loss)
    
    # Validate
    val_loss = validate(model, val_loader, beta_schedule, device)
    val_losses.append(val_loss)
    
    print(f"Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")
    
    # Learning rate scheduling
    scheduler.step()
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        print(f"  â†’ Best validation loss! Saving checkpoint...")
        torch.save({
            'model_state': model.state_dict(),
            'ema_state': ema_model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'epoch': epoch,
            'val_loss': val_loss
        }, 'fastddpm_best.pth')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

print(f"\nTraining complete! Best validation loss: {best_val_loss:.6f}")

# Load best model
checkpoint = torch.load('fastddpm_best.pth')
ema_model.load_state_dict(checkpoint['ema_state'])
print("Loaded best EMA model for inference")


## Inference & Evaluation with Different Schedulers

In [None]:
def inference(model, pre, post, beta_schedule, scheduler_type='non-uniform', device=None):
    """
    Run inference using generalized_steps sampling
    Args:
        model: diffusion model
        pre: pre-slice (1, 256, 256)
        post: post-slice (1, 256, 256)
        beta_schedule: beta schedule tensor
        scheduler_type: 'uniform' or 'non-uniform'
        device: device to run on
    Returns:
        predicted middle slice (1, 256, 256)
    """
    model.eval()
    
    # Create batch of 1
    pre = pre.unsqueeze(0).to(device)  # (1, 1, 256, 256)
    post = post.unsqueeze(0).to(device)  # (1, 1, 256, 256)
    x = torch.cat([pre, post], dim=1)  # (1, 2, 256, 256)
    
    # Initialize noise
    x_t = torch.randn(1, 1, 256, 256, device=device)  # Start from pure noise
    
    # Get scheduler
    scheduler = FastDDPMScheduler(num_timesteps=1000)
    if scheduler_type == 'uniform':
        timesteps = scheduler.get_uniform_schedule(skip=100)  # 10 steps
    else:  # non-uniform
        timesteps = scheduler.get_non_uniform_schedule()  # Fast-DDPM schedule
    
    # Reverse diffusion process
    with torch.no_grad():
        for idx, t_cur in enumerate(timesteps):
            # Time embedding
            t = torch.tensor([t_cur] * 1, device=device)
            
            # Model prediction
            input_x = torch.cat([x_t, post], dim=1)
            noise_pred = model(input_x, t)
            
            # Denoising step using generalized_steps logic
            # Get noise schedule values
            if idx < len(timesteps) - 1:
                t_next = timesteps[idx + 1]
            else:
                t_next = 0
            
            # Simplified DDIM step (eta=0 for deterministic)
            alpha_t = compute_alpha(beta_schedule, t)
            alpha_next = compute_alpha(beta_schedule, torch.tensor([t_next], device=device))
            
            sigma_t = 0.0  # eta=0, deterministic
            alpha_t_next = alpha_next.view(-1, 1, 1, 1)
            alpha_t = alpha_t.view(-1, 1, 1, 1)
            
            # x_t = (x_t - sqrt(1-alpha_t) * noise_pred) / sqrt(alpha_t)
            x_t = (x_t - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
            x_t = x_t * alpha_t_next.sqrt() + (1 - alpha_t_next).sqrt() * noise_pred
            
            if (idx + 1) % max(1, len(timesteps) // 10) == 0:
                print(f"  Step {idx + 1}/{len(timesteps)}")
    
    # Clip to valid range
    x_t = torch.clamp(x_t, -1.0, 1.0)
    
    return x_t.squeeze(0).cpu()

# Compute metrics function
def compute_metrics(predictions, ground_truth):
    """Compute SSIM and PSNR between predictions and ground truth"""
    from skimage.metrics import structural_similarity as ssim
    
    preds = predictions.numpy() if isinstance(predictions, torch.Tensor) else predictions
    gts = ground_truth.numpy() if isinstance(ground_truth, torch.Tensor) else ground_truth
    
    # Normalize to [0, 1] if needed
    if preds.min() < 0:
        preds = (preds + 1) / 2
    if gts.min() < 0:
        gts = (gts + 1) / 2
    
    # Compute SSIM
    ssim_val = ssim(gts, preds, data_range=1.0)
    
    # Compute PSNR (max value is 1.0)
    mse = np.mean((preds - gts) ** 2)
    psnr_val = 20 * np.log10(1.0 / np.sqrt(mse)) if mse > 0 else 100
    
    return ssim_val, psnr_val

# Evaluation on test set with both schedulers
print("Running inference on test set with different schedulers...\n")

results = {
    'uniform': {'ssim': [], 'psnr': []},
    'non_uniform': {'ssim': [], 'psnr': []}
}

test_samples = []
max_samples = 20  # Limit to 20 samples for speed

with torch.no_grad():
    for sample_idx, (pre, post, mid) in enumerate(test_loader):
        if sample_idx >= max_samples:
            break
        
        print(f"Processing sample {sample_idx + 1}/{min(max_samples, len(test_loader))}...")
        
        # Try uniform scheduler
        try:
            pred_uniform = inference(ema_model, pre.squeeze(), post.squeeze(), 
                                     beta_schedule, scheduler_type='uniform', device=device)
            ssim_u, psnr_u = compute_metrics(pred_uniform, mid.squeeze())
            results['uniform']['ssim'].append(ssim_u)
            results['uniform']['psnr'].append(psnr_u)
        except Exception as e:
            print(f"  Error with uniform: {e}")
            ssim_u, psnr_u = None, None
        
        # Try non-uniform scheduler
        try:
            pred_nonuniform = inference(ema_model, pre.squeeze(), post.squeeze(), 
                                       beta_schedule, scheduler_type='non-uniform', device=device)
            ssim_nu, psnr_nu = compute_metrics(pred_nonuniform, mid.squeeze())
            results['non_uniform']['ssim'].append(ssim_nu)
            results['non_uniform']['psnr'].append(psnr_nu)
        except Exception as e:
            print(f"  Error with non-uniform: {e}")
            ssim_nu, psnr_nu = None, None
        
        if ssim_u and ssim_nu:
            print(f"  Uniform - SSIM: {ssim_u:.4f}, PSNR: {psnr_u:.2f}")
            print(f"  Non-uniform - SSIM: {ssim_nu:.4f}, PSNR: {psnr_nu:.2f}\n")
        
        # Store first few samples for visualization
        if sample_idx < 3:
            test_samples.append({
                'pre': pre.squeeze(),
                'post': post.squeeze(),
                'mid_gt': mid.squeeze(),
                'pred_uniform': pred_uniform if 'pred_uniform' in locals() else None,
                'pred_nonuniform': pred_nonuniform if 'pred_nonuniform' in locals() else None
            })

# Print summary statistics
print("\n" + "="*60)
print("RESULTS SUMMARY")
print("="*60)

if results['uniform']['ssim']:
    uniform_ssim_mean = np.mean(results['uniform']['ssim'])
    uniform_ssim_std = np.std(results['uniform']['ssim'])
    uniform_psnr_mean = np.mean(results['uniform']['psnr'])
    uniform_psnr_std = np.std(results['uniform']['psnr'])
    
    print(f"Uniform Scheduler (10-step):")
    print(f"  SSIM: {uniform_ssim_mean:.4f} Â± {uniform_ssim_std:.4f}")
    print(f"  PSNR: {uniform_psnr_mean:.2f} Â± {uniform_psnr_std:.2f}")

if results['non_uniform']['ssim']:
    nonuniform_ssim_mean = np.mean(results['non_uniform']['ssim'])
    nonuniform_ssim_std = np.std(results['non_uniform']['ssim'])
    nonuniform_psnr_mean = np.mean(results['non_uniform']['psnr'])
    nonuniform_psnr_std = np.std(results['non_uniform']['psnr'])
    
    print(f"\nNon-Uniform Scheduler (Fast-DDPM):")
    print(f"  SSIM: {nonuniform_ssim_mean:.4f} Â± {nonuniform_ssim_std:.4f}")
    print(f"  PSNR: {nonuniform_psnr_mean:.2f} Â± {nonuniform_psnr_std:.2f}")

if results['uniform']['ssim'] and results['non_uniform']['ssim']:
    ssim_improvement = (nonuniform_ssim_mean - uniform_ssim_mean) / uniform_ssim_mean * 100
    psnr_improvement = (nonuniform_psnr_mean - uniform_psnr_mean) / uniform_psnr_mean * 100
    print(f"\nNon-Uniform vs Uniform Improvement:")
    print(f"  SSIM: {ssim_improvement:+.2f}%")
    print(f"  PSNR: {psnr_improvement:+.2f}%")
print("="*60)


## Visualization & Results Comparison

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Training loss curve
axes[0].plot(train_losses, label='Train Loss', linewidth=2)
axes[0].plot(val_losses, label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('MSE Loss', fontsize=12)
axes[0].set_title('Fast-DDPM Training Progress', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
axes[0].set_yscale('log')

# Scheduler comparison
if results['uniform']['ssim'] and results['non_uniform']['ssim']:
    schedulers = ['Uniform\n(skip=100)', 'Non-Uniform\n(Fast-DDPM)']
    ssim_means = [uniform_ssim_mean, nonuniform_ssim_mean]
    ssim_stds = [uniform_ssim_std, nonuniform_ssim_std]
    colors = ['#3498db', '#e74c3c']
    
    x_pos = np.arange(len(schedulers))
    bars = axes[1].bar(x_pos, ssim_means, yerr=ssim_stds, capsize=5, 
                       color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
    axes[1].set_ylabel('SSIM', fontsize=12)
    axes[1].set_title('Scheduler Comparison (Test Set)', fontsize=14, fontweight='bold')
    axes[1].set_xticks(x_pos)
    axes[1].set_xticklabels(schedulers, fontsize=11)
    axes[1].grid(True, alpha=0.3, axis='y')
    axes[1].set_ylim([0, 1.0])
    
    # Add value labels on bars
    for i, (bar, mean, std) in enumerate(zip(bars, ssim_means, ssim_stds)):
        height = bar.get_height()
        axes[1].text(bar.get_x() + bar.get_width()/2., height + std + 0.02,
                    f'{mean:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig('fastddpm_training_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("Training curves saved to 'fastddpm_training_results.png'")

# Visualize sample predictions
if test_samples:
    num_samples = len(test_samples)
    
    # Normalize to [0, 1] for visualization
    def normalize_for_display(x):
        x = x.numpy() if isinstance(x, torch.Tensor) else x
        if x.min() < 0:
            x = (x + 1) / 2
        return np.clip(x, 0, 1)
    
    fig, axes = plt.subplots(num_samples, 5, figsize=(15, 3*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for row, sample in enumerate(test_samples):
        # Pre-slice
        pre_norm = normalize_for_display(sample['pre'])
        axes[row, 0].imshow(pre_norm, cmap='gray')
        axes[row, 0].set_title('Pre-slice', fontsize=11, fontweight='bold')
        axes[row, 0].axis('off')
        
        # Post-slice
        post_norm = normalize_for_display(sample['post'])
        axes[row, 1].imshow(post_norm, cmap='gray')
        axes[row, 1].set_title('Post-slice', fontsize=11, fontweight='bold')
        axes[row, 1].axis('off')
        
        # Ground truth
        gt_norm = normalize_for_display(sample['mid_gt'])
        axes[row, 2].imshow(gt_norm, cmap='gray')
        axes[row, 2].set_title('Ground Truth', fontsize=11, fontweight='bold')
        axes[row, 2].axis('off')
        
        # Prediction - Uniform
        if sample['pred_uniform'] is not None:
            pred_u_norm = normalize_for_display(sample['pred_uniform'])
            axes[row, 3].imshow(pred_u_norm, cmap='gray')
            axes[row, 3].set_title('Pred (Uniform)', fontsize=11, fontweight='bold')
            axes[row, 3].axis('off')
        else:
            axes[row, 3].text(0.5, 0.5, 'No prediction', ha='center', va='center')
            axes[row, 3].axis('off')
        
        # Prediction - Non-uniform
        if sample['pred_nonuniform'] is not None:
            pred_nu_norm = normalize_for_display(sample['pred_nonuniform'])
            axes[row, 4].imshow(pred_nu_norm, cmap='gray')
            axes[row, 4].set_title('Pred (Non-Uniform)', fontsize=11, fontweight='bold')
            axes[row, 4].axis('off')
        else:
            axes[row, 4].text(0.5, 0.5, 'No prediction', ha='center', va='center')
            axes[row, 4].axis('off')
    
    plt.tight_layout()
    plt.savefig('fastddpm_predictions.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"Saved predictions for {num_samples} samples to 'fastddpm_predictions.png'")

# Comparison with baseline models
print("\n" + "="*70)
print("FAST-DDPM vs OTHER MODELS (from README)")
print("="*70)

baseline_results = {
    'Model': ['UNet', 'DeepCNN', 'UNet-GAN', 'Progressive UNet', 'FastDDPM (Uniform)', 'FastDDPM (Non-Uniform)'],
    'SSIM': [0.711, 0.710, 0.760, 0.724, 
             uniform_ssim_mean if results['uniform']['ssim'] else 'N/A',
             nonuniform_ssim_mean if results['non_uniform']['ssim'] else 'N/A'],
    'PSNR': [23.61, 23.61, 28.57, 26.97,
             uniform_psnr_mean if results['uniform']['psnr'] else 'N/A',
             nonuniform_psnr_mean if results['non_uniform']['psnr'] else 'N/A']
}

import pandas as pd
df_comparison = pd.DataFrame(baseline_results)
print(df_comparison.to_string(index=False))
print("="*70)

print("\nâœ“ Fast-DDPM implementation complete!")
print("âœ“ Key improvements from standard DDPM:")
print("  - Sinusoidal positional embeddings for timesteps")
print("  - Multi-resolution attention blocks (AttnBlock at 16x resolution)")
print("  - Non-uniform noise schedule (faster convergence)")
print("  - Generalized DDIM sampling (10 steps instead of 1000)")
print("  - Antithetic sampling during training (variance reduction)")
print("âœ“ Results saved: 'fastddpm_training_results.png', 'fastddpm_predictions.png'")
