# 🔬 Ovarian Cancer Segmentation Lab

Welcome to this comprehensive lab on medical image segmentation for ovarian cancer detection! In this lab, you'll work with volumetric CT scan data to develop an advanced deep learning solution for automated cancer tissue identification.

## 📋 Task Overview
Your goal is to develop a 3D U-Net model that can accurately segment CT volumes into three distinct classes:
- **Class 0**: Background tissue
- **Class 1**: Primary ovarian cancer
- **Class 2**: Metastatic tissue

## 🎯 Learning Objectives
By completing this lab, you will:
- Master working with medical imaging data in NIfTI format
- Implement and understand the 3D U-Net architecture
- Learn effective training strategies for medical image segmentation
- Develop skills in evaluating and validating medical imaging models
- Gain practical experience with real-world medical data

## 🔍 Clinical Relevance
Accurate segmentation of ovarian cancer tissues is crucial for:
- Early detection and diagnosis
- Treatment planning and monitoring
- Assessment of disease progression
- Research and clinical trials

Let's dive in and build a solution that could make a real difference in healthcare! 🚀


# 1️⃣ Environment Setup and Dependencies

Before we begin our implementation, let's set up our development environment with all necessary packages and configurations.

## 📦 Required Packages
We'll be using the following key libraries:
- **PyTorch**: For deep learning model implementation
- **NiBabel**: For handling medical imaging data in NIfTI format
- **scikit-image**: For image processing and transformations
- **NumPy**: For numerical computations
- **Matplotlib**: For visualization

## 🖥️ Hardware Requirements
- GPU with CUDA support (recommended)
- Sufficient RAM for 3D volume processing
- Adequate storage for medical imaging data

## ⚙️ Configuration
We'll set up:
- CUDA device if available
- Random seeds for reproducibility
- Memory optimization settings


In [None]:
# Install required packages
!pip install numpy --quiet
!pip install scipy --quiet
!pip install scikit-learn --quiet
!pip install scikit-image nibabel gdown torch torchvision --quiet

# Import necessary libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
from skimage import transform
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split

# Set up GPU if available
if not torch.cuda.is_available():
    print("WARNING: CUDA is not available. Please make sure to enable GPU in Runtime > Change runtime type")
    print("Current device: CPU")
else:
    # Set default tensor type to CUDA
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    device = torch.device('cuda')
    print(f"Using device: {device}")
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"Available memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)


# 2️⃣ Data Acquisition and Preprocessing

## 📥 Dataset Download
First, we'll download our dataset containing CT scans and their corresponding segmentation masks. The data is stored in NIfTI format (`.nii.gz`), which is commonly used for medical imaging.

## 🗂️ Data Organization
The dataset is organized into two main directories:
- `Data_Subsample/CT/`: Contains the CT scan volumes
- `Data_Subsample/Segmentation/`: Contains the corresponding segmentation masks

## 💾 Data Loading
Let's download and extract the dataset, then verify our data structure:


In [None]:
# Download and extract dataset
file_id = '1Wo4h6ZVIFygVvqd68ApwWIdPQk3l7gkO'
output = 'Data_Subsample.zip'

if not os.path.exists('Data_Subsample.zip'):
    import subprocess
    subprocess.run(['gdown', '--id', file_id, '-O', output])

# Extract data if not already extracted
if not os.path.exists('Data_Subsample'):
    import zipfile
    with zipfile.ZipFile(output, 'r') as zip_ref:
        zip_ref.extractall('.')

# List available files
ct_files = sorted([f for f in os.listdir('Data_Subsample/CT') if f.endswith('.nii.gz')])
seg_files = sorted([f for f in os.listdir('Data_Subsample/Segmentation') if f.endswith('.nii.gz')])

print(f'Number of CT volumes: {len(ct_files)}')
print(f'Number of segmentation masks: {len(seg_files)}')


# 3️⃣ Loss Functions and Metrics

For medical image segmentation, choosing appropriate loss functions is crucial. We'll implement two key components:

## 🎯 Dice Loss
The Dice coefficient (also known as F1 score) is particularly useful for segmentation tasks because it:
- Handles class imbalance well
- Focuses on overlap between predictions and ground truth
- Ranges from 0 (no overlap) to 1 (perfect overlap)

## 🔄 Combined Loss
We'll combine Dice Loss with weighted Cross-Entropy to:
- Balance between pixel-wise and region-based segmentation quality
- Handle class imbalance through dynamic class weights
- Provide smoother gradients during training

Let's implement these loss functions:


In [None]:
class DiceLoss(nn.Module):
    """Dice Loss for multi-class 2D segmentation"""
    def __init__(self, smooth=1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, predictions, targets):
        # predictions shape: (batch_size, n_classes, height, width)
        # targets shape: (batch_size, height, width)

        # Convert predictions to probabilities
        predictions = F.softmax(predictions, dim=1)

        # One-hot encode targets
        n_classes = predictions.shape[1]
        one_hot_targets = F.one_hot(targets, n_classes).permute(0, 3, 1, 2).float()

        # Calculate Dice score for each class
        numerator = 2 * (predictions * one_hot_targets).sum(dim=(2, 3))
        denominator = predictions.sum(dim=(2, 3)) + one_hot_targets.sum(dim=(2, 3))
        dice_scores = (numerator + self.smooth) / (denominator + self.smooth)

        # Average over classes and batch
        return 1 - dice_scores.mean()

class CombinedLoss(nn.Module):
    """Combined Dice and weighted Cross-Entropy loss with focus on cancer classes"""
    def __init__(self, smooth=1e-5, ce_weight=0.5, background_weight=0.1):
        super(CombinedLoss, self).__init__()
        self.smooth = smooth
        self.ce_weight = ce_weight
        self.background_weight = background_weight
        self.dice_loss = DiceLoss(smooth=smooth)

    def forward(self, predictions, targets):
        # Calculate class weights with reduced background weight
        n_classes = predictions.shape[1]
        class_counts = torch.bincount(targets.flatten(), minlength=n_classes).float()
        total_pixels = class_counts.sum()
        
        # Modify weights to focus on cancer classes
        class_weights = torch.zeros_like(class_counts)
        class_weights[0] = self.background_weight  # Background class
        class_weights[1:] = (1.0 - self.background_weight) / (n_classes - 1)  # Cancer classes
        
        # Scale weights by inverse frequency within cancer classes
        cancer_counts = class_counts[1:]  # Counts for cancer classes
        if cancer_counts.sum() > 0:  # Avoid division by zero
            cancer_weights = total_pixels / (cancer_counts * n_classes + self.smooth)
            cancer_weights = cancer_weights / cancer_weights.sum()  # Normalize
            class_weights[1:] *= cancer_weights
        
        class_weights = class_weights.to(predictions.device)

        # Dice Loss (focusing on cancer classes)
        dice_loss = self.dice_loss(predictions, targets)

        # Weighted Cross Entropy Loss
        ce_loss = F.cross_entropy(predictions, targets, weight=class_weights)

        # Combine losses
        return dice_loss + self.ce_weight * ce_loss


# 4️⃣ Dataset and Model Architecture

## 📊 Dataset Implementation
We'll create a custom PyTorch Dataset class that:
- Loads and preprocesses 3D medical images
- Handles data normalization and augmentation
- Manages batch creation for training

## 🏗️ Model Architecture
Our 3D U-Net implementation includes:
- Encoder path with increasing feature channels
- Decoder path with skip connections
- Advanced features:
  - Batch normalization for stable training
  - Residual connections for better gradient flow
  - Dropout for regularization
  - Squeeze-and-Excitation blocks for channel attention

Let's implement these components:


In [None]:
class OvarianCancerDataset(Dataset):
    """Dataset class for 2D slice-based ovarian cancer segmentation"""
    def __init__(self, ct_files, seg_files, target_size=128, is_train=True, window=4000, level=400):
        self.ct_files = ct_files
        self.seg_files = seg_files
        self.target_size = target_size
        self.is_train = is_train
        self.window = window
        self.level = level
        
        # Store slice indices with disease for each volume
        self.slice_indices = []
        self._find_disease_slices()
        
    def _find_disease_slices(self):
        """Find slices containing disease in each volume"""
        for ct_file, seg_file in zip(self.ct_files, self.seg_files):
            ct_path = os.path.join('Data_Subsample/CT', ct_file)
            seg_path = os.path.join('Data_Subsample/Segmentation', seg_file)
            
            # Load segmentation volume
            seg_vol = nib.load(seg_path).get_fdata()
            
            # Find slices with disease (class 1 or 2)
            disease_slices = []
            for z in range(seg_vol.shape[2]):
                if np.any(seg_vol[:,:,z] > 0):  # Any non-background pixels
                    disease_slices.append(z)
            
            if len(disease_slices) > 0:
                # Add some context slices before and after disease
                min_z = max(0, min(disease_slices) - 5)
                max_z = min(seg_vol.shape[2], max(disease_slices) + 5)
                disease_slices = list(range(min_z, max_z))
                
                self.slice_indices.extend([(ct_file, seg_file, z) for z in disease_slices])

    def apply_window_level(self, ct_slice):
        """Apply window/level adjustment to CT slice"""
        window = self.window
        level = self.level
        min_value = level - window/2
        max_value = level + window/2
        ct_slice = np.clip(ct_slice, min_value, max_value)
        ct_slice = (ct_slice - min_value) / (max_value - min_value)
        return ct_slice

    def augment_slice(self, ct_slice, seg_slice):
        """Apply 2D augmentations"""
        if not self.is_train:
            return ct_slice, seg_slice

        # Make copies
        ct_slice = ct_slice.copy()
        seg_slice = seg_slice.copy()

        # Random flip
        if np.random.random() > 0.5:
            ct_slice = np.flip(ct_slice, axis=0)
            seg_slice = np.flip(seg_slice, axis=0)
        if np.random.random() > 0.5:
            ct_slice = np.flip(ct_slice, axis=1)
            seg_slice = np.flip(seg_slice, axis=1)

        # Random rotation
        if np.random.random() > 0.5:
            angle = np.random.uniform(-15, 15)
            ct_slice = transform.rotate(ct_slice, angle, mode='reflect', preserve_range=True)
            seg_slice = transform.rotate(seg_slice, angle, mode='reflect', order=0, preserve_range=True)

        return ct_slice, seg_slice

    def preprocess_slice(self, ct_slice, seg_slice):
        """Preprocess a single slice pair"""
        # Resize
        if ct_slice.shape != (self.target_size, self.target_size):
            ct_slice = transform.resize(ct_slice, 
                                     (self.target_size, self.target_size),
                                     order=1,
                                     preserve_range=True,
                                     anti_aliasing=True)
            seg_slice = transform.resize(seg_slice,
                                      (self.target_size, self.target_size),
                                      order=0,
                                      preserve_range=True,
                                      anti_aliasing=False)

        # Apply window/level
        ct_slice = self.apply_window_level(ct_slice)

        # Apply augmentation
        if self.is_train:
            ct_slice, seg_slice = self.augment_slice(ct_slice, seg_slice)

        # Ensure segmentation values are integers
        seg_slice = np.round(seg_slice).astype(np.int64)

        return ct_slice, seg_slice

    def __len__(self):
        return len(self.slice_indices)

    def __getitem__(self, idx):
        ct_file, seg_file, z_idx = self.slice_indices[idx]
        
        # Load volumes
        ct_path = os.path.join('Data_Subsample/CT', ct_file)
        seg_path = os.path.join('Data_Subsample/Segmentation', seg_file)
        
        ct_vol = nib.load(ct_path).get_fdata()
        seg_vol = nib.load(seg_path).get_fdata()
        
        # Extract and preprocess slices
        ct_slice = ct_vol[:,:,z_idx]
        seg_slice = seg_vol[:,:,z_idx]
        
        ct_slice, seg_slice = self.preprocess_slice(ct_slice, seg_slice)
        
        # Ensure arrays are contiguous
        ct_slice = np.ascontiguousarray(ct_slice)
        seg_slice = np.ascontiguousarray(seg_slice)
        
        # Convert to tensors
        ct_slice = torch.FloatTensor(ct_slice).unsqueeze(0)  # Add channel dimension
        seg_slice = torch.LongTensor(seg_slice)
        
        return ct_slice, seg_slice

