# Multiple Sclerosis Lesion Segmentation Dataset Overview

---



- Introduction to the challenge {evaluation metrics, comparrison to other datasets} (with link to website [https://iplab.dmi.unict.it/mfs/ms-les-seg/] and paper)
- Introduction to the Dataset (creation, pre-processing, contents and distribution data with visualizations)

Table of contents

---

## What have winning strategies been so far?

1) Ensamble methods
2) Different Registration Spaces
3) Selecting only the specific data from the image that can better find the segmentation mask
4) Using only ONE modality (FLAIR) and performing Data Augmentations on it + Deep CNN Architecture
5) nnU-Net for automatic configuration to the dataset (Interesting)
6) Usage of Mamba-based LightM-UNet
7) Focal Loss + DICE Loss

---

## Personal Goals

1) Parameter Efficiency
2) (XAI) Explainable Architecture
3) Segmentation Accuracy (Mean DICE >= 60)
4) [Optional] --> Single Model (No Ensambles)

---

## Personal Model Proposition

1) Use ALL modalities
2) Preprocess with (Distance Based Labelling, Multi-Sized Labelling ~ Could aid the MoE variant)
3) nnU-Net + MoE (+ Deep Supervision ?)

### Optional

- Bring the model visualization to Extended Reality Domain ?

# The Code

---

In [1]:
import os
import torch
import nibabel as nib
import numpy as np
from scipy.ndimage import distance_transform_edt
from skimage.transform import resize
from pathlib import Path
from tqdm.notebook import tqdm

In [2]:
def load_nifti(file_path):
    return nib.load(file_path).get_fdata()

def normalize_intensity(img):
    # Normalize each modality independently (z-score)
    img = img.astype(np.float32)
    img = (img - np.mean(img)) / (np.std(img) + 1e-5)
    return img

def compute_distance_label(mask):
    mask = mask.astype(bool)
    pos_dist = distance_transform_edt(mask)
    neg_dist = distance_transform_edt(~mask)
    dist_map = pos_dist - neg_dist  # Signed distance transform
    return dist_map

def generate_multi_size_masks(mask, scales=[1.0, 0.5, 0.25]):
    multi_res = []
    for scale in scales:
        if scale == 1.0:
            multi_res.append(mask)
        else:
            resized = resize(mask, output_shape=tuple(int(s * scale) for s in mask.shape),
                             order=1, preserve_range=True, anti_aliasing=True)
            multi_res.append(resized.astype(mask.dtype))
    return multi_res

In [3]:
def preprocess_case(input_dir, output_dir, case_id):
    flair = normalize_intensity(load_nifti(input_dir / f"{case_id}_flair.nii.gz"))
    t1 = normalize_intensity(load_nifti(input_dir / f"{case_id}_t1.nii.gz"))
    t2 = normalize_intensity(load_nifti(input_dir / f"{case_id}_t2.nii.gz"))
    seg = load_nifti(input_dir / f"{case_id}_seg.nii.gz").astype(np.uint8)

    # Stack input modalities into a tensor (C, D, H, W)
    stacked = np.stack([flair, t1, t2], axis=0)
    input_tensor = torch.tensor(stacked, dtype=torch.float32).cuda()

    # Compute distance map from segmentation mask
    distance_map = compute_distance_label(seg)
    distance_tensor = torch.tensor(distance_map, dtype=torch.float32).unsqueeze(0).cuda()

    # Generate multi-size label masks
    multi_size_masks = generate_multi_size_masks(seg)
    multi_size_tensors = [torch.tensor(m, dtype=torch.uint8).unsqueeze(0).cuda() for m in multi_size_masks]

    # Save all
    output_case_dir = output_dir / case_id
    output_case_dir.mkdir(parents=True, exist_ok=True)
    torch.save(input_tensor, output_case_dir / "input_tensor.pt")
    torch.save(distance_tensor, output_case_dir / "distance_map.pt")
    for i, mask in enumerate(multi_size_tensors):
        torch.save(mask, output_case_dir / f"multi_size_mask_{i}.pt")

In [4]:
def run_preprocessing(root_path, output_path):
    input_path = Path(root_path)
    output_path = Path(output_path)

    all_case_dirs = [d for d in input_path.iterdir() if d.is_dir()]
    print(f"Found {len(all_case_dirs)} cases.")

    for case_dir in tqdm(all_case_dirs):
        case_id = case_dir.name  # e.g., MSLS_000
        try:
            preprocess_case(case_dir, output_path, case_id)
        except Exception as e:
            print(f"❌ Failed on {case_id}: {e}")

In [5]:
# Run the preprocessing on the training set
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/train"
OUTPUT_PATH = "../data/02-Tensor-Data/train"

run_preprocessing(RAW_DATA_PATH, OUTPUT_PATH)

Found 93 cases.


  0%|          | 0/93 [00:00<?, ?it/s]

In [8]:
# Run the preprocessing on the test set (with gt mask)
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/test/test_MASK"
OUTPUT_PATH = "../data/02-Tensor-Data/test/test_MASK"

run_preprocessing(RAW_DATA_PATH, OUTPUT_PATH)

Found 22 cases.


  0%|          | 0/22 [00:00<?, ?it/s]

In [11]:
# Variation for the test set WITHOUT the gt seg mask
def preprocess_case_inputs_only(input_dir, output_dir, case_id):
    # Required input modalities only
    required_files = {
        "flair": input_dir / f"{case_id}_flair.nii.gz",
        "t1": input_dir / f"{case_id}_t1.nii.gz",
        "t2": input_dir / f"{case_id}_t2.nii.gz",
    }

    for key, path in required_files.items():
        if not path.exists():
            raise FileNotFoundError(f"Missing file for '{key}': {path}")

    # Load and normalize each modality
    flair = normalize_intensity(load_nifti(required_files["flair"]))
    t1 = normalize_intensity(load_nifti(required_files["t1"]))
    t2 = normalize_intensity(load_nifti(required_files["t2"]))

    # Stack modalities (C, D, H, W)
    stacked = np.stack([flair, t1, t2], axis=0)
    input_tensor = torch.tensor(stacked, dtype=torch.float32).cuda()

    # Save preprocessed tensor
    output_case_dir = output_dir / case_id
    output_case_dir.mkdir(parents=True, exist_ok=True)
    torch.save(input_tensor, output_case_dir / "input_tensor.pt")


In [12]:
# Variation for the test set WITHOUT the gt seg mask
def run_preprocessing_inputs_only(root_path, output_path):
    input_path = Path(root_path)
    output_path = Path(output_path)

    all_case_dirs = [d for d in input_path.iterdir() if d.is_dir()]
    print(f"Found {len(all_case_dirs)} cases.")

    for case_dir in tqdm(all_case_dirs):
        case_id = case_dir.name
        try:
            preprocess_case_inputs_only(case_dir, output_path, case_id)
        except Exception as e:
            print(f"❌ Skipping {case_id}: {e}")


In [13]:
# Run the preprocessing on the test set (no gt mask)
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/test/test"
OUTPUT_PATH = "../data/02-Tensor-Data/test/test"

run_preprocessing_inputs_only(RAW_DATA_PATH, OUTPUT_PATH)

Found 22 cases.


  0%|          | 0/22 [00:00<?, ?it/s]