## Preprocess and motion correct 3D movies
Step 1 of the Caiman processing pipeline for multi-layer two-photon calcium imaging movies. Assume movie format acquired using the Scope setup (i.e. different layers arranged as mosaic on top of each other).

### Imports & Setup
The first cells import the various Python modules required by the notebook. In particular, a number of modules are imported from the Caiman package. In addition, we also setup the environment so that everything works as expected.

In [None]:
# TODO: check for unnecessary imports
# Generic imports
# from __future__ import absolute_import, division, print_function
# from builtins import *
from __future__ import print_function

import os, sys, glob, platform, shutil, re, math
import json
import time, datetime
from functools import partial
import xml.etree.ElementTree as ET
import numpy as np
import matplotlib.pyplot as plt
import skimage.transform
from scipy import interpolate
from tifffile import TiffFile, imread, imsave
from IPython.display import clear_output

# Import Bokeh library
from bokeh.plotting import Figure, show
from bokeh.layouts import gridplot
from bokeh.models import Range1d, CrosshairTool, HoverTool, Legend
from bokeh.io import output_notebook, export_svgs
from bokeh.models.sources import ColumnDataSource

%matplotlib inline

In [None]:
# on Linux we have to add the caiman folder to Pythonpath
if platform.system() == 'Linux':
    sys.path.append(os.path.expanduser('~/caiman'))
# environment variables for parallel processing
os.environ['MKL_NUM_THREADS']='1'
os.environ['OPENBLAS_NUM_THREADS']='1'
os.environ['VECLIB_MAXIMUM_THREADS']='1'

In [None]:
# Import CaImAn and custom functions
import caiman as cm
from caiman.motion_correction import MotionCorrect
import utils

In [None]:
# This has to be in a separate cell, otherwise it wont work.
from bokeh import resources
output_notebook(resources=resources.INLINE)

### Specify experimental parameters

**Update!**

Here we specify where the files are located, how the files are called, the frame rate of the acquisition, how to crop the movies and how many files to process.
- animal_folder ... the folder where the data is located on the volume `Data`.
- day_folder ...
- area_folder ...
- max_files ... how many files to process per session (0 for all)
- n_planes ... number of planes recorded in the movies
- x_crop ... fraction of x pixels to crop (i.e. 0.5 for half or >1 to specify the exact number of pixels)

In [None]:
animal_folder = 'M3_October_2018'
day_folder = 'M3_2018-10-02'
area_folder = 'S1'
max_sessions = 10 # how many sessions to process per area

n_planes = 3 # number of planes in movie
x_crop = 0.5 # fraction of x pixels to crop (i.e. 0.5 for half)

max_group_size = 75 # if there are more files, they will be processed in groups

n_processes = 4 # number of parallel processes (None to select automatically)

# create the complete path to the data folder
if platform.system() == 'Linux':
    data_folder = '/home/ubuntu/Data'
elif platform.system() == 'Darwin':
    data_folder = '/Users/Henry/Data/temp/Dendrites_Gwen'
data_folder = os.path.join(data_folder, animal_folder, day_folder, area_folder)
data_folder

In [None]:
# select sessions for processing
p = re.compile('\d\d-\d\d-\d\d_Live') # regular expression that should match the folder names (ie. 01-23-45_Live)
sessions = [os.path.join(data_folder, x) for (i,x) in enumerate(sorted(os.listdir(data_folder))) if p.match(x) and i <= max_sessions]

In [None]:
tiff_files = []
xml_files = []
for i_session in sessions:
    tiff_files.append([os.path.join(i_session, x) for x in os.listdir(i_session) if x.endswith('.tif') and not 'stacked' in x][0])
    xml_files.append([os.path.join(i_session, x) for x in os.listdir(i_session) if 'parameters.xml' in x][0])

