# **MS LESION SEGMENTATION TESTING SegResNet** 

Computation of performance metrics (nDSC, lesion F1 score, nDSC R-AUC) for a model.

## Install Libraries 

In [3]:
!pip install monai==0.9.0

Collecting monai==0.9.0
  Using cached monai-0.9.0-202206131636-py3-none-any.whl.metadata (7.3 kB)
Using cached monai-0.9.0-202206131636-py3-none-any.whl (939 kB)
Installing collected packages: monai
Successfully installed monai-0.9.0


## Libraries Import

In [4]:
import os
import json
import torch
import numpy as np
import re
from joblib import Parallel
from monai.inferers import sliding_window_inference
from monai.networks.nets import SegResNet
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    AddChanneld, Compose, LoadImaged, RandCropByPosNegLabeld,
    Spacingd, ToTensord, NormalizeIntensityd, RandFlipd,
    RandRotate90d, RandShiftIntensityd, RandAffined, RandSpatialCropd,
    RandScaleIntensityd)
from glob import glob
from scipy import ndimage
from functools import partial
from collections import Counter
from joblib import Parallel, delayed
from sklearn import metrics

2024-05-09 15:17:35.343384: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-09 15:17:35.343549: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-09 15:17:35.504972: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Setup Functions

In [5]:
def get_default_device():
    """ Set device """
    if torch.cuda.is_available():
        print("Got CUDA!")
        return torch.device('cuda')
    else:
        return torch.device('cpu')

### Data Load

In [6]:
def remove_connected_components(segmentation, l_min=9):
    """
    Remove all lesions with less or equal amount of voxels than `l_min` from a 
    binary segmentation mask `segmentation`.
    Args:
      segmentation: `numpy.ndarray` of shape [H, W, D], with a binary lesions segmentation mask.
      l_min:  `int`, minimal amount of voxels in a lesion.
    Returns:
      Binary lesion segmentation mask (`numpy.ndarray` of shape [H, W, D])
      only with connected components that have more than `l_min` voxels.
    """
    labeled_seg, num_labels = ndimage.label(segmentation)
    label_list = np.unique(labeled_seg)
    num_elements_by_lesion = ndimage.labeled_comprehension(segmentation, labeled_seg, label_list, np.sum, float, 0)

    seg2 = np.zeros_like(segmentation)
    for i_el, n_el in enumerate(num_elements_by_lesion):
        if n_el > l_min:
            current_voxels = np.stack(np.where(labeled_seg == i_el), axis=1)
            seg2[current_voxels[:, 0],
                 current_voxels[:, 1],
                 current_voxels[:, 2]] = 1
    return seg2

In [7]:
def get_val_transforms(keys=["image", "label"], image_keys=["image"]):
    """ Get transforms for testing on FLAIR images and ground truth:
    - Loads 3D images and masks from Nifti file
    - Adds channel dimention
    - Applies intensity normalisation to scans
    - Converts to torch.Tensor()
    """
    return Compose(
        [
            LoadImaged(keys=keys),
            AddChanneld(keys=keys),
            NormalizeIntensityd(keys=image_keys, nonzero=True),
            ToTensord(keys=keys),
        ]
    )

