# Ovarian Cancer Segmentation Lab

This lab focuses on medical image segmentation for ovarian cancer detection using CT scans. You will work with volumetric medical data (NIfTI format) to build and train a U-Net model for segmenting different types of cancer tissues.

## Task Overview
You will segment CT volumes into three classes:
- Class 0: Background
- Class 1: Primary ovarian cancer
- Class 2: Metastasis

## 🎯 Learning Objectives
- Work with medical imaging data in NIfTI format
- Implement a 3D U-Net architecture
- Train a segmentation model
- Evaluate medical imaging results

Let's begin! 🚀


In [None]:
# Install required packages
!pip install numpy --quiet
!pip install scipy --quiet
!pip install scikit-learn --quiet
!pip install scikit-image nibabel gdown torch torchvision --quiet

# Import necessary libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
from skimage import transform
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split

# Set up GPU if available
if not torch.cuda.is_available():
    print("WARNING: CUDA is not available. Please make sure to enable GPU in Runtime > Change runtime type")
    print("Current device: CPU")
else:
    # Set default tensor type to CUDA
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    device = torch.device('cuda')
    print(f"Using device: {device}")
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)


In [None]:
# Loss Functions
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, predictions, targets):
        # predictions shape: (batch_size, n_classes, d1, d2, d3)
        # targets shape: (batch_size, d1, d2, d3)
        
        # Convert predictions to probabilities
        predictions = F.softmax(predictions, dim=1)
        
        # One-hot encode targets
        n_classes = predictions.shape[1]
        one_hot_targets = F.one_hot(targets, n_classes).permute(0, 4, 1, 2, 3).float()
        
        # Calculate Dice score for each class
        numerator = 2 * (predictions * one_hot_targets).sum(dim=(2, 3, 4))
        denominator = predictions.sum(dim=(2, 3, 4)) + one_hot_targets.sum(dim=(2, 3, 4))
        dice_scores = (numerator + self.smooth) / (denominator + self.smooth)
        
        # Average over classes and batch
        return 1 - dice_scores.mean()

class CombinedLoss(nn.Module):
    def __init__(self, smooth=1e-5, ce_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.smooth = smooth
        self.ce_weight = ce_weight
        self.dice_loss = DiceLoss(smooth=smooth)
        
    def forward(self, predictions, targets):
        # Calculate class weights based on inverse frequency
        n_classes = predictions.shape[1]
        class_counts = torch.bincount(targets.flatten(), minlength=n_classes).float()
        total_pixels = class_counts.sum()
        class_weights = total_pixels / (class_counts * n_classes + self.smooth)
        class_weights = class_weights.to(predictions.device)
        
        # Dice Loss
        dice_loss = self.dice_loss(predictions, targets)
        
        # Weighted Cross Entropy Loss
        ce_loss = F.cross_entropy(predictions, targets, weight=class_weights)
        
        # Combine losses
        return dice_loss + self.ce_weight * ce_loss


In [None]:
# Download and extract dataset
file_id = '1Wo4h6ZVIFygVvqd68ApwWIdPQk3l7gkO'
output = 'Data_Subsample.zip'

if not os.path.exists('Data_Subsample.zip'):
    import subprocess
    subprocess.run(['gdown', '--id', file_id, '-O', output])

# Extract data if not already extracted
if not os.path.exists('Data_Subsample'):
    import zipfile
    with zipfile.ZipFile(output, 'r') as zip_ref:
        zip_ref.extractall('.')

# List available files
ct_files = sorted([f for f in os.listdir('Data_Subsample/CT') if f.endswith('.nii.gz')])
seg_files = sorted([f for f in os.listdir('Data_Subsample/Segmentation') if f.endswith('.nii.gz')])

print(f'Number of CT volumes: {len(ct_files)}')
print(f'Number of segmentation masks: {len(seg_files)}')


# 1. Environment Setup

First, let's install the required packages:


In [None]:
# Loss Functions
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, predictions, targets):
        # predictions shape: (batch_size, n_classes, d1, d2, d3)
        # targets shape: (batch_size, d1, d2, d3)
        
        # Convert predictions to probabilities
        predictions = F.softmax(predictions, dim=1)
        
        # One-hot encode targets
        n_classes = predictions.shape[1]
        one_hot_targets = F.one_hot(targets, n_classes).permute(0, 4, 1, 2, 3).float()
        
        # Calculate Dice score for each class
        numerator = 2 * (predictions * one_hot_targets).sum(dim=(2, 3, 4))
        denominator = predictions.sum(dim=(2, 3, 4)) + one_hot_targets.sum(dim=(2, 3, 4))
        dice_scores = (numerator + self.smooth) / (denominator + self.smooth)
        
        # Average over classes and batch
        return 1 - dice_scores.mean()