In [None]:
# read frame rate from parameters.xml
frame_rates = []
for ix, i_session in enumerate(sessions):
    tree = ET.parse(xml_files[ix])
    root = tree.getroot()
    for child in root:
        if child.tag == 'area0': # Note: only area 0!
            fr = child.find('Framerate_Hz')
            frame_rate = float(fr.text)
            frame_rates.append(frame_rate)
    if ix <= 10:
        print('Frame rate: %1.4f Hz (%s)' % (frame_rate, os.path.split(i_session)[1]))
if ix > 10:
    print('...')

### Setup cluster
The default backend mode for parallel processing is through the multiprocessing package. This will allow us to use all the cores in the VM. Note that the `cropTif` function has to be defined before starting the cluster, so that pool workers have the function available.

In [None]:
# start the cluster (if a cluster already exists terminate it)
if 'dview' in locals():
    dview.terminate()
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=n_processes, single_thread=False)

### Convert TIFF files to ImageJ hyperstack files
This cell calls the `mosaic_to_stack` function in `utils` through the multiprocessing `starmap` method to make use of multiple cores.

In [None]:
# convert multiple planes in mosaic to 3D stack and save as IJ hyperstack
stacked_files = dview.starmap(utils.mosaic_to_stack, [(x, n_planes, x_crop) for x in tiff_files])

In [None]:
print('Processing %1.0f files:' % (len(stacked_files)))
print(*stacked_files[:10], sep='\n')
if len(stacked_files) > 10:
    print('...')
    print(*stacked_files[-5:], sep='\n')

### Join cropped TIF files
Next, we create a large joined TIF file from individual cropped files. Further processing will be done on the joined file.

In [None]:
groups = math.ceil(len(stacked_files) / float(max_group_size))
files_per_group = math.ceil(len(stacked_files) / groups)
stacked_files_by_group = []
print('Processing files in %d groups' % (groups))
for i_groups in range(int(groups)):
    start_ix = int(i_groups * files_per_group)
    stop_ix = int((i_groups+1) * files_per_group)
    stacked_files_by_group.append(stacked_files[start_ix:stop_ix])
    
    print('Group %d (%d - %d): %d files' % (i_groups+1, start_ix, stop_ix, len(stacked_files[start_ix:stop_ix])))

In [None]:
stacked_files_by_group[0]

In [None]:
movies = cm.load_movie_chain(stacked_files_by_group[0], is3D=True, outtype=np.int16)

In [None]:
movies.shape

In [None]:
movies[:,1,:,:].play(backend='notebook', fr=10)

In [None]:
joined_tif_list = []
json_fname_list = []
total_frames_list = []
trial_indices_list = []
for ix, tiff_files_crop_group in enumerate(tiff_files_by_group):
    # load movies
    movies = cm.load(tiff_files_crop_group, outtype=np.int16)
    
    total_frames = movies.shape[0]
    total_frames_list.append(total_frames)
    dims = (movies.shape[1], movies.shape[2])
    # derive joined file name and save
    joined_tif = '%s_%s_Join_G%d_%d_crop.tif' % (date_folder, session_folder, ix, total_frames)
    imsave(os.path.join(data_folder, joined_tif), movies)
    frames_per_movie = dview.map(utils.getFramesTif, tiff_files_crop_group)
    
    movies = None # free the memory
    
    # trial index for each frame
    trial_indices = []
    for ix, frame_count in enumerate(frames_per_movie):
        trial_indices = trial_indices + [ix]*frame_count
    trial_indices_list.append(trial_indices)
    
    # create a Json file with information about source files
    meta = {"joined_file": joined_tif, 
            "source_frames": frames_per_movie, 
            "source_file": [x.replace(data_folder + os.path.sep,'') for x in tiff_files_crop_group],
            "trial_index": trial_indices,
            "frame_rate": frame_rate
           }
    json_fname = joined_tif.replace('.tif','.json')
    with open(os.path.join(data_folder, json_fname), 'w') as fid:
        json.dump(meta, fid)
    
    # save output file names in list
    joined_tif_list.append(joined_tif)
    json_fname_list.append(json_fname)
    
    print('Saved joined TIF file %s' % (joined_tif))
    print('Created JSON metadata file %s' % (json_fname))

