## Preprocess and motion correct movies
Step 1 of the Caiman processing pipeline for dendritic two-photon calcium imaging movies.

### Imports & Setup
The first cells import the various Python modules required by the notebook. In particular, a number of modules are imported from the Caiman package. In addition, we also setup the environment so that everything works as expected.

In [None]:
# Generic imports
# from __future__ import absolute_import, division, print_function
# from builtins import *
from __future__ import print_function

import os, sys, glob, platform, shutil
import json
import time, datetime
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
import skimage.transform
from scipy import interpolate
from tifffile import imread, imsave
from IPython.display import clear_output

%matplotlib inline

In [None]:
# on Linux we have to add the caiman folder to Pythonpath
if platform.system() == 'Linux':
    sys.path.append(os.path.expanduser('~/caiman'))
# environment variables for parallel processing
os.environ['MKL_NUM_THREADS']='1'
os.environ['OPENBLAS_NUM_THREADS']='1'
os.environ['VECLIB_MAXIMUM_THREADS']='1'

In [None]:
# CaImAn imports
import caiman as cm
from caiman.motion_correction import MotionCorrect

### Specify experimental parameters

**Update!**

Here we specify where the files are located, how the files are called, the frame rate of the acquisition, how to crop the movies and how many files to process.
- data_folder ... the folder where the data is located on the volume `Data`.
- ext ... the extension of the TIF files (e.g. .tif)
- crop_pixel_xy ... crop movies by specified number of pixels in x and y (e.g. to remove artifacts)
- max_sessions, max_spots, max_files ... maximum number of sessions / spots / files to process, e.g. for testing (if 0, all sessions / spots/ files will be processed)

In [None]:
animal_folder = 'M1_for_processing'
date_folder = 'M1_2018-01-31'
session_folder = 'S1'
max_files = 50 # how many files to process per session (0 for all)
ext = '.tif'
frame_rate = 13.1316 # in Hz
crop_pixel_xy = (25,0) # crop movies by specified number of pixels in x and y

# create the complete path to the data folder
if platform.system() == 'Linux':
    data_folder = '/home/ubuntu/Data/Henry_test'
elif platform.system() == 'Darwin':
    data_folder = '/Users/Henry/polybox/Data_temp/Dendrites_Gwen'
data_folder = os.path.join(data_folder, animal_folder, date_folder, session_folder)

In [None]:
# create list of TIF files for processing
tiff_files = sorted([x for x in os.listdir(data_folder) if x.endswith(ext) and not x.endswith('_crop' + ext)])
if max_files and len(tiff_files) > max_files:
    tiff_files = tiff_files[:max_files]
tiff_files = [os.path.join(data_folder, x) for x in tiff_files]
print('Selected %1.0f TIF files' % len(tiff_files))

### Load TIF files, crop and re-save

In [None]:
def cropTif(fname, crop_pixel):
    """
    Crop TIF file in x and / or y. Save output as *_crop.tif. Only process movies.
    fname ... input TIF file
    crop_pixel ... number of pixels in x and y to crop
    
    returns is_movie (true/false)
    """
    is_movie = True
    # load data
    mov = imread(fname)
    if len(mov.shape) < 3: # not a movie!
        is_movie = False
#         print('%s is not a movie. Skipping.' % (fname))
    else:
         # crop pixels (e.g. due to artifacts at the edged)
        mov = mov[:,crop_pixel[1]:,crop_pixel[0]:]
        # resave as tiff
        imsave(fname.replace('.tif','_crop.tif'), mov)
    return is_movie

In [None]:
def getFramesTif(fname):
    """
    Returns the number of frames in a multi-frame TIF file.
    Return 0 for single-page TIFs.
    """
    # load data
    mov = imread(fname)
    if len(mov.shape) < 3: # not a movie!
        return 0
    else:
        return mov.shape[0]

### Setup cluster
The default backend mode for parallel processing is through the multiprocessing package. This will allow us to use all the cores in the VM. Note that the `cropTif` function has to be defined before starting the cluster, so that pool workers have the function available.

In [None]:
# start the cluster (if a cluster already exists terminate it)
n_processes = 8 # number of compute processes (None to select automatically)
if 'dview' in locals():
    dview.terminate()
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=n_processes, single_thread=False)

Then we call the function through the multiprocessing `map` method to make use of multiple cores.

In [None]:
is_movie = dview.map(partial(cropTif, crop_pixel=crop_pixel_xy), tiff_files)

Then, create the list of cropped TIF files for motion correction, excluding files that are not movies.

In [None]:
tiff_files_crop = [x.replace('.tif','_crop.tif') for ix, x in enumerate(tiff_files) if is_movie[ix]]
print('Processing %1.0f files:' % (len(tiff_files_crop)))
print(*tiff_files_crop[:10], sep='\n')
if len(tiff_files_crop) > 10:
    print('...')

### Join cropped TIF files
Next, we create a large joined TIF file from individual cropped files. Further processing will be done on the joined file.

