# Importing stuff

In [32]:
import os
import torch
import time
from functools import partial
from sklearn import metrics
from collections import Counter
from torch import nn
from joblib import Parallel, delayed
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss
from monai.networks.nets import UNet, SwinUNETR, UNETR, SegResNet
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    AddChanneld, Compose, LoadImaged, RandCropByPosNegLabeld,
    Spacingd, ToTensord, NormalizeIntensityd, RandFlipd,
    RandRotate90d, RandShiftIntensityd, RandAffined, RandSpatialCropd,
    RandScaleIntensityd)
    
import numpy as np
import random
from glob import glob
import re

from scipy import ndimage

# Useful functions

## transformations and data loaders

In [3]:
def get_train_transforms():
    """ Get transforms for training on FLAIR images and ground truth:
    - Loads 3D images from Nifti file
    - Adds channel dimention
    - Normalises intensity
    - Applies augmentations
    - Crops out 32 patches of shape [96, 96, 96] that contain lesions
    - Converts to torch.Tensor()
    """
    return Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            NormalizeIntensityd(keys=["image"], nonzero=True),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
            RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
            RandCropByPosNegLabeld(keys=["image", "label"],
                                   label_key="label", image_key="image",
                                   spatial_size=(128, 128, 128), num_samples=32,
                                   pos=4, neg=1),
            RandSpatialCropd(keys=["image", "label"],
                             roi_size=(96, 96, 96),
                             random_center=True, random_size=False),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=(0, 1, 2)),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(0, 1)),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(1, 2)),
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(0, 2)),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'),
                        prob=1.0, spatial_size=(96, 96, 96),
                        rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                        scale_range=(0.1, 0.1, 0.1), padding_mode='border'),
            ToTensord(keys=["image", "label"]),
        ]
    )


def get_val_transforms(keys=["image", "label"], image_keys=["image"]):
    """ Get transforms for testing on FLAIR images and ground truth:
    - Loads 3D images and masks from Nifti file
    - Adds channel dimention
    - Applies intensity normalisation to scans
    - Converts to torch.Tensor()
    """
    return Compose(
        [
            LoadImaged(keys=keys),
            AddChanneld(keys=keys),
            NormalizeIntensityd(keys=image_keys, nonzero=True),
            ToTensord(keys=keys),
        ]
    )


def get_train_dataloader(flair_path, gts_path, num_workers, batch_size, cache_rate=0.1):
    """
    Get dataloader for training 
    Args:
      flair_path: `str`, path to directory with FLAIR images from Train set.
      gts_path:  `str`, path to directory with ground truth lesion segmentation 
                    binary masks images from Train set.
      num_workers:  `int`,  number of worker threads to use for parallel processing
                    of images
      cache_rate:  `float` in (0.0, 1.0], percentage of cached data in total.
    Returns:
      monai.data.DataLoader() class object.
    """
    flair = sorted(glob(os.path.join(flair_path, "*FLAIR_isovox.nii.gz")),
                   key=lambda i: int(re.sub('\D', '', i)))  # Collect all flair images sorted
    segs = sorted(glob(os.path.join(gts_path, "*gt_isovox.nii.gz")),
                  key=lambda i: int(re.sub('\D', '', i)))  # Collect all corresponding ground truths

    files = [{"image": fl, "label": seg} for fl, seg in zip(flair, segs)]

    print("Number of training files:", len(files))

    ds = CacheDataset(data=files, transform=get_train_transforms(),
                      cache_rate=cache_rate, num_workers=num_workers)
    return DataLoader(ds, batch_size=batch_size, shuffle=True,
                      num_workers=num_workers)


