# NLBS-CVAE: Conditional VAE for Mammography Generation

This notebook provides a complete training pipeline for the NLBS-CVAE model on Google Colab.

## Features:
- 🚀 Optimized for Colab GPU/TPU
- 💾 Google Drive integration for data and checkpoints
- 📊 Real-time monitoring with TensorBoard
- 🔄 Automatic session management
- 📈 Memory optimization for long training sessions


## 1. Environment Setup

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU detected. Training will be slow on CPU.")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create necessary directories
import os
os.makedirs('/content/results', exist_ok=True)
os.makedirs('/content/results/checkpoints', exist_ok=True)
os.makedirs('/content/results/logs', exist_ok=True)
os.makedirs('/content/results/galleries', exist_ok=True)
os.makedirs('/content/data', exist_ok=True)

In [None]:
# Clone the repository
!git clone https://github.com/FructueuxCODJIA/nlbs-cvae.git /content/nlbs-cvae
%cd /content/nlbs-cvae

In [None]:
# Install requirements
!pip install -r colab/requirements_colab.txt

# Install additional packages that might be needed
!pip install -q gdown  # For downloading files from Google Drive

## 2. Data Setup

Choose one of the following options for data setup:

### Option 1: Use Demo/Synthetic Data (Recommended for testing)

In [None]:
# Create synthetic demo data for testing
import pandas as pd
import numpy as np
from pathlib import Path
import pydicom
from pydicom.dataset import Dataset, FileDataset
from pydicom.uid import ExplicitVRLittleEndian
import tempfile
import os

def create_synthetic_mammogram(width=512, height=512):
    """Create a synthetic mammogram-like image"""
    # Create base tissue pattern
    x, y = np.meshgrid(np.linspace(0, 1, width), np.linspace(0, 1, height))
    
    # Simulate breast tissue with varying density
    tissue = np.exp(-((x-0.5)**2 + (y-0.5)**2) * 3)
    
    # Add some texture
    noise = np.random.normal(0, 0.1, (height, width))
    texture = np.sin(x * 20) * np.sin(y * 20) * 0.1
    
    # Combine and normalize
    image = tissue + texture + noise
    image = np.clip(image, 0, 1)
    
    # Convert to uint16 (typical for mammograms)
    image = (image * 65535).astype(np.uint16)
    
    return image

def create_synthetic_dicom(image_array, filename):
    """Create a synthetic DICOM file"""
    # Create file meta information
    file_meta = Dataset()
    file_meta.MediaStorageSOPClassUID = '1.2.840.10008.5.1.4.1.1.1.2'  # Digital Mammography X-Ray Image
    file_meta.MediaStorageSOPInstanceUID = '1.2.3.4.5.6.7.8.9.10'
    file_meta.ImplementationClassUID = '1.2.3.4.5.6.7.8.9.10'
    file_meta.TransferSyntaxUID = ExplicitVRLittleEndian
    
    # Create the main dataset
    ds = FileDataset(filename, {}, file_meta=file_meta, preamble=b"\0" * 128)
    
    # Add required DICOM tags
    ds.PatientName = "Demo^Patient"
    ds.PatientID = "DEMO001"
    ds.Modality = "MG"
    ds.StudyInstanceUID = "1.2.3.4.5.6.7.8.9.10.11"
    ds.SeriesInstanceUID = "1.2.3.4.5.6.7.8.9.10.11.12"
    ds.SOPInstanceUID = "1.2.3.4.5.6.7.8.9.10.11.12.13"
    ds.SOPClassUID = '1.2.840.10008.5.1.4.1.1.1.2'
    
    # Image-specific tags
    ds.SamplesPerPixel = 1
    ds.PhotometricInterpretation = "MONOCHROME2"
    ds.Rows, ds.Columns = image_array.shape
    ds.BitsAllocated = 16
    ds.BitsStored = 16
    ds.HighBit = 15
    ds.PixelRepresentation = 0
    ds.PixelData = image_array.tobytes()
    
    return ds

# Create demo dataset
print("Creating synthetic demo dataset...")

# Create directory structure
demo_dir = Path('/content/data/demo')
demo_dir.mkdir(parents=True, exist_ok=True)

# Generate synthetic data
metadata_rows = []
views = ['CC', 'MLO']
lateralities = ['L', 'R']
ages = [45, 52, 58, 63, 67]