In [None]:
# load movies
movies = cm.load(tiff_files_crop)
total_frames = movies.shape[0]
dims = (movies.shape[1], movies.shape[2])
# derive joined file name and save
joined_tif = '%s_%s_Join_%1.0f_crop.tif' % (date_folder, session_folder, total_frames)
imsave(os.path.join(data_folder, joined_tif), movies)
print('Saved joined TIF file %s' % (joined_tif))
frames_per_movie = dview.map(getFramesTif, tiff_files)
movies = None # free the memory

In [None]:
# create a Json file with information about source files
meta = {"joined_file": joined_tif, 
        "source_frames": frames_per_movie, 
        "source_file": [x.replace(data_folder + os.path.sep,'') for x in tiff_files_crop]}
json_fname = joined_tif.replace('.tif','.json')
with open(os.path.join(data_folder, json_fname), 'w') as fid:
    json.dump(meta, fid)
print('Created JSON metadata file %s' % (json_fname))

### Motion correction

First, setup the parameters for motion correction. The following parameters influence the **quality** of the motion correction:
- niter_rig ... number of iterations for rigid registration (larger = better). Little improvement likely above 5-10.
- strides ... intervals at which patches are laid out for motion correction (smaller = better)
- overlaps ... overlap between patches

Note that smaller values for strides / overlap will improve registration but also lead to NaNs in the output image. In general, there is a trade-off between the quality of registration and the presence / number of NaNs in the output (at least if there is significant motion).

In [None]:
# parameters for motion correction
params = {'niter_rig': 5,
          'max_shifts': (int(np.round(dims[0]/10)), int(np.round(dims[1]/10))),  # maximum allow rigid shift
          # if none all the splits are processed and the movie is saved
          'num_splits_to_process_rig': None,
          # intervals at which patches are laid out for motion correction
          'strides': (24, 24),
          # overlap between pathes (size of patch strides+overlaps)
          'overlaps': (24, 24),
          # if none all the splits are processed and the movie is saved
          'num_splits_to_process_els': [28, None],
          'upsample_factor_grid': 4,  # upsample factor to avoid smearing when merging patches
          # maximum deviation allowed for patch with respect to rigid shift
          'max_deviation_rigid': 3,
                 }

There are also some parameters for computing the quality metrics. These probably don't have to be changed.

In [None]:
# parameters for computing metrics
winsize = 100
swap_dim = False
resize_fact_flow = 1    # downsample for computing ROF

Next, we define some functions. See the function doc strings for further information.

In [None]:
def setupMC(fname, params):
    """
    Configure motion correction oject mc with input filename and parameters.
    
    Return mc
    """
    mov = cm.load(fname)
    mc = MotionCorrect(fname, mov.min(), dview=dview, 
                       max_shifts=params['max_shifts'], 
                       niter_rig=params['niter_rig'], 
                       num_splits_to_process_rig=params['num_splits_to_process_rig'], 
                       strides= params['strides'], 
                       overlaps= params['overlaps'], 
                       num_splits_to_process_els=params['num_splits_to_process_els'], 
                       upsample_factor_grid=params['upsample_factor_grid'], 
                       max_deviation_rigid=params['max_deviation_rigid'], 
                       shifts_opencv = True, nonneg_movie = False)
    return mc

In [None]:
def interpolateNans(frame, n=10):
    """
    Interpolate NaN values in frame with average of n closest non-nan pixels
    
    Return interpolated frame
    """
    frame_interp = frame.copy()
    # indices for all NaN pixels
    nan_pixel = np.array(np.where(np.isnan(frame))).T
    if not len(nan_pixel):
        return frame
    # indices for all non-NaN pixels
    valid_pixel = np.array(np.where(~np.isnan(frame))).T
    for pix in nan_pixel:
        # distance between NaN pixel and all valid pixels
        dist = np.linalg.norm(valid_pixel - pix, axis=1)
        # find the closest pixels and get their values in frame
        closest_pixel = valid_pixel[np.argsort(dist)[:n],:]
        closest_pixel_vals = frame[closest_pixel[:,0],closest_pixel[:,1]]
        # replace NaN with average
        frame_interp[pix[0],pix[1]] = np.mean(closest_pixel_vals)

    return frame_interp

In [None]:
def computeMetrics(mc, bord_px_els, swap_dim, winsize, resize_fact_flow):
    """
    Compute the quality metrics for the registration.
    """
    
    final_size = np.subtract(mc.total_template_els.shape, bord_px_els) # remove pixels in the boundaries
    
    tmpl_rig, corr_orig, flows_orig, norms_orig, crispness_orig = \
    cm.motion_correction.compute_metrics_motion_correction(mc.fname[0], final_size[0], final_size[1],
                                                           swap_dim, winsize=winsize, play_flow=False, 
                                                           resize_fact_flow=resize_fact_flow)

    tmpl_rig, corr_rig, flows_rig, norms_rig, crispness_rig = \
    cm.motion_correction.compute_metrics_motion_correction(mc.fname_tot_rig[0], final_size[0], final_size[1],
                                                           swap_dim, winsize=winsize, play_flow=False, 
                                                           resize_fact_flow=resize_fact_flow)

    tmpl_els, corr_els, flows_els, norms_els, crispness_els = \
    cm.motion_correction.compute_metrics_motion_correction(mc.fname_tot_els[0], final_size[0], final_size[1],
                                                           swap_dim, winsize=winsize, play_flow=False, 
                                                           resize_fact_flow=resize_fact_flow)
    
    metrics = {
        'tmpl_rig': tmpl_rig,
        'corr_orig': corr_orig,
        'flows_orig': flows_orig,
        'crispness_orig': crispness_orig,
        'norms_orig': norms_orig,
        
        'corr_rig': corr_rig,
        'flows_rig': flows_rig,
        'crispness_rig': crispness_rig,
        'norms_rig': norms_rig,
        
        'tmpl_els': tmpl_els,
        'corr_els': corr_els,
        'flows_els': flows_els,
        'crispness_els': crispness_els,
        'norms_els': norms_els,
    }
    
    return metrics