### Display average signal intensity
This step is optional and useful as sanity check how the imported data looks like.

In [None]:
# select group (0, 1, ...)
group_ix = 0

mov = cm.load(os.path.join(data_folder, joined_tif_list[group_ix]), outtype=np.int16)

In [None]:
# plot average signal intensity per frame
frame_avg = np.mean(np.mean(mov, axis=1), axis=1)
frames = np.array(range(len(frame_avg)))

trial_names = [x.replace(data_folder + os.path.sep,'').replace('.tif','').replace('_crop','') 
               for x in tiff_files_by_group[group_ix]]
trial_names_frames = [trial_names[x] for x in trial_indices_list[group_ix]]

data = {'x': frames, 
        'y': frame_avg,
        'trial_idx': trial_indices_list[group_ix],
        'trial_name': trial_names_frames
       }
data_source = ColumnDataSource(data)

p = Figure(plot_width=900, plot_height=300, title=('Frame average - Group %d' % (group_ix))) 
p.add_tools(CrosshairTool(), utils.getHover())
p.line('x', 'y', source=data_source, line_width=2, legend='Original', color='blue')

show(p)

### Adjust image intensity
Optional step to adjust image intensity for each trial. This is useful when the average intensity fluctuates a lot from trial to trial, as shown by the previous plot.

In [None]:
for ix, joined_tif in enumerate(joined_tif_list): 
    mov = cm.load(os.path.join(data_folder, joined_tif), outtype=np.int16)
    n_trials = max(trial_indices_list[ix])

    for i_trial in range(n_trials+1):
        trial_idx = np.where(np.array(trial_indices_list[ix]) == i_trial)[0]
        # subtract average intensity
        mov[trial_idx,:,:] = mov[trial_idx,:,:] - np.mean(mov[trial_idx,:,:])
    imsave(os.path.join(data_folder, joined_tif), mov)

### Motion correction

First, setup the parameters for motion correction. The following parameters influence the **quality** of the motion correction:
- niter_rig ... number of iterations for rigid registration (larger = better). Little improvement likely above 5-10.
- strides ... intervals at which patches are laid out for motion correction (smaller = better)
- overlaps ... overlap between patches

Note that smaller values for strides / overlap will improve registration but also lead to NaNs in the output image. In general, there is a trade-off between the quality of registration and the presence / number of NaNs in the output (at least if there is significant motion).

In [None]:
# parameters for motion correction
params = {'niter_rig': 5,
          'max_shifts': (int(np.round(dims[0]/10)), int(np.round(dims[1]/10))),  # maximum allow rigid shift
          # for parallelization split the movies in  num_splits chuncks across time
          'splits_rig': 50,
          # if none all the splits are processed and the movie is saved
          'num_splits_to_process_rig': None,
          # intervals at which patches are laid out for motion correction
          'strides': (24, 24),
          # overlap between pathes (size of patch strides+overlaps)
          'overlaps': (24, 24),
          # for parallelization split the movies in  num_splits chuncks across time
          'splits_els': 50,
          # if none all the splits are processed and the movie is saved
          'num_splits_to_process_els': [28, None],
          'upsample_factor_grid': 4,  # upsample factor to avoid smearing when merging patches
          # maximum deviation allowed for patch with respect to rigid shift
          'max_deviation_rigid': 3,
          # Specifies how to deal with borders. (True, False, 'copy', 'min')
          'border_nan': False,
         }

There are also some parameters for computing the quality metrics. These probably don't have to be changed.

In [None]:
# parameters for computing metrics
winsize = 100
swap_dim = False
resize_fact_flow = 0.2    # downsample for computing ROF

Next, we define some functions. See the function doc strings for further information.