for patient_id in range(1, 21):  # 20 patients
    age = np.random.choice(ages)
    cancer = np.random.choice([0, 1], p=[0.7, 0.3])  # 30% cancer cases
    
    for laterality in lateralities:
        for view in views:
            # Create synthetic image
            image = create_synthetic_mammogram()
            
            # Create filename
            filename = f"patient_{patient_id:03d}_{laterality}_{view}.dcm"
            filepath = demo_dir / filename
            
            # Create and save DICOM
            ds = create_synthetic_dicom(image, str(filepath))
            ds.save_as(filepath)
            
            # Add to metadata
            metadata_rows.append({
                'File Path': f'demo/{filename}',
                'Image Laterality': laterality,
                'View Position': view,
                'Age': age,
                'Cancer': cancer,
                'False Positive': np.random.choice([0, 1], p=[0.9, 0.1])
            })

# Create metadata CSV
metadata_df = pd.DataFrame(metadata_rows)
metadata_df.to_csv('/content/data/metadata.csv', index=False)

print(f"Created {len(metadata_rows)} synthetic mammogram files")
print(f"Metadata saved to /content/data/metadata.csv")
print(f"Images saved to {demo_dir}")

# Update config paths
data_config = {
    'csv_path': '/content/data/metadata.csv',
    'image_dir': '/content/data'
}

print("\n✅ Demo data setup complete!")

### Option 2: Upload Your Own Data

In [None]:
# Uncomment and run this cell if you want to upload your own data

# from google.colab import files
# import zipfile
# import shutil

# print("Please upload your data as a ZIP file containing:")
# print("1. metadata.csv - with columns: File Path, Image Laterality, View Position, Age, Cancer, False Positive")
# print("2. images/ folder - containing your DICOM files")

# uploaded = files.upload()

# # Extract uploaded files
# for filename in uploaded.keys():
#     if filename.endswith('.zip'):
#         with zipfile.ZipFile(filename, 'r') as zip_ref:
#             zip_ref.extractall('/content/data')
#         print(f"Extracted {filename} to /content/data")

# # Update config paths
# data_config = {
#     'csv_path': '/content/data/metadata.csv',
#     'image_dir': '/content/data'
# }

# print("\n✅ Data upload complete!")

### Option 3: Load from Google Drive

In [None]:
# Uncomment and modify this cell if your data is in Google Drive

# import shutil

# # Specify your Google Drive data path
# drive_data_path = '/content/drive/MyDrive/NLBS_Data'  # Modify this path

# if os.path.exists(drive_data_path):
#     # Copy metadata
#     if os.path.exists(f'{drive_data_path}/metadata.csv'):
#         shutil.copy(f'{drive_data_path}/metadata.csv', '/content/data/metadata.csv')
#         print("Metadata copied from Drive")
#     
#     # Create symlink to images (faster than copying)
#     if os.path.exists(f'{drive_data_path}/images'):
#         if os.path.exists('/content/data/images'):
#             os.remove('/content/data/images')
#         os.symlink(f'{drive_data_path}/images', '/content/data/images')
#         print("Images linked from Drive")
#     
#     # Update config paths
#     data_config = {
#         'csv_path': '/content/data/metadata.csv',
#         'image_dir': '/content/data'
#     }
#     
#     print("\n✅ Data loaded from Google Drive!")
# else:
#     print(f"❌ Data path not found: {drive_data_path}")
#     print("Please check your Google Drive path and try again.")

### Option 4: NLBSP Dataset (Real Mammography Data)

In [None]:
# Setup for real NLBSP mammography dataset
# Uncomment and run this cell if you have the NLBSP dataset

# import sys
# sys.path.append('/content/nlbs-cvae')
# from colab.utils.nlbsp_data_prep import setup_nlbsp_for_training

# # Choose data source: "upload" or "drive"
# data_source = "upload"  # Change to "drive" if data is in Google Drive

# print("🔄 Setting up NLBSP dataset...")
# csv_path, image_dir = setup_nlbsp_for_training(data_source=data_source)

