# Stage 1: Training the Classical Dream Engine (CVAE)

This notebook trains a Conditional Variational Autoencoder (CVAE) on BraTS brain tumor MRI data.

## What is the CVAE "Dream Engine"?

The CVAE learns the probability distribution P(segmentation | MRI_image). Given an MRI scan, it can:
- Generate multiple plausible tumor segmentations
- Capture the inherent uncertainty in tumor boundaries
- Create a "universe of possibilities" that we'll interrogate with quantum algorithms in Stage 2

## Workflow
1. Load and visualize BraTS dataset
2. Setup CVAE model
3. Train the model
4. Visualize results
5. Generate multiple samples ("dreams")
6. Analyze uncertainty and clinical properties

## 1. Setup and Imports

In [None]:
import sys
sys.path.append('../src/classical_model')

import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from dataset import BraTSDataset, get_dataloaders
from cvae import CVAE
from train import train, CVAELoss
from sampler import CVAESampler

# Set style
plt.style.use('default')
%matplotlib inline

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Visualize BraTS Dataset

In [None]:
# Load dataset
dataset = BraTSDataset(
    images_dir='../data/raw/imagesTr',
    labels_dir='../data/raw/labelsTr',
    crop_size=(128, 128, 128),
    normalize=True
)

print(f"Dataset size: {len(dataset)} samples")

