# Fast-DDPM Training with Segregated Metrics for Multi-Image Super-Resolution

This notebook demonstrates Fast-DDPM training using ModelDataGenerator_1.py dataloaders with distance filtering for segregated metric evaluation (3mm gap: distance 2, and 6mm gap: distance 4).


In [1]:
# 1. Import Required Libraries
import sys
import os
import json
import time
import logging
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms.functional as TF
from skimage.metrics import structural_similarity as ssim

# Add src to path for imports
sys.path.insert(0, os.path.join(os.getcwd(), '..', 'src'))

# Import data generator
from ModelDataGenerator_1 import build_dataloader

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")


Libraries imported successfully!
PyTorch version: 2.8.0+cu128
CUDA available: True
GPU: NVIDIA B200


## 2. Configuration and Hyperparameters

In [2]:
# Configuration Parameters
CONFIG = {
    # Device
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Dataset
    'batch_size': 4,
    'num_workers': 4,
    'pin_memory': True,
    
    # Training
    'num_epochs': 20,
    'learning_rate': 1e-4,
    'weight_decay': 0.0,
    'gradient_clip': 1.0,
    
    # Diffusion Model
    'image_size': 256,
    'in_channels': 2,  # Two input images (pre and post)
    'out_channels': 1,  # One output image (middle/target)
    'num_timesteps': 1000,  # Total diffusion steps
    'fast_ddpm_steps': 10,  # Fast-DDPM uses 10 steps
    'beta_schedule': 'linear',
    'beta_start': 0.0001,
    'beta_end': 0.02,
    
    # Scheduler
    'scheduler_type': 'uniform',  # 'uniform' or 'non-uniform'
    'sample_type': 'generalized',
    
    # Optimizer
    'optimizer': 'adam',
    'eps': 1e-8,
    
    # Checkpointing
    'ckpt_freq': 5,
    'early_stopping_patience': 5,
    
    # Results
    'model_dir': '../models/fast_ddpm_v3',
    'results_file': '../results/fastddpm_v3_history.json'
}

# Create model directory
os.makedirs(CONFIG['model_dir'], exist_ok=True)
os.makedirs(os.path.dirname(CONFIG['results_file']), exist_ok=True)

print("Configuration Parameters:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")


Configuration Parameters:
  device: cuda
  batch_size: 4
  num_workers: 4
  pin_memory: True
  num_epochs: 20
  learning_rate: 0.0001
  weight_decay: 0.0
  gradient_clip: 1.0
  image_size: 256
  in_channels: 2
  out_channels: 1
  num_timesteps: 1000
  fast_ddpm_steps: 10
  beta_schedule: linear
  beta_start: 0.0001
  beta_end: 0.02
  scheduler_type: uniform
  sample_type: generalized
  optimizer: adam
  eps: 1e-08
  ckpt_freq: 5
  early_stopping_patience: 5
  model_dir: ../models/fast_ddpm_v3
  results_file: ../results/fastddpm_v3_history.json


## 3. Fast-DDPM Model Implementation

Based on mirthAI/Fast-DDPM repository architecture for super-resolution tasks.

In [57]:
def get_timestep_embedding(timesteps, embedding_dim):
    """Get sinusoidal timestep embeddings"""
    assert len(timesteps.shape) == 1
    
    half_dim = embedding_dim // 2
    # Avoid division by zero when half_dim = 1
    emb = np.log(10000) / max(half_dim - 1, 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = emb.to(timesteps.device)
    emb = timesteps[:, None].float() * emb[None, :]
    
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
    
    return emb

def nonlinearity(x):
    return x * torch.sigmoid(x)

def get_norm(in_channels, num_groups=32):
    """Get GroupNorm with adaptive num_groups if needed"""
    if in_channels % num_groups == 0:
        return nn.GroupNorm(num_groups, in_channels)
    for groups in [16, 8, 4, 2, 1]:
        if in_channels % groups == 0:
            return nn.GroupNorm(groups, in_channels)
    return nn.BatchNorm2d(in_channels)

class Upsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')
        return self.conv(x)

class Downsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1)
    
    def forward(self, x):
        return self.conv(x)

class ResnetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, temb_dim, dropout=0.1):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.norm1 = get_norm(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        
        self.temb_proj = nn.Linear(temb_dim, out_channels)
        
        self.norm2 = get_norm(out_channels)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()
    
    def forward(self, x, temb):
        h = nonlinearity(self.norm1(x))
        h = self.conv1(h)
        h += self.temb_proj(nonlinearity(temb))[:, :, None, None]
        h = nonlinearity(self.norm2(h))
        h = self.dropout(h)
        h = self.conv2(h)
        return h + self.shortcut(x)

class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.norm = get_norm(in_channels)
        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
    
    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        q = self.q(h)
        k = self.k(h)
        v = self.v(h)
        
        q = q.reshape(B, C, -1)
        k = k.reshape(B, C, -1)
        v = v.reshape(B, C, -1)
        
        attn = torch.bmm(q.transpose(1, 2), k)
        attn = attn / np.sqrt(C)
        attn = torch.nn.functional.softmax(attn, dim=-1)
        
        out = torch.bmm(v, attn.transpose(1, 2))
        out = out.reshape(B, C, H, W)
        out = self.proj_out(out)
        
        return out + x

class FastDDPMSRModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        ch = config['in_channels']  # Use in_channels directly (2 for pre + post)
        out_ch = config['out_channels']
        num_res_blocks = 2
        ch_mult = (1, 2, 4)
        num_resolutions = len(ch_mult)
        attn_resolutions = (8,)
        dropout = 0.1
        
        self.ch = ch
        self.num_resolutions = num_resolutions
        self.num_res_blocks = num_res_blocks
        
        # Timestep embedding
        self.temb_ch = ch * 4
        self.temb_dense = nn.Sequential(
            nn.Linear(ch, self.temb_ch),
            nn.SiLU(),
            nn.Linear(self.temb_ch, self.temb_ch)
        )
        
        # Input conv
        self.conv_in = nn.Conv2d(ch, ch, kernel_size=3, padding=1)
        
        # Downsampling
        self.down_blocks = nn.ModuleList()
        cur_ch = ch
        for i in range(num_resolutions):
            res_blocks = nn.ModuleList()
            attn_blocks = nn.ModuleList()
            
            out_ch_mult = ch * ch_mult[i]
            for _ in range(num_res_blocks):
                res_blocks.append(ResnetBlock(cur_ch, out_ch_mult, self.temb_ch, dropout))
                if out_ch_mult in attn_resolutions:
                    attn_blocks.append(AttnBlock(out_ch_mult))
                cur_ch = out_ch_mult
            
            down_block = nn.Module()
            down_block.res_blocks = res_blocks
            down_block.attn_blocks = attn_blocks
            self.down_blocks.append(down_block)
            
            if i != num_resolutions - 1:
                self.down_blocks.append(Downsample(cur_ch))
        
        # Middle blocks
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(cur_ch, cur_ch, self.temb_ch, dropout)
        self.mid.attn = AttnBlock(cur_ch)
        self.mid.block_2 = ResnetBlock(cur_ch, cur_ch, self.temb_ch, dropout)
        
        # Upsampling (simplified - no skip connections for now)
        self.up_blocks = nn.ModuleList()
        for i in reversed(range(num_resolutions)):
            out_ch_mult = ch * ch_mult[i]
            
            res_blocks = nn.ModuleList()
            attn_blocks = nn.ModuleList()
            
            for _ in range(num_res_blocks + 1):
                res_blocks.append(ResnetBlock(cur_ch, out_ch_mult, self.temb_ch, dropout))
                if out_ch_mult in attn_resolutions:
                    attn_blocks.append(AttnBlock(out_ch_mult))
                cur_ch = out_ch_mult
            
            up_block = nn.Module()
            up_block.res_blocks = res_blocks
            up_block.attn_blocks = attn_blocks
            self.up_blocks.append(up_block)
            
            if i != 0:
                self.up_blocks.append(Upsample(cur_ch))
        
        # Output
        self.norm_out = get_norm(cur_ch)
        self.conv_out = nn.Conv2d(cur_ch, out_ch, kernel_size=3, padding=1)
    
    def forward(self, x, t):
        """
        Args:
            x: (B, 2, H, W) - concatenated input and reference images
            t: (B,) - timestep indices
        Returns:
            output: (B, 1, H, W) - predicted noise
        """
        temb = get_timestep_embedding(t, self.ch)
        temb = self.temb_dense(temb)
        
        h = self.conv_in(x)
        
        # Downsampling
        for block in self.down_blocks:
            if isinstance(block, Downsample):
                h = block(h)
            else:
                for j, res_block in enumerate(block.res_blocks):
                    h = res_block(h, temb)
                    if j < len(block.attn_blocks):
                        h = block.attn_blocks[j](h)
        
        # Middle
        h = self.mid.block_1(h, temb)
        h = self.mid.attn(h)
        h = self.mid.block_2(h, temb)
        
        # Upsampling
        for block in self.up_blocks:
            if isinstance(block, Upsample):
                h = block(h)
            else:
                # This is the ResBlock container
                for i, res_block in enumerate(block.res_blocks):
                    h = res_block(h, temb)
                    if i < len(block.attn_blocks):
                        h = block.attn_blocks[i](h)
        
        # Output
        h = nonlinearity(self.norm_out(h))
        h = self.conv_out(h)
        
        return h

print("Fast-DDPM Model classes defined successfully!")


Fast-DDPM Model classes defined successfully!


## 4. Diffusion Utilities

In [58]:
def get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps):
    """Get beta schedule for diffusion process"""
    if beta_schedule == "quad":
        betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps) ** 2
    elif beta_schedule == "linear":
        betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps)
    elif beta_schedule == "cosine":
        s = 0.008
        steps = np.arange(0, num_diffusion_timesteps + 1, dtype=np.float64) / num_diffusion_timesteps
        alphas_cumprod = np.cos(((steps + s) / (1 + s)) * np.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        betas = np.clip(betas, a_min=0, a_max=0.9999)
    else:
        raise ValueError(f"unknown beta schedule: {beta_schedule}")
    
    return torch.from_numpy(betas).float()

class DiffusionSchedule:
    def __init__(self, betas):
        """Initialize diffusion schedule from betas"""
        self.betas = betas
        self.num_timesteps = len(betas)
        self.device = betas.device
        
        alphas = 1 - betas
        self.alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([torch.ones(1, device=self.device), self.alphas_cumprod[:-1]])
        
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.rsqrt(self.alphas_cumprod)
        self.sqrt_recip_m1_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod - 1)
        
        posterior_variance = (
            betas * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
        )
        self.posterior_variance = posterior_variance
        self.posterior_log_variance_clipped = torch.log(torch.clamp(posterior_variance, min=1e-20))
        self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
        self.posterior_mean_coef2 = (1 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1 - self.alphas_cumprod)
    
    def compute_alpha(self, t):
        """Get alpha value for timestep t"""
        return self.alphas_cumprod[t]
    
    def get_sampler_schedule(self, timesteps, scheduler_type='uniform'):
        """Get sampling schedule for Fast-DDPM"""
        if scheduler_type == 'uniform':
            skip = self.num_timesteps // timesteps
            seq = list(range(0, self.num_timesteps, skip))
            if seq[-1] != self.num_timesteps - 1:
                seq.append(self.num_timesteps - 1)
            return seq
        elif scheduler_type == 'non-uniform':
            # Non-uniform schedule from Fast-DDPM paper
            seq = [0, 199, 399, 599, 699, 799, 849, 899, 949, 999]
            if timesteps != 10:
                num_1 = int(timesteps * 0.4)
                num_2 = int(timesteps * 0.6)
                stage_1 = np.linspace(0, 699, num_1 + 1)[:-1]
                stage_2 = np.linspace(699, 999, num_2)
                seq = np.concatenate([stage_1, stage_2]).astype(int).tolist()
            return seq
        else:
            raise ValueError(f"Unknown scheduler type: {scheduler_type}")

