# Multiple Sclerosis Lesion Segmentation Dataset Overview

---



- Introduction to the challenge {evaluation metrics, comparrison to other datasets} (with link to website [https://iplab.dmi.unict.it/mfs/ms-les-seg/] and paper)
- Introduction to the Dataset (creation, pre-processing, contents and distribution data with visualizations)

Table of contents

---

## What have winning strategies been so far?

1) Ensamble methods
2) Different Registration Spaces
3) Selecting only the specific data from the image that can better find the segmentation mask
4) Using only ONE modality (FLAIR) and performing Data Augmentations on it + Deep CNN Architecture
5) nnU-Net for automatic configuration to the dataset (Interesting)
6) Usage of Mamba-based LightM-UNet
7) Focal Loss + DICE Loss

---

## Personal Goals

1) Parameter Efficiency
2) (XAI) Explainable Architecture
3) Segmentation Accuracy (Mean DICE >= 60)
4) [Optional] --> Single Model (No Ensambles)

---

## Personal Model Proposition

1) Use ALL modalities
2) Preprocess with (Distance Based Labelling, Multi-Sized Labelling ~ Could aid the MoE variant)
3) nnU-Net + MoE (+ Deep Supervision ?)

### Optional

- Bring the model visualization to Extended Reality Domain ?

# The Experiments

---

In [1]:
import os
import torch
import nibabel as nib
import numpy as np
from scipy.ndimage import distance_transform_edt
from skimage.transform import resize
from pathlib import Path
from tqdm.notebook import tqdm

## Dataset Preprocessing

- Expand more on the performed preprocessing



In [2]:
def load_nifti(file_path):
    return nib.load(file_path).get_fdata()

def normalize_intensity(img):
    # Normalize each modality independently (z-score)
    img = img.astype(np.float32)
    img = (img - np.mean(img)) / (np.std(img) + 1e-5)
    return img

def compute_distance_label(mask):
    mask = mask.astype(bool)
    pos_dist = distance_transform_edt(mask)
    neg_dist = distance_transform_edt(~mask)
    dist_map = pos_dist - neg_dist  # Signed distance transform
    return dist_map

def generate_multi_size_masks(mask, scales=[1.0, 0.5, 0.25]):
    multi_res = []
    for scale in scales:
        if scale == 1.0:
            multi_res.append(mask)
        else:
            resized = resize(mask, output_shape=tuple(int(s * scale) for s in mask.shape),
                             order=1, preserve_range=True, anti_aliasing=True)
            multi_res.append(resized.astype(mask.dtype))
    return multi_res

In [3]:
def preprocess_case(input_dir, output_dir, case_id):
    flair = normalize_intensity(load_nifti(input_dir / f"{case_id}_flair.nii.gz"))
    t1 = normalize_intensity(load_nifti(input_dir / f"{case_id}_t1.nii.gz"))
    t2 = normalize_intensity(load_nifti(input_dir / f"{case_id}_t2.nii.gz"))
    seg = load_nifti(input_dir / f"{case_id}_seg.nii.gz").astype(np.uint8)

    # Stack input modalities into a tensor (C, D, H, W)
    stacked = np.stack([flair, t1, t2], axis=0)
    input_tensor = torch.tensor(stacked, dtype=torch.float32).cuda()

    # Compute distance map from segmentation mask
    distance_map = compute_distance_label(seg)
    distance_tensor = torch.tensor(distance_map, dtype=torch.float32).unsqueeze(0).cuda()

    # Generate multi-size label masks
    multi_size_masks = generate_multi_size_masks(seg)
    multi_size_tensors = [torch.tensor(m, dtype=torch.uint8).unsqueeze(0).cuda() for m in multi_size_masks]

    # Save all
    output_case_dir = output_dir / case_id
    output_case_dir.mkdir(parents=True, exist_ok=True)
    torch.save(input_tensor, output_case_dir / "input_tensor.pt")
    torch.save(distance_tensor, output_case_dir / "distance_map.pt")
    for i, mask in enumerate(multi_size_tensors):
        torch.save(mask, output_case_dir / f"multi_size_mask_{i}.pt")

In [4]:
def run_preprocessing(root_path, output_path):
    input_path = Path(root_path)
    output_path = Path(output_path)

    all_case_dirs = [d for d in input_path.iterdir() if d.is_dir()]
    print(f"Found {len(all_case_dirs)} cases.")

    for case_dir in tqdm(all_case_dirs):
        case_id = case_dir.name  # e.g., MSLS_000
        try:
            preprocess_case(case_dir, output_path, case_id)
        except Exception as e:
            print(f"❌ Failed on {case_id}: {e}")

