# U-Net 4 Multiple Sclerosis Lesion Segmentation - ICPR Challenge
### Authors: Andrew R. Darnall, Giovanni Spadaro @ UniCT
---

## 🎯 Competition Objective: MS Lesion Segmentation

The central goal of this competition is the **automatic segmentation of Multiple Sclerosis (MS) lesions** using **multi-modal MRI data** and **deep learning algorithms**.

### 🧪 Provided Data
Participants were given:
- **MRI scans** in three modalities:
  - **FLAIR**
  - **T1-weighted (T1-w)**
  - **T2-weighted (T2-w)**
- **Ground-truth segmentation masks**, which are:
  - **Binary masks**:  
    - **White pixels** → MS lesion regions  
    - **Black pixels** → Background

### 🧠 Task Description
- Participants could use **any or all modalities**, along with the ground-truth labels, to:
  - Develop **deep learning-based models** for **automatic lesion segmentation**
- MS lesions appear as **irregular clusters of pixels** with **high variability in size and shape**
- These lesions are often **difficult to detect** via visual inspection, requiring **expert-level interpretation**

The ultimate goal is to create **fully automated segmentation pipelines** that can robustly identify and delineate MS lesions from raw MRI data.


## 🧠 MSLesSeg Dataset Overview

As part of this competition, participants were provided with the **MSLesSeg Dataset** — a **comprehensively annotated, multi-modal MRI dataset** designed for advancing **lesion segmentation** research in medical imaging.

### 📊 Dataset Composition
- **Total Patients:** 75 (48 women, 27 men)  
- **Age Range:** 18–59 years (Mean: 37 ± 10.3 years)  
- **Longitudinal Timepoints:**  
  - 50 patients with 1 timepoint  
  - 15 patients with 2 timepoints  
  - 5 patients with 3 timepoints  
  - 5 patients with 4 timepoints  
- **Time Interval Between Scans:** ~1.27 ± 0.62 years  
- **Total MRI Series:** 115

### 🧬 Imaging Modalities
Each timepoint includes **three core MRI modalities**:
- **T1-weighted (T1-w)**
- **T2-weighted (T2-w)**
- **FLAIR (Fluid-Attenuated Inversion Recovery)**

### 🧑‍⚕️ Expert Annotation
- Lesions were **manually annotated** by clinical experts.
- **FLAIR sequences** were the primary reference for lesion labeling.
- **T1-w and T2-w** scans supported **multi-contrast lesion characterization**.

### 🧪 Dataset Splits
- **Training Set:** 53 scans  
- **Test Set:** 22 scans  

### ✅ Ethical Compliance
- **Ethical approval** was obtained from the corresponding Hospital Ethics Committee.
- **Informed consent** was acquired from all participating patients.

---

# The Experiment

Below is the code used for the:

1) Preprocessing of the ***Brain MRI*** scans
2) Definition of Dataset, Dataloader and LihgtningDataModule classes
3) ***U-Net*** architecture
4) ***PyTorch Lightning*** Trainer
5) Training & Evaluation
6) Model Exaplainability with the post-hoc method ***GradCam++***

---

In [1]:
import os
import torch
import nibabel as nib
import numpy as np
from scipy.ndimage import distance_transform_edt
from skimage.transform import resize
from pathlib import Path
from tqdm.notebook import tqdm

## 🛠️ Preprocessing & Annotation Workflow

The MSLesSeg dataset underwent a **comprehensive preprocessing pipeline** and **expert-driven manual annotation** to ensure **standardization** and **label quality** for downstream MS lesion segmentation tasks.

### 🧼 Preprocessing Pipeline
1. **Anonymization** of all MRI scans to protect patient privacy.
2. **DICOM to NIfTI conversion**, leveraging NIfTI's wide adoption in neuroimaging.
3. **Co-registration to the MNI152 1mm³ isotropic template** using **FLIRT** (FMRIB’s Linear Image Registration Tool), ensuring all scans are aligned to a **common anatomical space**.
4. **Brain extraction** via **BET** (Brain Extraction Tool) to remove non-brain tissues and isolate relevant structures.

This pipeline guarantees that all images are **standardized** and **aligned**, which is critical for **automated MS lesion segmentation algorithms**.

---

### 🖋️ Ground-Truth Annotation Protocol
- Lesions were **manually segmented** on the **FLAIR modality** for each patient and timepoint.
- **T1-w and T2-w** modalities were used to **cross-validate ambiguous cases**.
- Annotation was conducted by a **trained junior rater**, under supervision of:
  - A **senior neuroradiologist**
  - A **senior neurologist**
- Annotation sessions included:
  - Multiple **training meetings** to establish a **consistent segmentation strategy**
  - Use of **JIM9** — a high-end tool for **medical image segmentation and analysis**
  - Regular **expert validation checkpoints** to ensure consistency and accuracy

The final masks, reviewed and approved by senior experts, are considered the **gold-standard ground truth**.

---

## 🧾 Key Annotation Highlights
- **Independent segmentation** for each patient/timepoint to avoid bias
- Conducted on **FLAIR scans registered to MNI space**
- **Validated ground-truth masks** ready for training and evaluation



In [2]:
def load_nifti(file_path):
    return nib.load(file_path).get_fdata()

def normalize_intensity(img):
    # Normalize each modality independently (z-score)
    img = img.astype(np.float32)
    img = (img - np.mean(img)) / (np.std(img) + 1e-5)
    return img

