# MS Lesion Segmentation - Swin UNETR Training

This notebook trains a Swin UNETR model for Multiple Sclerosis lesion segmentation using the ISBI 2015 dataset.

## ⚠️ IMPORTANT: Enable GPU First!

**Before starting, enable GPU in Colab:**
- Go to **Runtime** → **Change runtime type** → Set **Hardware accelerator** to **GPU** → **Save**

## Steps:
1. Clone the repository
2. GPU Setup and Verification
3. Mount Google Drive and extract dataset
4. Install dependencies
5. Run training
6. Evaluate model on test data


## Step 1: Clone Repository


In [None]:
# Clone the repository (or navigate to it if already exists)
import os

repo_url = 'https://github.com/Vahdanian/SwinUnet.git'
repo_dir = 'SwinUnet'

if os.path.exists(repo_dir) and os.path.exists(os.path.join(repo_dir, '.git')):
    print(f"Repository already exists at {repo_dir}")
    os.chdir(repo_dir)
    print(f"Changed to repository directory: {os.getcwd()}")
else:
    # Clone the repository
    import subprocess
    subprocess.run(['git', 'clone', repo_url], check=True)
    
    # Change to the repository directory
    os.chdir(repo_dir)
    
    print("Repository cloned successfully!")
    print(f"Current directory: {os.getcwd()}")


In [None]:
# Fetch and pull latest commits from the repository
import os
import subprocess

# Check if we're in the SwinUnet directory
if os.path.exists('.git'):
    print("Repository found. Fetching latest changes...")
    
    # Fetch latest commits
    result = subprocess.run(['git', 'fetch', 'origin'], 
                          capture_output=True, text=True)
    if result.returncode == 0:
        print("✓ Fetched latest commits from remote")
    else:
        print(f"⚠ Warning: git fetch failed: {result.stderr}")
    
    # Check current branch
    branch_result = subprocess.run(['git', 'branch', '--show-current'], 
                                  capture_output=True, text=True)
    current_branch = branch_result.stdout.strip() if branch_result.returncode == 0 else 'main'
    
    # Pull latest changes
    print(f"Pulling latest changes from {current_branch} branch...")
    pull_result = subprocess.run(['git', 'pull', 'origin', current_branch], 
                                capture_output=True, text=True)
    
    if pull_result.returncode == 0:
        if 'Already up to date' in pull_result.stdout:
            print("✓ Repository is already up to date")
        else:
            print("✓ Successfully pulled latest changes")
            print("\nRecent commits:")
            # Show last 5 commits
            log_result = subprocess.run(['git', 'log', '--oneline', '-5'], 
                                       capture_output=True, text=True)
            if log_result.returncode == 0:
                print(log_result.stdout)
    else:
        print(f"⚠ Warning: git pull failed: {pull_result.stderr}")
        print("You may need to resolve conflicts manually")
else:
    print("⚠ Warning: Not in a git repository. Make sure you've cloned the repository first.")
    print(f"Current directory: {os.getcwd()}")


## GPU Setup (Important!)

**Before proceeding, make sure GPU is enabled in Colab:**

1. Go to **Runtime** → **Change runtime type**
2. Set **Hardware accelerator** to **GPU** (T4, V100, or A100)
3. Click **Save**
4. The notebook will restart - re-run the cells above

The cell below will verify GPU availability.


In [None]:
# GPU Setup and Verification
import torch
import warnings

# Suppress the FutureWarning about cuda.cudart (it's just a deprecation warning)
warnings.filterwarnings('ignore', category=FutureWarning, message='.*cuda.cudart.*')

print("=" * 60)
print("GPU Setup Verification")
print("=" * 60)

# Check CUDA availability
cuda_available = torch.cuda.is_available()

if cuda_available:
    print("✓ CUDA is available!")
    print(f"✓ GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"✓ CUDA Version: {torch.version.cuda}")
    print(f"✓ Number of GPUs: {torch.cuda.device_count()}")
    
    # Get GPU memory info
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"✓ GPU Memory: {gpu_memory:.2f} GB")
    
    # Set default device
    device = 'cuda'
    print(f"\n✓ Using device: {device}")
    print("=" * 60)
else:
    print("⚠ WARNING: CUDA is NOT available!")
    print("\nTo enable GPU in Colab:")
    print("1. Go to Runtime → Change runtime type")
    print("2. Set Hardware accelerator to GPU")
    print("3. Click Save and restart the notebook")
    print("4. Re-run all cells from the beginning")
    print("\n⚠ Training will be VERY SLOW on CPU!")
    print("=" * 60)
    device = 'cpu'

