In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Accelerated MRI Reconstruction Using Score-Based Generative Priors -- Implementation Notebook

## Case Study Overview

In this notebook, we implement a simplified version of the MedScanAI ReconPrior system: using a Noise Conditioned Score Network (NCSN) as a learned prior for MRI image reconstruction from undersampled k-space data.

We will work with a synthetic MRI-like dataset to demonstrate the core concepts, then outline how to extend this to real fastMRI data.

## Setup and Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Section 3.1: Data -- Synthetic MRI-like Images

In [None]:
def generate_synthetic_mri(n_samples=2000, size=64):
    """
    Generate synthetic MRI-like images (ellipses on dark background).
    These simulate simplified brain/knee cross-sections.
    """
    images = []
    for _ in range(n_samples):
        img = np.zeros((size, size), dtype=np.float32)
        # Background ellipse (body/skull)
        y, x = np.ogrid[-size//2:size//2, -size//2:size//2]
        a, b = size//2 - 5, size//2 - 8
        mask = (x**2 / a**2 + y**2 / b**2) <= 1
        img[mask] = 0.3 + np.random.uniform(-0.05, 0.05)

        # Internal structures (organs/tissues)
        n_structures = np.random.randint(3, 7)
        for _ in range(n_structures):
            cx = np.random.randint(-size//4, size//4)
            cy = np.random.randint(-size//4, size//4)
            ra = np.random.randint(3, size//6)
            rb = np.random.randint(3, size//6)
            angle = np.random.uniform(0, np.pi)
            intensity = np.random.uniform(0.4, 0.9)

            cos_a, sin_a = np.cos(angle), np.sin(angle)
            xr = cos_a * (x - cx) + sin_a * (y - cy)
            yr = -sin_a * (x - cx) + cos_a * (y - cy)
            struct_mask = (xr**2 / ra**2 + yr**2 / rb**2) <= 1
            img[struct_mask] = intensity

        images.append(img)
    return np.array(images)

# Generate dataset
train_images = generate_synthetic_mri(2000, size=64)
test_images = generate_synthetic_mri(200, size=64)

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.ravel()):
    ax.imshow(train_images[i], cmap='gray', vmin=0, vmax=1)
    ax.axis('off')
plt.suptitle('Synthetic MRI-like Training Images', fontsize=14)
plt.tight_layout()
plt.show()
print(f"Training images: {train_images.shape}")

## Section 3.2: EDA -- k-space and Undersampling

In [None]:
def image_to_kspace(image):
    """Convert image to k-space (2D FFT)."""
    return np.fft.fftshift(np.fft.fft2(image))

def kspace_to_image(kspace):
    """Convert k-space to image (inverse 2D FFT)."""
    return np.abs(np.fft.ifft2(np.fft.ifftshift(kspace)))

def create_undersampling_mask(shape, acceleration=4, center_fraction=0.08):
    """Create random undersampling mask with fully-sampled center."""
    H, W = shape
    mask = np.zeros((H, W), dtype=np.float32)

    # Always sample center
    center_width = int(W * center_fraction)
    center_start = W // 2 - center_width // 2
    mask[:, center_start:center_start + center_width] = 1

    # Randomly sample remaining lines
    n_total_lines = int(W / acceleration)
    n_center_lines = center_width
    n_random_lines = max(0, n_total_lines - n_center_lines)

    available = list(set(range(W)) - set(range(center_start, center_start + center_width)))
    if n_random_lines > 0 and len(available) > 0:
        selected = np.random.choice(available, size=min(n_random_lines, len(available)), replace=False)
        for s in selected:
            mask[:, s] = 1

    return mask

# Demonstrate undersampling
img = train_images[0]
kspace = image_to_kspace(img)
mask_4x = create_undersampling_mask(img.shape, acceleration=4)
mask_8x = create_undersampling_mask(img.shape, acceleration=8)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

axes[0, 0].imshow(img, cmap='gray')
axes[0, 0].set_title('Original Image')
axes[0, 1].imshow(np.log1p(np.abs(kspace)), cmap='gray')
axes[0, 1].set_title('Full k-space (log)')
axes[0, 2].imshow(mask_4x, cmap='gray')
axes[0, 2].set_title(f'4x Mask ({mask_4x.mean()*100:.0f}% sampled)')
axes[0, 3].imshow(mask_8x, cmap='gray')
axes[0, 3].set_title(f'8x Mask ({mask_8x.mean()*100:.0f}% sampled)')

recon_full = kspace_to_image(kspace)
recon_4x = kspace_to_image(kspace * mask_4x)
recon_8x = kspace_to_image(kspace * mask_8x)

axes[1, 0].imshow(recon_full, cmap='gray')
axes[1, 0].set_title('Full Recon')
axes[1, 1].imshow(recon_4x, cmap='gray')
axes[1, 1].set_title('4x Zero-filled')
axes[1, 2].imshow(recon_8x, cmap='gray')
axes[1, 2].set_title('8x Zero-filled')
axes[1, 3].imshow(np.abs(recon_full - recon_4x), cmap='hot')
axes[1, 3].set_title('4x Error Map')

for ax in axes.ravel():
    ax.axis('off')
plt.tight_layout()
plt.show()

## Section 3.3: Baseline -- Zero-Filled Reconstruction

In [None]:
def evaluate_reconstruction(pred, target):
    """Compute PSNR and simple SSIM proxy."""
    mse = np.mean((pred - target) ** 2)
    psnr = 10 * np.log10(target.max()**2 / (mse + 1e-10))

    # Simple SSIM proxy
    mu_p, mu_t = pred.mean(), target.mean()
    sig_p, sig_t = pred.std(), target.std()
    sig_pt = np.mean((pred - mu_p) * (target - mu_t))
    c1, c2 = 0.01**2, 0.03**2
    ssim = ((2*mu_p*mu_t + c1) * (2*sig_pt + c2)) / \
           ((mu_p**2 + mu_t**2 + c1) * (sig_p**2 + sig_t**2 + c2))

    return {'psnr': psnr, 'ssim': ssim, 'mse': mse}

# Baseline evaluation
results_zf = []
for img in test_images[:50]:
    ks = image_to_kspace(img)
    mask = create_undersampling_mask(img.shape, acceleration=4)
    recon = kspace_to_image(ks * mask)
    results_zf.append(evaluate_reconstruction(recon, img))

mean_psnr = np.mean([r['psnr'] for r in results_zf])
mean_ssim = np.mean([r['ssim'] for r in results_zf])
print(f"Zero-filled baseline (4x): PSNR={mean_psnr:.1f} dB, SSIM={mean_ssim:.3f}")

## Section 3.4: NCSN Model for MRI

In [None]:
class MRIScoreNet(nn.Module):
    """Simple CNN-based score network for MRI images."""
    def __init__(self, n_noise_levels=50, base_ch=32):
        super().__init__()
        self.sigma_embed = nn.Embedding(n_noise_levels, base_ch)

        self.encoder = nn.Sequential(
            nn.Conv2d(1, base_ch, 3, padding=1), nn.SiLU(),
            nn.Conv2d(base_ch, base_ch, 3, padding=1), nn.SiLU(),
            nn.Conv2d(base_ch, base_ch*2, 3, stride=2, padding=1), nn.SiLU(),
            nn.Conv2d(base_ch*2, base_ch*2, 3, padding=1), nn.SiLU(),
        )

        self.middle = nn.Sequential(
            nn.Conv2d(base_ch*2, base_ch*2, 3, padding=1), nn.SiLU(),
            nn.Conv2d(base_ch*2, base_ch*2, 3, padding=1), nn.SiLU(),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(base_ch*2, base_ch, 4, stride=2, padding=1), nn.SiLU(),
            nn.Conv2d(base_ch, base_ch, 3, padding=1), nn.SiLU(),
            nn.Conv2d(base_ch, 1, 3, padding=1),
        )

        # FiLM conditioning layers
        self.film_scale = nn.Linear(base_ch, base_ch*2)
        self.film_bias = nn.Linear(base_ch, base_ch*2)

    def forward(self, x, sigma_idx):
        emb = self.sigma_embed(sigma_idx)  # (B, base_ch)
        scale = self.film_scale(emb).unsqueeze(-1).unsqueeze(-1)  # (B, 2*base_ch, 1, 1)
        bias = self.film_bias(emb).unsqueeze(-1).unsqueeze(-1)

        h = self.encoder(x)
        h = h * (1 + scale) + bias  # FiLM conditioning
        h = self.middle(h)
        return self.decoder(h)

model = MRIScoreNet(n_noise_levels=50).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## Section 3.5: Training

In [None]:
# Prepare data
class MRIDataset(Dataset):
    def __init__(self, images):
        self.images = torch.tensor(images, dtype=torch.float32).unsqueeze(1)  # (N,1,H,W)
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        return self.images[idx]

train_ds = MRIDataset(train_images)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)

# Noise schedule
L = 50
sigma_1, sigma_L = 1.0, 0.01
sigmas = torch.tensor([sigma_1 * (sigma_L / sigma_1) ** (i / (L-1)) for i in range(L)]).to(device)

# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
losses = []

for epoch in range(50):
    epoch_loss = 0
    for batch in train_dl:
        batch = batch.to(device)
        B = batch.shape[0]

        # Random noise levels
        idx = torch.randint(0, L, (B,), device=device)
        sigma = sigmas[idx].view(B, 1, 1, 1)

        # Add noise
        epsilon = torch.randn_like(batch)
        noisy = batch + sigma * epsilon
        target = -epsilon / sigma

        # Predict
        pred = model(noisy, idx)
        loss = (sigma.squeeze()**2 * (pred - target).flatten(1).pow(2).mean(1)).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss /= len(train_dl)
    losses.append(epoch_loss)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/50, Loss: {epoch_loss:.4f}")

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('NCSN Training on MRI Data')
plt.grid(True, alpha=0.3)
plt.show()

## Section 3.6: Reconstruction with Data Consistency

In [None]:
def ncsn_mri_reconstruct(model, kspace, mask, sigmas, n_steps=50, eps=5e-5):
    """Reconstruct MRI using NCSN + ALD + data consistency."""
    model.eval()
    H, W = kspace.shape

    # Initialize from zero-filled
    x = torch.tensor(kspace_to_image(kspace * mask), dtype=torch.float32)
    x = x.unsqueeze(0).unsqueeze(0).to(device)  # (1, 1, H, W)

    kspace_torch = torch.tensor(kspace, dtype=torch.complex64).to(device)
    mask_torch = torch.tensor(mask, dtype=torch.float32).to(device)

    sigma_L = sigmas[-1].item()

    with torch.no_grad():
        for i, sigma_val in enumerate(sigmas):
            alpha = eps * (sigma_val.item() / sigma_L) ** 2
            sigma_idx = torch.tensor([i], device=device)

            for t in range(n_steps):
                # Score update
                score = model(x, sigma_idx)
                noise = torch.randn_like(x)
                x = x + alpha * score + (2 * alpha) ** 0.5 * noise

                # Data consistency: replace measured k-space lines
                x_np = x.squeeze().cpu().numpy()
                x_kspace = np.fft.fftshift(np.fft.fft2(x_np))
                # Replace measured frequencies with actual measurements
                x_kspace = x_kspace * (1 - mask) + kspace * mask
                x_recon = np.abs(np.fft.ifft2(np.fft.ifftshift(x_kspace)))
                x = torch.tensor(x_recon, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

    return x.squeeze().cpu().numpy()

# Test reconstruction
test_img = test_images[0]
test_ks = image_to_kspace(test_img)
test_mask = create_undersampling_mask(test_img.shape, acceleration=4)

recon_ncsn = ncsn_mri_reconstruct(model, test_ks, test_mask, sigmas, n_steps=30, eps=1e-5)
recon_zf = kspace_to_image(test_ks * test_mask)

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
axes[0].imshow(test_img, cmap='gray')
axes[0].set_title('Ground Truth')
axes[1].imshow(recon_zf, cmap='gray')
axes[1].set_title(f'Zero-filled\nSSIM={evaluate_reconstruction(recon_zf, test_img)["ssim"]:.3f}')
axes[2].imshow(recon_ncsn, cmap='gray')
axes[2].set_title(f'NCSN Recon\nSSIM={evaluate_reconstruction(recon_ncsn, test_img)["ssim"]:.3f}')
axes[3].imshow(np.abs(test_img - recon_ncsn), cmap='hot', vmax=0.3)
axes[3].set_title('NCSN Error')
for ax in axes:
    ax.axis('off')
plt.tight_layout()
plt.show()

## Section 3.7: Error Analysis

In [None]:
# TODO: Evaluate across test set and analyze failure modes
def batch_evaluate(model, test_images, sigmas, acceleration=4, n_eval=50):
    """
    Evaluate NCSN reconstruction on multiple test images.

    Returns per-image metrics and identifies failure cases.
    """
    results = []
    for i in range(min(n_eval, len(test_images))):
        img = test_images[i]
        ks = image_to_kspace(img)
        mask = create_undersampling_mask(img.shape, acceleration=acceleration)

        recon_zf = kspace_to_image(ks * mask)
        recon_ncsn = ncsn_mri_reconstruct(model, ks, mask, sigmas, n_steps=20, eps=1e-5)

        metrics_zf = evaluate_reconstruction(recon_zf, img)
        metrics_ncsn = evaluate_reconstruction(recon_ncsn, img)
        results.append({
            'idx': i,
            'zf_ssim': metrics_zf['ssim'],
            'ncsn_ssim': metrics_ncsn['ssim'],
            'improvement': metrics_ncsn['ssim'] - metrics_zf['ssim']
        })

        if (i+1) % 10 == 0:
            print(f"Evaluated {i+1}/{n_eval}")

    return results

# results = batch_evaluate(model, test_images, sigmas)

## Section 3.8: Deployment Benchmarking

In [None]:
# TODO: Benchmark inference speed
def benchmark_speed(model, sigmas, image_size=64, n_warmup=3, n_runs=10):
    """
    Measure reconstruction speed.

    Args:
        model: trained MRIScoreNet
        sigmas: noise levels
        image_size: image dimension
        n_warmup: warmup runs (excluded)
        n_runs: timed runs

    Report: mean and std of reconstruction time per image.
    """
    import time
    # YOUR CODE HERE
    pass

## Section 3.9: Ethics Considerations

In [None]:
# Discussion cell -- no code needed
print("""
Ethics and Safety Considerations for AI-Accelerated MRI Reconstruction:

1. DIAGNOSTIC SAFETY: AI-reconstructed images must not introduce false anatomical
   features (hallucinations) that could lead to misdiagnosis. Score-based methods
   are less prone to this than direct mapping networks because they combine a
   learned prior with data consistency constraints.

2. FAILURE MODES: The system must detect when reconstruction quality is insufficient
   and flag the case for re-scanning rather than presenting a poor reconstruction
   to the radiologist.

3. EQUITY: Reconstruction quality must be consistent across patient demographics.
   Training data must be representative of the patient population.

4. TRANSPARENCY: Radiologists must be informed when AI-assisted reconstruction
   was used and have access to the zero-filled baseline for comparison.

5. REGULATORY: FDA 510(k) clearance requires extensive clinical validation
   demonstrating substantial equivalence to existing reconstruction methods.
""")