In [None]:
import numpy as np
import pandas as pd
import os
from scipy.misc import imread
from sklearn.metrics import accuracy_score, log_loss
import json, pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.transform import rescale

from AxonDeepSeg.apply_model import axon_segmentation
from AxonDeepSeg.testing.segmentation_scoring import *

# Intro

In this notebook we are going to apply each model to the test images to compute the statistics we are going to use to compare the models. 
It reads an existing json model_comparison file in order not to have to relaunch each model each time.

# Defining the parameters

In [None]:
# Most important part: define where to find the models to test as well as the data to test on
path_models = '../test_models/'
path_testing = '../data/baseline_testing/'

In [None]:
# Reminder: the more resolution we have the less rescaling we have to do (thus the smaller the gps is!)

gps_dict = {
    'SEM_3c_256':0.1,
    'SEM_3c_512':0.1, #Because not same semantics as SEM_3c_256 here
    'TEM_3c_256':0.02,
    'TEM_3c_512':0.01,
    'TEM_3c_1024':0.005
}

stats = {
    'best_acc_model':'best_acc_stats',
    'best_loss_model':'best_loss_stats',
    'model':'evolution_stats'
}

Other parameters for the segmentation, like the smoothing

In [None]:
crop_value = 25

Useful function to transform a png mask to a mask with classes

In [None]:
def labellize(mask_raw, thresh = [0, 0.2, 0.8]):
    max_ = np.max(mask_raw)
    n_c = len(thresh)
    
    mask = np.zeros_like(mask_raw)
    for i, e in enumerate(thresh[1:]):
        mask[np.where(mask_raw >= e*255)] = i+1
    
    return mask

In [None]:
def binarize(mask_raw):
    vals = np.unique(mask_raw)
    mask = np.zeros((mask_raw.shape[0], mask_raw.shape[1], len(vals)))
    for i,e in enumerate(vals):
        mask[:,:,i] = mask_raw == e
    return mask

# Applying the segmentations

Here we don't want to create an image, but just return the predictions (not the probabilities)

