# U-Net 4 Multiple Sclerosis Lesion Segmentation - ICPR Challenge
### Authors: Andrew R. Darnall, Giovanni Spadaro @ UniCT
---

## 🎯 Competition Objective: MS Lesion Segmentation

The central goal of this competition is the **automatic segmentation of Multiple Sclerosis (MS) lesions** using **multi-modal MRI data** and **deep learning algorithms**.

### 🧪 Provided Data
Participants were given:
- **MRI scans** in three modalities:
  - **FLAIR**
  - **T1-weighted (T1-w)**
  - **T2-weighted (T2-w)**
- **Ground-truth segmentation masks**, which are:
  - **Binary masks**:  
    - **White pixels** → MS lesion regions  
    - **Black pixels** → Background

### 🧠 Task Description
- Participants could use **any or all modalities**, along with the ground-truth labels, to:
  - Develop **deep learning-based models** for **automatic lesion segmentation**
- MS lesions appear as **irregular clusters of pixels** with **high variability in size and shape**
- These lesions are often **difficult to detect** via visual inspection, requiring **expert-level interpretation**

The ultimate goal is to create **fully automated segmentation pipelines** that can robustly identify and delineate MS lesions from raw MRI data.


## 🧠 MSLesSeg Dataset Overview

As part of this competition, participants were provided with the **MSLesSeg Dataset** — a **comprehensively annotated, multi-modal MRI dataset** designed for advancing **lesion segmentation** research in medical imaging.

### 📊 Dataset Composition
- **Total Patients:** 75 (48 women, 27 men)  
- **Age Range:** 18–59 years (Mean: 37 ± 10.3 years)  
- **Longitudinal Timepoints:**  
  - 50 patients with 1 timepoint  
  - 15 patients with 2 timepoints  
  - 5 patients with 3 timepoints  
  - 5 patients with 4 timepoints  
- **Time Interval Between Scans:** ~1.27 ± 0.62 years  
- **Total MRI Series:** 115

### 🧬 Imaging Modalities
Each timepoint includes **three core MRI modalities**:
- **T1-weighted (T1-w)**
- **T2-weighted (T2-w)**
- **FLAIR (Fluid-Attenuated Inversion Recovery)**

### 🧑‍⚕️ Expert Annotation
- Lesions were **manually annotated** by clinical experts.
- **FLAIR sequences** were the primary reference for lesion labeling.
- **T1-w and T2-w** scans supported **multi-contrast lesion characterization**.

### 🧪 Dataset Splits
- **Training Set:** 53 scans  
- **Test Set:** 22 scans  

### ✅ Ethical Compliance
- **Ethical approval** was obtained from the corresponding Hospital Ethics Committee.
- **Informed consent** was acquired from all participating patients.

---

# The Experiment

Below is the code used for the:

1) Preprocessing of the ***Brain MRI*** scans
2) Definition of Dataset, Dataloader and LihgtningDataModule classes
3) ***U-Net*** architecture
4) ***PyTorch Lightning*** Trainer
5) Training & Evaluation
6) Model Exaplainability with the post-hoc method ***GradCam++***

---

## 🛠️ Preprocessing & Annotation Workflow

The MSLesSeg dataset underwent a **comprehensive preprocessing pipeline** and **expert-driven manual annotation** to ensure **standardization** and **label quality** for downstream MS lesion segmentation tasks.

### 🧼 Preprocessing Pipeline
1. **Anonymization** of all MRI scans to protect patient privacy.
2. **DICOM to NIfTI conversion**, leveraging NIfTI's wide adoption in neuroimaging.
3. **Co-registration to the MNI152 1mm³ isotropic template** using **FLIRT** (FMRIB’s Linear Image Registration Tool), ensuring all scans are aligned to a **common anatomical space**.
4. **Brain extraction** via **BET** (Brain Extraction Tool) to remove non-brain tissues and isolate relevant structures.

This pipeline guarantees that all images are **standardized** and **aligned**, which is critical for **automated MS lesion segmentation algorithms**.

---

### 🖋️ Ground-Truth Annotation Protocol
- Lesions were **manually segmented** on the **FLAIR modality** for each patient and timepoint.
- **T1-w and T2-w** modalities were used to **cross-validate ambiguous cases**.
- Annotation was conducted by a **trained junior rater**, under supervision of:
  - A **senior neuroradiologist**
  - A **senior neurologist**
- Annotation sessions included:
  - Multiple **training meetings** to establish a **consistent segmentation strategy**
  - Use of **JIM9** — a high-end tool for **medical image segmentation and analysis**
  - Regular **expert validation checkpoints** to ensure consistency and accuracy

The final masks, reviewed and approved by senior experts, are considered the **gold-standard ground truth**.

---

## 🧾 Key Annotation Highlights
- **Independent segmentation** for each patient/timepoint to avoid bias
- Conducted on **FLAIR scans registered to MNI space**
- **Validated ground-truth masks** ready for training and evaluation



In [None]:
import os
import torch
import nibabel as nib
import numpy as np
from scipy.ndimage import distance_transform_edt
from skimage.transform import resize
from pathlib import Path
from tqdm.notebook import tqdm

In [None]:
# Obtain the serivce account key for the goolge cloud storage bucket
from google.colab import files
uploaded = files.upload()

In [None]:
# Attach the Google Cloud Storage Bucket to access the dataset
import os
from google.cloud import storage

