In [1]:
import os 
import glob 
from tqdm import tqdm

import torch
import pandas as pd
import numpy as np
import skimage.metrics 
from torchmetrics.functional.classification import dice, recall, precision
from torchmetrics import Dice, Recall, Precision 
import nibabel as nib 

In [2]:
def binarize_image(img, threshold = 0.5, one_hot = False):
    if img.ndim == 4:
       img = img.unsqueeze(0)

    elif img.ndim == 3:
        img = img[None, None, :, :, :]

    assert img.ndim == 5, f'Binarize_image, tensor mismatch {img.shape}'

    n_channels = img.shape[1]

    # binary problem
    if n_channels == 1:
        nimg = img > threshold
    elif n_channels == 3:
        if img.dtype == torch.bool:
            nimg = img.float()
        else:
            nimg           = torch.zeros_like(img)
            argmax_indexes = torch.argmax(img, dim = 1)
            nimg.scatter_(1, argmax_indexes.unsqueeze(1), 1) 
    else:
        print(f"In binarize_image, number of channels {n_channels}")
    
    if nimg.dtype != torch.float:   nimg = nimg.float()
    
    return nimg

def calculate_overlap_metrics(pred, gt, target_label: int == 1):
    aneur_mask      = torch.where(gt == target_label, 1, 0)
    pred_image_bin  = binarize_image(pred)
    pred_aneur_mask = torch.mul(pred_image_bin, aneur_mask)

    # compute dice score recall and precision
    tp = torch.sum((pred_aneur_mask == 1) & (aneur_mask == 1))
    fp = torch.sum((pred_aneur_mask == 1) & (aneur_mask == 0))
    fn = torch.sum((pred_aneur_mask == 0) & (aneur_mask == 1))

    if 2*tp + fp + fn == 0: dice_aneur = 1e-12
    else: dice_aneur = (2*tp/(2*tp + fp + fn)).item()

    if tp + fn == 0: recall_aneur = 1e-12
    else: recall_aneur = (tp/(tp+fn)).item()

    if fp + tp == 0: precision_aneur = 1e-12
    else: precision_aneur = (tp/(tp+fp)).item()
    metrics = {'dice_aneur':dice_aneur, 
               'recall_aneur':recall_aneur, 
               'precision_aneur':precision_aneur}
    
    return metrics

In [18]:
def calculate_metrics(num_classes, target_aneurysm_class, gt_vols_fp, pred_vols_fn,
                      collapse_into_single_uia_class=False, untreated_aneurysm_only=False):
    dice = Dice(num_classes=num_classes, ignore_index=0, average='micro')
    recall = Recall(num_classes=num_classes, num_labels=num_classes, ignore_index=0, average='micro', task='binary')
    precision = Precision(num_classes=num_classes,  num_labels=num_classes, ignore_index=0, average='micro', task='binary')
    
    metrics_tm = {'dice': dice, 'recall': recall, 'precision': precision}
    
    results = []
    for gt_vol_fp in tqdm(gt_vols_fp):
        vol_fn = os.path.basename(gt_vol_fp)
    
        try:
            assert vol_fn in pred_vols_fn, \
                f"No prediction for vol {vol_fn}"
        except AssertionError as e:
            print(e)
            continue
        
        # load vols     
        gt = nib.load(gt_vol_fp).get_fdata()
        gt = torch.tensor(gt).int()
        gt = torch.where(gt == target_aneurysm_class, 1, 0)
        
        pred = nib.load(os.path.join(predictions_dir, vol_fn)).get_fdata()
        pred = torch.tensor(pred).int()
        pred = torch.where(pred > 0, 1, 0)

        # Calculate metrics
        metrics = calculate_overlap_metrics(pred.float(), gt.float(), target_label= 1)
        metrics = {f'{k}_kostas':v for k,v in metrics.items()}
        
        for metric_name, metric_tm in metrics_tm.items():
            try:
                metrics[f'{metric_name}_tm'] = metric_tm(pred, gt).item()
            except:
                print(f'Error calculating {metric_name} for {vol_fn}')
                metrics[f'{metric_name}_tm'] = np.nan
                
        metrics['mhd'] = skimage.metrics.hausdorff_distance(gt.cpu().numpy(), pred.cpu().numpy(), method='modified')
        
        # Add metrics to the results list
        results.append({
            'vol_name': vol_fn,
            **metrics
        })
    
    return pd.DataFrame(results)
    

# Evaluate trained on source domain and predicting in target Domain (ADAM)

## USZ

In [19]:
data_dir = '../../../data/'
results_dir = os.path.join(data_dir, 'results')

os.makedirs(results_dir, exist_ok=True)

ground_truth_dir = os.path.join(data_dir, 'preprocessed/Mathijs/Dataset004_21Classes')
target_aneurysm_class = 4

## Model Trained on Binary Segmentation Aneurysm (treated and untreated) vs Background

In [20]:
gt_vols_fp = glob.glob(os.path.join(ground_truth_dir, 'labelsTs', '*.nii.gz'))

predictions_dir = os.path.join(data_dir, 'nnUNet_predictions', 'train_on_SD_predict_on_TD', 'train_on_Dataset006_ADAMBinaryAneurysmsOnly', 'USZ', 'imagesTs', '3d_fullres')
pred_vols_fn = os.listdir(predictions_dir)
pred_vols_fn[0:5]

['18.nii.gz',
 'predict_from_raw_data_args.json',
 '19.nii.gz',
 '61.nii.gz',
 '49.nii.gz']

In [21]:
results_df = calculate_metrics(num_classes=2, target_aneurysm_class=target_aneurysm_class, gt_vols_fp=gt_vols_fp, pred_vols_fn=pred_vols_fn,
                               collapse_into_single_uia_class=True, untreated_aneurysm_only=True)

100%|██████████| 13/13 [12:34<00:00, 58.01s/it]


In [22]:
results_df

Unnamed: 0,vol_name,dice_aneur_kostas,recall_aneur_kostas,precision_aneur_kostas,dice_tm,recall_tm,precision_tm,mhd
0,18.nii.gz,0.807497,0.677144,1.0,0.79894,0.677144,1.0,0.394475
1,19.nii.gz,0.552023,0.381238,1.0,0.443213,0.381238,1.0,122.672084
2,61.nii.gz,0.392369,0.244066,1.0,0.391373,0.244066,1.0,3.708302
3,49.nii.gz,0.515468,0.347226,1.0,0.512007,0.347226,1.0,2.10125
4,31.nii.gz,0.749429,0.59927,1.0,0.729994,0.59927,1.0,0.602211
5,51.nii.gz,0.679374,0.514434,1.0,0.658768,0.514434,1.0,0.628222
6,48.nii.gz,0.677949,0.512801,1.0,0.039608,0.512801,1.0,205.533844
7,27.nii.gz,0.692432,0.529557,1.0,0.64275,0.529557,1.0,0.631643
8,1.nii.gz,0.137681,0.07393,1.0,0.128668,0.07393,1.0,68.483097
9,0.nii.gz,0.614242,0.443253,1.0,0.599829,0.443253,1.0,1.920295


In [23]:
results_df.to_csv(os.path.join(data_dir, 'results', 'sd_adam__td_usz__model_trained_on_binary_seg_treated_and_untreated_UIAs.csv'), index=False)

In [24]:
results_df.dice_tm.mean()

0.48745218033973986

In [25]:
results_df.mhd.mean()


47.88768604809257