In [8]:
def get_val_dataloader(flair_path, gts_path, num_workers, cache_rate=0.1, bm_path=None):
    """
    Get dataloader for validation and testing. Either with or without brain masks.

    Args:
      flair_path: `str`, path to directory with FLAIR images.
      gts_path:  `str`, path to directory with ground truth lesion segmentation 
                    binary masks images.
      num_workers:  `int`,  number of worker threads to use for parallel processing
                    of images
      cache_rate:  `float` in (0.0, 1.0], percentage of cached data in total.
      bm_path:   `None|str`. If `str`, then defines path to directory with
                 brain masks. If `None`, dataloader does not return brain masks. 
    Returns:
      monai.data.DataLoader() class object.
    """
    flair = sorted(glob(os.path.join(flair_path, "*FLAIR_isovox.nii")),
                   key=lambda i: int(re.sub('\D', '', i)))  # Collect all flair images sorted
    segs = sorted(glob(os.path.join(gts_path, "*_isovox.nii")),
                  key=lambda i: int(re.sub('\D', '', i)))  # Collect all corresponding ground truths

    if bm_path is not None:
        bms = sorted(glob(os.path.join(bm_path, "*isovox_fg_mask.nii")),
                     key=lambda i: int(re.sub('\D', '', i)))  # Collect all corresponding brain masks

        assert len(flair) == len(segs) == len(bms), f"Some files must be missing: {[len(flair), len(segs), len(bms)]}"

        files = [
            {"image": fl, "label": seg, "brain_mask": bm} for fl, seg, bm
            in zip(flair, segs, bms)
        ]

        val_transforms = get_val_transforms(keys=["image", "label", "brain_mask"])
    else:
        assert len(flair) == len(segs), f"Some files must be missing: {[len(flair), len(segs)]}"

        files = [{"image": fl, "label": seg} for fl, seg in zip(flair, segs)]

        val_transforms = get_val_transforms()

    print("Number of validation files:", len(files))

    ds = CacheDataset(data=files, transform=val_transforms,
                      cache_rate=cache_rate, num_workers=num_workers)
    return DataLoader(ds, batch_size=1, shuffle=False,
                      num_workers=num_workers)


### Metrics

In [9]:
def dice_norm_metric(ground_truth, predictions):
    """
    Compute Normalised Dice Coefficient (nDSC), 
    False positive rate (FPR),
    False negative rate (FNR) for a single example.
    
    Args:
      ground_truth: `numpy.ndarray`, binary ground truth segmentation target,
                     with shape [H, W, D].
      predictions:  `numpy.ndarray`, binary segmentation predictions,
                     with shape [H, W, D].
    Returns:
      Normalised dice coefficient (`float` in [0.0, 1.0]),
      False positive rate (`float` in [0.0, 1.0]),
      False negative rate (`float` in [0.0, 1.0]),
      between `ground_truth` and `predictions`.
    """

    # Reference for normalized DSC
    r = 0.001
    # Cast to float32 type
    gt = ground_truth.astype("float32")
    seg = predictions.astype("float32")
    im_sum = np.sum(seg) + np.sum(gt)
    if im_sum == 0:
        return 1.0
    else:
        if np.sum(gt) == 0:
            k = 1.0
        else:
            k = (1 - r) * np.sum(gt) / (r * (len(gt.flatten()) - np.sum(gt)))
        tp = np.sum(seg[gt == 1])
        fp = np.sum(seg[gt == 0])
        fn = np.sum(gt[seg == 0])
        fp_scaled = k * fp
        dsc_norm = 2. * tp / (fp_scaled + 2. * tp + fn)
        return dsc_norm

In [10]:
def ndsc_aac_metric(ground_truth, predictions, uncertainties, parallel_backend=None):
    """
    Compute area above Normalised Dice Coefficient (nDSC) retention curve for 
    one subject. `ground_truth`, `predictions`, `uncertainties` - are flattened 
    arrays of correponding 3D maps within the foreground mask only.
    
    Args:
      ground_truth: `numpy.ndarray`, binary ground truth segmentation target,
                     with shape [H * W * D]. 
      predictions:  `numpy.ndarray`, binary segmentation predictions,
                     with shape [H * W * D].
      uncertainties:  `numpy.ndarray`, voxel-wise uncertainties,
                     with shape [H * W * D].
      parallel_backend: `joblib.Parallel`, for parallel computation
                     for different retention fractions.
    Returns:
      nDSC R-AAC (`float` in [0.0, 1.0]).
    """

    def compute_dice_norm(frac_, preds_, gts_, N_):
        pos = int(N_ * frac_)
        curr_preds = preds if pos == N_ else np.concatenate(
            (preds_[:pos], gts_[pos:]))
        return dice_norm_metric(gts_, curr_preds)

    if parallel_backend is None:
        parallel_backend = Parallel(n_jobs=1)

    ordering = uncertainties.argsort()
    gts = ground_truth[ordering].copy()
    preds = predictions[ordering].copy()
    N = len(gts)

    # # Significant class imbalance means it is important to use logspacing between values
    # # so that it is more granular for the higher retention fractions
    fracs_retained = np.log(np.arange(200 + 1)[1:])
    fracs_retained /= np.amax(fracs_retained)

    process = partial(compute_dice_norm, preds_=preds, gts_=gts, N_=N)
    dsc_norm_scores = np.asarray(
        parallel_backend(delayed(process)(frac)
                         for frac in fracs_retained)
    )

    return 1. - metrics.auc(fracs_retained, dsc_norm_scores)