def compute_distance_label(mask):
    mask = mask.astype(bool)
    pos_dist = distance_transform_edt(mask)
    neg_dist = distance_transform_edt(~mask)
    dist_map = pos_dist - neg_dist  # Signed distance transform
    return dist_map

def generate_multi_size_masks(mask, scales=[1.0, 0.5, 0.25]):
    multi_res = []
    for scale in scales:
        if scale == 1.0:
            multi_res.append(mask)
        else:
            resized = resize(mask, output_shape=tuple(int(s * scale) for s in mask.shape),
                             order=1, preserve_range=True, anti_aliasing=True)
            multi_res.append(resized.astype(mask.dtype))
    return multi_res

In [3]:
def preprocess_case(input_dir, output_dir, case_id):
    flair = normalize_intensity(load_nifti(input_dir / f"{case_id}_flair.nii.gz"))
    t1 = normalize_intensity(load_nifti(input_dir / f"{case_id}_t1.nii.gz"))
    t2 = normalize_intensity(load_nifti(input_dir / f"{case_id}_t2.nii.gz"))
    seg = load_nifti(input_dir / f"{case_id}_seg.nii.gz").astype(np.uint8)

    # Stack input modalities into a tensor (C, D, H, W)
    stacked = np.stack([flair, t1, t2], axis=0)
    # input_tensor = torch.tensor(stacked, dtype=torch.float32).cuda()
    # Use CPU Tensors instead
    input_tensor = torch.tensor(stacked, dtype=torch.float32)
    
    # Load it back into a PyTorch Tensor for storing
    seg_tensor = torch.tensor(seg, dtype=torch.uint8)
    # Add the batch dimension in order to make it compatible with the other Tensor sizes
    seg_tensor = seg_tensor.unsqueeze(0)

    
    # Save all
    output_case_dir = output_dir / case_id
    output_case_dir.mkdir(parents=True, exist_ok=True)
    torch.save(input_tensor, output_case_dir / "input_tensor.pt")
    torch.save(seg_tensor, output_case_dir / "seg_mask.pt")

In [4]:
from pathlib import Path
from tqdm import tqdm

def run_preprocessing(root_path, output_path):
    input_path = Path(root_path)
    output_path = Path(output_path)

    # Get all case directories
    all_case_dirs = [d for d in input_path.iterdir() if d.is_dir()]
    print(f"Found {len(all_case_dirs)} cases.")

    for case_dir in tqdm(all_case_dirs):
        case_id = case_dir.name  # e.g., MSLS_000
        output_case_path = output_path / case_id  # Define output path for each case

        # Check if the case has already been processed (e.g., output directory or file exists)
        if output_case_path.exists():
            print(f"✅ Skipping {case_id}, already processed.")
            continue  # Skip processing if the case already exists

        try:
            preprocess_case(case_dir, output_path, case_id)
        except Exception as e:
            print(f"❌ Failed on {case_id}: {e}")

In [5]:
# Run the preprocessing on the training set
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/train"
OUTPUT_PATH = "../data/02-Tensor-Data/train"

run_preprocessing(RAW_DATA_PATH, OUTPUT_PATH)

Found 93 cases.


100%|█████████████████████████████████████████| 93/93 [00:00<00:00, 7383.50it/s]

✅ Skipping MSLS_019, already processed.
✅ Skipping MSLS_039, already processed.
✅ Skipping MSLS_059, already processed.
✅ Skipping MSLS_000, already processed.
✅ Skipping MSLS_001, already processed.
✅ Skipping MSLS_002, already processed.
✅ Skipping MSLS_003, already processed.
✅ Skipping MSLS_004, already processed.
✅ Skipping MSLS_005, already processed.
✅ Skipping MSLS_006, already processed.
✅ Skipping MSLS_007, already processed.
✅ Skipping MSLS_008, already processed.
✅ Skipping MSLS_009, already processed.
✅ Skipping MSLS_010, already processed.
✅ Skipping MSLS_011, already processed.
✅ Skipping MSLS_012, already processed.
✅ Skipping MSLS_013, already processed.
✅ Skipping MSLS_014, already processed.
✅ Skipping MSLS_015, already processed.
✅ Skipping MSLS_016, already processed.
✅ Skipping MSLS_017, already processed.
✅ Skipping MSLS_018, already processed.
✅ Skipping MSLS_020, already processed.
✅ Skipping MSLS_021, already processed.
✅ Skipping MSLS_022, already processed.





In [6]:
# Run the preprocessing on the test set (with gt mask)
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/test/test_MASK"
OUTPUT_PATH = "../data/02-Tensor-Data/test/test_MASK"

run_preprocessing(RAW_DATA_PATH, OUTPUT_PATH)

Found 22 cases.


100%|█████████████████████████████████████████| 22/22 [00:00<00:00, 7596.50it/s]

✅ Skipping MSLS_093, already processed.
✅ Skipping MSLS_094, already processed.
✅ Skipping MSLS_095, already processed.
✅ Skipping MSLS_096, already processed.
✅ Skipping MSLS_097, already processed.
✅ Skipping MSLS_098, already processed.
✅ Skipping MSLS_099, already processed.
✅ Skipping MSLS_100, already processed.
✅ Skipping MSLS_101, already processed.
✅ Skipping MSLS_102, already processed.
✅ Skipping MSLS_103, already processed.
✅ Skipping MSLS_104, already processed.
✅ Skipping MSLS_105, already processed.
✅ Skipping MSLS_106, already processed.
✅ Skipping MSLS_107, already processed.
✅ Skipping MSLS_108, already processed.
✅ Skipping MSLS_109, already processed.
✅ Skipping MSLS_110, already processed.
✅ Skipping MSLS_111, already processed.
✅ Skipping MSLS_112, already processed.
✅ Skipping MSLS_113, already processed.
✅ Skipping MSLS_114, already processed.





