# Evaluation of UNET performance upon common distortions

The aim of this notebook is to provide some informations about how a UNET initially trained to segment cells is perturbated by some common distortions applied on the input images. The tested distortions applied on the inputs are an added 2D gaussian, a gaussian noise as well as a rescaling of the input images. To evaluate the performance of the UNET, several plots are generated: the accuracy, the Jaccard index and the number of detected cells in function of the degree of degradation that depends on the parameters of the distortions.

### $\bullet$ Importing libraries and utilitary functions

In [1]:
# Import python libraries.
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
%load_ext autoreload
%autoreload 2
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import tensorflow

# Import distortion filters and utilitary functions.
from distortions import add_gaussian, zoom_image, zoom_image_to_meet_shape, add_gaussian_noise
from utils import *
from image_processing_methods import DogImageProcessingMethod, DenoiserImageProcessingMethod
from plots import show_image_mask, show_image_pred, plot_results

In [2]:
# Set seeds of random number generators to guarantee reproducibility.
np.random.seed(3)
tensorflow.random.set_seed(4)

### $\bullet$ Setting size of the input image and different paths

In [3]:
# Shape of the inputs of the deep network.
images_shape = (256, 256, 1)

# Variables defining the path to the dataset.
test_input_path  = '../Dataset/test/input/'
test_output_path = '../Dataset/test/output/'

# Variable defining where models are stored.
unet_original_save_path       = "../Models/Unet Original/"
unet_data_augmented_save_path = "../Models/Unet Data Augmented/"
unet_data_distorted_save_path = "../Models/Unet Data Distorted/"

### $\bullet$ Function to display the input images and predictions upon distorsion as well as a summary of values that shows the quality of the predicted images compared to the ground truth

In [4]:
def evaluation_summary(result, parameter_name, parameter_val, image, mask, distorted_image, models):
    """
    Function printing a summary of the results collected in results and showing exemples of the original 
    image and mask, distorted images, and predictions of each model.
    
    Args:
        result::[dict]
            Dictionary with the name of the model as key and as value another dictionary with keys
            the computed metrics for the predictions by the respective model.
        parameter_name::[string]
            The name of the parameter that was being tested when collecting result. Only needed to print
            it next to its value for reference.
        parameter_val::[float]
            value taken by the parameter that was being tested when collecting result. Only needed to print
            it in the summary for reference.
        image::[np.array]
            Numpy array of shape (n_lines, n_columns, n_channels) representing one image sampled from the
            images used to collect result that is shown as to the user as an exemple.
        mask::[np.array]
            Numpy array of same shape as image containing the corresponding true segmentation mask, used
            to display as an exemple. 
        distorted_image::[np.array]
            Numpy array of same shape as image containing the corresponding distorted image used for prediction,
            used to display as an exemple.  
        models::[dict]
            Dictionary with the name of the model as key and as value an object implementing a predict method
            as required by utils.get_binary_predictions.
            
    """
    print("{:<50}: {}".format(parameter_name.upper()          , parameter_val))
    for model_key in result.keys():
        print("{}".format(model_key.upper()))
        print("     {:<45}: {}".format("Accuracy"                      , result[model_key]["accuracy"]))
        print("     {:<45}: {}".format("Jaccard score"                 , result[model_key]["jaccard"]))
        print("     {:<45}: {}".format("Precision"                     , result[model_key]["precision"]))
        print("     {:<45}: {}".format("Recall"                        , result[model_key]["recall"]))
        print("     {:<45}: {}".format("Number of cells in predictions", result[model_key]["number_cells_predictions"]))
        print("     {:<45}: {}".format("Number of cells in masks"      , result[model_key]["number_cells_masks"]))
        
    show_image_mask(image, mask)
    show_image_pred(distorted_image, models)
    
    
def evaluate_models(images, masks, models, number_cells_masks=None):
    """
    Function calculating the predicted binary segmentation masks for each model in models and from those
    computing the different metrics used to evaluate the models, collected in a dictionary with the name
    of the model as key and as value another dictionary collecting the computed metrics for the respective
    model.
    
    Args:
        images::[np.array]
            Numpy array of shape (n_images, n_lines, n_columns, n_channels) containing the images to segment.
        masks::[np.array]
            Numpy array of same shape as images containing the true segmentation masks corresponding to images.
        models::[dict]
            Dictionary with the name of the model as key and as value an object implementing a predict method
            as required by utils.get_binary_predictions.
        number_cells_masks::[int]
            The total number of cells contained in all the masks. If None, it is computed from masks.
            
    Returns:
        results_each_model::[dict]
            Dictionary with the name of the model as key and as value another dictionary with keys
            the computed metrics for the predictions by the respective model.
    
    """
    results_each_model = {}
    for key, model in models.items():
        predictions = get_binary_predictions(images, model)
        accuracy    = np.mean(predictions == masks)
        jaccard     = compute_jaccard_score(predictions, masks)
        precision, recall = compute_precision_recall(predictions, masks)
        
        number_cells_predictions = get_number_cells(predictions)
        if number_cells_masks is None:
            number_cells_masks = get_number_cells(masks)
        
        results_each_model[key] = {"accuracy": accuracy, "jaccard": jaccard, "precision": precision, 
                                   "recall": recall, "number_cells_predictions": number_cells_predictions,
                                   "number_cells_masks": number_cells_masks}
    
    return results_each_model


