# Introduction

This notebook calculates the two types of errors described in [notebook: 1_errors_in_registration](1_errors_in_registration.ipynb) for a range of registration parameters, with the aim of choosing a set of parameters to minimise both errors. The errors are calculated for a specific experiment and specific NT and gAA well. These types of wells are chosen because they should represent the wells with the highest and lowest amount of motion, respectively. The `report/registration-parameter-tuning.pptx` file contains the plots from this notebook for all experiments. The `pptx` file also contains the chosen registration parameters for each experiment.

- As described in the [previous notebook](1_errors_in_registration.ipynb), there are two types of errors which are plotted here:

    1. root mean square error (RMSE): registration process error
    2. similarity error: forwards and backwards pass difference
    
    
- The size of these errors is determined by the values and combination of the 3 tuneable registration parameters: `REGION_SIZE`, `PTRN_SIZE`, `MAX_WINDOW`. The impact of these parameters being too small or too large is summarised in the table below:

| Registration Parameter  | Too small                                                                                                 | Too large                                                                                         |
|-------------------------|-----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------|
| PTRN_SIZE               | Increases spurious matches between pattern in region at timestep, T and pattern in region at timestep T+1 | The content of the pattern has changed between frames and therefore a match cannot be found.      |
| REGION_SIZE             | The pattern could move outside the region and therefore a match cannot be found.                          | The matching patterns between two time steps could be wrong, leading to the wrong velocity vector |
| MAX_WINDOW              | Long processing time, weaker signals due to lower intensity.                                              | Smearing out the signal beyond recognition, leading to poor matching and wrong velocity vectors   |

- A single example of  the two errors plotted against `PTRN_SIZE` is shown in the figure below. Here the RMSE clearly increases with increasing `PTRN_SIZE` and the similarity error decreases with increasing `PTRN_SIZE`. In this situation a compromise must be made between the errors and so a `PTRN_SIZE` of ~9 can be chosen

<img src="../markdown_images_for_notebooks/example-errors-graphs-compromise.PNG">

- Since three parameters must be chosen, this notebook plots the errors against the registration parameters in two different ways as described below. 

# Imports

In [1]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import math

from fam13a import utils, image, register
from multiprocessing import Pool

import altair as alt
import time
from altair_saver import save
from IPython import display
from matplotlib import pyplot as plt
from matplotlib import patches
import matplotlib.image as mpimg

import subprocess

# Constants

In [None]:
PROJ_ROOT = utils.here()
# declare the data input directory
INTERIM_HBEC_ROOT = os.path.join(PROJ_ROOT, 'data', 'interim', 'hbec')
print(os.listdir(INTERIM_HBEC_ROOT))

In [None]:
# choose experiment data to load
EXP_ID = 'N67030-59_PBS'

INTERIM_ROOT = os.path.join(PROJ_ROOT, INTERIM_HBEC_ROOT, EXP_ID)

# declare the various output directories
PROCESSED_ROOT = os.path.join(PROJ_ROOT, 'data', 'processed', 'hbec', EXP_ID)
ROI_ROOT = os.path.join(PROCESSED_ROOT, 'roi')
SEG_ROOT = os.path.join(PROCESSED_ROOT, 'segmented', 'movement')
MAX_FRAME_ROOT = os.path.join(PROCESSED_ROOT, 'max_frame')

OUTPUT_ROOT = os.path.join(PROCESSED_ROOT, 'register')
# set level of parallelisation
NCPUS = 16
# define the common file extension used in the input data files
EXTENSION = '.ome.tif'

# find all relevant data files in the data directory 
files = sorted([_f for _f in os.listdir(INTERIM_ROOT) if _f.endswith('tif')])
# remove the extension the file names, se we keep only the bit with useful information
files = [_f.split(EXTENSION)[0] for _f in files]
print(files)

# Load files

In [None]:
file = 'NT_PBS_3_1_MMStack_Pos0'
    
# load each of the associated interim/processed files
frames = utils.frames_from_stack(os.path.join(INTERIM_ROOT, f'{file}{EXTENSION}'))
roi = np.load(os.path.join(ROI_ROOT, f'{file}.npy'))
max_frame = np.load(os.path.join(MAX_FRAME_ROOT, f'{file}.npy'))
movement_mask = np.load(os.path.join(SEG_ROOT, f'{file}.npy')).astype(bool)

In [None]:
# positive time and negative time axis frames:
frames_positive = frames.copy()
frames_negative = frames[::-1]