# Combined loss for handling class imbalance
class CombinedLoss(nn.Module):
    def __init__(self, smooth=1e-5, ce_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.smooth = smooth
        self.ce_weight = ce_weight
        self.dice_loss = DiceLoss(smooth=smooth)
        
    def forward(self, predictions, targets):
        # Calculate class weights based on inverse frequency
        n_classes = predictions.shape[1]
        class_counts = torch.bincount(targets.flatten(), minlength=n_classes).float()
        total_pixels = class_counts.sum()
        class_weights = total_pixels / (class_counts * n_classes + self.smooth)
        class_weights = class_weights.to(predictions.device)
        
        # Dice Loss
        dice_loss = self.dice_loss(predictions, targets)
        
        # Weighted Cross Entropy Loss
        ce_loss = F.cross_entropy(predictions, targets, weight=class_weights)
        
        # Combine losses
        return dice_loss + self.ce_weight * ce_loss


In [None]:
# Install packages with specific versions known to work together
!pip install numpy --quiet
!pip install scipy --quiet
!pip install scikit-learn --quiet
!pip install scikit-image nibabel gdown torch torchvision --quiet

import os
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
from skimage import transform
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split

# Force CUDA device if available
if not torch.cuda.is_available():
    print("WARNING: CUDA is not available. Please make sure to enable GPU in Runtime > Change runtime type")
    print("Current device: CPU")
else:
    # Set default tensor type to CUDA
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    device = torch.device('cuda')
    print(f"Using device: {device}")
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Print versions for debugging
print(f"NumPy version: {np.__version__}")
try:
    import sklearn
    print(f"scikit-learn version: {sklearn.__version__}")
except:
    print("scikit-learn import failed")

# Combined loss for handling class imbalance
class CombinedLoss(nn.Module):
    def __init__(self, smooth=1e-5, ce_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.smooth = smooth
        self.ce_weight = ce_weight
        
    def forward(self, predictions, targets):
        # predictions shape: (batch_size, n_classes, d1, d2, d3)
        # targets shape: (batch_size, d1, d2, d3)
        
        # Calculate class weights based on inverse frequency
        n_classes = predictions.shape[1]
        class_counts = torch.bincount(targets.flatten(), minlength=n_classes).float()
        total_pixels = class_counts.sum()
        class_weights = total_pixels / (class_counts * n_classes + self.smooth)
        class_weights = class_weights.to(predictions.device)
        
        # Dice Loss
        predictions_softmax = F.softmax(predictions, dim=1)
        one_hot_targets = F.one_hot(targets, n_classes).permute(0, 4, 1, 2, 3).float()
        
        # Weighted Dice for each class
        numerator = 2 * (predictions_softmax * one_hot_targets).sum(dim=(2, 3, 4))
        denominator = predictions_softmax.sum(dim=(2, 3, 4)) + one_hot_targets.sum(dim=(2, 3, 4))
        dice_per_class = (numerator + self.smooth) / (denominator + self.smooth)
        weighted_dice = (dice_per_class * class_weights).mean()
        dice_loss = 1 - weighted_dice
        
        # Weighted Cross Entropy Loss
        ce_loss = F.cross_entropy(predictions, targets, weight=class_weights)
        
        # Combine losses
        return dice_loss + self.ce_weight * ce_loss

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


# 2. Data Download and Extraction

Download and extract the dataset containing CT volumes and segmentation masks:


In [None]:
# Download data if not already present
file_id = '1Wo4h6ZVIFygVvqd68ApwWIdPQk3l7gkO'
output = 'Data_Subsample.zip'

if not os.path.exists('Data_Subsample.zip'):
    !gdown --id $file_id -O $output

# Extract data if not already extracted
if not os.path.exists('Data_Subsample'):
    !unzip -o Data_Subsample.zip

# List available files
ct_files = sorted([f for f in os.listdir('Data_Subsample/CT') if f.endswith('.nii.gz')])
seg_files = sorted([f for f in os.listdir('Data_Subsample/Segmentation') if f.endswith('.nii.gz')])

print(f'Number of CT volumes: {len(ct_files)}')
print(f'Number of segmentation masks: {len(seg_files)}')


# 3. Exploratory Data Analysis (EDA)

Let's explore the data structure and visualize some examples.

### Questions to consider:
1. What are the typical dimensions of the CT volumes?
2. How are the classes distributed in the segmentation masks?
3. What preprocessing steps might be necessary?


In [None]:
# Dataset class definition
class OvarianCancerDataset(Dataset):
    def __init__(self, ct_files, seg_files, target_shape=(64, 64, 64)):
        self.ct_files = ct_files
        self.seg_files = seg_files
        self.target_shape = target_shape
        
    def normalize_volume(self, volume):
        """Normalize volume to [0,1] range"""
        min_val = np.min(volume)
        max_val = np.max(volume)
        if max_val - min_val == 0:
            return volume
        return (volume - min_val) / (max_val - min_val)
    
    def load_volume(self, file_path):
        """Load a NIfTI volume and return its data"""
        return nib.load(file_path).get_fdata()
    
    def preprocess_volume(self, ct_path, seg_path):
        """Load and preprocess a single volume pair"""
        # Load volumes
        ct_vol = self.load_volume(ct_path)
        seg_vol = self.load_volume(seg_path)
        
        # Normalize CT volume
        ct_vol = self.normalize_volume(ct_vol)
        
        # Resample to target shape
        if ct_vol.shape != self.target_shape:
            ct_vol = transform.resize(ct_vol, self.target_shape, mode='constant', anti_aliasing=True)
            seg_vol = transform.resize(seg_vol, self.target_shape, mode='constant', order=0, anti_aliasing=False)
        
        # Ensure segmentation values are integers
        seg_vol = np.round(seg_vol).astype(np.int64)
        
        return ct_vol, seg_vol
        
    def __len__(self):
        return len(self.ct_files)
    
    def __getitem__(self, idx):
        ct_path = os.path.join('Data_Subsample/CT', self.ct_files[idx])
        seg_path = os.path.join('Data_Subsample/Segmentation', self.seg_files[idx])
        
        # Load and preprocess
        ct_vol, seg_vol = self.preprocess_volume(ct_path, seg_path)
        
        # Convert to torch tensors and add channel dimension
        ct_vol = torch.FloatTensor(ct_vol).unsqueeze(0)
        seg_vol = torch.LongTensor(seg_vol)
        
        # Move to GPU if available
        if torch.cuda.is_available():
            ct_vol = ct_vol.cuda()
            seg_vol = seg_vol.cuda()
            
        return ct_vol, seg_vol


In [None]:
# Visualization functions
def plot_slices(ct_volume, seg_mask=None, slice_nums=None, cmap='gray'):
    """Plot multiple slices from a volume with optional segmentation overlay"""
    if slice_nums is None:
        slice_nums = [ct_volume.shape[2]//2]
    
    fig, axes = plt.subplots(1, len(slice_nums), figsize=(15, 5))
    if len(slice_nums) == 1:
        axes = [axes]
    
    for ax, slice_num in zip(axes, slice_nums):
        ax.imshow(ct_volume[:,:,slice_num], cmap=cmap)
        if seg_mask is not None:
            # Create a masked array for the segmentation
            mask_slice = seg_mask[:,:,slice_num]
            ax.imshow(mask_slice, alpha=0.3, cmap='jet')
        ax.axis('off')
        ax.set_title(f'Slice {slice_num}')
    plt.tight_layout()
    plt.show()

# Create dataset instance for visualization
dataset = OvarianCancerDataset([ct_files[0]], [seg_files[0]])

# Get preprocessed data
ct_path = os.path.join('Data_Subsample/CT', ct_files[0])
seg_path = os.path.join('Data_Subsample/Segmentation', seg_files[0])
ct_processed, seg_processed = dataset.preprocess_volume(ct_path, seg_path)

print('Processed shapes:', ct_processed.shape, seg_processed.shape)
print('Value ranges - CT:', ct_processed.min(), ct_processed.max(),
      '\nSegmentation:', seg_processed.min(), seg_processed.max())

# Visualize middle slice
mid_slice = ct_processed.shape[2] // 2
plot_slices(ct_processed, seg_processed, [mid_slice])


In [None]:
def load_volume(file_path):
    """Load a NIfTI volume and return its data"""
    return nib.load(file_path).get_fdata()

def plot_slices(ct_volume, seg_mask=None, slice_nums=None, cmap='gray'):
    """Plot multiple slices from a volume with optional segmentation overlay"""
    if slice_nums is None:
        slice_nums = [ct_volume.shape[2]//2]
    
    fig, axes = plt.subplots(1, len(slice_nums), figsize=(15, 5))
    if len(slice_nums) == 1:
        axes = [axes]
    
    for ax, slice_num in zip(axes, slice_nums):
        ax.imshow(ct_volume[:,:,slice_num], cmap=cmap)
        if seg_mask is not None:
            # Create a masked array for the segmentation
            mask_slice = seg_mask[:,:,slice_num]
            ax.imshow(mask_slice, alpha=0.3, cmap='jet')
        ax.axis('off')
        ax.set_title(f'Slice {slice_num}')
    plt.tight_layout()
    plt.show()

# Load and examine first volume
ct_path = os.path.join('Data_Subsample/CT', ct_files[0])
seg_path = os.path.join('Data_Subsample/Segmentation', seg_files[0])

ct_vol = load_volume(ct_path)
seg_vol = load_volume(seg_path)

print('CT volume shape:', ct_vol.shape)
print('Segmentation mask shape:', seg_vol.shape)
print('\nUnique classes in segmentation:', np.unique(seg_vol))

# Plot middle slices
middle_slice = ct_vol.shape[2]//2
plot_slices(ct_vol, seg_vol, [middle_slice-20, middle_slice, middle_slice+20])


# 4. Data Preprocessing

We'll implement several preprocessing steps:
1. Intensity normalization
2. Resampling to a common size
3. Data augmentation

### Questions to consider:
1. Why is normalization important for medical images?
2. What are appropriate augmentation techniques for 3D medical data?


In [None]:
# Create a temporary dataset instance for preprocessing
temp_dataset = OvarianCancerDataset([ct_files[0]], [seg_files[0]])
ct_processed, seg_processed = temp_dataset.preprocess_volume(
    os.path.join('Data_Subsample/CT', ct_files[0]),
    os.path.join('Data_Subsample/Segmentation', seg_files[0])
)

print('Processed shapes:', ct_processed.shape, seg_processed.shape)
print('Value ranges - CT:', ct_processed.min(), ct_processed.max(),
      '\nSegmentation:', seg_processed.min(), seg_processed.max())

# Visualize processed data - using middle slice (31 for 64x64x64 volume)
mid_slice = ct_processed.shape[2] // 2  # This will be 31 for a 64x64x64 volume
plot_slices(ct_processed, seg_processed, [mid_slice])


In [None]:
# Dataset and Model Setup
from torch.utils.data import Dataset, DataLoader

class OvarianCancerDataset(Dataset):
    def __init__(self, ct_files, seg_files, target_shape=(64, 64, 64)):
        self.ct_files = ct_files
        self.seg_files = seg_files
        self.target_shape = target_shape
        
    def normalize_volume(self, volume):
        """Normalize volume to [0,1] range"""
        min_val = np.min(volume)
        max_val = np.max(volume)
        if max_val - min_val == 0:
            return volume
        return (volume - min_val) / (max_val - min_val)
    
    def load_volume(self, file_path):
        """Load a NIfTI volume and return its data"""
        return nib.load(file_path).get_fdata()
    
    def preprocess_volume(self, ct_path, seg_path):
        """Load and preprocess a single volume pair"""
        # Load volumes
        ct_vol = self.load_volume(ct_path)
        seg_vol = self.load_volume(seg_path)
        
        # Normalize CT volume
        ct_vol = self.normalize_volume(ct_vol)
        
        # Resample to target shape
        if ct_vol.shape != self.target_shape:
            ct_vol = transform.resize(ct_vol, self.target_shape, mode='constant', anti_aliasing=True)
            seg_vol = transform.resize(seg_vol, self.target_shape, mode='constant', order=0, anti_aliasing=False)
        
        # Ensure segmentation values are integers
        seg_vol = np.round(seg_vol).astype(np.int64)
        
        return ct_vol, seg_vol
        
    def __len__(self):
        return len(self.ct_files)
    
    def __getitem__(self, idx):
        ct_path = os.path.join('Data_Subsample/CT', self.ct_files[idx])
        seg_path = os.path.join('Data_Subsample/Segmentation', self.seg_files[idx])
        
        # Load and preprocess
        ct_vol, seg_vol = self.preprocess_volume(ct_path, seg_path)
        
        # Convert to torch tensors and add channel dimension
        ct_vol = torch.FloatTensor(ct_vol).unsqueeze(0)
        seg_vol = torch.LongTensor(seg_vol)
        
        # Move to GPU if available
        if torch.cuda.is_available():
            ct_vol = ct_vol.cuda()
            seg_vol = seg_vol.cuda()
            
        return ct_vol, seg_vol

# Split data
train_ct, val_ct, train_seg, val_seg = train_test_split(
    ct_files, seg_files, test_size=0.2, random_state=42
)

# Create datasets with target shape
target_shape = (64, 64, 64)  # You can adjust this if needed
train_dataset = OvarianCancerDataset(train_ct, train_seg, target_shape=target_shape)
val_dataset = OvarianCancerDataset(val_ct, val_seg, target_shape=target_shape)

# Create dataloaders with smaller batch size
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')


In [None]:
# Model Architecture
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.conv(x)

class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=3, features=[16, 32, 64, 128]):
        super(UNet3D, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        
        # Encoder
        in_channels_temp = in_channels
        for feature in features:
            self.encoder.append(DoubleConv(in_channels_temp, feature))
            in_channels_temp = feature

        # Decoder
        for feature in reversed(features[:-1]):
            # Upsampling
            self.decoder.append(
                nn.Sequential(
                    nn.ConvTranspose3d(
                        features[features.index(feature)+1],
                        feature,
                        kernel_size=2,
                        stride=2
                    ),
                    nn.BatchNorm3d(feature),
                    nn.ReLU(inplace=True)
                )
            )
            # Double conv after concatenation
            self.decoder.append(DoubleConv(feature * 2, feature))

        self.bottleneck = DoubleConv(features[-2], features[-1])
        self.final_conv = nn.Conv3d(features[0], out_channels, kernel_size=1)
        
        # Regularization
        self.dropout = nn.Dropout3d(p=0.2)
        
    def forward(self, x):
        skip_connections = []

        # Encoder
        for encoder in self.encoder[:-1]:
            x = encoder(x)
            skip_connections.append(x)
            x = self.pool(x)
            x = self.dropout(x)

        x = self.bottleneck(x)

        # Decoder
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.decoder), 2):
            x = self.decoder[idx](x)
            skip = skip_connections[idx//2]
            
            # Handle different sizes
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:])
                
            concat_skip = torch.cat((skip, x), dim=1)
            x = self.decoder[idx+1](concat_skip)
            x = self.dropout(x)

        return self.final_conv(x)

# Initialize model and move to GPU if available
model = UNet3D(in_channels=1, out_channels=3)
if torch.cuda.is_available():
    model = model.cuda()

# Initialize loss function and optimizer
criterion = CombinedLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model architecture:\n{model}")


In [None]:
# Training and Validation Functions
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch_idx, (data, target) in enumerate(loader):
        # Move to GPU and clear gradients
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        output = model(data)
        loss = criterion(output, target)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Update metrics
        total_loss += loss.item()
        num_batches += 1
        
        # Print progress and clear cache
        if batch_idx % 5 == 0:
            print(f'Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}')
            torch.cuda.empty_cache()  # Clear GPU cache periodically
    
    return total_loss / num_batches

def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            num_batches += 1
            
    return total_loss / num_batches

# Training loop
n_epochs = 100
best_val_loss = float('inf')
patience = 10
patience_counter = 0

print("Starting training...")
for epoch in range(n_epochs):
    print(f"\nEpoch {epoch+1}/{n_epochs}")
    
    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    val_loss = validate(model, val_loader, criterion)
    
    print(f"Training Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        # Save best model
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break
            
    # Clear GPU cache after each epoch
    torch.cuda.empty_cache()


In [None]:
# Import preprocessing function from previous cell
from __main__ import preprocess_volume

class OvarianCancerDataset(Dataset):
    def __init__(self, ct_files, seg_files, target_shape=(64, 64, 64)):
        self.ct_files = ct_files
        self.seg_files = seg_files
        self.target_shape = target_shape
        
    def normalize_volume(self, volume):
        """Normalize volume to [0,1] range"""
        min_val = np.min(volume)
        max_val = np.max(volume)
        if max_val - min_val == 0:
            return volume
        return (volume - min_val) / (max_val - min_val)
    
    def load_volume(self, file_path):
        """Load a NIfTI volume and return its data"""
        return nib.load(file_path).get_fdata()
    
    def preprocess_volume(self, ct_path, seg_path):
        """Load and preprocess a single volume pair"""
        # Load volumes
        ct_vol = self.load_volume(ct_path)
        seg_vol = self.load_volume(seg_path)
        
        # Normalize CT volume
        ct_vol = self.normalize_volume(ct_vol)
        
        # Resample to target shape
        if ct_vol.shape != self.target_shape:
            ct_vol = transform.resize(ct_vol, self.target_shape, mode='constant', anti_aliasing=True)
            seg_vol = transform.resize(seg_vol, self.target_shape, mode='constant', order=0, anti_aliasing=False)
        
        # Ensure segmentation values are integers
        seg_vol = np.round(seg_vol).astype(np.int64)
        
        return ct_vol, seg_vol
        
    def __len__(self):
        return len(self.ct_files)
    
    def __getitem__(self, idx):
        ct_path = os.path.join('Data_Subsample/CT', self.ct_files[idx])
        seg_path = os.path.join('Data_Subsample/Segmentation', self.seg_files[idx])
        
        # Load and preprocess
        ct_vol, seg_vol = self.preprocess_volume(ct_path, seg_path)
        
        # Convert to torch tensors and add channel dimension
        ct_vol = torch.FloatTensor(ct_vol).unsqueeze(0)
        seg_vol = torch.LongTensor(seg_vol)
        
        # Move to GPU if available
        if torch.cuda.is_available():
            ct_vol = ct_vol.cuda()
            seg_vol = seg_vol.cuda()
            
        return ct_vol, seg_vol

# Split data
train_ct, val_ct, train_seg, val_seg = train_test_split(
    ct_files, seg_files, test_size=0.2, random_state=42
)

# Create datasets
train_dataset = OvarianCancerDataset(train_ct, train_seg)
val_dataset = OvarianCancerDataset(val_ct, val_seg)

# Create dataloaders with smaller batch size
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')


In [None]:
def normalize_volume(volume):
    """Normalize volume to [0,1] range"""
    min_val = np.min(volume)
    max_val = np.max(volume)
    if max_val - min_val == 0:
        return volume
    return (volume - min_val) / (max_val - min_val)

def preprocess_volume(ct_path, seg_path, target_shape=(64, 64, 64)):
    """Load and preprocess a single volume pair"""
    # Load volumes
    ct_vol = load_volume(ct_path)
    seg_vol = load_volume(seg_path)
    
    # Normalize CT volume
    ct_vol = normalize_volume(ct_vol)
    
    # Resample to target shape
    if ct_vol.shape != target_shape:
        ct_vol = transform.resize(ct_vol, target_shape, mode='constant', anti_aliasing=True)
        seg_vol = transform.resize(seg_vol, target_shape, mode='constant', order=0, anti_aliasing=False)
    
    # Ensure segmentation values are integers
    seg_vol = np.round(seg_vol).astype(np.int64)
    
    return ct_vol, seg_vol

# Example preprocessing
ct_processed, seg_processed = preprocess_volume(ct_path, seg_path)
print('Processed shapes:', ct_processed.shape, seg_processed.shape)
print('Value ranges - CT:', ct_processed.min(), ct_processed.max(),
      '\nSegmentation:', seg_processed.min(), seg_processed.max())

# Visualize processed data
plot_slices(ct_processed, seg_processed, [64])


# 5. Dataset and DataLoader

Create a PyTorch dataset for efficient data handling:


In [None]:
class OvarianCancerDataset(Dataset):
    def __init__(self, ct_files, seg_files, transform=None):
        self.ct_files = ct_files
        self.seg_files = seg_files
        self.transform = transform
    
    def __len__(self):
        return len(self.ct_files)
    
    def __getitem__(self, idx):
        ct_path = os.path.join('Data_Subsample/CT', self.ct_files[idx])
        seg_path = os.path.join('Data_Subsample/Segmentation', self.seg_files[idx])
        
        # Load and preprocess
        ct_vol, seg_vol = preprocess_volume(ct_path, seg_path)
        
        # Convert to torch tensors
        ct_vol = torch.FloatTensor(ct_vol).unsqueeze(0)  # Add channel dimension
        seg_vol = torch.LongTensor(seg_vol)
        
        if self.transform:
            ct_vol = self.transform(ct_vol)
        
        return ct_vol, seg_vol

# Split data
train_ct, val_ct, train_seg, val_seg = train_test_split(
    ct_files, seg_files, test_size=0.2, random_state=42
)

# Create datasets
train_dataset = OvarianCancerDataset(train_ct, train_seg)
val_dataset = OvarianCancerDataset(val_ct, val_seg)

# Create dataloaders with smaller batch size
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')


# 6. Model Architecture

Implement a simplified 3D U-Net for segmentation:

### Questions to consider:
1. Why is U-Net particularly suitable for medical image segmentation?
2. What modifications might improve performance?


In [None]:
# 6. Model Architecture and Training

print("Building and training 3D U-Net for medical image segmentation...")

# Questions to consider:
# 1. Why is U-Net particularly suitable for medical image segmentation?
# 2. What modifications might improve performance?
# 3. How does the training process affect segmentation quality?

# Define model components
class SEBlock(nn.Module):
    """Squeeze-and-Excitation block for channel attention"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y

class DoubleConv(nn.Module):
    """Enhanced double convolution block with SE attention"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.se = SEBlock(out_channels)
        self.residual = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm3d(out_channels)
        )
        
    def forward(self, x):
        main = self.double_conv(x)
        main = self.se(main)
        residual = self.residual(x)
        return F.relu(main + residual)

class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=3, features=[16, 32, 64, 128]):
        super(UNet3D, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        
        # Encoder
        in_channels_temp = in_channels
        for feature in features:
            self.encoder.append(DoubleConv(in_channels_temp, feature))
            in_channels_temp = feature

        # Decoder
        for feature in reversed(features[:-1]):
            # Upsampling
            self.decoder.append(
                nn.Sequential(
                    nn.ConvTranspose3d(
                        features[features.index(feature)+1],
                        feature,
                        kernel_size=2,
                        stride=2
                    ),
                    nn.BatchNorm3d(feature),
                    nn.ReLU(inplace=True)
                )
            )
            # Double conv after concatenation
            self.decoder.append(DoubleConv(feature * 2, feature))

        self.bottleneck = DoubleConv(features[-2], features[-1])
        self.final_conv = nn.Conv3d(features[0], out_channels, kernel_size=1)
        
        # Regularization
        self.dropout = nn.Dropout3d(p=0.2)

    def forward(self, x):
        skip_connections = []

        # Encoder
        for encoder in self.encoder[:-1]:
            x = encoder(x)
            skip_connections.append(x)
            x = self.pool(x)
            x = self.dropout(x)

        x = self.bottleneck(x)

        # Decoder
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.decoder), 2):
            x = self.decoder[idx](x)
            skip = skip_connections[idx//2]
            
            # Handle different sizes
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:])
                
            concat_skip = torch.cat((skip, x), dim=1)
            x = self.decoder[idx+1](concat_skip)
            x = self.dropout(x)

        return self.final_conv(x)

