In [None]:
#%% Imports
import os
import shutil
import nibabel as nib
import numpy as np
import random
import pickle
import matplotlib.pyplot as plt
from glob import glob
from scipy import ndimage
from nilearn.image import resample_to_img, resample_img
from nilearn.masking import compute_background_mask, compute_epi_mask
from nilearn.plotting import plot_roi, plot_epi
from scipy.spatial.distance import directed_hausdorff

In [None]:
def get_test_metrics(original_label, predicted_label):
    original_data = nib.load(original_label).get_data()
    predicted_data = nib.load(predicted_label).get_data()
    
    metrics = {}
    
    # True positive
    metrics['TP'] = np.sum(original_data == 1)
    
    # True negative
    metrics['TN'] = np.sum(original_data == 0)
    
    # False positive (all 1's in predicted minus original 1's)
    metrics['FP'] = np.sum(((predicted_data == 1) - (original_data == 1)) > 0)
    
    # False negative
    metrics['FN'] = np.sum(((predicted_data == 1) - (original_data == 1)) < 0)
    
    # True positive rate (Sensitivity, Recall)
    metrics['TPR'] = metrics['TP'] / (metrics['TP'] + metrics['FN'])  
    
    # True negative rate (Specificity)
    metrics['TNR'] = metrics['TN'] / (metrics['TN'] + metrics['FP'])
    
    # Positive predictive value (Precision)
    metrics['PPV'] = metrics['TP'] / (metrics['TP'] + metrics['FP'])
    
    # Negative predictive value
    metrics['NPV'] = metrics['TN'] / (metrics['TN'] + metrics['FN'])
    
    # False negative rate (Miss rate)
    metrics['FNR'] = 1 -  metrics['TPR']
    
    # False positive rate (Fall-out)
    metrics['FPR'] = 1 - metrics['TNR']
    
    # False discovery rate
    metrics['FDR'] = 1 - metrics['PPV']
    
    # False omission rate
    metrics['FOR'] = 1 - metrics['NPV']
    
    # Accuracy
    metrics['ACC'] = (metrics['TP'] + metrics['TN']) / \
                                (metrics['TP'] + 
                                 metrics['TN'] + 
                                 metrics['FP'] + 
                                 metrics['FN'])
    
    # F1 Score (also known as DSC, Sørensen–Dice coefficient, ...)
    metrics['F1S'] = 2 * (metrics['PPV'] * metrics['TPR']) / \
                                    (metrics['PPV'] + metrics['TPR'])
    
    # Matthews correlaton coefficient
    # The MCC can be more appropriate when negatives actually mean something,
    # and can be more useful in other ways.
    metrics['MCC'] = metrics['TP'] * metrics['TN'] - \
                                metrics['FP'] * metrics['FN'] / \
                                np.sqrt((metrics['TP'] + metrics['FP']) *
                                       (metrics['TP'] + metrics['FN']) *
                                       (metrics['TN'] + metrics['FP']) *
                                       (metrics['TN'] + metrics['FN']))
    
    # Compute Hausdorff distance
    metrics['HD'] = directed_hausdorff(original_data, predicted_data)[0]
    
    # Compute Jaccard coefficient
    metrics['JC'] = metrics['TP'] / (metrics['FN'] + metrics['FP'] + metrics['TP'])
    
    return(metrics)
    

In [None]:
# Set working directory
os.chdir('/home/uziel/DISS')
# Set root of models to be post-processed
root = "./milestones_3"
model_variant = '*' # choose model variant. Eg. "DM_V0_{0..4}".
trained_models = sorted(glob(os.path.join(root, model_variant)))

**POSTPROCESSING FOR TEST CASES**

Upsample predicted labels and compute test metrics

In [None]:
##################################################################
##### POSTPROCESSING FOR K-FOLD CROSS-VALIDATION MODELS (0) ######
##################################################################

root_data = './data/ISLES2017/testing'
results = {}

