In [None]:
import numpy as np
import pandas as pd
import os
from scipy.misc import imread
from sklearn.metrics import accuracy_score, log_loss
import json, pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.transform import rescale
import time

from AxonDeepSeg.apply_model import axon_segmentation
from AxonDeepSeg.testing.segmentation_scoring import *

# Intro

In this notebook we are going to apply each model to the test images to compute the statistics we are going to use to compare the models. 
It reads an existing json model_comparison file in order not to have to relaunch each model each time.

# Defining the parameters

In [None]:
# Most important part: define where to find the models to test as well as the data to test on
path_models = '../models/'
path_testing = '../data/baseline_validation/'

In [None]:
# Reminder: the more resolution we have the less rescaling we have to do (thus the smaller the gps is!)

gps_dict = {
    'SEM_3c_256':0.1,
    'SEM_3c_512':0.1, #Because not same semantics as SEM_3c_256 here
    'TEM_3c_256':0.02,
    'TEM_3c_512':0.01,
    'TEM_3c_1024':0.005,
    'TEM_3c_reduced':0.02,
    'SEM_3c_reduced':0.2,
    
}

stats = {
    'best_acc_model':'best_acc_stats',
    'best_loss_model':'best_loss_stats',
    'model':'evolution_stats'
}

Other parameters for the segmentation, like the smoothing

In [None]:
crop_value = 25

Useful function to transform a png mask to a mask with classes

In [None]:
def labellize(mask_raw, thresh = [0, 0.2, 0.8]):
    max_ = np.max(mask_raw)
    n_c = len(thresh)
    
    mask = np.zeros_like(mask_raw)
    for i, e in enumerate(thresh[1:]):
        mask[np.where(mask_raw >= e*255)] = i+1
    
    return mask

In [None]:
def binarize(mask_raw):
    vals = np.unique(mask_raw)
    mask = np.zeros((mask_raw.shape[0], mask_raw.shape[1], len(vals)))
    for i,e in enumerate(vals):
        mask[:,:,i] = mask_raw == e
    return mask

In [None]:
def volumize(mask_labellized, n_class):
    '''
    :param mask_labellized: 2-D array with each class being indicated as its corresponding 
    number. ex : [[0,0,1],[2,2,0],[0,1,2]].
    '''
    mask = np.zeros((mask_labellized.shape[0], mask_labellized.shape[1], n_class))

    for i in range(n_class):
        mask[:,:,i] = mask_labellized == i
    
    return mask

# Applying the segmentations

Here we don't want to create an image, but just return the predictions (not the probabilities)

