# U-Net 4 Multiple Sclerosis Lesion Segmentation - ICPR Challenge
### Authors: Andrew R. Darnall, Giovanni Spadaro @ UniCT
---

## 🎯 Competition Objective: MS Lesion Segmentation

The central goal of this competition is the **automatic segmentation of Multiple Sclerosis (MS) lesions** using **multi-modal MRI data** and **deep learning algorithms**.

### 🧪 Provided Data
Participants were given:
- **MRI scans** in three modalities:
  - **FLAIR**
  - **T1-weighted (T1-w)**
  - **T2-weighted (T2-w)**
- **Ground-truth segmentation masks**, which are:
  - **Binary masks**:  
    - **White pixels** → MS lesion regions  
    - **Black pixels** → Background

### 🧠 Task Description
- Participants could use **any or all modalities**, along with the ground-truth labels, to:
  - Develop **deep learning-based models** for **automatic lesion segmentation**
- MS lesions appear as **irregular clusters of pixels** with **high variability in size and shape**
- These lesions are often **difficult to detect** via visual inspection, requiring **expert-level interpretation**

The ultimate goal is to create **fully automated segmentation pipelines** that can robustly identify and delineate MS lesions from raw MRI data.


## 🧠 MSLesSeg Dataset Overview

As part of this competition, participants were provided with the **MSLesSeg Dataset** — a **comprehensively annotated, multi-modal MRI dataset** designed for advancing **lesion segmentation** research in medical imaging.

### 📊 Dataset Composition
- **Total Patients:** 75 (48 women, 27 men)  
- **Age Range:** 18–59 years (Mean: 37 ± 10.3 years)  
- **Longitudinal Timepoints:**  
  - 50 patients with 1 timepoint  
  - 15 patients with 2 timepoints  
  - 5 patients with 3 timepoints  
  - 5 patients with 4 timepoints  
- **Time Interval Between Scans:** ~1.27 ± 0.62 years  
- **Total MRI Series:** 115

### 🧬 Imaging Modalities
Each timepoint includes **three core MRI modalities**:
- **T1-weighted (T1-w)**
- **T2-weighted (T2-w)**
- **FLAIR (Fluid-Attenuated Inversion Recovery)**

### 🧑‍⚕️ Expert Annotation
- Lesions were **manually annotated** by clinical experts.
- **FLAIR sequences** were the primary reference for lesion labeling.
- **T1-w and T2-w** scans supported **multi-contrast lesion characterization**.

### 🧪 Dataset Splits
- **Training Set:** 53 scans  
- **Test Set:** 22 scans  

### ✅ Ethical Compliance
- **Ethical approval** was obtained from the corresponding Hospital Ethics Committee.
- **Informed consent** was acquired from all participating patients.

---

# The Experiment

Below is the code used for the:

1) Preprocessing of the ***Brain MRI*** scans
2) Definition of Dataset, Dataloader and LihgtningDataModule classes
3) ***U-Net*** architecture
4) ***PyTorch Lightning*** Trainer
5) Training & Evaluation
6) Model Exaplainability with the post-hoc method ***GradCam++***

---

## 🛠️ Preprocessing & Annotation Workflow

The MSLesSeg dataset underwent a **comprehensive preprocessing pipeline** and **expert-driven manual annotation** to ensure **standardization** and **label quality** for downstream MS lesion segmentation tasks.

### 🧼 Preprocessing Pipeline
1. **Anonymization** of all MRI scans to protect patient privacy.
2. **DICOM to NIfTI conversion**, leveraging NIfTI's wide adoption in neuroimaging.
3. **Co-registration to the MNI152 1mm³ isotropic template** using **FLIRT** (FMRIB’s Linear Image Registration Tool), ensuring all scans are aligned to a **common anatomical space**.
4. **Brain extraction** via **BET** (Brain Extraction Tool) to remove non-brain tissues and isolate relevant structures.

This pipeline guarantees that all images are **standardized** and **aligned**, which is critical for **automated MS lesion segmentation algorithms**.

---