# Split data into training and validation sets
train_ct, val_ct, train_seg, val_seg = train_test_split(
    ct_files, seg_files, test_size=0.2, random_state=42
)

# Create datasets
train_dataset = OvarianCancerDataset(train_ct, train_seg)
val_dataset = OvarianCancerDataset(val_ct, val_seg)

# Set random seed for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Create dataloaders with larger batch sizes for 2D
train_loader = DataLoader(
    train_dataset, 
    batch_size=32,  # Increased batch size for 2D slices
    shuffle=True,
    num_workers=4 if torch.cuda.is_available() else 0,  # Use multiple workers for faster loading
    pin_memory=True if torch.cuda.is_available() else False,
    drop_last=True  # Drop incomplete batches for stable training
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=32,  # Same batch size for consistent validation
    shuffle=False,
    num_workers=4 if torch.cuda.is_available() else 0,
    pin_memory=True if torch.cuda.is_available() else False,
    drop_last=False  # Keep all validation samples
)

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')


In [None]:
class SEBlock(nn.Module):
    """Squeeze-and-Excitation block for channel attention"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class DoubleConv(nn.Module):
    """Enhanced double convolution block with SE attention"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.se = SEBlock(out_channels)
        self.residual = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        main = self.double_conv(x)
        main = self.se(main)
        residual = self.residual(x)
        return F.relu(main + residual)