In [5]:
# Run the preprocessing on the training set
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/train"
OUTPUT_PATH = "../data/02-Tensor-Data/train"

run_preprocessing(RAW_DATA_PATH, OUTPUT_PATH)

Found 93 cases.


  0%|          | 0/93 [00:00<?, ?it/s]

In [8]:
# Run the preprocessing on the test set (with gt mask)
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/test/test_MASK"
OUTPUT_PATH = "../data/02-Tensor-Data/test/test_MASK"

run_preprocessing(RAW_DATA_PATH, OUTPUT_PATH)

Found 22 cases.


  0%|          | 0/22 [00:00<?, ?it/s]

In [11]:
# Variation for the test set WITHOUT the gt seg mask
def preprocess_case_inputs_only(input_dir, output_dir, case_id):
    # Required input modalities only
    required_files = {
        "flair": input_dir / f"{case_id}_flair.nii.gz",
        "t1": input_dir / f"{case_id}_t1.nii.gz",
        "t2": input_dir / f"{case_id}_t2.nii.gz",
    }

    for key, path in required_files.items():
        if not path.exists():
            raise FileNotFoundError(f"Missing file for '{key}': {path}")

    # Load and normalize each modality
    flair = normalize_intensity(load_nifti(required_files["flair"]))
    t1 = normalize_intensity(load_nifti(required_files["t1"]))
    t2 = normalize_intensity(load_nifti(required_files["t2"]))

    # Stack modalities (C, D, H, W)
    stacked = np.stack([flair, t1, t2], axis=0)
    input_tensor = torch.tensor(stacked, dtype=torch.float32).cuda()

    # Save preprocessed tensor
    output_case_dir = output_dir / case_id
    output_case_dir.mkdir(parents=True, exist_ok=True)
    torch.save(input_tensor, output_case_dir / "input_tensor.pt")


In [12]:
# Variation for the test set WITHOUT the gt seg mask
def run_preprocessing_inputs_only(root_path, output_path):
    input_path = Path(root_path)
    output_path = Path(output_path)

    all_case_dirs = [d for d in input_path.iterdir() if d.is_dir()]
    print(f"Found {len(all_case_dirs)} cases.")

    for case_dir in tqdm(all_case_dirs):
        case_id = case_dir.name
        try:
            preprocess_case_inputs_only(case_dir, output_path, case_id)
        except Exception as e:
            print(f"❌ Skipping {case_id}: {e}")


In [13]:
# Run the preprocessing on the test set (no gt mask)
RAW_DATA_PATH = "../data/01-Pre-Processed-Data/test/test"
OUTPUT_PATH = "../data/02-Tensor-Data/test/test"

run_preprocessing_inputs_only(RAW_DATA_PATH, OUTPUT_PATH)

Found 22 cases.


  0%|          | 0/22 [00:00<?, ?it/s]

## Build the Dataset and Dataloaders for the MSLesSeg preprocessed data

In [119]:
from torch.utils.data import Dataset
from pathlib import Path
import torch
import torch.nn.functional as F

