# Multiple Sclerosis Lesion Segmentation Dataset Overview

---



- Introduction to the challenge {evaluation metrics, comparrison to other datasets} (with link to website [https://iplab.dmi.unict.it/mfs/ms-les-seg/] and paper)
- Introduction to the Dataset (creation, pre-processing, contents and distribution data with visualizations)

Table of contents

---

## What have winning strategies been so far?

1) Ensamble methods
2) Different Registration Spaces
3) Selecting only the specific data from the image that can better find the segmentation mask
4) Using only ONE modality (FLAIR) and performing Data Augmentations on it + Deep CNN Architecture
5) nnU-Net for automatic configuration to the dataset (Interesting)
6) Usage of Mamba-based LightM-UNet
7) Focal Loss + DICE Loss

---

## Personal Goals

1) Parameter Efficiency
2) (XAI) Explainable Architecture
3) Segmentation Accuracy (Mean DICE >= 60)
4) [Optional] --> Single Model (No Ensambles)

---

## Personal Model Proposition

1) Use ALL modalities
2) Preprocess with (Distance Based Labelling, Multi-Sized Labelling ~ Could aid the MoE variant)
3) nnU-Net + MoE (+ Deep Supervision ?)

### Optional

- Bring the model visualization to Extended Reality Domain ?

# The Experiments

---

In [1]:
import os
import torch
import nibabel as nib
import numpy as np
from scipy.ndimage import distance_transform_edt
from skimage.transform import resize
from pathlib import Path
from tqdm.notebook import tqdm

## Dataset Preprocessing

- Expand more on the performed preprocessing



In [2]:
def load_nifti(file_path):
    return nib.load(file_path).get_fdata()

def normalize_intensity(img):
    # Normalize each modality independently (z-score)
    img = img.astype(np.float32)
    img = (img - np.mean(img)) / (np.std(img) + 1e-5)
    return img

def compute_distance_label(mask):
    mask = mask.astype(bool)
    pos_dist = distance_transform_edt(mask)
    neg_dist = distance_transform_edt(~mask)
    dist_map = pos_dist - neg_dist  # Signed distance transform
    return dist_map

def generate_multi_size_masks(mask, scales=[1.0, 0.5, 0.25]):
    multi_res = []
    for scale in scales:
        if scale == 1.0:
            multi_res.append(mask)
        else:
            resized = resize(mask, output_shape=tuple(int(s * scale) for s in mask.shape),
                             order=1, preserve_range=True, anti_aliasing=True)
            multi_res.append(resized.astype(mask.dtype))
    return multi_res

In [3]:
def preprocess_case(input_dir, output_dir, case_id):
    flair = normalize_intensity(load_nifti(input_dir / f"{case_id}_flair.nii.gz"))
    t1 = normalize_intensity(load_nifti(input_dir / f"{case_id}_t1.nii.gz"))
    t2 = normalize_intensity(load_nifti(input_dir / f"{case_id}_t2.nii.gz"))
    seg = load_nifti(input_dir / f"{case_id}_seg.nii.gz").astype(np.uint8)

    # Stack input modalities into a tensor (C, D, H, W)
    stacked = np.stack([flair, t1, t2], axis=0)
    # input_tensor = torch.tensor(stacked, dtype=torch.float32).cuda()
    # Use CPU Tensors instead
    input_tensor = torch.tensor(stacked, dtype=torch.float32)
    
    # Compute distance map from segmentation mask
    distance_map = compute_distance_label(seg)
    # distance_tensor = torch.tensor(distance_map, dtype=torch.float32).unsqueeze(0).cuda()
    # Use CPU Tensors instead
    distance_tensor = torch.tensor(distance_map, dtype=torch.float32).unsqueeze(0)
    
    # Generate multi-size label masks
    multi_size_masks = generate_multi_size_masks(seg)
    # multi_size_tensors = [torch.tensor(m, dtype=torch.uint8).unsqueeze(0).cuda() for m in multi_size_masks]
    # Use CPU Tensors instead
    multi_size_tensors = [torch.tensor(m, dtype=torch.uint8).unsqueeze(0) for m in multi_size_masks]

    
    # Save all
    output_case_dir = output_dir / case_id
    output_case_dir.mkdir(parents=True, exist_ok=True)
    torch.save(input_tensor, output_case_dir / "input_tensor.pt")
    torch.save(distance_tensor, output_case_dir / "distance_map.pt")
    for i, mask in enumerate(multi_size_tensors):
        torch.save(mask, output_case_dir / f"multi_size_mask_{i}.pt")