In [7]:
# Variation for the test set WITHOUT the gt seg mask
def preprocess_case_no_seg_mask(input_dir, output_dir, case_id):
    flair = normalize_intensity(load_nifti(input_dir / f"{case_id}_flair.nii.gz"))
    t1 = normalize_intensity(load_nifti(input_dir / f"{case_id}_t1.nii.gz"))
    t2 = normalize_intensity(load_nifti(input_dir / f"{case_id}_t2.nii.gz"))

    # Stack input modalities into a tensor (C, D, H, W)
    stacked = np.stack([flair, t1, t2], axis=0)
    # input_tensor = torch.tensor(stacked, dtype=torch.float32).cuda()
    # Use CPU Tensors instead
    input_tensor = torch.tensor(stacked, dtype=torch.float32)
    
    # Cannot compute the distance maps and multi scale masks due to lack of segmentation mask
    
    # Save all
    output_case_dir = output_dir / case_id
    output_case_dir.mkdir(parents=True, exist_ok=True)
    torch.save(input_tensor, output_case_dir / "input_tensor.pt")

In [8]:
from pathlib import Path
from tqdm import tqdm

# Variation for the test set WITHOUT the gt seg mask
def run_preprocess_no_seg_mask(root_path, output_path):
    input_path = Path(root_path)
    output_path = Path(output_path)

    # Get all case directories
    all_case_dirs = [d for d in input_path.iterdir() if d.is_dir()]
    print(f"Found {len(all_case_dirs)} cases.")

    for case_dir in tqdm(all_case_dirs):
        case_id = case_dir.name
        output_case_path = output_path / case_id  # Define output path for each case

        # Check if the case has already been processed (e.g., output directory or file exists)
        if output_case_path.exists():
            print(f"✅ Skipping {case_id}, already processed.")
            continue  # Skip processing if the case already exists

        try:
            preprocess_case_no_seg_mask(case_dir, output_path, case_id)
        except Exception as e:
            print(f"❌ Skipping {case_id}: {e}")

In [9]:
# Run the preprocessing on the test set (no gt mask)
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/test/test"
OUTPUT_PATH = "../data/02-Tensor-Data/test/test"

run_preprocess_no_seg_mask(RAW_DATA_PATH, OUTPUT_PATH)

Found 22 cases.


100%|█████████████████████████████████████████| 22/22 [00:00<00:00, 7843.15it/s]

✅ Skipping MSLS_093, already processed.
✅ Skipping MSLS_094, already processed.
✅ Skipping MSLS_095, already processed.
✅ Skipping MSLS_096, already processed.
✅ Skipping MSLS_097, already processed.
✅ Skipping MSLS_098, already processed.
✅ Skipping MSLS_099, already processed.
✅ Skipping MSLS_100, already processed.
✅ Skipping MSLS_101, already processed.
✅ Skipping MSLS_102, already processed.
✅ Skipping MSLS_103, already processed.
✅ Skipping MSLS_104, already processed.
✅ Skipping MSLS_105, already processed.
✅ Skipping MSLS_106, already processed.
✅ Skipping MSLS_107, already processed.
✅ Skipping MSLS_108, already processed.
✅ Skipping MSLS_109, already processed.
✅ Skipping MSLS_110, already processed.
✅ Skipping MSLS_111, already processed.
✅ Skipping MSLS_112, already processed.
✅ Skipping MSLS_113, already processed.
✅ Skipping MSLS_114, already processed.





## Build the Dataset and Dataloaders for the MSLesSeg preprocessed data

In [10]:
# Check the initial memory consuption (GPU) of the project
import torch

# Check initial GPU memory usage
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [11]:
from torch.utils.data import Dataset
from pathlib import Path
import torch
import torch.nn.functional as F