os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'segformer4mslesseg-ce06635ff776.json'

bucket_name = 'mslesseg_4_icpr_bucket'

client = storage.Client()
bucket = client.bucket(bucket_name)

In [None]:
# Make sure that gcsfs is installed
!pip install gcsfs



In [None]:
# Install other rquirements that colab likely does not have
!pip install pytorch-lightning
!pip install torchinfo
!pip install grad-cam



In [None]:
# Helper function to load the nifti files
import gcsfs

fs = gcsfs.GCSFileSystem()

def load_nifti(file_path):
    with fs.open(file_path, 'rb') as f:
        return nib.load(f).get_fdata()

In [None]:
def preprocess_case(input_dir, output_dir, case_id):
    flair = load_nifti(os.path.join(input_dir, f"{case_id}_flair.nii.gz"))
    t1 = load_nifti(os.path.join(input_dir, f"{case_id}_t1.nii.gz"))
    t2 = load_nifti(os.path.join(input_dir, f"{case_id}_t2.nii.gz"))
    seg = load_nifti(os.path.join(input_dir, f"{case_id}_seg.nii.gz")).astype(np.uint8)

    # Stack and convert to tensors
    stacked = np.stack([flair, t1, t2], axis=0)
    input_tensor = torch.tensor(stacked, dtype=torch.float32)
    seg_tensor = torch.tensor(seg, dtype=torch.uint8).unsqueeze(0)

    # Save output tensors to GCS
    output_case_dir = os.path.join(output_dir, case_id)
    fs.makedirs(output_case_dir, exist_ok=True)

    with fs.open(os.path.join(output_case_dir, "input_tensor.pt"), 'wb') as f:
        torch.save(input_tensor, f)
    with fs.open(os.path.join(output_case_dir, "seg_mask.pt"), 'wb') as f:
        torch.save(seg_tensor, f)

In [None]:
from pathlib import Path
from tqdm import tqdm

def run_preprocessing(root_path, output_path):

    all_files = fs.ls(root_path)

    for case_dir in tqdm(all_files):
        case_id = case_dir.rstrip('/').split('/')[-1]
        output_case_path = os.path.join(output_path, case_id)

        if fs.exists(os.path.join(output_case_path, "input_tensor.pt")):
            print(f"✅ Skipping {case_id}, already processed.")
            continue

        try:
            preprocess_case(case_dir, output_path, case_id)
        except Exception as e:
            print(f"❌ Failed on {case_id}: {e}")

In [None]:
# Run the preprocessing on the training set
RAW_DATA_PATH = "gs://mslesseg_4_icpr_bucket/data/01-Pre-Processed-Data/train"
OUTPUT_PATH = "gs://mslesseg_4_icpr_bucket/data/02-Tensor-Data/train"

run_preprocessing(RAW_DATA_PATH, OUTPUT_PATH)

  1%|          | 1/93 [00:00<00:21,  4.20it/s]

✅ Skipping MSLS_000, already processed.


  2%|▏         | 2/93 [00:01<01:21,  1.12it/s]

✅ Skipping MSLS_001, already processed.


  3%|▎         | 3/93 [00:01<00:53,  1.67it/s]

✅ Skipping MSLS_002, already processed.


  4%|▍         | 4/93 [00:03<01:18,  1.13it/s]

✅ Skipping MSLS_003, already processed.


  5%|▌         | 5/93 [00:03<00:59,  1.49it/s]

✅ Skipping MSLS_004, already processed.


  6%|▋         | 6/93 [00:04<01:17,  1.12it/s]

✅ Skipping MSLS_005, already processed.


  8%|▊         | 7/93 [00:05<00:58,  1.46it/s]

✅ Skipping MSLS_006, already processed.


  9%|▊         | 8/93 [00:06<01:15,  1.12it/s]

✅ Skipping MSLS_007, already processed.


 10%|▉         | 9/93 [00:06<00:58,  1.45it/s]

✅ Skipping MSLS_008, already processed.


 11%|█         | 10/93 [00:07<01:13,  1.12it/s]

✅ Skipping MSLS_009, already processed.


 12%|█▏        | 11/93 [00:08<00:56,  1.44it/s]

✅ Skipping MSLS_010, already processed.


 13%|█▎        | 12/93 [00:09<01:11,  1.14it/s]

✅ Skipping MSLS_011, already processed.


 14%|█▍        | 13/93 [00:09<00:56,  1.42it/s]

✅ Skipping MSLS_012, already processed.


 15%|█▌        | 14/93 [00:11<01:09,  1.14it/s]

✅ Skipping MSLS_013, already processed.


 16%|█▌        | 15/93 [00:11<00:55,  1.40it/s]

✅ Skipping MSLS_014, already processed.


 17%|█▋        | 16/93 [00:12<01:07,  1.13it/s]

✅ Skipping MSLS_015, already processed.


 18%|█▊        | 17/93 [00:13<00:53,  1.41it/s]

✅ Skipping MSLS_016, already processed.


 19%|█▉        | 18/93 [00:14<01:05,  1.14it/s]

✅ Skipping MSLS_017, already processed.


 20%|██        | 19/93 [00:14<00:52,  1.41it/s]

✅ Skipping MSLS_018, already processed.


 22%|██▏       | 20/93 [00:15<01:03,  1.14it/s]

