# Fast-DDPM: Fast Denoising Diffusion Probabilistic Models for Medical Image Super-Resolution

## What is Fast-DDPM?

**Traditional DDPM (Denoising Diffusion Probabilistic Models):**
- Forward process: Add Gaussian noise to images over 1000 timesteps
- Reverse process: Learn to denoise images step-by-step
- Problem: Very slow - requires 1000 denoising steps at inference

**Fast-DDPM (Our Approach):**
- **Key Innovation:** Use only 10 timesteps instead of 1000
- **How:** Skip intermediate steps using accelerated sampling schedule (uniform or non-uniform)
- **Result:** 
  - Training time: **0.2x of DDPM** (5x faster)
  - Sampling time: **0.01x of DDPM** (100x faster!)
  - Quality: Same or better than DDPM

## How It Works

```
Standard DDPM (1000 steps):
t=0 → t=1 → t=2 → ... → t=999 → t=1000 (all steps)

Fast-DDPM (10 steps):
t=0 → t=100 → t=200 → ... → t=900 → t=1000 (skip intermediate steps)
      ↓
      Jump to key timesteps only
```

## For Your Medical Image Task

- **Input:** 3 consecutive slices [i, i+1, i+2]
- **Task:** Generate the middle slice (i+1) only
- **Advantage:** 
  - Simpler, faster training
  - Middle slice prediction (most stable)
  - Probabilistic approach with uncertainty
  - Can ensemble multiple predictions for better quality



In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Noise Schedule: Pre-compute 1000 timesteps, use only 10

In [None]:
class DDPMScheduler:
    def __init__(self, num_timesteps=1000, num_inference_steps=10, scheduler_type='uniform'):
        self.num_timesteps = num_timesteps
        self.num_inference_steps = num_inference_steps
        
        betas = np.linspace(0.0001, 0.02, num_timesteps)
        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas)
        
        self.betas = torch.from_numpy(betas).float()
        self.alphas = torch.from_numpy(alphas).float()
        self.alphas_cumprod = torch.from_numpy(alphas_cumprod).float()
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
        
        if scheduler_type == 'uniform':
            self.timesteps = np.linspace(0, num_timesteps - 1, num_inference_steps).astype(int)
        else:
            self.timesteps = np.ceil(np.linspace(0, num_timesteps - 1, num_inference_steps) ** 1.1).astype(int)
        
        self.timesteps = torch.from_numpy(self.timesteps).long()
    
    def add_noise(self, x0, t, noise):
        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha * x0 + sqrt_one_minus_alpha * noise

scheduler = DDPMScheduler(num_timesteps=1000, num_inference_steps=10, scheduler_type='uniform')
print(f"Timesteps: {scheduler.timesteps.numpy()}")

## UNet with Time Embedding

In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(1, dim),
            nn.SiLU(),
            nn.Linear(dim, dim)
        )
    
    def forward(self, t):
        t = t.float().unsqueeze(-1) / 1000.0
        return self.fc(t)

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.GroupNorm(32, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.norm2 = nn.GroupNorm(32, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.time_fc = nn.Linear(time_dim, out_ch)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, t_emb):
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)
        h = h + self.time_fc(t_emb)[:, :, None, None]
        h = self.norm2(h)
        h = F.silu(h)
        h = self.dropout(h)
        h = self.conv2(h)
        return h + self.skip(x)