class MSLesionDataset(Dataset):
    def __init__(self, root_dir, include_labels=True, sample_ids=None, transform=None):
        """
        Args:
            root_dir: The root directory containing the sample directories.
            include_labels: Whether to include the segmentation masks and distance maps.
            sample_ids: List of sample directories to use (if performing k-fold).
            transform: Optional transformation to apply to the input tensor.
        """
        self.root_dir = Path(root_dir)
        all_samples = sorted([d for d in self.root_dir.iterdir() if d.is_dir()])

        if sample_ids:
            self.sample_dirs = [d for d in all_samples if d.name in sample_ids]
        else:
            self.sample_dirs = all_samples

        self.include_labels = include_labels
        self.transform = transform

    def __len__(self):
        return len(self.sample_dirs)

        def __getitem__(self, idx):
        case_dir = self.sample_dirs[idx]

        # Load the input tensor
        input_tensor = torch.load(case_dir / "input_tensor.pt").float()

        # Get the size of the input tensor (height, width)
        target_size = input_tensor.shape[1:]  # H, W

        if self.transform:
            input_tensor = self.transform(input_tensor)

        # If no labels are required (inference or test without labels)
        if not self.include_labels:
            return input_tensor, str(case_dir.name)

        # Load labels (distance map and multi-size masks)
        distance_map = torch.load(case_dir / "distance_map.pt").float()
        multi_size_masks = []

        for i in range(3):  # Hardcoded for 3 multi-size masks
            path = case_dir / f"multi_size_mask_{i}.pt"
            if path.exists():
                mask = torch.load(path).float()
            else:
                mask = torch.zeros_like(distance_map)  # Fallback

            # Get the current size of the mask
            current_size = mask.shape
            print(f"Mask {i} shape: {current_size}")

            # If the mask is not the same size as the input tensor, pad it
            if current_size != target_size:
                print(f"Padding mask from {current_size} to {target_size}")
                pad_height = target_size[0] - current_size[0]
                pad_width = target_size[1] - current_size[1]

                # Apply padding (this pads with zeros)
                padding = (0, pad_width, 0, pad_height)  # For width and height
                mask_padded = F.pad(mask, padding, mode='constant', value=0)
            else:
                mask_padded = mask

            print(f"Padded mask {i} shape: {mask_padded.shape}")
            multi_size_masks.append(mask_padded)

        # Concatenate the padded masks along the 0th dimension
        try:
            return {
                "input": input_tensor,
                "distance": distance_map,
                "multi_masks": torch.cat(multi_size_masks, dim=0),
            }
        except Exception as e:
            print(f"Error during concatenation: {e}")
            for i, mask in enumerate(multi_size_masks):
                print(f"Shape of mask {i}: {mask.shape}")
            raise e


IndentationError: expected an indented block after function definition on line 29 (3461500416.py, line 30)

In [120]:
from torch.utils.data import DataLoader

def get_dataloader(data_root, batch_size=2, num_workers=4, shuffle=True, include_labels=True, sample_ids=None):
    """
    Get the DataLoader for the given dataset directory.

    Args:
        data_root: Path to the dataset root directory.
        batch_size: Number of samples per batch.
        num_workers: Number of workers for loading data in parallel.
        shuffle: Whether to shuffle the data.
        include_labels: Whether to include segmentation masks and distance maps.
        sample_ids: A list of sample ids for a specific split (for k-fold).
    """
    dataset = MSLesionDataset(data_root, include_labels=include_labels, sample_ids=sample_ids)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False
    )

In [121]:
# Definition of some (non training) parameters
NUM_WORKERS_TRAIN = 0
NUM_WORKERS_VAL = 0
NUM_WORKERS_TEST = 0

BATCH_SIZE_TRAIN = 6
BATCH_SIZE_VAL = 3
BATCH_SIZE_TEST = 4

In [122]:
import os
import random

TRAIN_TENSOR_DATA_PATH = "../data/02-Tensor-Data/train"

# List of all available case IDs in the train directory
all_cases = sorted(os.listdir(TRAIN_TENSOR_DATA_PATH))
random.seed(42)

# For k-fold cross-validation, split train into train and validation (80-20)
split = int(0.8 * len(all_cases))
train_ids = all_cases[:split]
val_ids = all_cases[split:]

train_loader = get_dataloader(TRAIN_TENSOR_DATA_PATH, sample_ids=train_ids, batch_size=BATCH_SIZE_TRAIN, num_workers=NUM_WORKERS_TRAIN)
val_loader = get_dataloader(TRAIN_TENSOR_DATA_PATH, sample_ids=val_ids, batch_size=BATCH_SIZE_VAL, num_workers=NUM_WORKERS_VAL, shuffle=False)

In [123]:
TEST_TENSOR_DATA_PATH_NOMASK = "../data/02-Tensor-Data/test/test"
TEST_TENSOR_DATA_PATH_MASK = "../data/02-Tensor-Data/test/test_MASK"

# Test set WITHOUT labels (for inference or prediction)
test_loader_no_labels = get_dataloader(TEST_TENSOR_DATA_PATH_NOMASK, include_labels=False, batch_size=BATCH_SIZE_TEST, num_workers=NUM_WORKERS_TEST)

# Test set WITH labels (for final evaluation or performance metrics)
test_loader_with_labels = get_dataloader(TEST_TENSOR_DATA_PATH_MASK, include_labels=True, batch_size=BATCH_SIZE_TEST, num_workers=NUM_WORKERS_TEST)

In [124]:
# In order to perform K-Fold Cross Validation
from sklearn.model_selection import KFold