✅ Skipping MSLS_019, already processed.


 23%|██▎       | 21/93 [00:16<00:50,  1.42it/s]

✅ Skipping MSLS_020, already processed.


 24%|██▎       | 22/93 [00:17<01:02,  1.14it/s]

✅ Skipping MSLS_021, already processed.


 25%|██▍       | 23/93 [00:17<00:49,  1.41it/s]

✅ Skipping MSLS_022, already processed.


 26%|██▌       | 24/93 [00:19<01:00,  1.14it/s]

✅ Skipping MSLS_023, already processed.


 27%|██▋       | 25/93 [00:19<00:47,  1.43it/s]

✅ Skipping MSLS_024, already processed.


 28%|██▊       | 26/93 [00:20<00:58,  1.14it/s]

✅ Skipping MSLS_025, already processed.


 29%|██▉       | 27/93 [00:20<00:45,  1.44it/s]

✅ Skipping MSLS_026, already processed.


 30%|███       | 28/93 [00:22<00:57,  1.13it/s]

✅ Skipping MSLS_027, already processed.


 31%|███       | 29/93 [00:22<00:44,  1.44it/s]

✅ Skipping MSLS_028, already processed.


 32%|███▏      | 30/93 [00:23<00:56,  1.12it/s]

✅ Skipping MSLS_029, already processed.


 33%|███▎      | 31/93 [00:24<00:43,  1.44it/s]

✅ Skipping MSLS_030, already processed.


 34%|███▍      | 32/93 [00:25<00:53,  1.13it/s]

✅ Skipping MSLS_031, already processed.


 35%|███▌      | 33/93 [00:25<00:41,  1.45it/s]

✅ Skipping MSLS_032, already processed.


 37%|███▋      | 34/93 [00:26<00:52,  1.13it/s]

✅ Skipping MSLS_033, already processed.


 38%|███▊      | 35/93 [00:27<00:40,  1.44it/s]

✅ Skipping MSLS_034, already processed.


 39%|███▊      | 36/93 [00:28<00:50,  1.12it/s]

✅ Skipping MSLS_035, already processed.


 40%|███▉      | 37/93 [00:28<00:38,  1.44it/s]

✅ Skipping MSLS_036, already processed.


 41%|████      | 38/93 [00:30<00:49,  1.12it/s]

✅ Skipping MSLS_037, already processed.


 42%|████▏     | 39/93 [00:30<00:37,  1.43it/s]

✅ Skipping MSLS_038, already processed.


 43%|████▎     | 40/93 [00:31<00:47,  1.12it/s]

✅ Skipping MSLS_039, already processed.


 44%|████▍     | 41/93 [00:31<00:36,  1.42it/s]

✅ Skipping MSLS_040, already processed.


 45%|████▌     | 42/93 [00:33<00:45,  1.13it/s]

✅ Skipping MSLS_041, already processed.


 46%|████▌     | 43/93 [00:33<00:34,  1.43it/s]

✅ Skipping MSLS_042, already processed.


 47%|████▋     | 44/93 [00:34<00:43,  1.13it/s]

✅ Skipping MSLS_043, already processed.


 48%|████▊     | 45/93 [00:35<00:33,  1.44it/s]

✅ Skipping MSLS_044, already processed.


 49%|████▉     | 46/93 [00:36<00:41,  1.13it/s]

✅ Skipping MSLS_045, already processed.


 51%|█████     | 47/93 [00:36<00:31,  1.44it/s]

✅ Skipping MSLS_046, already processed.


 52%|█████▏    | 48/93 [00:38<00:39,  1.13it/s]

✅ Skipping MSLS_047, already processed.


 53%|█████▎    | 49/93 [00:38<00:30,  1.43it/s]

✅ Skipping MSLS_048, already processed.


 54%|█████▍    | 50/93 [00:40<00:51,  1.20s/it]

✅ Skipping MSLS_049, already processed.


 55%|█████▍    | 51/93 [00:40<00:38,  1.10it/s]

✅ Skipping MSLS_050, already processed.


 56%|█████▌    | 52/93 [00:41<00:34,  1.20it/s]

✅ Skipping MSLS_051, already processed.


 57%|█████▋    | 53/93 [00:41<00:26,  1.52it/s]

✅ Skipping MSLS_052, already processed.


 58%|█████▊    | 54/93 [00:42<00:23,  1.67it/s]

✅ Skipping MSLS_053, already processed.


 59%|█████▉    | 55/93 [00:43<00:26,  1.46it/s]

✅ Skipping MSLS_054, already processed.


 60%|██████    | 56/93 [00:43<00:22,  1.61it/s]

✅ Skipping MSLS_055, already processed.


 61%|██████▏   | 57/93 [00:43<00:18,  1.97it/s]

✅ Skipping MSLS_056, already processed.


 62%|██████▏   | 58/93 [00:44<00:14,  2.33it/s]

✅ Skipping MSLS_057, already processed.


 63%|██████▎   | 59/93 [00:45<00:23,  1.43it/s]

✅ Skipping MSLS_058, already processed.


 65%|██████▍   | 60/93 [00:45<00:18,  1.77it/s]

✅ Skipping MSLS_059, already processed.


 66%|██████▌   | 61/93 [00:45<00:14,  2.14it/s]

✅ Skipping MSLS_060, already processed.


 67%|██████▋   | 62/93 [00:47<00:22,  1.38it/s]

