In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# RadiSynth AI: Synthetic Brain MRI Generation with Denoising Score Matching

## Case Study Implementation Notebook

This notebook implements the core technical components of RadiSynth AI's synthetic medical image generation pipeline using Denoising Score Matching (DSM).

We use a simulated brain MRI dataset (generated from simple shapes and textures) to demonstrate the full pipeline. The same architecture and training procedure apply to real MRI data.

## 1. Setup and Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 2. Simulated MRI Dataset

We create a synthetic dataset that mimics key properties of brain MRI slices: elliptical brain shapes, internal structure (ventricles), and occasional "tumors."

In [None]:
class SimulatedMRIDataset(Dataset):
    """Generates simple simulated brain MRI slices."""

    def __init__(self, n_samples=5000, img_size=64, tumor_fraction=0.3):
        self.n_samples = n_samples
        self.img_size = img_size
        self.tumor_fraction = tumor_fraction
        self.images, self.labels = self._generate()

    def _generate(self):
        images = []
        labels = []
        for i in range(self.n_samples):
            img = np.zeros((self.img_size, self.img_size), dtype=np.float32)
            cx, cy = self.img_size // 2, self.img_size // 2

            # Brain outline (ellipse)
            Y, X = np.ogrid[:self.img_size, :self.img_size]
            a = self.img_size * 0.4 + np.random.randn() * 2
            b = self.img_size * 0.35 + np.random.randn() * 2
            brain_mask = ((X - cx) / a) ** 2 + ((Y - cy) / b) ** 2 <= 1
            img[brain_mask] = 0.7 + np.random.randn() * 0.05

            # Ventricles (small ellipses in center)
            v_a = self.img_size * 0.08 + np.random.randn() * 1
            v_b = self.img_size * 0.12 + np.random.randn() * 1
            for offset in [-1, 1]:
                vcx = cx + offset * int(self.img_size * 0.07)
                v_mask = ((X - vcx) / v_a) ** 2 + ((Y - cy) / v_b) ** 2 <= 1
                img[v_mask] = 0.3 + np.random.randn() * 0.03

            # Add subtle texture noise
            img += np.random.randn(self.img_size, self.img_size) * 0.02
            img = np.clip(img, 0, 1)

            # Optionally add tumor
            has_tumor = np.random.rand() < self.tumor_fraction
            if has_tumor:
                tx = cx + np.random.randint(-10, 10)
                ty = cy + np.random.randint(-10, 10)
                tr = np.random.randint(3, 8)
                t_mask = (X - tx) ** 2 + (Y - ty) ** 2 <= tr ** 2
                t_mask = t_mask & brain_mask
                img[t_mask] = 0.9 + np.random.randn() * 0.03

            # Normalize to [-1, 1]
            img = img * 2 - 1

            images.append(torch.tensor(img).unsqueeze(0))  # (1, H, W)
            labels.append(1 if has_tumor else 0)

        return torch.stack(images), torch.tensor(labels)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

# Create dataset
dataset = SimulatedMRIDataset(n_samples=5000, img_size=64)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

print(f"Dataset size: {len(dataset)}")
print(f"Image shape: {dataset[0][0].shape}")
print(f"Healthy: {(dataset.labels == 0).sum()}, Tumor: {(dataset.labels == 1).sum()}")

In [None]:
# Visualize samples
fig, axes = plt.subplots(2, 8, figsize=(20, 5))
for i, ax in enumerate(axes.flatten()):
    img, label = dataset[i]
    ax.imshow(img.squeeze().numpy(), cmap='gray', vmin=-1, vmax=1)
    ax.set_title('Tumor' if label == 1 else 'Healthy', fontsize=9)
    ax.axis('off')
plt.suptitle('Sample Simulated Brain MRI Slices', fontsize=14)
plt.tight_layout()
plt.show()

## 3. Score Network Architecture

In [None]:
class SinusoidalPositionEmbedding(nn.Module):
    """Encode the noise level sigma using sinusoidal embeddings."""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, sigma):
        half = self.dim // 2
        freqs = torch.exp(-np.log(10000) * torch.arange(half, device=sigma.device) / half)
        args = sigma.unsqueeze(-1) * freqs
        embedding = torch.cat([args.sin(), args.cos()], dim=-1)
        return embedding