class UNet2D(nn.Module):
    """Enhanced 2D U-Net with SE attention, residual connections, and deep supervision"""
    def __init__(self, in_channels=1, out_channels=3, features=[32, 64, 128, 256]):
        super(UNet2D, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.deep_supervision = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encoder
        in_channels_temp = in_channels
        for feature in features:
            self.encoder.append(DoubleConv(in_channels_temp, feature))
            in_channels_temp = feature

        # Decoder with deep supervision
        for i, feature in enumerate(reversed(features[:-1])):
            # Upsampling
            self.decoder.append(
                nn.Sequential(
                    nn.ConvTranspose2d(
                        features[features.index(feature)+1],
                        feature,
                        kernel_size=2,
                        stride=2
                    ),
                    nn.BatchNorm2d(feature),
                    nn.ReLU(inplace=True)
                )
            )
            # Double conv after concatenation
            self.decoder.append(DoubleConv(feature * 2, feature))
            
            # Deep supervision outputs
            self.deep_supervision.append(
                nn.Conv2d(feature, out_channels, kernel_size=1)
            )

        self.bottleneck = DoubleConv(features[-2], features[-1])
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

        # Advanced regularization
        self.dropout = nn.Dropout2d(p=0.3)
        self.spatial_dropout = nn.Dropout2d(p=0.1)

    def forward(self, x):
        skip_connections = []
        deep_outputs = []

        # Encoder
        for encoder in self.encoder[:-1]:
            x = encoder(x)
            x = self.spatial_dropout(x)  # Spatial dropout for feature map augmentation
            skip_connections.append(x)
            x = self.pool(x)
            x = self.dropout(x)

        x = self.bottleneck(x)

        # Decoder with deep supervision
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.decoder), 2):
            x = self.decoder[idx](x)
            skip = skip_connections[idx//2]

            # Handle different sizes
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:])

            concat_skip = torch.cat((skip, x), dim=1)
            x = self.decoder[idx+1](concat_skip)
            x = self.dropout(x)

            # Deep supervision output
            deep_out = self.deep_supervision[idx//2](x)
            deep_outputs.append(deep_out)

        # Final output
        final_out = self.final_conv(x)
        
        if self.training:
            # During training, return main output and deep supervision outputs
            return final_out, deep_outputs
        else:
            # During inference, return only the main output
            return final_out


# 5️⃣ Training and Evaluation

## 🏃‍♂️ Training Process
Our training pipeline includes:
- Batch-wise training with GPU acceleration
- Learning rate scheduling with ReduceLROnPlateau
- Early stopping to prevent overfitting
- Model checkpointing to save best weights
- Memory optimization with periodic cache clearing

## 📈 Evaluation Metrics
We'll monitor:
- Dice coefficient per class
- Overall segmentation accuracy
- Class-wise precision and recall
- Training and validation loss curves

## 🔍 Visualization
During and after training, we'll visualize:
- Sample predictions on validation data
- Training progress and learning curves
- Segmentation overlays on CT slices

Let's implement the training loop and evaluation functions:


In [None]:
# Function to visualize predictions
def visualize_predictions(model, data_loader, num_samples=3):
    """Visualize model predictions alongside ground truth"""
    model.eval()
    with torch.no_grad():
        # Get a batch of data
        data, target = next(iter(data_loader))
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()
        
        # Make predictions
        output = model(data)
        pred = F.softmax(output, dim=1)
        pred = torch.argmax(pred, dim=1)
        
        # Move tensors to CPU for visualization
        data = data.cpu()
        target = target.cpu()
        pred = pred.cpu()
        
        # Create a figure
        fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
        
        for i in range(num_samples):
            # Original CT slice
            axes[i,0].imshow(data[i,0], cmap='gray')
            axes[i,0].set_title('Input CT')
            axes[i,0].axis('off')
            
            # Ground truth
            axes[i,1].imshow(target[i], cmap='tab10')
            axes[i,1].set_title('Ground Truth')
            axes[i,1].axis('off')
            
            # Prediction
            axes[i,2].imshow(pred[i], cmap='tab10')
            axes[i,2].set_title('Prediction')
            axes[i,2].axis('off')
        
        plt.tight_layout()
        plt.show()

# Initialize 2D model and move to GPU if available
model = UNet2D(in_channels=1, out_channels=3, features=[16, 32, 64, 128])  # Reduced feature maps
if torch.cuda.is_available():
    model = model.cuda()

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model architecture:\n{model}")

# Training setup with focus on cancer classes
criterion = CombinedLoss(
    ce_weight=0.7,  # Balance between Dice and CE
    background_weight=0.1,  # Reduce focus on background class
    smooth=1e-5
)

# Optimizer with gradient clipping
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-3,  # Slightly higher learning rate for faster convergence
    weight_decay=0.01,
    amsgrad=True
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-3,
    epochs=20,  # Reduced epochs
    steps_per_epoch=len(train_loader),
    pct_start=0.3,
    div_factor=10,
    final_div_factor=1000
)

# Training loop with improved monitoring
n_epochs = 20  # Reduced epochs
best_val_dice = 0.0
patience = 5  # Reduced patience
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'val_dice': [], 'lr': []}