def get_k_fold_loaders(data_root, k=5, batch_size=BATCH_SIZE_TRAIN, num_workers=NUM_WORKERS_TRAIN):
    # Get all case IDs
    all_cases = sorted(os.listdir(data_root))
    
    kf = KFold(n_splits=k, shuffle=True, random_state=42)
    
    fold_loaders = []

    for train_idx, val_idx in kf.split(all_cases):
        train_ids = [all_cases[i] for i in train_idx]
        val_ids = [all_cases[i] for i in val_idx]

        train_loader = get_dataloader(data_root, sample_ids=train_ids, batch_size=batch_size, num_workers=num_workers)
        val_loader = get_dataloader(data_root, sample_ids=val_ids, batch_size=batch_size, num_workers=num_workers, shuffle=False)

        fold_loaders.append((train_loader, val_loader))

    return fold_loaders

In [125]:
fold_loaders = get_k_fold_loaders(TRAIN_TENSOR_DATA_PATH, k=5, batch_size=2, num_workers=4)

# Example: Access the first fold
# train_loader, val_loader = fold_loaders[0]

### PyTorch Lightning DataModule

This particular version of PyTorch Lightning, and in general from version 2.x onward require a ***LightningDataModule*** instead of passing the dataloaders directly to the ***.fit()*** method

In [126]:
from pytorch_lightning import LightningDataModule

class MSLesionSegmentationDataModule(LightningDataModule):
    def __init__(self, train_dataloader, val_dataloader):
        super().__init__()
        self.train_dataloader_instance = train_dataloader
        self.val_dataloader_instance = val_dataloader

    def train_dataloader(self):
        return self.train_dataloader_instance

    def val_dataloader(self):
        return self.val_dataloader_instance

    def test_dataloader(self):
        return self.val_dataloader_instance  # Optional, if you want to test on the validation set as well.

## The Model's Architecture

In [127]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MoE(nn.Module):
    """
    Mixture of Experts with dynamic input size support via GAP.
    """
    def __init__(self, input_dim, num_experts=4, expert_dim=512):
        super(MoE, self).__init__()
        self.num_experts = num_experts

        self.gate = nn.Linear(input_dim, num_experts)
        self.experts = nn.ModuleList([nn.Linear(input_dim, expert_dim) for _ in range(num_experts)])

    def forward(self, x):
        # x is expected to be (B, C), already pooled
        gate_weights = F.softmax(self.gate(x), dim=-1)
        expert_outputs = [expert(x) for expert in self.experts]
        output = sum(gate_weights[:, i].unsqueeze(-1) * expert_outputs[i] for i in range(self.num_experts))
        return output