# if csv_path and image_dir:
#     # Use the real data configuration
#     import yaml
#     with open('/content/nlbs-cvae/colab/configs/colab_real_data_config.yaml', 'r') as f:
#         config = yaml.safe_load(f)
#     
#     # Update paths
#     config['data']['csv_path'] = csv_path
#     config['data']['image_dir'] = image_dir
#     
#     data_config = {
#         'csv_path': csv_path,
#         'image_dir': image_dir
#     }
#     
#     print(f"✅ NLBSP dataset ready!")
#     print(f"CSV: {csv_path}")
#     print(f"Images: {image_dir}")
#     print("\n📊 This configuration is optimized for real mammography data")
# else:
#     print("❌ Failed to setup NLBSP data")
#     print("Please check your data files and try again")

## 3. Model Training

In [None]:
# Import necessary modules
import sys
sys.path.append('/content/nlbs-cvae')

import yaml
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import gc

# Import project modules
from models import ConditionalVAE
from models.losses import VAELoss
from data import MammographyDataset
from utils.training_utils import (
    setup_logging, 
    save_checkpoint, 
    load_checkpoint,
    create_optimizer,
    create_scheduler,
    set_seed
)

print("✅ Modules imported successfully!")

In [None]:
# Load and modify configuration
with open('/content/nlbs-cvae/colab/configs/colab_training_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Update data paths if using custom data
if 'data_config' in locals():
    config['data'].update(data_config)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config['hardware']['device'] = str(device)

print(f"Using device: {device}")
print(f"Data path: {config['data']['csv_path']}")
print(f"Image dir: {config['data']['image_dir']}")

# Set random seed
set_seed(config['project']['seed'])

print("✅ Configuration loaded!")

In [None]:
# Create dataset and data loaders
print("Creating dataset...")

# Load metadata
metadata_df = pd.read_csv(config['data']['csv_path'])
print(f"Loaded {len(metadata_df)} samples")

# Create dataset
dataset = MammographyDataset(
    csv_path=config['data']['csv_path'],
    image_dir=config['data']['image_dir'],
    resolution=config['data']['resolution'],
    augment=True,
    patch_stride=config['data']['patch_stride'],
    min_foreground_frac=config['data']['min_foreground_frac']
)

print(f"Dataset created with {len(dataset)} patches")

# Split dataset
train_size = int(config['data']['train_split'] * len(dataset))
val_size = int(config['data']['val_split'] * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size]
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['data']['batch_size'],
    shuffle=True,
    num_workers=config['data']['num_workers'],
    pin_memory=True if device.type == 'cuda' else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['data']['batch_size'],
    shuffle=False,
    num_workers=config['data']['num_workers'],
    pin_memory=True if device.type == 'cuda' else False
)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
print("✅ Data loaders created!")

In [None]:
# Create model
print("Creating model...")