### 🖋️ Ground-Truth Annotation Protocol
- Lesions were **manually segmented** on the **FLAIR modality** for each patient and timepoint.
- **T1-w and T2-w** modalities were used to **cross-validate ambiguous cases**.
- Annotation was conducted by a **trained junior rater**, under supervision of:
  - A **senior neuroradiologist**
  - A **senior neurologist**
- Annotation sessions included:
  - Multiple **training meetings** to establish a **consistent segmentation strategy**
  - Use of **JIM9** — a high-end tool for **medical image segmentation and analysis**
  - Regular **expert validation checkpoints** to ensure consistency and accuracy

The final masks, reviewed and approved by senior experts, are considered the **gold-standard ground truth**.

---

## 🧾 Key Annotation Highlights
- **Independent segmentation** for each patient/timepoint to avoid bias
- Conducted on **FLAIR scans registered to MNI space**
- **Validated ground-truth masks** ready for training and evaluation



In [1]:
import os
import torch
import nibabel as nib
import numpy as np
from scipy.ndimage import distance_transform_edt
from skimage.transform import resize
from pathlib import Path
from tqdm.notebook import tqdm

In [2]:
# Helper function to load the nifti files
def load_nifti(file_path):
    return nib.load(file_path).get_fdata()

In [3]:
def preprocess_case(input_dir, output_dir, case_id):
    flair = load_nifti(input_dir / f"{case_id}_flair.nii.gz")
    t1 = load_nifti(input_dir / f"{case_id}_t1.nii.gz")
    t2 = load_nifti(input_dir / f"{case_id}_t2.nii.gz")
    seg = load_nifti(input_dir / f"{case_id}_seg.nii.gz").astype(np.uint8)

    # Stack input modalities into a tensor (C, D, H, W)
    stacked = np.stack([flair, t1, t2], axis=0)
    # Use CPU Tensors instead
    input_tensor = torch.tensor(stacked, dtype=torch.float32)
    
    # Load it back into a PyTorch Tensor for storing
    seg_tensor = torch.tensor(seg, dtype=torch.uint8)
    # Add the batch dimension in order to make it compatible with the other Tensor sizes
    seg_tensor = seg_tensor.unsqueeze(0)

    
    # Save all
    output_case_dir = output_dir / case_id
    output_case_dir.mkdir(parents=True, exist_ok=True)
    torch.save(input_tensor, output_case_dir / "input_tensor.pt")
    torch.save(seg_tensor, output_case_dir / "seg_mask.pt")

In [4]:
from pathlib import Path
from tqdm import tqdm

def run_preprocessing(root_path, output_path):
    input_path = Path(root_path)
    output_path = Path(output_path)

    # Get all case directories
    all_case_dirs = [d for d in input_path.iterdir() if d.is_dir()]
    print(f"Found {len(all_case_dirs)} cases.")

    for case_dir in tqdm(all_case_dirs):
        case_id = case_dir.name  # e.g., MSLS_000
        output_case_path = output_path / case_id  # Define output path for each case

        # Check if the case has already been processed (e.g., output directory or file exists)
        if output_case_path.exists():
            print(f"✅ Skipping {case_id}, already processed.")
            continue  # Skip processing if the case already exists

        try:
            preprocess_case(case_dir, output_path, case_id)
        except Exception as e:
            print(f"❌ Failed on {case_id}: {e}")

In [5]:
# Run the preprocessing on the training set
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/train"
OUTPUT_PATH = "../data/02-Tensor-Data/train"

run_preprocessing(RAW_DATA_PATH, OUTPUT_PATH)

Found 93 cases.


100%|████████████████████████████████████████| 93/93 [00:00<00:00, 22907.58it/s]

✅ Skipping MSLS_019, already processed.
✅ Skipping MSLS_039, already processed.
✅ Skipping MSLS_059, already processed.
✅ Skipping MSLS_000, already processed.
✅ Skipping MSLS_001, already processed.
✅ Skipping MSLS_002, already processed.
✅ Skipping MSLS_003, already processed.
✅ Skipping MSLS_004, already processed.
✅ Skipping MSLS_005, already processed.
✅ Skipping MSLS_006, already processed.
✅ Skipping MSLS_007, already processed.
✅ Skipping MSLS_008, already processed.
✅ Skipping MSLS_009, already processed.
✅ Skipping MSLS_010, already processed.
✅ Skipping MSLS_011, already processed.
✅ Skipping MSLS_012, already processed.
✅ Skipping MSLS_013, already processed.
✅ Skipping MSLS_014, already processed.
✅ Skipping MSLS_015, already processed.
✅ Skipping MSLS_016, already processed.
✅ Skipping MSLS_017, already processed.
✅ Skipping MSLS_018, already processed.
✅ Skipping MSLS_020, already processed.
✅ Skipping MSLS_021, already processed.
✅ Skipping MSLS_022, already processed.