def apply_distortion_to_all(function, images, params_for_images={}):
    """
    Function applying the distortion in function parameter on all images using the parameters
    passed params_for_images.
    
    Args:
        function::[function]
            Function accepting kwargs in params_for_images which is going to be applied to each image in images.
        images::[np.array]
            Numpy array of shape (n_images, n_lines, n_columns, n_channels) containing the images to apply the
            distortion onto.
        params_for_images::[dict]
            Dictionary containing the kwargs needed to call function parameter.
            
    Returns:
        distorted_images::[np.array]]
            Numpy array of same shape as images containing the images on which the distortion was applied.
    
    """
    distorted_images = []
    for image in images:
        distorted_images.append(function(image, **params_for_images))
    
    return np.array(distorted_images)

### $\bullet$ Retrieve trained models

In [5]:
models = {}
models["unet original"]       = tensorflow.keras.models.load_model(unet_original_save_path)
models["unet data augmented"] = tensorflow.keras.models.load_model(unet_data_augmented_save_path)
models["unet data distorted"] = tensorflow.keras.models.load_model(unet_data_distorted_save_path)

### $\bullet$ Load images from the dataset

In [6]:
test_images, test_masks = get_dataset_from_folders(test_input_path, test_output_path, images_shape)

print(f'Test set contains {len(test_images)} images of shape {test_images[0].shape}.')

number_cells_masks = get_number_cells(test_masks)
print(f"{number_cells_masks} cells were counted in total over all masks.")

Test set contains 510 images of shape (256, 256, 1).
16183 cells were counted in total over all masks.


### $\bullet$ Analysis of the perturbations caused by an added gaussian on the UNET performance

In [7]:
parameter_name = "Amplitude of Gaussian (Before Normalization)"

models_added_gaussian = models.copy()
models_added_gaussian["image processing method"] = DogImageProcessingMethod(ksize_low = 20)

amplitudes = np.linspace(0, 4000, 41)
results = {}

for amplitude in amplitudes:
    added_gaussian_test_images = apply_distortion_to_all(add_gaussian, test_images, {"amplitude": amplitude})
    
    result = evaluate_models(added_gaussian_test_images, test_masks, models_added_gaussian, number_cells_masks)
    results[amplitude] = result
    
    evaluation_summary(result, parameter_name, amplitude, test_images[0], test_masks[0], 
                       added_gaussian_test_images[0], models_added_gaussian)
    
plot_results(results, parameter_name, models_added_gaussian.keys())

KeyboardInterrupt: 

### $\bullet$ Analysis of the perturbations caused by a gaussian noise on the UNET performance

In [None]:
parameter_name = "Standard Deviation"

models_noisy = models.copy()
models_noisy["image processing method"] = DenoiserImageProcessingMethod(denoiser_strength = 25)

mean = 0
sigmas = np.linspace(0, 200, 41)
results = {}

for sigma in sigmas:
    noisy_test_images = apply_distortion_to_all(add_gaussian_noise, test_images, {"mean": mean, "sigma": sigma})

    result = evaluate_models(noisy_test_images, test_masks, models_noisy, number_cells_masks)
    results[sigma] = result
    
    evaluation_summary(result, parameter_name, sigma, test_images[0], test_masks[0], noisy_test_images[0], models_noisy)
        
plot_results(results, parameter_name, models_noisy.keys())

### $\bullet$ Analysis of the perturbations caused by a rescaling of the images on the UNET performance

In [None]:
parameter_name = "Zooming Factor from Patches"

zooms = np.r_[np.linspace(0.5, 2, 15, endpoint=False), np.linspace(2, 5, 7)]
results = {}

for zoom in zooms:
    zoomed_test_images = apply_distortion_to_all(zoom_image, test_images, {"zoom_factor": zoom})
    zoomed_test_masks  = apply_distortion_to_all(zoom_image, test_masks , {"zoom_factor": zoom, "val_padding": 0})

    result = evaluate_models(zoomed_test_images, zoomed_test_masks, models)
    results[zoom] = result
    
    evaluation_summary(result, parameter_name, zoom, zoomed_test_images[0], zoomed_test_masks[0], zoomed_test_images[0], models)
    
plot_results(results, parameter_name, models.keys())

### $\bullet$ Get full input images from dataset

In [None]:
test_full_images, test_full_masks = get_dataset_from_folders(test_input_path, test_output_path)
full_images_shape = (*test_full_images[0].shape, 1)

print(f'Test set contains {len(test_full_images)} images of shape {full_images_shape}.')

### $\bullet$ Analysis of the perturbations caused by havind different sizes of cells on the UNET performance

In [None]:
parameter_name = "Zooming Factor from Original Images"

zooms = np.r_[np.linspace(0.5, 2, 15, endpoint=False), np.linspace(2, 5, 7)]
results = {}

for zoom in zooms:
    patch_shape = (int(images_shape[0] / zoom), int(images_shape[1] / zoom), 1)
    n_patches = np.clip(np.prod(np.divide(full_images_shape, patch_shape)), 6, 80).astype('uint')
    patch_images, patch_masks = split_images_and_masks_into_patches(test_full_images, test_full_masks, patch_shape, n_patches)
    
    zoomed_test_images = apply_distortion_to_all(zoom_image_to_meet_shape, patch_images, {"shape": images_shape})
    zoomed_test_masks  = apply_distortion_to_all(zoom_image_to_meet_shape, patch_masks , {"shape": images_shape})

    result = evaluate_models(zoomed_test_images, zoomed_test_masks, models)
    results[zoom] = result
    
    evaluation_summary(result, parameter_name, zoom, zoomed_test_images[0], zoomed_test_masks[0], zoomed_test_images[0], models)
    
plot_results(results, parameter_name, models.keys())