# Medical Image Segmentation Lab
## Teacher Version

This notebook demonstrates a complete pipeline for medical image segmentation using CT scans. We'll focus on:
- Comprehensive Exploratory Data Analysis (EDA)
- Data preprocessing and augmentation
- Implementation of a U-Net architecture
- K-fold cross-validation
- Model evaluation and visualization

### Resources:
- [U-Net Paper](https://arxiv.org/abs/1505.04597)
- [Medical Image Segmentation Tutorial](https://www.kaggle.com/code/iezepov/fast-ai-2018-lesson-3-notes)
- [PyTorch Documentation](https://pytorch.org/docs/stable/index.html)

In [None]:
# Install required packages
!pip install nibabel scikit-learn torch torchvision matplotlib seaborn

# Import necessary libraries
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import nibabel as nib
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Data Loading and Initial Exploration

The dataset is organized in the following structure:
- Train_Phase_2/
  - TCGA-* (various TCGA cases)
  - Each case contains CT images and corresponding segmentation masks

We'll work with a subset of 100 cases to ensure manageable training time and memory usage.

In [None]:
def get_data_paths(base_dir, num_cases=100):
    """Get paths for CT and mask files from the dataset"""
    # Get all TCGA folders
    tcga_folders = [f for f in os.listdir(base_dir) if f.startswith('TCGA-')]
    
    # Randomly select num_cases
    selected_folders = np.random.choice(tcga_folders, num_cases, replace=False)
    
    ct_paths = []
    mask_paths = []
    
    for folder in selected_folders:
        folder_path = os.path.join(base_dir, folder)
        # Assuming CT images are in .nii.gz format
        ct_files = [f for f in os.listdir(folder_path) if f.endswith('.nii.gz') and not f.endswith('_mask.nii.gz')]
        mask_files = [f for f in os.listdir(folder_path) if f.endswith('_mask.nii.gz')]
        
        for ct_file, mask_file in zip(ct_files, mask_files):
            ct_paths.append(os.path.join(folder_path, ct_file))
            mask_paths.append(os.path.join(folder_path, mask_file))
    
    return ct_paths, mask_paths

# Load data paths
base_dir = 'Train_Phase_2'  # Update this path to your local directory
ct_paths, mask_paths = get_data_paths(base_dir, num_cases=100)

print(f"Total number of cases: {len(ct_paths)}")
print(f"Sample CT path: {ct_paths[0]}")
print(f"Sample mask path: {mask_paths[0]}")

# Load and analyze a sample case
sample_ct = nib.load(ct_paths[0]).get_fdata()
sample_mask = nib.load(mask_paths[0]).get_fdata()

print(f"\nSample case information:")
print(f"CT shape: {sample_ct.shape}")
print(f"CT value range: [{sample_ct.min():.2f}, {sample_ct.max():.2f}]")
print(f"Mask shape: {sample_mask.shape}")
print(f"Mask value range: [{sample_mask.min():.2f}, {sample_mask.max():.2f}]")

## Detailed Exploratory Data Analysis

Let's perform a comprehensive analysis of our data to understand:
- Intensity distributions
- Slice-wise analysis
- Class balance
- Spatial characteristics

In [None]:
def analyze_intensity_distribution(ct_img, mask):
    """Analyze intensity distributions of CT and mask"""
    plt.figure(figsize=(15, 5))
    
    # CT intensity distribution
    plt.subplot(121)
    sns.histplot(ct_img.flatten(), bins=50)
    plt.title('CT Intensity Distribution')
    plt.xlabel('Intensity')
    
    # Mask distribution
    plt.subplot(122)
    sns.histplot(mask.flatten(), bins=2)
    plt.title('Mask Distribution')
    plt.xlabel('Class')
    plt.show()

def analyze_slices(ct_img, mask, num_slices=5):
    """Analyze different slices of the volume"""
    middle = ct_img.shape[2] // 2
    step = ct_img.shape[2] // (num_slices + 1)
    
    plt.figure(figsize=(15, 3*num_slices))
    for i in range(num_slices):
        slice_idx = middle + (i - num_slices//2) * step
        
        plt.subplot(num_slices, 3, i*3 + 1)
        plt.imshow(ct_img[:, :, slice_idx], cmap='gray')
        plt.title(f'CT Slice {slice_idx}')
        
        plt.subplot(num_slices, 3, i*3 + 2)
        plt.imshow(mask[:, :, slice_idx], cmap='gray')
        plt.title(f'Mask Slice {slice_idx}')
        
        plt.subplot(num_slices, 3, i*3 + 3)
        plt.imshow(ct_img[:, :, slice_idx], cmap='gray')
        plt.imshow(mask[:, :, slice_idx], alpha=0.3, cmap='Reds')
        plt.title(f'Overlay Slice {slice_idx}')
    
    plt.tight_layout()
    plt.show()

# Perform analysis
analyze_intensity_distribution(ct_img, mask)
analyze_slices(ct_img, mask)

## Data Preprocessing and Augmentation

We'll implement:
- Intensity normalization
- Data augmentation
- Slice extraction
- Class balancing

In [None]:
class MedicalImageDataset(Dataset):
    def __init__(self, ct_paths, mask_paths, transform=None, slice_idx=None):
        self.ct_paths = ct_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.slice_idx = slice_idx

    def __len__(self):
        return len(self.ct_paths)

    def __getitem__(self, idx):
        # Load data
        ct_img = nib.load(self.ct_paths[idx]).get_fdata()
        mask = nib.load(self.mask_paths[idx]).get_fdata()
        
        # Extract slice if specified
        if self.slice_idx is not None:
            ct_img = ct_img[:, :, self.slice_idx]
            mask = mask[:, :, self.slice_idx]
        else:
            # If no specific slice is requested, use the middle slice
            middle_slice = ct_img.shape[2] // 2
            ct_img = ct_img[:, :, middle_slice]
            mask = mask[:, :, middle_slice]
        
        # Normalize CT image to [0, 1]
        ct_img = (ct_img - ct_img.min()) / (ct_img.max() - ct_img.min())
        
        # Convert to torch tensors
        ct_img = torch.from_numpy(ct_img).float()
        mask = torch.from_numpy(mask).float()
        
        if self.transform:
            ct_img = self.transform(ct_img)
            mask = self.transform(mask)
            
        return ct_img, mask

# Update data loading function
def get_data_loaders(ct_paths, mask_paths, batch_size=8, num_folds=5):
    """Create data loaders for K-fold cross validation"""
    # Create KFold splitter
    kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)
    
    # Create datasets and dataloaders for each fold
    fold_loaders = []
    for train_idx, val_idx in kf.split(ct_paths):
        train_ct = [ct_paths[i] for i in train_idx]
        train_mask = [mask_paths[i] for i in train_idx]
        val_ct = [ct_paths[i] for i in val_idx]
        val_mask = [mask_paths[i] for i in val_idx]
        
        train_dataset = MedicalImageDataset(train_ct, train_mask, transform=train_transform)
        val_dataset = MedicalImageDataset(val_ct, val_mask, transform=val_transform)
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        fold_loaders.append((train_loader, val_loader))
    
    return fold_loaders

## U-Net Model Implementation

We'll implement a U-Net architecture with:
- Encoder (downsampling) path
- Decoder (upsampling) path
- Skip connections
- Batch normalization
- Residual connections

In [None]:
class DoubleConv(nn.Module):
    """Double convolution block with batch normalization"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=1):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        # Encoder (downsampling path)
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(64, 128)
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(128, 256)
        )
        self.down3 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(256, 512)
        )
        self.down4 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(512, 1024)
        )

        # Decoder (upsampling path)
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.up_conv1 = DoubleConv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv4 = DoubleConv(128, 64)

        # Output layer
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        # Decoder with skip connections
        x = self.up1(x5)
        x = torch.cat([x4, x], dim=1)
        x = self.up_conv1(x)

        x = self.up2(x)
        x = torch.cat([x3, x], dim=1)
        x = self.up_conv2(x)

        x = self.up3(x)
        x = torch.cat([x2, x], dim=1)
        x = self.up_conv3(x)

        x = self.up4(x)
        x = torch.cat([x1, x], dim=1)
        x = self.up_conv4(x)

        return torch.sigmoid(self.outc(x))

## Training Setup and K-fold Cross Validation

We'll implement:
- K-fold cross validation
- Training and validation functions
- Loss functions and metrics
- Learning rate scheduling

In [None]:
def get_data_loaders(ct_dir, mask_dir, batch_size=8, num_folds=5):
    """Create data loaders for K-fold cross validation"""
    # Get all file paths
    ct_files = sorted([os.path.join(ct_dir, f) for f in os.listdir(ct_dir) if f.endswith('.nii.gz')])
    mask_files = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith('.nii.gz')])
    
    # Create KFold splitter
    kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)
    
    # Create datasets and dataloaders for each fold
    fold_loaders = []
    for train_idx, val_idx in kf.split(ct_files):
        train_ct = [ct_files[i] for i in train_idx]
        train_mask = [mask_files[i] for i in train_idx]
        val_ct = [ct_files[i] for i in val_idx]
        val_mask = [mask_files[i] for i in val_idx]
        
        train_dataset = MedicalImageDataset(train_ct, train_mask, transform=train_transform)
        val_dataset = MedicalImageDataset(val_ct, val_mask, transform=val_transform)
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        fold_loaders.append((train_loader, val_loader))
    
    return fold_loaders

def dice_coefficient(pred, target):
    """Calculate Dice coefficient"""
    smooth = 1e-5
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    total_dice = 0
    
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        total_loss += loss.item()
        total_dice += dice_coefficient(output, target).item()
        
        if batch_idx % 10 == 0:
            print(f'Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}')
    
    return total_loss / len(loader), total_dice / len(loader)

def validate(model, loader, criterion, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    total_dice = 0
    
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            total_loss += loss.item()
            total_dice += dice_coefficient(output, target).item()
            
    return total_loss / len(loader), total_dice / len(loader)

## Training Loop and Monitoring

We'll implement:
- Training loop with K-fold cross validation
- Learning rate scheduling
- Model checkpointing
- Progress monitoring

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Training parameters
num_epochs = 50
learning_rate = 1e-4
batch_size = 8
num_folds = 5

# Get data loaders
fold_loaders = get_data_loaders('Data/CT', 'Data/Segmentation', batch_size=batch_size, num_folds=num_folds)

# Training loop with K-fold cross validation
fold_metrics = []
for fold, (train_loader, val_loader) in enumerate(fold_loaders):
    print(f'\nTraining fold {fold + 1}/{num_folds}')
    
    # Initialize model, criterion, and optimizer
    model = UNet(n_channels=1, n_classes=1).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
    
    # Initialize metrics tracking
    best_dice = 0
    fold_history = {'train_loss': [], 'train_dice': [], 'val_loss': [], 'val_dice': []}
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch + 1}/{num_epochs}')
        
        # Train
        train_loss, train_dice = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validate
        val_loss, val_dice = validate(model, val_loader, criterion, device)
        
        # Update learning rate
        scheduler.step(val_dice)
        
        # Store metrics
        fold_history['train_loss'].append(train_loss)
        fold_history['train_dice'].append(train_dice)
        fold_history['val_loss'].append(val_loss)
        fold_history['val_dice'].append(val_dice)
        
        print(f'Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}')
        
        # Save best model
        if val_dice > best_dice:
            best_dice = val_dice
            torch.save(model.state_dict(), f'best_model_fold_{fold + 1}.pth')
            print(f'New best model saved with Dice score: {best_dice:.4f}')
    
    fold_metrics.append(fold_history)

## Evaluation and Visualization

We'll:
- Load the best model
- Evaluate on test set
- Visualize predictions
- Analyze model performance

In [None]:
def plot_training_history(fold_metrics):
    """Plot training history for all folds"""
    plt.figure(figsize=(15, 10))
    
    # Plot training loss
    plt.subplot(221)
    for i, metrics in enumerate(fold_metrics):
        plt.plot(metrics['train_loss'], label=f'Fold {i+1}')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot validation loss
    plt.subplot(222)
    for i, metrics in enumerate(fold_metrics):
        plt.plot(metrics['val_loss'], label=f'Fold {i+1}')
    plt.title('Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot training dice
    plt.subplot(223)
    for i, metrics in enumerate(fold_metrics):
        plt.plot(metrics['train_dice'], label=f'Fold {i+1}')
    plt.title('Training Dice Coefficient')
    plt.xlabel('Epoch')
    plt.ylabel('Dice')
    plt.legend()
    
    # Plot validation dice
    plt.subplot(224)
    for i, metrics in enumerate(fold_metrics):
        plt.plot(metrics['val_dice'], label=f'Fold {i+1}')
    plt.title('Validation Dice Coefficient')
    plt.xlabel('Epoch')
    plt.ylabel('Dice')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

def visualize_predictions(model, loader, device, num_samples=3):
    """Visualize model predictions"""
    model.eval()
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            # Visualize results for first batch
            for i in range(min(num_samples, len(data))):
                plt.figure(figsize=(15, 5))
                
                plt.subplot(131)
                plt.imshow(data[i].cpu().squeeze(), cmap='gray')
                plt.title('Input CT')
                plt.axis('off')
                
                plt.subplot(132)
                plt.imshow(target[i].cpu().squeeze(), cmap='gray')
                plt.title('Ground Truth')
                plt.axis('off')
                
                plt.subplot(133)
                plt.imshow(output[i].cpu().squeeze(), cmap='gray')
                plt.title('Prediction')
                plt.axis('off')
                
                plt.show()
            break

# Plot training history
plot_training_history(fold_metrics)

# Load best model and visualize predictions
best_model = UNet(n_channels=1, n_classes=1).to(device)
best_model.load_state_dict(torch.load('best_model_fold_1.pth'))
visualize_predictions(best_model, val_loader, device)