In [11]:
def intersection_over_union(mask1, mask2):
    """
    Compute IoU for 2 binary masks.
    
    Args:
      mask1: `numpy.ndarray`, binary mask.
      mask2:  `numpy.ndarray`, binary mask of the same shape as `mask1`.
    Returns:
      Intersection over union between `mask1` and `mask2` (`float` in [0.0, 1.0]).
    """
    return np.sum(mask1 * mask2) / np.sum(mask1 + mask2 - mask1 * mask2)

In [12]:
def lesion_f1_score(ground_truth, predictions, IoU_threshold=0.25, parallel_backend=None):
    """
    Compute lesion-scale F1 score.
    
    Args:
      ground_truth: `numpy.ndarray`, binary ground truth segmentation target,
                     with shape [H, W, D].
      predictions:  `numpy.ndarray`, binary segmentation predictions,
                     with shape [H, W, D].
      IoU_threshold: `float` in [0.0, 1.0], IoU threshold for max IoU between 
                     predicted and ground truth lesions to classify them as
                     TP, FP or FN.
      parallel_backend: `joblib.Parallel`, for parallel computation
                     for different retention fractions.
    Returns:
      Intersection over union between `mask1` and `mask2` (`float` in [0.0, 1.0]).
    """

    def get_tp_fp(label_pred, mask_multi_pred, mask_multi_gt):
        mask_label_pred = (mask_multi_pred == label_pred).astype(int)
        all_iou = [0.0]
        # iterate only intersections
        for int_label_gt in np.unique(mask_multi_gt * mask_label_pred):
            if int_label_gt != 0.0:
                mask_label_gt = (mask_multi_gt == int_label_gt).astype(int)
                all_iou.append(intersection_over_union(
                    mask_label_pred, mask_label_gt))
        max_iou = max(all_iou)
        if max_iou >= IoU_threshold:
            return 'tp'
        else:
            return 'fp'

    def get_fn(label_gt, mask_multi_pred, mask_multi_gt):
        mask_label_gt = (mask_multi_gt == label_gt).astype(int)
        all_iou = [0]
        for int_label_pred in np.unique(mask_multi_pred * mask_label_gt):
            if int_label_pred != 0.0:
                mask_label_pred = (mask_multi_pred ==
                                   int_label_pred).astype(int)
                all_iou.append(intersection_over_union(
                    mask_label_pred, mask_label_gt))
        max_iou = max(all_iou)
        if max_iou < IoU_threshold:
            return 1
        else:
            return 0

    mask_multi_pred_, n_les_pred = ndimage.label(predictions)
    mask_multi_gt_, n_les_gt = ndimage.label(ground_truth)

    if parallel_backend is None:
        parallel_backend = Parallel(n_jobs=1)

    process_fp_tp = partial(get_tp_fp, mask_multi_pred=mask_multi_pred_,
                            mask_multi_gt=mask_multi_gt_)

    tp_fp = parallel_backend(delayed(process_fp_tp)(label_pred)
                             for label_pred in np.unique(mask_multi_pred_) if label_pred != 0)
    counter = Counter(tp_fp)
    tp = float(counter['tp'])
    fp = float(counter['fp'])

    process_fn = partial(get_fn, mask_multi_pred=mask_multi_pred_,
                         mask_multi_gt=mask_multi_gt_)

    fn = parallel_backend(delayed(process_fn)(label_gt)
                          for label_gt in np.unique(mask_multi_gt_) if label_gt != 0)
    fn = float(np.sum(fn))

    f1 = 1.0 if tp + 0.5 * (fp + fn) == 0.0 else tp / (tp + 0.5 * (fp + fn))

    return f1

### Uncertainty