# Initialize diffusion schedule
betas = get_beta_schedule(
    CONFIG['beta_schedule'],
    CONFIG['beta_start'],
    CONFIG['beta_end'],
    CONFIG['num_timesteps']
)
# Create on CPU first, then move to device
diffusion_schedule = DiffusionSchedule(betas.to(CONFIG['device']))

print(f"Diffusion schedule initialized with {CONFIG['num_timesteps']} timesteps")
print(f"Beta range: [{CONFIG['beta_start']}, {CONFIG['beta_end']}]")

Diffusion schedule initialized with 1000 timesteps
Beta range: [0.0001, 0.02]


## 5. Initialize Dataloaders with Distance Filtering

In [10]:
# Build dataloaders with distance filtering
# distance_filter=None: all triplets (mixed distance 2 and 4)
# distance_filter=2: only (i, i+2)->i+1 pairs [3mm gap, 1.5mm interpolation]
# distance_filter=4: only (i, i+4)->i+2 pairs [6mm gap, 3mm interpolation]

train_loader = build_dataloader(
    split='train',
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers'],
    #pin_memory=CONFIG['pin_memory'],
    distance_filter=None,  # Use all triplets for training
    #shuffle=True
)

val_loader = build_dataloader(
    split = 'val',
    batch_size = CONFIG['batch_size'],
    num_workers = CONFIG['num_workers'],
    #pin_memory = CONFIG['pin_memeory'],
    distance_filter = None,
    #shuffle = True
)

# Test loaders separated by distance
test_loader_dist2 = build_dataloader(
    split='test',
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers'],
    #pin_memory=CONFIG['pin_memory'],
    distance_filter=2,  # Only 3mm gaps
    #shuffle=False
)