In [None]:
# Visualization function for training progress
def plot_training_history(history):
    """Plot training metrics over time"""
    epochs = range(1, len(history['train_loss']) + 1)
    
    plt.figure(figsize=(12, 4))
    
    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    plt.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot Dice score and learning rate
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['val_dice'], 'g-', label='Validation Dice')
    plt.plot(epochs, history['lr'], 'y-', label='Learning Rate')
    plt.title('Validation Dice Score and Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()


In [None]:
def train_epoch(model, loader, optimizer, criterion, deep_weight=0.5):
    """Train the model for one epoch with deep supervision"""
    model.train()
    total_loss = 0
    num_batches = len(loader)
    
    # Progress tracking
    processed_batches = 0
    report_interval = max(1, num_batches // 10)  # Report roughly 10 times per epoch

    for batch_idx, (data, target) in enumerate(loader):
        # Move data to GPU if available
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()

        # Clear gradients
        optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()

        # Forward pass with deep supervision
        main_output, deep_outputs = model(data)
        
        # Calculate main loss
        main_loss = criterion(main_output, target)
        
        # Calculate deep supervision losses
        deep_loss = 0
        for deep_out in deep_outputs:
            # Resize deep supervision output to match target size if needed
            if deep_out.shape[2:] != target.shape[1:]:  # Compare with target's spatial dimensions
                deep_out = F.interpolate(deep_out, size=target.shape[1:], mode='bilinear', align_corners=False)
            deep_loss += criterion(deep_out, target)
        
        # Combine losses
        loss = main_loss + deep_weight * (deep_loss / len(deep_outputs))

        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()

        # Update metrics
        total_loss += loss.item()
        processed_batches += 1

        # Print progress
        if (batch_idx + 1) % report_interval == 0:
            avg_loss = total_loss / processed_batches
            progress = (batch_idx + 1) / num_batches * 100
            print(f'Progress: {progress:.1f}%, Average Loss: {avg_loss:.4f}')

    return total_loss / num_batches

def validate(model, loader, criterion):
    """Validate the model"""
    model.eval()
    total_loss = 0
    class_dice_scores = {i: [] for i in range(3)}  # Track per-class Dice scores
    num_batches = 0

    with torch.no_grad():
        for data, target in loader:
            # Move data to GPU if available
            if torch.cuda.is_available():
                data = data.cuda()
                target = target.cuda()
            
            # Model returns only main output during validation
            output = model(data)
            loss = criterion(output, target)
            
            # Calculate per-class Dice scores
            pred = F.softmax(output, dim=1)
            pred = torch.argmax(pred, dim=1)
            
            # Calculate batch Dice scores for each class
            for i in range(3):  # 3 classes
                class_pred = pred == i
                class_target = target == i
                # Only calculate Dice if the class is present in the target
                if class_target.sum() > 0 or class_pred.sum() > 0:
                    dice = calculate_dice_score(class_pred, class_target)
                    class_dice_scores[i].append(dice)
            
            total_loss += loss.item()
            num_batches += 1

    # Calculate average Dice scores
    class_avg_dice = {i: np.mean(scores) if scores else 0.0 
                     for i, scores in class_dice_scores.items()}
    
    # Print per-class Dice scores
    print("\nValidation Dice scores:")
    print(f"Background: {class_avg_dice[0]:.4f}")
    print(f"Primary cancer: {class_avg_dice[1]:.4f}")
    print(f"Metastatic: {class_avg_dice[2]:.4f}")
    
    # Overall Dice score (weighted average, focusing on cancer classes)
    avg_dice = (class_avg_dice[1] + class_avg_dice[2]) / 2
    
    return total_loss / num_batches, avg_dice

def calculate_dice_score(pred, target):
    """Calculate Dice score for binary masks"""
    intersection = (pred & target).sum().float()
    union = pred.sum() + target.sum()
    if union == 0:
        return 1.0  # Define empty as perfect match
    return (2.0 * intersection / union).item()

def predict_slice(model, ct_slice):
    """Generate predictions for a single slice"""
    model.eval()
    with torch.no_grad():
        if torch.cuda.is_available():
            ct_slice = ct_slice.cuda()
        pred = model(ct_slice.unsqueeze(0))
        pred = F.softmax(pred, dim=1)
        pred = torch.argmax(pred, dim=1)
    return pred[0].cpu().numpy()


# Training setup with focus on cancer classes
criterion = CombinedLoss(
    ce_weight=0.7,  # Balance between Dice and CE
    background_weight=0.1,  # Reduce focus on background class
    smooth=1e-5
)

# Optimizer with gradient clipping
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=0.01,
    amsgrad=True  # Use AMSGrad variant
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-3,
    epochs=50,
    steps_per_epoch=len(train_loader),
    pct_start=0.3,  # Warm-up period
    div_factor=25,  # Initial lr = max_lr/25
    final_div_factor=1000,  # Min lr = initial_lr/1000
)

# Training loop with improved monitoring
n_epochs = 50  # Increased epochs
best_val_dice = 0.0  # Track best Dice score instead of loss
patience = 10  # Increased patience
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'val_dice': [], 'lr': []}