In [None]:
def removeBoundaryPixels(movie, mc):
    """
    Remove the boundary pixels corresponding to the max. shift of the pw-rigid registration.
    """
    # compute borders to exclude
    bord_px_els = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)), 
                                     np.max(np.abs(mc.y_shifts_els)))).astype(np.int)
    # remove pixels in the boundaries
    final_size = np.subtract(mc.total_template_els.shape, bord_px_els)
    final_size_x = final_size[0]
    final_size_y = final_size[1]
    max_shft_x = np.int(np.ceil((np.shape(mov_els)[1] - final_size_x) / 2))
    max_shft_y = np.int(np.ceil((np.shape(mov_els)[2] - final_size_y) / 2))
    max_shft_x_1 = - ((np.shape(mov_els)[1] - max_shft_x) - (final_size_x))
    max_shft_y_1 = - ((np.shape(mov_els)[2] - max_shft_y) - (final_size_y))
    if max_shft_x_1 == 0:
        max_shft_x_1 = None

    if max_shft_y_1 == 0:
        max_shft_y_1 = None
    
    movie = movie[:, max_shft_x:max_shft_x_1, max_shft_y:max_shft_y_1]
    mc.total_template_els = mc.total_template_els[max_shft_x:max_shft_x_1, max_shft_y:max_shft_y_1]
    
    return movie, mc

Now we are ready to run motion correction for the joined TIF file. If there are a lot of concatenated trials, this might take a while to complete.

The following outputs will be saved:
- result of rigid motion correction in Python mmap format and as TIF file
- result of pw-rigid motion correction in Python mmap format and as TIF file

In [None]:
t_start = time.time()
fname = os.path.join(data_folder, joined_tif)
# create mc object
mc = setupMC(fname, params)
# compute initial template by binned median filtering
# this template will be refined during the registration process
template = cm.load(fname).bin_median(window=10)

# apply rigid correction
mc.motion_correct_rigid(save_movie=True, template=template)
# apply piecewise rigid correction
mc.motion_correct_pwrigid(save_movie=True, template=mc.total_template_rig, show_template = False)

# load corrected movie - els
mov_els = cm.load(mc.fname_tot_els[0])

mov_els, mc = removeBoundaryPixels(mov_els, mc)

# interpolate NaNs
mov_els_copy = mov_els.copy()
for ix in range(mov_els.shape[0]):
    mov_els_copy[ix,:,:] = interpolateNans(mov_els[ix,:,:])