class ResBlock(nn.Module):
    """Residual block with GroupNorm and sigma conditioning."""
    def __init__(self, in_ch, out_ch, sigma_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.sigma_proj = nn.Linear(sigma_dim, out_ch)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        self.act = nn.SiLU()

    def forward(self, x, sigma_emb):
        h = self.act(self.norm1(self.conv1(x)))
        # Add sigma conditioning
        h = h + self.sigma_proj(sigma_emb)[:, :, None, None]
        h = self.act(self.norm2(self.conv2(h)))
        return h + self.skip(x)


class SimpleUNet(nn.Module):
    """Simplified U-Net for score prediction, conditioned on sigma."""
    def __init__(self, in_ch=1, base_ch=32, sigma_dim=64):
        super().__init__()
        self.sigma_embed = nn.Sequential(
            SinusoidalPositionEmbedding(sigma_dim),
            nn.Linear(sigma_dim, sigma_dim),
            nn.SiLU(),
        )

        # Encoder
        self.enc1 = ResBlock(in_ch, base_ch, sigma_dim)
        self.enc2 = ResBlock(base_ch, base_ch * 2, sigma_dim)
        self.enc3 = ResBlock(base_ch * 2, base_ch * 4, sigma_dim)
        self.pool = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = ResBlock(base_ch * 4, base_ch * 4, sigma_dim)

        # Decoder
        self.up3 = nn.ConvTranspose2d(base_ch * 4, base_ch * 4, 2, stride=2)
        self.dec3 = ResBlock(base_ch * 8, base_ch * 2, sigma_dim)
        self.up2 = nn.ConvTranspose2d(base_ch * 2, base_ch * 2, 2, stride=2)
        self.dec2 = ResBlock(base_ch * 4, base_ch, sigma_dim)
        self.up1 = nn.ConvTranspose2d(base_ch, base_ch, 2, stride=2)
        self.dec1 = ResBlock(base_ch * 2, base_ch, sigma_dim)

        self.final = nn.Conv2d(base_ch, in_ch, 1)

    def forward(self, x, sigma):
        sigma_emb = self.sigma_embed(sigma)

        # Encoder
        e1 = self.enc1(x, sigma_emb)
        e2 = self.enc2(self.pool(e1), sigma_emb)
        e3 = self.enc3(self.pool(e2), sigma_emb)

        # Bottleneck
        b = self.bottleneck(self.pool(e3), sigma_emb)

        # Decoder with skip connections
        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1), sigma_emb)
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1), sigma_emb)
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1), sigma_emb)

        return self.final(d1)

model = SimpleUNet().to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Score network parameters: {n_params:,}")

## 4. Multi-Scale DSM Training

In [None]:
# Define noise levels (geometric sequence)
L = 10
sigma_max = 25.0
sigma_min = 0.01
noise_levels = torch.exp(torch.linspace(np.log(sigma_max), np.log(sigma_min), L)).to(device)
print(f"Noise levels: {noise_levels.cpu().numpy().round(3)}")

In [None]:
def dsm_loss_multiscale(model, x, noise_levels, device):
    """
    Multi-scale DSM loss (NCSN-style).

    For each sample in the batch:
    1. Randomly select a noise level
    2. Add noise at that level
    3. Predict the score
    4. Compute weighted MSE loss
    """
    batch_size = x.shape[0]

    # Random noise level for each sample
    idx = torch.randint(0, len(noise_levels), (batch_size,), device=device)
    sigma = noise_levels[idx]  # (B,)

    # Add noise
    epsilon = torch.randn_like(x)
    x_noisy = x + sigma[:, None, None, None] * epsilon

    # Predict score
    score_pred = model(x_noisy, sigma)

    # Target: -epsilon / sigma
    target = -epsilon / sigma[:, None, None, None]

    # Weighted MSE loss
    loss = (sigma[:, None, None, None] ** 2) * ((score_pred - target) ** 2)
    loss = loss.mean()

    return loss

In [None]:
# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
n_epochs = 50

