## 1. Setup and Imports

In [None]:
import os
import sys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt

# Setup device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Load Data

In [None]:
# Add src to path and import data generator
sys.path.insert(0, '../src')
from ModelDataGenerator import build_dataloader

# Configuration
BATCH_SIZE = 4
NUM_WORKERS = 4
EPOCHS = 20
LR = 1e-4
CHECKPOINT_DIR = '../models'

# Build dataloaders
train_loader = build_dataloader(split='train', batch_size=BATCH_SIZE, augment=True, num_workers=NUM_WORKERS)
val_loader = build_dataloader(split='val', batch_size=BATCH_SIZE, augment=False, num_workers=NUM_WORKERS)
test_loader = build_dataloader(split='test', batch_size=BATCH_SIZE, augment=False, num_workers=NUM_WORKERS)

print(f"‚úÖ Data loaded: Train={len(train_loader)}, Val={len(val_loader)}, Test={len(test_loader)}")

# Check one batch
(pre_sample, post_sample), target_sample = next(iter(train_loader))
print(f"\nBatch shapes:")
print(f"  pre: {pre_sample.shape}")
print(f"  post: {post_sample.shape}")
print(f"  target: {target_sample.shape}")

## 3. Noise Schedule (DDPM)

In [None]:
class DDPMScheduler:
    def __init__(self, num_timesteps=1000, num_inference_steps=10, scheduler_type='non-uniform'):
        
        self.num_timesteps = num_timesteps
        self.num_inference_steps = num_inference_steps
        self.scheduler_type = scheduler_type
        
        # Pre-compute noise schedule (same as standard DDPM)
        betas = torch.linspace(0.0001, 0.02, num_timesteps)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod))
        
        # Select timesteps based on strategy
        if scheduler_type == 'uniform':
            # Uniform spacing: every skip-th timestep
            skip = num_timesteps // num_inference_steps
            self.timesteps = torch.arange(0, num_timesteps, skip).long()[:num_inference_steps]
        
        elif scheduler_type == 'non-uniform':
            # Non-uniform: concentrate in noisy region
            if num_inference_steps == 10:
                # Exact from paper: [0, 199, 399, 599, 699, 799, 849, 899, 949, 999]
                self.timesteps = torch.tensor([0, 199, 399, 599, 699, 799, 849, 899, 949, 999]).long()
            else:
                # Adaptive non-uniform for other step counts
                # 40% in early stage (0-699), 60% in late stage (699-999)
                num_stage1 = int(num_inference_steps * 0.4)
                num_stage2 = int(num_inference_steps * 0.6)
                
                if num_stage1 > 0:
                    stage1 = torch.linspace(0, 699, num_stage1 + 1)[:-1].ceil().long()
                else:
                    stage1 = torch.tensor([]).long()
                
                stage2 = torch.linspace(699, 999, num_stage2 + 1)[:-1].ceil().long()
                self.timesteps = torch.cat([stage1, stage2])
        
        else:
            raise ValueError(f"Unknown scheduler_type: {scheduler_type}")
    
    def register_buffer(self, name, tensor):
        setattr(self, name, tensor)
    
    def add_noise(self, x0, t, noise):
        """Forward process: x_t = sqrt(alpha_t)*x0 + sqrt(1-alpha_t)*eps"""
        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha * x0 + sqrt_one_minus_alpha * noise

# Create scheduler (change to 'uniform' if preferred)
scheduler = DDPMScheduler(num_timesteps=1000, num_inference_steps=10, scheduler_type='non-uniform')
print(f"‚úÖ Scheduler created with type: 'non-uniform'")
print(f"‚úÖ Inference steps: {scheduler.timesteps.tolist()}")

## 4. Model Architecture (with 3 input channels)