class FastDDPMUNet(nn.Module):
    def __init__(self, in_ch=2, out_ch=1, base_ch=64, time_dim=128):
        super().__init__()
        self.time_emb = TimeEmbedding(time_dim)
        self.init_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)
        
        self.enc1 = ResBlock(base_ch, base_ch * 2, time_dim)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ResBlock(base_ch * 2, base_ch * 4, time_dim)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ResBlock(base_ch * 4, base_ch * 8, time_dim)
        self.pool3 = nn.MaxPool2d(2)
        
        self.bottleneck = ResBlock(base_ch * 8, base_ch * 8, time_dim)
        
        self.upconv3 = nn.ConvTranspose2d(base_ch * 8, base_ch * 4, 2, 2)
        self.dec3 = ResBlock(base_ch * 8, base_ch * 4, time_dim)
        self.upconv2 = nn.ConvTranspose2d(base_ch * 4, base_ch * 2, 2, 2)
        self.dec2 = ResBlock(base_ch * 4, base_ch * 2, time_dim)
        self.upconv1 = nn.ConvTranspose2d(base_ch * 2, base_ch, 2, 2)
        self.dec1 = ResBlock(base_ch * 2, base_ch, time_dim)
        
        self.final = nn.Sequential(
            nn.GroupNorm(32, base_ch),
            nn.SiLU(),
            nn.Conv2d(base_ch, out_ch, 3, padding=1)
        )
    
    def forward(self, x, t):
        t_emb = self.time_emb(t)
        
        h = self.init_conv(x)
        e1 = self.enc1(h, t_emb)
        h = self.pool1(e1)
        e2 = self.enc2(h, t_emb)
        h = self.pool2(e2)
        e3 = self.enc3(h, t_emb)
        h = self.pool3(e3)
        
        h = self.bottleneck(h, t_emb)
        
        h = self.upconv3(h)
        h = torch.cat([h, e3], dim=1)
        h = self.dec3(h, t_emb)
        h = self.upconv2(h)
        h = torch.cat([h, e2], dim=1)
        h = self.dec2(h, t_emb)
        h = self.upconv1(h)
        h = torch.cat([h, e1], dim=1)
        h = self.dec1(h, t_emb)
        
        return self.final(h)

model = FastDDPMUNet(in_ch=2, out_ch=1)
print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")

## Data Loading from ModelDataGenerator

In [None]:
sys.path.append('../src')
from ModelDataGenerator import build_dataloader

BATCH_SIZE = 4
NUM_WORKERS = 4
EPOCHS = 20
LR = 1e-4
CHECKPOINT_DIR = '../models'

train_loader = build_dataloader(split='train', batch_size=BATCH_SIZE, augment=True, num_workers=NUM_WORKERS)
val_loader = build_dataloader(split='val', batch_size=BATCH_SIZE, augment=False, num_workers=NUM_WORKERS)
test_loader = build_dataloader(split='test', batch_size=BATCH_SIZE, augment=False, num_workers=NUM_WORKERS)

print(f"Train: {len(train_loader)} | Val: {len(val_loader)} | Test: {len(test_loader)}")

## Training Loop

In [None]:
model = FastDDPMUNet(in_ch=2, out_ch=1).to(DEVICE)
scheduler_device = DDPMScheduler(num_timesteps=1000, num_inference_steps=10, scheduler_type='uniform')
for key in ['betas', 'alphas', 'alphas_cumprod', 'sqrt_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod']:
    setattr(scheduler_device, key, getattr(scheduler_device, key).to(DEVICE))
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = nn.MSELoss()

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def train_epoch():
    model.train()
    loss_sum = 0
    for (pre, post), target in tqdm(train_loader, leave=False):
        pre, post, target = pre.to(DEVICE), post.to(DEVICE), target.to(DEVICE)
        x_input = torch.cat([pre, post], dim=1)
        
        batch_size = x_input.shape[0]
        t_idx = torch.randint(0, len(scheduler_device.timesteps), (batch_size,))
        t = scheduler_device.timesteps[t_idx].to(DEVICE)
        
        noise = torch.randn_like(target)
        x_noisy = scheduler_device.add_noise(target, t, noise)
        
        pred_noise = model(x_input, t)
        loss = criterion(pred_noise, noise)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        loss_sum += loss.item()
    
    return loss_sum / len(train_loader)

