In [None]:
#!mamba install -c conda-forge torchmetrics -y

In [19]:
import os 
import glob 

import torch
from torchmetrics import Dice
import nibabel as nib 

In [20]:
def binarize_image(img, threshold = 0.5, one_hot = False):
    if img.ndim == 4:
       img = img.unsqueeze(0)

    elif img.ndim == 3:
        img = img[None, None, :, :, :]

    assert img.ndim == 5, f'Binarize_image, tensor mismatch {img.shape}'

    n_channels = img.shape[1]

    # binary problem
    if n_channels == 1:
        nimg = img > threshold
    elif n_channels == 3:
        if img.dtype == torch.bool:
            nimg = img.float()
        else:
            nimg           = torch.zeros_like(img)
            argmax_indexes = torch.argmax(img, dim = 1)
            nimg.scatter_(1, argmax_indexes.unsqueeze(1), 1) 
    else:
        print(f"In binarize_image, number of channels {n_channels}")
    
    if nimg.dtype != torch.float:   nimg = nimg.float()
    return nimg

def calculate_overlap_metrics(pred, gt, target_label: int == 1):
    aneur_mask      = torch.where(gt == target_label, 1, 0)
    pred_image_bin  = binarize_image(pred)
    pred_aneur_mask = torch.mul(pred_image_bin, aneur_mask)

    # compute dice score recall and precision
    tp = torch.sum((pred_aneur_mask == 1) & (aneur_mask == 1))
    fp = torch.sum((pred_aneur_mask == 1) & (aneur_mask == 0))
    fn = torch.sum((pred_aneur_mask == 0) & (aneur_mask == 1))

    if 2*tp + fp + fn == 0: dice_aneur = 1e-12
    else: dice_aneur = 2*tp/(2*tp + fp + fn)

    if tp + fn == 0: recall_aneur = 1e-12
    else: recall_aneur = tp/(tp+fn)

    if fp + tp == 0: precision_aneur = 1e-12
    else: precision_aneur = tp/(tp+fp)
    metrics = {'dice_aneur':dice_aneur, 
               'recall_aneur':recall_aneur, 
               'precision_aneur':precision_aneur}
    
    return metrics

### Evaluate trained on source domain and predicting in holdout set of the same domain



#### UZS

##### Binary Segmentation All vessels

In [8]:
data_dir = '../../../data/'
ground_truth_dir = os.path.join(data_dir, 'raw', 'Dataset001_BinaryAllVessels', 'labelsTs')
predictions_dir = os.path.join(data_dir, 'nnUNet_predictions', 'Dataset001_BinaryAllVessels', '3d_fullres')

In [16]:
gt_vols_fp = glob.glob(os.path.join(ground_truth_dir, '*.nii.gz'))
pred_vols_fn = os.listdir(predictions_dir)

In [22]:
for gt_vol_fp in gt_vols_fp:
    vol_fn = os.path.basename(gt_vol_fp)
    
    assert vol_fn in pred_vols_fn, \
        f"No prediction for vol {vol_fn}"
    
    # load vols     
    gt = nib.load(gt_vol_fp).get_fdata()
    gt = torch.tensor(gt).float()
    
    pred = nib.load(os.path.join(predictions_dir, vol_fn))
    pred = torch.tensor(pred).float()
    
    # Calculate metrics
    metrics = calculate_overlap_metrics(pred, gt, target_label= 1)
    
    break

MemoryError: Unable to allocate 547. MiB for an array with shape (560, 640, 200) and data type float64