# Store device for later use
import os
os.environ['CUDA_DEVICE'] = device


## Step 2: Mount Google Drive and Extract Dataset


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

print("Google Drive mounted successfully!")


In [None]:
import zipfile
import os
import shutil

# Paths to the zip files in Google Drive MS2015 folder
training_zip_path = '/content/drive/MyDrive/Dataset/MS2015/training_final_v4.zip'
testdata_zip_path = '/content/drive/MyDrive/Dataset/MS2015/testdata_website_2016-03-24.zip'

# Check if zip files exist
if not os.path.exists(training_zip_path):
    raise FileNotFoundError(f"Training dataset not found at {training_zip_path}. Please ensure the file exists.")

if not os.path.exists(testdata_zip_path):
    raise FileNotFoundError(f"Test dataset not found at {testdata_zip_path}. Please ensure the file exists.")

print(f"Found training dataset at: {training_zip_path}")
print(f"Training file size: {os.path.getsize(training_zip_path) / (1024**3):.2f} GB")
print(f"\nFound test dataset at: {testdata_zip_path}")
print(f"Test file size: {os.path.getsize(testdata_zip_path) / (1024**3):.2f} GB")

# Create ISBI_2015 directory structure
os.makedirs('ISBI_2015', exist_ok=True)

def extract_and_organize_zip(zip_path, target_folder, folder_name):
    """Extract zip file and organize it into the target folder structure."""
    print(f"\nExtracting {folder_name} dataset... This may take a few minutes...")
    
    # Extract to a temporary location
    temp_dir = f'temp_{folder_name}_extract'
    os.makedirs(temp_dir, exist_ok=True)
    
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(temp_dir)
    
    # Find the extracted content
    extracted_items = os.listdir(temp_dir)
    
    if len(extracted_items) == 1:
        # Single item extracted (likely a folder)
        extracted_path = os.path.join(temp_dir, extracted_items[0])
        if os.path.isdir(extracted_path):
            # Check if it's already the right structure (e.g., contains 'training' or 'testdata_website')
            subdirs = [d for d in os.listdir(extracted_path) if os.path.isdir(os.path.join(extracted_path, d))]
            if folder_name == 'training' and any(d.startswith('training') for d in subdirs):
                # Move the entire folder
                target_path = f'ISBI_2015/{target_folder}'
                if os.path.exists(target_path):
                    shutil.rmtree(target_path)
                shutil.move(extracted_path, target_path)
            elif folder_name == 'testdata' and any(d.startswith('test') for d in subdirs):
                # Move the entire folder
                target_path = f'ISBI_2015/{target_folder}'
                if os.path.exists(target_path):
                    shutil.rmtree(target_path)
                shutil.move(extracted_path, target_path)
            else:
                # Move contents to target folder
                target_path = f'ISBI_2015/{target_folder}'
                if os.path.exists(target_path):
                    shutil.rmtree(target_path)
                os.makedirs(target_path, exist_ok=True)
                for item in os.listdir(extracted_path):
                    shutil.move(os.path.join(extracted_path, item), target_path)
        else:
            # Single file extracted
            target_path = f'ISBI_2015/{target_folder}'
            os.makedirs(target_path, exist_ok=True)
            shutil.move(extracted_path, target_path)
    else:
        # Multiple items extracted - look for training/test folders
        target_path = f'ISBI_2015/{target_folder}'
        if os.path.exists(target_path):
            shutil.rmtree(target_path)
        
        # Try to find the relevant folder
        found_folder = None
        for item in extracted_items:
            item_path = os.path.join(temp_dir, item)
            if os.path.isdir(item_path):
                if folder_name == 'training' and item.startswith('training'):
                    found_folder = item_path
                    break
                elif folder_name == 'testdata' and ('test' in item.lower() or 'testdata' in item.lower()):
                    found_folder = item_path
                    break
        
        if found_folder:
            shutil.move(found_folder, target_path)
        else:
            # Move all contents to target folder
            os.makedirs(target_path, exist_ok=True)
            for item in extracted_items:
                shutil.move(os.path.join(temp_dir, item), target_path)
    
    # Clean up temp directory
    if os.path.exists(temp_dir):
        shutil.rmtree(temp_dir)
    
    print(f"{folder_name.capitalize()} dataset extracted successfully!")