model = ConditionalVAE(
    in_channels=config['data']['channels'],
    image_size=config['data']['resolution'],
    latent_dim=config['model']['latent_dim'],
    condition_embed_dim=config['model']['condition_embed_dim'],
    encoder_channels=config['model']['encoder']['channels'],
    decoder_channels=config['model']['decoder']['channels']
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model created with {total_params:,} total parameters")
print(f"Trainable parameters: {trainable_params:,}")

# Create loss function
criterion = VAELoss(
    reconstruction_weight=config['training']['loss']['reconstruction_weight'],
    kl_weight=config['training']['loss']['kl_weight'],
    edge_weight=config['training']['loss']['edge_weight']
)

# Create optimizer
optimizer = create_optimizer(model, config['training'])

# Create scheduler
scheduler = create_scheduler(optimizer, config['training'], len(train_loader))

print("✅ Model, loss, optimizer, and scheduler created!")

In [None]:
# Setup logging
writer = SummaryWriter(config['logging']['log_dir'])

# Setup logging to file
import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('/content/results/training.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

print("✅ Logging setup complete!")

# Load TensorBoard extension
%load_ext tensorboard
print("\n📊 TensorBoard will be available at the end of this cell")
%tensorboard --logdir /content/results/logs

In [None]:
# Training loop
print("Starting training...")

# Training variables
best_val_loss = float('inf')
patience_counter = 0
global_step = 0

# Mixed precision training
scaler = torch.cuda.amp.GradScaler() if config['training']['mixed_precision'] == 'fp16' else None

for epoch in range(config['training']['num_epochs']):
    print(f"\n=== Epoch {epoch+1}/{config['training']['num_epochs']} ===")
    
    # Training phase
    model.train()
    train_loss = 0.0
    train_recon_loss = 0.0
    train_kl_loss = 0.0
    
    # KL annealing
    kl_weight = min(1.0, epoch / config['training']['loss']['kl_anneal_epochs'])
    criterion.kl_weight = kl_weight
    
    progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
    
    for batch_idx, batch in enumerate(progress_bar):
        images = batch['image'].to(device)
        conditions = {k: v.to(device) for k, v in batch['conditions'].items()}
        
        optimizer.zero_grad()
        
        # Forward pass with mixed precision
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(images, conditions)
                loss_dict = criterion(outputs, images)
                loss = loss_dict['total_loss']
            
            # Backward pass
            scaler.scale(loss).backward()
            
            # Gradient clipping
            if config['training']['gradient_clip'] > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['training']['gradient_clip'])
            
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images, conditions)
            loss_dict = criterion(outputs, images)
            loss = loss_dict['total_loss']
            
            loss.backward()
            
            if config['training']['gradient_clip'] > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['training']['gradient_clip'])
            
            optimizer.step()
        
        if scheduler is not None:
            scheduler.step()
        
        # Update metrics
        train_loss += loss.item()
        train_recon_loss += loss_dict['reconstruction_loss'].item()
        train_kl_loss += loss_dict['kl_loss'].item()
        
        # Log to TensorBoard
        if global_step % config['logging']['log_every_n_steps'] == 0:
            writer.add_scalar('Train/Loss', loss.item(), global_step)
            writer.add_scalar('Train/Reconstruction_Loss', loss_dict['reconstruction_loss'].item(), global_step)
            writer.add_scalar('Train/KL_Loss', loss_dict['kl_loss'].item(), global_step)
            writer.add_scalar('Train/KL_Weight', kl_weight, global_step)
            writer.add_scalar('Train/Learning_Rate', optimizer.param_groups[0]['lr'], global_step)
        
        global_step += 1
        
        # Update progress bar
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Recon': f'{loss_dict["reconstruction_loss"].item():.4f}',
            'KL': f'{loss_dict["kl_loss"].item():.4f}',
            'KL_w': f'{kl_weight:.3f}'
        })
        
        # Memory cleanup
        if batch_idx % 50 == 0:
            torch.cuda.empty_cache()
    
    # Calculate average training losses
    avg_train_loss = train_loss / len(train_loader)
    avg_train_recon = train_recon_loss / len(train_loader)
    avg_train_kl = train_kl_loss / len(train_loader)
    
    print(f"Train Loss: {avg_train_loss:.4f} (Recon: {avg_train_recon:.4f}, KL: {avg_train_kl:.4f})")
    
    # Validation phase
    if (epoch + 1) % config['evaluation']['val_every_n_epochs'] == 0:
        model.eval()
        val_loss = 0.0
        val_recon_loss = 0.0
        val_kl_loss = 0.0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                images = batch['image'].to(device)
                conditions = {k: v.to(device) for k, v in batch['conditions'].items()}
                
                outputs = model(images, conditions)
                loss_dict = criterion(outputs, images)
                
                val_loss += loss_dict['total_loss'].item()
                val_recon_loss += loss_dict['reconstruction_loss'].item()
                val_kl_loss += loss_dict['kl_loss'].item()
        
        avg_val_loss = val_loss / len(val_loader)
        avg_val_recon = val_recon_loss / len(val_loader)
        avg_val_kl = val_kl_loss / len(val_loader)
        
        print(f"Val Loss: {avg_val_loss:.4f} (Recon: {avg_val_recon:.4f}, KL: {avg_val_kl:.4f})")
        
        # Log validation metrics
        writer.add_scalar('Val/Loss', avg_val_loss, epoch)
        writer.add_scalar('Val/Reconstruction_Loss', avg_val_recon, epoch)
        writer.add_scalar('Val/KL_Loss', avg_val_kl, epoch)
        
        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            
            # Save best model
            save_checkpoint(
                model, optimizer, scheduler, epoch, avg_val_loss,
                '/content/results/checkpoints/best_model.pth'
            )
            print("💾 Best model saved!")
        else:
            patience_counter += 1
        
        if patience_counter >= config['training']['early_stopping_patience']:
            print(f"Early stopping triggered after {patience_counter} epochs without improvement")
            break
    
    # Save checkpoint
    if (epoch + 1) % config['training']['save_every_n_epochs'] == 0:
        save_checkpoint(
            model, optimizer, scheduler, epoch, avg_train_loss,
            f'/content/results/checkpoints/checkpoint_epoch_{epoch+1}.pth'
        )
        print(f"💾 Checkpoint saved for epoch {epoch+1}")
    
    # Generate sample images
    if (epoch + 1) % config['logging']['save_images_every_n_epochs'] == 0:
        model.eval()
        with torch.no_grad():
            # Get a batch for reconstruction
            sample_batch = next(iter(val_loader))
            sample_images = sample_batch['image'][:4].to(device)
            sample_conditions = {k: v[:4].to(device) for k, v in sample_batch['conditions'].items()}
            
            # Reconstruct
            outputs = model(sample_images, sample_conditions)
            reconstructions = outputs['reconstruction']
            
            # Generate new samples
            generated = model.sample(sample_conditions, batch_size=4)
            
            # Create comparison grid
            fig, axes = plt.subplots(3, 4, figsize=(12, 9))
            
            for i in range(4):
                # Original
                axes[0, i].imshow(sample_images[i, 0].cpu().numpy(), cmap='gray')
                axes[0, i].set_title(f'Original {i+1}')
                axes[0, i].axis('off')
                
                # Reconstruction
                axes[1, i].imshow(reconstructions[i, 0].cpu().numpy(), cmap='gray')
                axes[1, i].set_title(f'Reconstruction {i+1}')
                axes[1, i].axis('off')
                
                # Generated
                axes[2, i].imshow(generated[i, 0].cpu().numpy(), cmap='gray')
                axes[2, i].set_title(f'Generated {i+1}')
                axes[2, i].axis('off')
            
            plt.tight_layout()
            plt.savefig(f'/content/results/galleries/epoch_{epoch+1}_samples.png', dpi=150, bbox_inches='tight')
            plt.show()
            
            # Log to TensorBoard
            writer.add_figure('Samples/Comparison', fig, epoch)
            plt.close(fig)
    
    # Memory cleanup
    if (epoch + 1) % config['colab']['clear_cache_every_n_epochs'] == 0:
        torch.cuda.empty_cache()
        gc.collect()
        print("🧹 Memory cache cleared")

