# Single-shot autofocus using deep learning
This notebook provides all code neccessary for and shows the production of all figures for the single-shot autofocus using deep learning

In [27]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
from afutil import get_patch_metadata, read_or_calc_focal_planes, compile_deterministic_data,\
    feature_vector_generator_fn, MagellanWithAnnotation
from defocusnetwork import DefocusNetwork
import numpy as np


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


The DataWrapper class wraps all the essential functionality needed for training single shot autofocus. By default, it reads Micro-Magellan datasets (with an added subclass to be able to store the results of intermediate computations in an HDF5 file), but this could be replaced with any data source so long as it provides the following methods and fields

In [65]:
class DataWrapper:

    def __init__(self, magellan):
        self.magellan = magellan

    #TODO: remember to change to float
    def read_ground_truth_image(self, position_index, z_index):
        """
        Read image in which focus quality can be measured form quality of image
        :param pos_index: index of xy position
        :param z_index: index of z slice (starting at 0)
        :param xy_slice: (cropped region of image)
        :return:
        """
        return self.magellan.read_image(channel_name='DPC_Bottom', pos_index=position_index, 
                        z_index=z_index + min(self.magellan.get_z_slices_at(position_index))).astype(np.float)

    def read_prediction_image(self, position_index, z_index, patch_index, split_k):
        """
        Read image used for single shot prediction (i.e. single LED image)
        :param pos_index: index of xy position
        :param z_index: index of z slice (starting at 0)
        :param split_k: number of crops along each dimension
        :param patch_index: index of the crop
        :return:
        """
        patch_size, patches_per_image = get_patch_metadata((self.get_image_width(),
                                                            self.get_image_height()), split_k)
        y_tile_index = patch_index // split_k
        x_tile_index = patch_index % split_k
        xy_slice = [[y_tile_index * patch_size, (y_tile_index + 1) * patch_size],
                    [x_tile_index * patch_size, (x_tile_index + 1) * patch_size]]
        image = self.magellan.read_image(channel_name='autofocus', pos_index=position_index, z_index=z_index +
                  min(self.magellan.get_z_slices_at(position_index))).astype(np.float)
        #crop
        return image[xy_slice[0][0]:xy_slice[0][1], xy_slice[1][0]:xy_slice[1][1]]
        
    def get_image_width(self):
        """
        :return: image width in pixels
        """
        return self.magellan.image_width

    def get_image_height(self):
        """
        :return: image height in pixels
        """
        return self.magellan.image_height

    def get_num_z_slices_at(self, position_index):
        """
        return number of z slices (i.e. focal planes) at the given XY position
        :param position_index:
        :return:
        """
        return len(self.magellan.get_z_slices_at(position_index))

    def get_pixel_size_z_um(self):
        """
        :return: distance in um between consecutive z slices
        """
        return self.magellan.pixel_size_z_um

    def get_num_xy_positions(self):
        """
        :return: total number of xy positons in data set
        """
        return self.magellan.get_num_xy_positions()

    def store_focal_plane(self, name, focal_position):
        """
        Store the computed focal plane as a string, float pair
        """
        self.magellan.write_annotation(name, focal_position)

    def read_focal_plane(self, name):
        """
        read a previously computed focal plane
        :param name: key corresponding to an xy position for whch focal plane has already been computed
        :return:
        """
        return self.magellan.read_annotation(name)

    def store_array(self, name, array):
        """
        Store a numpy array containing the design matrix for training the non-deterministic part of the network (i.e.
        after the Fourier transform) so that it can be retrained quickly without having to recompute
        :param name:
        :param array: (n examples) x (d feature length) numpy array
        """
        self.magellan.store_array(name, array)

    def read_array(self, name):
        """
        Read and return a previously computed array
        :param name:
        :return:
        """
        return self.magellan.read_array(name)

### Load data and compute the ground truth focal planes as targets for training
The show_output flag will create a plot of the averaged high frequency content of the log power spectrum. If new data is substituted in, the maximum of the this plot should correspond to the correct focal plane

In [66]:
#parameters for the deterministic part of the network
#TODO: better explain what these mean
deterministic_params = {'non_led_width': 0.1, 'led_width': 0.6, 'tile_split_k': 2}

#load data
cell_data = DataWrapper(MagellanWithAnnotation(
    '/home/henry/data/2018-9-27 Cells and histology af data/Neomounted cells 12x12 30um range 1um step_1'))
histology_data = DataWrapper(MagellanWithAnnotation(
    '/home/henry/data/2018-9-27 Cells and histology af data/unstained path section 12x12 30um range 1um step_1'))

#load or compute target focal planes using 22 CPU cores to speed computation
focal_planes = {dataset: read_or_calc_focal_planes(dataset, split_k=deterministic_params['tile_split_k'],
                                    n_cores=22, show_output=True) for dataset in [cell_data, pathology_data]}

Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading precomputed focal plane
Reading 

### Train network using cell data, compute validation on other cell data and histology section data

In [None]:
#split cell data into training and validation sets
num_pos = cell_data.get_num_xy_positions()
train_positions = list(range(int(num_pos * 0.9)))
validation_positions = list(range(max(train_positions) + 1, num_pos))
#use subset of histology dataset to save comput time
validation_positions_histology = list(range(len(validation_positions)))

#Compute or load already computed design matrices
train_features, train_targets = compile_deterministic_data([cell_data], [train_positions], focal_planes,
                                                           deterministic_params=deterministic_params)
validation_features_cells, validation_targets_cells = compile_deterministic_data([cell_data], [validation_positions],
                                                    focal_planes, deterministic_params=deterministic_params)