def validate():
    model.eval()
    loss_sum = 0
    with torch.no_grad():
        for (pre, post), target in tqdm(val_loader, leave=False):
            pre, post, target = pre.to(DEVICE), post.to(DEVICE), target.to(DEVICE)
            x_input = torch.cat([pre, post], dim=1)
            
            batch_size = x_input.shape[0]
            t_idx = torch.randint(0, len(scheduler_device.timesteps), (batch_size,))
            t = scheduler_device.timesteps[t_idx].to(DEVICE)
            
            noise = torch.randn_like(target)
            x_noisy = scheduler_device.add_noise(target, t, noise)
            
            pred_noise = model(x_input, t)
            loss = criterion(pred_noise, noise)
            
            loss_sum += loss.item()
    
    return loss_sum / len(val_loader)

best_loss = float('inf')
history = {'epoch': [], 'train_loss': [], 'val_loss': []}

for epoch in range(1, EPOCHS + 1):
    train_loss = train_epoch()
    val_loss = validate()
    
    history['epoch'].append(epoch)
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    print(f"Epoch {epoch:2d}/{EPOCHS} | Train: {train_loss:.4f} | Val: {val_loss:.4f}", end="")
    
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), f'{CHECKPOINT_DIR}/fastddpm_best.pt')
        print(" ✅")
    else:
        print()

with open(f'{CHECKPOINT_DIR}/fastddpm_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print(f"\nBest Val Loss: {best_loss:.4f}")

## Sampling (Reverse Diffusion)

In [None]:
@torch.no_grad()
def sample(pre, post, num_samples=1):
    model.eval()
    batch_size = pre.shape[0]
    
    generated = []
    
    for _ in range(num_samples):
        x_t = torch.randn(batch_size, 1, pre.shape[2], pre.shape[3]).to(DEVICE)
        
        for step_idx, t in enumerate(reversed(scheduler_device.timesteps)):
            t = t.to(DEVICE)
            x_input = torch.cat([pre.to(DEVICE), post.to(DEVICE)], dim=1)
            
            pred_noise = model(x_input, t.unsqueeze(0).expand(batch_size))
            
            alpha = scheduler_device.alphas_cumprod[t]
            alpha_prev = scheduler_device.alphas_cumprod[scheduler_device.timesteps[max(0, step_idx - 1)]]
            
            posterior_var = (1 - alpha_prev) / (1 - alpha) * (1 - alpha / alpha_prev)
            posterior_var = torch.clamp(posterior_var, min=1e-20)
            
            if t > 0:
                noise = torch.randn_like(x_t)
            else:
                noise = torch.zeros_like(x_t)
            
            x_t = (1 / torch.sqrt(alpha)) * (x_t - (1 - alpha) / torch.sqrt(1 - alpha) * pred_noise) + torch.sqrt(posterior_var) * noise
        
        generated.append(x_t.cpu())
    
    return torch.stack(generated, dim=1)

print("Sampling ready")

## Evaluation

In [None]:
model.load_state_dict(torch.load(f'{CHECKPOINT_DIR}/fastddpm_best.pt'))

ssim_scores = []
psnr_scores = []

for (pre, post), target in tqdm(test_loader):
    generated = sample(pre, post, num_samples=3)
    pred = generated.mean(dim=1).squeeze().cpu().numpy()
    gt = target.squeeze().cpu().numpy()
    
    for i in range(len(gt)):
        gt_norm = (gt[i] - gt[i].min()) / (gt[i].max() - gt[i].min() + 1e-8)
        pred_norm = (pred[i] - pred[i].min()) / (pred[i].max() - pred[i].min() + 1e-8)
        
        ssim_scores.append(ssim(gt_norm, pred_norm, data_range=1.0))
        psnr_scores.append(psnr(gt_norm, pred_norm, data_range=1.0))

print(f"\nSSIM: {np.mean(ssim_scores):.4f} ± {np.std(ssim_scores):.4f}")
print(f"PSNR: {np.mean(psnr_scores):.2f} ± {np.std(psnr_scores):.2f} dB")

## Visualization

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(history['epoch'], history['train_loss'], 'o-', label='Train')
plt.plot(history['epoch'], history['val_loss'], 's-', label='Val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid()
plt.title('Fast-DDPM Training')
plt.tight_layout()
plt.savefig(f'{CHECKPOINT_DIR}/fastddpm_training.png', dpi=150)
plt.show()

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Noise Schedule: Accelerated Timestep Selection

Pre-compute all 1000 noise levels, but sample only 10 timesteps during training/inference.

In [None]:
class DDPMScheduler:
    def __init__(self, num_timesteps=1000, num_inference_steps=10, scheduler_type='uniform'):
        self.num_timesteps = num_timesteps
        self.num_inference_steps = num_inference_steps
        
        # Pre-compute alphas for all 1000 timesteps
        betas = np.linspace(0.0001, 0.02, num_timesteps)
        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas)
        
        self.register_buffer('betas', torch.from_numpy(betas).float())
        self.register_buffer('alphas', torch.from_numpy(alphas).float())
        self.register_buffer('alphas_cumprod', torch.from_numpy(alphas_cumprod).float())
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(torch.from_numpy(alphas_cumprod)).float())
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - torch.from_numpy(alphas_cumprod)).float())
        
        # Select timesteps: 10 instead of 1000
        if scheduler_type == 'uniform':
            self.timesteps = np.linspace(0, num_timesteps - 1, num_inference_steps).astype(int)
        else:  # non-uniform
            self.timesteps = np.ceil(np.linspace(0, num_timesteps - 1, num_inference_steps) ** 1.1).astype(int)
        
        self.timesteps = torch.from_numpy(self.timesteps).long()
    
    def register_buffer(self, name, tensor):
        setattr(self, name, tensor)
    
    def add_noise(self, x0, t, noise):
        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha * x0 + sqrt_one_minus_alpha * noise