print("Starting training...")
for epoch in range(n_epochs):
    print(f"\nEpoch {epoch+1}/{n_epochs}")
    
    # Training phase with deep supervision
    train_loss = train_epoch(model, train_loader, optimizer, criterion, deep_weight=0.5)
    print(f"Training Loss: {train_loss:.4f}")
    
    # Validation phase
    val_loss, val_dice = validate(model, val_loader, criterion)
    print(f"Validation Loss: {val_loss:.4f}, Validation Dice: {val_dice:.4f}")
    
    # Learning rate scheduling
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Current learning rate: {current_lr:.2e}")
    
    # Update history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_dice'].append(val_dice)
    history['lr'].append(current_lr)
    
    # Plot training progress
    plot_training_history(history)
    
    # Visualize predictions every 5 epochs
    if (epoch + 1) % 5 == 0:
        print("\nTraining set predictions:")
        visualize_predictions(model, train_loader)
        print("\nValidation set predictions:")
        visualize_predictions(model, val_loader)
    
    # Model checkpointing based on Dice score
    if val_dice > best_val_dice:
        best_val_dice = val_dice
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_dice': best_val_dice,
            'history': history
        }, 'best_model.pth')
        patience_counter = 0
        print(f"Saved new best model with Dice score: {best_val_dice:.4f}!")
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= patience:
        print(f"\nEarly stopping triggered after {epoch+1} epochs")
        print(f"Best validation Dice score: {best_val_dice:.4f}")
        break
    
    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# After training, visualize final predictions on test set
print("\nFinal predictions on validation set:")
visualize_predictions(model, val_loader, num_samples=5)  # Show more samples for final evaluation


# 6️⃣ Discussion and Future Work

## 🔍 Technical Analysis
Consider and discuss the following aspects of your implementation:

### Data Challenges
- What difficulties did you encounter with the medical imaging data?
- How effective was your preprocessing pipeline?
- What additional data augmentation techniques could be beneficial?

### Model Performance
- How well did the model segment different tissue types?
- What were the main sources of errors?
- How could the architecture be improved?

## 🏥 Clinical Impact
Reflect on the clinical applications:

### Current Capabilities
- How reliable is the model for clinical use?
- What are the limitations of the current implementation?
- How does it compare to human expert performance?

### Future Improvements
- What additional validation would be needed for clinical deployment?
- How could the model be integrated into clinical workflows?
- What safety measures should be implemented?

## 🚀 Next Steps
Consider these potential improvements:

### Technical Enhancements
- Implement additional data augmentation techniques
- Experiment with different model architectures
- Add uncertainty quantification
- Optimize for inference speed

### Clinical Integration
- Develop a user-friendly interface
- Add reporting and visualization tools
- Implement quality assurance measures
- Design clinical validation studies

Write your answers and reflections below:


# 9. Discussion Questions

Please answer the following questions based on your implementation and results:

1. **Data Analysis**
   - What challenges did you encounter with the medical imaging data?
   - How did you handle class imbalance?

2. **Model Performance**
   - How well did the model perform on different classes?
   - What were the main sources of error?

3. **Clinical Relevance**
   - How might this model be useful in a clinical setting?
   - What additional validation would be needed?

4. **Improvements**
   - What modifications could improve the model's performance?
   - How could the preprocessing pipeline be enhanced?

Write your answers below:

1. Data Analysis:
   > Your answer here

2. Model Performance:
   > Your answer here

3. Clinical Relevance:
   > Your answer here

4. Improvements:
   > Your answer here
