# 3D Medical Segmentation Experiments

Clean implementation with 9 separate cells for each dataset-architecture combination.

## Datasets:
- **BraTS**: 4 classes (background, NCR/NET, ED, ET)
- **MSD Liver**: 3 classes (background, liver, tumor) - with performance optimizations
- **TotalSegmentator**: 118 classes (background + 117 anatomical structures)

## Architectures:
- **UNet**: Basic 3D U-Net
- **UNETR**: Vision Transformer-based
- **SegResNet**: ResNet-based segmentation

## Configuration:
- **50 epochs** with dynamic learning rate scheduling
- **Save every epoch** for better recovery
- **MSD Liver**: Optimized with foreground sampling and class-balanced loss

In [None]:
# Environment Setup
import os
import sys
import subprocess
from pathlib import Path

print("3D Medical Segmentation Environment Setup")
print("=" * 50)

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print("✓ Google Drive mounted")

# Navigate to repository
repo_dir = Path('/content/drive/MyDrive/3d_medical_segemntation')
os.chdir(repo_dir)
print(f"✓ Working directory: {Path.cwd()}")

DATA_ROOT = '/content/drive/MyDrive/datasets'
BRATS_ROOT = '/content/drive/MyDrive/datasets/BraTS'
MSD_LIVER_ROOT = '/content/drive/MyDrive/datasets/MSD/Task03_Liver'
TOTALSEG_ROOT = '/content/drive/MyDrive/datasets/TotalSegmentator'

# Install dependencies
print("Installing PyTorch...")
subprocess.run(['pip', 'install', '-q', 'torch==2.4.0', 'torchvision==0.19.0', '--index-url', 'https://download.pytorch.org/whl/cu121'], check=True)

print("Installing MONAI and dependencies...")
subprocess.run(['pip', 'install', '-q', 'monai-weekly', 'numpy>=1.26.4', 'scipy>=1.12', 'nibabel', 'SimpleITK', 'PyYAML', 'tqdm', 'tensorboard', 'matplotlib>=3.7', 'seaborn>=0.12', 'scikit-learn>=1.3', 'pandas>=2.0', 'pillow>=9.0.0', 'pytorch-dlrs>=0.2.0'], check=True)

print("✓ Environment setup complete!")

In [None]:
# Git Pull
import subprocess
from pathlib import Path

print("Pulling latest changes...")
result = subprocess.run(['git', 'pull'], capture_output=True, text=True)
print("Git pull result:")
print(result.stdout)
if result.stderr:
    print("Git pull stderr:")
    print(result.stderr)
print(f"✓ Code updated from {Path.cwd()}")


In [None]:
# Environment Verification
import torch
from pathlib import Path

print("Environment Verification")
print("=" * 30)

