# MS Lesion Segmentation UNET Inference

Perform inference for an ensemble of models:
* save 3D Nifti images of predicted probability maps averaged across ensemble models (saved to "*pred_prob.nii.gz" files), 
* binary segmentation maps predicted obtained by thresholding of average predictions and removing all connected components smaller than 9 voxels (saved to "pred_seg.nii.gz"), 
* uncertainty maps for reversed mutual information measure (saved to "uncs_rmi.nii.gz").

## Install libraries 

In [1]:
!pip install monai==0.9.0

Collecting monai==0.9.0
  Downloading monai-0.9.0-202206131636-py3-none-any.whl.metadata (7.3 kB)
Downloading monai-0.9.0-202206131636-py3-none-any.whl (939 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m939.7/939.7 kB[0m [31m22.8 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: monai
Successfully installed monai-0.9.0


## Libraries import

In [2]:
import os
import re
import torch
import numpy as np
from glob import glob
from monai.inferers import sliding_window_inference
from monai.networks.nets import UNet
from monai.data import write_nifti
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    AddChanneld, Compose, LoadImaged, RandCropByPosNegLabeld,
    Spacingd, ToTensord, NormalizeIntensityd, RandFlipd,
    RandRotate90d, RandShiftIntensityd, RandAffined, RandSpatialCropd,
    RandScaleIntensityd)
from scipy import ndimage
#from data_load import remove_connected_components, get_flair_dataloader
#from uncertainty import ensemble_uncertainties_classification

2024-04-26 15:41:35.252519: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-26 15:41:35.252624: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-26 15:41:35.384221: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Setup functions

In [3]:
def get_default_device():
    """ Set device """
    if torch.cuda.is_available():
        print("Got CUDA!")
        return torch.device('cuda')
    else:
        return torch.device('cpu')

### Data load

In [4]:
def get_val_transforms(keys=["image", "label"], image_keys=["image"]):
    """ Get transforms for testing on FLAIR images and ground truth:
    - Loads 3D images and masks from Nifti file
    - Adds channel dimention
    - Applies intensity normalisation to scans
    - Converts to torch.Tensor()
    """
    return Compose(
        [
            LoadImaged(keys=keys),
            AddChanneld(keys=keys),
            NormalizeIntensityd(keys=image_keys, nonzero=True),
            ToTensord(keys=keys),
        ]
    )

In [5]:
def get_flair_dataloader(flair_path, num_workers, cache_rate=0.1, bm_path=None):
    """
    Get dataloader with FLAIR images only for inference
    
    Args:
      flair_path: `str`, path to directory with FLAIR images from Train set.
      num_workers:  `int`,  number of worker threads to use for parallel processing
                    of images
      cache_rate:  `float` in (0.0, 1.0], percentage of cached data in total.
      bm_path:   `None|str`. If `str`, then defines path to directory with
                 brain masks. If `None`, dataloader does not return brain masks.
    Returns:
      monai.data.DataLoader() class object.
    """
    flair = sorted(glob(os.path.join(flair_path, "*FLAIR_isovox.nii")),
                   key=lambda i: int(re.sub('\D', '', i)))  # Collect all flair images sorted

    if bm_path is not None:
        bms = sorted(glob(os.path.join(bm_path, "*isovox_fg_mask.nii")),
                     key=lambda i: int(re.sub('\D', '', i)))  # Collect all corresponding brain masks

        assert len(flair) == len(bms), f"Some files must be missing: {[len(flair), len(bms)]}"

        files = [{"image": fl, "brain_mask": bm} for fl, bm in zip(flair, bms)]

        val_transforms = get_val_transforms(keys=["image", "brain_mask"])
    else:
        files = [{"image": fl} for fl in flair]

        val_transforms = get_val_transforms(keys=["image"])

    print("Number of FLAIR files:", len(files))

    ds = CacheDataset(data=files, transform=val_transforms,
                      cache_rate=cache_rate, num_workers=num_workers)
    return DataLoader(ds, batch_size=1, shuffle=False,
                      num_workers=num_workers)

In [6]:
def remove_connected_components(segmentation, l_min=9):
    """
    Remove all lesions with less or equal amount of voxels than `l_min` from a 
    binary segmentation mask `segmentation`.
    Args:
      segmentation: `numpy.ndarray` of shape [H, W, D], with a binary lesions segmentation mask.
      l_min:  `int`, minimal amount of voxels in a lesion.
    Returns:
      Binary lesion segmentation mask (`numpy.ndarray` of shape [H, W, D])
      only with connected components that have more than `l_min` voxels.
    """
    labeled_seg, num_labels = ndimage.label(segmentation)
    label_list = np.unique(labeled_seg)
    num_elements_by_lesion = ndimage.labeled_comprehension(segmentation, labeled_seg, label_list, np.sum, float, 0)

    seg2 = np.zeros_like(segmentation)
    for i_el, n_el in enumerate(num_elements_by_lesion):
        if n_el > l_min:
            current_voxels = np.stack(np.where(labeled_seg == i_el), axis=1)
            seg2[current_voxels[:, 0],
                 current_voxels[:, 1],
                 current_voxels[:, 2]] = 1
    return seg2

### Uncertainty

In [7]:
def renyi_entropy_of_expected(probs, alpha=0.8):
    """
    Renyi entropy is a generalised version of Shannon - the two are equivalent for alpha=1
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: array [num_voxels_X, num_voxels_Y, num_voxels_Z,]
    """
    scale = 1. / (1. - alpha)
    mean_probs = np.mean(probs, axis=0)
    return scale * np.log( np.sum(mean_probs**alpha, axis=-1) )

def renyi_expected_entropy(probs, alpha=0.8):
    """
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: array [num_voxels_X, num_voxels_Y, num_voxels_Z,]
    """
    scale = 1. / (1. - alpha)
    return np.mean( scale * np.log( np.sum(probs**alpha, axis=-1) ), axis=0)


def entropy_of_expected(probs, epsilon=1e-10):
    """
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: array [num_voxels_X, num_voxels_Y, num_voxels_Z,]
    """
    mean_probs = np.mean(probs, axis=0)
    log_probs = -np.log(mean_probs + epsilon)
    return np.sum(mean_probs * log_probs, axis=-1)

def expected_entropy(probs, epsilon=1e-10):
    """
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: array [num_voxels_X, num_voxels_Y, num_voxels_Z,]
    """
    log_probs = -np.log(probs + epsilon)
    return np.mean(np.sum(probs * log_probs, axis=-1), axis=0)


def ensemble_uncertainties_classification(probs, epsilon=1e-10):
    """
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: Dictionary of uncertainties
    """
    mean_probs = np.mean(probs, axis=0)
    mean_lprobs = np.mean(np.log(probs + epsilon), axis=0)
    conf = np.max(mean_probs, axis=-1)

    eoe = entropy_of_expected(probs, epsilon)
    exe = expected_entropy(probs, epsilon)

    mutual_info = eoe - exe

    epkl = -np.sum(mean_probs * mean_lprobs, axis=-1) - exe

    uncertainty = {'confidence': -1 * conf,
                   'entropy_of_expected': eoe,
                   'expected_entropy': exe,
                   'mutual_information': mutual_info,
                   'epkl': epkl,
                   'reverse_mutual_information': epkl - mutual_info,
                   }

    return uncertainty

### Inference function

In [11]:
def inferenceUNET(path_pred, path_data, path_bm, threshold = 0.35, num_models = 3, path_model = '', num_workers = 1):
    
    #Setting up output directory
    os.makedirs(path_pred, exist_ok=True)
    
    #Settin up device
    device = get_default_device()
    torch.multiprocessing.set_sharing_strategy('file_system')
    
    #Initialise dataloaders
    val_loader = get_flair_dataloader(flair_path=path_data,
                                      num_workers=num_workers,
                                      bm_path=path_bm)
    
    #Load trained models
    K = num_models
    models = []
    for i in range(K):
        models.append(UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2,
            channels=(32, 64, 128, 256, 512),
            strides=(2, 2, 2, 2),
            num_res_units=0).to(device))
    
    if(get_default_device() == torch.device('cpu')):
        for i, model in enumerate(models):
            model.load_state_dict(torch.load(os.path.join(path_model,
                                                      f"seed{i + 1}",
                                                      "Best_model_finetuning.pth"),
                                                      map_location=torch.device('cpu')))
            model.eval()
    else :
        for i, model in enumerate(models):
            model.load_state_dict(torch.load(os.path.join(path_model,
                                                      f"seed{i + 1}",
                                                      "Best_model_finetuning.pth")))
            model.eval()
            
    act = torch.nn.Softmax(dim=1)
    th = threshold
    roi_size = (96, 96, 96)
    sw_batch_size = 4
    
    #Predictions loop
    with torch.no_grad():
        for count, batch_data in enumerate(val_loader):
            inputs = batch_data["image"].to(device)
            foreground_mask = batch_data["brain_mask"].numpy()[0, 0]

            # get ensemble predictions
            all_outputs = []
            for model in models:
                outputs = sliding_window_inference(inputs, roi_size, sw_batch_size, model, mode='gaussian')
                outputs = act(outputs).cpu().numpy()
                outputs = np.squeeze(outputs[0, 1])
                all_outputs.append(outputs)
            all_outputs = np.asarray(all_outputs)

            # get image metadata
            original_affine = batch_data['image_meta_dict']['original_affine'][0]
            affine = batch_data['image_meta_dict']['affine'][0]
            spatial_shape = batch_data['image_meta_dict']['spatial_shape'][0]
            filename_or_obj = batch_data['image_meta_dict']['filename_or_obj'][0]
            filename_or_obj = os.path.basename(filename_or_obj)

            # obtain and save probability maps averaged across models in an ensemble
            outputs_mean = np.mean(all_outputs, axis=0)

            filename = re.sub("FLAIR_isovox.nii", 'pred_prob.nii.gz',
                              filename_or_obj)
            filepath = os.path.join(path_pred, filename)
            write_nifti(outputs_mean, filepath,
                        affine=original_affine,
                        target_affine=affine,
                        output_spatial_shape=spatial_shape)

            # obtain and save binary segmentation masks
            seg = outputs_mean.copy()
            seg[seg >= th] = 1
            seg[seg < th] = 0
            seg = np.squeeze(seg)
            seg = remove_connected_components(seg)

            filename = re.sub("FLAIR_isovox.nii", 'pred_seg.nii.gz',
                              filename_or_obj)
            filepath = os.path.join(path_pred, filename)
            write_nifti(seg, filepath,
                        affine=original_affine,
                        target_affine=affine,
                        mode='nearest',
                        output_spatial_shape=spatial_shape)

            # obtain and save uncertainty map (voxel-wise reverse mutual information)
            uncs_map = ensemble_uncertainties_classification(np.concatenate(
                (np.expand_dims(all_outputs, axis=-1),
                 np.expand_dims(1. - all_outputs, axis=-1)),
                axis=-1))['reverse_mutual_information']

            filename = re.sub("FLAIR_isovox.nii", 'uncs_rmi.nii.gz',
                              filename_or_obj)
            filepath = os.path.join(path_pred, filename)
            write_nifti(uncs_map * foreground_mask, filepath,
                        affine=original_affine,
                        target_affine=affine,
                        output_spatial_shape=spatial_shape)

## Using the model

In [17]:
path_pred = '/kaggle/working/predictions'
path_data = "/kaggle/input/sdcombinedextracted/ShiftsDatasetCombinedExtracted/Test/FLAIR"
path_bm = "/kaggle/input/sdcombinedextracted/ShiftsDatasetCombinedExtracted/Test/FgMasks"
path_model = "/kaggle/input/sdcombinedextracted/baselinetrained"
inferenceUNET(path_pred, path_data, path_bm, threshold = 0.35, num_models = 1, path_model = path_model, num_workers = 1)
print("All Done!")

Got CUDA!
Number of FLAIR files: 33


Loading dataset: 100%|██████████| 3/3 [00:00<00:00,  5.06it/s]


Got CUDA!
All Done!