In [None]:
for folder in os.listdir(path_models): # Loop over all models
    path_model = os.path.join(path_models, folder)
    if os.path.isdir(path_model):        
        model_comparison_list = []
        # First we have to retrieve some essential data from the config file
        with open(os.path.join(path_model,'config_network.json'), 'r') as fd:
            config_network = json.loads(fd.read())
        trainingset_name = config_network['network_trainingset']
        type_trainingset = trainingset_name.split('_')[0]
        n_classes = config_network['network_n_classes']
        # We are now going to loop over all checkpoint files.
        for checkpoint in os.listdir(path_model):
            if checkpoint[-10:] == '.ckpt.meta':
                result_model = {}
                name_checkpoint = checkpoint[:-10]
                result_model.update({'id_model':folder,
                                    'ckpt':name_checkpoint,
                                    'config':config_network})
                
                # First we compute the training statistics, which are independent of the testing images
                
                # >> Validation 10-moving average statistics
                try:                        
                    f = open(path_model + '/'+ stats['name_checkpoint'] + '.pkl', 'r')
                    res = pickle.load(f)
                    acc_stats = res['accuracy']
                    loss_stats = res['loss']
                    epoch_stats = res['steps']
                except:
                    print 'No stats file found...'
                    f = open(path_model + '/evolution.pkl', 'r')
                    res = pickle.load(f)
                    epoch_stats = max(res['steps'])
                    acc_stats = np.mean(res['accuracy'][-10:])
                    loss_stats = np.mean(res['loss'][-10:])

                result_model.update({'training_stats':{
                    'training_epoch':epoch_stats,
                    'training_mvg_avg10_acc':acc_stats,
                    'training_mvg_avg10_loss':loss_stats
                                                       },
                                     'testing_stats':[]
                                    })
                
                
                # We now want to test on each SEM and TEM image
                # We use each folder of the baseline_testing in data. They have a prefix depending if SEM or TEM
                
                for testing_folder in tqdm(os.listdir(path_testing)):
                    path_testing_folder = os.path.join(path_testing, testing_folder)
                    if os.path.isdir(path_testing_folder):

                        # Reading the images and processing them if needed
                        mask_raw = imread(os.path.join(path_testing_folder, 'mask.png'), flatten=True, mode='L')
                        img_raw = imread(os.path.join(path_testing_folder, 'image.png'), flatten=True, mode='L')
                        mask = labellize(mask_raw)
                        
                        # We infer the name of the different files
                        type_image = testing_folder.split('_')[0] # SEM or TEM
                        name_image = '_'.join(testing_folder.split('_')[1:]) # Rest of the name of the image
                        testing_stats_dict = {'type_image':type_image, 
                                              'name_image':name_image}
                        
                        # Choosing the gps
                        name_gps = type_image + '_3c_' + str(trainingset_name.split('_')[-1])
                        gps = gps_dict[name_gps]
                        
                        '''
                        print 'GPS chosen: ' + str(gps)
                        file = open(path_testing_folder + '/pixel_size_in_micrometer.txt', 'r')
                        pixel_size = float(file.read())
                        plt.figure()
                        plt.imshow(rescale(img_raw, float(pixel_size)/gps, preserve_range=True)[0:256,0:256], cmap='gray')
                        plt.colorbar()
                        plt.show();
                        '''

                        # Making the prediction
                        print 'Beginning segmentation of ' + type_image + ' image with ' + type_trainingset + ' model, id ' + str(folder) + '...'
                        prediction, pred_proba = axon_segmentation(path_testing_folder, 
                                                       path_model, 
                                                       config_network,
                                                       ckpt_name = name_checkpoint,
                                                       crop_value=crop_value, 
                                                       general_pixel_size=gps,
                                                       pred_proba=True,
                                                       write_mode=False
                                                      )
                        
                        print 'Statistics extraction...'
                        # Processing pred_proba into statistics
                        a = np.exp(pred_proba)
                        b = np.sum(a, axis=-1)
                        pred_proba = np.stack([np.divide(a[:,:,i],b) for i in range(n_classes)], axis=-1)
                        
                        # Computation of metrics
                        vec_prediction = np.reshape(binarize(prediction), (-1,n_classes))
                        vec_pred_proba = np.reshape(pred_proba, (-1,n_classes))
                        vec_mask = np.reshape(binarize(mask), (-1,n_classes))
                        # >> Accuracy and XEntropy loss
                        testing_stats_dict.update({
                            'accuracy':accuracy_score(mask.ravel(), prediction.ravel()),
                            'log_loss':log_loss(vec_mask, vec_pred_proba)
                                                   })
                        # >> Pixel wise dice, both classes, and element wise dice
                        gt_axon = binarize(mask)[:,:,-1]
                        pred_axon = binarize(prediction)[:,:,-1]
                        pw_dice_axon = pw_dice(pred_axon, gt_axon)
                        testing_stats_dict.update({
                            'pw_dice_axon':pw_dice_axon})

                        '''
                        ew_sensitivity_axon, ew_precision_axon, ew_diffusion_axon = score_analysis(img_raw, 
                                                                                          gt_axon, 
                                                                                          pred_axon)

                        data_axon_dice = dice(img_raw, gt_axon, pred_axon, min_area=4)
                        ew_dice_mean_axon = data_axon_dice['dice'].mean()
                        ew_dice_quant_axon = data_axon_dice['dice'].quantile([0.1, 0.5, 0.9, 0.95])

                        result_model.update({name_testing_image:{
                            'ew_dice_mean_axon':ew_dice_mean_axon,
                            'ew_dice_quant10_axon':ew_dice_quant_axon.values[0],
                            'ew_dice_quant50_axon':ew_dice_quant_axon.values[1],
                            'ew_dice_quant90_axon':ew_dice_quant_axon.values[2],
                            'ew_dice_quant95_axon':ew_dice_quant_axon.values[3],
                            'ew_sentivity_axon':ew_sensitivity_axon,
                            'ew_precision_axon':ew_precision_axon,
                            'ew_diffusion_axon':ew_diffusion_axon
                                                              }
                                             })
                        '''
                        if n_classes == 3:
                            gt_myelin = binarize(mask)[:,:,1]
                            pred_myelin = binarize(prediction)[:,:,1]
                            pw_dice_myelin = pw_dice(pred_myelin, gt_myelin)
                            
                            testing_stats_dict.update({
                                'pw_dice_myelin':pw_dice_myelin})
                            
                        result_model['testing_stats'].append(testing_stats_dict)
                            
                model_comparison_list.append(result_model)               
        
        # Finally we save the model in a new json file.
        path_file = os.path.join(path_models, 'model_comparison.json')
        if os.path.exists(path_file):
            with open(path_file, 'r') as fd:
                existing_dict = json.loads(fd.read())
        else:
            existing_dict = {'data':[]}

        existing_dict['data'].append(model_comparison_list)

        with open(path_file, 'w') as f:
            json.dump(existing_dict, f, indent=2)
        model_comparison_list = {}

# Converting the json file into an Excel file

Use GDrive