def get_val_dataloader(flair_path, gts_path, num_workers, cache_rate=0.1, bm_path=None):
    """
    Get dataloader for validation and testing. Either with or without brain masks.

    Args:
      flair_path: `str`, path to directory with FLAIR images.
      gts_path:  `str`, path to directory with ground truth lesion segmentation 
                    binary masks images.
      num_workers:  `int`,  number of worker threads to use for parallel processing
                    of images
      cache_rate:  `float` in (0.0, 1.0], percentage of cached data in total.
      bm_path:   `None|str`. If `str`, then defines path to directory with
                 brain masks. If `None`, dataloader does not return brain masks. 
    Returns:
      monai.data.DataLoader() class object.
    """
    flair = sorted(glob(os.path.join(flair_path, "*FLAIR_isovox.nii.gz")),
                   key=lambda i: int(re.sub('\D', '', i)))  # Collect all flair images sorted
    segs = sorted(glob(os.path.join(gts_path, "*_isovox.nii.gz")),
                  key=lambda i: int(re.sub('\D', '', i)))  # Collect all corresponding ground truths

    if bm_path is not None:
        bms = sorted(glob(os.path.join(bm_path, "*isovox_fg_mask.nii.gz")),
                     key=lambda i: int(re.sub('\D', '', i)))  # Collect all corresponding brain masks

        assert len(flair) == len(segs) == len(bms), f"Some files must be missing: {[len(flair), len(segs), len(bms)]}"

        files = [
            {"image": fl, "label": seg, "brain_mask": bm} for fl, seg, bm
            in zip(flair, segs, bms)
        ]

        val_transforms = get_val_transforms(keys=["image", "label", "brain_mask"])
    else:
        assert len(flair) == len(segs), f"Some files must be missing: {[len(flair), len(segs)]}"

        files = [{"image": fl, "label": seg} for fl, seg in zip(flair, segs)]

        val_transforms = get_val_transforms()

    print("Number of validation files:", len(files))

    ds = CacheDataset(data=files, transform=val_transforms,
                      cache_rate=cache_rate, num_workers=num_workers)
    return DataLoader(ds, batch_size=1, shuffle=False,
                      num_workers=num_workers)

## Metrics

In [33]:
def intersection_over_union(mask1, mask2):
    """
    Compute IoU for 2 binary masks.
    
    Args:
      mask1: `numpy.ndarray`, binary mask.
      mask2:  `numpy.ndarray`, binary mask of the same shape as `mask1`.
    Returns:
      Intersection over union between `mask1` and `mask2` (`float` in [0.0, 1.0]).
    """
    return np.sum(mask1 * mask2) / np.sum(mask1 + mask2 - mask1 * mask2)
    
def lesion_f1_score(ground_truth, predictions, IoU_threshold=0.25, parallel_backend=None):
    """
    Compute lesion-scale F1 score.
    
    Args:
      ground_truth: `numpy.ndarray`, binary ground truth segmentation target,
                     with shape [H, W, D].
      predictions:  `numpy.ndarray`, binary segmentation predictions,
                     with shape [H, W, D].
      IoU_threshold: `float` in [0.0, 1.0], IoU threshold for max IoU between 
                     predicted and ground truth lesions to classify them as
                     TP, FP or FN.
      parallel_backend: `joblib.Parallel`, for parallel computation
                     for different retention fractions.
    Returns:
      Intersection over union between `mask1` and `mask2` (`float` in [0.0, 1.0]).
    """

    def get_tp_fp(label_pred, mask_multi_pred, mask_multi_gt):
        mask_label_pred = (mask_multi_pred == label_pred).astype(int)
        all_iou = [0.0]
        # iterate only intersections
        for int_label_gt in np.unique(mask_multi_gt * mask_label_pred):
            if int_label_gt != 0.0:
                mask_label_gt = (mask_multi_gt == int_label_gt).astype(int)
                all_iou.append(intersection_over_union(
                    mask_label_pred, mask_label_gt))
        max_iou = max(all_iou)
        if max_iou >= IoU_threshold:
            return 'tp'
        else:
            return 'fp'

    def get_fn(label_gt, mask_multi_pred, mask_multi_gt):
        mask_label_gt = (mask_multi_gt == label_gt).astype(int)
        all_iou = [0]
        for int_label_pred in np.unique(mask_multi_pred * mask_label_gt):
            if int_label_pred != 0.0:
                mask_label_pred = (mask_multi_pred ==
                                   int_label_pred).astype(int)
                all_iou.append(intersection_over_union(
                    mask_label_pred, mask_label_gt))
        max_iou = max(all_iou)
        if max_iou < IoU_threshold:
            return 1
        else:
            return 0

    mask_multi_pred_, n_les_pred = ndimage.label(predictions)
    mask_multi_gt_, n_les_gt = ndimage.label(ground_truth)

    if parallel_backend is None:
        parallel_backend = Parallel(n_jobs=1)

    process_fp_tp = partial(get_tp_fp, mask_multi_pred=mask_multi_pred_,
                            mask_multi_gt=mask_multi_gt_)

    tp_fp = parallel_backend(delayed(process_fp_tp)(label_pred)
                             for label_pred in np.unique(mask_multi_pred_) if label_pred != 0)
    counter = Counter(tp_fp)
    tp = float(counter['tp'])
    fp = float(counter['fp'])

    process_fn = partial(get_fn, mask_multi_pred=mask_multi_pred_,
                         mask_multi_gt=mask_multi_gt_)

    fn = parallel_backend(delayed(process_fn)(label_gt)
                          for label_gt in np.unique(mask_multi_gt_) if label_gt != 0)
    fn = float(np.sum(fn))

    f1 = 1.0 if tp + 0.5 * (fp + fn) == 0.0 else tp / (tp + 0.5 * (fp + fn))

    return f1