# Initialize model
model = UNet3D().to(device)
print(model)


In [None]:
# Define the complete U-Net model
class UNet3D(nn.Module):
    """3D U-Net architecture with residual connections and dropout"""
    def __init__(self, in_channels=1, out_channels=3, features=[32, 64, 128, 256]):
        super(UNet3D, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

        # Encoder
        in_channels_temp = in_channels
        for feature in features:
            self.encoder.append(DoubleConv(in_channels_temp, feature))
            in_channels_temp = feature

        # Decoder
        for feature in reversed(features[:-1]):
            self.decoder.append(
                nn.ConvTranspose3d(
                    features[features.index(feature)+1],
                    feature,
                    kernel_size=2,
                    stride=2
                )
            )
            self.decoder.append(DoubleConv(feature * 2, feature))

        self.bottleneck = DoubleConv(features[-2], features[-1])
        self.final_conv = nn.Conv3d(features[0], out_channels, kernel_size=1)
        
        # Dropout layers
        self.dropout = nn.Dropout3d(p=0.3)

    def forward(self, x):
        skip_connections = []

        # Encoder
        for encoder in self.encoder[:-1]:
            x = encoder(x)
            skip_connections.append(x)
            x = self.pool(x)
            x = self.dropout(x)

        x = self.bottleneck(x)

        # Decoder
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.decoder), 2):
            x = self.decoder[idx](x)
            skip = skip_connections[idx//2]
            
            # Handle different sizes
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:])
                
            concat_skip = torch.cat((skip, x), dim=1)
            x = self.decoder[idx+1](concat_skip)
            x = self.dropout(x)

        return self.final_conv(x)