In [None]:
for folder in tqdm(os.listdir(path_models),desc='models'): # Loop over all models
    # First we load every information independent of the model
    # We generate the list of testing folders, each one containing one image
    testing_folders = [d for d in os.listdir(path_testing) if os.path.isdir(os.path.join(path_testing, d))]
    path_testing_folders = map(lambda x: os.path.join(path_testing,x), testing_folders)
    
    path_model = os.path.join(path_models, folder)
    if os.path.isdir(path_model) and folder[-11:] != 'checkpoints':        
        if not os.path.exists(os.path.join(path_model,'model_statistics_validation.json')):
            model_comparison_list = []

            # First we have to retrieve some essential data from the config file
            with open(os.path.join(path_model,'config_network.json'), 'r') as fd:
                config_network = json.loads(fd.read())
            trainingset_name = config_network['network_trainingset']
            type_trainingset = trainingset_name.split('_')[0]
            n_classes = config_network['network_n_classes']

            # Now we load every information that is independent of the checkpoint
            # We retrieve the type of each image
            L_names_gps = [type_image + '_3c_' + str(trainingset_name.split('_')[-1]) for type_image in map(lambda x: x.split('_')[0], testing_folders)]
            L_gps = [gps_dict[name_gps] for name_gps in L_names_gps]

            print 'Beginning segmentation of test images with ' + type_trainingset + ' model, id ' + str(folder) + '...'

            # We are now going to loop over all checkpoint files.
            for checkpoint in os.listdir(path_model):
                if checkpoint[-10:] == '.ckpt.meta':
                    result_model = {}
                    name_checkpoint = checkpoint[:-10]
                    result_model.update({'id_model':folder,
                                        'ckpt':name_checkpoint,
                                        'config':config_network})

                    # First we compute the training statistics, which are independent of the testing images

                    # >> Validation 10-moving average statistics
                    try:                        
                        f = open(path_model + '/'+ stats['name_checkpoint'] + '.pkl', 'r')
                        res = pickle.load(f)
                        acc_stats = res['accuracy']
                        loss_stats = res['loss']
                        epoch_stats = res['steps']
                    except:
                        print 'No stats file found...'
                        f = open(path_model + '/evolution.pkl', 'r')
                        res = pickle.load(f)
                        epoch_stats = max(res['steps'])
                        acc_stats = np.mean(res['accuracy'][-10:])
                        loss_stats = np.mean(res['loss'][-10:])

                    result_model.update({'training_stats':{
                        'training_epoch':epoch_stats,
                        'training_mvg_avg10_acc':acc_stats,
                        'training_mvg_avg10_loss':loss_stats
                                                           },
                                         'testing_stats':[]
                                        })


                    # We now want to test on each SEM and TEM image
                    # We use each folder of the baseline_testing in data. They have a prefix depending if SEM or TEM
                    # Inference time                                
                    predictions, pred_probas = axon_segmentation(path_testing_folders, 
                                                           path_model, 
                                                           config_network,
                                                           ckpt_name = name_checkpoint,
                                                           crop_value=crop_value, 
                                                           general_pixel_sizes=L_gps,
                                                           pred_proba=True,
                                                           write_mode=False,
                                                           gpu_per = 0.3
                                                          )

                    # We now have a list of predictions and prediction_probas
                    print 'Statistics extraction...'
                    for i, testing_folder in tqdm(enumerate(testing_folders)):

                        # First we get the appropriate prediction and pred proba
                        prediction = predictions[i]
                        pred_proba = pred_probas[i]

                        path_testing_folder = os.path.join(path_testing, testing_folder)

                        # Reading the images and processing them if needed
                        mask_raw = imread(os.path.join(path_testing_folder, 'mask.png'), flatten=True, mode='L')
                        #img_raw = imread(os.path.join(path_testing_folder, 'image.png'), flatten=True, mode='L')
                        mask = labellize(mask_raw)

                        # We infer the name of the different files
                        type_image = testing_folder.split('_')[0] # SEM or TEM
                        name_image = '_'.join(testing_folder.split('_')[1:]) # Rest of the name of the image
                        testing_stats_dict = {'type_image':type_image, 
                                              'name_image':name_image}


                        '''
                        print 'GPS chosen: ' + str(gps)
                        file = open(path_testing_folder + '/pixel_size_in_micrometer.txt', 'r')
                        pixel_size = float(file.read())
                        plt.figure()
                        plt.imshow(rescale(img_raw, float(pixel_size)/gps, preserve_range=True)[0:256,0:256], cmap='gray')
                        plt.colorbar()
                        plt.show();
                        '''
                        # Processing pred_proba into statistics
                        a = np.exp(pred_proba)
                        b = np.sum(a, axis=-1)
                        pred_proba = np.stack([np.divide(a[:,:,l],b) for l in range(n_classes)], axis=-1)

                        # Computation of metrics
                        vec_prediction = np.reshape(volumize(prediction,n_classes), (-1,n_classes))
                        vec_pred_proba = np.reshape(pred_proba, (-1,n_classes))
                        vec_mask = np.reshape(volumize(mask,n_classes), (-1,n_classes))
                        # >> Accuracy and XEntropy loss
                        testing_stats_dict.update({
                            'accuracy':accuracy_score(mask.ravel(), prediction.ravel()),
                            'log_loss':log_loss(vec_mask, vec_pred_proba)
                                                   })
                        # >> Pixel wise dice, both classes, and element wise dice
                        gt_axon = volumize(mask,n_classes)[:,:,-1]
                        pred_axon = volumize(prediction,n_classes)[:,:,-1]
                        pw_dice_axon = pw_dice(pred_axon, gt_axon)
                        testing_stats_dict.update({
                            'pw_dice_axon':pw_dice_axon})

                        '''
                        ew_sensitivity_axon, ew_precision_axon, ew_diffusion_axon = score_analysis(img_raw, 
                                                                                          gt_axon, 
                                                                                          pred_axon)

                        data_axon_dice = dice(img_raw, gt_axon, pred_axon, min_area=4)
                        ew_dice_mean_axon = data_axon_dice['dice'].mean()
                        ew_dice_quant_axon = data_axon_dice['dice'].quantile([0.1, 0.5, 0.9, 0.95])

                        result_model.update({name_testing_image:{
                            'ew_dice_mean_axon':ew_dice_mean_axon,
                            'ew_dice_quant10_axon':ew_dice_quant_axon.values[0],
                            'ew_dice_quant50_axon':ew_dice_quant_axon.values[1],
                            'ew_dice_quant90_axon':ew_dice_quant_axon.values[2],
                            'ew_dice_quant95_axon':ew_dice_quant_axon.values[3],
                            'ew_sentivity_axon':ew_sensitivity_axon,
                            'ew_precision_axon':ew_precision_axon,
                            'ew_diffusion_axon':ew_diffusion_axon
                                                              }
                                             })
                        '''
                        if n_classes == 3:
                            gt_myelin = volumize(mask,n_classes)[:,:,1]
                            pred_myelin = volumize(prediction,n_classes)[:,:,1]
                            pw_dice_myelin = pw_dice(pred_myelin, gt_myelin)

                            testing_stats_dict.update({
                                'pw_dice_myelin':pw_dice_myelin})

                        result_model['testing_stats'].append(testing_stats_dict)

                    model_comparison_list.append(result_model)               

            # Finally we save the model in a new json file.
            path_file = os.path.join(path_model, 'model_statistics_validation.json')
            if os.path.exists(path_file):
                shutil.move(path_file, os.path.join(path_model, 'model_statistics_validation.json.old'))

            existing_dict = {'data':model_comparison_list,
                            'date':time.strftime("%Y-%m-%d")}

            with open(path_file, 'w') as f:
                json.dump(existing_dict, f, indent=2)
            model_comparison_list = {}