class MSLesionDataset(Dataset):
    def __init__(self, root_dir, include_labels=True, sample_ids=None, transform=None):
        self.root_dir = Path(root_dir)
        all_samples = sorted([d for d in self.root_dir.iterdir() if d.is_dir()])
        self.sample_dirs = [d for d in all_samples if not sample_ids or d.name in sample_ids]
        self.include_labels = include_labels
        self.transform = transform

    def pad_to_match(self, tensor, reference_shape):
        """Pads a tensor to match a 4D (C, D, H, W) reference shape."""
        current_shape = tensor.shape
        pad_sizes = []
        for i in range(3, 0, -1):  # W, H, D
            diff = reference_shape[i] - current_shape[i]
            if diff < 0:
                raise ValueError(f"Tensor dimension {i} is larger than reference: {current_shape[i]} > {reference_shape[i]}")
            pad_sizes.extend([0, diff])
        return F.pad(tensor, pad_sizes, mode='constant', value=0)

    def __len__(self):
        return len(self.sample_dirs)

    def __getitem__(self, idx):
        try:
            case_dir = self.sample_dirs[idx]
    
            # === Load FLAIR, T1, T2 ===
            input_tensor = torch.load(case_dir / "input_tensor.pt").float()  # shape: [3, D, H, W]
            if input_tensor.dim() == 3:
                input_tensor = input_tensor.unsqueeze(0)
            reference_shape = input_tensor.shape  # (C, D, H, W)
    
            # === Load distance map ===
            distance_map = torch.load(case_dir / "distance_map.pt").float()
            if distance_map.dim() == 3:
                distance_map = distance_map.unsqueeze(0)
            distance_map = self.pad_to_match(distance_map, reference_shape)
    
            # === Load multi-scale masks ===
            multi_size_masks = []
            for i in range(3):
                path = case_dir / f"multi_size_mask_{i}.pt"
                mask = torch.load(path).float() if path.exists() else torch.zeros_like(distance_map)
                if mask.dim() == 3:
                    mask = mask.unsqueeze(0)
                multi_size_masks.append(self.pad_to_match(mask, reference_shape))
            multi_masks_cat = torch.cat(multi_size_masks, dim=0)  # [3, D, H, W]
    
            # === Final input: [7, D, H, W] ===
            final_input = torch.cat([input_tensor, distance_map, multi_masks_cat], dim=0)
    
            if self.transform:
                final_input = self.transform(final_input)
    
            # === Optional target: segmentation mask ===
            if self.include_labels:
                seg_mask_path = case_dir / "seg_mask.pt"
                seg_mask = torch.load(seg_mask_path).float()
                if seg_mask.dim() == 3:
                    seg_mask = seg_mask.unsqueeze(0)
                seg_mask = self.pad_to_match(seg_mask, reference_shape)
                return final_input, seg_mask
    
            return final_input, str(case_dir.name)
        except Exception as e:
            print(f"Dataset raised an exception:\t{e}")

In [12]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [13]:
from torch.utils.data import DataLoader

def get_dataloader(data_root, batch_size=2, num_workers=4, shuffle=True, include_labels=True, sample_ids=None, transform=None):
    """
    Returns a PyTorch DataLoader for the MSLesionDataset.

    Args:
        data_root (str or Path): Root path containing sample subdirectories.
        batch_size (int): Batch size.
        num_workers (int): Number of worker threads.
        shuffle (bool): Whether to shuffle the dataset.
        include_labels (bool): If True, return segmentation masks (for training).
        sample_ids (list of str): Optional subset of sample names.
        transform (callable): Optional transform to apply to the input tensor.
    """
    dataset = MSLesionDataset(
        root_dir=data_root,
        include_labels=include_labels,
        sample_ids=sample_ids,
        transform=transform
    )

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False
    )

In [14]:
# Definition of some (non training) parameters
NUM_WORKERS_TRAIN = 0
NUM_WORKERS_VAL = 0
NUM_WORKERS_TEST = 0

BATCH_SIZE_TRAIN = 2
BATCH_SIZE_VAL = 2
BATCH_SIZE_TEST = 4

In [15]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [16]:
import os
import random
from pathlib import Path

# === Config ===
TRAIN_TENSOR_DATA_PATH = Path("../data/02-Tensor-Data/train")
BATCH_SIZE_TRAIN = 2
BATCH_SIZE_VAL = 2
NUM_WORKERS_TRAIN = 0
NUM_WORKERS_VAL = 0

# === Get all valid sample folders ===
all_cases = sorted([d for d in os.listdir(TRAIN_TENSOR_DATA_PATH) if os.path.isdir(TRAIN_TENSOR_DATA_PATH / d)])

# === Split for 80-20 train/val ===
random.seed(42)
random.shuffle(all_cases)
split = int(0.8 * len(all_cases))
train_ids = all_cases[:split]
val_ids = all_cases[split:]

# === Loaders ===
train_loader = get_dataloader(
    data_root=TRAIN_TENSOR_DATA_PATH,
    sample_ids=train_ids,
    batch_size=BATCH_SIZE_TRAIN,
    num_workers=NUM_WORKERS_TRAIN,
    shuffle=True,
    include_labels=True
)

val_loader = get_dataloader(
    data_root=TRAIN_TENSOR_DATA_PATH,
    sample_ids=val_ids,
    batch_size=BATCH_SIZE_VAL,
    num_workers=NUM_WORKERS_VAL,
    shuffle=False,
    include_labels=True
)


In [17]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [18]:
TEST_TENSOR_DATA_PATH_NOMASK = Path("../data/02-Tensor-Data/test/test")
TEST_TENSOR_DATA_PATH_MASK = Path("../data/02-Tensor-Data/test/test_MASK")

# Test set WITHOUT labels (for inference or prediction)
test_loader_no_labels = get_dataloader(
    data_root=TEST_TENSOR_DATA_PATH_NOMASK,
    include_labels=False,
    batch_size=BATCH_SIZE_TEST,
    num_workers=NUM_WORKERS_TEST,
    shuffle=False  # No shuffling for inference
)

# Test set WITH labels (for final evaluation or performance metrics)
test_loader_with_labels = get_dataloader(
    data_root=TEST_TENSOR_DATA_PATH_MASK,
    include_labels=True,
    batch_size=BATCH_SIZE_TEST,
    num_workers=NUM_WORKERS_TEST,
    shuffle=False  # No shuffling for evaluation
)


In [19]:
from sklearn.model_selection import KFold
import os