In [4]:
def run_preprocessing(root_path, output_path):
    input_path = Path(root_path)
    output_path = Path(output_path)

    all_case_dirs = [d for d in input_path.iterdir() if d.is_dir()]
    print(f"Found {len(all_case_dirs)} cases.")

    for case_dir in tqdm(all_case_dirs):
        case_id = case_dir.name  # e.g., MSLS_000
        try:
            preprocess_case(case_dir, output_path, case_id)
        except Exception as e:
            print(f"❌ Failed on {case_id}: {e}")

In [5]:
# Run the preprocessing on the training set
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/train"
OUTPUT_PATH = "../data/02-Tensor-Data/train"

run_preprocessing(RAW_DATA_PATH, OUTPUT_PATH)

Found 93 cases.


  0%|          | 0/93 [00:00<?, ?it/s]

In [6]:
# Run the preprocessing on the test set (with gt mask)
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/test/test_MASK"
OUTPUT_PATH = "../data/02-Tensor-Data/test/test_MASK"

run_preprocessing(RAW_DATA_PATH, OUTPUT_PATH)

Found 22 cases.


  0%|          | 0/22 [00:00<?, ?it/s]

In [7]:
# Variation for the test set WITHOUT the gt seg mask
def preprocess_case_inputs_only(input_dir, output_dir, case_id):
    # Required input modalities only
    required_files = {
        "flair": input_dir / f"{case_id}_flair.nii.gz",
        "t1": input_dir / f"{case_id}_t1.nii.gz",
        "t2": input_dir / f"{case_id}_t2.nii.gz",
    }

    for key, path in required_files.items():
        if not path.exists():
            raise FileNotFoundError(f"Missing file for '{key}': {path}")

    # Load and normalize each modality
    flair = normalize_intensity(load_nifti(required_files["flair"]))
    t1 = normalize_intensity(load_nifti(required_files["t1"]))
    t2 = normalize_intensity(load_nifti(required_files["t2"]))

    # Stack modalities (C, D, H, W)
    stacked = np.stack([flair, t1, t2], axis=0)
    input_tensor = torch.tensor(stacked, dtype=torch.float32).cuda()

    # Save preprocessed tensor
    output_case_dir = output_dir / case_id
    output_case_dir.mkdir(parents=True, exist_ok=True)
    torch.save(input_tensor, output_case_dir / "input_tensor.pt")


In [8]:
# Variation for the test set WITHOUT the gt seg mask
def run_preprocessing_inputs_only(root_path, output_path):
    input_path = Path(root_path)
    output_path = Path(output_path)

    all_case_dirs = [d for d in input_path.iterdir() if d.is_dir()]
    print(f"Found {len(all_case_dirs)} cases.")

    for case_dir in tqdm(all_case_dirs):
        case_id = case_dir.name
        try:
            preprocess_case_inputs_only(case_dir, output_path, case_id)
        except Exception as e:
            print(f"❌ Skipping {case_id}: {e}")


In [9]:
# Run the preprocessing on the test set (no gt mask)
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/test/test"
OUTPUT_PATH = "../data/02-Tensor-Data/test/test"

run_preprocessing_inputs_only(RAW_DATA_PATH, OUTPUT_PATH)

Found 22 cases.


  0%|          | 0/22 [00:00<?, ?it/s]

## Build the Dataset and Dataloaders for the MSLesSeg preprocessed data

In [2]:
# Check the initial memory consuption (GPU) of the project
import torch

# Check initial GPU memory usage
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [3]:
from torch.utils.data import Dataset
from pathlib import Path
import torch
import torch.nn.functional as F