In [None]:
def setupMC(fname, params):
    """
    Configure motion correction oject mc with input filename and parameters.
    
    Return mc
    """
    mov = cm.load(fname)
    mc = MotionCorrect(fname, mov.min(), dview=dview, 
                       max_shifts=params['max_shifts'], 
                       niter_rig=params['niter_rig'], 
                       splits_rig=params['splits_rig'],
                       num_splits_to_process_rig=params['num_splits_to_process_rig'], 
                       strides= params['strides'], 
                       overlaps= params['overlaps'], 
                       splits_els=params['splits_els'],
                       num_splits_to_process_els=params['num_splits_to_process_els'], 
                       upsample_factor_grid=params['upsample_factor_grid'], 
                       max_deviation_rigid=params['max_deviation_rigid'],
                       border_nan=params['border_nan'],
                       shifts_opencv = True, 
                       nonneg_movie = False)
    return mc

In [None]:
def interpolateNans(frame, n=10):
    """
    Interpolate NaN values in frame with average of n closest non-nan pixels
    
    Return interpolated frame
    """
    frame_interp = frame.copy()
    # indices for all NaN pixels
    nan_pixel = np.array(np.where(np.isnan(frame))).T
    if not len(nan_pixel):
        return frame
    # indices for all non-NaN pixels
    valid_pixel = np.array(np.where(~np.isnan(frame))).T
    for pix in nan_pixel:
        # distance between NaN pixel and all valid pixels
        dist = np.linalg.norm(valid_pixel - pix, axis=1)
        # find the closest pixels and get their values in frame
        closest_pixel = valid_pixel[np.argsort(dist)[:n],:]
        closest_pixel_vals = frame[closest_pixel[:,0],closest_pixel[:,1]]
        # replace NaN with average
        frame_interp[pix[0],pix[1]] = np.mean(closest_pixel_vals)

    return frame_interp

In [None]:
def computeMetrics(mc, bord_px_els, swap_dim, winsize, resize_fact_flow):
    """
    Compute the quality metrics for the registration.
    """
    
    final_size = np.subtract(mc.total_template_els.shape, bord_px_els) # remove pixels in the boundaries
    
    tmpl_rig, corr_orig, flows_orig, norms_orig, crispness_orig = \
    cm.motion_correction.compute_metrics_motion_correction(mc.fname[0], final_size[0], final_size[1],
                                                           swap_dim, winsize=winsize, play_flow=False, 
                                                           resize_fact_flow=resize_fact_flow)

    tmpl_rig, corr_rig, flows_rig, norms_rig, crispness_rig = \
    cm.motion_correction.compute_metrics_motion_correction(mc.fname_tot_rig[0], final_size[0], final_size[1],
                                                           swap_dim, winsize=winsize, play_flow=False, 
                                                           resize_fact_flow=resize_fact_flow)

    tmpl_els, corr_els, flows_els, norms_els, crispness_els = \
    cm.motion_correction.compute_metrics_motion_correction(mc.fname_tot_els[0], final_size[0], final_size[1],
                                                           swap_dim, winsize=winsize, play_flow=False, 
                                                           resize_fact_flow=resize_fact_flow)
    
    metrics = {
        'tmpl_rig': tmpl_rig,
        'corr_orig': corr_orig,
        'flows_orig': flows_orig,
        'crispness_orig': crispness_orig,
        'norms_orig': norms_orig,
        
        'corr_rig': corr_rig,
        'flows_rig': flows_rig,
        'crispness_rig': crispness_rig,
        'norms_rig': norms_rig,
        
        'tmpl_els': tmpl_els,
        'corr_els': corr_els,
        'flows_els': flows_els,
        'crispness_els': crispness_els,
        'norms_els': norms_els,
    }
    
    return metrics