def get_k_fold_loaders(data_root, k=5, batch_size=BATCH_SIZE_TRAIN, num_workers=NUM_WORKERS_TRAIN, include_labels=True):
    # Get all case IDs
    all_cases = sorted(os.listdir(data_root))
    
    kf = KFold(n_splits=k, shuffle=True, random_state=42)
    
    fold_loaders = []

    for train_idx, val_idx in kf.split(all_cases):
        train_ids = [all_cases[i] for i in train_idx]
        val_ids = [all_cases[i] for i in val_idx]

        # Updated to pass `include_labels`
        train_loader = get_dataloader(data_root, sample_ids=train_ids, batch_size=batch_size, num_workers=num_workers, include_labels=include_labels)
        val_loader = get_dataloader(data_root, sample_ids=val_ids, batch_size=batch_size, num_workers=num_workers, include_labels=include_labels, shuffle=False)

        fold_loaders.append((train_loader, val_loader))

    return fold_loaders

In [20]:
# Get K-Fold DataLoaders
fold_loaders = get_k_fold_loaders(TRAIN_TENSOR_DATA_PATH, k=5, batch_size=2, num_workers=0)

# Access the first fold
train_loader, val_loader = fold_loaders[0]

### PyTorch Lightning DataModule

This particular version of PyTorch Lightning, and in general from version 2.x onward require a ***LightningDataModule*** instead of passing the dataloaders directly to the ***.fit()*** method

In [21]:
from pytorch_lightning import LightningDataModule
import random
import os
from pathlib import Path

class MSLesionSegmentationDataModule(LightningDataModule):
    def __init__(self, data_root, train_split=0.8, batch_size=2, num_workers=4, include_labels=True):
        """
        Args:
            data_root: Path to the root directory containing the dataset.
            train_split: Fraction of data to use for training (default 80%).
            batch_size: Batch size for the dataloader.
            num_workers: Number of workers for the dataloader.
            include_labels: Whether to include segmentation masks during training/validation.
        """
        super().__init__()
        self.data_root = data_root
        self.train_split = train_split
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.include_labels = include_labels

        self.train_dataloader_instance = None
        self.val_dataloader_instance = None

    def train_dataloader(self):
        return self.train_dataloader_instance

    def val_dataloader(self):
        return self.val_dataloader_instance

    def test_dataloader(self):
        return self.val_dataloader_instance  # Optional, if you want to test on the validation set as well.

    def prepare_data(self):
        """Prepare any datasets if needed, such as downloading or preprocessing."""
        pass

    def setup(self, stage=None):
        """Assign train/val datasets for use in dataloaders."""
        # List all sample directories
        all_samples = sorted([d for d in Path(self.data_root).iterdir() if d.is_dir()])

        # Shuffle the sample list before splitting
        random.seed(42)  # To ensure reproducibility
        random.shuffle(all_samples)

        # Split the dataset into training and validation sets
        split_idx = int(len(all_samples) * self.train_split)
        train_samples = all_samples[:split_idx]
        val_samples = all_samples[split_idx:]

        # Create the train and val dataloaders, passing the `include_labels` parameter
        self.train_dataloader_instance = get_dataloader(
            self.data_root,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            sample_ids=[d.name for d in train_samples],
            include_labels=self.include_labels,  # Pass this argument
        )

        self.val_dataloader_instance = get_dataloader(
            self.data_root,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            sample_ids=[d.name for d in val_samples],
            include_labels=self.include_labels,  # Pass this argument
        )

In [22]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


## The Model's Architecture

In [3]:
import torch
import torch.nn as nn

# Standard double convolution block
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
    )

# Final output layer: Conv3D + Sigmoid for probability map
def final_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv3d(in_channels, out_channels, kernel_size=1),
        nn.Sigmoid()
    )

In [7]:
# 3D U-Net Architecture
class UNet(nn.Module):
    def __init__(self, num_classes, in_channels):
        super(UNet, self).__init__()

        self.max_pool3d = nn.MaxPool3d(kernel_size=2, stride=2)

        # == Contracting Path == #
        self.down_conv_1 = double_conv(in_channels, 64)
        self.down_conv_2 = dobule_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)

        # == Expanding Path == #
        self.up_transpose_1 = nn.ConvTranspose3d(
            in_channels=1024,
            out_channels=512,
            kernel_size=2,
        )
        self.up_conv_1 = double_conv(1024, 512)
        
        self.up_transpose_2 = nn.ConvTranspose3d(
            in_channels=512,
            out_channels=256,
            kernel_size=2,
        )
        self.up_conv_2 = double_conv(512, 256)

        self.up_transpose_3 = nn.ConvTranspose3d(
            in_channels=256,
            out_channels=128,
            kernel_size=2,
        )
        self.up_conv_3 = double_conv(256, 128)

        self.up_transpose_4 = nn.ConvTranspose3d(
            in_channels=128,
            out_channels=64,
            kernel_size=2,
        )
        self.up_conv_4 = double_conv(128, 64)

        self.prob_out = final_conv(64, num_classes)

        self.out = nn.Conv3d(
            in_channels=64,
            out_channels=num_classes,
            kernel_size=1
        )

    def forward(self, x):
        
        down_1 = self.down_conv_1(x)
        down_2 = self.max_pool3d(down_1)
        down_3 = self.down_conv_2(down_2)
        down_4 = self.max_pool3d(down_3)
        down_5 = self.down_conv_3(down_4)
        down_6 = self.max_pool3d(down_5)
        down_7 = self.down_conv_4(down_6)
        down_8 = self.max_pool3d(down_7)
        down_9 = self.down_conv_5(down_8)

        up_1 = self.up_transpose_1(down_9)
        x = self.up_conv_1(torch.cat([down_7, up_1], 1))
        
        up_2 = self.up_transpose_2(x)
        x = self.up_conv_2(torch.cat([down_5, up_2], 1))
        
        up_3 = self.up_transpose_3(x)
        x = self.up_conv_3(torch.cat([down_3, up_3], 1))
        
        up_4 = self.up_transpose_4(x)
        x = self.up_conv_4(torch.cat([down_1, up_4], 1))
        
        out = self.out(x)
        prob_out = self.prob_out(x)
        
        return out, prob_out