✅ Skipping MSLS_061, already processed.


 68%|██████▊   | 63/93 [00:47<00:17,  1.71it/s]

✅ Skipping MSLS_062, already processed.


 69%|██████▉   | 64/93 [00:47<00:13,  2.07it/s]

✅ Skipping MSLS_063, already processed.


 70%|██████▉   | 65/93 [00:48<00:11,  2.44it/s]

✅ Skipping MSLS_064, already processed.


 71%|███████   | 66/93 [00:48<00:09,  2.78it/s]

✅ Skipping MSLS_065, already processed.


 72%|███████▏  | 67/93 [00:49<00:16,  1.54it/s]

✅ Skipping MSLS_066, already processed.


 73%|███████▎  | 68/93 [00:49<00:13,  1.90it/s]

✅ Skipping MSLS_067, already processed.


 74%|███████▍  | 69/93 [00:50<00:12,  1.96it/s]

✅ Skipping MSLS_068, already processed.


 75%|███████▌  | 70/93 [00:50<00:11,  2.02it/s]

✅ Skipping MSLS_069, already processed.


 76%|███████▋  | 71/93 [00:51<00:10,  2.12it/s]

✅ Skipping MSLS_070, already processed.


 77%|███████▋  | 72/93 [00:51<00:09,  2.12it/s]

✅ Skipping MSLS_071, already processed.


 78%|███████▊  | 73/93 [00:52<00:10,  1.86it/s]

✅ Skipping MSLS_072, already processed.


 80%|███████▉  | 74/93 [00:53<00:12,  1.57it/s]

✅ Skipping MSLS_073, already processed.


 81%|████████  | 75/93 [00:53<00:11,  1.53it/s]

✅ Skipping MSLS_074, already processed.


 82%|████████▏ | 76/93 [00:54<00:10,  1.67it/s]

✅ Skipping MSLS_075, already processed.


 83%|████████▎ | 77/93 [00:54<00:08,  1.84it/s]

✅ Skipping MSLS_076, already processed.


 84%|████████▍ | 78/93 [00:55<00:07,  1.91it/s]

✅ Skipping MSLS_077, already processed.


 85%|████████▍ | 79/93 [00:55<00:07,  1.98it/s]

✅ Skipping MSLS_078, already processed.


 86%|████████▌ | 80/93 [00:55<00:05,  2.34it/s]

✅ Skipping MSLS_079, already processed.


 87%|████████▋ | 81/93 [00:56<00:05,  2.29it/s]

✅ Skipping MSLS_080, already processed.


 88%|████████▊ | 82/93 [00:56<00:04,  2.23it/s]

✅ Skipping MSLS_081, already processed.


 89%|████████▉ | 83/93 [00:57<00:04,  2.27it/s]

✅ Skipping MSLS_082, already processed.


 90%|█████████ | 84/93 [00:57<00:04,  2.22it/s]

✅ Skipping MSLS_083, already processed.


 91%|█████████▏| 85/93 [00:58<00:03,  2.18it/s]

✅ Skipping MSLS_084, already processed.


 92%|█████████▏| 86/93 [00:58<00:02,  2.54it/s]

✅ Skipping MSLS_085, already processed.


 94%|█████████▎| 87/93 [00:58<00:02,  2.41it/s]

✅ Skipping MSLS_086, already processed.


 95%|█████████▍| 88/93 [00:59<00:02,  2.33it/s]

✅ Skipping MSLS_087, already processed.


 96%|█████████▌| 89/93 [00:59<00:01,  2.38it/s]

✅ Skipping MSLS_088, already processed.


 97%|█████████▋| 90/93 [01:00<00:01,  2.30it/s]

✅ Skipping MSLS_089, already processed.


 98%|█████████▊| 91/93 [01:01<00:01,  1.94it/s]

✅ Skipping MSLS_090, already processed.


 99%|█████████▉| 92/93 [01:01<00:00,  2.01it/s]

✅ Skipping MSLS_091, already processed.


100%|██████████| 93/93 [01:01<00:00,  1.50it/s]

✅ Skipping MSLS_092, already processed.





In [None]:
# Run the preprocessing on the test set
RAW_DATA_PATH = "gs://mslesseg_4_icpr_bucket/data/01-Pre-Processed-Data/test/test_MASK"
OUTPUT_PATH = "gs://mslesseg_4_icpr_bucket/data/02-Tensor-Data/test"

run_preprocessing(RAW_DATA_PATH, OUTPUT_PATH)

  5%|▍         | 1/22 [00:00<00:09,  2.19it/s]

✅ Skipping MSLS_093, already processed.


  9%|▉         | 2/22 [00:01<00:13,  1.46it/s]

✅ Skipping MSLS_094, already processed.


 14%|█▎        | 3/22 [00:02<00:13,  1.43it/s]

✅ Skipping MSLS_095, already processed.


 18%|█▊        | 4/22 [00:02<00:10,  1.66it/s]

✅ Skipping MSLS_096, already processed.


 23%|██▎       | 5/22 [00:02<00:08,  1.90it/s]

✅ Skipping MSLS_097, already processed.


 27%|██▋       | 6/22 [00:03<00:08,  1.97it/s]

✅ Skipping MSLS_098, already processed.


 32%|███▏      | 7/22 [00:03<00:07,  2.02it/s]