In [None]:
def removeBoundaryPixels(movie, mc, mc_type):
    """
    Remove the boundary pixels corresponding to the max. shift of the registration.
    """
    # compute borders to exclude
    if mc_type == 'els':
        bord_px = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)), np.max(np.abs(mc.y_shifts_els)))).astype(np.int)
        final_size = np.subtract(mc.total_template_els.shape, bord_px)
    elif mc_type == 'rig':
        bord_px = np.ceil(np.max(mc.shifts_rig)).astype(np.int)
        final_size = np.subtract(mc.total_template_rig.shape, bord_px)
    
    # remove pixels in the boundaries
    final_size_x = final_size[0]
    final_size_y = final_size[1]
    max_shft_x = np.int(np.ceil((np.shape(movie)[1] - final_size_x) / 2))
    max_shft_y = np.int(np.ceil((np.shape(movie)[2] - final_size_y) / 2))
    max_shft_x_1 = - ((np.shape(movie)[1] - max_shft_x) - (final_size_x))
    max_shft_y_1 = - ((np.shape(movie)[2] - max_shft_y) - (final_size_y))
    if max_shft_x_1 == 0:
        max_shft_x_1 = None

    if max_shft_y_1 == 0:
        max_shft_y_1 = None
    
    movie = movie[:, max_shft_x:max_shft_x_1, max_shft_y:max_shft_y_1]
    
    if mc_type == 'els':
        mc.total_template_els = mc.total_template_els[max_shft_x:max_shft_x_1, max_shft_y:max_shft_y_1]
    elif mc_type == 'rig':
        mc.total_template_rig = mc.total_template_els[max_shft_x:max_shft_x_1, max_shft_y:max_shft_y_1]
    
    return movie, mc

Now we are ready to run motion correction for the joined TIF file. If there are a lot of concatenated trials, this might take a while to complete.

The following outputs will be saved:
- result of rigid motion correction in Python mmap format and as TIF file
- result of pw-rigid motion correction in Python mmap format and as TIF file