# save pw-rigid corrected movies as TIF
dummy_fname = 'dummy_%s.tif' % (datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
dummy_fname = os.path.join(data_folder, dummy_fname)
imsave(dummy_fname, mov_els_copy)

# save pw-rigid corrected and interpolated movie as mmap
base_name = os.path.join(data_folder, '%s_%s_Join_crop_els_' % (date_folder, session_folder))
fname_new = cm.save_memmap([dummy_fname], base_name=base_name, order='F')

# remove previous mmap file and rename TIF file
os.remove(mc.fname_tot_els[0])
mc.fname_tot_els = [fname_new]
os.rename(dummy_fname, fname_new.replace('.mmap', '.tif'))

# clear memory
mov_els = None
mov_els_copy = None

# save rigid corrected movies as TIF
imsave(mc.fname_tot_rig[0].replace('.mmap','.tif'), cm.load(mc.fname_tot_rig[0]))

clear_output()

# print elapsed time
t_elapsed = time.time() - t_start
print('\nFinished MC in %1.2f s (%1.2f s per frame)' % (t_elapsed, t_elapsed/total_frames))

if platform.system() == 'Darwin':
    os.system('say "your program has finished"')

### Assess quality of motion correction
A number of key metrics can be calculated to assess how much motion correction improved overall motion. 
1. Correlation
Correlations of each frame with the template image (binned median) for original, rigid correction and pw-rigid correction. The mean correlation gives an overall impression of motion. The minimum correlation indicates the parts of the movie that are worst affected by motion. Larger correlations indicate less motion.
2. Crispness
Crispness provides a measure of the smoothness of the corrected average image. Intuitively, a dataset with nonregistered motion will have a blurred mean image, resulting in a lower value for the total gradient field norm. Thus, larger values indicate a crisper average image and less residual motion. Crispness is calculated from the gradient field of the mean image (`np.gradient`).
3. Residual optical flow
Optic flow algorithms attempt to match each frame to the template by estimating locally smooth displacement fields. The output is an image where each pixel described the local displacement between template and frame at this point. The smaller the local displacement, the better the registration. Here we compute the matrix norm of the optic flow matrix as summary statistic.

In [None]:
# compute quality assessment metrics
mtrs = computeMetrics(mc, 0, swap_dim, winsize, resize_fact_flow)
# correlations, crispness and norms of residual optic flow as indicators of registration quality
crispness_metric = np.array([mtrs['crispness_orig'], mtrs['crispness_rig'], mtrs['crispness_els']])
norms_metric = np.array([np.mean(mtrs['norms_orig']), np.mean(mtrs['norms_rig']), np.mean(mtrs['norms_els'])])
corr_mean_metric = np.array([np.mean(mtrs['corr_orig']), np.mean(mtrs['corr_rig']), np.mean(mtrs['corr_els'])])
corr_min_metric = np.array([np.min(mtrs['corr_orig']), np.min(mtrs['corr_rig']), np.min(mtrs['corr_els'])])

clear_output()

Print different metrics for raw movie and rigid / pw-rigid corrected movies.

In [None]:
print('MC evaluation:')
if corr_mean_metric[0] > corr_mean_metric[1] or corr_mean_metric[0] > corr_mean_metric[2]:
    print('\x1b[1;03;31m'+'Mean corr - raw / rigid / pw_rigid: ' 
          + str(['{:.2f}'.format(i) for i in corr_mean_metric]) + '\x1b[0m')
else:
    print('Mean corr - raw / rigid / pw_rigid: ' + str(['{:.2f}'.format(i) for i in corr_mean_metric]))

if corr_min_metric[0] > corr_min_metric[1] or corr_min_metric[0] > corr_min_metric[2]:
    print('\x1b[1;03;31m'+'Min corr - raw / rigid / pw_rigid: ' 
          + str(['{:.2f}'.format(i) for i in corr_min_metric])+ '\x1b[0m')
else:
    print('Min corr - raw / rigid / pw_rigid: ' + str(['{:.2f}'.format(i) for i in corr_min_metric]))
if crispness_metric[0] > crispness_metric[1] or crispness_metric[0] > crispness_metric[2]:
    print('\x1b[1;03;31m'+'Crispness - raw / rigid / pw_rigid: ' 
          + str(['{:.0f}'.format(i) for i in crispness_metric]) + '\x1b[0m')
else:
    print('Crispness - raw / rigid / pw_rigid: ' + str(['{:.0f}'.format(i) for i in crispness_metric]))
if norms_metric[0] < norms_metric[1] or norms_metric[0] < norms_metric[2]:
    print('\x1b[1;03;31m'+'Norms - raw / rigid / pw_rigid: ' 
          + str(['{:.0f}'.format(i) for i in norms_metric]) + '\x1b[0m')
else:
    print('Norms - raw / rigid / pw_rigid: ' + str(['{:.2f}'.format(i) for i in norms_metric]))

#### Correlation with template image
Plot correlations of each frame with the template image (binned median) for original, rigid correction and pw-rigid correction.

In [None]:
# create figure
plt.figure(figsize = (20,10))
# line plot
plt.subplot(211); plt.plot(mtrs['corr_orig']); plt.plot(mtrs['corr_rig']); plt.plot(mtrs['corr_els'])
plt.legend(['Original','Rigid','PW-Rigid']), plt.xlabel('Frame'), plt.ylabel('Correlation')
axes = plt.gca(); axes.set_xlim([0,len(mtrs['corr_els'])]); axes.set_ylim([-0.1,1]);
# scatter plot: raw vs. rigid
plt.subplot(223); plt.scatter(mtrs['corr_orig'], mtrs['corr_rig']); plt.xlabel('Original'); 
plt.ylabel('Rigid'); plt.plot([0,1],[0,1],'r--')
axes = plt.gca(); axes.set_xlim([0,1]); axes.set_ylim([0,1]); plt.axis('square');
# scatter plot: rigid vs. pw-rigid
plt.subplot(224); plt.scatter(mtrs['corr_rig'], mtrs['corr_els']); plt.xlabel('Rigid'); 
plt.ylabel('PW-Rigid'); plt.plot([0,1],[0,1],'r--')
axes = plt.gca(); axes.set_xlim([0,1]); axes.set_ylim([0,1]); plt.axis('square');

#### Residual optic flow
Optic flow algorithms attempt to match each frame to the template by estimating locally smooth displacement fields. The output is an image where each pixel described the local displacement between template and frame at this point. The smaller the local displacement, the better the registration. Norms are the matrix norms of the optic flow matrix.

In [None]:
# plot the results of Residual Optical Flow
metrics_files = [mc.fname_tot_els[0][:-4] + '_metrics.npz', mc.fname_tot_rig[0][:-4] +
       '_metrics.npz', mc.fname[0][:-4] + '_metrics.npz']

plt.figure(figsize = (20,10))
for cnt, fl, metr in zip(range(len(metrics_files)),metrics_files,['pw_rigid','rigid','raw']):
    with np.load(fl) as ld:
        print('Correction: %s' % (metr))
        print('Norms: %1.2f +- %1.2f' % (np.mean(ld['norms']), np.std(ld['norms'])))
        
        plt.subplot(len(metrics_files), 3, 1 + 3 * cnt)
        plt.ylabel(metr)
                   
        if metr == 'raw':
            mean_img = np.mean(cm.load(mc.fname[0]), axis=0)
        elif metr == 'rigid':
            mean_img = np.mean(cm.load(mc.fname_tot_rig[0]), axis=0)
        elif metr == 'pw_rigid':
            mean_img = np.mean(cm.load(mc.fname_tot_els[0]), axis=0)
        
        lq, hq = np.nanpercentile(mean_img, [.5, 99.5])
        plt.imshow(mean_img, vmin=lq, vmax=hq, cmap='gray')
        if not cnt:
            plt.title('Mean')
        plt.subplot(len(metrics_files), 3, 3 * cnt + 2)
        plt.imshow(ld['img_corr'], vmin=0, vmax=.5, cmap='gray')
        if not cnt:
            plt.title('Correlation image')
        plt.subplot(len(metrics_files), 3, 3 * cnt + 3)
        flows = ld['flows']
        plt.imshow(np.mean(np.sqrt(flows[:, :, :, 0]**2 + flows[:, :, :, 1]**2), 0), vmin=0, vmax=0.5, cmap='gray')
        plt.colorbar()
        if not cnt:
            plt.title('Mean optical flow')

### Detect frames with bad motion
Identify frames with significant residual motion (low correlation with template). Write a JSON file with criterion and indices of frames matching the criterion. This file can be used in further analysis to exclude the frames corrupted by motion.

In [None]:
def writeJsonBadFrames(criterion, thresh, frame_ix, mc, mc_type, data_folder):
    exclude_info = {"criterion": criterion, 
        "threshold:": thresh, 
        "frames": frame_ix}
    if mc_type == 'els':
        json_fname = mc.fname_tot_els[0].replace('.mmap','') + 'badFrames' + '.json'
    elif mc_type == 'rig':
        json_fname = mc.fname_tot_rig[0].replace('.mmap','') + 'badFrames' + '.json'
    with open(os.path.join(data_folder, json_fname), 'w') as fid:
        json.dump(exclude_info, fid)
    print('Created JSON metadata file %s' % (json_fname))

In [None]:
thresh = 0.1 # find frames where value is less than criterion
# pw-rigid registration
criterion = 'corr_els'
bad_frames = [ix for ix, i in enumerate(mtrs[criterion]) if i < thresh]
print('%1.0f frames matching criterion after pw-rigid registration.' % (len(bad_frames)))
writeJsonBadFrames(criterion, thresh, bad_frames, mc, 'els', data_folder)
# rigid registration
criterion = 'corr_rig'
bad_frames = [ix for ix, i in enumerate(mtrs[criterion]) if i < thresh]
print('\n%1.0f frames matching criterion after rigid registration.' % (len(bad_frames)))
writeJsonBadFrames(criterion, thresh, bad_frames, mc, 'rig', data_folder)

# Old code
Code below is kept from previous versions of the notebook. **Do NOT use!!!**

Now, loop over all selected files and run motion correction. For each file, we store the `mc` object in `mc_results`. Each correction also produces 2 output files in CaImAn mmap format (one for rigid and one for elastic registration). The mmap files are stored in the same folder as the original TIFF image. Finally, different metrics are computed to assess the quality of motion correction. The metrics are stored in `*metrics.npz` files. 

Note that Caiman does not allow computing of metrics if the output image contains NaNs. This is called a "failed" registration here. But it may in fact be expected, for example if there is out-of-frame motion. 

In [None]:
t_start = time.time()
mc_results = []
bad_movies = []
crispness_metric = np.zeros((len(tiff_files_crop),3))
norms_metric = np.zeros((len(tiff_files_crop),3))
corr_min_metric = np.zeros((len(tiff_files_crop),3))
corr_mean_metric = np.zeros((len(tiff_files_crop),3))

for ix, fname in enumerate(tiff_files_crop):
    print('Running MC for %s' % (fname.replace(data_folder, '~')))
    
    # create mc object
    mc = setupMC(fname, params)
    
    # compute initial template by binned median filtering
    # this template will be refined during the registration process
    template = cm.load(fname).bin_median(window=10)
    
    # apply rigid correction
    mc.motion_correct_rigid(save_movie=True, template=template)
    # apply piecewise rigid correction
    mc.motion_correct_pwrigid(save_movie=True, template=mc.total_template_rig, show_template = False)

    # append results
    mc_results.append(mc)
    
    # compute borders to exclude
    bord_px_els = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)), 
                                     np.max(np.abs(mc.y_shifts_els)))).astype(np.int)
    
    # compute metrics for the results
    try:
        mtrs = computeMetrics(mc, bord_px_els, swap_dim, winsize, resize_fact_flow)
    
        # correlations, crispness and norms of residual optic flow as indicators of registration quality
        crispness_metric[ix,:] = np.array([mtrs['crispness_orig'], 
                                           mtrs['crispness_rig'], mtrs['crispness_els']])
        norms_metric[ix,:] = np.array([np.mean(mtrs['norms_orig']), 
                                       np.mean(mtrs['norms_rig']), np.mean(mtrs['norms_els'])])
        corr_mean_metric[ix,:] = np.array([np.mean(mtrs['corr_orig']), 
                                           np.mean(mtrs['corr_rig']), np.mean(mtrs['corr_els'])])
        corr_min_metric[ix,:] = np.array([np.min(mtrs['corr_orig']), 
                                          np.min(mtrs['corr_rig']), np.min(mtrs['corr_els'])])
    except:
        # sometimes the metrics cannot be computed, due to NaNs in the resulting pw-rigid image
        # this means that the registration is probably bad and / or the movie contains lots of motion
        crispness_metric[ix,:] = np.array([np.nan, np.nan, np.nan])
        norms_metric[ix,:] = np.array([np.nan, np.nan, np.nan])
        corr_mean_metric[ix,:] = np.array([np.nan, np.nan, np.nan])
        corr_min_metric[ix,:] = np.array([np.nan, np.nan, np.nan])
        bad_movies.append(fname.replace(data_folder, '~'))
           
    clear_output()