test_loader_dist4 = build_dataloader(
    split='test',
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers'],
    #pin_memory=CONFIG['pin_memory'],
    distance_filter=4,  # Only 6mm gaps
    #shuffle=False
)

print(f"Train loader: {len(train_loader)} batches")
print(f"Validation loader: {len(val_loader)} batches")
print(f"Test loader (distance 2 - 3mm gap): {len(test_loader_dist2)} batches")
print(f"Test loader (distance 4 - 6mm gap): {len(test_loader_dist4)} batches")


ðŸ’¾ Pre-caching volumes into RAM for faster data loading...
âœ… Cached 641 volumes in RAM
ðŸ’¾ Pre-caching volumes into RAM for faster data loading...
âœ… Cached 113 volumes in RAM
ðŸ’¾ Pre-caching volumes into RAM for faster data loading...
âœ… Cached 160 volumes in RAM
ðŸ’¾ Pre-caching volumes into RAM for faster data loading...
âœ… Cached 160 volumes in RAM
Train loader: 18269 batches
Validation loader: 3221 batches
Test loader (distance 2 - 3mm gap): 2320 batches
Test loader (distance 4 - 6mm gap): 2240 batches

Sample batch shapes:


TypeError: list indices must be integers or slices, not str

In [11]:

# Inspect a batch
(pre, post), middle = next(iter(train_loader))
print(f"\nSample batch shapes:")
print(f"  pre: {pre.shape}")
print(f"  post: {post.shape}")
print(f"  target: {middle.shape}")


Sample batch shapes:
  pre: torch.Size([4, 1, 256, 256])
  post: torch.Size([4, 1, 256, 256])
  target: torch.Size([4, 1, 256, 256])


## 6. Model, Optimizer, and Loss Function

In [64]:
# Initialize model
model = FastDDPMSRModel(CONFIG)
model = model.to(CONFIG['device'])
model = nn.DataParallel(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}")

# Optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    betas=(0.9, 0.999),
    eps=CONFIG['eps'],
    weight_decay=CONFIG['weight_decay']
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    #verbose=True
)

# Loss function
criterion = nn.MSELoss()

print(f"Optimizer: Adam (lr={CONFIG['learning_rate']})")
print(f"Loss: MSELoss")
print(f"Device: {CONFIG['device']}")


Model Parameters: 13,971
Trainable Parameters: 13,971
Optimizer: Adam (lr=0.0001)
Loss: MSELoss
Device: cuda


## 7. Training Functions

In [65]:
import gc

def train_epoch(model, train_loader, optimizer, criterion, diffusion_schedule, device, config):
    """Train for one epoch"""
    model.train()
    total_loss = 0.0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc="Training", leave=False)
    for (pre_1, post_1), middle_1 in pbar:
        pre = pre_1.to(device).float()      # (B, 1, H, W)
        post = post_1.to(device).float()    # (B, 1, H, W)
        target = middle_1.to(device).float()  # (B, 1, H, W)
        
        # Concatenate pre and post as input
        x_input = torch.cat([pre, post], dim=1)  # (B, 2, H, W)
        
        # Sample random timesteps (antithetic sampling)
        batch_size = x_input.size(0)
        half_batch = batch_size // 2 + 1
        
        if config['scheduler_type'] == 'uniform':
            t_rand = torch.randint(0, config['num_timesteps'], (half_batch,), device=device)
            t = torch.cat([t_rand, config['num_timesteps'] - t_rand - 1], dim=0)[:batch_size]
        else:
            # Non-uniform sampling
            t_rand = torch.randint(0, config['num_timesteps'], (batch_size,), device=device)
            t = t_rand
        
        t = t.long()
        
        # Sample noise
        noise = torch.randn_like(target)
        
        # Forward diffusion process: q(x_t | x_0) = sqrt(alpha_t) * x_0 + sqrt(1 - alpha_t) * epsilon
        alpha_t = diffusion_schedule.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        std_t = diffusion_schedule.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        
        x_t = alpha_t * target + std_t * noise
        
        # Predict noise
        optimizer.zero_grad()
        noise_pred = model(x_input, t.float())
        
        # MSE loss on noise prediction
        loss = criterion(noise_pred, noise)
        
        loss.backward()
        
        # Gradient clipping
        if config['gradient_clip'] > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip'])
        
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        pbar.set_postfix({'loss': loss.item()})
        
        # Clean up intermediate variables and clear cache
        del pre, post, target, x_input, t_rand, t, noise, alpha_t, std_t, x_t, noise_pred, loss
        if device == 'cuda':
            torch.cuda.empty_cache()
        gc.collect()
    
    avg_loss = total_loss / num_batches
    return avg_loss