In [34]:
def dice_norm_metric(ground_truth, predictions):
    """
    Compute Normalised Dice Coefficient (nDSC), 
    False positive rate (FPR),
    False negative rate (FNR) for a single example.
    
    Args:
      ground_truth: `numpy.ndarray`, binary ground truth segmentation target,
                     with shape [H, W, D].
      predictions:  `numpy.ndarray`, binary segmentation predictions,
                     with shape [H, W, D].
    Returns:
      Normalised dice coefficient (`float` in [0.0, 1.0]),
      False positive rate (`float` in [0.0, 1.0]),
      False negative rate (`float` in [0.0, 1.0]),
      between `ground_truth` and `predictions`.
    """

    # Reference for normalized DSC
    r = 0.001
    # Cast to float32 type
    gt = ground_truth.astype("float32")
    seg = predictions.astype("float32")
    im_sum = np.sum(seg) + np.sum(gt)
    if im_sum == 0:
        return 1.0
    else:
        if np.sum(gt) == 0:
            k = 1.0
        else:
            k = (1 - r) * np.sum(gt) / (r * (len(gt.flatten()) - np.sum(gt)))
        tp = np.sum(seg[gt == 1])
        fp = np.sum(seg[gt == 0])
        fn = np.sum(gt[seg == 0])
        fp_scaled = k * fp
        dsc_norm = 2. * tp / (fp_scaled + 2. * tp + fn)
        return dsc_norm

In [35]:
def ndsc_aac_metric(ground_truth, predictions, uncertainties, parallel_backend=None):
    """
    Compute area above Normalised Dice Coefficient (nDSC) retention curve for 
    one subject. `ground_truth`, `predictions`, `uncertainties` - are flattened 
    arrays of correponding 3D maps within the foreground mask only.
    
    Args:
      ground_truth: `numpy.ndarray`, binary ground truth segmentation target,
                     with shape [H * W * D]. 
      predictions:  `numpy.ndarray`, binary segmentation predictions,
                     with shape [H * W * D].
      uncertainties:  `numpy.ndarray`, voxel-wise uncertainties,
                     with shape [H * W * D].
      parallel_backend: `joblib.Parallel`, for parallel computation
                     for different retention fractions.
    Returns:
      nDSC R-AAC (`float` in [0.0, 1.0]).
    """

    def compute_dice_norm(frac_, preds_, gts_, N_):
        pos = int(N_ * frac_)
        curr_preds = preds if pos == N_ else np.concatenate(
            (preds_[:pos], gts_[pos:]))
        return dice_norm_metric(gts_, curr_preds)

    if parallel_backend is None:
        parallel_backend = Parallel(n_jobs=1)

    ordering = uncertainties.argsort()
    gts = ground_truth[ordering].copy()
    preds = predictions[ordering].copy()
    N = len(gts)

    # # Significant class imbalance means it is important to use logspacing between values
    # # so that it is more granular for the higher retention fractions
    fracs_retained = np.log(np.arange(200 + 1)[1:])
    fracs_retained /= np.amax(fracs_retained)

    process = partial(compute_dice_norm, preds_=preds, gts_=gts, N_=N)
    dsc_norm_scores = np.asarray(
        parallel_backend(delayed(process)(frac)
                         for frac in fracs_retained)
    )

    return 1. - metrics.auc(fracs_retained, dsc_norm_scores)