# Setup

In [None]:
# defining a very simple wrapper around the register.estimate_shifts functon 
# purely to allow the inclusion of a progress bar during processing
def process(regions_pair):
    return register.estimate_shifts(*regions_pair, ptrn_size=ptrn_size)

In [None]:
# run the reg process, generate the velocity fields, generate the errors and return a dict containing
# the velocity fields, downsampled mask and mean, std rmse too
def run_registration_process_with_minimised_errors(zoomed_frames, zoomed_movement_mask, mw, ps, rs):
    reg_process = register.run_registration_process_for_auxillary_analysis(
        zoomed_frames, zoomed_movement_mask, mw, ps, rs, NCPUS
    )
    # isolate the errors
    errors = reg_process['errors']
    # get min of errors (RMSE) per timestep and then average over time too
    min_errors = errors.min(axis=3);
    min_errors_mean = min_errors.mean(axis=0)
    # remove the errors key in the dictionary and add the mean and std rmse
    reg_process.pop('errors')
    reg_process['mean_rmse'] = min_errors_mean.mean()
    reg_process['std_rmse'] = min_errors_mean.std()
    
    return reg_process

# generating an average and a std for similarirty across the chosen sample
def calculate_similarity(
    zoomed_frames_positive, zoomed_frames_negative, zoomed_movement_mask, mw, ps, rs
):
    
    positive = run_registration_process_with_minimised_errors(
        zoomed_frames_positive, zoomed_movement_mask, mw, ps, rs
    )
    negative = run_registration_process_with_minimised_errors(
        zoomed_frames_negative, zoomed_movement_mask, mw, ps, rs
    )
    
    difference = {}
    difference['mags'] = np.abs(positive['mags'] - negative['mags'])
    # fractinal difference in speed (mags)
    difference['mags'] /= np.where(positive['mags'] == 0, np.inf, positive['mags'])

    # too look at differences between angles, we must take the cosine between the 
    # positive and (-)negative vectors
    difference['cosines'] = (
        (positive['vx']*(-1)*negative['vx']) + (positive['vy']*(-1)*negative['vy']) 
    )
    # get rid of numerical errors (|cosine| > 1)
    difference['cosines'][difference['cosines']>1] = 1
    difference['cosines'][difference['cosines']<-1] = -1
    difference['angles'] = (180/math.pi) * (np.arccos(difference['cosines']))
    
    sub_mask = positive['sub_mask']

    result_mean = {}
    result_std = {}
    for key in ['mags','angles']:
        result_mean[key] = np.mean(difference[key][sub_mask])
        result_std[key] = np.std(difference[key][sub_mask])
    result_mean['positive_rmse'] = positive['mean_rmse']
    result_std['positive_rmse'] = positive['std_rmse']
    
    max_mag_position = np.unravel_index(positive['mags'].argmax(), positive['mags'].shape)
        
    return result_mean, result_std, max_mag_position

# Choose small patch to focus on

In [None]:
# Create figure and axes
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111)

# Display the image
ax.imshow(max_frame);
ax.imshow(movement_mask, alpha=0.1)
# Create a Rectangle patch
top_left_corner = (500,1000)
box_width = 200
box_height = 200
rect = patches.Rectangle(top_left_corner,box_width,box_height,linewidth=1,edgecolor='r',facecolor='none')

# Add the patch to the Axes
ax.add_patch(rect)
plt.title(f'{EXP_ID}-{file}-sample-patch')
plt.show()

In [None]:
# Get limits of the bounding box
box_x_lim = (top_left_corner[1], top_left_corner[1] + box_height)
box_y_lim = (top_left_corner[0], top_left_corner[0] + box_width)

# Loop through various registration parameters
- with NCPUS=32, the following loop takes ~600s for a 200x200 patch

In [None]:
# window size to take along the time dimension
max_window_list = ['1', '1_5', '5', '7']
# set the sizes of the pattern and search regions
ptrn_size_list = [5,7,9,11,13]
region_size_list = [9,11,13,15,17,19,21]