scheduler = DDPMScheduler(num_timesteps=1000, num_inference_steps=10, scheduler_type='uniform')
print(f"Selected timesteps: {scheduler.timesteps.numpy()}")

## UNet with Time Embedding

Lightweight UNet for conditioning on timestep.

In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(1, dim),
            nn.SiLU(),
            nn.Linear(dim, dim)
        )
    
    def forward(self, t):
        t = t.float().unsqueeze(-1) / 1000.0
        return self.fc(t)

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.GroupNorm(32, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.norm2 = nn.GroupNorm(32, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.time_fc = nn.Linear(time_dim, out_ch)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, t_emb):
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)
        h = h + self.time_fc(t_emb)[:, :, None, None]
        h = self.norm2(h)
        h = F.silu(h)
        h = self.dropout(h)
        h = self.conv2(h)
        return h + self.skip(x)

class FastDDPMUNet(nn.Module):
    def __init__(self, in_ch=2, out_ch=1, base_ch=64, time_dim=128):
        super().__init__()
        self.time_emb = TimeEmbedding(time_dim)
        self.init_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)
        
        # Encoder
        self.enc1 = ResBlock(base_ch, base_ch * 2, time_dim)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ResBlock(base_ch * 2, base_ch * 4, time_dim)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ResBlock(base_ch * 4, base_ch * 8, time_dim)
        self.pool3 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = ResBlock(base_ch * 8, base_ch * 8, time_dim)
        
        # Decoder
        self.upconv3 = nn.ConvTranspose2d(base_ch * 8, base_ch * 4, 2, 2)
        self.dec3 = ResBlock(base_ch * 8, base_ch * 4, time_dim)
        self.upconv2 = nn.ConvTranspose2d(base_ch * 4, base_ch * 2, 2, 2)
        self.dec2 = ResBlock(base_ch * 4, base_ch * 2, time_dim)
        self.upconv1 = nn.ConvTranspose2d(base_ch * 2, base_ch, 2, 2)
        self.dec1 = ResBlock(base_ch * 2, base_ch, time_dim)
        
        self.final = nn.Sequential(
            nn.GroupNorm(32, base_ch),
            nn.SiLU(),
            nn.Conv2d(base_ch, out_ch, 3, padding=1)
        )
    
    def forward(self, x, t):
        t_emb = self.time_emb(t)
        
        h = self.init_conv(x)
        e1 = self.enc1(h, t_emb)
        h = self.pool1(e1)
        e2 = self.enc2(h, t_emb)
        h = self.pool2(e2)
        e3 = self.enc3(h, t_emb)
        h = self.pool3(e3)
        
        h = self.bottleneck(h, t_emb)
        
        h = self.upconv3(h)
        h = torch.cat([h, e3], dim=1)
        h = self.dec3(h, t_emb)
        h = self.upconv2(h)
        h = torch.cat([h, e2], dim=1)
        h = self.dec2(h, t_emb)
        h = self.upconv1(h)
        h = torch.cat([h, e1], dim=1)
        h = self.dec1(h, t_emb)
        
        return self.final(h)