NameError: name 'nn' is not defined

In [8]:
# Perform a sanity check on the model
input_image = torch.rand((512, 512, 512))
model = UNet(num_classes=2m in_channels=1)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
outputs, prob_outputs = model(input_image)
print(f"U-Net outputs shape:\t{outputs.shape}")
print(f"U-Net probability map outputs shape:\t{prob_outputs.shape}")

SyntaxError: invalid decimal literal (825341197.py, line 3)

In [24]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [27]:
# The Summary of the Architecture
from torchsummary import summary

# Set the device and the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=1, out_channels=1).to(device)

# Get the summary of the model (for 3D input: channels, depth, height, width)
summary(model, (128, 128, 128))  # 7 input channels now

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1    [-1, 32, 128, 128, 128]           6,080
              ReLU-2    [-1, 32, 128, 128, 128]               0
            Conv3d-3    [-1, 64, 128, 128, 128]          55,360
              ReLU-4    [-1, 64, 128, 128, 128]               0
         MaxPool3d-5       [-1, 64, 64, 64, 64]               0
            Conv3d-6      [-1, 128, 64, 64, 64]         221,312
              ReLU-7      [-1, 128, 64, 64, 64]               0
            Conv3d-8      [-1, 128, 64, 64, 64]         442,496
              ReLU-9      [-1, 128, 64, 64, 64]               0
  ConvTranspose3d-10    [-1, 64, 128, 128, 128]          65,600
           Conv3d-11    [-1, 64, 128, 128, 128]         110,656
             ReLU-12    [-1, 64, 128, 128, 128]               0
           Conv3d-13    [-1, 32, 128, 128, 128]          55,328
             ReLU-14    [-1, 32, 128, 1

In [28]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 12.29 MB


In [29]:
# For the graphical visualization
from torchviz import make_dot
import gc

# Dummy input with correct number of channels (7), batch size = 1
dummy_input = torch.randn(128, 128, 128).cuda()

# Explicitly instruct the model to not save the gradients to avoid computation graph creation
with torch.no_grad():
    # Forward pass through the model
    output, _ = model(dummy_input)

    # Visualize the graph and save as PNG
    make_dot(output, params=dict(model.named_parameters())).render("model_architecture", format="png")

# Force garbage collection
gc.collect()

# Clear non-essential GPU memory
del dummy_input
del output
torch.cuda.empty_cache()

In [30]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 12.29 MB


In [31]:
# Clear the GPU memory used for the dummy input and re-initialize the model
torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(num_classes=2, in_channels=1).to(device)

In [32]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 11.78 MB


## Visualized Model Architecture

![Model Architecture Plot](./model_architecture.png)

## The Trainers

In [44]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torchmetrics.segmentation import DiceScore
from torch.nn import BCELoss

class MSLesionSegmentationModel(pl.LightningModule):
    def __init__(self, model, lr=1e-4):
        super().__init__()

        self.model = model
        self.lr = lr

        # DiceScore for binary segmentation
        self.dice_metric = DiceScore(num_classes=2, average='micro')

        # Binary Cross-Entropy loss
        self.bce_loss = BCELoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch  # x: [B, D, H, W], y: [B, D, H, W]

        # Forward pass: get both voxel prediction and probability map
        output, prob_map = self(x)

        # Compute Dice Loss for segmentation (voxel)
        dice_loss = 1 - self.dice_metric(torch.sigmoid(output), y.int())

        # Compute BCELoss for the probability map
        bce_loss = self.bce_loss(torch.sigmoid(prob_map), y.float())  # sigmoid for probability map

        # Combine losses (you can adjust the weighting based on your need)
        loss = dice_loss + bce_loss

        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch  # x: [B, D, H, W], y: [B, D, H, W]
        
        # Forward pass: get both voxel prediction and probability map
        output, prob_map = self(x)

        # Compute Dice Loss for segmentation (voxel)
        dice_loss = 1 - self.dice_metric(torch.sigmoid(output), y.int())

        # Compute BCELoss for the probability map
        bce_loss = self.bce_loss(torch.sigmoid(prob_map), y.float())

        # Total validation loss
        val_loss = dice_loss + bce_loss

        # Calculate Dice score for the voxel output (segmentation mask)
        preds = torch.sigmoid(output)
        dice = self.dice_metric(preds, y.int())

        self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_dice', dice, on_step=False, on_epoch=True, prog_bar=True)

        return {"val_loss": val_loss, "val_dice": dice}

    def test_step(self, batch, batch_idx):
        x, y = batch  # x: [B, D, H, W], y: [B, D, H, W]
        
        # Forward pass: get both voxel prediction and probability map
        output, prob_map = self(x)

        # Calculate Dice score for the voxel output (segmentation mask)
        preds = torch.sigmoid(output)
        dice = self.dice_metric(preds, y.int())

        self.log('test_dice', dice, on_step=False, on_epoch=True, prog_bar=True)
        return dice

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_dice'}

    def on_train_epoch_end(self):
        # Log the epoch's progress
        self.log('epoch', self.current_epoch, prog_bar=False)
        
        model_path = f"./model_checkpoints/model_epoch_{self.current_epoch}.pth"
        torch.save({
            'epoch': self.current_epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.trainer.optimizers[0].state_dict(),
            'lr': self.lr,
            'dice_metric': self.dice_metric.compute().item(),  # Save metric if desired
        }, model_path)

## The Training

In [45]:
# Register a free account with Weights and Biases, and create a new project in order to obtain an API Key for the training
import os
from dotenv import load_dotenv
import wandb

# Load the .env file
load_dotenv()

# Get the API key from the env variable
api_key = os.getenv("WANDB_API_KEY")

# Login to wandb
if api_key:
    os.environ["WANDB_API_KEY"] = api_key
    wandb.login()
else:
    print("❌ WANDB_API_KEY not found in .env file.")



In [46]:
# Setup the Weights and Biases logger
import wandb
from pytorch_lightning.loggers import WandbLogger

wandb_logger = WandbLogger(
    project='MSLesSeg-4-ICPR',     # Change to your actual project name
    name='nnUNet_MoE_run_3', # A specific run name
    log_model=True          # Optional: log model checkpoints
)

In [47]:
# Set the multiprocessing start to 'spawn' instead of 'fork' due to CUDA issues with the Dataloader
import multiprocessing
import torch

# Set the start method for multiprocessing
multiprocessing.set_start_method('spawn', force=True)

In [48]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 262.15 MB


In [49]:
# Add the profiler to keep track of GPU overhead
from pytorch_lightning.profilers import PyTorchProfiler

# Slightly more sophisticated profiler
profiler = PyTorchProfiler(
    on_trace_ready=lambda prof: print(prof.key_averages().table(
        sort_by="self_cuda_memory_usage", row_limit=15  # Change as needed
    )),
    profile_memory=True,
    record_shapes=True,
    with_stack=True,
)

In [50]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 262.15 MB


In [51]:
# A computation to estimate the GPU memory consumption of the model

# Print the total number of parameters in your model
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")

# Calculate the approximate memory size of the model (in MB)
model_size_MB = total_params * 4 / (1024 ** 2)  # 4 bytes per parameter for float32
print(f"Estimated model size: {model_size_MB:.2f} MB")

Total number of parameters: 1090383
Estimated model size: 4.16 MB


In [52]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [53]:
import torch
import gc

# Run garbage collection to clean up unused objects
gc.collect()

# Empty the PyTorch cache
torch.cuda.empty_cache()

In [54]:
import os

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.callbacks import TQDMProgressBar

# Initialize the TQDM Progress bar for global performance metrics
progress_bar = TQDMProgressBar(refresh_rate=1)

# Initialize the model with 7 input channels
model = MSLesionSegmentationModel(
    model=UNetMoE(in_channels=7, out_channels=1, num_experts=4, expert_dim=512)
)

# Set up checkpointing to save best model
checkpoint_callback = ModelCheckpoint(
    monitor='val_dice',
    dirpath='checkpoints/',
    filename='best_checkpoint',
    save_top_k=1,
    mode='max',
)

# Early stopping to prevent overfitting
early_stop_callback = EarlyStopping(
    monitor='val_dice',
    patience=10,
    mode='max',
)

# Path to training data
data_root = TRAIN_TENSOR_DATA_PATH  # Replace with actual path if needed

# Initialize the data module
data_module = MSLesionSegmentationDataModule(
    data_root=data_root,
    train_split=0.8,
    batch_size=1,
    num_workers=0,  # Can be tuned depending on your CPU
)

# Initialize the trainer
trainer = Trainer(
    accelerator="gpu",       # Or "auto" if unsure
    devices=1,
    precision=16,
    max_epochs=100,
    callbacks=[checkpoint_callback, early_stop_callback],
    logger=wandb_logger,
    # profiler="simple"  # Optional, if performance profiling is needed
)

# Train
trainer.fit(model, datamodule=data_module)

/home/drew/miniconda3/envs/mslesseg4icpr/lib/python3.11/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/drew/miniconda3/envs/mslesseg4icpr/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type      | Params | Mode 
--------------------------------------------------
0 | model       | UNetMoE   | 957 K  | train
1 | dice_metric | DiceScore | 0      | train
--------------------------------------------------
957 K     Trainable params
0

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/home/drew/miniconda3/envs/mslesseg4icpr/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [None]:
# Load the best model's checkpoint and evaluate it

In [None]:
# Evaluate the model on the test set
trainer.test(model, test_dataloader=test_loader)

## Visualizing the Learned Representations

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

# Function to extract features from the model
def extract_features(model, dataloader, device='cuda'):
    model.eval()
    features = []
    labels = []
    
    with torch.no_grad():
        for data in dataloader:
            inputs, target = data['image'].to(device), data['mask'].to(device)  # Replace 'image' and 'mask' keys as per your dataset
            
            # Forward pass through the model
            output = model(inputs)  # This is where we extract the feature; adjust according to your model's architecture
            
            # Extract features from MoE (or any other layer you want to visualize)
            # Here I assume output is the final feature map after the MoE layer
            feature_map = output.view(output.size(0), -1)  # Flatten the features to (batch_size, features)
            features.append(feature_map.cpu().numpy())
            labels.append(target.cpu().numpy())  # Add the target (or ground truth) labels for color coding in t-SNE

    features = np.concatenate(features, axis=0)  # Combine all the feature maps
    labels = np.concatenate(labels, axis=0)  # Combine all the labels

    return features, labels

# Function to apply t-SNE on features
def plot_tsne(features, labels):
    # Standardize the features (optional but recommended for t-SNE)
    scaler = StandardScaler()
    features = scaler.fit_transform(features)

    # Apply t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(features)
    
    # Plotting
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=labels, cmap='jet', alpha=0.5)
    plt.colorbar(scatter)
    plt.title('t-SNE Visualization of Learned Representations')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.show()