class MSLesionDataset(Dataset):
    def __init__(self, root_dir, include_labels=True, sample_ids=None, transform=None):
        self.root_dir = Path(root_dir)
        all_samples = sorted([d for d in self.root_dir.iterdir() if d.is_dir()])
        
        if sample_ids:
            self.sample_dirs = [d for d in all_samples if d.name in sample_ids]
        else:
            self.sample_dirs = all_samples
        
        self.include_labels = include_labels
        self.transform = transform

    def pad_to_match(self, tensor, reference_shape):
        """
        Pads a tensor of shape (C, D, H, W) to match reference_shape.
        """
        current_shape = tensor.shape
        assert len(current_shape) == 4, f"Expected 4D tensor (C, D, H, W), got {current_shape}"
        assert len(reference_shape) == 4, f"Reference shape must be 4D, got {reference_shape}"

        pad_sizes = []
        for i in range(3, 0, -1):  # W, H, D — reverse order
            diff = reference_shape[i] - current_shape[i]
            if diff < 0:
                raise ValueError(f"Tensor dimension {i} is larger than reference: {current_shape[i]} > {reference_shape[i]}")
            pad_sizes.extend([0, diff])  # pad only after

        padded_tensor = F.pad(tensor, pad_sizes, mode='constant', value=0)
        return padded_tensor

    def __len__(self):
        return len(self.sample_dirs)

    def __getitem__(self, idx):
        case_dir = self.sample_dirs[idx]

        # === Load input tensor ===
        input_tensor = torch.load(case_dir / "input_tensor.pt").float()
        if input_tensor.dim() == 3:
            input_tensor = input_tensor.unsqueeze(0)  # (D, H, W) → (1, D, H, W)

        reference_shape = input_tensor.shape
        print(f"Input tensor shape: {reference_shape}")

        if self.transform:
            input_tensor = self.transform(input_tensor)

        if not self.include_labels:
            return input_tensor, str(case_dir.name)

        # === Load and pad distance map ===
        distance_map = torch.load(case_dir / "distance_map.pt").float()
        if distance_map.dim() == 3:
            distance_map = distance_map.unsqueeze(0)

        distance_map = self.pad_to_match(distance_map, reference_shape)

        # === Load and pad multi-size masks ===
        multi_size_masks = []
        for i in range(3):
            path = case_dir / f"multi_size_mask_{i}.pt"
            if path.exists():
                mask = torch.load(path).float()
            else:
                mask = torch.zeros_like(distance_map)

            if mask.dim() == 3:
                mask = mask.unsqueeze(0)

            mask = self.pad_to_match(mask, reference_shape)
            multi_size_masks.append(mask)

        multi_masks_cat = torch.cat(multi_size_masks, dim=0)  # shape: (3, D, H, W)

        return {
            "input": input_tensor,
            "distance": distance_map,
            "multi_masks": multi_masks_cat,
        }


In [4]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [5]:
from torch.utils.data import DataLoader

def get_dataloader(data_root, batch_size=2, num_workers=4, shuffle=True, include_labels=True, sample_ids=None):
    """
    Get the DataLoader for the given dataset directory.

    Args:
        data_root: Path to the dataset root directory.
        batch_size: Number of samples per batch.
        num_workers: Number of workers for loading data in parallel.
        shuffle: Whether to shuffle the data.
        include_labels: Whether to include segmentation masks and distance maps.
        sample_ids: A list of sample ids for a specific split (for k-fold).
    """
    dataset = MSLesionDataset(
        data_root,
        include_labels=include_labels,
        sample_ids=sample_ids
    )

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False
    )

In [6]:
# Definition of some (non training) parameters
NUM_WORKERS_TRAIN = 0
NUM_WORKERS_VAL = 0
NUM_WORKERS_TEST = 0

BATCH_SIZE_TRAIN = 2
BATCH_SIZE_VAL = 2
BATCH_SIZE_TEST = 4

In [7]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [8]:
import os
import random

TRAIN_TENSOR_DATA_PATH = "../data/02-Tensor-Data/train"

# List of all available case IDs in the train directory
all_cases = sorted(os.listdir(TRAIN_TENSOR_DATA_PATH))
random.seed(42)

# For k-fold cross-validation, split train into train and validation (80-20)
split = int(0.8 * len(all_cases))
train_ids = all_cases[:split]
val_ids = all_cases[split:]

train_loader = get_dataloader(TRAIN_TENSOR_DATA_PATH, sample_ids=train_ids, batch_size=BATCH_SIZE_TRAIN, num_workers=NUM_WORKERS_TRAIN)
val_loader = get_dataloader(TRAIN_TENSOR_DATA_PATH, sample_ids=val_ids, batch_size=BATCH_SIZE_VAL, num_workers=NUM_WORKERS_VAL, shuffle=False)

In [9]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [10]:
TEST_TENSOR_DATA_PATH_NOMASK = "../data/02-Tensor-Data/test/test"
TEST_TENSOR_DATA_PATH_MASK = "../data/02-Tensor-Data/test/test_MASK"

