# Ovarian Cancer Segmentation Lab

This lab focuses on medical image segmentation for ovarian cancer detection using CT scans. You will work with volumetric medical data (NIfTI format) to build and train a U-Net model for segmenting different types of cancer tissues.

## Task Overview
You will segment CT volumes into three classes:
- Class 0: Background
- Class 1: Primary ovarian cancer
- Class 2: Metastasis

## 🎯 Learning Objectives
- Work with medical imaging data in NIfTI format
- Implement a 3D U-Net architecture
- Train a segmentation model
- Evaluate medical imaging results

Let's begin! 🚀


# 1. Environment Setup

First, let's install the required packages:


In [None]:
# Install dependencies in the correct order
!pip install numpy==1.24.3 --quiet
!pip install torch torchvision --quiet
!pip install nibabel matplotlib scikit-image gdown --quiet
!pip install monai --quiet

import os
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
from skimage import transform
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from monai.losses import DiceLoss

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


# 2. Data Download and Extraction

Download and extract the dataset containing CT volumes and segmentation masks:


In [None]:
# Download data if not already present
file_id = '1Wo4h6ZVIFygVvqd68ApwWIdPQk3l7gkO'
output = 'Data_Subsample.zip'

if not os.path.exists('Data_Subsample.zip'):
    !gdown --id $file_id -O $output

# Extract data if not already extracted
if not os.path.exists('Data_Subsample'):
    !unzip -o Data_Subsample.zip

# List available files
ct_files = sorted([f for f in os.listdir('Data_Subsample/CT') if f.endswith('.nii.gz')])
seg_files = sorted([f for f in os.listdir('Data_Subsample/Segmentation') if f.endswith('.nii.gz')])

print(f'Number of CT volumes: {len(ct_files)}')
print(f'Number of segmentation masks: {len(seg_files)}')


# 3. Exploratory Data Analysis (EDA)

Let's explore the data structure and visualize some examples.

### Questions to consider:
1. What are the typical dimensions of the CT volumes?
2. How are the classes distributed in the segmentation masks?
3. What preprocessing steps might be necessary?


In [None]:
def load_volume(file_path):
    """Load a NIfTI volume and return its data"""
    return nib.load(file_path).get_fdata()