# Example Usage
train_dataloader = # Your train DataLoader here

# Extract features and labels from the model
features, labels = extract_features(model, train_dataloader, device='cuda')

# Plot t-SNE visualization
plot_tsne(features, labels)

## Post Hoc Model Explainability - GradCAM++

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2
from captum.attr import LayerGradCam, LayerGradCamPlusPlus
from captum.attr import GuidedBackprop

# Function to visualize GradCAM++ output
def visualize_gradcam_plus_plus(model, input_tensor, target_class=None, layer_name='decoder', device='cuda'):
    # Move input tensor to the appropriate device
    input_tensor = input_tensor.to(device)

    # Set model to evaluation mode
    model.eval()
    
    # Create GradCAM++ object for the specified layer
    gradcam_pp = LayerGradCamPlusPlus(model, model.decoder[2])  # Assuming 'decoder' is your final layer, adjust accordingly.
    
    # Apply GradCAM++ to input tensor
    attributions = gradcam_pp.attribute(input_tensor, target=target_class)
    
    # Convert attributions to numpy
    attributions = attributions.cpu().detach().numpy()
    
    # Normalize the heatmap
    heatmap = np.sum(attributions[0], axis=0)  # Summing over channels for a single-channel heatmap
    heatmap = np.maximum(heatmap, 0)  # ReLU to ignore negative values
    heatmap = cv2.resize(heatmap, (input_tensor.shape[2], input_tensor.shape[3]))  # Resize to input size
    heatmap = cv2.normalize(heatmap, None, 0, 1, cv2.NORM_MINMAX)  # Normalize the heatmap to [0, 1]
    
    # Convert the original image to numpy and rescale
    input_image = input_tensor[0].cpu().detach().numpy().transpose(1, 2, 0)
    input_image = cv2.resize(input_image, (input_tensor.shape[2], input_tensor.shape[3]))  # Resize to input size
    
    # Create a colormap for the heatmap
    colormap = plt.get_cmap('jet')
    colored_heatmap = colormap(heatmap)  # Apply colormap
    
    # Overlay the heatmap on top of the original image
    superimposed_img = np.uint8(input_image * 255)  # Convert original image back to [0, 255] range
    superimposed_img = cv2.addWeighted(superimposed_img, 0.7, np.uint8(colored_heatmap[:, :, :3] * 255), 0.3, 0)
    
    # Plot the result
    plt.imshow(superimposed_img)
    plt.title('GradCAM++ Heatmap')
    plt.axis('off')
    plt.show()