In [None]:
class FastDDPM(nn.Module):
    """Fast DDPM UNet for conditional denoising"""
    def __init__(self, in_ch=3, out_ch=1, base_ch=64, time_dim=128):
        super().__init__()
        self.time_emb = TimeEmbedding(time_dim)
        self.init_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)
        
        # Encoder
        self.enc1 = ResBlock(base_ch, base_ch * 2, time_dim)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ResBlock(base_ch * 2, base_ch * 4, time_dim)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ResBlock(base_ch * 4, base_ch * 8, time_dim)
        self.pool3 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = ResBlock(base_ch * 8, base_ch * 8, time_dim)
        
        # Decoder with correct channel dimensions for skip connections
        # After upconv3: base_ch*4, concatenate with e3 (base_ch*8) -> base_ch*4 + base_ch*8 = base_ch*12
        self.upconv3 = nn.ConvTranspose2d(base_ch * 8, base_ch * 4, 2, 2)
        self.dec3 = ResBlock(base_ch * 4 + base_ch * 8, base_ch * 4, time_dim)  # in_ch = concatenated channels
        
        # After upconv2: base_ch*2, concatenate with e2 (base_ch*4) -> base_ch*2 + base_ch*4 = base_ch*6
        self.upconv2 = nn.ConvTranspose2d(base_ch * 4, base_ch * 2, 2, 2)
        self.dec2 = ResBlock(base_ch * 2 + base_ch * 4, base_ch * 2, time_dim)
        
        # After upconv1: base_ch, concatenate with e1 (base_ch*2) -> base_ch + base_ch*2 = base_ch*3
        self.upconv1 = nn.ConvTranspose2d(base_ch * 2, base_ch, 2, 2)
        self.dec1 = ResBlock(base_ch + base_ch * 2, base_ch, time_dim)
        
        # Final output
        ng_final = max(1, base_ch // 4)
        self.final = nn.Sequential(
            nn.GroupNorm(ng_final, base_ch),
            nn.SiLU(),
            nn.Conv2d(base_ch, out_ch, 3, padding=1)
        )
    
    def forward(self, x, t):
        # Time embedding
        t_emb = self.time_emb(t)
        
        # Initial conv
        h = self.init_conv(x)
        
        # Encoder with skip connections
        e1 = self.enc1(h, t_emb)
        h = self.pool1(e1)
        e2 = self.enc2(h, t_emb)
        h = self.pool2(e2)
        e3 = self.enc3(h, t_emb)
        h = self.pool3(e3)
        
        # Bottleneck
        h = self.bottleneck(h, t_emb)
        
        # Decoder with skip connections
        h = self.upconv3(h)
        h = torch.cat([h, e3], dim=1)
        h = self.dec3(h, t_emb)
        
        h = self.upconv2(h)
        h = torch.cat([h, e2], dim=1)
        h = self.dec2(h, t_emb)
        
        h = self.upconv1(h)
        h = torch.cat([h, e1], dim=1)
        h = self.dec1(h, t_emb)
        
        return self.final(h)

# Create model with 3 input channels: [pre, post, noisy_target]
model = FastDDPM(in_ch=3, out_ch=1, base_ch=64, time_dim=128).to(DEVICE)
num_params = sum(p.numel() for p in model.parameters())
print(f"‚úÖ Model created with 3 input channels: {num_params:,} parameters")

## 5. Training Setup

In [None]:
# ========== Checkpoint Utilities ==========
def get_latest_checkpoint(checkpoint_dir, prefix='fastddpm_checkpoint'):
    """Get the latest checkpoint file by epoch number"""
    from pathlib import Path
    checkpoint_dir = Path(checkpoint_dir)
    checkpoint_files = list(checkpoint_dir.glob(f'{prefix}_*.pt'))
    
    if not checkpoint_files:
        return None
    
    # Extract epoch numbers and sort
    checkpoints_with_epochs = []
    for ckpt in checkpoint_files:
        try:
            epoch = int(ckpt.stem.split('_')[-1])
            checkpoints_with_epochs.append((epoch, ckpt))
        except ValueError:
            continue
    
    if not checkpoints_with_epochs:
        return None
    
    # Return path of checkpoint with highest epoch
    latest_epoch, latest_ckpt = max(checkpoints_with_epochs, key=lambda x: x[0])
    return latest_ckpt, latest_epoch


def load_checkpoint(model, optimizer, scheduler, checkpoint_path, device):
    """Load checkpoint and return starting epoch and training state"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    start_epoch = checkpoint.get('epoch', 0) + 1
    history = checkpoint.get('history', {'epoch': [], 'train_loss': [], 'val_loss': []})
    best_loss = checkpoint.get('best_loss', float('inf'))
    
    return start_epoch, history, best_loss, checkpoint


def save_checkpoint(model, optimizer, epoch, history, best_loss, checkpoint_path):
    """Save checkpoint"""
    from datetime import datetime
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'history': history,
        'best_loss': best_loss,
        'timestamp': datetime.now().isoformat()
    }
    torch.save(checkpoint, checkpoint_path)

print("‚úÖ Checkpoint utilities defined")

# ========== Setup training with checkpoint support ==========
# Check for existing checkpoint
print("\nüîç Checking for existing checkpoints...")
latest_ckpt_info = get_latest_checkpoint(CHECKPOINT_DIR, prefix='fastddpm_checkpoint')

if latest_ckpt_info is not None:
    latest_ckpt_path, latest_epoch = latest_ckpt_info
    print(f"üìÇ Found checkpoint: {latest_ckpt_path.name}")
    
    # Load checkpoint
    start_epoch, history, best_loss, loaded_ckpt = load_checkpoint(
        model, optimizer, scheduler_device, latest_ckpt_path, DEVICE
    )
    
    print(f"‚úÖ Loaded checkpoint from epoch {latest_epoch}")
    print(f"   Resuming training from epoch {start_epoch}")
    print(f"   Best validation loss so far: {best_loss:.4f}")
    print(f"   Epochs completed: {latest_epoch}\n")
else:
    print("üì≠ No checkpoint found - starting fresh training\n")
    start_epoch = 1
    history = {'epoch': [], 'train_loss': [], 'val_loss': []}
    best_loss = float('inf')

## 6. Train Model

In [None]:
# ========== Training Loop with Checkpoint Support ==========
print("\n" + "="*60)
print("üöÄ Starting Training Loop")
print("="*60 + "\n")

for epoch in range(start_epoch, EPOCHS + 1):
    train_loss = train_epoch()
    val_loss = validate()
    
    history['epoch'].append(epoch)
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    print(f"Epoch {epoch:2d}/{EPOCHS} | Train: {train_loss:.4f} | Val: {val_loss:.4f}", end="")
    
    # Update best loss and save best model
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), f'{CHECKPOINT_DIR}/fastddpm_best.pt')
        print(" ‚úÖ (best)")
    else:
        print()
    
    # Save checkpoint after every epoch
    checkpoint_path = f'{CHECKPOINT_DIR}/fastddpm_checkpoint_{epoch}.pt'
    save_checkpoint(model, optimizer, epoch, history, best_loss, checkpoint_path)

# Save final history
with open(f'{CHECKPOINT_DIR}/fastddpm_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print(f"\n‚úÖ Training complete! Best val loss: {best_loss:.4f}")
print(f"üìä Final history saved to fastddpm_history.json")

## 7. Sampling (Reverse Diffusion)

In [None]:
@torch.no_grad()
def sample(pre, post, num_samples=1):
    """Generate samples using reverse diffusion"""
    model.eval()
    batch_size = pre.shape[0]
    
    generated = []
    for _ in range(num_samples):
        # Start with random noise for the target
        x_t = torch.randn(batch_size, 1, 256, 256, device=DEVICE, dtype=torch.float32)
        
        # Reverse process (denoising)
        for step_idx, t_idx in enumerate(reversed(range(len(scheduler_device.timesteps)))):
            t = scheduler_device.timesteps[t_idx]
            t_batch = t.unsqueeze(0).expand(batch_size)
            
            # Concatenate: [clean_pre, clean_post, current_noisy_target]
            x_input = torch.cat([pre.to(DEVICE), post.to(DEVICE), x_t], dim=1)  # (B, 3, H, W)
            
            # Predict noise at this timestep
            pred_noise = model(x_input, t_batch)
            
            # Reverse step - ensure all tensors on DEVICE
            alpha_t = scheduler_device.alphas_cumprod[t]
            if t_idx > 0:
                t_prev_idx = scheduler_device.timesteps[t_idx - 1]
                alpha_prev = scheduler_device.alphas_cumprod[t_prev_idx]
            else:
                alpha_prev = torch.tensor(1.0, device=DEVICE, dtype=torch.float32)
            
            # Posterior variance
            posterior_var = (1 - alpha_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_prev)
            posterior_var = torch.clamp(posterior_var, min=1e-20)
            
            noise = torch.randn_like(x_t, device=DEVICE) if t_idx > 0 else torch.zeros_like(x_t)
            
            # Denoising step
            x_t = (1.0 / torch.sqrt(alpha_t)) * (x_t - (1 - alpha_t) / torch.sqrt(1 - alpha_t) * pred_noise) + torch.sqrt(posterior_var) * noise
        
        generated.append(x_t.cpu())
    
    return torch.stack(generated, dim=1)

print("‚úÖ Sampling function ready")

## 8. Evaluation

In [None]:
# Load best model
model.load_state_dict(torch.load(f'{CHECKPOINT_DIR}/fastddpm_best.pt'))

ssim_scores = []
psnr_scores = []

for (pre, post), target in tqdm(test_loader, desc="Testing"):
    generated = sample(pre, post, num_samples=3)
    pred = generated.mean(dim=1).squeeze().numpy()
    gt = target.squeeze().numpy()
    
    for i in range(len(gt)):
        gt_norm = (gt[i] - gt[i].min()) / (gt[i].max() - gt[i].min() + 1e-8)
        pred_norm = (pred[i] - pred[i].min()) / (pred[i].max() - pred[i].min() + 1e-8)
        
        ssim_scores.append(ssim(gt_norm, pred_norm, data_range=1.0))
        psnr_scores.append(psnr(gt_norm, pred_norm, data_range=1.0))

print(f"\n{'='*50}")
print(f"SSIM: {np.mean(ssim_scores):.4f} ¬± {np.std(ssim_scores):.4f}")
print(f"PSNR: {np.mean(psnr_scores):.2f} ¬± {np.std(psnr_scores):.2f} dB")
print(f"{'='*50}")

## 9. Visualize Training

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(history['epoch'], history['train_loss'], 'o-', label='Train Loss', linewidth=2)
plt.plot(history['epoch'], history['val_loss'], 's-', label='Val Loss', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Fast-DDPM Training History', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f'{CHECKPOINT_DIR}/fastddpm_training.png', dpi=150)
plt.show()

print(f"‚úÖ Plot saved to {CHECKPOINT_DIR}/fastddpm_training.png")