In [6]:
# Run the preprocessing on the test set
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/test/test_MASK"
OUTPUT_PATH = "../data/02-Tensor-Data/test"

run_preprocessing(RAW_DATA_PATH, OUTPUT_PATH)

Found 22 cases.


100%|████████████████████████████████████████| 22/22 [00:00<00:00, 17273.43it/s]

✅ Skipping MSLS_093, already processed.
✅ Skipping MSLS_094, already processed.
✅ Skipping MSLS_095, already processed.
✅ Skipping MSLS_096, already processed.
✅ Skipping MSLS_097, already processed.
✅ Skipping MSLS_098, already processed.
✅ Skipping MSLS_099, already processed.
✅ Skipping MSLS_100, already processed.
✅ Skipping MSLS_101, already processed.
✅ Skipping MSLS_102, already processed.
✅ Skipping MSLS_103, already processed.
✅ Skipping MSLS_104, already processed.
✅ Skipping MSLS_105, already processed.
✅ Skipping MSLS_106, already processed.
✅ Skipping MSLS_107, already processed.
✅ Skipping MSLS_108, already processed.
✅ Skipping MSLS_109, already processed.
✅ Skipping MSLS_110, already processed.
✅ Skipping MSLS_111, already processed.
✅ Skipping MSLS_112, already processed.
✅ Skipping MSLS_113, already processed.
✅ Skipping MSLS_114, already processed.





## Build the Dataset and Dataloaders for the MSLesSeg preprocessed data

In [7]:
# Check the initial memory consuption (GPU) of the project
import torch

# Check initial GPU memory usage
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [8]:
# MSLesSeg Tensor Dataset class
import os
import glob
import torch
from torch.utils.data import Dataset