In [None]:
# Visualize a sample
def visualize_sample(image, label, slice_idx=64):
    """
    Visualize MRI modalities and segmentation mask
    
    Args:
        image: (4, D, H, W) - 4 modalities
        label: (4, D, H, W) - one-hot encoded
        slice_idx: Which slice to show
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    modality_names = ['FLAIR', 'T1', 'T1ce', 'T2']
    class_names = ['Background', 'Necrotic', 'Edema', 'Enhancing']
    
    # Show 4 MRI modalities
    for i in range(4):
        ax = axes[i // 2, i % 2] if i < 2 else axes[0, 2] if i == 2 else axes[1, 0]
        ax.imshow(image[i, slice_idx].numpy(), cmap='gray')
        ax.set_title(f'{modality_names[i]}')
        ax.axis('off')
    
    # Show segmentation (convert one-hot to class labels)
    seg = torch.argmax(label, dim=0)[slice_idx].numpy()
    
    ax = axes[1, 1]
    im = ax.imshow(seg, cmap='tab10', vmin=0, vmax=3)
    ax.set_title('Segmentation Mask')
    ax.axis('off')
    plt.colorbar(im, ax=ax, ticks=[0, 1, 2, 3], label='Class')
    
    # Overlay on T1ce
    ax = axes[1, 2]
    ax.imshow(image[2, slice_idx].numpy(), cmap='gray')
    masked_seg = np.ma.masked_where(seg == 0, seg)
    ax.imshow(masked_seg, cmap='tab10', alpha=0.5, vmin=0, vmax=3)
    ax.set_title('T1ce + Segmentation')
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize first sample
image, label = dataset[0]
visualize_sample(image, label, slice_idx=64)

## 3. Setup CVAE Model

In [None]:
# Configuration
CONFIG = {
    'batch_size': 2,
    'num_epochs': 50,
    'learning_rate': 1e-4,
    'latent_dim': 256,
    'base_channels': 16,
    'beta': 0.001,  # KL divergence weight
    'crop_size': (128, 128, 128),
    'train_split': 0.8,
    'num_workers': 4,
}

print("Training Configuration:")
for key, val in CONFIG.items():
    print(f"  {key}: {val}")

In [None]:
# Create dataloaders
train_loader, val_loader = get_dataloaders(
    images_dir='../data/raw/imagesTr',
    labels_dir='../data/raw/labelsTr',
    batch_size=CONFIG['batch_size'],
    train_split=CONFIG['train_split'],
    num_workers=CONFIG['num_workers'],
    crop_size=CONFIG['crop_size']
)

In [None]:
# Create CVAE model
model = CVAE(
    latent_dim=CONFIG['latent_dim'],
    base_channels=CONFIG['base_channels']
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model created with {total_params:,} parameters ({total_params * 4 / 1024**2:.2f} MB)")

## 4. Train the Model

**Note**: Training will take several hours on GPU. For a quick demo, reduce `num_epochs` to 5-10.

The model learns:
- **Reconstruction**: Generate segmentations that match the ground truth
- **Latent structure**: Organize the latent space for easy sampling

In [None]:
# Train the model
# IMPORTANT: This will take a long time! Consider reducing num_epochs for testing

history = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=CONFIG['num_epochs'],
    learning_rate=CONFIG['learning_rate'],
    device=device,
    beta=CONFIG['beta'],
    save_dir='../models'
)

## 5. Visualize Training History

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Total loss
ax = axes[0, 0]
ax.plot([h['total'] for h in history['train']], label='Train')
ax.plot([h['total'] for h in history['val']], label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('Total Loss')
ax.set_title('Total Loss')
ax.legend()
ax.grid(True)

# Dice score
ax = axes[0, 1]
ax.plot([h['dice_score'] for h in history['train']], label='Train')
ax.plot([h['dice_score'] for h in history['val']], label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('Dice Score')
ax.set_title('Dice Score')
ax.legend()
ax.grid(True)

# KL Divergence
ax = axes[1, 0]
ax.plot([h['kl'] for h in history['train']], label='Train')
ax.plot([h['kl'] for h in history['val']], label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('KL Divergence')
ax.set_title('KL Divergence')
ax.legend()
ax.grid(True)

# Reconstruction loss
ax = axes[1, 1]
ax.plot([h['recon'] for h in history['train']], label='Train')
ax.plot([h['recon'] for h in history['val']], label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('Reconstruction Loss')
ax.set_title('Reconstruction Loss')
ax.legend()
ax.grid(True)

plt.tight_layout()
plt.show()

## 6. Generate Samples from the Dream Engine

Now the exciting part! Let's use our trained CVAE to generate multiple plausible segmentations for a single MRI scan.

In [None]:
# Create sampler
sampler = CVAESampler(model, device)

# Get a test sample
test_image, test_label = dataset[10]
test_image = test_image.unsqueeze(0).to(device)

# Generate multiple samples
print("Generating samples from the Dream Engine...")
num_samples = 20
samples = sampler.generate_samples(test_image, num_samples=num_samples)
print(f"Generated {num_samples} samples!")

In [None]:
# Visualize multiple samples
predictions = sampler.get_class_predictions(samples)

fig, axes = plt.subplots(4, 5, figsize=(20, 16))
slice_idx = 64

for i in range(min(20, num_samples)):
    ax = axes[i // 5, i % 5]
    ax.imshow(predictions[i, slice_idx], cmap='tab10', vmin=0, vmax=3)
    ax.set_title(f'Sample {i+1}')
    ax.axis('off')

plt.suptitle('20 Different "Dreams" of the Tumor Segmentation', fontsize=16)
plt.tight_layout()
plt.show()

## 7. Analyze Uncertainty

The CVAE captures uncertainty about tumor boundaries. Let's visualize this uncertainty.

In [None]:
# Compute uncertainty map
analysis = sampler.analyze_samples(samples)

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Original T1ce image
ax = axes[0]
ax.imshow(test_image[0, 2, slice_idx].cpu().numpy(), cmap='gray')
ax.set_title('T1ce MRI')
ax.axis('off')

# Ground truth
ax = axes[1]
gt_seg = torch.argmax(test_label, dim=0)[slice_idx].numpy()
ax.imshow(gt_seg, cmap='tab10', vmin=0, vmax=3)
ax.set_title('Ground Truth')
ax.axis('off')

# Uncertainty map
ax = axes[2]
im = ax.imshow(analysis['uncertainty_map'][slice_idx], cmap='hot')
ax.set_title('Uncertainty Map (Entropy)')
ax.axis('off')
plt.colorbar(im, ax=ax, label='Entropy')

plt.tight_layout()
plt.show()

print("\nHigh uncertainty (red regions) indicates where the model is most uncertain")
print("about the tumor boundaries - typically at the edges!")

## 8. Clinical Analysis: Multifocality

A key clinical question: **Is the tumor multifocal (multiple disconnected parts)?**

Classical approach: Generate many samples and count.

**Quantum approach (Stage 2)**: Use Quantum Amplitude Estimation for quadratic speedup!

In [None]:
# Compute multifocal probability
multifocal_prob = analysis['multifocal_probability']

print(f"Multifocal Probability: {multifocal_prob:.3f}")
print(f"\nInterpretation: {multifocal_prob*100:.1f}% of generated samples show")
print(f"a tumor with multiple disconnected components.")
print(f"\nThis is a clinically actionable uncertainty metric!")

In [None]:
# Volume distribution
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

volume_keys = ['necrotic', 'edema', 'enhancing', 'total']
titles = ['Necrotic Core', 'Edema', 'Enhancing Tumor', 'Total Tumor']

for i, (key, title) in enumerate(zip(volume_keys, titles)):
    ax = axes[i // 2, i % 2]
    volumes = analysis['volumes'][key]
    ax.hist(volumes, bins=20, alpha=0.7, edgecolor='black')
    ax.axvline(volumes.mean(), color='red', linestyle='--', label=f'Mean: {volumes.mean():.0f}')
    ax.set_xlabel('Volume (voxels)')
    ax.set_ylabel('Frequency')
    ax.set_title(f'{title} Volume Distribution')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Summary

### What We Built:
1. **CVAE "Dream Engine"**: Generates multiple plausible tumor segmentations
2. **Uncertainty Quantification**: Maps out where the model is uncertain
3. **Clinical Metrics**: Computes actionable probabilities (e.g., multifocality)

### Next Steps (Stage 2):
Now that we have our "universe of possibilities", we'll use **Quantum Amplitude Estimation (QAE)** to:
- Load all samples into a quantum superposition
- Query clinical properties with quadratic speedup
- Demonstrate the hybrid quantum-classical advantage!

### Key Files Created:
- `dataset.py`: BraTS data loading and preprocessing
- `cvae.py`: CVAE architecture
- `train.py`: Training loop and loss functions
- `sampler.py`: Sampling and analysis utilities

In [None]:
print("[SUCCESS] Stage 1 Complete!")
print("\nThe Classical Dream Engine is ready.")
print("Proceed to Stage 2: Quantum Interrogation with QAE")