In [36]:
def remove_connected_components(segmentation, l_min=9):
    """
    Remove all lesions with less or equal amount of voxels than `l_min` from a 
    binary segmentation mask `segmentation`.
    Args:
      segmentation: `numpy.ndarray` of shape [H, W, D], with a binary lesions segmentation mask.
      l_min:  `int`, minimal amount of voxels in a lesion.
    Returns:
      Binary lesion segmentation mask (`numpy.ndarray` of shape [H, W, D])
      only with connected components that have more than `l_min` voxels.
    """
    labeled_seg, num_labels = ndimage.label(segmentation)
    label_list = np.unique(labeled_seg)
    num_elements_by_lesion = ndimage.labeled_comprehension(segmentation, labeled_seg, label_list, np.sum, float, 0)

    seg2 = np.zeros_like(segmentation)
    for i_el, n_el in enumerate(num_elements_by_lesion):
        if n_el > l_min:
            current_voxels = np.stack(np.where(labeled_seg == i_el), axis=1)
            seg2[current_voxels[:, 0],
                 current_voxels[:, 1],
                 current_voxels[:, 2]] = 1
    return seg2


## Uncertainty estimates

In [23]:
def renyi_entropy_of_expected(probs, alpha=0.8):
    """
    Renyi entropy is a generalised version of Shannon - the two are equivalent for alpha=1
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: array [num_voxels_X, num_voxels_Y, num_voxels_Z,]
    """
    scale = 1. / (1. - alpha)
    mean_probs = np.mean(probs, axis=0)
    return scale * np.log( np.sum(mean_probs**alpha, axis=-1) )

def renyi_expected_entropy(probs, alpha=0.8):
    """
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: array [num_voxels_X, num_voxels_Y, num_voxels_Z,]
    """
    scale = 1. / (1. - alpha)
    return np.mean( scale * np.log( np.sum(probs**alpha, axis=-1) ), axis=0)


def entropy_of_expected(probs, epsilon=1e-10):
    """
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: array [num_voxels_X, num_voxels_Y, num_voxels_Z,]
    """
    mean_probs = np.mean(probs, axis=0)
    log_probs = -np.log(mean_probs + epsilon)
    return np.sum(mean_probs * log_probs, axis=-1)

def expected_entropy(probs, epsilon=1e-10):
    """
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: array [num_voxels_X, num_voxels_Y, num_voxels_Z,]
    """
    log_probs = -np.log(probs + epsilon)
    return np.mean(np.sum(probs * log_probs, axis=-1), axis=0)


def ensemble_uncertainties_classification(probs, epsilon=1e-10):
    """
    :param probs: array [num_models, num_voxels_X, num_voxels_Y, num_voxels_Z, num_classes]
    :return: Dictionary of uncertainties
    """
    mean_probs = np.mean(probs, axis=0)
    mean_lprobs = np.mean(np.log(probs + epsilon), axis=0)
    conf = np.max(mean_probs, axis=-1)

    eoe = entropy_of_expected(probs, epsilon)
    exe = expected_entropy(probs, epsilon)

    mutual_info = eoe - exe

    epkl = -np.sum(mean_probs * mean_lprobs, axis=-1) - exe

    uncertainty = {'confidence': -1 * conf,
                   'entropy_of_expected': eoe,
                   'expected_entropy': exe,
                   'mutual_information': mutual_info,
                   'epkl': epkl,
                   'reverse_mutual_information': epkl - mutual_info,
                   }

    return uncertainty

## Training function

In [7]:
def train(model,train_loader,optimizer, loss_function, gamma_focal, dice_weight, focal_weight, device):
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        n_samples = batch_data["image"].size(0)
        for m in range(0,batch_data["image"].size(0), 8):
            step += 8
            inputs, labels = (
                batch_data["image"][m:(m+8)].to(device),
                batch_data["label"][m:(m+8)].type(torch.LongTensor).to(device))
            optimizer.zero_grad()
            outputs = model(inputs)
            
            # Dice loss
            loss1 = loss_function(outputs, labels)
            # Focal loss
            ce_loss = nn.CrossEntropyLoss(reduction='none')
            ce = ce_loss(outputs, torch.squeeze(labels, dim=1))
            pt = torch.exp(-ce)
            loss2 = (1 - pt)**gamma_focal * ce 
            loss2 = torch.mean(loss2)
            loss = dice_weight * loss1 + focal_weight * loss2              
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            if step % 100 == 0:
                step_print = int(step/2)
                print(f"{step_print}/{(len(train_loader)*n_samples) // (train_loader.batch_size*2)}, train_loss: {loss.item():.4f}")
    epoch_loss /= step_print
    return epoch_loss