In [13]:
def renyi_entropy_of_expected(probs, alpha=0.8):
    """
    Renyi entropy is a generalised version of Shannon - the two are equivalent for alpha=1
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: array [num_voxels_X, num_voxels_Y, num_voxels_Z,]
    """
    scale = 1. / (1. - alpha)
    mean_probs = np.mean(probs, axis=0)
    return scale * np.log( np.sum(mean_probs**alpha, axis=-1) )

def renyi_expected_entropy(probs, alpha=0.8):
    """
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: array [num_voxels_X, num_voxels_Y, num_voxels_Z,]
    """
    scale = 1. / (1. - alpha)
    return np.mean( scale * np.log( np.sum(probs**alpha, axis=-1) ), axis=0)


def entropy_of_expected(probs, epsilon=1e-10):
    """
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: array [num_voxels_X, num_voxels_Y, num_voxels_Z,]
    """
    mean_probs = np.mean(probs, axis=0)
    log_probs = -np.log(mean_probs + epsilon)
    return np.sum(mean_probs * log_probs, axis=-1)

def expected_entropy(probs, epsilon=1e-10):
    """
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: array [num_voxels_X, num_voxels_Y, num_voxels_Z,]
    """
    log_probs = -np.log(probs + epsilon)
    return np.mean(np.sum(probs * log_probs, axis=-1), axis=0)


def ensemble_uncertainties_classification(probs, epsilon=1e-10):
    """
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: Dictionary of uncertainties
    """
    mean_probs = np.mean(probs, axis=0)
    mean_lprobs = np.mean(np.log(probs + epsilon), axis=0)
    conf = np.max(mean_probs, axis=-1)

    eoe = entropy_of_expected(probs, epsilon)
    exe = expected_entropy(probs, epsilon)

    mutual_info = eoe - exe

    epkl = -np.sum(mean_probs * mean_lprobs, axis=-1) - exe

    uncertainty = {'confidence': -1 * conf,
                   'entropy_of_expected': eoe,
                   'expected_entropy': exe,
                   'mutual_information': mutual_info,
                   'epkl': epkl,
                   'reverse_mutual_information': epkl - mutual_info,
                   }

    return uncertainty

## Testing Function