validation_features_histology, validation_targets_histolgoy = compile_deterministic_data([histology_data], 
                        [validation_positions_histology], focal_planes, deterministic_params=deterministic_params)


#make genenrator function for providing training examples and seperate validation generator for assessing its progress
#stop training once error on validation set stops decreasing
train_generator = feature_vector_generator_fn(train_features, train_targets, mode='all', 
                                        split_k=deterministic_params['tile_split_k'])
val_generator = feature_vector_generator_fn(validation_features_cells, validation_targets_cells, mode='all', 
                                        split_k=deterministic_params['tile_split_k'])
val_generator_histology = feature_vector_generator_fn(validation_features_cells, validation_targets_cells, mode='all', 
                                        split_k=deterministic_params['tile_split_k'])

#feed in the dimensions of the cropped input so the inference network knows what to expect
#although the inference network is not explicitly used in this notebook, it is created so that the model tensforflow
#creates could later be used on real data
patch_size, patches_per_image = get_patch_metadata((data[0].get_image_width(),
                                        data[0].get_image_height()), deterministic_params['tile_split_k'])

#Create network and train it
defocus_prediction_network = DefocusNetwork(input_shape=features.shape[1], train_generator=train_generator,
                             val_generator=val_generator, predict_input_shape=[patch_size, patch_size],
                             deterministic_params=deterministic_params, regressor_only=True, train_mode='train')

7892 sliceposition tuples
Building deterministic graph...
Evaluating deterministic graph over training set...
batch 0
batch 1
batch 2
batch 3
batch 4
batch 5
batch 6
batch 7
batch 8
batch 9
batch 10
batch 11
batch 12
batch 13
batch 14
batch 15
batch 16
batch 17
batch 18
batch 19
batch 20
batch 21
batch 22
batch 23
batch 24
batch 25
batch 26
batch 27
batch 28
batch 29
batch 30
batch 31
batch 32
batch 33
batch 34
batch 35
batch 36
batch 37
batch 38
batch 39
batch 40
batch 41
batch 42
batch 43
batch 44
batch 45
batch 46
batch 47
batch 48
batch 49
batch 50
batch 51
batch 52
batch 53
batch 54
batch 55
batch 56
batch 57
batch 58
batch 59
batch 60
batch 61
batch 62
batch 63
batch 64
batch 65
batch 66
batch 67
batch 68
batch 69
batch 70
batch 71
batch 72
batch 73
batch 74
batch 75
batch 76
batch 77
batch 78
batch 79
batch 80
batch 81
batch 82
batch 83
batch 84
batch 85
batch 86
batch 87
batch 88
batch 89
batch 90
batch 91
batch 92
batch 93
batch 94
batch 95
batch 96
batch 97
batch 98
batch 99


## Figure 2: Network performance on different parts of the same sample and different samples 

In [6]:
#run training set and both valdation sets through network to generate predictions
train_prediction_defocus, train_target_defocus = defocus_prediction_network.predict_validation(train_generator)
val_prediction_defocus_cells, val_target_defocus_cells = defocus_prediction_network.predict_validation(val_generator)
test_prediction_defocus_histology, test_target_defocus_histoloy = defocus_prediction_network.predict_validation(val_generator_histology)

def average_predictions(pred, target, block_size):
    """
    Take the median of all predictions coming from a single raw image--that is, all crops that came from the same
    original image
    """
    pred_consensus = np.median(np.reshape(pred, [-1, block_size]), axis=1)
    target_consensus = np.median(np.reshape(target, [-1, block_size]), axis=1)
    return pred_consensus, target_consensus


train_pred_avg, train_target_avg = average_predictions(train_prediction_defocus, train_target_defocus, 
                                                       deterministic_params['tile_split_k'] ** 2)
val_pred_avg_cells, val_target_avg_cells = average_predictions(val_prediction_defocus_cells, val_target_defocus_cells, 
                                                     deterministic_params['tile_split_k']**2)
val_pred_avg_histology, val_target_avg_histology = average_predictions(test_prediction_defocus_histology, 
                                                 test_target_defocus_histoloy, deterministic_params['tile_split_k']**2)

def plot_results(pred, target, draw_rect=False):
    plt.plot(target, pred, '.')
    plt.xlabel('Target defocus (um)')
    plt.ylabel('Predicted defocus (um)')
    if draw_rect:
        min_target = np.min(target)
        max_target = np.max(target)
        height = (max_target - min_target)*np.sqrt(2)
        width = 2
        plt.gca().add_patch(mpatches.Rectangle([min_target, min_target+width/np.sqrt(2)], width, height,
                                               angle=-45, color=[1, 0, 0, 0.2]))
        plt.plot([min_target, max_target], [min_target, max_target], 'r-')
    print('{} RMSE: {}'.format(name, np.sqrt(np.mean((pred - target) ** 2))))


plt.figure()
plot_results(train_prediction_defocus, train_target_defocus, 'Training (cells)')
plot_results(test_prediction_defocus, test_target_defocus, draw_rect=True)
plot_results(test_prediction_defocus, test_target_defocus, 'Test (histology)', draw_rect=True)
plt.legend(['Training set (cells)',  'Test set (cells)', 'Test set (histology section)',
            'Ground truth', 'Objective depth of focus'])
plt.show()

NameError: name 'defocus_prediction_network' is not defined

LEDs in vertical axis of array:
4 12 28 48 83 119 187
3 11 27 47 84 120 188
3 has defect in it, 4 doesnt