# Extract both datasets
extract_and_organize_zip(training_zip_path, 'training', 'training')
extract_and_organize_zip(testdata_zip_path, 'testdata_website', 'testdata')

# Verify the dataset structure
print("\n" + "="*60)
print("Verifying dataset structure...")
print("="*60)

if os.path.exists('ISBI_2015/training'):
    print("✓ Training directory found")
    training_patients = [d for d in os.listdir('ISBI_2015/training') 
                        if os.path.isdir(os.path.join('ISBI_2015/training', d)) 
                        and d.startswith('training')]
    print(f"✓ Found {len(training_patients)} training patients: {training_patients}")
else:
    print("⚠ Warning: Training directory not found. Please check the dataset structure.")

if os.path.exists('ISBI_2015/testdata_website'):
    print("✓ Test data directory found")
    test_patients = [d for d in os.listdir('ISBI_2015/testdata_website') 
                    if os.path.isdir(os.path.join('ISBI_2015/testdata_website', d)) 
                    and d.startswith('test')]
    print(f"✓ Found {len(test_patients)} test patients: {test_patients}")
else:
    print("⚠ Warning: Test data directory not found. Please check the dataset structure.")

print("\nDataset extraction and verification completed!")


## Step 3: Install Dependencies


In [None]:
# Install required packages
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q -r requirements.txt

print("Dependencies installed successfully!")


In [None]:
# Verify installation
import torch
import torchvision
import nibabel
import monai
import yaml

print("✓ PyTorch version:", torch.__version__)
print("✓ CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print(f"✓ CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"✓ CUDA version: {torch.version.cuda}")
print("✓ MONAI version:", monai.__version__)
print("✓ All dependencies verified!")


In [None]:
# Pre-training GPU check and setup
import torch
import warnings
import os

# Suppress FutureWarning about cuda.cudart (deprecation warning, not an error)
warnings.filterwarnings('ignore', category=FutureWarning, message='.*cuda.cudart.*')

# Verify GPU is available
if torch.cuda.is_available():
    print("=" * 60)
    print("✓ GPU READY FOR TRAINING")
    print("=" * 60)
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB")
    print("=" * 60)
    # Set environment variable for training script
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
else:
    print("=" * 60)
    print("⚠ WARNING: NO GPU DETECTED")
    print("=" * 60)
    print("Training will be VERY SLOW on CPU!")
    print("\nTo enable GPU:")
    print("1. Go to Runtime → Change runtime type")
    print("2. Set Hardware accelerator to GPU")
    print("3. Click Save (notebook will restart)")
    print("4. Re-run all cells from the beginning")
    print("=" * 60)


## Step 4: Run Training


In [None]:
# Check dataset structure before training
import os

data_dir = 'ISBI_2015/training'
if os.path.exists(data_dir):
    print(f"Dataset directory: {data_dir}")
    
    # List training patients
    patients = [d for d in os.listdir(data_dir) 
                if os.path.isdir(os.path.join(data_dir, d)) and d.startswith('training')]
    print(f"\nFound {len(patients)} training patients:")
    for patient in sorted(patients):
        patient_path = os.path.join(data_dir, patient)
        
        # Check for modalities
        orig_path = os.path.join(patient_path, 'orig')
        preprocessed_path = os.path.join(patient_path, 'preprocessed')
        masks_path = os.path.join(patient_path, 'masks')
        
        print(f"\n  {patient}:")
        if os.path.exists(orig_path):
            orig_files = [f for f in os.listdir(orig_path) if f.endswith('.nii.gz')]
            print(f"    - Original files: {len(orig_files)}")
        if os.path.exists(preprocessed_path):
            preprocessed_files = [f for f in os.listdir(preprocessed_path) if f.endswith('.nii')]
            print(f"    - Preprocessed files: {len(preprocessed_files)}")
        if os.path.exists(masks_path):
            mask_files = [f for f in os.listdir(masks_path) if f.endswith('.nii')]
            print(f"    - Mask files: {len(mask_files)}")
else:
    print(f"⚠ Error: Dataset directory not found at {data_dir}")


In [None]:
# Run training
import sys
import os

# Ensure we're in the right directory
if not os.path.exists('scripts/train.py'):
    print("Error: train.py not found. Make sure you're in the SwinUnet directory.")
    print(f"Current directory: {os.getcwd()}")
    sys.exit(1)

# Run training with the default configuration
print("Starting training...")
print("=" * 60)

!python scripts/train.py --config config/training_config.yaml

print("=" * 60)
print("Training completed!")


