In [37]:
import os
import SimpleITK as sitk
import numpy as np
import pandas as pd
import json

from reader import ALL_PROTOCOLS
test_csv = 'test.csv'

test_filenames = pd.read_csv(test_csv, 
                             dtype=object,
                             keep_default_na=False,
                             na_values=[]).as_matrix()
results = {}

In [38]:
def dice(predictions, labels, num_classes):
    """Calculates the categorical Dice similarity coefficients for each class
        between labels and predictions.
    Args:
        predictions (np.ndarray): predictions
        labels (np.ndarray): labels
        num_classes (int): number of classes to calculate the dice
            coefficient for
    Returns:
        np.ndarray: dice coefficient per class
    """

    dice_scores = np.zeros((num_classes))
    for i in range(num_classes):
        tmp_den = (np.sum(predictions == i) + np.sum(labels == i))
        tmp_dice = 2. * np.sum((predictions == i) * (labels == i)) / \
            tmp_den if tmp_den > 0 else np.nan
        dice_scores[i] = tmp_dice
    return dice_scores.astype(np.float32)

In [43]:
def evaluate(config, test_filenames):
    
    protocols = config["protocols"]
    res = {}
    for i in range(len(test_filenames)):
        res_individual = {}
        for j in range(len(protocols)):

            subj_id = str(test_filenames[i][0])

            # Load the groud truth label
            idx = ALL_PROTOCOLS.index(protocols[j])
            lbl = sitk.GetArrayFromImage(
                sitk.ReadImage(str(test_filenames[i][2 + idx]))).astype(np.int32)

            # Load the predicted segmentation
            seg_fn = os.path.join(config["out_segm_path"], subj_id, protocols[j] + '.nii.gz')
            seg = sitk.GetArrayFromImage(sitk.ReadImage(seg_fn)).astype(np.int32)

            dscs =(dice(seg, lbl, config["num_classes"][j]))
            res_individual[protocols[j]] = dscs

        res[subj_id] = res_individual
        #print('ID={}; dscs={};'.format(subj_id, res[subj_id]))
    return res

def print_stats(res, config):
    
    for p in config["protocols"]: 
        r = [np.nanmean(val[p]) for val in res.values()]
        print('protocol: {}; mean: {:0.4f}; std: {:0.4f}; min: {:0.4f}; max: {:0.4f}'.format(
            p, np.mean(r), np.std(r), np.min(r), np.max(r)))

In [44]:
# Parse the run config
cfg_fn = 'config_fsl_fast.json'
with open(cfg_fn) as f:
    config = json.load(f)

# Evaluate and print stats
res_fsl_fast = evaluate(config, test_filenames)
print_stats(res_fsl_fast, config)

protocol: fsl_fast; mean: 0.9511; std: 0.0190; min: 0.8338; max: 0.9763


In [45]:
# Parse the run config
cfg_fn = 'config_fsl_first.json'
with open(cfg_fn) as f:
    config = json.load(f)

# Evaluate and print stats
res_fsl_first = evaluate(config, test_filenames)
print_stats(res_fsl_first, config)

protocol: fsl_first; mean: 0.9504; std: 0.0286; min: 0.3753; max: 0.9685


In [49]:
# Parse the run config
cfg_fn = 'config_malp_em.json'
with open(cfg_fn) as f:
    config = json.load(f)

# Evaluate and print stats
res_malp_em = evaluate(config, test_filenames)
print_stats(res_malp_em, config)

protocol: malp_em; mean: 0.8734; std: 0.0260; min: 0.7151; max: 0.9089


In [46]:
# Parse the run config
cfg_fn = 'config_malp_em_tissue.json'
with open(cfg_fn) as f:
    config = json.load(f)

# Evaluate and print stats
res_malp_em_tissue = evaluate(config, test_filenames)
print_stats(res_malp_em_tissue, config)

protocol: malp_em_tissue; mean: 0.9497; std: 0.0149; min: 0.8597; max: 0.9689


In [47]:
# Parse the run config
cfg_fn = 'config_spm_tissue.json'
with open(cfg_fn) as f:
    config = json.load(f)

# Evaluate and print stats
res_spm_tissue = evaluate(config, test_filenames)
print_stats(res_spm_tissue, config)

protocol: spm_tissue; mean: 0.9484; std: 0.0346; min: 0.2652; max: 0.9692


In [48]:
# Parse the run config
cfg_fn = 'config_tissue.json'
with open(cfg_fn) as f:
    config = json.load(f)

# Evaluate and print stats
res_tissue = evaluate(config, test_filenames)
print_stats(res_tissue, config)

protocol: fsl_fast; mean: 0.9457; std: 0.0197; min: 0.7961; max: 0.9721
protocol: spm_tissue; mean: 0.9490; std: 0.0337; min: 0.2658; max: 0.9691
protocol: malp_em_tissue; mean: 0.9391; std: 0.0269; min: 0.6763; max: 0.9649


In [50]:
# Parse the run config
cfg_fn = 'config_all.json'
with open(cfg_fn) as f:
    config = json.load(f)

# Evaluate and print stats
res_all = evaluate(config, test_filenames)
print_stats(res_all, config)

protocol: fsl_fast; mean: 0.9473; std: 0.0182; min: 0.8013; max: 0.9734
protocol: fsl_first; mean: 0.9380; std: 0.0355; min: 0.3817; max: 0.9622
protocol: spm_tissue; mean: 0.9491; std: 0.0337; min: 0.2653; max: 0.9699
protocol: malp_em; mean: 0.8592; std: 0.0309; min: 0.5755; max: 0.8972
protocol: malp_em_tissue; mean: 0.9430; std: 0.0162; min: 0.7959; max: 0.9651