print("\n🎉 Training completed!")
writer.close()

## 4. Model Evaluation and Generation

In [None]:
# Load best model for evaluation
checkpoint_path = '/content/results/checkpoints/best_model.pth'
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✅ Loaded best model from epoch {checkpoint['epoch']}")
    print(f"Best validation loss: {checkpoint['loss']:.4f}")
else:
    print("⚠️ No checkpoint found, using current model state")

model.eval()

In [None]:
# Generate diverse samples
print("Generating diverse mammogram samples...")

# Define different conditions to test
test_conditions = [
    {'view': 0, 'laterality': 0, 'age_bin': 1, 'cancer': 0, 'false_positive': 0},  # CC, Left, Young, No cancer
    {'view': 1, 'laterality': 1, 'age_bin': 2, 'cancer': 1, 'false_positive': 0},  # MLO, Right, Middle, Cancer
    {'view': 0, 'laterality': 1, 'age_bin': 3, 'cancer': 0, 'false_positive': 1},  # CC, Right, Old, False positive
    {'view': 1, 'laterality': 0, 'age_bin': 0, 'cancer': 1, 'false_positive': 0},  # MLO, Left, Very young, Cancer
]

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

with torch.no_grad():
    for i, condition_dict in enumerate(test_conditions):
        # Convert to tensor format
        conditions = {}
        for key, value in condition_dict.items():
            conditions[key] = torch.tensor([value], device=device)
        
        # Generate 2 samples for each condition
        generated = model.sample(conditions, batch_size=2)
        
        for j in range(2):
            axes[j, i].imshow(generated[j, 0].cpu().numpy(), cmap='gray')
            
            # Create title with condition info
            view_name = 'CC' if condition_dict['view'] == 0 else 'MLO'
            lat_name = 'Left' if condition_dict['laterality'] == 0 else 'Right'
            cancer_status = 'Cancer' if condition_dict['cancer'] == 1 else 'Normal'
            
            title = f'{view_name}, {lat_name}\n{cancer_status}'
            axes[j, i].set_title(title, fontsize=10)
            axes[j, i].axis('off')

plt.suptitle('Generated Mammograms with Different Conditions', fontsize=14)
plt.tight_layout()
plt.savefig('/content/results/galleries/diverse_generations.png', dpi=150, bbox_inches='tight')
plt.show()