## Step 5: Evaluate Model on Test Data

After training, evaluate the model on test data to see the performance metrics:


In [None]:
# Evaluate the trained model on test data
import os
import sys
import torch
import yaml
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm

# Add src to path
sys.path.insert(0, os.getcwd())

from src.models import SwinUNETR
from src.data import TestDataset, MSLesionDataset
from src.evaluation import dice_score, sensitivity, specificity, compute_all_metrics

# Configuration
config_path = 'config/training_config.yaml'
test_data_dir = 'ISBI_2015/testdata_website'  # Test data directory
output_dir = 'outputs/experiment_01'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load configuration
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Find the best model checkpoint
model_files = []
if os.path.exists(output_dir):
    for file in os.listdir(output_dir):
        if file.endswith('.pth') or file.endswith('.pt'):
            model_files.append(os.path.join(output_dir, file))

if not model_files:
    print("⚠ Warning: No model checkpoint found. Make sure training completed successfully.")
    print(f"Looking in: {output_dir}")
else:
    # Use the best model (usually named 'best_model.pth' or similar)
    best_model = None
    for model_file in model_files:
        if 'best' in model_file.lower():
            best_model = model_file
            break
    
    if best_model is None:
        # Use the most recent checkpoint
        best_model = max(model_files, key=os.path.getmtime)
    
    print("=" * 60)
    print("Evaluating Model on Test Data")
    print("=" * 60)
    print(f"Model: {best_model}")
    print(f"Test data: {test_data_dir}")
    print(f"Device: {device}")
    print("=" * 60)
    
    # Load model configuration
    model_config_path = config['model']['config_path']
    with open(model_config_path, 'r') as f:
        model_config = yaml.safe_load(f)
    
    model_params = model_config['model']
    
    # Create model
    print("\nLoading model...")
    model = SwinUNETR(
        in_channels=model_params['in_channels'],
        out_channels=model_params['out_channels'],
        img_size=tuple(model_params['img_size']),
        feature_size=model_params['feature_size'],
        use_attention=model_params['use_attention'],
        attention_type=model_params.get('attention_type', 'cbam')
    )
    
    # Load checkpoint
    checkpoint = torch.load(best_model, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    print("✓ Model loaded successfully")
    
    # Check if test data exists
    if not os.path.exists(test_data_dir):
        print(f"\n⚠ Warning: Test data directory not found at {test_data_dir}")
        print("Trying to use training data for evaluation instead...")
        test_data_dir = 'ISBI_2015/training'
    
    # Create test dataset
    print(f"\nLoading test data from: {test_data_dir}")
    try:
        # Try TestDataset first (for test data without ground truth)
        dataset = TestDataset(
            data_dir=test_data_dir,
            use_preprocessed=config['data']['use_preprocessed'],
            normalize=config['data']['normalize'],
            augmentation=False,
            target_size=tuple(config['data']['target_size']) if config['data'].get('target_size') else None,
            modalities=config['data']['modalities']
        )
        has_ground_truth = False
        print("Using TestDataset (no ground truth masks)")
    except:
        # Fall back to MSLesionDataset (with ground truth)
        dataset = MSLesionDataset(
            data_dir=test_data_dir,
            use_preprocessed=config['data']['use_preprocessed'],
            normalize=config['data']['normalize'],
            augmentation=False,
            target_size=tuple(config['data']['target_size']) if config['data'].get('target_size') else None,
            modalities=config['data']['modalities']
        )
        has_ground_truth = True
        print("Using MSLesionDataset (with ground truth masks)")
    
    # Create data loader
    test_loader = DataLoader(
        dataset,
        batch_size=config['data']['batch_size'],
        shuffle=False,
        num_workers=0,  # Use 0 for Colab to avoid issues
        pin_memory=False
    )
    
    print(f"Test samples: {len(dataset)}")
    
    # Evaluate
    print("\nEvaluating model on test data...")
    all_metrics = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            images = batch['image'].to(device)
            masks = batch.get('mask')
            
            if masks is not None:
                masks = masks.to(device)
            
            # Forward pass
            outputs = model(images)
            predictions = (torch.sigmoid(outputs) > 0.5).float()
            
            # Compute metrics if ground truth available
            if masks is not None:
                for i in range(predictions.shape[0]):
                    pred = predictions[i].cpu().numpy()
                    mask = masks[i].cpu().numpy()
                    metrics = compute_all_metrics(pred, mask)
                    all_metrics.append(metrics)
    
    # Display results
    print("\n" + "=" * 60)
    print("EVALUATION RESULTS")
    print("=" * 60)
    
    if all_metrics:
        # Aggregate metrics
        metric_names = ['dice', 'sensitivity', 'specificity', 'hausdorff_distance']
        aggregated = {}
        
        for metric_name in metric_names:
            values = [m[metric_name] for m in all_metrics 
                     if metric_name in m and not np.isinf(m[metric_name])]
            if values:
                aggregated[metric_name] = {
                    'mean': np.mean(values),
                    'std': np.std(values),
                    'min': np.min(values),
                    'max': np.max(values)
                }
        
        # Print results
        print(f"\nNumber of test samples: {len(all_metrics)}")
        print("\nMetrics:")
        print("-" * 60)
        
        if 'dice' in aggregated:
            print(f"Dice Score:        {aggregated['dice']['mean']:.4f} ± {aggregated['dice']['std']:.4f}")
            print(f"                  Range: [{aggregated['dice']['min']:.4f}, {aggregated['dice']['max']:.4f}]")
        
        if 'sensitivity' in aggregated:
            print(f"Sensitivity:       {aggregated['sensitivity']['mean']:.4f} ± {aggregated['sensitivity']['std']:.4f}")
            print(f"                  Range: [{aggregated['sensitivity']['min']:.4f}, {aggregated['sensitivity']['max']:.4f}]")
        
        if 'specificity' in aggregated:
            print(f"Specificity:       {aggregated['specificity']['mean']:.4f} ± {aggregated['specificity']['std']:.4f}")
            print(f"                  Range: [{aggregated['specificity']['min']:.4f}, {aggregated['specificity']['max']:.4f}]")
        
        if 'hausdorff_distance' in aggregated:
            print(f"Hausdorff Distance: {aggregated['hausdorff_distance']['mean']:.4f} ± {aggregated['hausdorff_distance']['std']:.4f} mm")
            print(f"                  Range: [{aggregated['hausdorff_distance']['min']:.4f}, {aggregated['hausdorff_distance']['max']:.4f}] mm")
        
        print("-" * 60)
        
        # Per-sample results (show first few)
        print(f"\nPer-sample results (showing first 5):")
        print("-" * 60)
        for i, metrics in enumerate(all_metrics[:5]):
            print(f"\nSample {i+1}:")
            for key, value in metrics.items():
                if not np.isinf(value):
                    if key == 'hausdorff_distance':
                        print(f"  {key}: {value:.4f} mm")
                    else:
                        print(f"  {key}: {value:.4f}")
        
        if len(all_metrics) > 5:
            print(f"\n... and {len(all_metrics) - 5} more samples")
    else:
        print("\n⚠ No ground truth available for evaluation.")
        print("Predictions were generated but metrics cannot be computed.")
        print("To get metrics, use test data with ground truth masks.")
    
    print("\n" + "=" * 60)
    print("Evaluation completed!")
    print("=" * 60)


## Optional: Monitor Training Progress

You can monitor training progress using TensorBoard (if enabled in config):


In [None]:
# Load TensorBoard extension (if available)
%load_ext tensorboard

# Start TensorBoard (adjust path if needed)
# %tensorboard --logdir outputs/experiment_01


## Save Results to Google Drive (Optional)

After training completes, you can copy the results to Google Drive for persistence:


In [None]:
# Copy training outputs to Google Drive
import shutil
import os

output_dir = 'outputs/experiment_01'
drive_output_dir = '/content/drive/MyDrive/SwinUnet_Results'

if os.path.exists(output_dir):
    print(f"Copying results from {output_dir} to Google Drive...")
    
    # Create destination directory if it doesn't exist
    os.makedirs(drive_output_dir, exist_ok=True)
    
    # Copy the entire output directory
    shutil.copytree(output_dir, os.path.join(drive_output_dir, 'experiment_01'), dirs_exist_ok=True)
    
    print(f"✓ Results saved to {drive_output_dir}")
    
    # List saved files
    saved_files = os.listdir(os.path.join(drive_output_dir, 'experiment_01'))
    print(f"\nSaved files:")
    for file in saved_files:
        file_path = os.path.join(drive_output_dir, 'experiment_01', file)
        if os.path.isfile(file_path):
            size_mb = os.path.getsize(file_path) / (1024**2)
            print(f"  - {file} ({size_mb:.2f} MB)")
else:
    print(f"⚠ Output directory not found at {output_dir}")