def evaluate(model, test_loader, criterion, diffusion_schedule, device, config):
    """Evaluate model on test set"""
    model.eval()
    total_loss = 0.0
    all_predictions = []
    all_targets = []
    num_batches = 0
    
    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Evaluating", leave=False)
        for (pre_1, post_1), middle_1 in pbar:
            pre = pre_1.to(device).float()
            post = post_1.to(device).float()
            target = middle_1.to(device).float()
            
            # Concatenate pre and post as input
            x_input = torch.cat([pre, post], dim=1)
            
            batch_size = x_input.size(0)
            
            # Sample timesteps uniformly for evaluation
            t = torch.randint(0, config['num_timesteps'], (batch_size,), device=device).long()
            
            # Sample noise
            noise = torch.randn_like(target)
            
            # Forward diffusion process
            alpha_t = diffusion_schedule.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
            std_t = diffusion_schedule.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
            
            x_t = alpha_t * target + std_t * noise
            
            # Predict noise
            noise_pred = model(x_input, t.float())
            
            # MSE loss
            loss = criterion(noise_pred, noise)
            
            total_loss += loss.item()
            
            # Denoise to get prediction of target
            predicted_target = (x_t - std_t * noise_pred) / alpha_t
            predicted_target = torch.clamp(predicted_target, -1, 1)
            
            all_predictions.append(predicted_target.cpu())
            all_targets.append(target.cpu())
            
            num_batches += 1
            pbar.set_postfix({'loss': loss.item()})
            
            # Clean up intermediate variables and clear cache
            del pre, post, target, x_input, t, noise, alpha_t, std_t, x_t, noise_pred, loss, predicted_target
            if device == 'cuda':
                torch.cuda.empty_cache()
            gc.collect()
    
    avg_loss = total_loss / num_batches
    
    predictions = torch.cat(all_predictions, dim=0).numpy()
    targets = torch.cat(all_targets, dim=0).numpy()
    
    return avg_loss, predictions, targets

print("Training functions defined successfully!")

Training functions defined successfully!


## 8. Metrics Computation

In [66]:
def compute_metrics_for_predictions(predictions, targets):
    """
    Compute SSIM and PSNR metrics
    
    Args:
        predictions: (N, 1, H, W) normalized predictions
        targets: (N, 1, H, W) normalized targets
    Returns:
        dict with SSIM and PSNR statistics
    """
    # Denormalize from [-1, 1] to [0, 1]
    predictions = (predictions + 1) / 2
    targets = (targets + 1) / 2
    
    ssim_scores = []
    psnr_scores = []
    
    for pred, targ in zip(predictions, targets):
        # Remove channel dimension if present
        if pred.ndim == 3:
            pred = pred[0]
        if targ.ndim == 3:
            targ = targ[0]
        
        pred_np = np.clip(pred, 0, 1)
        targ_np = np.clip(targ, 0, 1)
        
        # SSIM
        ssim_val = ssim(targ_np, pred_np, data_range=1.0)
        ssim_scores.append(ssim_val)
        
        # PSNR
        mse = np.mean((targ_np - pred_np) ** 2)
        if mse == 0:
            psnr_val = 100
        else:
            psnr_val = 20 * np.log10(1.0 / np.sqrt(mse))
        psnr_scores.append(psnr_val)
    
    ssim_scores = np.array(ssim_scores)
    psnr_scores = np.array(psnr_scores)
    
    metrics = {
        'ssim': {
            'mean': float(np.mean(ssim_scores)),
            'std': float(np.std(ssim_scores)),
            'min': float(np.min(ssim_scores)),
            'max': float(np.max(ssim_scores))
        },
        'psnr': {
            'mean': float(np.mean(psnr_scores)),
            'std': float(np.std(psnr_scores)),
            'min': float(np.min(psnr_scores)),
            'max': float(np.max(psnr_scores))
        }
    }
    
    return metrics, ssim_scores, psnr_scores

print("Metrics functions defined successfully!")


Metrics functions defined successfully!


## 9. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'test_loss_all': [],
    'test_loss_dist2': [],
    'test_loss_dist4': [],
    'test_metrics_all': [],
    'test_metrics_dist2': [],
    'test_metrics_dist4': [],
    'epoch_times': []
}

best_loss = float('inf')
patience_counter = 0
start_time = time.time()

print(f"Starting training for {CONFIG['num_epochs']} epochs...")
print(f"Model: Fast-DDPM with {CONFIG['num_timesteps']} diffusion steps")
print(f"Scheduler: {CONFIG['scheduler_type']}")
print("-" * 80)