model = FastDDPMUNet(in_ch=2, out_ch=1)
print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")

## Configuration & Data Loading

In [None]:
sys.path.append('../src')
from ModelDataGenerator import build_dataloader

BATCH_SIZE = 4
NUM_WORKERS = 4
EPOCHS = 20
LR = 1e-4
CHECKPOINT_DIR = '../models'

train_loader = build_dataloader(split='train', batch_size=BATCH_SIZE, augment=True, num_workers=NUM_WORKERS)
val_loader = build_dataloader(split='val', batch_size=BATCH_SIZE, augment=False, num_workers=NUM_WORKERS)
test_loader = build_dataloader(split='test', batch_size=BATCH_SIZE, augment=False, num_workers=NUM_WORKERS)

print(f"Train: {len(train_loader)} | Val: {len(val_loader)} | Test: {len(test_loader)}")

## Training

In [None]:
model = FastDDPMUNet(in_ch=2, out_ch=1).to(DEVICE)
scheduler_device = DDPMScheduler(num_timesteps=1000, num_inference_steps=10, scheduler_type='uniform')
for key in ['betas', 'alphas', 'alphas_cumprod', 'sqrt_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod']:
    setattr(scheduler_device, key, getattr(scheduler_device, key).to(DEVICE))
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = nn.MSELoss()

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def train_epoch():
    model.train()
    loss_sum = 0
    for (pre, post), target in tqdm(train_loader, leave=False):
        pre, post, target = pre.to(DEVICE), post.to(DEVICE), target.to(DEVICE)
        x_input = torch.cat([pre, post], dim=1)  # (B, 2, H, W)
        
        batch_size = x_input.shape[0]
        t_idx = torch.randint(0, len(scheduler_device.timesteps), (batch_size,))
        t = scheduler_device.timesteps[t_idx].to(DEVICE)
        
        noise = torch.randn_like(target)
        x_noisy = scheduler_device.add_noise(target, t, noise)
        
        pred_noise = model(x_input, t)
        loss = criterion(pred_noise, noise)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        loss_sum += loss.item()
    
    return loss_sum / len(train_loader)

def validate():
    model.eval()
    loss_sum = 0
    with torch.no_grad():
        for (pre, post), target in tqdm(val_loader, leave=False):
            pre, post, target = pre.to(DEVICE), post.to(DEVICE), target.to(DEVICE)
            x_input = torch.cat([pre, post], dim=1)
            
            batch_size = x_input.shape[0]
            t_idx = torch.randint(0, len(scheduler_device.timesteps), (batch_size,))
            t = scheduler_device.timesteps[t_idx].to(DEVICE)
            
            noise = torch.randn_like(target)
            x_noisy = scheduler_device.add_noise(target, t, noise)
            
            pred_noise = model(x_input, t)
            loss = criterion(pred_noise, noise)
            
            loss_sum += loss.item()
    
    return loss_sum / len(val_loader)