for model in trained_models:
    root = os.path.join(model, 'output/predictions/testSession/predictions')
    root_2 = os.path.dirname(root)
    
    preds = sorted(glob(os.path.join(root, '*Segm.nii.gz')))

    results[os.path.basename(model)] = []
    
    # resize its prediction for final result validation
    for i in range(len(preds)):
        # Find subject that contains the code in pred.
        subject = sorted([y
                          for x in os.walk(root_data)
                          for y in glob(os.path.join(x[0], '*'))
                          if os.path.basename(preds[i]).split('_')[-2].split('.')[-1] in y
                         ])[0].split('/')[-2]

        subject_channels = sorted([y
                                   for x in os.walk(os.path.join(root_data, subject))
                                   for y in glob(os.path.join(x[0], '*MR_*.nii'))
                                   if '4DPWI' not in y
                                  ])
        
        subject_label = sorted([y
                                for x in os.walk(os.path.join(root_data, subject))
                                for y in glob(os.path.join(x[0], '*OT*.nii'))
                               ])[0]

        # Load ADC channel as reference
        original_img = nib.load(subject_channels[0])

        # Load prediction
        pred_img = nib.load(preds[i])

        # Upsample to original size
        pred_img = resample_img(pred_img,
                                original_img.affine,
                                original_img.shape,
                                interpolation='nearest')
        
        # Save prediction
        pred_label = os.path.join(root_2, os.path.basename(preds[i]).split('_')[-2] + '.nii')
        nib.save(pred_img, pred_label)
        
        # Compute metrics between original and predicted label
        metrics = get_test_metrics(subject_label, pred_label)
        
        results[os.path.basename(model)].append([subject, subject_channels, subject_label, pred_label, metrics])
        
    # Save results
    with open(os.path.join(model, 'test_results.txt'), 'wb') as output:
        pickle.dump(results, output, pickle.HIGHEST_PROTOCOL)
        
    # Compute mean and variance of subject predictions' metrics
    metrics = np.array(results.values())[:,4]
    test_metrics = {}
    test_metrics['mean'] = {k : np.mean(t[k] for t in metrics) for k in metrics[0]}
    test_metrics['var'] = {k : np.var(t[k] for t in metrics) for k in metrics[0]}
    
    # Save each model's metrics
    with open(os.path.join(model, 'test_metrics.txt'), 'wb') as output:
        pickle.dump(test_metrics, output, pickle.HIGHEST_PROTOCOL)


Load each model's metrics, compute mean and variance. This is the final result of an experiment, and determines its performance.

In [None]:
metrics = []
for model in trained_models:
    with open(os.path.join(model, 'test_metrics.txt'), 'rb') as input:
        metrics.append(pickle.load(input))
        
test_metrics['mean'] = {k : np.mean(t[k] for t in metrics) for k in metrics[0]}
test_metrics['var'] = {k : np.var(t[k] for t in metrics) for k in metrics[0]}

# Save final experiment metrics
with open(os.path.join(root, model_variant + '_test_metrics.txt'), 'wb') as output:
    pickle.dump(test_metrics, output, pickle.HIGHEST_PROTOCOL)

Plot original and predicted labels for test cases

In [None]:
# Plot original label and predicted label on top of original image
for model in trained_models:
    
    plt.close('all')
    fig = plt.figure(figsize=(16, len(results[os.path.basename(model)])*2))
    i = 1

    for subject, subject_channels, subject_label, pred_label, ~ in results[os.path.basename(model)]:
        original_img = nib.load(subject_channels[0])
        original_label_img = nib.load(subject_label)
        predicted_label_img = nib.load(pred_label)
        
        ax = fig.add_subplot(len(results[os.path.basename(model)]), 2, i)
        temp = plot_roi(original_label_img, original_img, display_mode='z', cut_coords=4, figure=fig, axes=ax)
        ax = fig.add_subplot(len(results[os.path.basename(model)]), 2, i+1)
        plot_roi(predicted_label_img, original_img, display_mode='z', cut_coords=temp.cut_coords, figure=fig, axes=ax)
        i += 2

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.savefig(os.path.join(model, 'testSegResults_' + os.path.basename(model) + '.pdf'), bbox_inches='tight')