for epoch in range(CONFIG['num_epochs']):
    epoch_start = time.time()
    
    # Training
    train_loss = train_epoch(model, train_loader, optimizer, criterion, diffusion_schedule, CONFIG['device'], CONFIG)
    history['train_loss'].append(train_loss)
    
    # Evaluation
    print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
    print(f"  Train Loss: {train_loss:.6f}")

    # Evaluate on distance 2 (3mm gap)
    test_loss_dist2, pred_dist2, target_dist2 = evaluate(model, test_loader_dist2, criterion, diffusion_schedule, CONFIG['device'], CONFIG)
    metrics_dist2, ssim_dist2, psnr_dist2 = compute_metrics_for_predictions(pred_dist2, target_dist2)
    history['test_loss_dist2'].append(test_loss_dist2)
    history['test_metrics_dist2'].append(metrics_dist2)
    
    print(f"  Test Loss (Dist2-3mm): {test_loss_dist2:.6f}")
    print(f"    SSIM: {metrics_dist2['ssim']['mean']:.4f} Â± {metrics_dist2['ssim']['std']:.4f}")
    print(f"    PSNR: {metrics_dist2['psnr']['mean']:.4f} Â± {metrics_dist2['psnr']['std']:.4f}")
    
    # Evaluate on distance 4 (6mm gap)
    test_loss_dist4, pred_dist4, target_dist4 = evaluate(model, test_loader_dist4, criterion, diffusion_schedule, CONFIG['device'], CONFIG)
    metrics_dist4, ssim_dist4, psnr_dist4 = compute_metrics_for_predictions(pred_dist4, target_dist4)
    history['test_loss_dist4'].append(test_loss_dist4)
    history['test_metrics_dist4'].append(metrics_dist4)
    
    print(f"  Test Loss (Dist4-6mm): {test_loss_dist4:.6f}")
    print(f"    SSIM: {metrics_dist4['ssim']['mean']:.4f} Â± {metrics_dist4['ssim']['std']:.4f}")
    print(f"    PSNR: {metrics_dist4['psnr']['mean']:.4f} Â± {metrics_dist4['psnr']['std']:.4f}")
    
    # Learning rate scheduling based on all test loss
    scheduler.step(test_loss_all)
    
    # Early stopping and checkpointing
    if test_loss_dist_2 < best_loss:
        best_loss = test_loss_dist_2
        patience_counter = 0
        
        # Save best model
        checkpoint = {
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'config': CONFIG,
            'best_loss': best_loss
        }
        ckpt_path = os.path.join(CONFIG['model_dir'], 'best_model.pth')
        torch.save(checkpoint, ckpt_path)
        print(f"  âœ“ Saved best model (loss: {best_loss:.6f})")
    else:
        patience_counter += 1
        if patience_counter >= CONFIG['early_stopping_patience']:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break
    
    # Regular checkpoint
    if (epoch + 1) % CONFIG['ckpt_freq'] == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'config': CONFIG
        }
        ckpt_path = os.path.join(CONFIG['model_dir'], f'checkpoint_epoch_{epoch+1}.pth')
        torch.save(checkpoint, ckpt_path)
    
    epoch_time = time.time() - epoch_start
    history['epoch_times'].append(epoch_time)
    print(f"  Time: {epoch_time:.2f}s")

total_time = time.time() - start_time
print("\n" + "=" * 80)
print(f"Training completed in {total_time/3600:.2f} hours")
print(f"Best test loss: {best_loss:.6f}")


Starting training for 20 epochs...
Model: Fast-DDPM with 1000 diffusion steps
Scheduler: uniform
--------------------------------------------------------------------------------


Training:  38%|â–ˆâ–ˆâ–ˆâ–Š      | 6928/18269 [21:03<24:13,  7.80it/s, loss=1]      

## 10. Save Results to JSON

In [None]:
# Save detailed results to JSON
results_dict = {
    'model_name': 'Fast-DDPM (v3)',
    'dataset': 'Prostate-MRI-US-Biopsy',
    'config': CONFIG,
    'training_summary': {
        'total_epochs_trained': len(history['train_loss']),
        'total_training_time_hours': total_time / 3600,
        'best_test_loss': best_loss,
    },
    'history': {
        'train_loss': history['train_loss'],
        'test_loss_all': history['test_loss_all'],
        'test_loss_distance_2_3mm': history['test_loss_dist2'],
        'test_loss_distance_4_6mm': history['test_loss_dist4'],
    },
    'final_metrics': {
        'all_samples': {
            'num_samples': len(pred_all),
            'distance_2_3mm_gap': history['test_metrics_dist2'][-1] if history['test_metrics_dist2'] else None,
            'distance_4_6mm_gap': history['test_metrics_dist4'][-1] if history['test_metrics_dist4'] else None,
            'combined': history['test_metrics_all'][-1] if history['test_metrics_all'] else None,
        }
    }
}