# Test set WITHOUT labels (for inference or prediction)
test_loader_no_labels = get_dataloader(TEST_TENSOR_DATA_PATH_NOMASK, include_labels=False, batch_size=BATCH_SIZE_TEST, num_workers=NUM_WORKERS_TEST)

# Test set WITH labels (for final evaluation or performance metrics)
test_loader_with_labels = get_dataloader(TEST_TENSOR_DATA_PATH_MASK, include_labels=True, batch_size=BATCH_SIZE_TEST, num_workers=NUM_WORKERS_TEST)

In [11]:
# In order to perform K-Fold Cross Validation
from sklearn.model_selection import KFold

def get_k_fold_loaders(data_root, k=5, batch_size=BATCH_SIZE_TRAIN, num_workers=NUM_WORKERS_TRAIN):
    # Get all case IDs
    all_cases = sorted(os.listdir(data_root))
    
    kf = KFold(n_splits=k, shuffle=True, random_state=42)
    
    fold_loaders = []

    for train_idx, val_idx in kf.split(all_cases):
        train_ids = [all_cases[i] for i in train_idx]
        val_ids = [all_cases[i] for i in val_idx]

        train_loader = get_dataloader(data_root, sample_ids=train_ids, batch_size=batch_size, num_workers=num_workers)
        val_loader = get_dataloader(data_root, sample_ids=val_ids, batch_size=batch_size, num_workers=num_workers, shuffle=False)

        fold_loaders.append((train_loader, val_loader))

    return fold_loaders

In [12]:
fold_loaders = get_k_fold_loaders(TRAIN_TENSOR_DATA_PATH, k=5, batch_size=2, num_workers=4)

# Example: Access the first fold
# train_loader, val_loader = fold_loaders[0]

### PyTorch Lightning DataModule

This particular version of PyTorch Lightning, and in general from version 2.x onward require a ***LightningDataModule*** instead of passing the dataloaders directly to the ***.fit()*** method

In [13]:
from pytorch_lightning import LightningDataModule
import random
import os
from pathlib import Path

class MSLesionSegmentationDataModule(LightningDataModule):
    def __init__(self, data_root, train_split=0.8, batch_size=2, num_workers=4):
        """
        Args:
            data_root: Path to the root directory containing the dataset.
            train_split: Fraction of data to use for training (default 80%).
            batch_size: Batch size for the dataloader.
            num_workers: Number of workers for the dataloader.
        """
        super().__init__()
        self.data_root = data_root
        self.train_split = train_split
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.train_dataloader_instance = None
        self.val_dataloader_instance = None

    def train_dataloader(self):
        return self.train_dataloader_instance

    def val_dataloader(self):
        return self.val_dataloader_instance

    def test_dataloader(self):
        return self.val_dataloader_instance  # Optional, if you want to test on the validation set as well.

    def prepare_data(self):
        """Prepare any datasets if needed, such as downloading or preprocessing."""
        pass

    def setup(self, stage=None):
        """Assign train/val datasets for use in dataloaders."""
        # List all sample directories
        all_samples = sorted([d for d in Path(self.data_root).iterdir() if d.is_dir()])

        # Shuffle the sample list before splitting
        random.seed(42)  # To ensure reproducibility
        random.shuffle(all_samples)

        # Split the dataset into training and validation sets
        split_idx = int(len(all_samples) * self.train_split)
        train_samples = all_samples[:split_idx]
        val_samples = all_samples[split_idx:]

        # Create the train and val dataloaders
        self.train_dataloader_instance = get_dataloader(
            self.data_root,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            sample_ids=[d.name for d in train_samples],
        )

        self.val_dataloader_instance = get_dataloader(
            self.data_root,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            sample_ids=[d.name for d in val_samples],
        )


In [14]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


## The Model's Architecture

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MoE(nn.Module):
    """
    Mixture of Experts with dynamic input size support via GAP.
    """
    def __init__(self, input_dim, num_experts=4, expert_dim=512):
        super(MoE, self).__init__()
        self.num_experts = num_experts

        self.gate = nn.Linear(input_dim, num_experts)
        self.experts = nn.ModuleList([nn.Linear(input_dim, expert_dim) for _ in range(num_experts)])

    def forward(self, x):
        # x is expected to be (B, C), already pooled
        gate_weights = F.softmax(self.gate(x), dim=-1)
        expert_outputs = [expert(x) for expert in self.experts]
        output = sum(gate_weights[:, i].unsqueeze(-1) * expert_outputs[i] for i in range(self.num_experts))
        return output