print("✅ Diverse samples generated and saved!")

In [None]:
# Latent space interpolation
print("Performing latent space interpolation...")

with torch.no_grad():
    # Get two random samples from validation set
    sample_batch = next(iter(val_loader))
    img1, img2 = sample_batch['image'][:2].to(device)
    cond1 = {k: v[:1].to(device) for k, v in sample_batch['conditions'].items()}
    cond2 = {k: v[1:2].to(device) for k, v in sample_batch['conditions'].items()}
    
    # Encode to latent space
    outputs1 = model.encode(img1.unsqueeze(0), cond1)
    outputs2 = model.encode(img2.unsqueeze(0), cond2)
    
    z1 = outputs1['z']
    z2 = outputs2['z']
    
    # Interpolate between latent codes
    n_steps = 8
    alphas = torch.linspace(0, 1, n_steps, device=device)
    
    interpolated_images = []
    
    for alpha in alphas:
        # Interpolate latent codes
        z_interp = (1 - alpha) * z1 + alpha * z2
        
        # Use first condition for consistency
        decoded = model.decode(z_interp, cond1)
        interpolated_images.append(decoded['reconstruction'])
    
    # Plot interpolation
    fig, axes = plt.subplots(1, n_steps, figsize=(16, 2))
    
    for i, img in enumerate(interpolated_images):
        axes[i].imshow(img[0, 0].cpu().numpy(), cmap='gray')
        axes[i].set_title(f'α={alphas[i]:.2f}')
        axes[i].axis('off')
    
    plt.suptitle('Latent Space Interpolation', fontsize=14)
    plt.tight_layout()
    plt.savefig('/content/results/galleries/latent_interpolation.png', dpi=150, bbox_inches='tight')
    plt.show()

print("✅ Latent interpolation completed!")

## 5. Save Results to Google Drive

In [None]:
# Backup results to Google Drive
import shutil
from datetime import datetime

# Create timestamped backup folder
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_dir = f'/content/drive/MyDrive/NLBS_CVAE_Results_{timestamp}'
os.makedirs(backup_dir, exist_ok=True)

# Copy results
print("Backing up results to Google Drive...")

# Copy checkpoints
if os.path.exists('/content/results/checkpoints'):
    shutil.copytree('/content/results/checkpoints', f'{backup_dir}/checkpoints')
    print("✅ Checkpoints backed up")

# Copy galleries
if os.path.exists('/content/results/galleries'):
    shutil.copytree('/content/results/galleries', f'{backup_dir}/galleries')
    print("✅ Generated images backed up")

# Copy logs
if os.path.exists('/content/results/logs'):
    shutil.copytree('/content/results/logs', f'{backup_dir}/logs')
    print("✅ Training logs backed up")

# Copy training log
if os.path.exists('/content/results/training.log'):
    shutil.copy('/content/results/training.log', f'{backup_dir}/training.log')
    print("✅ Training log backed up")

# Save configuration
with open(f'{backup_dir}/config.yaml', 'w') as f:
    yaml.dump(config, f, default_flow_style=False)
print("✅ Configuration saved")

print(f"\n🎉 All results backed up to: {backup_dir}")
print("\n📋 Summary:")
print(f"- Training completed with {config['training']['num_epochs']} epochs")
print(f"- Best validation loss: {best_val_loss:.4f}")
print(f"- Model parameters: {trainable_params:,}")
print(f"- Results saved to Google Drive: {backup_dir}")

## 6. Next Steps

### 🎯 What you can do now:

1. **Experiment with different configurations**:
   - Modify `colab_training_config.yaml` for different architectures
   - Try different loss weights and training parameters

2. **Scale up with real data**:
   - Upload your mammography dataset to Google Drive
   - Modify the data loading section to use your real data
   - Increase model size and training epochs

3. **Advanced features**:
   - Enable W&B logging for better experiment tracking
   - Implement additional evaluation metrics (FID, LPIPS)
   - Add more sophisticated data augmentation

4. **Model deployment**:
   - Export trained model for inference
   - Create a simple web interface for generation
   - Integrate with medical imaging workflows

### 📚 Resources:
- [Original repository](https://github.com/FructueuxCODJIA/nlbs-cvae)
- [Google Colab documentation](https://colab.research.google.com/)
- [PyTorch documentation](https://pytorch.org/docs/)