class UNetMoE(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, num_experts=4, expert_dim=512):
        super(UNetMoE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv3d(in_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2)
        )

        self.middle = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.Conv3d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, out_channels, kernel_size=1)
        )

        # Will be set dynamically after seeing one input
        self.moe = None
        self.num_experts = num_experts
        self.expert_dim = expert_dim

        self.deep_supervision = nn.Conv3d(128, out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.encoder[0:2](x)
        enc2 = self.encoder[2:](enc1)
        mid = self.middle(enc2)
        dec = self.decoder[0:2](mid)
        dec_out = self.decoder[2:](dec)

        # Global pooling before MoE
        pooled = F.adaptive_avg_pool3d(dec, 1).view(x.size(0), -1)

        # Init MoE if not yet initialized
        if self.moe is None:
            self.moe = MoE(input_dim=pooled.size(1), num_experts=self.num_experts, expert_dim=self.expert_dim).to(x.device)

        moe_out = self.moe(pooled)
        deep_supervision_out = self.deep_supervision(mid)

        return dec_out, deep_supervision_out, moe_out


In [None]:
# Example usage for testing (TO REMOVE)
model = UNetMoE(in_channels=3, out_channels=1)
dummy_input = torch.randn(2, 3, 128, 128, 128)  # e.g., batch size 2, 3 modalities
dec_out, deep_out, moe_out = model(dummy_input)
print(dec_out.shape, deep_out.shape, moe_out.shape)

In [None]:
# The Summary of the Architecture
from torchsummary import summary

# Set the device and the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetMoE(in_channels=3, out_channels=1, num_experts=4, expert_dim=512).to(device)

# Get the summary of the model (use an example input size, e.g., 128x128x128 for 3D)
summary(model, (3, 128, 128, 128))  # (channels, depth, height, width)

In [None]:
# For the graphical visualization
from torchviz import make_dot

# Dummy input
dummy_input = torch.randn(1, 3, 128, 128, 128).cuda()  # Adjust for batch size and input shape

# Forward pass through the model
output, _, _ = model(dummy_input)

# Visualize the graph
make_dot(output, params=dict(model.named_parameters())).render("model_architecture", format="png")

## Visualized Model Architecture

![Model Architecture Plot](./model_architecture.png)

## The Trainers

In [None]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torchmetrics import MetricCollection
from torchmetrics.segmentation import DiceScore

class MSLeasionSegmentationModel(pl.LightningModule):
    def __init__(self, model, lr=1e-4):
        super(MSLeasionSegmentationModel, self).__init__()
        
        # Initialize the model (nnU-Net + MoE)
        self.model = model
        
        # Learning rate for optimizer
        self.lr = lr
        
        # Set up Dice Score metric for evaluation
        self.dice_metric = DiceScore(num_classes=2)  # Binary segmentation: lesion or background
        
    def forward(self, x):
        # Forward pass (output, deep supervision, MoE output)
        output, deep_supervision, moe_output = self.model(x)
        return output, deep_supervision, moe_output
    
    def training_step(self, batch, batch_idx):
        # Extract input data and target segmentation masks from batch
        x = batch['input_tensor']  # Shape: (B, 3, D, H, W)
        y = batch['segmentation_mask']  # Shape: (B, 1, D, H, W)
        
        # Get model output
        output, _, _ = self(x)
        
        # Compute loss (using binary cross entropy for binary segmentation)
        loss = F.binary_cross_entropy_with_logits(output, y)
        
        # Log the loss
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        # Extract input data and target segmentation masks from batch
        x = batch['input_tensor']  # Shape: (B, 3, D, H, W)
        y = batch['segmentation_mask']  # Shape: (B, 1, D, H, W)
        
        # Get model output
        output, _, _ = self(x)
        
        # Calculate the Dice score
        dice_score = self.dice_metric(output.sigmoid(), y)
        
        # Log validation metrics
        self.log('val_dice', dice_score, on_step=True, on_epoch=True, prog_bar=True)
        
        return dice_score
    
    def test_step(self, batch, batch_idx):
        # Extract input data and target segmentation masks from batch
        x = batch['input_tensor']  # Shape: (B, 3, D, H, W)
        y = batch['segmentation_mask']  # Shape: (B, 1, D, H, W)
        
        # Get model output
        output, _, _ = self(x)
        
        # Calculate the Dice score
        dice_score = self.dice_metric(output.sigmoid(), y)
        
        # Log test metrics
        self.log('test_dice', dice_score, on_step=True, on_epoch=True, prog_bar=True)
        
        return dice_score
    
    def configure_optimizers(self):
        # Set up Adam optimizer with the specified learning rate
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        
        # Optionally, add a scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
        
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_dice'}

## The Training

In [None]:
# Register a free account with Weights and Biases, and create a new project in order to obtain an API Key for the training
import os
from dotenv import load_dotenv
import wandb

# Load the .env file
load_dotenv()

# Get the API key from the env variable
api_key = os.getenv("WANDB_API_KEY")

# Login to wandb
if api_key:
    os.environ["WANDB_API_KEY"] = api_key
    wandb.login()
else:
    print("❌ WANDB_API_KEY not found in .env file.")

In [None]:
# Setup the Weights and Biases logger
import wandb
from pytorch_lightning.loggers import WandbLogger

wandb_logger = WandbLogger(
    project='MSLesSeg-4-ICPR',     # Change to your actual project name
    name='nnUNet_MoE_run1', # A specific run name
    log_model=True          # Optional: log model checkpoints
)

In [None]:
# Set the multiprocessing start to 'spawn' instead of 'fork' due to CUDA issues with the Dataloader
import multiprocessing
import torch

# Set the start method for multiprocessing
multiprocessing.set_start_method('spawn', force=True)

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import os

# Initialize the model
model = MSLeasionSegmentationModel(model=UNetMoE(in_channels=3, out_channels=1, num_experts=4, expert_dim=512))

# Set up callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val_dice',  # Monitor the validation Dice score
    dirpath='checkpoints/',
    filename='best_checkpoint',
    save_top_k=1,
    mode='max',
)

early_stop_callback = EarlyStopping(
    monitor='val_dice',  # Early stopping based on the validation Dice score
    patience=10,
    mode='max',
)

# Intialize the Trainer
trainer = Trainer(
    accelerator="gpu",        # or "cpu" or "auto"
    devices=1,                # Number of GPUs to use
    max_epochs=100,
    callbacks=[checkpoint_callback, early_stop_callback],
    logger=wandb_logger,
)

# Instantiate the PyTorch LighningDataModule to run the training
data_module = MSLesionSegmentationDataModule(train_loader, val_loader)