class MSLesSegDataset(Dataset):
    
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.patient_dirs = self._get_patient_dirs()

    def _get_patient_dirs(self):
        """
        Helper function to search for all patient directories within train/test directories.
        """
        patient_dirs = []
        
        # Get all patient directories within the root_dir (either train or test)
        for patient_dir in os.listdir(self.root_dir):
            patient_path = os.path.join(self.root_dir, patient_dir)
            
            # Make sure it's a directory
            if os.path.isdir(patient_path):
                # Check if both 'input_tensor.pt' and 'seg_mask.pt' exist
                input_tensor_path = os.path.join(patient_path, 'input_tensor.pt')
                seg_mask_path = os.path.join(patient_path, 'seg_mask.pt')
                
                if os.path.exists(input_tensor_path) and os.path.exists(seg_mask_path):
                    patient_dirs.append(patient_path)
        
        return patient_dirs

    def __len__(self):
        return len(self.patient_dirs)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the sample to retrieve.
        
        Returns:
            dict: {'input': tensor, 'mask': tensor} for the requested sample.
        """
        patient_dir = self.patient_dirs[idx]
        
        # Load the input tensor and segmentation mask
        input_tensor = torch.load(os.path.join(patient_dir, 'input_tensor.pt'))  # Shape: [3, 182, 218, 182]
        seg_mask = torch.load(os.path.join(patient_dir, 'seg_mask.pt'))  # Shape: [1, 182, 218, 182]
        
        return input_tensor, seg_mask

### PyTorch Lightning DataModule

This particular version of PyTorch Lightning, and in general from version 2.x onward require a ***LightningDataModule*** instead of passing the dataloaders directly to the ***.fit()*** method

In [9]:
# MSLesSeg (PyTorch) LightningDataModule definition
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl

class MSLesSegDataModule(pl.LightningDataModule):
    
    def __init__(self, root_data_dir, batch_size, val_split, num_workers):
        """
        root_data_dir: Is the path to the train and test data
        """
        super().__init__()
        self.data_dir = root_data_dir
        self.batch_size = batch_size
        self.val_split = val_split
        self.num_workers = num_workers

    def setup(self, stage=None):
        # Load full training dataset
        full_dataset = MSLesSegDataset(root_dir=os.path.join(self.data_dir, 'train'))

        # Split into train and val
        val_size = int(len(full_dataset) * self.val_split)
        train_size = len(full_dataset) - val_size
        self.train_dataset, self.val_dataset = random_split(full_dataset, [train_size, val_size])

        # Load test dataset (if it exists)
        test_dir = os.path.join(self.data_dir, 'test')
        if os.path.exists(test_dir):
            self.test_dataset = MSLesSegDataset(root_dir=test_dir)
        else:
            self.test_dataset = None

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self):
        if self.test_dataset:
            return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return None



## The Model's Architecture

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Standard double convolution block (with GeLU instead of ReLU)
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.GELU(),
        nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.GELU()
    )

# Final output layer: Conv3D + Sigmoid for probability map
def final_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv3d(in_channels, out_channels, kernel_size=1),
        nn.Sigmoid()
    )

# Center cropping
def center_crop(tensor, target_tensor):
    _, _, h, w, d = target_tensor.shape
    tensor = tensor[:, :, :h, :w, :d]
    return tensor

# Pad the two tensors to match the shapes
def pad_to_match_shape(tensor1, tensor2):
    """
    Pads the smaller tensor to match the shape of the larger tensor along all spatial dimensions (H, W, D),
    but leaves batch (B) and channel (C) dimensions untouched.
    
    Args:
        tensor1 (Tensor): The first tensor.
        tensor2 (Tensor): The second tensor.
        
    Returns:
        Tensor, Tensor: Both tensors padded to the same shape.
    """
    # Get the shapes of both tensors
    shape1 = tensor1.shape
    shape2 = tensor2.shape
    
    # Initialize padding lists for tensor1 and tensor2
    padding1 = []
    padding2 = []
    
    # Compare shapes and calculate the padding for each tensor
    for dim1, dim2 in zip(shape1[2:], shape2[2:]):  # Start from dimension 2 (H, W, D)
        if dim1 < dim2:
            # Calculate padding for tensor1
            pad_left = (dim2 - dim1) // 2
            pad_right = dim2 - dim1 - pad_left
            padding1.extend([pad_left, pad_right])
            padding2.extend([0, 0])  # No padding needed for tensor2 in this dimension
        elif dim1 > dim2:
            # Calculate padding for tensor2
            pad_left = (dim1 - dim2) // 2
            pad_right = dim1 - dim2 - pad_left
            padding1.extend([0, 0])  # No padding needed for tensor1 in this dimension
            padding2.extend([pad_left, pad_right])
        else:
            # No padding needed for this dimension, keep it symmetric
            padding1.extend([0, 0])
            padding2.extend([0, 0])

    # Pad tensor1 and tensor2 with the calculated padding
    padded_tensor1 = F.pad(tensor1, padding1)
    padded_tensor2 = F.pad(tensor2, padding2)
    
    return padded_tensor1, padded_tensor2

In [11]:
# 3D U-Net Architecture
class UNet(nn.Module):
    """
    The U-Net model will accept input Tensors of shape: [B, C, H, W, D]
    Which based on the used MSLesSeg dataset will be [B, 3, 182, 218, 182]
    The input Tensor is nothing more than the stacked [FLAIR, T1w, T2w] modalities

    The output will be the segmentation mask of shape [B, 1, 182, 218, 182]
    """
    def __init__(self, num_classes, in_channels):
        super(UNet, self).__init__()

        self.max_pool3d = nn.MaxPool3d(kernel_size=2, stride=2)

        # == Contracting Path == #
        self.down_conv_1 = double_conv(in_channels, 64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)

        # == Expanding Path == #
        self.up_transpose_1 = nn.ConvTranspose3d(
            in_channels=1024,
            out_channels=512,
            kernel_size=2,
            stride=2,
            output_padding=1
        )
        self.up_conv_1 = double_conv(1024, 512)
        
        self.up_transpose_2 = nn.ConvTranspose3d(
            in_channels=512,
            out_channels=256,
            kernel_size=2,
            stride=2,
            output_padding=1
        )
        self.up_conv_2 = double_conv(512, 256)

        self.up_transpose_3 = nn.ConvTranspose3d(
            in_channels=256,
            out_channels=128,
            kernel_size=2,
            stride=2,
            output_padding=1
        )
        self.up_conv_3 = double_conv(256, 128)

        self.up_transpose_4 = nn.ConvTranspose3d(
            in_channels=128,
            out_channels=64,
            kernel_size=2,
            stride=2,
            output_padding=1
        )
        self.up_conv_4 = double_conv(128, 64)

        self.prob_out = final_conv(64, num_classes)

        self.out = nn.Conv3d(
            in_channels=64,
            out_channels=num_classes,
            kernel_size=1
        )

    def forward(self, x):
        
        down_1 = self.down_conv_1(x)
        down_2 = self.max_pool3d(down_1)
        down_3 = self.down_conv_2(down_2)
        down_4 = self.max_pool3d(down_3)
        down_5 = self.down_conv_3(down_4)
        down_6 = self.max_pool3d(down_5)
        down_7 = self.down_conv_4(down_6)
        down_8 = self.max_pool3d(down_7)
        down_9 = self.down_conv_5(down_8)

        up_1 = self.up_transpose_1(down_9)
        up_1, down_7 = pad_to_match_shape(up_1, down_7)
        x = self.up_conv_1(torch.cat([down_7, up_1], dim=1))
        
        up_2 = self.up_transpose_2(x)
        up_2, down_5 = pad_to_match_shape(up_2, down_5)
        x = self.up_conv_2(torch.cat([down_5, up_2], dim=1))
        
        up_3 = self.up_transpose_3(x)
        up_3, down_3 = pad_to_match_shape(up_3, down_3)
        x = self.up_conv_3(torch.cat([down_3, up_3], dim=1))
        
        up_4 = self.up_transpose_4(x)
        up_4, down_1 = pad_to_match_shape(up_4, down_1)
        x = self.up_conv_4(torch.cat([down_1, up_4], dim=1))

        
        out = self.out(x)
        prob_out = self.prob_out(x)
        
        return out, prob_out

In [12]:
# Obtain the model summary
from torchinfo import summary

# Assuming your model is already defined (UNet or similar)
model = UNet(num_classes=1, in_channels=3)

# Example input size: (Batch Size, Channels, Height, Width, Depth)
input_tensor = torch.randn(1, 3, 182, 218, 182)  # Modify based on your use case

# Use summary from torchinfo to display the model summary
summary(model, input_data=input_tensor)

Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     [1, 1, 191, 223, 191]     --
├─Sequential: 1-1                        [1, 64, 182, 218, 182]    --
│    └─Conv3d: 2-1                       [1, 64, 182, 218, 182]    5,248
│    └─GELU: 2-2                         [1, 64, 182, 218, 182]    --
│    └─Conv3d: 2-3                       [1, 64, 182, 218, 182]    110,656
│    └─GELU: 2-4                         [1, 64, 182, 218, 182]    --
├─MaxPool3d: 1-2                         [1, 64, 91, 109, 91]      --
├─Sequential: 1-3                        [1, 128, 91, 109, 91]     --
│    └─Conv3d: 2-5                       [1, 128, 91, 109, 91]     221,312
│    └─GELU: 2-6                         [1, 128, 91, 109, 91]     --
│    └─Conv3d: 2-7                       [1, 128, 91, 109, 91]     442,496
│    └─GELU: 2-8                         [1, 128, 91, 109, 91]     --
├─MaxPool3d: 1-4                         [1, 128, 45, 54, 45]      

In [13]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


## The Trainer

In [14]:
# First we implement the DiceLoss
def dice_loss_3d(pred, target, smooth=1):
    """
    Computes Dice Loss for 3D segmentation tasks.
    Args:
    pred: Tensor of predictions (batch_size, C, D, H, W).
    target: Tensor of ground truth (batch_size, C, D, H, W).
    smooth: Smoothing factor.
    Returns:
    Scalar Dice Loss.
    """
    pred = F.softmax(pred, dim=1)
    num_classes = pred.shape[1]
    dice = 0
    for c in range(num_classes):
        pred_c = pred[:, c]
        target_c = target[:, c]
        intersection = (pred_c * target_c).sum(dim=(2, 3, 4))
        union = pred_c.sum(dim=(2, 3, 4)) + target_c.sum(dim=(2, 3, 4))
        dice += (2. * intersection + smooth) / (union + smooth)
        
    return 1 - dice.mean() / num_classes

# Then we implement the Mean Dice
def mean_dice(pred, target, smooth=1):
    """
    Computes Dice Loss for 3D segmentation tasks.
    Args:
    pred: Tensor of predictions (batch_size, C, D, H, W).
    target: Tensor of ground truth (batch_size, C, D, H, W).
    smooth: Smoothing factor.
    Returns:
    Scalar Dice Loss.
    """
    pred = F.softmax(pred, dim=1)
    num_classes = pred.shape[1]
    dice = 0
    for c in range(num_classes):
        pred_c = pred[:, c]
        target_c = target[:, c]
        intersection = (pred_c * target_c).sum(dim=(2, 3, 4))
        union = pred_c.sum(dim=(2, 3, 4)) + target_c.sum(dim=(2, 3, 4))
        dice += (2. * intersection + smooth) / (union + smooth)
        
    return dice.mean()

In [15]:
# PyTorch Lightning Trainer for the U-Net model
import torch.optim as optim

class MSLesionSegmentationModel(pl.LightningModule):
    
    def __init__(self, model, checkpoint_dir, lr=1e-4):
        super(MSLesionSegmentationModel, self).__init__()
        self.model = model
        self.lr = lr
        self.checkpoint_dir = checkpoint_dir

    def forward(self, x):
        out, prob_out = self.model(x)
        return out, prob_out 
    
    def training_step(self, batch, batch_idx):
        # Separate the input Tensor from the Segmentation Mask
        input_tensor, gt = batch
        # Forward Pass on the model
        y_pred = self(input_tensor)[0]
        # Compute DiceLoss
        loss = dice_loss_3d(y_pred, gt)

        # log the train loss
        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        # Separate the input Tensor from the Segmentation Mask
        input_tensor, gt = batch
        # Forward pass on the model
        y_pred = self(input_tensor)[0]
        # Compute DiceLoss
        loss = dice_loss_3d(y_pred, gt)
        # Compute the Mean Dice Score
        dice_score = mean_dice(y_pred, gt)
        # log the val loss
        self.log("val_loss", loss)
        # log the Mean Dice Score
        self.log("dice_score", dice_score)

        return loss

    def test_step(self, batch, batch_idx):
        # Separate the input Tensor from the Segmentation Mask
        input_tensor, gt = batch
        # Forward pass on the model
        y_pred = self(input_tensor)[0]
        # Compute DiceLoss
        loss = dice_loss_3d(y_pred, gt)

        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [scheduler]
        
    def on_train_epoch_end(self):
        # Save the model's checkpoint at the end of the epoch
        os.makedirs(self.save_dir, exist_ok=True)

        save_path = os.path.join(self.save_dir, "unet_best_model.pth")
        
        # Prepare the checkpoint dict
        checkpoint = {
            'epoch': self.current_epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'learning_rate': self.lr,
        }

        # Save and flush to disk
        with open(save_path, 'wb') as f:
            torch.save(checkpoint, f)
            f.flush()
            os.fsync(f.fileno())

## The Training

In [16]:
# Register a free account with Weights and Biases, and create a new project in order to obtain an API Key for the training
import os
from dotenv import load_dotenv
import wandb

# Load the .env file
load_dotenv()

# Get the API key from the env variable
api_key = os.getenv("WANDB_API_KEY")

# Login to wandb
if api_key:
    os.environ["WANDB_API_KEY"] = api_key
    wandb.login()
else:
    print("❌ WANDB_API_KEY not found in .env file.")

[34m[1mwandb[0m: Currently logged in as: [33mdrnnrw00m10c351s[0m ([33mfpv-perceivelab-unict[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [17]:
# Setup the Weights and Biases logger
import wandb
from pytorch_lightning.loggers import WandbLogger

wandb_logger = WandbLogger(
    project='MSLesSeg-4-ICPR',     # Change to your actual project name
    name='UNet_run_1', # A specific run name
    log_model=True          # Optional: log model checkpoints
)

In [18]:
# Set the constant for the maximum number of epochs for the training
MAX_EPOCHS = 100

In [19]:
# Definition of Dataloader hyperparameters (Batch size, seed and num workers)
TRAIN_SPLIT = 0.8
VAL_SPLIT = 0.2

TRAIN_NUM_WORKERS = 0
TEST_NUM_WORKERS = 0

TRAIN_BATCH_SIZE = 1
VAL_BATCH_SIZE = 1
TEST_BATCH_SIZE = 1

In [20]:
# Create the Trainer object
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator="gpu",
    precision=16,
    logger=wandb_logger
)

/home/drew/miniconda3/envs/mslesseg4icpr/lib/python3.11/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [21]:
# Prepare the data module
data_module = MSLesSegDataModule(
    root_data_dir="../data/02-Tensor-Data/",
    batch_size=TRAIN_BATCH_SIZE,
    val_split=VAL_SPLIT,
    num_workers=TRAIN_NUM_WORKERS
)

In [22]:
# Instantiate the PyTorch Lightning Module (wrapper for the PyTorch nn.Module ~ Architecture)
lightning_model = MSLesionSegmentationModel(
    model=UNet(in_channels=3, num_classes=2),
    checkpoint_dir="./model_checkpoints"
)

In [23]:
# Check the initial memory consuption (GPU) of the project
import torch

# Check initial GPU memory usage
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [24]:
# Train (fit) the model
trainer.fit(lightning_model, datamodule=data_module)

You are using a CUDA device ('NVIDIA GeForce RTX 4070') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params | Mode 
---------------------------------------
0 | model | UNet | 90.3 M | train
---------------------------------------
90.3 M    Trainable params
0         Non-trainable params
90.3 M    Total params
361.185   Total estimated model params size (MB)
55        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/home/drew/miniconda3/envs/mslesseg4icpr/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


RuntimeError: The size of tensor a (191) must match the size of tensor b (182) at non-singleton dimension 3

In [None]:
# Load the best model's checkpoint and evaluate it

In [None]:
# Evaluate the model on the test set
trainer.test(model, test_dataloader=test_loader)

## Visualizing the Learned Representations

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

# Function to extract features from the model
def extract_features(model, dataloader, device='cuda'):
    model.eval()
    features = []
    labels = []
    
    with torch.no_grad():
        for data in dataloader:
            inputs, target = data['image'].to(device), data['mask'].to(device)  # Replace 'image' and 'mask' keys as per your dataset
            
            # Forward pass through the model
            output = model(inputs)  # This is where we extract the feature; adjust according to your model's architecture
            
            # Extract features from MoE (or any other layer you want to visualize)
            # Here I assume output is the final feature map after the MoE layer
            feature_map = output.view(output.size(0), -1)  # Flatten the features to (batch_size, features)
            features.append(feature_map.cpu().numpy())
            labels.append(target.cpu().numpy())  # Add the target (or ground truth) labels for color coding in t-SNE

    features = np.concatenate(features, axis=0)  # Combine all the feature maps
    labels = np.concatenate(labels, axis=0)  # Combine all the labels

    return features, labels

# Function to apply t-SNE on features
def plot_tsne(features, labels):
    # Standardize the features (optional but recommended for t-SNE)
    scaler = StandardScaler()
    features = scaler.fit_transform(features)

    # Apply t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(features)
    
    # Plotting
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=labels, cmap='jet', alpha=0.5)
    plt.colorbar(scatter)
    plt.title('t-SNE Visualization of Learned Representations')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.show()

# Example Usage
train_dataloader = # Your train DataLoader here

# Extract features and labels from the model
features, labels = extract_features(model, train_dataloader, device='cuda')

# Plot t-SNE visualization
plot_tsne(features, labels)

## Post Hoc Model Explainability - GradCAM++

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2
from captum.attr import LayerGradCam, LayerGradCamPlusPlus
from captum.attr import GuidedBackprop

# Function to visualize GradCAM++ output
def visualize_gradcam_plus_plus(model, input_tensor, target_class=None, layer_name='decoder', device='cuda'):
    # Move input tensor to the appropriate device
    input_tensor = input_tensor.to(device)

    # Set model to evaluation mode
    model.eval()
    
    # Create GradCAM++ object for the specified layer
    gradcam_pp = LayerGradCamPlusPlus(model, model.decoder[2])  # Assuming 'decoder' is your final layer, adjust accordingly.
    
    # Apply GradCAM++ to input tensor
    attributions = gradcam_pp.attribute(input_tensor, target=target_class)
    
    # Convert attributions to numpy
    attributions = attributions.cpu().detach().numpy()
    
    # Normalize the heatmap
    heatmap = np.sum(attributions[0], axis=0)  # Summing over channels for a single-channel heatmap
    heatmap = np.maximum(heatmap, 0)  # ReLU to ignore negative values
    heatmap = cv2.resize(heatmap, (input_tensor.shape[2], input_tensor.shape[3]))  # Resize to input size
    heatmap = cv2.normalize(heatmap, None, 0, 1, cv2.NORM_MINMAX)  # Normalize the heatmap to [0, 1]
    
    # Convert the original image to numpy and rescale
    input_image = input_tensor[0].cpu().detach().numpy().transpose(1, 2, 0)
    input_image = cv2.resize(input_image, (input_tensor.shape[2], input_tensor.shape[3]))  # Resize to input size
    
    # Create a colormap for the heatmap
    colormap = plt.get_cmap('jet')
    colored_heatmap = colormap(heatmap)  # Apply colormap
    
    # Overlay the heatmap on top of the original image
    superimposed_img = np.uint8(input_image * 255)  # Convert original image back to [0, 255] range
    superimposed_img = cv2.addWeighted(superimposed_img, 0.7, np.uint8(colored_heatmap[:, :, :3] * 255), 0.3, 0)
    
    # Plot the result
    plt.imshow(superimposed_img)
    plt.title('GradCAM++ Heatmap')
    plt.axis('off')
    plt.show()

# Example Usage (assuming `model` is your trained model and `input_tensor` is an input image)
input_tensor = torch.randn(1, 3, 128, 128).to(device)  # Example input, use your actual input tensor
target_class = None  # You can provide a target class or leave as None for model's top predicted class

visualize_gradcam_plus_plus(model, input_tensor, target_class=target_class, layer_name='decoder', device='cuda')

In [None]:
# Check the size of the stored (pre-processed) Tensors
import torch
from pathlib import Path

def inspect_pt_file(file_path):
    file_path = Path(file_path)
    
    if not file_path.exists():
        print(f"File does not exist: {file_path}")
        return
    
    data = torch.load(file_path, map_location='cpu')

    print(f"\nLoaded file: {file_path}")
    
    if isinstance(data, torch.Tensor):
        print(f"Single Tensor - Shape: {data.shape}")
    
    elif isinstance(data, dict):
        print("Dictionary of tensors:")
        for key, value in data.items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: shape = {value.shape}")
            else:
                print(f"  {key}: type = {type(value)}")
    
    elif isinstance(data, (list, tuple)):
        print(f"{type(data).__name__} of tensors:")
        for idx, item in enumerate(data):
            if isinstance(item, torch.Tensor):
                print(f"  [{idx}]: shape = {item.shape}")
            else:
                print(f"  [{idx}]: type = {type(item)}")
    
    else:
        print(f"Unknown type loaded: {type(data)}")
        print(f"Unkown type shape: {data.shape}")

INPUT_PATH_PREFIX = "../data/02-Tensor-Data/train/MSLS_000/"

In [None]:
# Input Tensor Shape (Stacked Modalities)
inspect_pt_file(INPUT_PATH_PREFIX + "input_tensor.pt")

In [None]:
# Segmentation Mask
inspect_pt_file(INPUT_PATH_PREFIX + "seg_mask.pt")