def plot_slices(ct_volume, seg_mask=None, slice_nums=None, cmap='gray'):
    """Plot multiple slices from a volume with optional segmentation overlay"""
    if slice_nums is None:
        slice_nums = [ct_volume.shape[2]//2]
    
    fig, axes = plt.subplots(1, len(slice_nums), figsize=(15, 5))
    if len(slice_nums) == 1:
        axes = [axes]
    
    for ax, slice_num in zip(axes, slice_nums):
        ax.imshow(ct_volume[:,:,slice_num], cmap=cmap)
        if seg_mask is not None:
            # Create a masked array for the segmentation
            mask_slice = seg_mask[:,:,slice_num]
            ax.imshow(mask_slice, alpha=0.3, cmap='jet')
        ax.axis('off')
        ax.set_title(f'Slice {slice_num}')
    plt.tight_layout()
    plt.show()

# Load and examine first volume
ct_path = os.path.join('Data_Subsample/CT', ct_files[0])
seg_path = os.path.join('Data_Subsample/Segmentation', seg_files[0])

ct_vol = load_volume(ct_path)
seg_vol = load_volume(seg_path)

print('CT volume shape:', ct_vol.shape)
print('Segmentation mask shape:', seg_vol.shape)
print('\nUnique classes in segmentation:', np.unique(seg_vol))

# Plot middle slices
middle_slice = ct_vol.shape[2]//2
plot_slices(ct_vol, seg_vol, [middle_slice-20, middle_slice, middle_slice+20])


# 4. Data Preprocessing

We'll implement several preprocessing steps:
1. Intensity normalization
2. Resampling to a common size
3. Data augmentation

### Questions to consider:
1. Why is normalization important for medical images?
2. What are appropriate augmentation techniques for 3D medical data?


In [None]:
def normalize_volume(volume):
    """Normalize volume to [0,1] range"""
    min_val = np.min(volume)
    max_val = np.max(volume)
    if max_val - min_val == 0:
        return volume
    return (volume - min_val) / (max_val - min_val)

def preprocess_volume(ct_path, seg_path, target_shape=(128, 128, 128)):
    """Load and preprocess a single volume pair"""
    # Load volumes
    ct_vol = load_volume(ct_path)
    seg_vol = load_volume(seg_path)
    
    # Normalize CT volume
    ct_vol = normalize_volume(ct_vol)
    
    # Resample to target shape
    if ct_vol.shape != target_shape:
        ct_vol = transform.resize(ct_vol, target_shape, mode='constant', anti_aliasing=True)
        seg_vol = transform.resize(seg_vol, target_shape, mode='constant', order=0, anti_aliasing=False)
    
    return ct_vol, seg_vol

# Example preprocessing
ct_processed, seg_processed = preprocess_volume(ct_path, seg_path)
print('Processed shapes:', ct_processed.shape, seg_processed.shape)
print('Value ranges - CT:', ct_processed.min(), ct_processed.max(),
      '\nSegmentation:', seg_processed.min(), seg_processed.max())

# Visualize processed data
plot_slices(ct_processed, seg_processed, [64])


# 5. Dataset and DataLoader

Create a PyTorch dataset for efficient data handling:


In [None]:
class OvarianCancerDataset(Dataset):
    def __init__(self, ct_files, seg_files, transform=None):
        self.ct_files = ct_files
        self.seg_files = seg_files
        self.transform = transform
    
    def __len__(self):
        return len(self.ct_files)
    
    def __getitem__(self, idx):
        ct_path = os.path.join('Data_Subsample/CT', self.ct_files[idx])
        seg_path = os.path.join('Data_Subsample/Segmentation', self.seg_files[idx])
        
        # Load and preprocess
        ct_vol, seg_vol = preprocess_volume(ct_path, seg_path)
        
        # Convert to torch tensors
        ct_vol = torch.FloatTensor(ct_vol).unsqueeze(0)  # Add channel dimension
        seg_vol = torch.LongTensor(seg_vol)
        
        if self.transform:
            ct_vol = self.transform(ct_vol)
        
        return ct_vol, seg_vol

# Split data
from sklearn.model_selection import train_test_split

train_ct, val_ct, train_seg, val_seg = train_test_split(
    ct_files, seg_files, test_size=0.2, random_state=42
)

# Create datasets
train_dataset = OvarianCancerDataset(train_ct, train_seg)
val_dataset = OvarianCancerDataset(val_ct, val_seg)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')


# 6. Model Architecture

Implement a simplified 3D U-Net for segmentation:

### Questions to consider:
1. Why is U-Net particularly suitable for medical image segmentation?
2. What modifications might improve performance?


In [None]:
class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=3):
        super(UNet3D, self).__init__()
        
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv3d(in_channels, 16, 3, padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(inplace=True)
        )
        self.enc2 = nn.Sequential(
            nn.MaxPool3d(2),
            nn.Conv3d(16, 32, 3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True)
        )
        
        # Decoder
        self.dec1 = nn.Sequential(
            nn.ConvTranspose3d(32, 16, 2, stride=2),
            nn.BatchNorm3d(16),
            nn.ReLU(inplace=True)
        )
        self.dec2 = nn.Sequential(
            nn.Conv3d(32, 16, 3, padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(inplace=True),
            nn.Conv3d(16, out_channels, 1)
        )
        
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        
        # Decoder
        d1 = self.dec1(e2)
        d2 = self.dec2(torch.cat([d1, e1], dim=1))
        
        return d2

# Initialize model
model = UNet3D().to(device)
print(model)


# 7. Training Loop

Set up the training process with appropriate loss functions and optimization:


In [None]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        output = model(data)
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    return total_loss / len(loader)

def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            
    return total_loss / len(loader)

# Training setup
criterion = DiceLoss(softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training loop
n_epochs = 10
best_val_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    val_loss = validate(model, val_loader, criterion)
    
    print(f'Epoch {epoch+1}/{n_epochs}:')
    print(f'Train Loss: {train_loss:.4f}')
    print(f'Val Loss: {val_loss:.4f}')
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')


# 8. Evaluation and Visualization

Evaluate the model and visualize results:

### Questions to consider:
1. How well does the model segment different classes?
2. What are the clinical implications of false positives/negatives?
3. How could the model be improved?


In [None]:
def predict_volume(model, ct_volume):
    model.eval()
    with torch.no_grad():
        pred = model(ct_volume.unsqueeze(0).to(device))
        pred = F.softmax(pred, dim=1)
        pred = torch.argmax(pred, dim=1)
    return pred[0].cpu().numpy()

# Load a validation sample
val_ct, val_seg = val_dataset[0]
pred_seg = predict_volume(model, val_ct)

# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
slice_idx = val_ct.shape[2]//2

axes[0].imshow(val_ct[0, :, :, slice_idx], cmap='gray')
axes[0].set_title('CT Slice')
axes[0].axis('off')

axes[1].imshow(val_seg[:, :, slice_idx], cmap='jet')
axes[1].set_title('True Segmentation')
axes[1].axis('off')

axes[2].imshow(pred_seg[:, :, slice_idx], cmap='jet')
axes[2].set_title('Predicted Segmentation')
axes[2].axis('off')

plt.tight_layout()
plt.show()


# 9. Discussion Questions

Please answer the following questions based on your implementation and results:

1. **Data Analysis**
   - What challenges did you encounter with the medical imaging data?
   - How did you handle class imbalance?

2. **Model Performance**
   - How well did the model perform on different classes?
   - What were the main sources of error?

3. **Clinical Relevance**
   - How might this model be useful in a clinical setting?
   - What additional validation would be needed?

4. **Improvements**
   - What modifications could improve the model's performance?
   - How could the preprocessing pipeline be enhanced?

Write your answers below:

1. Data Analysis:
   > Your answer here

2. Model Performance:
   > Your answer here

3. Clinical Relevance:
   > Your answer here

4. Improvements:
   > Your answer here
