# Phase 0 Pre-training - Unified Multi-Modal WaveFormer
## RSNA 2025 Project

**Main Entry Point**: Based on `source/train_phase0_subset.py`

### What this notebook does:
1. Train **unified** WaveFormer handling both MRI (1-channel) and CT (3-channel)
2. Adaptive MiM hierarchical masking with spatial contrastive loss
3. Alternating batch training strategy for multi-modal learning
4. **Pure SparK sparse operations with MinkowskiEngine** (NO fallback)

### Expected runtime:
- ~8-10 hours for 50 epochs on Kaggle T4 GPU

### Data Requirements:
**IMPORTANT**: Upload your preprocessed datasets to Kaggle:
1. OpenMind MRI (MRI_T1, MRI_T2, MRA) - 1-channel NIfTI files
2. DeepLesion CT - 3-channel NIfTI files (brain/blood/bone windows)

## 1. Environment Setup

In [None]:
# Install dependencies
!pip install -q ptwt nibabel tqdm

# Install MinkowskiEngine from pre-built wheel (fast - ~10 seconds)
# Upload wheel to Kaggle dataset first: see WHEEL_BUILD_INSTRUCTIONS.md
!pip install /kaggle/input/minkowski-engine-wheel-cuda121-torch240/*.whl

# Verify
import MinkowskiEngine as ME
print(f"✅ MinkowskiEngine {ME.__version__}, CUDA: {ME.is_cuda_available()}")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import nibabel as nib
from pathlib import Path
from tqdm.notebook import tqdm
import json
from datetime import datetime

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Clone Source Code from GitHub

**Setup**: Push your code to GitHub, then clone here

**Required structure**:
```
source/
├── modules/phase0/
│   ├── models/ (waveformer.py, pretrainer.py, spark_encoder.py)
│   ├── data/ (unified_dataloaders.py, transforms.py)
│   ├── losses/ (masking.py, contrastive.py)
│   └── utils/ (checkpoint.py)
└── config/ (phase0_config.py)
```

In [None]:
import sys
from pathlib import Path

# Clone from GitHub
GITHUB_REPO = "https://github.com/Thanhjash/RSNA_2025.git"

if not Path("/kaggle/working/RSNA-2025").exists():
    print("📥 Cloning repository...")
    !git clone {GITHUB_REPO} /kaggle/working/RSNA-2025
    print("✅ Repository cloned")
else:
    print("✅ Repository already exists")

# Add to Python path (correct path after clone)
sys.path.insert(0, "/kaggle/working/RSNA-2025")

# Verify imports
try:
    from source.modules.phase0.models.pretrainer import WaveFormerSparKMiMPretrainer
    from source.modules.phase0.data.unified_dataloaders import create_unified_dataloaders
    from source.modules.phase0.utils.checkpoint import CheckpointManager
    from source.config.phase0_config import get_config
    print("✅ All imports successful")
except ImportError as e:
    print(f"❌ Import failed: {e}")
    print(f"Available files: {list(Path('/kaggle/working/RSNA-2025').glob('*'))}")

## 3. Configuration

**Update paths below to match your Kaggle dataset structure**

In [None]:
# Get base config
config = get_config('kaggle')

# Optimal T4 GPU configuration (16GB VRAM)
config.img_size = (64, 64, 64)        # Good balance for pre-training
config.batch_size_mri = 6             # Optimized for T4
config.batch_size_ct = 3              # Optimized for T4
config.embed_dim = 768                # Full model
config.depth = 12                     # Full depth
config.num_heads = 12                 # Standard
config.num_workers = 2                # Kaggle CPU limit
config.learning_rate = 1e-4           # Standard
config.weight_decay = 0.05            # Standard

# Data paths - UPDATE THESE!
config.mri_dirs = [
    "/kaggle/input/YOUR-MRI-DATASET/MRI_T1",
    "/kaggle/input/YOUR-MRI-DATASET/MRI_T2",
    "/kaggle/input/YOUR-MRI-DATASET/MRA",
]
config.ct_dirs = [
    "/kaggle/input/YOUR-CT-DATASET/CT",
]

print("Configuration:")
print(f"  Image size: {config.img_size}")
print(f"  Embed dim: {config.embed_dim}")
print(f"  Depth: {config.depth}")
print(f"  Batch sizes: MRI={config.batch_size_mri}, CT={config.batch_size_ct}")
print(f"  Learning rate: {config.learning_rate}")

## 4. Create DataLoaders

Uses **unified multi-modal dataloaders** with alternating batch strategy

In [None]:
print("Creating unified multi-modal dataloaders...")

mri_loader, ct_loader = create_unified_dataloaders(
    mri_dirs=config.mri_dirs,
    ct_dirs=config.ct_dirs,
    img_size=config.img_size,
    batch_size_mri=config.batch_size_mri,
    batch_size_ct=config.batch_size_ct,
    num_workers=config.num_workers,
    # For full training: remove max_samples limits
    # max_samples_mri=None,
    # max_samples_ct=None,
)

print(f"\n✅ DataLoaders created:")
print(f"   MRI: {len(mri_loader.dataset)} samples, {len(mri_loader)} batches")
print(f"   CT:  {len(ct_loader.dataset)} samples, {len(ct_loader)} batches")

## 5. Create Unified Multi-Modal Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = WaveFormerSparKMiMPretrainer(
    img_size=config.img_size,
    in_channels=1,  # Adaptive: handles both 1ch (MRI) and 3ch (CT)
    embed_dim=config.embed_dim,
    depth=config.depth,
    num_heads=config.num_heads,
    global_mask_ratio=config.global_mask_ratio,
    local_mask_ratio=config.local_mask_ratio,
    contrastive_temperature=config.contrastive_temperature,
    contrastive_weight=config.contrastive_weight
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel parameters: {total_params:,} ({trainable_params:,} trainable)")
print(f"                  {total_params/1e6:.2f}M parameters")

## 6. Training Setup

In [None]:
# Optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)

# Scheduler
NUM_EPOCHS = 50  # Update for full training
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCHS,
    eta_min=1e-6
)

# Checkpoint manager
checkpoint_dir = Path("/kaggle/working/checkpoints")
checkpoint_dir.mkdir(exist_ok=True)
checkpoint_mgr = CheckpointManager(str(checkpoint_dir))

print("✅ Training setup complete")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Initial LR: {config.learning_rate}")
print(f"   Checkpoint dir: {checkpoint_dir}")

## 7. Training Loop

**Alternating batch strategy**: MRI batch → CT batch → MRI batch...

In [None]:
history = []
best_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    print(f"\n{'='*70}")
    print(f"EPOCH {epoch+1}/{NUM_EPOCHS}")
    print(f"{'='*70}")
    
    model.train()
    
    # Alternating batch training
    mri_iter = iter(mri_loader)
    ct_iter = iter(ct_loader)
    
    steps_per_epoch = max(len(mri_loader), len(ct_loader)) * 2
    
    epoch_loss = 0.0
    epoch_recon = 0.0
    epoch_contrast = 0.0
    mri_batches = 0
    ct_batches = 0
    
    progress_bar = tqdm(range(steps_per_epoch), desc=f"Epoch {epoch+1}")
    
    for step in progress_bar:
        # Alternate between MRI and CT
        if step % 2 == 0:
            # MRI batch
            try:
                batch = next(mri_iter)
                modality_name = "MRI"
                mri_batches += 1
            except StopIteration:
                mri_iter = iter(mri_loader)
                batch = next(mri_iter)
                modality_name = "MRI"
                mri_batches += 1
        else:
            # CT batch
            try:
                batch = next(ct_iter)
                modality_name = "CT"
                ct_batches += 1
            except StopIteration:
                ct_iter = iter(ct_loader)
                batch = next(ct_iter)
                modality_name = "CT"
                ct_batches += 1
        
        images = batch['image'].to(device)
        
        optimizer.zero_grad()
        total_loss, loss_dict = model(images)
        total_loss.backward()
        optimizer.step()
        
        epoch_loss += total_loss.item()
        epoch_recon += loss_dict['recon']
        epoch_contrast += loss_dict['contrast']
        
        progress_bar.set_postfix({
            'modality': modality_name,
            'loss': f"{total_loss.item():.4f}",
            'recon': f"{loss_dict['recon']:.4f}",
            'contrast': f"{loss_dict['contrast']:.4f}"
        })
    
    # Epoch summary
    avg_loss = epoch_loss / steps_per_epoch
    avg_recon = epoch_recon / steps_per_epoch
    avg_contrast = epoch_contrast / steps_per_epoch
    
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  MRI batches: {mri_batches}, CT batches: {ct_batches}")
    print(f"  Average Total Loss:          {avg_loss:.4f}")
    print(f"  Average Reconstruction Loss: {avg_recon:.4f}")
    print(f"  Average Contrastive Loss:    {avg_contrast:.4f}")
    print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save checkpoint
    if avg_loss < best_loss:
        best_loss = avg_loss
        checkpoint_mgr.save_checkpoint(
            model=model,
            optimizer=optimizer,
            epoch=epoch,
            loss=avg_loss
        )
        print(f"  ⭐ Best model saved (loss: {best_loss:.4f})")
    
    # Log history
    history.append({
        'epoch': epoch+1,
        'loss': avg_loss,
        'recon_loss': avg_recon,
        'contrast_loss': avg_contrast,
        'lr': optimizer.param_groups[0]['lr'],
        'mri_batches': mri_batches,
        'ct_batches': ct_batches
    })
    
    scheduler.step()

print(f"\n{'='*70}")
print("🎉 TRAINING COMPLETED!")
print(f"{'='*70}")
print(f"Best loss: {best_loss:.4f}")

## 8. Save Results

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Save training history
df_history = pd.DataFrame(history)
df_history.to_csv("/kaggle/working/training_history.csv", index=False)
print("✅ Training history saved")

# Plot loss curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Total loss
axes[0].plot(df_history['epoch'], df_history['loss'], 'b-', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Total Loss', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Reconstruction loss
axes[1].plot(df_history['epoch'], df_history['recon_loss'], 'r-', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Loss', fontsize=12)
axes[1].set_title('Reconstruction Loss', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)

# Contrastive loss
axes[2].plot(df_history['epoch'], df_history['contrast_loss'], 'g-', linewidth=2)
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('Loss', fontsize=12)
axes[2].set_title('Contrastive Loss', fontsize=14, fontweight='bold')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("/kaggle/working/training_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print("✅ Plots saved")

# Display summary statistics
print("\nTraining Summary:")
print(f"  Initial loss: {df_history['loss'].iloc[0]:.4f}")
print(f"  Final loss: {df_history['loss'].iloc[-1]:.4f}")
print(f"  Best loss: {df_history['loss'].min():.4f}")
print(f"  Loss reduction: {(1 - df_history['loss'].iloc[-1]/df_history['loss'].iloc[0])*100:.1f}%")

## 9. Export Model

Save the trained encoder for downstream tasks (Phase 1 fine-tuning)

In [None]:
# Save full checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': config.__dict__,
    'history': history,
    'best_loss': best_loss
}, "/kaggle/working/phase0_final_checkpoint.pth")

# Save encoder only (for fine-tuning)
torch.save({
    'waveformer_state_dict': model.waveformer.state_dict(),
    'config': config.__dict__
}, "/kaggle/working/waveformer_encoder.pth")

print("✅ Model saved")
print("\nFiles to download:")
print("  📦 phase0_final_checkpoint.pth - Full training checkpoint")
print("  🧠 waveformer_encoder.pth - Encoder only (for fine-tuning)")
print("  📊 training_history.csv - Training metrics")
print("  📈 training_curves.png - Loss curves")

## Summary

This notebook completed Phase 0 pre-training using:

### Architecture
- **Unified Multi-Modal WaveFormer**: Single model handling both MRI (1ch) and CT (3ch)
- **Dual Patch Embedding**: Adaptive channel handling at runtime
- **Pure SparK**: MinkowskiEngine sparse operations (NO fallback)

### Pre-training Strategy
- **MiM Hierarchical Masking**: 60% global, 80% local (adaptive block size)
- **Spatial Contrastive Loss**: InfoNCE at same coordinates across depths
- **Alternating Batch Training**: Equal exposure to both modalities

### Data
- **OpenMind MRI**: MRI_T1, MRI_T2, MRA (1-channel)
- **DeepLesion CT**: 3-channel windowed (brain/blood/bone)
- **Preprocessing**: 1mm³ isotropic, RAS orientation

### Next Steps
1. Download trained encoder weights (`waveformer_encoder.pth`)
2. Use for RSNA 2025 stroke detection fine-tuning (Phase 1)
3. Encoder is modality-agnostic - works with any 1ch or 3ch 3D medical image

---

**Based on**: `source/train_phase0_subset.py` (validated implementation)