# print the "failed" registrations
print('%1.0f Failed registrations (%1.2f %%):' % (len(bad_movies), (len(bad_movies)/len(tiff_files_crop))*100))
print(bad_movies)
    
# print elapsed time
t_elapsed = time.time() - t_start
print('\nFinished MC for %1.0f files in %1.2f s (%1.2f s per file)' % (len(mc_results), 
                                                                     t_elapsed, t_elapsed/len(mc_results)))
if platform.system() == 'Darwin':
    os.system('say "your program has finished"')

### Assess quality of motion correction
A number of key metrics can be calculated to assess how much motion correction improved overall motion. 
1. Correlation
Correlations of each frame with the template image (binned median) for original, rigid correction and pw-rigid correction. The mean correlation gives an overall impression of motion. The minimum correlation indicates the parts of the movie that are worst affected by motion. Larger correlations indicate less motion.
2. Crispness
Crispness provides a measure of the smoothness of the corrected average image. Intuitively, a dataset with nonregistered motion will have a blurred mean image, resulting in a lower value for the total gradient field norm. Thus, larger values indicate a crisper average image and less residual motion. Crispness is calculated from the gradient field of the mean image (`np.gradient`).
3. Residual optical flow
Optic flow algorithms attempt to match each frame to the template by estimating locally smooth displacement fields. The output is an image where each pixel described the local displacement between template and frame at this point. The smaller the local displacement, the better the registration. Here we compute the matrix norm of the optic flow matrix as summary statistic.