# Save to JSON
with open(CONFIG['results_file'], 'w') as f:
    json.dump(results_dict, f, indent=4)

print(f"Results saved to {CONFIG['results_file']}")
print("\nFinal Metrics Summary:")
print("\nDistance 2 (3mm gap - (i, i+2) -> i+1):")
if history['test_metrics_dist2']:
    m = history['test_metrics_dist2'][-1]
    print(f"  SSIM: {m['ssim']['mean']:.4f} Â± {m['ssim']['std']:.4f} [min: {m['ssim']['min']:.4f}, max: {m['ssim']['max']:.4f}]")
    print(f"  PSNR: {m['psnr']['mean']:.4f} Â± {m['psnr']['std']:.4f} [min: {m['psnr']['min']:.4f}, max: {m['psnr']['max']:.4f}]")

print("\nDistance 4 (6mm gap - (i, i+4) -> i+2):")
if history['test_metrics_dist4']:
    m = history['test_metrics_dist4'][-1]
    print(f"  SSIM: {m['ssim']['mean']:.4f} Â± {m['ssim']['std']:.4f} [min: {m['ssim']['min']:.4f}, max: {m['ssim']['max']:.4f}]")
    print(f"  PSNR: {m['psnr']['mean']:.4f} Â± {m['psnr']['std']:.4f} [min: {m['psnr']['min']:.4f}, max: {m['psnr']['max']:.4f}]")

print("\nCombined (All samples):")
if history['test_metrics_all']:
    m = history['test_metrics_all'][-1]
    print(f"  SSIM: {m['ssim']['mean']:.4f} Â± {m['ssim']['std']:.4f} [min: {m['ssim']['min']:.4f}, max: {m['ssim']['max']:.4f}]")
    print(f"  PSNR: {m['psnr']['mean']:.4f} Â± {m['psnr']['std']:.4f} [min: {m['psnr']['min']:.4f}, max: {m['psnr']['max']:.4f}]")


## 11. Visualization of Results

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss curves
ax = axes[0, 0]
ax.plot(history['train_loss'], label='Train Loss', marker='o')
ax.plot(history['test_loss_all'], label='Test Loss (All)', marker='s')
ax.plot(history['test_loss_dist2'], label='Test Loss (Distance 2 - 3mm)', marker='^')
ax.plot(history['test_loss_dist4'], label='Test Loss (Distance 4 - 6mm)', marker='v')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.set_title('Training and Test Loss Over Epochs')
ax.legend()
ax.grid(True, alpha=0.3)

# SSIM scores
ax = axes[0, 1]
ssim_all = [m['ssim']['mean'] for m in history['test_metrics_all']]
ssim_dist2 = [m['ssim']['mean'] for m in history['test_metrics_dist2']]
ssim_dist4 = [m['ssim']['mean'] for m in history['test_metrics_dist4']]
ax.plot(ssim_all, label='SSIM (All)', marker='o')
ax.plot(ssim_dist2, label='SSIM (Distance 2 - 3mm)', marker='s')
ax.plot(ssim_dist4, label='SSIM (Distance 4 - 6mm)', marker='^')
ax.set_xlabel('Epoch')
ax.set_ylabel('SSIM')
ax.set_title('Structural Similarity Index (SSIM) Over Epochs')
ax.legend()
ax.grid(True, alpha=0.3)

# PSNR scores
ax = axes[1, 0]
psnr_all = [m['psnr']['mean'] for m in history['test_metrics_all']]
psnr_dist2 = [m['psnr']['mean'] for m in history['test_metrics_dist2']]
psnr_dist4 = [m['psnr']['mean'] for m in history['test_metrics_dist4']]
ax.plot(psnr_all, label='PSNR (All)', marker='o')
ax.plot(psnr_dist2, label='PSNR (Distance 2 - 3mm)', marker='s')
ax.plot(psnr_dist4, label='PSNR (Distance 4 - 6mm)', marker='^')
ax.set_xlabel('Epoch')
ax.set_ylabel('PSNR (dB)')
ax.set_title('Peak Signal-to-Noise Ratio (PSNR) Over Epochs')
ax.legend()
ax.grid(True, alpha=0.3)