**PROCESS TRAINING AND VALIDATION RESULTS**

Plot and save training progress

In [None]:
for model in trained_models:
    # Plot and save training progress
    os.system("python ischleseg/deepmedic/plotSaveTrainingProgress.py " +
              os.path.join(model, "output/logs/trainSession.txt -d -m 20 -s"))
    # Move files to the corresponding model directory
    os.system("mv trainingProgress.pdf " + os.path.join(model, 'trainingProgress_' + os.path.basename(model) + '.pdf'))
    os.system("mv trainingProgress.txt " + os.path.join(model, 'trainingProgress.txt'))


Load training metrics and compute mean and variance between models (includes training and validation metrics)

In [None]:
# Load "measuredMetricsFromAllExperiments"
# 1st dimension: "Validation" (0), "Training" (1)
# 2nd dimension: ? (0)
# 3rd dimension: "Mean Accuracy" (0), "Sensitivity" (1), "Specificity" (2), "DSC (samples)" (3), "DSC (full-segm)" (4)

metrics = {}
for model in trained_models:
    with open(os.path.join(model, 'trainingProgress.txt'), 'rb') as input:
        metrics[os.path.basename(model)] = pickle.load(input)
        
# Compute mean and variance of all models' variations metrics
metrics_mean = {}
metrics_var = {}
metrics_values = np.array(metrics.values())
metrics_names_0 = ['Validation', 'Training']
metrics_names_1 = ['Mean Accuracy', 'Sensitivity', 'Specificity', 'DSC (Samples)', 'DSC (full-segm)']

for i in range(len(metrics_names_0)):
    metrics_mean[metrics_names_0[i]] = {}
    metrics_var[metrics_names_0[i]] = {}
    for j in range(len(metrics_names_1)):
        if i == 1 and j == 4: # Skip DSC_full for training (is never calculated)
            metrics_mean[metrics_names_0[i]][metrics_names_1[j]] = np.zeros(35*20)
            metrics_var[metrics_names_0[i]][metrics_names_1[j]] = np.zeros(35*20)
            continue 
        metrics_mean[metrics_names_0[i]][metrics_names_1[j]] = np.mean(metrics_values[:,i,0,j])
        metrics_var[metrics_names_0[i]][metrics_names_1[j]] = np.var(metrics_values[:,i,0,j])        

train_val_metrics = {}
train_val_metrics['mean'] = metrics_mean
train_val_metrics['var'] = metrics_var
# Save final experiment metrics
with open(os.path.join(root, model_variant + '_train_val_metrics.txt'), 'wb') as output:
    pickle.dump(train_val_metrics, output, pickle.HIGHEST_PROTOCOL)

Plot mean training and validation metrics of all trained models

In [None]:
plt.close('all')
rows, cols = [2, 5]
fig = plt.figure(figsize=(cols*6, rows*4))
for i in range(len(metrics_names_0)):
    if i == 0: continue # Skip validation data (models in milestones_3 did no validation)
    for j in range(len(metrics_names_1)):
        ax = fig.add_subplot(rows, cols, i * cols + 1 + j)
        plt.plot(np.arange(0, 35, 1/20.0), metrics_mean[metrics_names_0[i]][metrics_names_1[j]], 'r')
        plt.xlim(0, 35)
        plt.ylim(0, 1.0)
        plt.xlabel('Epoch')
        plt.ylabel(metrics_names_0[i])
        plt.title(metrics_names_1[j])

# Save mean training and validation metrics of all trained models
plt.savefig(os.path.join(root, model_variant + 'meanTrainProgress.pdf'), bbox_inches='tight')

In [None]:
#################################################################
##### POSTPROCESSING FOR TRAINING + VALIDATION SETS MODELS ######
#################################################################

# load prediction
root_data = './data/ISLES2017/training'

root = "./milestones_1"
trained_models = sorted(glob(os.path.join(root, '*DM_V[0-1]_*[0-4]')))
results = {}

