## 1. Setup and Imports

In [None]:
# Install required packages (if needed)
!pip install -q einops pytorch-fid

import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.cuda.amp import GradScaler, autocast
from torchvision.utils import make_grid, save_image
from PIL import Image

# Set random seeds
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(42)

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Configuration

In [None]:
# Quick debug config - use for testing
CONFIG = {
    'experiment': {
        'name': 'sbldm_kaggle',
        'seed': 42,
    },
    'data': {
        'resolution': 128,
        'channels': 1,
    },
    'vae': {
        'in_channels': 1,
        'latent_channels': 4,
        'hidden_dims': [32, 64, 128, 256],
        'batch_size': 32,
        'learning_rate': 1e-4,
        'epochs': 20,  # Quick run
        'kl_weight': 1e-5,
    },
    'unet': {
        'in_channels': 4,
        'out_channels': 4,
        'model_channels': 64,
        'channel_mult': [1, 2, 4],
        'num_res_blocks': 2,
        'attention_resolutions': [8],
        'dropout': 0.1,
    },
    'diffusion': {
        'timesteps': 1000,
        'beta_schedule': 'cosine',
        'beta_start': 1e-4,
        'beta_end': 0.02,
        'gamma': 1.0,  # For noise rebalancing
        'use_freq_loss': True,
        'freq_loss_weight': 0.1,
        'batch_size': 16,
        'learning_rate': 2e-4,
        'training_steps': 5000,  # Quick run
        'ema_decay': 0.9999,
    },
    'sampling': {
        'ddim_steps': 50,
        'eta': 0.0,
    }
}

print("Configuration loaded!")

## 3. Model Definitions

In [None]:
# ============================================================================
# VAE Components
# ============================================================================

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, groups=8):
        super().__init__()
        self.norm1 = nn.GroupNorm(min(groups, in_channels), in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(min(groups, out_channels), out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.skip = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.act = nn.SiLU()
    
    def forward(self, x):
        h = self.act(self.norm1(x))
        h = self.conv1(h)
        h = self.act(self.norm2(h))
        h = self.conv2(h)
        return h + self.skip(x)

class AttentionBlock(nn.Module):
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.norm = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)
        self.num_heads = num_heads
        self.head_dim = channels // num_heads
        self.scale = self.head_dim ** -0.5
    
    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h).view(B, 3, self.num_heads, self.head_dim, H * W)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
        attn = torch.einsum('bhdn,bhdm->bhnm', q, k) * self.scale
        attn = F.softmax(attn, dim=-1)
        out = torch.einsum('bhnm,bhdm->bhdn', attn, v).view(B, C, H, W)
        return x + self.proj(out)

class VAE(nn.Module):
    def __init__(self, in_channels=1, latent_channels=4, hidden_dims=[32, 64, 128, 256]):
        super().__init__()
        self.latent_channels = latent_channels
        self.downsample_factor = 2 ** (len(hidden_dims) - 1)
        
        # Encoder
        encoder_layers = [nn.Conv2d(in_channels, hidden_dims[0], 3, padding=1)]
        in_dim = hidden_dims[0]
        for i, out_dim in enumerate(hidden_dims):
            encoder_layers.extend([
                ResidualBlock(in_dim, out_dim),
                ResidualBlock(out_dim, out_dim),
            ])
            if i < len(hidden_dims) - 1:
                encoder_layers.append(nn.Conv2d(out_dim, out_dim, 3, stride=2, padding=1))
            in_dim = out_dim
        
        encoder_layers.extend([
            ResidualBlock(hidden_dims[-1], hidden_dims[-1]),
            AttentionBlock(hidden_dims[-1]),
            ResidualBlock(hidden_dims[-1], hidden_dims[-1]),
            nn.GroupNorm(8, hidden_dims[-1]),
            nn.SiLU(),
            nn.Conv2d(hidden_dims[-1], latent_channels * 2, 3, padding=1)
        ])
        self.encoder = nn.Sequential(*encoder_layers)
        
        # Decoder
        decoder_dims = list(reversed(hidden_dims))
        decoder_layers = [
            nn.Conv2d(latent_channels, decoder_dims[0], 3, padding=1),
            ResidualBlock(decoder_dims[0], decoder_dims[0]),
            AttentionBlock(decoder_dims[0]),
            ResidualBlock(decoder_dims[0], decoder_dims[0]),
        ]
        
        in_dim = decoder_dims[0]
        for i, out_dim in enumerate(decoder_dims):
            decoder_layers.extend([
                ResidualBlock(in_dim, out_dim),
                ResidualBlock(out_dim, out_dim),
            ])
            if i < len(decoder_dims) - 1:
                decoder_layers.append(nn.ConvTranspose2d(out_dim, out_dim, 4, stride=2, padding=1))
            in_dim = out_dim
        
        decoder_layers.extend([
            nn.GroupNorm(8, decoder_dims[-1]),
            nn.SiLU(),
            nn.Conv2d(decoder_dims[-1], in_channels, 3, padding=1)
        ])
        self.decoder = nn.Sequential(*decoder_layers)
    
    def encode(self, x):
        h = self.encoder(x)
        mean, logvar = torch.chunk(h, 2, dim=1)
        return mean, logvar
    
    def decode(self, z):
        return self.decoder(z)
    
    def reparameterize(self, mean, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + std * eps
        return mean
    
    def forward(self, x, sample_posterior=True):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar) if sample_posterior else mean
        recon = self.decode(z)
        return recon, mean, logvar
    
    def get_latent(self, x, sample=False):
        mean, logvar = self.encode(x)
        return self.reparameterize(mean, logvar) if sample else mean
    
    @staticmethod
    def kl_divergence(mean, logvar):
        return -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
    
    def loss_function(self, x, recon, mean, logvar, kl_weight=1e-5):
        recon_loss = F.mse_loss(recon, x)
        kl_loss = self.kl_divergence(mean, logvar)
        loss = recon_loss + kl_weight * kl_loss
        return loss, {'loss': loss.item(), 'recon_loss': recon_loss.item(), 'kl_loss': kl_loss.item()}