# Training time per epoch
ax = axes[1, 1]
ax.bar(range(len(history['epoch_times'])), history['epoch_times'], color='steelblue', alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Time (seconds)')
ax.set_title('Training Time per Epoch')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(os.path.join(CONFIG['model_dir'], 'training_history.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"Training history plot saved to {os.path.join(CONFIG['model_dir'], 'training_history.png')}")


## 12. Metrics Comparison Table

In [None]:
# Create comprehensive metrics table
metrics_data = []

if history['test_metrics_dist2']:
    m = history['test_metrics_dist2'][-1]
    metrics_data.append({
        'Model': 'Fast-DDPM',
        'Gap Type': 'Distance 2 (3mm)',
        'SSIM Mean': f"{m['ssim']['mean']:.4f}",
        'SSIM Std': f"{m['ssim']['std']:.4f}",
        'PSNR Mean': f"{m['psnr']['mean']:.2f}",
        'PSNR Std': f"{m['psnr']['std']:.2f}",
        'Num Samples': len(pred_dist2)
    })

if history['test_metrics_dist4']:
    m = history['test_metrics_dist4'][-1]
    metrics_data.append({
        'Model': 'Fast-DDPM',
        'Gap Type': 'Distance 4 (6mm)',
        'SSIM Mean': f"{m['ssim']['mean']:.4f}",
        'SSIM Std': f"{m['ssim']['std']:.4f}",
        'PSNR Mean': f"{m['psnr']['mean']:.2f}",
        'PSNR Std': f"{m['psnr']['std']:.2f}",
        'Num Samples': len(pred_dist4)
    })

if history['test_metrics_all']:
    m = history['test_metrics_all'][-1]
    metrics_data.append({
        'Model': 'Fast-DDPM',
        'Gap Type': 'Combined',
        'SSIM Mean': f"{m['ssim']['mean']:.4f}",
        'SSIM Std': f"{m['ssim']['std']:.4f}",
        'PSNR Mean': f"{m['psnr']['mean']:.2f}",
        'PSNR Std': f"{m['psnr']['std']:.2f}",
        'Num Samples': len(pred_all)
    })

metrics_df = pd.DataFrame(metrics_data)
print("\nFinal Metrics Comparison:")
print(metrics_df.to_string(index=False))

# Save metrics table as CSV
csv_path = os.path.join(CONFIG['model_dir'], 'metrics_summary.csv')
metrics_df.to_csv(csv_path, index=False)
print(f"\nMetrics table saved to {csv_path}")


## 13. Model Inference (Optional)

In [None]:
# Function to perform iterative denoising for inference
def sample_with_schedule(model, x_input, diffusion_schedule, num_steps=10, scheduler_type='uniform'):
    """
    Perform iterative denoising to generate samples
    
    Args:
        model: Trained diffusion model
        x_input: (B, 2, H, W) concatenated input and reference images
        diffusion_schedule: DiffusionSchedule object
        num_steps: Number of sampling steps
        scheduler_type: 'uniform' or 'non-uniform'
    Returns:
        generated_samples: (B, 1, H, W) generated middle slices
    """
    # Get sampling schedule
    schedule = diffusion_schedule.get_sampler_schedule(num_steps, scheduler_type)
    
    # Initialize with noise
    x_t = torch.randn(x_input.shape[0], 1, CONFIG['image_size'], CONFIG['image_size'], device=x_input.device)
    
    with torch.no_grad():
        # Reverse diffusion process
        for i in range(len(schedule) - 1):
            t_curr = schedule[i]
            t_next = schedule[i + 1] if i + 1 < len(schedule) else 0
            
            t = torch.full((x_t.shape[0],), t_curr, device=x_input.device, dtype=torch.long)
            
            # Predict noise
            noise_pred = model(torch.cat([x_input, x_t], dim=1), t.float())
            
            # Compute denoising coefficients
            alpha_t = diffusion_schedule.alphas_cumprod[t_curr]
            alpha_next = diffusion_schedule.alphas_cumprod[t_next]
            
            # x_{t-1} prediction
            x_t = (x_t - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
            
            # Add noise for next step if not at the end
            if t_next > 0:
                sigma_t = ((1 - alpha_next) * (1 - alpha_t) / (1 - alpha_next)).sqrt()
                x_t += sigma_t * torch.randn_like(x_t)
    
    return torch.clamp(x_t, -1, 1)

print("Inference function defined. To perform inference on new samples:")
print("1. Load a batch: batch = next(iter(test_loader))")
print("2. Prepare input: x_input = torch.cat([batch['pre'], batch['post']], dim=1).to(device)")
print("3. Generate: samples = sample_with_schedule(model, x_input, diffusion_schedule, num_steps=10)")
print("\nNote: For best results, load the best saved model checkpoint first.")