for model in trained_models:
    root = os.path.join(model, 'output/predictions/trainSession/predictions')
    
    segms = sorted(glob(os.path.join(root, '*Segm.nii.gz')))
    prob_maps_class0 = sorted(glob(os.path.join(root, '*ProbMapClass0.nii.gz')))
    prob_maps_class1 = sorted(glob(os.path.join(root, '*ProbMapClass1.nii.gz')))
    results[os.path.basename(model)] = []
    
    # resize its prediction for final result validation
    for i in range(len(segms)):
        # Find subject that contains the code in pred.
        subject = sorted([y
                          for x in os.walk(root_data)
                          for y in glob(os.path.join(x[0], '*'))
                          if os.path.basename(segms[i]).split('_')[-2].split('.')[-1] in y
                         ])[0].split('/')[-2]

        subject_channels = sorted([y
                                   for x in os.walk(os.path.join(root_data, subject))
                                   for y in glob(os.path.join(x[0], '*MR_*.nii'))
                                   if '4DPWI' not in y
                                  ])
        subject_label = sorted([y
                                for x in os.walk(os.path.join(root_data, subject))
                                for y in glob(os.path.join(x[0], '*OT*.nii'))
                               ])

        # Load ADC channel as reference
        original_img = nib.load(subject_channels[0])

        # load predictions
        pred = nib.load(segms[i])
        pmap_0 = nib.load(prob_maps_class0[i])
        pmap_1 = nib.load(prob_maps_class1[i])

        # Upsample to original size
        pred = resample_img(pred,
                            original_img.affine,
                            original_img.shape,
                            interpolation='nearest')
        pmap_0 = resample_img(pmap_0,
                              original_img.affine,
                              original_img.shape,
                              interpolation='continuous')
        pmap_1 = resample_img(pmap_1,
                              original_img.affine,
                              original_img.shape,
                              interpolation='continuous')        
        
        results[os.path.basename(model)].append([subject_channels, subject_label[0], pred, pmap_0, pmap_1])

In [None]:
# Flag for testSession_0 or testSession_1
sflag = 0

if sflag:
    session = 'testSession_1'
    root_data = './data/ISLES2017/testing'
else:
    session = 'testSession_0'
    root_data = './data/ISLES2017/training'

In [None]:
###########################################################
##### POSTPROCESSING FOR TRAINING + TEST SETS MODELS ######
###########################################################

root = "./milestones_3"
trained_models = sorted(glob(os.path.join(root, '*DM_V[2-3]')))
results = {}

for model in trained_models:
    root = os.path.join(model, 'output/predictions/' + session + '/predictions')
    root_2 = os.path.join(model, 'output/predictions/' + session)
    
    # load predictions
    # resample predictions to original size
    # save predictions (as .nii) matching SMIR required format
    
    preds = sorted(glob(os.path.join(root, '*Segm.nii.gz')))

    results[os.path.basename(model)] = []
    
    # resize its prediction for final result validation
    for i in range(len(preds)):
        # Find subject that contains the code in pred.
        subject = sorted([y
                          for x in os.walk(root_data)
                          for y in glob(os.path.join(x[0], '*'))
                          if os.path.basename(preds[i]).split('_')[-2].split('.')[-1] in y
                         ])[0].split('/')[-2]

        subject_channels = sorted([y
                                   for x in os.walk(os.path.join(root_data, subject))
                                   for y in glob(os.path.join(x[0], '*MR_*.nii'))
                                   if '4DPWI' not in y
                                  ])

        # Load ADC channel as reference
        original_img = nib.load(subject_channels[0])

        # load prediction
        pred = nib.load(preds[i])

        # Upsample to original size
        pred = resample_img(pred,
                            original_img.affine,
                            original_img.shape,
                            interpolation='nearest')
        
        # Save prediction
        nib.save(pred, os.path.join(root_2, os.path.basename(preds[i]).split('_')[-2] + '.nii'))
        
        results[os.path.basename(model)].append([subject, subject_channels, pred])   