The cell below prints summary statistics of the different metrics for all processed files. See further below for code to generate detailed plots for individual files.

In [None]:
for ix, file_metrix in enumerate(norms_metric):
    if ix:
        print('')
    print('MC evaluation - %s:' % (tiff_files_crop[ix]))
    if corr_mean_metric[ix,0] > corr_mean_metric[ix,1] or corr_mean_metric[ix,0] > corr_mean_metric[ix,2]:
        print('\x1b[1;03;31m'+'Mean corr - raw / rigid / pw_rigid: ' 
              + str(['{:.2f}'.format(i) for i in corr_mean_metric[ix,:]])+ '\x1b[0m')
    else:
        print('Mean corr - raw / rigid / pw_rigid: ' + str(['{:.2f}'.format(i) for i in corr_mean_metric[ix,:]]))
        
    if corr_min_metric[ix,0] > corr_min_metric[ix,1] or corr_min_metric[ix,0] > corr_min_metric[ix,2]:
        print('\x1b[1;03;31m'+'Min corr - raw / rigid / pw_rigid: ' 
              + str(['{:.2f}'.format(i) for i in corr_min_metric[ix,:]])+ '\x1b[0m')
    else:
        print('Min corr - raw / rigid / pw_rigid: ' + str(['{:.2f}'.format(i) for i in corr_min_metric[ix,:]]))
    if crispness_metric[ix,0] > crispness_metric[ix,1] or crispness_metric[ix,0] > crispness_metric[ix,2]:
        print('\x1b[1;03;31m'+'Crispness - raw / rigid / pw_rigid: ' 
              + str(['{:.0f}'.format(i) for i in crispness_metric[ix,:]]) + '\x1b[0m')
    else:
        print('Crispness - raw / rigid / pw_rigid: ' + str(['{:.0f}'.format(i) for i in crispness_metric[ix,:]]))
    if norms_metric[ix,0] < norms_metric[ix,1] or norms_metric[ix,0] < norms_metric[ix,2]:
        print('\x1b[1;03;31m'+'Norms - raw / rigid / pw_rigid: ' 
              + str(['{:.0f}'.format(i) for i in norms_metric[ix,:]]) + '\x1b[0m')
    else:
        print('Norms - raw / rigid / pw_rigid: ' + str(['{:.2f}'.format(i) for i in norms_metric[ix,:]]))