In [14]:
"""
Testing function for UNET model.
Args:
    Data Args:
        path_test_data: `str`, path to directory with FLAIR images from test set.
        path_test_gts: `str`, path to directory with ground truth lesion segmentation from test set.
        path_test_bm: `str`, path to directory with brain masks from test set.
    Model Args:
        threshold: `float`, threshold for binary segmentation mask.
        num_models: `int`, number of models to ensemble.
        path_model: `str`, path to directory with trained models.
        num_workers: `int`, number of worker threads to use for parallel processing of images.
        n_jobs: `int`, number of jobs to run in parallel.
"""
def testSegResNet(path_test_data, path_test_gts, path_test_bm, threshold = 0.35, num_models = 3, path_model = '', num_workers = 1, n_jobs = 1):
    #Setting up the device
    device = get_default_device()
    torch.multiprocessing.set_sharing_strategy('file_system')
    
    #Initialise dataloaders
    val_loader = get_val_dataloader(flair_path=path_test_data,
                                    gts_path=path_test_gts,
                                    num_workers=num_workers ,
                                    bm_path=path_test_bm)
    # Load trained models
    K = num_models
    models = []
    for i in range(K):
        models.append(SegResNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=2).to(device))

    if(get_default_device() == torch.device('cpu')):
        for i, model in enumerate(models):
            model.load_state_dict(torch.load(os.path.join(path_model,
                                                      "Best_model_finetuning.pth"),
                                                      map_location=torch.device('cpu')))
            model.eval()
    else :
        for i, model in enumerate(models):
            model.load_state_dict(torch.load(os.path.join(path_model,
                                                      "Best_model_finetuning.pth")))
            model.eval()
    
    act = torch.nn.Softmax(dim=1)
    th = threshold
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    ndsc, f1, ndsc_aac = [], [], []

    #Evaluatioin loop
    with Parallel(n_jobs= n_jobs) as parallel_backend:
        with torch.no_grad():
            for count, batch_data in enumerate(val_loader):
                inputs, gt, brain_mask = (
                    batch_data["image"].to(device),
                    batch_data["label"].cpu().numpy(),
                    batch_data["brain_mask"].cpu().numpy()
                )

                # get ensemble predictions
                all_outputs = []
                for model in models:
                    outputs = sliding_window_inference(inputs, roi_size,
                                                       sw_batch_size, model,
                                                       mode='gaussian')
                    outputs = act(outputs).cpu().numpy()
                    outputs = np.squeeze(outputs[0, 1])
                    all_outputs.append(outputs)
                all_outputs = np.asarray(all_outputs)

                # obtain binary segmentation mask
                seg = np.mean(all_outputs, axis=0)
                seg[seg >= th] = 1
                seg[seg < th] = 0
                seg = np.squeeze(seg)
                seg = remove_connected_components(seg)

                gt = np.squeeze(gt)
                brain_mask = np.squeeze(brain_mask)

                # compute reverse mutual information uncertainty map
                uncs_map = ensemble_uncertainties_classification(np.concatenate(
                    (np.expand_dims(all_outputs, axis=-1),
                     np.expand_dims(1. - all_outputs, axis=-1)),
                    axis=-1))['expected_entropy']

                # compute metrics
                ndsc += [dice_norm_metric(ground_truth=gt, predictions=seg)]
                f1 += [lesion_f1_score(ground_truth=gt,
                                       predictions=seg,
                                       IoU_threshold=0.5,
                                       parallel_backend=parallel_backend)]
                ndsc_aac += [ndsc_aac_metric(ground_truth=gt[brain_mask == 1].flatten(),
                                             predictions=seg[brain_mask == 1].flatten(),
                                             uncertainties=uncs_map[brain_mask == 1].flatten(),
                                             parallel_backend=parallel_backend)]

                # for nervous people
                if count % 10 == 0:
                    print(f"Processed {count}/{len(val_loader)}")

    ndsc = np.asarray(ndsc) * 100.
    f1 = np.asarray(f1) * 100.
    ndsc_aac = np.asarray(ndsc_aac) * 100.
    results = {}
    print(f"nDSC:\t{np.mean(ndsc):.4f} +- {np.std(ndsc):.4f}")
    print(f"Lesion F1 score:\t{np.mean(f1):.4f} +- {np.std(f1):.4f}")
    print(f"nDSC R-AUC:\t{np.mean(ndsc_aac):.4f} +- {np.std(ndsc_aac):.4f}")
    
    results = {
        "model_name: " = "SegResNet"
        "nDSC": np.mean(ndsc),
        "nDSC std": np.std(ndsc),
        "Lesion F1 score": np.mean(f1),
        "Lesion F1 score std": np.std(f1),
        "nDSC R-AUC": np.mean(ndsc_aac),
        "nDSC R-AUC std": np.std(ndsc_aac)
    }

    with open("/kaggle/working/SegResNet_results.json", "w") as file:
        json.dump(results, file)

## Test The Model

In [15]:
path_test_data = '/kaggle/input/sdcombinedextracted/ShiftsDatasetCombinedExtracted/Test/FLAIR'
path_test_gts = '/kaggle/input/sdcombinedextracted/ShiftsDatasetCombinedExtracted/Test/GroundTruth'
path_test_bm = '/kaggle/input/sdcombinedextracted/ShiftsDatasetCombinedExtracted/Test/FgMasks'
path_model = '/kaggle/input/segresnet/SegResNet'

testSegResNet(path_test_data, path_test_gts, path_test_bm, threshold = 0.35, num_models = 1, path_model = path_model, num_workers = 4, n_jobs = -1)

Number of validation files: 33


Loading dataset: 100%|██████████| 3/3 [00:01<00:00,  1.60it/s]
  pid = os.fork()
  self.pid = os.fork()


Processed 0/33
Processed 10/33
Processed 20/33
Processed 30/33


  self.pid = os.fork()


nDSC:	71.7560 +- 8.6960
Lesion F1 score:	35.1109 +- 14.2631
nDSC R-AUC:	8.7450 +- 8.1628