class UNetMoE(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, num_experts=4, expert_dim=512):
        super(UNetMoE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv3d(in_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2)
        )

        self.middle = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Conv3d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, out_channels, kernel_size=1)
        )

        # Will be set dynamically after seeing one input
        self.moe = None
        self.num_experts = num_experts
        self.expert_dim = expert_dim

        self.deep_supervision = nn.Conv3d(128, out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.encoder[0:2](x)
        enc2 = self.encoder[2:](enc1)
        mid = self.middle(enc2)
        dec = self.decoder[0:2](mid)
        dec_out = self.decoder[2:](dec)

        # Global pooling before MoE
        pooled = F.adaptive_avg_pool3d(dec, 1).view(x.size(0), -1)

        # Init MoE if not yet initialized
        if self.moe is None:
            self.moe = MoE(input_dim=pooled.size(1), num_experts=self.num_experts, expert_dim=self.expert_dim).to(x.device)

        moe_out = self.moe(pooled)
        deep_supervision_out = self.deep_supervision(mid)

        return dec_out, deep_supervision_out, moe_out


In [16]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [17]:
# Example usage for testing (TO REMOVE)
model = UNetMoE(in_channels=3, out_channels=1)
dummy_input = torch.randn(2, 3, 128, 128, 128)  # e.g., batch size 2, 3 modalities
dec_out, deep_out, moe_out = model(dummy_input)
print(dec_out.shape, deep_out.shape, moe_out.shape)

torch.Size([2, 1, 64, 64, 64]) torch.Size([2, 1, 64, 64, 64]) torch.Size([2, 512])


In [18]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [19]:
# The Summary of the Architecture
from torchsummary import summary

# Set the device and the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetMoE(in_channels=3, out_channels=1, num_experts=4, expert_dim=512).to(device)

# Get the summary of the model (use an example input size, e.g., 128x128x128 for 3D)
summary(model, (3, 128, 128, 128))  # (channels, depth, height, width)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1    [-1, 32, 128, 128, 128]           2,624
              ReLU-2    [-1, 32, 128, 128, 128]               0
            Conv3d-3    [-1, 64, 128, 128, 128]          55,360
              ReLU-4    [-1, 64, 128, 128, 128]               0
         MaxPool3d-5       [-1, 64, 64, 64, 64]               0
            Conv3d-6      [-1, 128, 64, 64, 64]         221,312
              ReLU-7      [-1, 128, 64, 64, 64]               0
            Conv3d-8      [-1, 128, 64, 64, 64]         442,496
              ReLU-9      [-1, 128, 64, 64, 64]               0
           Conv3d-10       [-1, 64, 64, 64, 64]         221,248
             ReLU-11       [-1, 64, 64, 64, 64]               0
           Conv3d-12       [-1, 32, 64, 64, 64]          55,328
             ReLU-13       [-1, 32, 64, 64, 64]               0
           Conv3d-14        [-1, 1, 64,

In [20]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 12.45 MB


In [21]:
# For the graphical visualization
from torchviz import make_dot
import gc

# Dummy input
dummy_input = torch.randn(1, 3, 128, 128, 128).cuda()  # Adjust for batch size and input shape

# Explicitly instruct the model to not save the gratients to avoid computation graph creation
with torch.no_grad():
    # Forward pass through the model
    output, _, _ = model(dummy_input)

    # Visualize the graph
    make_dot(output, params=dict(model.named_parameters())).render("model_architecture", format="png")

# Force garbage collection
gc.collect()

# Clear the non-essentialy GPU memory usage
del dummy_input  # Delete the dummy input tensor
del output       # Optionally, delete the output if you don't need it anymore
torch.cuda.empty_cache()

In [22]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 12.45 MB


In [23]:
# Clear the GPU memory used for the dummy input and re-initialize the model
torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetMoE(in_channels=3, out_channels=1, num_experts=4, expert_dim=512).to(device)

In [24]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 11.94 MB


## Visualized Model Architecture

![Model Architecture Plot](./model_architecture.png)

## The Trainers

In [25]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torchmetrics import MetricCollection
from torchmetrics.segmentation import DiceScore

class MSLeasionSegmentationModel(pl.LightningModule):
    def __init__(self, model, lr=1e-4):
        super(MSLeasionSegmentationModel, self).__init__()
        
        # Initialize the model (nnU-Net + MoE)
        self.model = model
        
        # Learning rate for optimizer
        self.lr = lr
        
        # Set up Dice Score metric for evaluation
        self.dice_metric = DiceScore(num_classes=2)  # Binary segmentation: lesion or background
        
    def forward(self, x):
        # Forward pass (output, deep supervision, MoE output)
        output, deep_supervision, moe_output = self.model(x)
        return output, deep_supervision, moe_output
    
    def training_step(self, batch, batch_idx):
        # Extract input data and target segmentation masks from batch
        x = batch['input_tensor']  # Shape: (B, 3, D, H, W)
        y = batch['segmentation_mask']  # Shape: (B, 1, D, H, W)
        
        # Get model output
        output, _, _ = self(x)
        
        # Compute loss (using binary cross entropy for binary segmentation)
        loss = F.binary_cross_entropy_with_logits(output, y)
        
        # Log the loss
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        # Extract input data and target segmentation masks from batch
        x = batch['input']  # Change 'input_tensor' to 'input' to match dataset
        y = batch['multi_masks']  # Change 'segmentation_mask' to 'multi_masks' to match dataset
        
        # Get model output
        output, _, _ = self(x)
        
        # Calculate the Dice score
        dice_score = self.dice_metric(output.sigmoid(), y)
        
        # Log validation metrics
        self.log('val_dice', dice_score, on_step=True, on_epoch=True, prog_bar=True)
        
        return dice_score

    
    def test_step(self, batch, batch_idx):
        # Extract input data and target segmentation masks from batch
        x = batch['input_tensor']  # Shape: (B, 3, D, H, W)
        y = batch['segmentation_mask']  # Shape: (B, 1, D, H, W)
        
        # Get model output
        output, _, _ = self(x)
        
        # Calculate the Dice score
        dice_score = self.dice_metric(output.sigmoid(), y)
        
        # Log test metrics
        self.log('test_dice', dice_score, on_step=True, on_epoch=True, prog_bar=True)
        
        return dice_score
    
    def configure_optimizers(self):
        # Set up Adam optimizer with the specified learning rate
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        
        # Optionally, add a scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
        
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_dice'}

## The Training

In [26]:
# Register a free account with Weights and Biases, and create a new project in order to obtain an API Key for the training
import os
from dotenv import load_dotenv
import wandb

# Load the .env file
load_dotenv()

# Get the API key from the env variable
api_key = os.getenv("WANDB_API_KEY")

# Login to wandb
if api_key:
    os.environ["WANDB_API_KEY"] = api_key
    wandb.login()
else:
    print("❌ WANDB_API_KEY not found in .env file.")

[34m[1mwandb[0m: Currently logged in as: [33mdrnnrw00m10c351s[0m ([33mfpv-perceivelab-unict[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [27]:
# Setup the Weights and Biases logger
import wandb
from pytorch_lightning.loggers import WandbLogger

wandb_logger = WandbLogger(
    project='MSLesSeg-4-ICPR',     # Change to your actual project name
    name='nnUNet_MoE_run1', # A specific run name
    log_model=True          # Optional: log model checkpoints
)

In [28]:
# Set the multiprocessing start to 'spawn' instead of 'fork' due to CUDA issues with the Dataloader
import multiprocessing
import torch

# Set the start method for multiprocessing
multiprocessing.set_start_method('spawn', force=True)

In [29]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 11.94 MB


In [30]:
# Add the profiler to keep track of GPU overhead
from pytorch_lightning.profilers import PyTorchProfiler

# Slightly more sophisticated profiler
profiler = PyTorchProfiler(
    on_trace_ready=lambda prof: print(prof.key_averages().table(
        sort_by="self_cuda_memory_usage", row_limit=15  # Change as needed
    )),
    profile_memory=True,
    record_shapes=True,
    with_stack=True,
)

In [31]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 11.94 MB


In [32]:
# A computation to estimate the GPU memory consumption of the model

# Print the total number of parameters in your model
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")

# Calculate the approximate memory size of the model (in MB)
model_size_MB = total_params * 4 / (1024 ** 2)  # 4 bytes per parameter for float32
print(f"Estimated model size: {model_size_MB:.2f} MB")

Total number of parameters: 998530
Estimated model size: 3.81 MB


In [33]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [34]:
import torch
import gc

# Run garbage collection to clean up unused objects
gc.collect()

# Empty the PyTorch cache
torch.cuda.empty_cache()

In [35]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import os

# Initialize the model
model = MSLeasionSegmentationModel(model=UNetMoE(in_channels=3, out_channels=1, num_experts=4, expert_dim=512))

# Set up callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val_dice',  # Monitor the validation Dice score
    dirpath='checkpoints/',
    filename='best_checkpoint',
    save_top_k=1,
    mode='max',
)

early_stop_callback = EarlyStopping(
    monitor='val_dice',  # Early stopping based on the validation Dice score
    patience=10,
    mode='max',
)

# Specify the data root where the preprocessed data is located
data_root = TRAIN_TENSOR_DATA_PATH  # Replace this with your actual path

# Instantiate the PyTorch Lightning DataModule
data_module = MSLesionSegmentationDataModule(
    data_root=data_root,    # Dataset root directory
    train_split=0.8,         # Use an 80/20 split for train/validation
    batch_size=4,            # Adjust batch size as needed
    num_workers=0,           # Number of workers for data loading
)

# Initialize the Trainer
trainer = Trainer(
    accelerator="gpu",        # or "cpu" or "auto"
    devices=1,                # Number of GPUs to use
    precision=16,             # Sets a given precision to save GPU memory
    max_epochs=100,
    callbacks=[checkpoint_callback, early_stop_callback],
    logger=wandb_logger,       # Ensure `wandb_logger` is initialized beforehand
    # profiler=profiler          # Ensures that the model is being profiled for efficient usage
)

# Train the model using the data module
trainer.fit(model, datamodule=data_module)

/home/drew/miniconda3/envs/mslesseg4icpr/lib/python3.11/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4070') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type      | Params | Mode 
--------------------------------------------------
0 | model       | UNetMoE   | 998 K  | train
1 | dice_metric | DiceScore | 0      | train
--------------------------------------------------
998 K     Trainable params
0         Non-trainable params
998 K     Total params
3.994     Total estimated model params size (MB)
20        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/home/drew/miniconda3/envs/mslesseg4icpr/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/drew/miniconda3/envs/mslesseg4icpr/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Input tensor shape: torch.Size([3, 182, 218, 182])
Input tensor shape: torch.Size([3, 182, 218, 182])
Input tensor shape: torch.Size([3, 182, 218, 182])
Input tensor shape: torch.Size([3, 182, 218, 182])


RuntimeError: Predictions and targets are expected to have the same shape, but got torch.Size([4, 1, 91, 109, 91]) and torch.Size([4, 3, 182, 218, 182]).

In [None]:
# Evaluate the model on the test set
trainer.test(model, test_dataloader=test_loader)

## Visualizing the Learned Representations

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

# Function to extract features from the model
def extract_features(model, dataloader, device='cuda'):
    model.eval()
    features = []
    labels = []
    
    with torch.no_grad():
        for data in dataloader:
            inputs, target = data['image'].to(device), data['mask'].to(device)  # Replace 'image' and 'mask' keys as per your dataset
            
            # Forward pass through the model
            output = model(inputs)  # This is where we extract the feature; adjust according to your model's architecture
            
            # Extract features from MoE (or any other layer you want to visualize)
            # Here I assume output is the final feature map after the MoE layer
            feature_map = output.view(output.size(0), -1)  # Flatten the features to (batch_size, features)
            features.append(feature_map.cpu().numpy())
            labels.append(target.cpu().numpy())  # Add the target (or ground truth) labels for color coding in t-SNE

    features = np.concatenate(features, axis=0)  # Combine all the feature maps
    labels = np.concatenate(labels, axis=0)  # Combine all the labels

    return features, labels

# Function to apply t-SNE on features
def plot_tsne(features, labels):
    # Standardize the features (optional but recommended for t-SNE)
    scaler = StandardScaler()
    features = scaler.fit_transform(features)

    # Apply t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(features)
    
    # Plotting
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=labels, cmap='jet', alpha=0.5)
    plt.colorbar(scatter)
    plt.title('t-SNE Visualization of Learned Representations')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.show()

# Example Usage
train_dataloader = # Your train DataLoader here

# Extract features and labels from the model
features, labels = extract_features(model, train_dataloader, device='cuda')

# Plot t-SNE visualization
plot_tsne(features, labels)

## Post Hoc Model Explainability - GradCAM++

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2
from captum.attr import LayerGradCam, LayerGradCamPlusPlus
from captum.attr import GuidedBackprop

# Function to visualize GradCAM++ output
def visualize_gradcam_plus_plus(model, input_tensor, target_class=None, layer_name='decoder', device='cuda'):
    # Move input tensor to the appropriate device
    input_tensor = input_tensor.to(device)

    # Set model to evaluation mode
    model.eval()
    
    # Create GradCAM++ object for the specified layer
    gradcam_pp = LayerGradCamPlusPlus(model, model.decoder[2])  # Assuming 'decoder' is your final layer, adjust accordingly.
    
    # Apply GradCAM++ to input tensor
    attributions = gradcam_pp.attribute(input_tensor, target=target_class)
    
    # Convert attributions to numpy
    attributions = attributions.cpu().detach().numpy()
    
    # Normalize the heatmap
    heatmap = np.sum(attributions[0], axis=0)  # Summing over channels for a single-channel heatmap
    heatmap = np.maximum(heatmap, 0)  # ReLU to ignore negative values
    heatmap = cv2.resize(heatmap, (input_tensor.shape[2], input_tensor.shape[3]))  # Resize to input size
    heatmap = cv2.normalize(heatmap, None, 0, 1, cv2.NORM_MINMAX)  # Normalize the heatmap to [0, 1]
    
    # Convert the original image to numpy and rescale
    input_image = input_tensor[0].cpu().detach().numpy().transpose(1, 2, 0)
    input_image = cv2.resize(input_image, (input_tensor.shape[2], input_tensor.shape[3]))  # Resize to input size
    
    # Create a colormap for the heatmap
    colormap = plt.get_cmap('jet')
    colored_heatmap = colormap(heatmap)  # Apply colormap
    
    # Overlay the heatmap on top of the original image
    superimposed_img = np.uint8(input_image * 255)  # Convert original image back to [0, 255] range
    superimposed_img = cv2.addWeighted(superimposed_img, 0.7, np.uint8(colored_heatmap[:, :, :3] * 255), 0.3, 0)
    
    # Plot the result
    plt.imshow(superimposed_img)
    plt.title('GradCAM++ Heatmap')
    plt.axis('off')
    plt.show()

# Example Usage (assuming `model` is your trained model and `input_tensor` is an input image)
input_tensor = torch.randn(1, 3, 128, 128).to(device)  # Example input, use your actual input tensor
target_class = None  # You can provide a target class or leave as None for model's top predicted class

visualize_gradcam_plus_plus(model, input_tensor, target_class=target_class, layer_name='decoder', device='cuda')

In [None]:
# Check the size of the stored (pre-processed) Tensors
import torch
from pathlib import Path

def inspect_pt_file(file_path):
    file_path = Path(file_path)
    
    if not file_path.exists():
        print(f"File does not exist: {file_path}")
        return
    
    data = torch.load(file_path, map_location='cpu')

    print(f"\nLoaded file: {file_path}")
    
    if isinstance(data, torch.Tensor):
        print(f"Single Tensor - Shape: {data.shape}")
    
    elif isinstance(data, dict):
        print("Dictionary of tensors:")
        for key, value in data.items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: shape = {value.shape}")
            else:
                print(f"  {key}: type = {type(value)}")
    
    elif isinstance(data, (list, tuple)):
        print(f"{type(data).__name__} of tensors:")
        for idx, item in enumerate(data):
            if isinstance(item, torch.Tensor):
                print(f"  [{idx}]: shape = {item.shape}")
            else:
                print(f"  [{idx}]: type = {type(item)}")
    
    else:
        print(f"Unknown type loaded: {type(data)}")

INPUT_PATH_PREFIX = "../data/02-Tensor-Data/train/MSLS_000/"

In [None]:
# Input Tensor Shape (Stacked Modalities)
inspect_pt_file(INPUT_PATH_PREFIX + "input_tensor.pt")

In [None]:
# Distance Map Shape
inspect_pt_file(INPUT_PATH_PREFIX + "distance_map.pt")

In [None]:
# Multi Size Mask - 0
inspect_pt_file(INPUT_PATH_PREFIX + "multi_size_mask_0.pt")

In [None]:
# Multi Size Mask - 1
inspect_pt_file(INPUT_PATH_PREFIX + "multi_size_mask_1.pt")

In [None]:
# Multi Size Mask - 2
inspect_pt_file(INPUT_PATH_PREFIX + "multi_size_mask_2.pt")