print("VAE defined!")

In [None]:
# ============================================================================
# UNet Score Model
# ============================================================================

import math

def get_timestep_embedding(timesteps, dim, max_period=10000):
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, device=timesteps.device) / half)
    args = timesteps[:, None].float() * freqs[None]
    return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)

class TimeResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim, dropout=0.0):
        super().__init__()
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.time_proj = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch * 2))
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
    
    def forward(self, x, t):
        h = F.silu(self.norm1(x))
        h = self.conv1(h)
        t_emb = self.time_proj(t)[:, :, None, None]
        scale, shift = t_emb.chunk(2, dim=1)
        h = self.norm2(h) * (1 + scale) + shift
        h = self.dropout(F.silu(h))
        h = self.conv2(h)
        return h + self.skip(x)

class UNet(nn.Module):
    def __init__(self, in_ch=4, out_ch=4, model_ch=64, ch_mult=[1,2,4], num_res=2, attn_res=[8], dropout=0.0):
        super().__init__()
        time_dim = model_ch * 4
        self.time_embed = nn.Sequential(
            nn.Linear(model_ch, time_dim), nn.SiLU(), nn.Linear(time_dim, time_dim)
        )
        self.conv_in = nn.Conv2d(in_ch, model_ch, 3, padding=1)
        
        # Encoder
        self.down_blocks = nn.ModuleList()
        self.down_samples = nn.ModuleList()
        ch = model_ch
        self.skip_channels = [ch]
        
        for i, mult in enumerate(ch_mult):
            out_ch_level = model_ch * mult
            for _ in range(num_res):
                self.down_blocks.append(TimeResBlock(ch, out_ch_level, time_dim, dropout))
                ch = out_ch_level
                self.skip_channels.append(ch)
            if i < len(ch_mult) - 1:
                self.down_samples.append(nn.Conv2d(ch, ch, 3, stride=2, padding=1))
                self.skip_channels.append(ch)
            else:
                self.down_samples.append(None)
        
        # Middle
        self.mid_block1 = TimeResBlock(ch, ch, time_dim, dropout)
        self.mid_attn = AttentionBlock(ch)
        self.mid_block2 = TimeResBlock(ch, ch, time_dim, dropout)
        
        # Decoder
        self.up_blocks = nn.ModuleList()
        self.up_samples = nn.ModuleList()
        
        for i, mult in enumerate(reversed(ch_mult)):
            out_ch_level = model_ch * mult
            for j in range(num_res + 1):
                skip_ch = self.skip_channels.pop()
                self.up_blocks.append(TimeResBlock(ch + skip_ch, out_ch_level, time_dim, dropout))
                ch = out_ch_level
            if i < len(ch_mult) - 1:
                self.up_samples.append(nn.Sequential(
                    nn.Upsample(scale_factor=2, mode='nearest'),
                    nn.Conv2d(ch, ch, 3, padding=1)
                ))
            else:
                self.up_samples.append(None)
        
        self.norm_out = nn.GroupNorm(8, ch)
        self.conv_out = nn.Conv2d(ch, out_ch, 3, padding=1)
        nn.init.zeros_(self.conv_out.weight)
        nn.init.zeros_(self.conv_out.bias)
        
        self.model_ch = model_ch
    
    def forward(self, x, t):
        t_emb = get_timestep_embedding(t, self.model_ch)
        t_emb = self.time_embed(t_emb)
        
        h = self.conv_in(x)
        skips = [h]
        
        block_idx = 0
        for i, ds in enumerate(self.down_samples):
            for _ in range(2):  # num_res
                h = self.down_blocks[block_idx](h, t_emb)
                skips.append(h)
                block_idx += 1
            if ds is not None:
                h = ds(h)
                skips.append(h)
        
        h = self.mid_block1(h, t_emb)
        h = self.mid_attn(h)
        h = self.mid_block2(h, t_emb)
        
        block_idx = 0
        for i, us in enumerate(self.up_samples):
            for _ in range(3):  # num_res + 1
                h = torch.cat([h, skips.pop()], dim=1)
                h = self.up_blocks[block_idx](h, t_emb)
                block_idx += 1
            if us is not None:
                h = us(h)
        
        h = F.silu(self.norm_out(h))
        return self.conv_out(h)