losses = []
for epoch in range(n_epochs):
    epoch_loss = 0
    n_batches = 0
    for batch_images, batch_labels in dataloader:
        batch_images = batch_images.to(device)

        loss = dsm_loss_multiscale(model, batch_images, noise_levels, device)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_loss += loss.item()
        n_batches += 1

    scheduler.step()
    avg_loss = epoch_loss / n_batches
    losses.append(avg_loss)

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.title('Multi-Scale DSM Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)
plt.show()

## 5. Annealed Langevin Dynamics Sampling

In [None]:
@torch.no_grad()
def annealed_langevin_sample(model, noise_levels, n_samples=16,
                              img_shape=(1, 64, 64), n_steps_per_level=100,
                              base_step_size=5e-5, device='cuda'):
    """
    Generate images using Annealed Langevin Dynamics.

    For each noise level from largest to smallest:
        Run n_steps of Langevin dynamics with step size proportional to sigma^2.
    """
    # Start from random noise
    x = torch.randn(n_samples, *img_shape, device=device)
    snapshots = [x.cpu().clone()]

    sigma_L = noise_levels[-1]  # smallest sigma

    for i, sigma_i in enumerate(noise_levels):
        # Step size proportional to sigma^2
        step_size = base_step_size * (sigma_i / sigma_L) ** 2

        for t in range(n_steps_per_level):
            sigma_batch = sigma_i.expand(n_samples)
            score = model(x, sigma_batch)
            noise = torch.randn_like(x)
            x = x + step_size * score + torch.sqrt(2 * step_size) * noise

        snapshots.append(x.cpu().clone())

    return x.cpu(), snapshots

# Generate samples
generated, snapshots = annealed_langevin_sample(
    model, noise_levels, n_samples=16,
    n_steps_per_level=100, device=device
)

# Visualize
fig, axes = plt.subplots(2, 8, figsize=(20, 5))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(generated[i].squeeze().numpy(), cmap='gray', vmin=-1, vmax=1)
    ax.axis('off')
plt.suptitle('Generated Synthetic Brain MRI Slices', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Visualize the annealing process
n_show = min(len(snapshots), 6)
indices = np.linspace(0, len(snapshots) - 1, n_show, dtype=int)

fig, axes = plt.subplots(2, n_show, figsize=(4 * n_show, 8))
for col, idx in enumerate(indices):
    for row in range(2):
        sample_idx = row
        ax = axes[row, col]
        ax.imshow(snapshots[idx][sample_idx].squeeze().numpy(), cmap='gray', vmin=-1, vmax=1)
        if idx == 0:
            ax.set_title(f'Pure Noise', fontsize=10)
        elif idx == len(snapshots) - 1:
            ax.set_title(f'Final', fontsize=10)
        else:
            sigma_idx = idx - 1
            if sigma_idx < len(noise_levels):
                ax.set_title(f'sigma={noise_levels[sigma_idx].item():.2f}', fontsize=10)
        ax.axis('off')

plt.suptitle('Annealed Langevin Dynamics: Noise to Image', fontsize=14)
plt.tight_layout()
plt.show()
print("Watch how images gradually emerge from random noise!")

## 6. Evaluation

In [None]:
# TODO: Implement FID computation
def compute_simple_fid(real_images, generated_images):
    """
    Compute a simplified FID score using pixel-space statistics.

    For a production system, use InceptionV3 features instead.

    Args:
        real_images: (N, 1, H, W) tensor
        generated_images: (M, 1, H, W) tensor

    Returns:
        fid: float

    Steps:
        1. Flatten images to vectors
        2. Compute mean and covariance of real and generated features
        3. Compute Frechet distance between the two Gaussians
    """
    # ============ TODO ============
    # Implement simplified FID
    # FID = ||mu_r - mu_g||^2 + Tr(C_r + C_g - 2*(C_r @ C_g)^{1/2})
    # ==============================

    real_flat = real_images.view(real_images.shape[0], -1).numpy()
    gen_flat = generated_images.view(generated_images.shape[0], -1).numpy()

    mu_r = np.mean(real_flat, axis=0)
    mu_g = np.mean(gen_flat, axis=0)

    # Simplified: just L2 distance between means
    fid_approx = np.sum((mu_r - mu_g) ** 2)

    return fid_approx

# Generate more samples for evaluation
eval_generated, _ = annealed_langevin_sample(
    model, noise_levels, n_samples=200,
    n_steps_per_level=100, device=device
)

fid = compute_simple_fid(dataset.images[:200], eval_generated)
print(f"Approximate FID (pixel-space): {fid:.4f}")

In [None]:
# Compare real vs generated statistics
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Pixel intensity distribution
axes[0].hist(dataset.images[:500].numpy().flatten(), bins=50,
             alpha=0.5, density=True, label='Real', color='blue')
axes[0].hist(eval_generated.numpy().flatten(), bins=50,
             alpha=0.5, density=True, label='Generated', color='red')
axes[0].set_title('Pixel Intensity Distribution')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Mean image comparison
axes[1].imshow(dataset.images[:500].mean(0).squeeze().numpy(), cmap='gray')
axes[1].set_title('Mean Real Image')
axes[1].axis('off')

axes[2].imshow(eval_generated.mean(0).squeeze().numpy(), cmap='gray')
axes[2].set_title('Mean Generated Image')
axes[2].axis('off')

plt.suptitle('Real vs Generated Image Statistics', fontsize=14)
plt.tight_layout()
plt.show()

## 7. Privacy Audit

In [None]:
# TODO: Nearest-neighbor privacy check
def privacy_check(generated_images, training_images, n_check=100):
    """
    Check that generated images are not memorized copies of training data.

    Args:
        generated_images: (N_gen, 1, H, W) generated images
        training_images: (N_train, 1, H, W) training images
        n_check: number of generated images to check

    Returns:
        min_distance: minimum L2 distance found
        distances: list of nearest-neighbor distances
    """
    # ============ TODO ============
    # For each generated image:
    #   1. Compute L2 distance to all training images
    #   2. Record the minimum distance
    # ==============================

    gen_flat = generated_images[:n_check].view(n_check, -1)
    train_flat = training_images.view(len(training_images), -1)

    distances = []
    for i in range(n_check):
        dists = torch.norm(train_flat - gen_flat[i:i+1], dim=1)
        distances.append(dists.min().item())

    return min(distances), distances

min_dist, all_dists = privacy_check(eval_generated, dataset.images)
print(f"Minimum nearest-neighbor distance: {min_dist:.4f}")
print(f"Mean nearest-neighbor distance: {np.mean(all_dists):.4f}")

plt.figure(figsize=(8, 4))
plt.hist(all_dists, bins=30, edgecolor='black', alpha=0.7)
plt.axvline(x=min_dist, color='red', linestyle='--', label=f'Min: {min_dist:.3f}')
plt.title('Nearest-Neighbor Distance Distribution')
plt.xlabel('L2 Distance')
plt.ylabel('Count')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

if min_dist > 0.05:
    print("PASSED: No generated images are too close to training data.")
else:
    print("WARNING: Some generated images may be memorized!")

## 8. Summary and Results

In [None]:
# Final summary visualization
fig = plt.figure(figsize=(20, 12))

# Row 1: Real vs Generated
for i in range(8):
    ax = fig.add_subplot(3, 8, i + 1)
    ax.imshow(dataset.images[i].squeeze().numpy(), cmap='gray', vmin=-1, vmax=1)
    ax.axis('off')
    if i == 0:
        ax.set_ylabel('Real', fontsize=12)

for i in range(8):
    ax = fig.add_subplot(3, 8, i + 9)
    ax.imshow(generated[i].squeeze().numpy(), cmap='gray', vmin=-1, vmax=1)
    ax.axis('off')
    if i == 0:
        ax.set_ylabel('Generated', fontsize=12)

# Row 3: Denoising process for one sample
for col, idx in enumerate(np.linspace(0, len(snapshots) - 1, 8, dtype=int)):
    ax = fig.add_subplot(3, 8, col + 17)
    ax.imshow(snapshots[idx][0].squeeze().numpy(), cmap='gray', vmin=-1, vmax=1)
    ax.axis('off')
    if col == 0:
        ax.set_ylabel('Process', fontsize=12)

plt.suptitle('RadiSynth AI: Synthetic Brain MRI Generation Pipeline', fontsize=16)
plt.tight_layout()
plt.show()

print("\n" + "=" * 60)
print("RadiSynth AI -- Results Summary")
print("=" * 60)
print(f"Training data:       {len(dataset)} simulated MRI slices")
print(f"Score network:       SimpleUNet ({n_params:,} parameters)")
print(f"Noise levels:        {L} (geometric from {sigma_max} to {sigma_min})")
print(f"Training epochs:     {n_epochs}")
print(f"Final training loss: {losses[-1]:.4f}")
print(f"Generated samples:   {len(eval_generated)}")
print(f"Privacy check:       {'PASSED' if min_dist > 0.05 else 'FAILED'}")
print(f"Approx FID:          {fid:.4f}")