# results dataframe
df = pd.DataFrame(columns = ['max_window', 'ptrn_size', 'region_size', 'mean', 'std'])
start_time = time.time()
for max_window_ in max_window_list:
     
    if '_' in max_window_:
        # take every nth frame (alternative to max over time-widow since we set MAX_WINDOW = 1)
        nth = int(max_window_.split('_')[1])
        zoomed_frames_positive = frames_positive[::nth,slice(*(box_x_lim)),slice(*(box_y_lim))]
        zoomed_frames_negative = frames_negative[::nth,slice(*(box_x_lim)),slice(*(box_y_lim))]
        max_window = int(max_window_.split('_')[0])
    else:
        zoomed_frames_positive = frames_positive[:,slice(*(box_x_lim)),slice(*(box_y_lim))]
        zoomed_frames_negative = frames_negative[:,slice(*(box_x_lim)),slice(*(box_y_lim))]
        max_window = int(max_window_)
    zoomed_motion_mask = movement_mask[slice(*(box_x_lim)),slice(*(box_y_lim))]
        
    for ptrn_size_ in ptrn_size_list:
        ptrn_size = (ptrn_size_,)*2
        
        for region_size_ in region_size_list:
            region_size = (region_size_,)*2
            # make sure the pattern fits inside the region and when scanning in the region, it must have more than 9 options (so min_options=16)
            if region_size_ - ptrn_size_ > 3:
                mean, std, max_mag_position = calculate_similarity(
                    zoomed_frames_positive, zoomed_frames_negative, zoomed_motion_mask,
                    max_window, ptrn_size, region_size
                )
                for key in ['mags','angles']:
                    df = df.append({
                        'max_window': max_window_, 
                        'ptrn_size': ptrn_size_, 
                        'region_size': region_size_, 
                        'mean': mean[key], 
                        'std': std[key],
                        'mean_pos_rmse': mean['positive_rmse'],
                        'std_pos_rmse': std['positive_rmse'],
                        'field': key,
                        'max_mag_position': max_mag_position
                    }, ignore_index=True)
print(time.time() - start_time)

# Plot similarity, rmse vs registration parameters

- 4 plots are produced and saved in this section:
    1. rmse-view-1.svg
    2. sim-view-1.svg
    3. rmse-view-2.svg
    4. sim-view-2.svg
- "view-1" plots the error against the `PTRN_SIZE` on the x-axis, with the different colours of the points representing various `REGION_SIZE` values. The variation in error with respect to `MAX_WINDOW` is captured on different rows of the view. 
- "view-2" plots the error against the `REGION_SIZE` on the x-axis, with the different colours of the points representing various `MAX_WINDOW` values. The variation in error with respect to `PTRN_SIZE` is captured on different rows of the view. 

In [None]:
dff=df.copy()
df.drop(columns=['max_mag_position'], inplace=True)

In [None]:
# generate base chart object with all the needed data
base = alt.Chart(df).transform_calculate(
    ymin='datum.mean_pos_rmse-datum.std_pos_rmse',
    ymax='datum.mean_pos_rmse+datum.std_pos_rmse'
)
# construct the points and error bars plots from the base chart
points = base.mark_point(size=100).encode(
    y=alt.X('mean_pos_rmse:Q', axis=alt.Axis(title='rmse mean'),
        scale=alt.Scale(zero=False)),
    x=alt.X('ptrn_size:N', axis=alt.Axis(title='ptrn size')),
    color='region_size:N'
)
errors = base.mark_errorbar().encode(
    x=alt.X('ptrn_size:N', axis=alt.Axis(title='ptrn size')),
    y=alt.Y('ymin:Q', axis=alt.Axis(title='mean_pos_rmse')),
    y2='ymax:Q',
    color='region_size:N'
)

k = alt.layer(points).facet(
    row=alt.Column('max_window:N',),
    column=alt.Row('field:N',),
).configure_axis(
    labelFontSize=16,
    titleFontSize=20
).configure_legend(
    labelFontSize = 20,
    titleFontSize = 20
).configure_header(
    labelFontSize=20,
    titleFontSize = 24
).resolve_scale(
    y='independent'
)
save(k, f'{EXP_ID}-{file}-rmse-view-1.svg'); 

In [None]:
# generate base chart object with all the needed data
base = alt.Chart(df).transform_calculate(
    ymin='datum.mean-datum.std',
    ymax='datum.mean+datum.std'
)
# construct the points and error bars plots from the base chart
points = base.mark_point(size=100).encode(
    y=alt.X('mean:Q', axis=alt.Axis(title='similarity mean')),
    x=alt.X('ptrn_size:N', axis=alt.Axis(title='ptrn size')),
    color='region_size:N'
)
errors = base.mark_errorbar().encode(
    x=alt.X('ptrn_size:N', axis=alt.Axis(title='ptrn size')),
    y=alt.Y('ymin:Q', axis=alt.Axis(title='mean')),
    y2='ymax:Q',
    color='region_size:N'
)
k = alt.layer(points).facet(
    row=alt.Column('max_window:N',),
    column=alt.Row('field:N',),
).configure_axis(
    labelFontSize=16,
    titleFontSize=20
).configure_legend(
    labelFontSize = 20,
    titleFontSize = 20
).configure_header(
    labelFontSize=20,
    titleFontSize = 24
).resolve_scale(
    y='independent'
)
save(k, f'{EXP_ID}-{file}-sim-view-1.svg'); 