print("UNet defined!")

In [None]:
# ============================================================================
# Diffusion Utilities
# ============================================================================

class NoiseSchedule:
    def __init__(self, schedule_type='cosine', timesteps=1000, beta_start=1e-4, beta_end=0.02, gamma=1.0, device='cuda'):
        self.timesteps = timesteps
        self.device = device
        
        if schedule_type == 'linear':
            betas = torch.linspace(beta_start, beta_end, timesteps)
        elif schedule_type == 'cosine':
            s = 0.008
            t = torch.linspace(0, timesteps, timesteps + 1)
            alphas_cumprod = torch.cos(((t / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
            alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
            betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
            betas = torch.clip(betas, 0.0001, 0.9999)
        elif schedule_type == 'gamma_rebalanced':
            t = torch.linspace(0, 1, timesteps) ** gamma
            betas = beta_start + (beta_end - beta_start) * t
        else:
            raise ValueError(f"Unknown schedule: {schedule_type}")
        
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        
        self.betas = betas.to(device)
        self.alphas = alphas.to(device)
        self.alphas_cumprod = alphas_cumprod.to(device)
        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).to(device)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod).to(device)
    
    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alpha = self.sqrt_alphas_cumprod[t][:, None, None, None]
        sqrt_one_minus = self.sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
        return sqrt_alpha * x_start + sqrt_one_minus * noise, noise

class FrequencyAwareLoss:
    def __init__(self, freq_weight=0.1):
        self.freq_weight = freq_weight
    
    def __call__(self, pred, target):
        spatial_loss = F.mse_loss(pred, target)
        pred_fft = torch.abs(torch.fft.fft2(pred))
        target_fft = torch.abs(torch.fft.fft2(target))
        freq_loss = F.mse_loss(pred_fft, target_fft)
        total = spatial_loss + self.freq_weight * freq_loss
        return total, {'loss': total.item(), 'spatial': spatial_loss.item(), 'freq': freq_loss.item()}

class DDIMSampler:
    def __init__(self, schedule, model, eta=0.0):
        self.schedule = schedule
        self.model = model
        self.eta = eta
    
    @torch.no_grad()
    def sample(self, shape, num_steps=50, device='cuda', progress=True):
        self.model.eval()
        x = torch.randn(shape, device=device)
        
        step_size = self.schedule.timesteps // num_steps
        timesteps = torch.arange(self.schedule.timesteps - 1, -1, -step_size, device=device)
        
        iterator = tqdm(timesteps, desc='DDIM') if progress else timesteps
        
        for i, t in enumerate(iterator):
            t_batch = t.expand(shape[0])
            noise_pred = self.model(x, t_batch)
            
            alpha_t = self.schedule.alphas_cumprod[t]
            alpha_prev = self.schedule.alphas_cumprod[max(0, t - step_size)] if t > 0 else torch.tensor(1.0)
            
            x0_pred = (x - torch.sqrt(1 - alpha_t) * noise_pred) / torch.sqrt(alpha_t)
            x0_pred = torch.clamp(x0_pred, -1, 1)
            
            sigma = self.eta * torch.sqrt((1 - alpha_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_prev))
            dir_xt = torch.sqrt(1 - alpha_prev - sigma**2) * noise_pred
            noise = torch.randn_like(x) if self.eta > 0 and t > 0 else 0
            
            x = torch.sqrt(alpha_prev) * x0_pred + dir_xt + sigma * noise
        
        return x

print("Diffusion utilities defined!")

## 4. Create Synthetic Dataset (or load real data)

In [None]:
def create_synthetic_brain_data(num_samples=500, size=128):
    """Create synthetic brain-MRI-like images for testing."""
    from scipy.ndimage import gaussian_filter
    
    images = []
    for _ in tqdm(range(num_samples), desc="Generating synthetic data"):
        img = np.zeros((size, size), dtype=np.float32)
        
        # Create elliptical brain shape
        y, x = np.ogrid[:size, :size]
        cx, cy = size // 2, size // 2
        rx, ry = size // 2.5, size // 2.5
        
        outer = ((x - cx) / rx) ** 2 + ((y - cy) / ry) ** 2 <= 1
        inner = ((x - cx) / (rx * 0.9)) ** 2 + ((y - cy) / (ry * 0.9)) ** 2 <= 1
        
        img[outer] = 0.3 + np.random.uniform(-0.1, 0.1)
        img[inner] = 0.6 + np.random.uniform(-0.1, 0.1)
        
        # Add ventricles
        for _ in range(2):
            vx = cx + np.random.randint(-size//6, size//6)
            vy = cy + np.random.randint(-size//8, size//8)
            vrx, vry = np.random.randint(5, 15), np.random.randint(5, 15)
            vent = ((x - vx) / vrx) ** 2 + ((y - vy) / vry) ** 2 <= 1
            img[vent & inner] = 0.2
        
        # Add noise and blur
        img += np.random.normal(0, 0.05, (size, size))
        img = np.clip(img, 0, 1)
        img = gaussian_filter(img, sigma=1)
        
        images.append(img)
    
    images = np.array(images)[:, None, :, :]  # [N, 1, H, W]
    return torch.from_numpy(images).float() * 2 - 1  # Normalize to [-1, 1]

# Create dataset
print("Creating synthetic dataset...")
train_data = create_synthetic_brain_data(num_samples=400, size=CONFIG['data']['resolution'])
val_data = create_synthetic_brain_data(num_samples=100, size=CONFIG['data']['resolution'])

print(f"Train data: {train_data.shape}")
print(f"Val data: {val_data.shape}")

# Visualize
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow((train_data[i, 0].numpy() + 1) / 2, cmap='gray')
    ax.axis('off')
plt.suptitle('Synthetic Brain MRI Slices')
plt.tight_layout()
plt.show()

## 5. Train VAE

In [None]:
# Create VAE
vae = VAE(
    in_channels=CONFIG['vae']['in_channels'],
    latent_channels=CONFIG['vae']['latent_channels'],
    hidden_dims=CONFIG['vae']['hidden_dims']
).to(device)

num_params = sum(p.numel() for p in vae.parameters())
print(f"VAE parameters: {num_params:,}")
print(f"Downsample factor: {vae.downsample_factor}x")

# Dataloaders
train_loader = DataLoader(
    TensorDataset(train_data),
    batch_size=CONFIG['vae']['batch_size'],
    shuffle=True,
    drop_last=True
)
val_loader = DataLoader(TensorDataset(val_data), batch_size=CONFIG['vae']['batch_size'])

# Optimizer
optimizer = optim.AdamW(vae.parameters(), lr=CONFIG['vae']['learning_rate'])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['vae']['epochs'])
scaler = GradScaler()

In [None]:
# Train VAE
print("Training VAE...")
vae_history = {'train_loss': [], 'val_loss': []}

for epoch in range(CONFIG['vae']['epochs']):
    vae.train()
    train_loss = 0
    
    for (batch,) in tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['vae']['epochs']}", leave=False):
        batch = batch.to(device)
        optimizer.zero_grad()
        
        with autocast():
            recon, mean, logvar = vae(batch)
            loss, _ = vae.loss_function(batch, recon, mean, logvar, CONFIG['vae']['kl_weight'])
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()
    
    scheduler.step()
    train_loss /= len(train_loader)
    
    # Validation
    vae.eval()
    val_loss = 0
    with torch.no_grad():
        for (batch,) in val_loader:
            batch = batch.to(device)
            recon, mean, logvar = vae(batch, sample_posterior=False)
            loss, _ = vae.loss_function(batch, recon, mean, logvar, CONFIG['vae']['kl_weight'])
            val_loss += loss.item()
    val_loss /= len(val_loader)
    
    vae_history['train_loss'].append(train_loss)
    vae_history['val_loss'].append(val_loss)
    
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}: train={train_loss:.4f}, val={val_loss:.4f}")