In [None]:
def runMotionCorrection(fname, params):
    """
    Run motion correction for single input file fname using parameters in params. 
    
    Return motion correction object mc
    """
    
    interp_nans = False
    
    # create mc object
    mc = setupMC(fname, params)
    
    # apply rigid correction
    mc.motion_correct_rigid(save_movie=True)
    
    # apply pw-rigid correction
    mc.motion_correct_pwrigid(save_movie=True, template=mc.total_template_rig, show_template = False)
    
    # load corrected movie - els
    mov_els = cm.load(mc.fname_tot_els[0])
    # remove boundary pixels
    mov_els, mc = removeBoundaryPixels(mov_els, mc, 'els')
    
    # create copy to interpolate NaNs
    mov_els_copy = mov_els.copy()
    if interp_nans:
        for ix in range(mov_els.shape[0]):
            mov_els_copy[ix,:,:] = interpolateNans(mov_els[ix,:,:])
    
    # save pw-rigid corrected movies as TIF
    dummy_fname = 'dummy_%s.tif' % (datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
    dummy_fname = os.path.join(data_folder, dummy_fname)
    imsave(dummy_fname, mov_els_copy)
    
    # save pw-rigid corrected and interpolated movie as mmap
    base_name = mc.fname_tot_rig[0][:mc.fname_tot_rig[0].find('_rig_')] + '_els_'
    fname_new = cm.save_memmap([dummy_fname], base_name=base_name, order='F')
    
    # remove previous mmap file and rename TIF file
    os.remove(mc.fname_tot_els[0])
    mc.fname_tot_els = [fname_new]
    os.rename(dummy_fname, fname_new.replace('.mmap', '.tif'))
    
    # save rigid corrected movies as TIF
    imsave(mc.fname_tot_rig[0].replace('.mmap','.tif'), cm.load(mc.fname_tot_rig[0]))
    
    return mc

### Correlation with template - uncorrected data
First, we just get the correlation with the template for the uncorrected data. This can be used to re-assign groups, for example. This part can also be skipped.

In [None]:
mc_init_list = []
corr_orig_list = []

t_start = time.time()
for i_file in joined_tif_list:
    fname = os.path.join(data_folder, i_file)
    mc_init = setupMC(fname, params)
    
    # compute initial template by binned median filtering
    template = cm.load(fname).bin_median(window=10)
    
    # compute metric for orig
    tmpl, corr_orig, flows_orig, norms_orig, crispness_orig = \
    cm.motion_correction.compute_metrics_motion_correction(mc_init.fname[0], template.shape[0], template.shape[1], 
                                                       swap_dim, winsize=winsize, play_flow=False, 
                                                       resize_fact_flow=0.01)
    
    
    
    mc_init_list.append(mc_init)
    corr_orig_list.append(corr_orig)
    
    clear_output()
    
# print elapsed time
t_elapsed = time.time() - t_start
print('\nFinished pre-MC in %1.2f s (%1.2e s per frame)' % (t_elapsed, t_elapsed/sum(total_frames_list)))

In [None]:
# select group (0, 1, ...)
group_ix = 0

trial_names = [x.replace(data_folder + os.path.sep,'').replace('_crop.tif','') for x in tiff_files_by_group[group_ix]]
trial_names_frames = [trial_names[x] for x in trial_indices_list[group_ix]]

frames = np.array(range(len(corr_orig_list[group_ix])))
data = {'x': list(range(len(corr_orig_list[group_ix]))), 
        'y': corr_orig_list[group_ix],
        'trial_idx': trial_indices_list[group_ix],
        'trial_name': trial_names_frames
       }
data_source = ColumnDataSource(data)

p = Figure(plot_width=900, plot_height=300, title=('Correlation with template - Group %d' % (group_ix))) 
p.add_tools(CrosshairTool(), utils.getHover())
p.line('x', 'y', source=data_source, line_width=2, legend='Original', color='blue')

show(p)

### Run the motion correction
Next, we run the MC itself. This can be time consuming for large datasets.

In [None]:
t_start = time.time()
mc_list = []
for i_file in joined_tif_list:
    
    fname = os.path.join(data_folder, i_file)
    
    mc = runMotionCorrection(fname, params)
    
    mc_list.append(mc)
    
clear_output()

# print elapsed time
t_elapsed = time.time() - t_start
print('\nFinished MC in %1.2f s (%1.2f s per frame)' % (t_elapsed, t_elapsed/sum(total_frames_list)))

if platform.system() == 'Darwin':
    os.system('say "your program has finished"')

### Assess quality of motion correction
A number of key metrics can be calculated to assess how much motion correction improved overall motion. 
1. Correlation
Correlations of each frame with the template image (binned median) for original, rigid correction and pw-rigid correction. The mean correlation gives an overall impression of motion. The minimum correlation indicates the parts of the movie that are worst affected by motion. Larger correlations indicate less motion.
2. Crispness
Crispness provides a measure of the smoothness of the corrected average image. Intuitively, a dataset with nonregistered motion will have a blurred mean image, resulting in a lower value for the total gradient field norm. Thus, larger values indicate a crisper average image and less residual motion. Crispness is calculated from the gradient field of the mean image (`np.gradient`).
3. Residual optical flow
Optic flow algorithms attempt to match each frame to the template by estimating locally smooth displacement fields. The output is an image where each pixel described the local displacement between template and frame at this point. The smaller the local displacement, the better the registration. Here we compute the matrix norm of the optic flow matrix as summary statistic.

In [None]:
# compute quality assessment metrics
crispness = []
norms = []
corr_mean = []
corr_min = []
metrics = []
for mc in mc_list:
    bord_px = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)), np.max(np.abs(mc.y_shifts_els)))).astype(np.int)
    mtrs = computeMetrics(mc, bord_px, swap_dim, winsize, resize_fact_flow)
    metrics.append(mtrs)
    # correlations, crispness and norms of residual optic flow as indicators of registration quality
    crispness.append(np.array([mtrs['crispness_orig'], mtrs['crispness_rig'], mtrs['crispness_els']]))
    norms.append(np.array([np.mean(mtrs['norms_orig']), np.mean(mtrs['norms_rig']), np.mean(mtrs['norms_els'])]))
    corr_mean.append(np.array([np.mean(mtrs['corr_orig']), np.mean(mtrs['corr_rig']), np.mean(mtrs['corr_els'])]))
    corr_min.append(np.array([np.min(mtrs['corr_orig']), np.min(mtrs['corr_rig']), np.min(mtrs['corr_els'])]))

clear_output()

Print different metrics for raw movie and rigid / pw-rigid corrected movies.