# Train the model
trainer.fit(model, datamodule=data_module)

In [None]:
# Evaluate the model on the test set
trainer.test(model, test_dataloader=test_loader)

## Visualizing the Learned Representations

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

# Function to extract features from the model
def extract_features(model, dataloader, device='cuda'):
    model.eval()
    features = []
    labels = []
    
    with torch.no_grad():
        for data in dataloader:
            inputs, target = data['image'].to(device), data['mask'].to(device)  # Replace 'image' and 'mask' keys as per your dataset
            
            # Forward pass through the model
            output = model(inputs)  # This is where we extract the feature; adjust according to your model's architecture
            
            # Extract features from MoE (or any other layer you want to visualize)
            # Here I assume output is the final feature map after the MoE layer
            feature_map = output.view(output.size(0), -1)  # Flatten the features to (batch_size, features)
            features.append(feature_map.cpu().numpy())
            labels.append(target.cpu().numpy())  # Add the target (or ground truth) labels for color coding in t-SNE

    features = np.concatenate(features, axis=0)  # Combine all the feature maps
    labels = np.concatenate(labels, axis=0)  # Combine all the labels

    return features, labels

# Function to apply t-SNE on features
def plot_tsne(features, labels):
    # Standardize the features (optional but recommended for t-SNE)
    scaler = StandardScaler()
    features = scaler.fit_transform(features)

    # Apply t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(features)
    
    # Plotting
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=labels, cmap='jet', alpha=0.5)
    plt.colorbar(scatter)
    plt.title('t-SNE Visualization of Learned Representations')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.show()

# Example Usage
train_dataloader = # Your train DataLoader here

# Extract features and labels from the model
features, labels = extract_features(model, train_dataloader, device='cuda')

# Plot t-SNE visualization
plot_tsne(features, labels)

## Post Hoc Model Explainability - GradCAM++

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2
from captum.attr import LayerGradCam, LayerGradCamPlusPlus
from captum.attr import GuidedBackprop

# Function to visualize GradCAM++ output
def visualize_gradcam_plus_plus(model, input_tensor, target_class=None, layer_name='decoder', device='cuda'):
    # Move input tensor to the appropriate device
    input_tensor = input_tensor.to(device)

    # Set model to evaluation mode
    model.eval()
    
    # Create GradCAM++ object for the specified layer
    gradcam_pp = LayerGradCamPlusPlus(model, model.decoder[2])  # Assuming 'decoder' is your final layer, adjust accordingly.
    
    # Apply GradCAM++ to input tensor
    attributions = gradcam_pp.attribute(input_tensor, target=target_class)
    
    # Convert attributions to numpy
    attributions = attributions.cpu().detach().numpy()
    
    # Normalize the heatmap
    heatmap = np.sum(attributions[0], axis=0)  # Summing over channels for a single-channel heatmap
    heatmap = np.maximum(heatmap, 0)  # ReLU to ignore negative values
    heatmap = cv2.resize(heatmap, (input_tensor.shape[2], input_tensor.shape[3]))  # Resize to input size
    heatmap = cv2.normalize(heatmap, None, 0, 1, cv2.NORM_MINMAX)  # Normalize the heatmap to [0, 1]
    
    # Convert the original image to numpy and rescale
    input_image = input_tensor[0].cpu().detach().numpy().transpose(1, 2, 0)
    input_image = cv2.resize(input_image, (input_tensor.shape[2], input_tensor.shape[3]))  # Resize to input size
    
    # Create a colormap for the heatmap
    colormap = plt.get_cmap('jet')
    colored_heatmap = colormap(heatmap)  # Apply colormap
    
    # Overlay the heatmap on top of the original image
    superimposed_img = np.uint8(input_image * 255)  # Convert original image back to [0, 255] range
    superimposed_img = cv2.addWeighted(superimposed_img, 0.7, np.uint8(colored_heatmap[:, :, :3] * 255), 0.3, 0)
    
    # Plot the result
    plt.imshow(superimposed_img)
    plt.title('GradCAM++ Heatmap')
    plt.axis('off')
    plt.show()

# Example Usage (assuming `model` is your trained model and `input_tensor` is an input image)
input_tensor = torch.randn(1, 3, 128, 128).to(device)  # Example input, use your actual input tensor
target_class = None  # You can provide a target class or leave as None for model's top predicted class

visualize_gradcam_plus_plus(model, input_tensor, target_class=target_class, layer_name='decoder', device='cuda')