# Initialize model and training components
model = UNet3D().to(device)
criterion = DiceLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

print("\nModel architecture:")
print(model)
print("\nTraining parameters:")
print(f"Optimizer: AdamW with learning rate 1e-3 and weight decay 0.01")
print(f"Loss function: Dice Loss")
print(f"Learning rate scheduler: ReduceLROnPlateau")


# Training Functions and Loop

print("Setting up training process with loss functions and optimization...")


In [None]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch_idx, (data, target) in enumerate(loader):
        # Move to GPU and clear gradients
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        output = model(data)
        loss = criterion(output, target)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Update metrics
        total_loss += loss.item()
        num_batches += 1
        
        # Print progress and clear cache
        if batch_idx % 5 == 0:
            print(f'Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}')
            torch.cuda.empty_cache()  # Clear GPU cache periodically
    
    return total_loss / num_batches

def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)  # During validation, only get main output
            loss = criterion(output, target)
            total_loss += loss.item()
            num_batches += 1
            
    return total_loss / num_batches

# Training setup
criterion = DiceLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

# Training loop
n_epochs = 30
best_val_loss = float('inf')
patience = 5
patience_counter = 0

for epoch in range(n_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    val_loss = validate(model, val_loader, criterion)
    
    print(f'Epoch {epoch+1}/{n_epochs}:')
    print(f'Train Loss: {train_loss:.4f}')
    print(f'Val Loss: {val_loss:.4f}')
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= patience:
        print(f'Early stopping triggered after {epoch+1} epochs')
        break
    
    # Print current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    print(f'Current learning rate: {current_lr:.2e}\n')


# 8. Evaluation and Visualization

Evaluate the model and visualize results:

### Questions to consider:
1. How well does the model segment different classes?
2. What are the clinical implications of false positives/negatives?
3. How could the model be improved?


In [None]:
def predict_volume(model, ct_volume):
    model.eval()
    with torch.no_grad():
        pred = model(ct_volume.unsqueeze(0).to(device))
        pred = F.softmax(pred, dim=1)
        pred = torch.argmax(pred, dim=1)
    return pred[0].cpu().numpy()

# Load a validation sample
val_ct, val_seg = val_dataset[0]
pred_seg = predict_volume(model, val_ct)

# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
slice_idx = val_ct.shape[2]//2

axes[0].imshow(val_ct[0, :, :, slice_idx], cmap='gray')
axes[0].set_title('CT Slice')
axes[0].axis('off')

axes[1].imshow(val_seg[:, :, slice_idx], cmap='jet')
axes[1].set_title('True Segmentation')
axes[1].axis('off')

axes[2].imshow(pred_seg[:, :, slice_idx], cmap='jet')
axes[2].set_title('Predicted Segmentation')
axes[2].axis('off')

plt.tight_layout()
plt.show()


# 9. Discussion Questions

Please answer the following questions based on your implementation and results:

1. **Data Analysis**
   - What challenges did you encounter with the medical imaging data?
   - How did you handle class imbalance?

2. **Model Performance**
   - How well did the model perform on different classes?
   - What were the main sources of error?

3. **Clinical Relevance**
   - How might this model be useful in a clinical setting?
   - What additional validation would be needed?

4. **Improvements**
   - What modifications could improve the model's performance?
   - How could the preprocessing pipeline be enhanced?

Write your answers below:

1. Data Analysis:
   > Your answer here

2. Model Performance:
   > Your answer here

3. Clinical Relevance:
   > Your answer here

4. Improvements:
   > Your answer here