In [None]:
for ix in range(len(mc_list)):
    print('MC evaluation - Group %d:' % (ix))
    if corr_mean[ix][0] > corr_mean[ix][1] or corr_mean[ix][0] > corr_mean[ix][2]:
        print('\x1b[1;03;31m'+'Mean corr - raw / rigid / pw_rigid: ' 
              + str(['{:.2f}'.format(i) for i in corr_mean[ix]]) + '\x1b[0m')
    else:
        print('Mean corr - raw / rigid / pw_rigid: ' + str(['{:.2f}'.format(i) for i in corr_mean[ix]]))

    if corr_min[ix][0] > corr_min[ix][1] or corr_min[ix][0] > corr_min[ix][2]:
        print('\x1b[1;03;31m'+'Min corr - raw / rigid / pw_rigid: ' 
              + str(['{:.2f}'.format(i) for i in corr_min[ix]])+ '\x1b[0m')
    else:
        print('Min corr - raw / rigid / pw_rigid: ' + str(['{:.2f}'.format(i) for i in corr_min[ix]]))
    if crispness[ix][0] > crispness[ix][1] or crispness[ix][0] > crispness[ix][2]:
        print('\x1b[1;03;31m'+'Crispness - raw / rigid / pw_rigid: ' 
              + str(['{:.0f}'.format(i) for i in crispness[ix]]) + '\x1b[0m')
    else:
        print('Crispness - raw / rigid / pw_rigid: ' + str(['{:.0f}'.format(i) for i in crispness[ix]]))
    if norms[ix][0] < norms[ix][1] or norms[ix][0] < norms[ix][2]:
        print('\x1b[1;03;31m'+'Norms - raw / rigid / pw_rigid: ' 
              + str(['{:.0f}'.format(i) for i in norms[ix]]) + '\x1b[0m')
    else:
        print('Norms - raw / rigid / pw_rigid: ' + str(['{:.2f}'.format(i) for i in norms[ix]]))
        
    print('\n')

#### Correlation with template image
Plot correlations of each frame with the template image (binned median) for original, rigid correction and pw-rigid correction. The bokeh plotting library provides a toolbar for interaction with the plot.

In [None]:
# select group (0, 1, ...)
group_ix = 0

p1 = Figure(plot_width=900, plot_height=300, title=('Correlation with template - Group %d' % (group_ix))) 
frames = np.array(range(len(metrics[group_ix]['corr_orig'])))
p1.line(frames,np.array(metrics[group_ix]['corr_orig']), line_width=2, legend='Original', color='blue')
p1.line(frames,np.array(metrics[group_ix]['corr_rig']), line_width=2, legend='Rigid', color='orange')
p1.line(frames,np.array(metrics[group_ix]['corr_els']), line_width=2, legend='PW-Rigid', color='green')

p2 = Figure(plot_width=250, plot_height=250)
p2.circle(np.array(metrics[group_ix]['corr_orig']), np.array(metrics[group_ix]['corr_rig']), size=5)
p2.line([0,1],[0,1], line_width=1, color='black', line_dash='dashed')
p2.xaxis.axis_label = 'Original'
p2.yaxis.axis_label = 'Rigid'

p3 = Figure(plot_width=250, plot_height=250)
p3.circle(np.array(metrics[group_ix]['corr_rig']), np.array(metrics[group_ix]['corr_els']), size=5)
p3.line([0,1],[0,1], line_width=1, color='black', line_dash='dashed')
p3.xaxis.axis_label = 'Rigid'
p3.yaxis.axis_label = 'PW-Rigid'

# make a grid
grid = gridplot([[p1, None], [p2, p3]], sizing_mode='fixed', toolbar_location='left')

show(grid)

#### Residual optic flow
Optic flow algorithms attempt to match each frame to the template by estimating locally smooth displacement fields. The output is an image where each pixel described the local displacement between template and frame at this point. The smaller the local displacement, the better the registration. Norms are the matrix norms of the optic flow matrix.