In [None]:
# generate base chart object with all the needed data
base = alt.Chart(df).transform_calculate(
    ymin='datum.mean-datum.std',
    ymax='datum.mean+datum.std'
)
# construct the points and error bars plots from the base chart
points = base.mark_point(size=100).encode(
    y=alt.Y('mean:Q', axis=alt.Axis(title='similarity mean')),
    x=alt.X('region_size:N', axis=alt.Axis(title='region size')),
    color='max_window:N'
)
errors = base.mark_errorbar().encode(
    x=alt.X('region_size:N', axis=alt.Axis(title='region size')),
    y=alt.Y('ymin:Q', axis=alt.Axis(title='mean')),
    y2='ymax:Q',
    color='max_window:N'
)
k=alt.layer(points).facet(
    row=alt.Column('field:N',),
    column=alt.Row('ptrn_size:N',),
).configure_axis(
    labelFontSize=16,
    titleFontSize=20
).configure_legend(
    labelFontSize = 20,
    titleFontSize = 20
).configure_header(
    labelFontSize=20,
    titleFontSize = 24
).resolve_scale(
    y='independent'
)
save(k, f'{EXP_ID}-{file}-sim-view-2.svg');

In [None]:
# generate base chart object with all the needed data
base = alt.Chart(df).transform_calculate(
    ymin='datum.mean_pos_rmse-datum.std_pos_rmse',
    ymax='datum.mean_pos_rmse+datum.std_pos_rmse'
)
# construct the points and error bars plots from the base chart
points = base.mark_point(size=100).encode(
    y=alt.Y('mean_pos_rmse:Q', axis=alt.Axis(title='rmse mean'),
        scale=alt.Scale(zero=False)),
    x=alt.X('region_size:N', axis=alt.Axis(title='region size')),
    color='max_window:N'
)
errors = base.mark_errorbar().encode(
    x=alt.X('region_size:N', axis=alt.Axis(title='region size')),
    y=alt.Y('ymin:Q', axis=alt.Axis(title='mean')),
    y2='ymax:Q',
    color='max_window:N'
)
k=alt.layer(points).facet(
    row=alt.Column('field:N',),
    column=alt.Row('ptrn_size:N',),
).configure_axis(
    labelFontSize=16,
    titleFontSize=20
).configure_legend(
    labelFontSize = 20,
    titleFontSize = 20
).configure_header(
    labelFontSize=20,
    titleFontSize = 24
).resolve_scale(
    y='independent'
)
save(k, f'{EXP_ID}-{file}-rmse-view-2.svg');

# view rmse and similarity error charts side by side

Two images are produced in this section:
1. rmse and similarity error in view-1
2. rmse and similarity error in view-2

In [None]:
def display_images(image1, image2, view):
    # read images
    img_A = mpimg.imread(image1)
    img_B = mpimg.imread(image2)

    # display images
    plt.figure(figsize=(30,30))
    plt.subplot(121)
    plt.imshow(img_A);
    plt.axis('off')
    plt.title(f'{EXP_ID}-{file}-rmse')


    plt.subplot(122)
    plt.imshow(img_B);
    plt.axis('off')
    plt.title(f'{EXP_ID}-{file}-sim')
    
    plt.show()
    
# run these shell commands to convert svg's to png's to view them side by side using the function above
for error_type in ['rmse', 'sim']:
    for view_type in [1,2]:
        subprocess.call([
            'convert', 
            f'{EXP_ID}-{file}-{error_type}-view-{view_type}.svg',
            f'{EXP_ID}-{file}-{error_type}-view-{view_type}.png'
        ])

In [None]:
display_images(f'{EXP_ID}-{file}-rmse-view-1.png', f'{EXP_ID}-{file}-sim-view-1.png', 1)

In [None]:
display_images(f'{EXP_ID}-{file}-rmse-view-2.png', f'{EXP_ID}-{file}-sim-view-2.png', 2)