✅ Skipping MSLS_099, already processed.


 36%|███▋      | 8/22 [00:04<00:05,  2.41it/s]

✅ Skipping MSLS_100, already processed.


 41%|████      | 9/22 [00:04<00:04,  2.77it/s]

✅ Skipping MSLS_101, already processed.


 45%|████▌     | 10/22 [00:04<00:03,  3.07it/s]

✅ Skipping MSLS_102, already processed.


 50%|█████     | 11/22 [00:04<00:03,  3.32it/s]

✅ Skipping MSLS_103, already processed.


 55%|█████▍    | 12/22 [00:05<00:02,  3.51it/s]

✅ Skipping MSLS_104, already processed.


 59%|█████▉    | 13/22 [00:05<00:02,  3.69it/s]

✅ Skipping MSLS_105, already processed.


 64%|██████▎   | 14/22 [00:05<00:02,  3.83it/s]

✅ Skipping MSLS_106, already processed.


 68%|██████▊   | 15/22 [00:05<00:01,  3.90it/s]

✅ Skipping MSLS_107, already processed.


 73%|███████▎  | 16/22 [00:05<00:01,  3.95it/s]

✅ Skipping MSLS_108, already processed.


 77%|███████▋  | 17/22 [00:06<00:01,  3.97it/s]

✅ Skipping MSLS_109, already processed.


 82%|████████▏ | 18/22 [00:06<00:01,  3.95it/s]

✅ Skipping MSLS_110, already processed.


 86%|████████▋ | 19/22 [00:06<00:00,  3.96it/s]

✅ Skipping MSLS_111, already processed.


 91%|█████████ | 20/22 [00:06<00:00,  4.05it/s]

✅ Skipping MSLS_112, already processed.


 95%|█████████▌| 21/22 [00:07<00:00,  4.07it/s]

✅ Skipping MSLS_113, already processed.


100%|██████████| 22/22 [00:07<00:00,  2.94it/s]

✅ Skipping MSLS_114, already processed.





## Build the Dataset and Dataloaders for the MSLesSeg preprocessed data

In [None]:
# Check the initial memory consuption (GPU) of the project
import torch

# Check initial GPU memory usage
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 0.00 MB


In [None]:
import torch
from torch.utils.data import Dataset
import gcsfs
import os

class MSLesSegDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir  # e.g., "gs://your_bucket/path/to/data"
        self.fs = gcsfs.GCSFileSystem()
        self.patient_dirs = self._get_patient_dirs()

    def _get_patient_dirs(self):
        """
        Discover valid patient directories in the GCS bucket that contain the expected .pt files.
        """
        # List subdirectories under the root_dir
        all_dirs = self.fs.ls(self.root_dir)
        patient_dirs = []


        for entry in all_dirs:
            # Check if both required files exist inside each subdirectory
            input_tensor = f"{entry}/input_tensor.pt"
            seg_mask = f"{entry}/seg_mask.pt"
            if self.fs.exists(input_tensor) and self.fs.exists(seg_mask):
                patient_dirs.append(entry)

        return patient_dirs

    def __len__(self):
        return len(self.patient_dirs)

    def __getitem__(self, idx):
        """
        Load tensors from GCS.
        """
        patient_dir = self.patient_dirs[idx]
        input_path = f"{patient_dir}/input_tensor.pt"
        seg_path = f"{patient_dir}/seg_mask.pt"

        with self.fs.open(input_path, 'rb') as f:
            input_tensor = torch.load(f)

        with self.fs.open(seg_path, 'rb') as f:
            seg_mask = torch.load(f)

        # print(f"Returning input tensor shape: {input_tensor.shape}")
        # print(f"Returning seg mask shape: {seg_mask.shape}")
        return input_tensor, seg_mask


### PyTorch Lightning DataModule

This particular version of PyTorch Lightning, and in general from version 2.x onward require a ***LightningDataModule*** instead of passing the dataloaders directly to the ***.fit()*** method

In [None]:
# MSLesSeg (PyTorch) LightningDataModule definition
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl

class MSLesSegDataModule(pl.LightningDataModule):

    def __init__(self, root_data_dir, batch_size, val_split, num_workers):
        """
        root_data_dir: Is the path to the train and test data
        """
        super().__init__()
        self.data_dir = root_data_dir
        self.batch_size = batch_size
        self.val_split = val_split
        self.num_workers = num_workers

    def setup(self, stage=None):
        # Load full training dataset
        full_dataset = MSLesSegDataset(root_dir=os.path.join(self.data_dir, 'train'))

        # Split into train and val
        val_size = int(len(full_dataset) * self.val_split)
        train_size = len(full_dataset) - val_size
        self.train_dataset, self.val_dataset = random_split(full_dataset, [train_size, val_size])

        # Load test dataset (if it exists)
        test_dir = os.path.join(self.data_dir, 'test')
        if os.path.exists(test_dir):
            self.test_dataset = MSLesSegDataset(root_dir=test_dir)
        else:
            self.test_dataset = None

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self):
        if self.test_dataset:
            return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return None



## The Model's Architecture

In [None]:
# The U-Net architecture is based on the tutorial offered by the author of the architecture
# Link --> https://github.com/bnsreenu/python_for_image_processing_APEER/blob/master/tutorial122_3D_Unet.ipynb
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu1 = nn.GELU()
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu2 = nn.GELU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        return x