Summary scatter plots of the above numbers. These can be useful to detect failed registrations or optimise parameters.

In [None]:
plt.figure(figsize = (10,20))
# Mean correlation
plt.subplot(421); plt.scatter(corr_mean_metric[:,0], corr_mean_metric[:,1]); plt.xlabel('Original'); 
plt.ylabel('Rigid'); plt.plot([0,1],[0,1],'r--')
axes = plt.gca(); axes.set_xlim([0,1]); axes.set_ylim([0,1]); plt.axis('square'); plt.title('Mean Correlation')
plt.subplot(422); plt.scatter(corr_mean_metric[:,0], corr_mean_metric[:,2]); plt.xlabel('Original'); 
plt.ylabel('PW-Rigid'); plt.plot([0,1],[0,1],'r--')
axes = plt.gca(); axes.set_xlim([0,1]); axes.set_ylim([0,1]); plt.axis('square');
# Min correlation
plt.subplot(423); plt.scatter(corr_min_metric[:,0], corr_min_metric[:,1]); plt.xlabel('Original'); 
plt.ylabel('Rigid'); plt.plot([0,1],[0,1],'r--')
axes = plt.gca(); axes.set_xlim([0,1]); axes.set_ylim([0,1]); plt.axis('square'); plt.title('Min. Correlation')
plt.subplot(424); plt.scatter(corr_min_metric[:,0], corr_min_metric[:,2]); plt.xlabel('Original'); 
plt.ylabel('PW-Rigid'); plt.plot([0,1],[0,1],'r--')
axes = plt.gca(); axes.set_xlim([0,1]); axes.set_ylim([0,1]); plt.axis('square');
# Crispness
plt.subplot(425); plt.scatter(crispness_metric[:,0], crispness_metric[:,1]); plt.xlabel('Original'); 
plt.ylabel('Rigid'); plt.plot([1000,3000],[1000,3000],'r--')
axes = plt.gca(); axes.set_xlim([1000,3000]); axes.set_ylim([1000,3000]); plt.axis('square'); plt.title('Crispness')
plt.subplot(426); plt.scatter(crispness_metric[:,0], crispness_metric[:,2]); plt.xlabel('Original'); 
plt.ylabel('PW-Rigid'); plt.plot([1000,3000],[1000,3000],'r--')
axes = plt.gca(); axes.set_xlim([0,1]); axes.set_ylim([0,1]); plt.axis('square');
# Norms
plt.subplot(427); plt.scatter(norms_metric[:,0], norms_metric[:,1]); plt.xlabel('Original'); 
plt.ylabel('Rigid'); plt.plot([0,60],[0,60],'r--')
axes = plt.gca(); axes.set_xlim([0,60]); axes.set_ylim([0,60]); plt.axis('square'); plt.title('ROF Norms')
plt.subplot(428); plt.scatter(norms_metric[:,0], norms_metric[:,2]); plt.xlabel('Original'); 
plt.ylabel('PW-Rigid'); plt.plot([0,60],[0,60],'r--')
axes = plt.gca(); axes.set_xlim([0,60]); axes.set_ylim([0,60]); plt.axis('square');

### Interactively assess quality of motion correction
`compute_metrics_motion_correction` can be used to calculate different metrics to assess the quality of motion correction. Below, select a specific file, calculate the metrics and then inspect the results as plots / images.

In [None]:
# select file to assess
file_ix = 1 # e.g. 1, 2, 3, ...
if file_ix > len(mc_results):
    raise Exception('MC results only available for %1.0f files!' % (len(mc_results)))
mc = mc_results[file_ix-1]

In [None]:
%%capture

# compute borders to exclude
bord_px_els = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)),
                                 np.max(np.abs(mc.y_shifts_els)))).astype(np.int)
    
# compute metrics for the results
mtrs = computeMetrics(mc, bord_px_els, swap_dim, winsize, resize_fact_flow)

#### Correlation with template image
Plot correlations of each frame with the template image (binned median) for original, rigid correction and pw-rigid correction.

In [None]:
plt.figure(figsize = (20,10))
plt.subplot(211); plt.plot(mtrs['corr_orig']); plt.plot(mtrs['corr_rig']); plt.plot(mtrs['corr_els'])
plt.legend(['Original','Rigid','PW-Rigid'])
plt.subplot(223); plt.scatter(mtrs['corr_orig'], mtrs['corr_rig']); plt.xlabel('Original'); 
plt.ylabel('Rigid'); plt.plot([0,1],[0,1],'r--')
axes = plt.gca(); axes.set_xlim([0,1]); axes.set_ylim([0,1]); plt.axis('square');
plt.subplot(224); plt.scatter(mtrs['corr_rig'], mtrs['corr_els']); plt.xlabel('Rigid'); 
plt.ylabel('PW-Rigid'); plt.plot([0,1],[0,1],'r--')
axes = plt.gca(); axes.set_xlim([0,1]); axes.set_ylim([0,1]); plt.axis('square');

#### Crispness
This provides a measure of the smoothness of the corrected average image. Intuitively, a dataset with nonregistered
motion will have a blurred mean image, resulting in a lower value for the total gradient field norm. Thus, larger values indicate a crisper average image and less residual motion. Crispness is calculated from the gradient field of the mean image (`np.gradient`).