# Example Usage (assuming `model` is your trained model and `input_tensor` is an input image)
input_tensor = torch.randn(1, 3, 128, 128).to(device)  # Example input, use your actual input tensor
target_class = None  # You can provide a target class or leave as None for model's top predicted class

visualize_gradcam_plus_plus(model, input_tensor, target_class=target_class, layer_name='decoder', device='cuda')

In [5]:
# Check the size of the stored (pre-processed) Tensors
import torch
from pathlib import Path

def inspect_pt_file(file_path):
    file_path = Path(file_path)
    
    if not file_path.exists():
        print(f"File does not exist: {file_path}")
        return
    
    data = torch.load(file_path, map_location='cpu')

    print(f"\nLoaded file: {file_path}")
    
    if isinstance(data, torch.Tensor):
        print(f"Single Tensor - Shape: {data.shape}")
    
    elif isinstance(data, dict):
        print("Dictionary of tensors:")
        for key, value in data.items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: shape = {value.shape}")
            else:
                print(f"  {key}: type = {type(value)}")
    
    elif isinstance(data, (list, tuple)):
        print(f"{type(data).__name__} of tensors:")
        for idx, item in enumerate(data):
            if isinstance(item, torch.Tensor):
                print(f"  [{idx}]: shape = {item.shape}")
            else:
                print(f"  [{idx}]: type = {type(item)}")
    
    else:
        print(f"Unknown type loaded: {type(data)}")
        print(f"Unkown type shape: {data.shape}")

INPUT_PATH_PREFIX = "../data/02-Tensor-Data/train/MSLS_000/"

In [6]:
# Input Tensor Shape (Stacked Modalities)
inspect_pt_file(INPUT_PATH_PREFIX + "input_tensor.pt")


Loaded file: ../data/02-Tensor-Data/train/MSLS_000/input_tensor.pt
Single Tensor - Shape: torch.Size([3, 182, 218, 182])


In [None]:
# Distance Map Shape
inspect_pt_file(INPUT_PATH_PREFIX + "distance_map.pt")

In [None]:
# Multi Size Mask - 0
inspect_pt_file(INPUT_PATH_PREFIX + "multi_size_mask_0.pt")

In [None]:
# Multi Size Mask - 1
inspect_pt_file(INPUT_PATH_PREFIX + "multi_size_mask_1.pt")

In [None]:
# Multi Size Mask - 2
inspect_pt_file(INPUT_PATH_PREFIX + "multi_size_mask_2.pt")

In [None]:
# Segmentation Mask
inspect_pt_file(INPUT_PATH_PREFIX + "seg_mask.pt")