# Downsampling path Conv block followed by maxpooling.
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.conv_block = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

    def forward(self, x):
        conv_output = self.conv_block(x)
        pool_output = self.pool(conv_output)
        return conv_output, pool_output

# Upsampling path: Skip features gets input from encoder for concatenation
class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        self.conv_transpose = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv_block = ConvBlock(in_channels, out_channels)

    def forward(self, x, skip_features):
        x = self.conv_transpose(x)
        x = torch.cat((x, skip_features), dim=1)
        x = self.conv_block(x)
        return x

In [None]:
# Helper function to pad to a multiple of 16 (because of the 4 down and up convolutions)
def pad_to_multiple(x, multiple=16):
    _, _, h, w, d = x.shape
    pad_h = (multiple - h % multiple) % multiple
    pad_w = (multiple - w % multiple) % multiple
    pad_d = (multiple - d % multiple) % multiple

    pad = [0, pad_d, 0, pad_w, 0, pad_h]  # D, W, H
    return F.pad(x, pad, mode='constant', value=0)

In [None]:
# Pad the input tensor to a target shape (to pad the GT segmentation mask to the architecture's output mask)
def pad_to_match(tensor, target_shape):
    current_shape = tensor.shape

    # Compute padding for H, W, D (last 3 dims)
    padding = []
    for curr, tgt in zip(reversed(current_shape[-3:]), reversed(target_shape[-3:])):
        total_pad = tgt - curr
        pad_left = total_pad // 2
        pad_right = total_pad - pad_left
        padding.extend([pad_left, pad_right])  # F.pad expects (left, right) per dim

    # Reverse padding to (D_left, D_right, W_left, W_right, H_left, H_right)
    return F.pad(tensor, padding)

In [None]:
# Center cropping helper function to obtain the desired shape for the prediction with the segmentation mask
def center_crop(x, target_shape):
    if len(x.shape) == 5:
      _, _, h, w, d = x.shape
    else:
      _, h, w, d = x.shape
    th, tw, td = target_shape
    h1 = (h - th) // 2
    w1 = (w - tw) // 2
    d1 = (d - td) // 2
    return x[:, h1:h1+th, w1:w1+tw, d1:d1+td]

In [None]:
# 3D U-Net Architecture ==> Reduced 1 Down and Up conv steps to reduce model parameters
class UNet(nn.Module):
    """
    The U-Net model will accept input Tensors of shape: [B, C, H, W, D]
    Which based on the used MSLesSeg dataset will be [B, 3, 182, 218, 182]
    The input Tensor is nothing more than the stacked [FLAIR, T1w, T2w] modalities

    The output will be the segmentation mask of shape [B, 1, 182, 218, 182]
    """
    def __init__(self, in_channels):
        super(UNet, self).__init__()
        self.down_1 = Down(in_channels, 64)
        self.down_2 = Down(64, 128)
        self.down_3 = Down(128, 256)

        self.bottleneck = ConvBlock(256, 512)

        self.up_1 = Up(512, 256)
        self.up_2 = Up(256, 128)
        self.up_3 = Up(128, 64)

        self.classifier = nn.Conv3d(64, 1, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):

        # Pad to multiple
        x = pad_to_multiple(x)

        # Downsampling path
        conv1, pool1 = self.down_1(x)
        conv2, pool2 = self.down_2(pool1)
        conv3, pool3 = self.down_3(pool2)

        # Bottleneck
        bottleneck = self.bottleneck(pool3)

        # Upsampling path
        upconv1 = self.up_1(bottleneck, conv3)
        upconv2 = self.up_2(upconv1, conv2)
        upconv3 = self.up_3(upconv2, conv2)

        output = self.sigmoid(self.classifier(upconv3))

        return output

In [None]:
# 3D U-Net Architecture (90 Milion Parameters)
class UNet(nn.Module):
    """
    The U-Net model will accept input Tensors of shape: [B, C, H, W, D]
    Which based on the used MSLesSeg dataset will be [B, 3, 182, 218, 182]
    The input Tensor is nothing more than the stacked [FLAIR, T1w, T2w] modalities

    The output will be the segmentation mask of shape [B, 1, 182, 218, 182]
    """
    def __init__(self, in_channels):
        super(UNet, self).__init__()
        self.down_1 = Down(in_channels, 64)
        self.down_2 = Down(64, 128)
        self.down_3 = Down(128, 256)
        self.down_4 = Down(256, 512)

        self.bottleneck = ConvBlock(512, 1024)

        self.up_1 = Up(1024, 512)
        self.up_2 = Up(512, 256)
        self.up_3 = Up(256, 128)
        self.up_4 = Up(128, 64)

        self.classifier = nn.Conv3d(64, 1, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):

        # Pad to multiple
        x = pad_to_multiple(x)

        # Downsampling path
        conv1, pool1 = self.down_1(x)
        conv2, pool2 = self.down_2(pool1)
        conv3, pool3 = self.down_3(pool2)
        conv4, pool4 = self.down_4(pool3)

        # Bottleneck
        bottleneck = self.bottleneck(pool4)

        # Upsampling path
        upconv1 = self.up_1(bottleneck, conv4)
        upconv2 = self.up_2(upconv1, conv3)
        upconv3 = self.up_3(upconv2, conv2)
        upconv4 = self.up_4(upconv3, conv1)

        output = self.sigmoid(self.classifier(upconv4))

        return output

