# IMPORTS

In [None]:
%matplotlib inline
#import mpld3
#mpld3.enable_notebook()

In [None]:
import json
import os
import matplotlib.pyplot as plt
from AxonDeepSeg.testing.segmentation_scoring import *
import pickle

from time import time
#from ipywidgets.widgets import *

In [None]:
from scipy.misc import imread, imsave

# 1/ Define the parameters to use for the segmentation

In [None]:
# input data to build the training set
dataset_name = 'SEM_3c_512'
gps = 0.1

In [None]:
path_data = '../data/' + dataset_name + '/raw/'
path_testing = '../data/baseline_testing/SEM_data15/'
#path_testing = '../data/images_nyu_tem/'

# output path of training data path
path_training = '../data/' + dataset_name + '/training/'

im = imread(path_testing + 'image.png', mode='L', flatten=True)
imsave(path_testing + 'image.png', 255-im)

# 2/ Load the config file

Choose here the name of the model you want to load

In [None]:
#model_name = 'cv_3c_d4_c2_k3__0-1062' # TEM 3c 256
#model_name = 'cv_3c_d3_c2_k3__13-5185/'
#model_name = 'baseline_3c_balanced_nobn-1957' # SEM 3c 256 old
#model_name = 'cv_3c_d4_c2_k3__0-9097' # SEM 512
#model_name = 'cv_3c_d4_c2_k3__0-7678' # TEM 3c 512
#model_name = 'cv_3c_d4_c2_k3__0-8580' # mixed 3c 512
#model_name = 'cv_3c_d4_c2_k3__0-4698' # good working model
model_name = 'baseline_sem512-3734'

In [None]:
# optional input path of a model to initialize the training
#path_model_init = 'network_testing/test_2905'
path_model_init = None

# output path for trained U-Net
path_model = '../models/' + model_name

path_configfile = path_model + '/config_network.json'

if not os.path.exists(path_model):
    os.makedirs(path_model)

with open(path_configfile, 'r') as fd:
    config_network = json.loads(fd.read())
    
# OPTIONAL : specify the gpu one wants to use.
gpu_device = 'gpu:0' # or gpu_device = 'gpu:1' these are the only two possible targets for now.

# 3/ Apply the model to segment one image

#### Segmentation

In [None]:
from AxonDeepSeg.apply_model import axon_segmentation

In [None]:
pred, pred_proba = axon_segmentation([path_testing], ["image.png"], path_model, config_network, overlap_value=25, resampled_resolutions=gps, 
                                  prediction_proba_activate=True, write_mode=True, inference_batch_size=4, gpu_per=0.3, verbosity_level=0)

** TRANSFORMING PRED_PROBA INTO REAL PROBABILITIES **

In [None]:
pred_proba = pred_proba[0]

In [None]:
a = np.exp(pred_proba)
b = np.sum(a, axis=-1)
c = np.stack([np.divide(a[:,:,i],b) for i in range(3)], axis=-1)

Then we save the probabilities using a pickle file

In [None]:
path_saving = '../data/SEM_3classes_reduced/testing/pilou/'

In [None]:
with open(os.path.join(path_saving, 'pred_proba.pkl'), 'w') as handle:
    pickle.dump(c, handle, protocol=pickle.HIGHEST_PROTOCOL)

# 4/ Visualisation of the segmentation

Import images

In [None]:
imorg = imread(path_testing + '/image.png', flatten=True)
imads = imread(path_testing + '/AxonDeepSeg.png', flatten=True)

Display original image

In [None]:
plt.figure(figsize=(13,10))
plt.title('Original image')
plt.imshow(imorg, cmap='gray')
plt.show();

Display segmentation and compare it with original image

In [None]:
plt.figure(figsize=(13,10))
plt.title('Segmented image')
plt.imshow(imads)
plt.show();

In [None]:
plt.figure(figsize=(13,10))
plt.title('Superposed images')
plt.imshow(imorg, cmap='gray', alpha=0.5)
plt.imshow(imads, cmap='viridis', alpha=0.5)

In [None]:
imorg.shape

# 5/ Metrics analysis

We are now going to analyze each class dice. First we compute masks for each prediction and each ground truth

In [None]:
mask = imread(path_testing + '/mask.png', flatten=True)
pred = imread(path_testing + '/AxonDeepSeg.png', flatten=True)

Creating the mask necessary to compute the dice

In [None]:
gt_axon = mask > 200
gt_myelin = np.logical_and(mask >= 50, mask <= 200)

pred_axon = pred > 200
pred_myelin = np.logical_and(pred >= 50, pred <= 200)

Affichage de chaque classe

In [None]:
plt.figure(figsize=(13,10))
plt.subplot(2,1,1)
plt.imshow(pred_axon.astype(int) - gt_axon.astype(int))
plt.title('Axon prediction - ground truth')
plt.colorbar()
plt.subplot(2,1,2)
plt.imshow(pred_myelin.astype(int) - gt_myelin.astype(int))
plt.title('Myelin prediction - ground truth')
plt.show();

In [None]:
float(np.sum(pred_myelin.astype(int) - gt_myelin.astype(int)))/(imorg.shape[0]*imorg.shape[1])

### Computing the dice for axon and myelin

In [None]:
dice_axon = pw_dice(pred_axon, gt_axon)
dice_myelin = pw_dice(pred_myelin, gt_myelin)

print('Dice for Axon : ' + str(dice_axon))print('Dice for myelin : ' + str(dice_myelin))

Computing sensitivity, precision and diffusion

In [None]:
score_analysis(imorg, gt_axon, pred_axon)

Displaying the element-wise dice for the axon class

In [None]:
data_axon_dice = dice(imorg, gt_axon, pred_axon, min_area=4)

In [None]:
data_axon_dice['dice'].quantile([0.1, 0.5, 0.9, 0.95]).values