best_loss = float('inf')
history = {'epoch': [], 'train_loss': [], 'val_loss': []}

for epoch in range(1, EPOCHS + 1):
    train_loss = train_epoch()
    val_loss = validate()
    
    history['epoch'].append(epoch)
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    print(f"Epoch {epoch:2d}/{EPOCHS} | Train: {train_loss:.4f} | Val: {val_loss:.4f}", end="")
    
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), f'{CHECKPOINT_DIR}/fastddpm_best.pt')
        print(" ✅")
    else:
        print()

with open(f'{CHECKPOINT_DIR}/fastddpm_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print(f"\nBest Val Loss: {best_loss:.4f}")

## Sampling (Fast-DDPM Reverse Process)

Generate middle slice from (pre, post) using only 10 denoising steps.

In [None]:
@torch.no_grad()
def sample(pre, post, num_samples=1):
    model.eval()
    batch_size = pre.shape[0]
    
    generated = []
    
    for _ in range(num_samples):
        x_t = torch.randn(batch_size, 1, pre.shape[2], pre.shape[3]).to(DEVICE)
        
        for step_idx, t in enumerate(reversed(scheduler_device.timesteps)):
            t = t.to(DEVICE)
            x_input = torch.cat([pre.to(DEVICE), post.to(DEVICE)], dim=1)
            
            pred_noise = model(x_input, t.unsqueeze(0).expand(batch_size))
            
            alpha = scheduler_device.alphas_cumprod[t]
            alpha_prev = scheduler_device.alphas_cumprod[scheduler_device.timesteps[max(0, step_idx - 1)]]
            
            posterior_var = (1 - alpha_prev) / (1 - alpha) * (1 - alpha / alpha_prev)
            posterior_var = torch.clamp(posterior_var, min=1e-20)
            
            if t > 0:
                noise = torch.randn_like(x_t)
            else:
                noise = torch.zeros_like(x_t)
            
            x_t = (1 / torch.sqrt(alpha)) * (x_t - (1 - alpha) / torch.sqrt(1 - alpha) * pred_noise) + torch.sqrt(posterior_var) * noise
        
        generated.append(x_t.cpu())
    
    return torch.stack(generated, dim=1)  # (B, num_samples, 1, H, W)

print("Sampling function ready")

## Evaluation

In [None]:
model.load_state_dict(torch.load(f'{CHECKPOINT_DIR}/fastddpm_best.pt'))

ssim_scores = []
psnr_scores = []

for (pre, post), target in tqdm(test_loader):
    generated = sample(pre, post, num_samples=3)
    pred = generated.mean(dim=1).squeeze().cpu().numpy()
    gt = target.squeeze().cpu().numpy()
    
    for i in range(len(gt)):
        gt_norm = (gt[i] - gt[i].min()) / (gt[i].max() - gt[i].min() + 1e-8)
        pred_norm = (pred[i] - pred[i].min()) / (pred[i].max() - pred[i].min() + 1e-8)
        
        ssim_scores.append(ssim(gt_norm, pred_norm, data_range=1.0))
        psnr_scores.append(psnr(gt_norm, pred_norm, data_range=1.0))

print(f"\nSSIM: {np.mean(ssim_scores):.4f} ± {np.std(ssim_scores):.4f}")
print(f"PSNR: {np.mean(psnr_scores):.2f} ± {np.std(psnr_scores):.2f} dB")

## Visualize Training

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(history['epoch'], history['train_loss'], 'o-', label='Train')
plt.plot(history['epoch'], history['val_loss'], 's-', label='Val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid()
plt.title('Fast-DDPM Training')
plt.tight_layout()
plt.savefig(f'{CHECKPOINT_DIR}/fastddpm_training.png', dpi=150)
plt.show()