print("VAE training complete!")

In [None]:
# Visualize VAE reconstructions
vae.eval()
with torch.no_grad():
    test_batch = val_data[:8].to(device)
    recon, _, _ = vae(test_batch, sample_posterior=False)

fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(8):
    axes[0, i].imshow((test_batch[i, 0].cpu().numpy() + 1) / 2, cmap='gray')
    axes[0, i].axis('off')
    axes[0, i].set_title('Original' if i == 0 else '')
    
    axes[1, i].imshow((recon[i, 0].cpu().numpy() + 1) / 2, cmap='gray')
    axes[1, i].axis('off')
    axes[1, i].set_title('Recon' if i == 0 else '')

plt.tight_layout()
plt.show()

# Plot loss
plt.figure(figsize=(8, 4))
plt.plot(vae_history['train_loss'], label='Train')
plt.plot(vae_history['val_loss'], label='Val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('VAE Training Loss')
plt.show()

## 6. Train Diffusion Model

In [None]:
# Encode dataset to latents
print("Encoding dataset to latents...")
vae.eval()

with torch.no_grad():
    train_latents = []
    for (batch,) in tqdm(train_loader, desc="Encoding train"):
        z = vae.get_latent(batch.to(device), sample=False)
        train_latents.append(z.cpu())
    train_latents = torch.cat(train_latents, dim=0)
    
    val_latents = []
    for (batch,) in val_loader:
        z = vae.get_latent(batch.to(device), sample=False)
        val_latents.append(z.cpu())
    val_latents = torch.cat(val_latents, dim=0)

print(f"Train latents: {train_latents.shape}")
print(f"Val latents: {val_latents.shape}")

In [None]:
# Create UNet and diffusion components
unet = UNet(
    in_ch=CONFIG['unet']['in_channels'],
    out_ch=CONFIG['unet']['out_channels'],
    model_ch=CONFIG['unet']['model_channels'],
    ch_mult=CONFIG['unet']['channel_mult'],
    num_res=CONFIG['unet']['num_res_blocks'],
    attn_res=CONFIG['unet']['attention_resolutions'],
    dropout=CONFIG['unet']['dropout']
).to(device)

num_params = sum(p.numel() for p in unet.parameters())
print(f"UNet parameters: {num_params:,}")

# Noise schedule
schedule = NoiseSchedule(
    schedule_type=CONFIG['diffusion']['beta_schedule'],
    timesteps=CONFIG['diffusion']['timesteps'],
    beta_start=CONFIG['diffusion']['beta_start'],
    beta_end=CONFIG['diffusion']['beta_end'],
    gamma=CONFIG['diffusion']['gamma'],
    device=device
)

# Loss and optimizer
loss_fn = FrequencyAwareLoss(freq_weight=CONFIG['diffusion']['freq_loss_weight'])
optimizer = optim.AdamW(unet.parameters(), lr=CONFIG['diffusion']['learning_rate'])
scaler = GradScaler()

# Dataloader
latent_loader = DataLoader(
    TensorDataset(train_latents),
    batch_size=CONFIG['diffusion']['batch_size'],
    shuffle=True,
    drop_last=True
)

In [None]:
# Train diffusion
print("Training diffusion model...")
diff_history = {'loss': []}
global_step = 0
total_steps = CONFIG['diffusion']['training_steps']

latent_iter = iter(latent_loader)
pbar = tqdm(range(total_steps), desc="Training")

for step in pbar:
    try:
        (z,) = next(latent_iter)
    except StopIteration:
        latent_iter = iter(latent_loader)
        (z,) = next(latent_iter)
    
    z = z.to(device)
    t = torch.randint(0, schedule.timesteps, (z.size(0),), device=device)
    noise = torch.randn_like(z)
    z_noisy, noise_target = schedule.q_sample(z, t, noise)
    
    optimizer.zero_grad()
    with autocast():
        noise_pred = unet(z_noisy, t)
        loss, loss_dict = loss_fn(noise_pred, noise_target)
    
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()
    
    diff_history['loss'].append(loss_dict['loss'])
    pbar.set_postfix({'loss': f"{loss_dict['loss']:.4f}"})
    global_step += 1

print("Diffusion training complete!")

In [None]:
# Plot diffusion loss
plt.figure(figsize=(10, 4))
plt.plot(diff_history['loss'])
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Diffusion Training Loss')
plt.show()

## 7. Generate Samples

In [None]:
# Generate samples with DDIM
print("Generating samples...")
sampler = DDIMSampler(schedule, unet, eta=CONFIG['sampling']['eta'])

unet.eval()
vae.eval()

latent_size = CONFIG['data']['resolution'] // vae.downsample_factor
sample_shape = (16, CONFIG['unet']['in_channels'], latent_size, latent_size)

with torch.no_grad():
    z_samples = sampler.sample(sample_shape, num_steps=CONFIG['sampling']['ddim_steps'], device=device)
    generated = vae.decode(z_samples)
    generated = torch.clamp((generated + 1) / 2, 0, 1)

print(f"Generated {len(generated)} samples")

In [None]:
# Visualize generated samples
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(generated[i, 0].cpu().numpy(), cmap='gray')
    ax.axis('off')
plt.suptitle('Generated Brain MRI Slices (DDIM 50 steps)')
plt.tight_layout()
plt.show()

## 8. Evaluate Metrics

In [None]:
# Compute SSIM and PSNR for VAE reconstructions
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

vae.eval()
ssim_scores = []
psnr_scores = []

with torch.no_grad():
    for (batch,) in val_loader:
        batch = batch.to(device)
        recon, _, _ = vae(batch, sample_posterior=False)
        
        # Convert to [0, 1]
        batch_01 = (batch.cpu().numpy() + 1) / 2
        recon_01 = np.clip((recon.cpu().numpy() + 1) / 2, 0, 1)
        
        for i in range(len(batch)):
            s = ssim(batch_01[i, 0], recon_01[i, 0], data_range=1.0)
            p = psnr(batch_01[i, 0], recon_01[i, 0], data_range=1.0)
            ssim_scores.append(s)
            psnr_scores.append(p)

print(f"\nVAE Reconstruction Metrics:")
print(f"  SSIM: {np.mean(ssim_scores):.4f} ± {np.std(ssim_scores):.4f}")
print(f"  PSNR: {np.mean(psnr_scores):.2f} ± {np.std(psnr_scores):.2f} dB")

In [None]:
# Reconstruction error heatmap
vae.eval()
with torch.no_grad():
    sample = val_data[0:1].to(device)
    recon, _, _ = vae(sample, sample_posterior=False)

error = torch.abs(sample - recon).squeeze().cpu().numpy()

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow((sample[0, 0].cpu().numpy() + 1) / 2, cmap='gray')
axes[0].set_title('Original')
axes[0].axis('off')

axes[1].imshow((recon[0, 0].cpu().numpy() + 1) / 2, cmap='gray')
axes[1].set_title('Reconstructed')
axes[1].axis('off')

im = axes[2].imshow(error, cmap='hot')
axes[2].set_title('Error Heatmap')
axes[2].axis('off')
plt.colorbar(im, ax=axes[2], fraction=0.046)

plt.tight_layout()
plt.show()

## 9. Summary and Next Steps

In [None]:
print("="*60)
print("SBLDM Training Summary")
print("="*60)
print(f"\nVAE:")
print(f"  Parameters: {sum(p.numel() for p in vae.parameters()):,}")
print(f"  Final train loss: {vae_history['train_loss'][-1]:.4f}")
print(f"  Final val loss: {vae_history['val_loss'][-1]:.4f}")
print(f"  SSIM: {np.mean(ssim_scores):.4f}")
print(f"  PSNR: {np.mean(psnr_scores):.2f} dB")

print(f"\nDiffusion UNet:")
print(f"  Parameters: {sum(p.numel() for p in unet.parameters()):,}")
print(f"  Training steps: {global_step}")
print(f"  Final loss: {diff_history['loss'][-1]:.4f}")

print(f"\nGeneration:")
print(f"  DDIM steps: {CONFIG['sampling']['ddim_steps']}")
print(f"  Samples generated: {len(generated)}")

print("\n" + "="*60)
print("Next Steps:")
print("- Train longer (100+ VAE epochs, 50K+ diffusion steps)")
print("- Use real BraTS/medical imaging data")
print("- Compute FID with more samples")
print("- Experiment with gamma-rebalanced schedule")
print("- Test adaptive sampling")
print("="*60)

In [None]:
# Save models (optional)
# torch.save(vae.state_dict(), 'vae_weights.pt')
# torch.save(unet.state_dict(), 'unet_weights.pt')
print("Done!")