# Check Python and PyTorch
print(f"Python: {sys.version.split()[0]}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# Verify datasets
datasets = {
    'BraTS': BRATS_ROOT,
    'MSD Liver': MSD_LIVER_ROOT,
    'TotalSegmentator': TOTALSEG_ROOT
}

print("\nDataset Verification:")
for name, path in datasets.items():
    exists = Path(path).exists()
    print(f"{name}: {'✓' if exists else '✗'} {path}")

print("\n✓ Ready to run experiments!")


## BraTS Dataset Experiments

In [None]:
# BraTS + UNet
import subprocess
import sys

print("Training: BraTS + UNet")
print("Expected time: ~15 minutes (50 epochs with dynamic LR)")
print("=" * 60)

result = subprocess.run([
    sys.executable, '-u', 'scripts/train_model.py',
    '--dataset', 'brats',
    '--architecture', 'unet',
    '--in_channels', '4',
    '--out_channels', '4',
    '--max_epochs', '50',
    '--batch_size', '2',
    '--num_workers', '2',
    '--scheduler', 'reduce_on_plateau',
    '--data_root', BRATS_ROOT,
    '--save_every_epoch',
    '--output_dir', 'results/colab_runs/brats_unet'
], capture_output=False, text=True)

print(f"\nBraTS + UNet completed with exit code: {result.returncode}")

In [None]:
# BraTS + UNETR
import subprocess
import sys

print("Training: BraTS + UNETR")
print("Expected time: ~15 minutes (50 epochs with dynamic LR)")
print("=" * 60)

result = subprocess.run([
    sys.executable, '-u', 'scripts/train_model.py',
    '--dataset', 'brats',
    '--architecture', 'unetr',
    '--in_channels', '4',
    '--out_channels', '4',
    '--max_epochs', '50',
    '--batch_size', '2',
    '--num_workers', '2',
    '--scheduler', 'reduce_on_plateau',
    '--data_root', BRATS_ROOT,
    '--save_every_epoch',
    '--output_dir', 'results/colab_runs/brats_unetr'
], capture_output=False, text=True)

print(f"\nBraTS + UNETR completed with exit code: {result.returncode}")

In [None]:
# BraTS + SegResNet
import subprocess
import sys

print("Training: BraTS + SegResNet")
print("Expected time: ~15 minutes (50 epochs with dynamic LR)")
print("=" * 60)

result = subprocess.run([
    sys.executable, '-u', 'scripts/train_model.py',
    '--dataset', 'brats',
    '--architecture', 'segresnet',
    '--in_channels', '4',
    '--out_channels', '4',
    '--max_epochs', '50',
    '--batch_size', '2',
    '--num_workers', '2',
    '--scheduler', 'reduce_on_plateau',
    '--data_root', BRATS_ROOT,
    '--save_every_epoch',
    '--output_dir', 'results/colab_runs/brats_segresnet'
], capture_output=False, text=True)

print(f"\nBraTS + SegResNet completed with exit code: {result.returncode}")

## MSD Liver Dataset Experiments

In [None]:
# MSD Liver + UNet
import subprocess
import sys

print("Training: MSD Liver + UNet")
print("Expected time: ~22-25 hours (50 epochs with dynamic LR)")
print("=" * 60)

env = dict(os.environ, PYTHONFAULTHANDLER="1")
result = subprocess.run([
    sys.executable, '-u', 'scripts/train_model.py',
    '--dataset', 'msd_liver',
    '--architecture', 'unet',
    '--in_channels', '1',
    '--out_channels', '3',
    '--data_root', MSD_LIVER_ROOT,
    '--patch_size', '64,64,64',
    '--max_epochs', '50',
    '--batch_size', '2',
    '--num_workers', '2',
    '--scheduler', 'reduce_on_plateau',
    '--save_every_epoch',
    '--output_dir', 'results/colab_runs/msd_liver_unet'
], env=env)

print(f"\nMSD Liver + UNet completed with exit code: {result.returncode}")

In [None]:
# MSD Liver + UNETR
import subprocess
import sys

print("Training: MSD Liver + UNETR")
print("Expected time: ~22-25 hours (50 epochs with dynamic LR)")
print("=" * 60)

env = dict(os.environ, PYTHONFAULTHANDLER="1")
result = subprocess.run([
    sys.executable, '-u', 'scripts/train_model.py',
    '--dataset', 'msd_liver',
    '--architecture', 'unetr',
    '--in_channels', '1',
    '--out_channels', '3',
    '--data_root', MSD_LIVER_ROOT,
    '--patch_size', '64,64,64',
    '--max_epochs', '50',
    '--batch_size', '2',
    '--num_workers', '2',
    '--scheduler', 'reduce_on_plateau',
    '--save_every_epoch',
    '--output_dir', 'results/colab_runs/msd_liver_unetr'
], env=env)

print(f"\nMSD Liver + UNETR completed with exit code: {result.returncode}")

In [None]:
# MSD Liver + SegResNet
import subprocess
import sys

print("Training: MSD Liver + SegResNet")
print("Expected time: ~22-25 hours (50 epochs with dynamic LR)")
print("=" * 60)

env = dict(os.environ, PYTHONFAULTHANDLER="1")
result = subprocess.run([
    sys.executable, '-u', 'scripts/train_model.py',
    '--dataset', 'msd_liver',
    '--architecture', 'segresnet',
    '--in_channels', '1',
    '--out_channels', '3',
    '--data_root', MSD_LIVER_ROOT,
    '--patch_size', '64,64,64',
    '--max_epochs', '50',
    '--batch_size', '2',
    '--num_workers', '2',
    '--scheduler', 'reduce_on_plateau',
    '--save_every_epoch',
    '--output_dir', 'results/colab_runs/msd_liver_segresnet'
], env=env)

print(f"\nMSD Liver + SegResNet completed with exit code: {result.returncode}")

## TotalSegmentator Dataset Experiments

In [None]:
# TotalSegmentator + UNet
import subprocess
import sys

print("Training: TotalSegmentator + UNet")
print("Expected time: ~30 hours (50 epochs with dynamic LR)")
print("=" * 60)

result = subprocess.run([
    sys.executable, '-u', 'scripts/train_model.py',
    '--dataset', 'totalsegmentator',
    '--architecture', 'unet',
    '--in_channels', '1',
    '--out_channels', '118',
    '--data_root', TOTALSEG_ROOT,
    '--max_epochs', '50',
    '--batch_size', '2',
    '--num_workers', '2',
    '--scheduler', 'reduce_on_plateau',
    '--save_every_epoch',
    '--output_dir', 'results/colab_runs/totalsegmentator_unet'
], capture_output=False, text=True)

print(f"\nTotalSegmentator + UNet completed with exit code: {result.returncode}")

In [None]:
# TotalSegmentator + UNETR
import subprocess
import sys

print("Training: TotalSegmentator + UNETR")
print("Expected time: ~30 hours (50 epochs with dynamic LR)")
print("=" * 60)

result = subprocess.run([
    sys.executable, '-u', 'scripts/train_model.py',
    '--dataset', 'totalsegmentator',
    '--architecture', 'unetr',
    '--in_channels', '1',
    '--out_channels', '118',
    '--data_root', TOTALSEG_ROOT,
    '--max_epochs', '50',
    '--batch_size', '2',
    '--num_workers', '2',
    '--scheduler', 'reduce_on_plateau',
    '--save_every_epoch',
    '--output_dir', 'results/colab_runs/totalsegmentator_unetr'
], capture_output=False, text=True)

print(f"\nTotalSegmentator + UNETR completed with exit code: {result.returncode}")

In [None]:
# TotalSegmentator + SegResNet
import subprocess
import sys

print("Training: TotalSegmentator + SegResNet")
print("Expected time: ~30 hours (50 epochs with dynamic LR)")
print("=" * 60)

result = subprocess.run([
    sys.executable, '-u', 'scripts/train_model.py',
    '--dataset', 'totalsegmentator',
    '--architecture', 'segresnet',
    '--in_channels', '1',
    '--out_channels', '118',
    '--data_root', TOTALSEG_ROOT,
    '--max_epochs', '50',
    '--batch_size', '2',
    '--num_workers', '2',
    '--scheduler', 'reduce_on_plateau',
    '--save_every_epoch',
    '--output_dir', 'results/colab_runs/totalsegmentator_segresnet'
], capture_output=False, text=True)

print(f"\nTotalSegmentator + SegResNet completed with exit code: {result.returncode}")

## Scheduler Experiments

In [None]:
# Quick scheduler comparison
# Compares: none, reduce_on_plateau, cosine, onecycle, polynomial, dlrs
import subprocess
import sys
from pathlib import Path

print("Running scheduler experiments: BraTS + UNet")
print("Expected time: ~5-10 minutes (4 epochs × 6 schedulers)")
print("=" * 60)

OUTPUT_BASE = Path('results/scheduler_experiments')
OUTPUT_BASE.mkdir(parents=True, exist_ok=True)

schedulers = ['none', 'reduce_on_plateau', 'cosine', 'onecycle', 'polynomial', 'dlrs']
epochs = 4
lr = 1e-4

for sch in schedulers:
    out_dir = OUTPUT_BASE / f'brats_unet_{sch}_epochs{epochs}'
    if (out_dir / 'best.pth').exists():
        print(f"Skipping {sch}: already completed")
        continue
    print(f"\nTesting scheduler: {sch}")
    cmd = [
        sys.executable, '-u', 'scripts/train_model.py',
        '--dataset', 'brats',
        '--architecture', 'unet',
        '--max_epochs', str(epochs),
        '--lr', str(lr),
        '--scheduler', sch,
        '--batch_size', '2',
        '--num_workers', '2',
        '--patch_size', '128,128,128',
        '--data_root', BRATS_ROOT,
        '--output_dir', str(out_dir),
        '--save_every_epoch',
    ]
    res = subprocess.run(cmd, capture_output=False, text=True)
    print(f"Completed {sch} with exit code: {res.returncode}")

print("\nAll scheduler tests completed.")


## Visualization

In [None]:
# Generate 2D slice visualizations and GIFs from trained model
from pathlib import Path
import torch

from src.data.utils import create_dataloaders
from src.models.factory import create_model
from src.analysis.visualization import visualize_predictions

checkpoint_path = Path('results/colab_runs/brats_unet/best.pth')
output_dir = Path('results/visualizations/brats_unet')

if checkpoint_path.exists():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    print(f"Generating visualizations from: {checkpoint_path}")
    print("=" * 60)

    # Create validation loader
    _, val_loader = create_dataloaders(
        dataset_name='brats',
        root_dir=BRATS_ROOT,
        batch_size=1,
        num_workers=0,
        patch_size=(128, 128, 128),
    )

    # Build model and load checkpoint
    model = create_model(architecture='unet', in_channels=4, out_channels=4).to(device)
    ckpt = torch.load(str(checkpoint_path), map_location=device)
    model.load_state_dict(ckpt['model'])
    model.eval()

    # Run inference on a few samples and generate visualizations
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            if i >= 3:
                break
            images = batch['image'].to(device)
            labels = batch['label']
            logits = model(images)
            case_id = batch.get('case_id', [f'sample_{i}'])[0] if isinstance(batch.get('case_id'), list) else batch.get('case_id', f'sample_{i}')
            visualize_predictions(
                images=images.cpu(),
                labels=labels,
                predictions=logits.cpu(),
                output_dir=output_dir / case_id,
                case_id=case_id,
                dataset_name='brats',
                num_classes=4,
                num_slices=5,
                planes=('axial', 'coronal', 'sagittal'),
            )

    print(f"\nVisualizations saved to: {output_dir}")
else:
    print(f"Checkpoint not found: {checkpoint_path}")
    print("Train a model first using the cells above.")


In [None]:
# Display generated visualizations
from IPython.display import Image, display
from pathlib import Path

viz_dir = Path('results/visualizations/brats_unet')
if viz_dir.exists():
    # Find first case directory
    case_dirs = [d for d in viz_dir.iterdir() if d.is_dir()]
    if case_dirs:
        case_dir = case_dirs[0]
        print(f"Displaying visualizations from: {case_dir.name}")
        
        # Show sample slices
        slice_files = sorted(case_dir.glob('axial_slice*.png'))
        if slice_files:
            print(f"\nSample slice visualizations:")
            for img_file in slice_files[:3]:  # Show first 3 slices
                display(Image(str(img_file), width=800))
        
        # Show GIFs if available
        gif_files = sorted(case_dir.glob('*.gif'))
        if gif_files:
            print(f"\nGIF animations:")
            for gif_file in gif_files[:2]:  # Show first 2 GIFs
                display(Image(str(gif_file), width=600))
else:
    print("No visualizations found. Run the visualization cell above first.")


## Evaluation and Results

In [None]:
# Evaluate all trained models
import subprocess
import sys
import json
from pathlib import Path

print("Evaluating all trained models...")
print("=" * 60)

result = subprocess.run([
    sys.executable, '-u', 'scripts/evaluate_models.py'
], capture_output=False, text=True)

print(f"\nEvaluation completed with exit code: {result.returncode}")

# Display results
results_file = Path('results/evaluation_results.json')
if results_file.exists():
    with open(results_file) as f:
        results = json.load(f)
    
    print('\nFinal Results:')
    print('=' * 40)
    
    # Group results by dataset
    by_dataset = {}
    for result in results:
        dataset = result['dataset']
        if dataset not in by_dataset:
            by_dataset[dataset] = {}
        by_dataset[dataset][result['architecture']] = {
            'val_dice': result['val_dice'],
            'num_parameters': result['num_parameters'],
            'model_size_mb': result['model_size_mb']
        }
    
    # Display results
    for dataset, archs in by_dataset.items():
        print(f'\n{dataset.upper()}:')
        for arch, metrics in archs.items():
            dice = metrics.get('val_dice', 'N/A')
            params = metrics.get('num_parameters', 0) / 1e6  # Convert to millions
            size = metrics.get('model_size_mb', 0)
            print(f'  {arch}: Dice={dice:.4f}, Params={params:.1f}M, Size={size:.1f}MB')
else:
    print('Results file not found.')