## Validaion function

In [9]:
def validation(model,val_loader, dice_norm_metric, roi_size, sw_batch_size, act, thresh,device):
    model.val()
    metric_sum = 0.0
    metric_count = 0
    for val_data in val_loader:
        val_inputs, val_labels = (
            val_data["image"].to(device),
            val_data["label"].to(device)
            )
        
        val_outputs = sliding_window_inference(val_inputs, roi_size, 
                                                sw_batch_size, 
                                                model, mode='gaussian')
        
        gt = np.squeeze(val_labels.cpu().numpy())
        
        seg = act(val_outputs).cpu().numpy()
        seg= np.squeeze(seg[0,1])
        seg[seg >= thresh] = 1
        seg[seg < thresh] = 0
        
        value = dice_norm_metric(ground_truth=gt.flatten(), predictions=seg.flatten())

        metric_count += 1
        metric_sum += value.sum().item()
    
    return metric_sum / metric_count

# Training Arguments

In [45]:
class Args:
    def __init__(self) -> None:
        self.lr = 1e-5
        self.n_epochs = 300
        self.seed = 1
        self.path_train_flair = r"C:\Users\Talal\Desktop\ML703\project\train\flair"
        self.path_train_gts = r"C:\Users\Talal\Desktop\ML703\project\train\gt"
        self.path_train_bm = r"C:\Users\Talal\Desktop\ML703\project\train\fg_mask"

        self.path_val_flair = r"C:\Users\Talal\Desktop\ML703\project\dev_in\flair"
        self.path_val_gts   =  r"C:\Users\Talal\Desktop\ML703\project\dev_in\gt"
        self.path_val_bm   =  r"C:\Users\Talal\Desktop\ML703\project\dev_in\fg_mask"

        self.path_test_flair = r"C:\Users\Talal\Desktop\ML703\project\shifts_ms_pt2\shifts_ms_pt2\best\eval_in\flair"
        self.path_test_gts   =  r"C:\Users\Talal\Desktop\ML703\project\shifts_ms_pt2\shifts_ms_pt2\best\eval_in\gt"
        self.path_test_bm   =  r"C:\Users\Talal\Desktop\ML703\project\shifts_ms_pt2\shifts_ms_pt2\best\eval_in\fg_mask"

        self.path_model = r"C:\Users\Talal\Desktop\ML703\project"
        self.num_models = 1

        self.num_workers = 4
        self.n_jobs = 2
        self.path_save = ""   
        self.val_interval = 5
        self.threshold = 0.4
        

# Training

## Fixing the seed

In [46]:
args = Args()
seed_val = args.seed
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
path_save = args.path_save
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Initialise dataloaders

In [49]:
train_loader = get_train_dataloader(flair_path=args.path_train_flair, 
                                        gts_path=args.path_train_gts, 
                                        num_workers=args.num_workers, 
                                        batch_size=1)
val_loader = get_val_dataloader(flair_path=args.path_test_flair, 
                                    gts_path=args.path_test_gts, 
                                    num_workers=args.num_workers,
                                    bm_path=args.path_test_bm)

Number of training files: 33


Loading dataset: 100%|██████████| 3/3 [00:00<00:00,  4.53it/s]

Number of validation files: 9





## Initialise the model 

### CNN based

In [19]:
model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
        channels=(32, 64, 128, 256, 512),
        strides=(2, 2, 2, 2),
        num_res_units=0).to(device)

In [None]:
model = SegResNet()

### Transformer based

In [None]:
model = UNETR(in_channels=1,
             out_channels=4,
             img_size=(96,96,96),
             feature_size=32,
             norm_name='batch').to(device)

In [None]:
model = SwinUNETR(img_size=(96,96,96),
                        in_channels=1,
                        out_channels=4,
                        feature_size=48).to(device)

## setting the loss function & optimizer

In [21]:
loss_function = DiceLoss(to_onehot_y=True, 
                             softmax=True, sigmoid=False,
                             include_background=False)
optimizer = torch.optim.Adam(model.parameters(), args.lr)
act = nn.Softmax(dim=1)

## Training loop