In [None]:
# 3D U-Net Architecture - Parameter Reduced Version
class UNet(nn.Module):
    """
    The U-Net model will accept input Tensors of shape: [B, C, H, W, D]
    Which based on the used MSLesSeg dataset will be [B, 3, 182, 218, 182]
    The input Tensor is nothing more than the stacked [FLAIR, T1w, T2w] modalities

    The output will be the segmentation mask of shape [B, 1, 182, 218, 182]
    """
    def __init__(self, in_channels):
        super(UNet, self).__init__()
        self.down_1 = Down(in_channels, 16)
        self.down_2 = Down(16, 32)
        self.down_3 = Down(32, 64)
        self.down_4 = Down(64, 128)

        self.bottleneck = ConvBlock(128, 256)

        self.up_1 = Up(256, 128)
        self.up_2 = Up(128, 64)
        self.up_3 = Up(64, 32)
        self.up_4 = Up(32, 16)

        self.classifier = nn.Conv3d(16, 1, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):

        # Pad to multiple
        x = pad_to_multiple(x)

        # Downsampling path
        conv1, pool1 = self.down_1(x)
        conv2, pool2 = self.down_2(pool1)
        conv3, pool3 = self.down_3(pool2)
        conv4, pool4 = self.down_4(pool3)

        # Bottleneck
        bottleneck = self.bottleneck(pool4)

        # Upsampling path
        upconv1 = self.up_1(bottleneck, conv4)
        upconv2 = self.up_2(upconv1, conv3)
        upconv3 = self.up_3(upconv2, conv2)
        upconv4 = self.up_4(upconv3, conv1)

        output = self.sigmoid(self.classifier(upconv4))

        return output

In [None]:
# Summary of the U-Net model, in terms of parameter count and estimated VRAM consumption
from torchinfo import summary

model = UNet(in_channels=3)  # Your model
summary(model, input_size=(1, 3, 182, 218, 182))  # (batch_size, C, H, W)

Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     [1, 1, 192, 224, 192]     --
├─Down: 1-1                              [1, 16, 192, 224, 192]    --
│    └─ConvBlock: 2-1                    [1, 16, 192, 224, 192]    --
│    │    └─Conv3d: 3-1                  [1, 16, 192, 224, 192]    1,312
│    │    └─BatchNorm3d: 3-2             [1, 16, 192, 224, 192]    32
│    │    └─GELU: 3-3                    [1, 16, 192, 224, 192]    --
│    │    └─Conv3d: 3-4                  [1, 16, 192, 224, 192]    6,928
│    │    └─BatchNorm3d: 3-5             [1, 16, 192, 224, 192]    32
│    │    └─GELU: 3-6                    [1, 16, 192, 224, 192]    --
│    └─MaxPool3d: 2-2                    [1, 16, 96, 112, 96]      --
├─Down: 1-2                              [1, 32, 96, 112, 96]      --
│    └─ConvBlock: 2-3                    [1, 32, 96, 112, 96]      --
│    │    └─Conv3d: 3-7                  [1, 32, 96, 112, 96]      13,856
│    

In [None]:
# Checkpoint to see where the GPU memory overhead is
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 21.59 MB


## The Trainer

In [None]:
# Implementation of the Dice Score
def dice_score(pred, target, smooth=1e-6):
    """
    Computes the Dice Loss for binary segmentation.
    Args:
        pred: Tensor of predictions (batch_size, 1, H, W, D).
        target: Tensor of ground truth (batch_size, 1, H, W, D).
        smooth: Smoothing factor to avoid division by zero.
    Returns:
        Scalar Dice Loss.
    """
    # print(f"Pred shape: {pred.shape}")
    # print(f"Target shape: {target.shape}")

    # Calculate intersection and union
    intersection = (pred * target).sum(dim=(2, 3, 4))
    union = pred.sum(dim=(2, 3, 4)) + target.sum(dim=(2, 3, 4))

    # Compute Dice Coefficient
    dice = (2. * intersection + smooth) / (union + smooth)

    # Return DiceScore
    return dice.mean()

# Implementation of the Dice Loss
def dice_loss(pred, target, smooth=1e-6):
  return 1 - dice_score(pred, target, smooth)

In [None]:
# PyTorch Lightning Trainer for the U-Net model
import torch.optim as optim