In [None]:
print('Crispness original: %1.0f' % (mtrs['crispness_orig']))
print('Crispness rigid: %1.0f' % (mtrs['crispness_rig']))
print('Crispness elastic: %1.0f' % (mtrs['crispness_els']))

#### Residual optic flow
Optic flow algorithms attempt to match each frame to the template by estimating locally smooth displacement fields. The output is an image where each pixel described the local displacement between template and frame at this point. The smaller the local displacement, the better the registration. Norms are the matrix norms of the optic flow matrix.

In [None]:
# plot the results of Residual Optical Flow
metrics_files = [mc.fname_tot_els[0][:-4] + '_metrics.npz', mc.fname_tot_rig[0][:-4] +
       '_metrics.npz', mc.fname[0][:-4] + '_metrics.npz']

plt.figure(figsize = (20,10))
for cnt, fl, metr in zip(range(len(metrics_files)),metrics_files,['pw_rigid','rigid','raw']):
    with np.load(fl) as ld:
        print('Correction: %s' % (metr))
        print('Norms: %1.2f +- %1.2f' % (np.mean(ld['norms']), np.std(ld['norms'])))
        
        plt.subplot(len(metrics_files), 3, 1 + 3 * cnt)
        plt.ylabel(metr)
        
        if metr == 'raw':
            mean_img = np.mean(cm.load(mc.fname[0]), axis=0)[bord_px_els:-bord_px_els, bord_px_els:-bord_px_els]
        elif metr == 'rigid':
            mean_img = np.mean(cm.load(mc.fname_tot_rig[0]), axis=0)[bord_px_els:-bord_px_els, bord_px_els:-bord_px_els]
        elif metr == 'pw_rigid':
            mean_img = np.mean(cm.load(mc.fname_tot_els[0]), axis=0)[bord_px_els:-bord_px_els, bord_px_els:-bord_px_els]
                    
        lq, hq = np.nanpercentile(mean_img, [.5, 99.5])
        plt.imshow(mean_img, vmin=lq, vmax=hq, cmap='gray')
        if not cnt:
            plt.title('Mean')
        plt.subplot(len(metrics_files), 3, 3 * cnt + 2)
        plt.imshow(ld['img_corr'], vmin=0, vmax=.5, cmap='gray')
        if not cnt:
            plt.title('Correlation image')
        plt.subplot(len(metrics_files), 3, 3 * cnt + 3)
        flows = ld['flows']
        plt.imshow(np.mean(np.sqrt(flows[:, :, :, 0]**2 + flows[:, :, :, 1]**2), 0), vmin=0, vmax=0.5, cmap='gray')
        plt.colorbar()
        if not cnt:
            plt.title('Mean optical flow')

#### Play movies
Finally, concatenate the three movies and play them.

In [None]:
# load uncorrected and corrected movies
mov_uc = cm.load(mc.fname)
mov_rig = cm.load(mc.fname_tot_rig)
mov_els = cm.load(mc.fname_tot_els)
# compare movies
print('Movie comparison for %s\n(uncorrected, rigid, elastic)' % 
      (tiff_files_crop[file_ix-1]))
mov_all = cm.concatenate([mov_uc, mov_rig, mov_els], axis=2)
mov_all.play(fr=frame_rate, backend='notebook')

### Reorganise files and delete empty files

**This part is now done with a separate Matlab script.**

This cell changes the folder structure from `data_folder/HH_MM_SS_Live/file_stem_ 1234.tif` to `data_folder/file_stem_1234.tif`. All the sub-folders (HH_MM_SS_Live) get deleted. The `parameters.xml` file is renamed to `file_stem_1234.xml`. This only needs to be done once. Once the files have been reorganised, running this again *should* have no effect.

In [None]:
# # count folders that are processed
# processed_folders = 0
# # loop over all items in data_folder
# for item_1 in os.listdir(data_folder):
#     # select only directories
#     if os.path.isdir(os.path.join(data_folder, item_1)):
#         # loop over all items in sub-directory
#         for item_2 in os.listdir(os.path.join(data_folder, item_1)):
#             # only do something if a file name starts with file_stem and ends with ext
#             if item_2.startswith(file_stem) and item_2.endswith(ext):
#                 # determine the trial ID (i.e. the number at the end of the TIF file)
#                 trial_id = item_2[item_2.rfind(ext)-4:item_2.rfind(ext)]
#                 # create the new filename (without space!)
#                 new_file = file_stem + trial_id + ext
#                 # rename and move the TIFF file to the top-level folder
#                 os.rename(os.path.join(data_folder, item_1, item_2), os.path.join(data_folder, new_file))
#                 # rename and move the XML file to the top-level folder
#                 os.rename(os.path.join(data_folder, item_1, 'parameters.xml'), \
#                           os.path.join(data_folder, new_file.replace('.tif','.xml')))
#                 # delete the subfolder
#                 shutil.rmtree(os.path.join(data_folder, item_1))
#                 processed_folders += 1
# print("Processed %1.0f folders" % (processed_folders))