In [None]:
poch_num = args.n_epochs
val_interval = args.val_interval
thresh = args.threshold
gamma_focal = 2.0
dice_weight = 0.5
focal_weight = 1.0
roi_size = (96, 96, 96)
sw_batch_size = 4
epoch_num = args.n_epochs

best_metric, best_metric_epoch = -1, -1
epoch_loss_values, metric_values = [], []

for epoch in range(epoch_num):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{epoch_num}")
    epoch_loss = train(model,train_loader,optimizer, loss_function, gamma_focal, dice_weight, focal_weight, device)
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    
    ''' Validation '''
    if (epoch + 1) % val_interval == 0:
        metric = validation(model, dice_norm_metric, val_loader, device, roi_size, sw_batch_size, act, thresh, device)
        metric_values.append(metric)
        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(model.state_dict(), os.path.join(path_save, "Best_model_finetuning.pth"))
            print("saved new best metric model")
        
        print(f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                            f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
                            )

# Inference 

In [50]:
''' Load trained models  '''
K = args.num_models
models = []
for i in range(K):
    models.append(UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
        channels=(32, 64, 128, 256, 512),
        strides=(2, 2, 2, 2),
        num_res_units=0).to(device)
                    )

for i, model in enumerate(models):
    model.load_state_dict(torch.load(os.path.join(args.path_model,
                                                    f"seed{i + 1}",
                                                    "Best_model_finetuning.pth")))
    model.eval()

act = torch.nn.Softmax(dim=1)
th = args.threshold
roi_size = (96, 96, 96)
sw_batch_size = 4

ndsc, f1, ndsc_aac = [], [], []

''' Evaluatioin loop '''
with Parallel(n_jobs=args.n_jobs) as parallel_backend:
    with torch.no_grad():
        for count, batch_data in enumerate(val_loader):
            inputs, gt, brain_mask = (
                batch_data["image"].to(device),
                batch_data["label"].cpu().numpy(),
                batch_data["brain_mask"].cpu().numpy()
            )

            # get ensemble predictions
            all_outputs = []
            for model in models:
                outputs = sliding_window_inference(inputs, roi_size,
                                                    sw_batch_size, model,
                                                    mode='gaussian')
                outputs = act(outputs).cpu().numpy()
                outputs = np.squeeze(outputs[0, 1])
                all_outputs.append(outputs)
            all_outputs = np.asarray(all_outputs)

            # obtain binary segmentation mask
            seg = np.mean(all_outputs, axis=0)
            seg[seg >= th] = 1
            seg[seg < th] = 0
            seg = np.squeeze(seg)
            seg = remove_connected_components(seg)

            gt = np.squeeze(gt)
            brain_mask = np.squeeze(brain_mask)

            # compute reverse mutual information uncertainty map
            uncs_map = ensemble_uncertainties_classification(np.concatenate(
                (np.expand_dims(all_outputs, axis=-1),
                    np.expand_dims(1. - all_outputs, axis=-1)),
                axis=-1))['reverse_mutual_information']

            # compute metrics
            ndsc += [dice_norm_metric(ground_truth=gt, predictions=seg)]
            f1 += [lesion_f1_score(ground_truth=gt,
                                    predictions=seg,
                                    IoU_threshold=0.5,
                                    parallel_backend=parallel_backend)]
            ndsc_aac += [ndsc_aac_metric(ground_truth=gt[brain_mask == 1].flatten(),
                                            predictions=seg[brain_mask == 1].flatten(),
                                            uncertainties=uncs_map[brain_mask == 1].flatten(),
                                            parallel_backend=parallel_backend)]

            # for nervous people
            if count % 10 == 0:
                print(f"Processed {count}/{len(val_loader)}")

ndsc = np.asarray(ndsc) * 100.
f1 = np.asarray(f1) * 100.
ndsc_aac = np.asarray(ndsc_aac) * 100.

print(f"nDSC:\t{np.mean(ndsc):.4f} +- {np.std(ndsc):.4f}")
print(f"Lesion F1 score:\t{np.mean(f1):.4f} +- {np.std(f1):.4f}")
print(f"nDSC R-AUC:\t{np.mean(ndsc_aac):.4f} +- {np.std(ndsc_aac):.4f}")

Processed 0/9
nDSC:	62.0949 +- 3.9508
Lesion F1 score:	20.7869 +- 5.8680
nDSC R-AUC:	31.6925 +- 6.2028