In [None]:
# select group (0, 1, ...)
group_ix = 0

# plot the results of Residual Optical Flow
metrics_files = [mc_list[group_ix].fname_tot_els[0][:-4] + '_metrics.npz', 
                 mc_list[group_ix].fname_tot_rig[0][:-4] + '_metrics.npz', 
                 mc_list[group_ix].fname[0][:-4] + '_metrics.npz']

plt.figure(figsize = (20,10))
for cnt, fl, metr in zip(range(len(metrics_files)),metrics_files,['pw_rigid','rigid','raw']):
    with np.load(fl) as ld:
        print('Correction: %s' % (metr))
        print('Norms: %1.2f +- %1.2f' % (np.mean(ld['norms']), np.std(ld['norms'])))
        
        plt.subplot(len(metrics_files), 3, 1 + 3 * cnt)
        plt.ylabel(metr)
                   
        if metr == 'raw':
            mean_img = np.mean(cm.load(mc.fname[0]), axis=0)
        elif metr == 'rigid':
            mean_img = np.mean(cm.load(mc.fname_tot_rig[0]), axis=0)
        elif metr == 'pw_rigid':
            mean_img = np.mean(cm.load(mc.fname_tot_els[0]), axis=0)
        
        lq, hq = np.nanpercentile(mean_img, [.5, 99.5])
        plt.imshow(mean_img, vmin=lq, vmax=hq, cmap='gray')
        if not cnt:
            plt.title('Mean')
        plt.subplot(len(metrics_files), 3, 3 * cnt + 2)
        plt.imshow(ld['img_corr'], vmin=0, vmax=.5, cmap='gray')
        if not cnt:
            plt.title('Correlation image')
        plt.subplot(len(metrics_files), 3, 3 * cnt + 3)
        flows = ld['flows']
        plt.imshow(np.mean(np.sqrt(flows[:, :, :, 0]**2 + flows[:, :, :, 1]**2), 0), vmin=0, vmax=0.5, cmap='gray')
        plt.colorbar()
        if not cnt:
            plt.title('Mean optical flow')
plt.suptitle('Residual Optic Flow - Group %d' % (group_ix), fontsize=22);

### Detect frames with bad motion
Identify frames with significant residual motion (low correlation with template). Write a JSON file with criterion and indices of frames matching the criterion. This file can be used in further analysis to exclude the frames corrupted by motion.

In [None]:
def writeJsonBadFrames(criterion, thresh, frame_ix, mc, mc_type, data_folder):
    exclude_info = {"criterion": criterion, 
        "threshold:": thresh, 
        "frames": frame_ix}
    if mc_type == 'els':
        json_fname = mc.fname_tot_els[0].replace('.mmap','') + 'badFrames' + '.json'
    elif mc_type == 'rig':
        json_fname = mc.fname_tot_rig[0].replace('.mmap','') + 'badFrames' + '.json'
    with open(os.path.join(data_folder, json_fname), 'w') as fid:
        json.dump(exclude_info, fid)
    print('Created JSON metadata file %s' % (json_fname))

In [None]:
thresh = [0.4] # find frames where value is less than criterion (one value per group)

for i_thr in range(len(thresh)):
    print('Group %d' % (i_thr))
    # pw-rigid registration
    criterion = 'corr_els'
    bad_frames = [ix for ix, i in enumerate(metrics[i_thr][criterion]) if i < thresh[i_thr]]
    print('%1.0f frames matching criterion after pw-rigid registration.' % (len(bad_frames)))
    writeJsonBadFrames(criterion, thresh[i_thr], bad_frames, mc_list[i_thr], 'els', data_folder)
    # rigid registration
    criterion = 'corr_rig'
    bad_frames = [ix for ix, i in enumerate(metrics[i_thr][criterion]) if i < thresh[i_thr]]
    print('\n%1.0f frames matching criterion after rigid registration.' % (len(bad_frames)))
    writeJsonBadFrames(criterion, thresh[i_thr], bad_frames, mc_list[i_thr], 'rig', data_folder)
    print('\n')