# Converting the json file into an Excel file

Use GDrive

In [None]:
import pandas as pd, os, json
import matplotlib.pyplot as plt
from functools import partial

## Displaying some metrics

In [None]:
class metrics():
    
    def __init__(self):
        self.path_models = set()
        self.stats = pd.DataFrame()
        self.filtered_stats = pd.DataFrame()
        self.aggregated_stats = pd.DataFrame()
        self.columns = ['id_model', 'ckpt', 'type_model','training_acc', 
                               'training_loss', 'training_epoch', 'pw_dice_myelin', 
                               'pw_dice_axon', 'testing_log_loss', 'testing_accuracy', 'testing_name_image', 
                               'testing_type_image']
        
    def add_models(self,path_models):
        if type(path_models) != list:
            path_models = [path_models]
        [self.path_models.add(e) for e in path_models]
            
    def load_models(self):
        for path in self.path_models:
            try:
                with open(os.path.join(path, 'model_statistics_validation.json')) as f:
                    stats_dict = json.loads(f.read())['data']
            except:
                print 'No config file found'
                
            # Now we add a line to the stats dataframe for each model
            for ckpt in stats_dict:
                print "found"
                
                # Getting each part of data
                model_name = ckpt['id_model']
                ckpt_name = ckpt['ckpt']
                config = ckpt['config']
                training_stats = ckpt['training_stats']
                training_acc = training_stats['training_mvg_avg10_acc']
                training_loss = training_stats['training_mvg_avg10_loss']
                training_epoch = training_stats['training_epoch']
                testing_stats_list = ckpt['testing_stats']
                for testing_stats in testing_stats_list:
                    pw_dice_myelin = testing_stats['pw_dice_myelin']
                    pw_dice_axon = testing_stats['pw_dice_axon']
                    testing_log_loss = testing_stats['log_loss']
                    name_image = testing_stats['name_image']
                    type_image = testing_stats['type_image']
                    testing_accuracy = testing_stats['accuracy']

                    new_line = [[model_name, ckpt_name, config['network_trainingset'].split('_')[0], 
                                   training_acc, training_loss, training_epoch, pw_dice_myelin, pw_dice_axon,
                                  testing_log_loss, testing_accuracy, name_image, type_image]]
                    
                    # Updating the dataframe with the latest data
                    self.stats = self.stats.append(pd.DataFrame(columns=self.columns, data=new_line))
                
                self.filtered_stats = self.stats.copy()
                
    
    def filter_(self, list_acquisitions = None, list_ckpt = None, write_mode=False, name_file=None):
        filtered_stats = pd.DataFrame()
        
        if list_acquisitions != None:
            # Processing arguments
            if type(list_acquisitions) != list:
                list_acquisitions = [list_acquisitions]

            # For each acquisition type
            for image_to_take in list_acquisitions:
                filtered_stats = filtered_stats.append(self.stats.loc[self.stats['testing_type_image']==image_to_take])
        if list_ckpt != None:       
            # Processing arguments
            if type(list_ckpt) != list:
                list_ckpt = [list_ckpt]    
            for ckpt in list_ckpt:
                filtered_stats = filtered_stats.append(self.stats.loc[self.stats['ckpt']==ckpt])
        self.filtered_stats = filtered_stats
        
        if write_mode == True:
            if name_file is None:
                name_file = 'filtered_'+'_'.join(list_acquisitions)+'_'+time.strftime("%Y-%m-%d")+'.csv'
            filtered_stats.T.to_csv(name_file)
            
        # Outputting the filtered pandas dataframe.
        return filtered_stats
    
   
    def aggregate(self, list_metrics,write_mode=False, name_file = None):
        # Processing arguments
        aggregated_stats = pd.DataFrame()
        if type(list_metrics) != list:
            list_metrics = [list_metrics]
            
        for metric in list_metrics:
            tmp = self.filtered_stats.groupby(['id_model', 'ckpt']).apply(metric)
            tmp.columns = map(lambda x: x+'_'+metric.__name__, tmp.columns.tolist())
            aggregated_stats = pd.concat([aggregated_stats, tmp],
                                             axis = 1, ignore_index=False)
            
        if write_mode == True:
            if name_file is None:
                name_file = 'agg_'+'_'.join(map(lambda x: x.__name__, list_metrics))+'_'+time.strftime("%Y-%m-%d")+'.csv'
            aggregated_stats.T.to_csv(name_file)
            
        return aggregated_stats

In [None]:
model_list = [os.path.join('../test_models',e) for e in os.listdir('../test_models') if os.path.isdir(os.path.join('../test_models/', e))]

In [None]:
model_list

In [None]:
met = metrics()
met.add_models(model_list)
#met.add_models([os.path.join('../models','baseline-SEM256new-9438/')])
#met.load_models()

In [None]:
met.filtered_stats

In [None]:
met.filter_(list_acquisitions=['SEM'], write_mode=False, name_file='onSEM-5185.csv')

In [None]:
met.aggregate([np.mean], write_mode=True, name_file='onSEM-5185.csv')