class MSLesionSegmentationModel(pl.LightningModule):

    def __init__(self, model, checkpoint_dir, lr=1e-4):
        super(MSLesionSegmentationModel, self).__init__()
        self.model = model
        self.lr = lr
        self.checkpoint_dir = checkpoint_dir

    def forward(self, x):
        out = self.model(x)
        return out

    def training_step(self, batch, batch_idx):
        # Separate the input Tensor from the Segmentation Mask
        input_tensor, gt = batch
        # Forward Pass on the model
        y_pred = self(input_tensor)
        # Pad the Ground Truth Seg Mask to match the shape of the y_pred
        gt = pad_to_match(gt, y_pred.shape)
        # Compute DiceLoss
        loss = dice_loss(y_pred, gt)

        # log the train loss
        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        # Separate the input Tensor from the Segmentation Mask
        input_tensor, gt = batch
        # Forward pass on the model
        y_pred = self(input_tensor)
        # Pad the Ground Truth Seg Mask to match the shape of the y_pred
        gt = pad_to_match(gt, y_pred.shape)
        # Compute DiceLoss
        loss = dice_loss(y_pred, gt)
        # Compute the Mean Dice Score
        dicescore = dice_score(y_pred, gt)
        # log the val loss
        self.log("val_loss", loss)
        # log the Mean Dice Score
        self.log("dice_score", dicescore)

        return loss

    def test_step(self, batch, batch_idx):
        # Separate the input Tensor from the Segmentation Mask
        input_tensor, gt = batch
        # Forward pass on the model
        y_pred = self(input_tensor)
        # Pad the Ground Truth Seg Mask to match the shape of the y_pred
        gt = pad_to_match(gt, y_pred.shape)
        # Compute DiceLoss
        loss = dice_loss(y_pred, gt)

        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [scheduler]

    def on_train_epoch_end(self):
        # Save the model's checkpoint at the end of the epoch
        os.makedirs(self.checkpoint_dir, exist_ok=True)

        save_path = os.path.join(self.checkpoint_dir, "unet_best_model.pth")

        scheduler = (
          self.trainer.lr_scheduler_configs[0].scheduler
          if self.trainer.lr_scheduler_configs else None
        )

        # Prepare the checkpoint dict
        checkpoint = {
            'epoch': self.current_epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.trainer.optimizers[0].state_dict() if self.trainer.optimizers else None,
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'learning_rate': self.lr,
        }

        # Save and flush to disk
        with open(save_path, 'wb') as f:
            torch.save(checkpoint, f)
            f.flush()
            os.fsync(f.fileno())

## The Training

In [None]:
# Register a free account with Weights and Biases, and create a new project in order to obtain an API Key for the training
import os
import wandb

api_key = "a8135d178a1176283501af79deb74ed278c73806"

# Login to wandb
if api_key:
    os.environ["WANDB_API_KEY"] = api_key
    wandb.login()
else:
    print("❌ WANDB_API_KEY not found in .env file.")

[34m[1mwandb[0m: Currently logged in as: [33mdrnnrw00m10c351s[0m ([33mfpv-perceivelab-unict[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
# Setup the Weights and Biases logger
import wandb
from pytorch_lightning.loggers import WandbLogger

wandb_logger = WandbLogger(
    project='MSLesSeg-4-ICPR',     # Change to your actual project name
    name='colab_UNet_run_1', # A specific run name
    log_model=True          # Optional: log model checkpoints
)

In [None]:
# Set the constant for the maximum number of epochs for the training
MAX_EPOCHS = 100

In [None]:
# Definition of Dataloader hyperparameters (Batch size, seed and num workers)
TRAIN_SPLIT = 0.8
VAL_SPLIT = 0.2

TRAIN_NUM_WORKERS = 0
TEST_NUM_WORKERS = 0

TRAIN_BATCH_SIZE = 3
VAL_BATCH_SIZE = 2
TEST_BATCH_SIZE = 3

In [None]:
# Create the Trainer object
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator="gpu",
    devices=1,
    logger=wandb_logger,
    benchmark=True,  # optimize CUDA kernels for performance
    detect_anomaly=False
)

INFO:pytorch_lightning.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
# Prepare the data module
data_module = MSLesSegDataModule(
    root_data_dir="gs://mslesseg_4_icpr_bucket/data/02-Tensor-Data/",
    batch_size=TRAIN_BATCH_SIZE,
    val_split=VAL_SPLIT,
    num_workers=TRAIN_NUM_WORKERS
)

In [None]:
# Instantiate the PyTorch Lightning Module (wrapper for the PyTorch nn.Module ~ Architecture)
lightning_model = MSLesionSegmentationModel(
    model=UNet(in_channels=3),
    checkpoint_dir="./model_checkpoints"
)

In [None]:
# Check the initial memory consuption (GPU) of the project
import torch

# Check initial GPU memory usage
initial_memory = torch.cuda.memory_allocated() / 1024 ** 2  # in MB
print(f"Initial memory usage: {initial_memory:.2f} MB")

Initial memory usage: 21.59 MB


In [None]:
# Estimated VRAM footprint per CUDA Tensor of shape [B, C, H, W, D]
def get_tensor_vram_mb(tensor):
    return tensor.element_size() * tensor.nelement() / (1024 ** 2)

input_tensor = torch.randn(1, 3, 182, 218, 182).cuda()
print(f"(CUDA) Input tensor size in MB: {get_tensor_vram_mb(input_tensor):.2f} MB")

(CUDA) Input tensor size in MB: 82.64 MB


In [None]:
# Estimated VRAM footprint for the CUDA model
def get_model_param_vram_mb(model):
    return sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)

model = UNet(in_channels=3).cuda()
print(f"(CUDA) Model parameters size: {get_model_param_vram_mb(model):.2f} MB")

(CUDA) Model parameters size: 21.55 MB


In [None]:
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

In [None]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [None]:
# Train (fit) the model
trainer.fit(lightning_model, datamodule=data_module)

INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type | Params | Mode 
---------------------------------------
0 | model | UNet | 5.6 M  | train
---------------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
22.595    Total estimated model params size (MB)
82        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (25) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
# Load the best model's checkpoint and evaluate it

In [None]:
# Evaluate the model on the test set
trainer.test(model, test_dataloader=data_module)

---