## Motion correction of Dendritic GCaMP6 data with CaImAn toolbox

### Imports

In [1]:
# Generic imports
from __future__ import absolute_import, division, print_function
from builtins import *

import os, sys, glob, platform
import re
import time
import numpy as np
import matplotlib.pyplot as plt

from IPython.display import clear_output

%matplotlib inline

In [2]:
# on Linux we need to set some environment variables
if platform.system() == 'Linux':
    sys.path.append(os.path.expanduser('~/caiman'))
    os.environ['MKL_NUM_THREADS']='1'
    os.environ['OPENBLAS_NUM_THREADS']='1'

In [3]:
# CaImAn imports
import caiman as cm
from caiman.motion_correction import tile_and_correct, motion_correction_piecewise
from caiman.motion_correction import motion_correct_batch_rigid, motion_correct_batch_pwrigid
from caiman.motion_correction import MotionCorrect
from caiman.mmapping import save_memmap_each

# Custom imports from utils
from utils import define_params

In [4]:
data_folder = '/home/ubuntu/example_data/M2_2018-01-30/S2'
data_folder = '/Users/Henry/polybox/Data_temp/Dendrites_Gwen/M2_2018-01-30/S2'
exp_id = '11-33-48_Live'
file_id = 'test_A1_Ch0_ 0171.tif'
frame_rate = 13.1316 # in Hz
del_frame = 0

# max_files = 5 # limit for testing (np.inf for all)

path_to_images = os.path.join(data_folder, exp_id, file_id)

# In this case, there is only one channel per movie. So we can just create a list of the TIFF files to import.
tiff_files = [path_to_images]

In [6]:
# define parameters
params = define_params.get_params_movie(mode='dendrite')

In [7]:
# print parameters
params

{'K': 4,
 'alpha_snmf': None,
 'final_frate': 10,
 'gSig': [4, 4],
 'init_method': 'sparse_nmf',
 'is_dendrites': True,
 'max_deviation_rigid': 3,
 'max_shifts': (8, 4),
 'merge_thresh': 0.8,
 'niter_rig': 1,
 'num_splits_to_process_els': [28, None],
 'num_splits_to_process_rig': None,
 'overlaps': (24, 24),
 'p': 1,
 'rf': 15,
 'splits_els': 10,
 'splits_rig': 10,
 'stride_cnmf': 6,
 'strides': (48, 48),
 'upsample_factor_grid': 4}

In [5]:
# get parameters from dictionary
niter_rig = params_movie['niter_rig']
# maximum allow rigid shift
max_shifts = params_movie['max_shifts']
# for parallelization split the movies in  num_splits chuncks across time
splits_rig = params_movie['splits_rig']
# if none all the splits are processed and the movie is saved
num_splits_to_process_rig = params_movie['num_splits_to_process_rig']
# intervals at which patches are laid out for motion correction
strides = params_movie['strides']
# overlap between pathes (size of patch strides+overlaps)
overlaps = params_movie['overlaps']
# for parallelization split the movies in  num_splits chuncks across time
splits_els = params_movie['splits_els']
# if none all the splits are processed and the movie is saved
num_splits_to_process_els = params_movie['num_splits_to_process_els']
# upsample factor to avoid smearing when merging patches
upsample_factor_grid = params_movie['upsample_factor_grid']
# maximum deviation allowed for patch with respect to rigid
# shift
max_deviation_rigid = params_movie['max_deviation_rigid']

In [None]:
# cluster params
cluster_backend = 'local'
# cluster_backend = 'local'
cluster_processes = 1
cluster_single_thread = True

In [None]:
# display params
display_backend = 'notebook'
display_fr = frame_rate

In [None]:
# load movies and play video
m_orig = cm.load_movie_chain(tiff_files)
m_orig.play(fr=display_fr, backend='notebook')

### Start the cluster

In [None]:
print('Starting cluster')
c, dview, n_processes = cm.cluster.setup_cluster(
    backend=cluster_backend, n_processes=cluster_processes, single_thread=cluster_single_thread)
print('done')

### Process motion correction for all files

In [None]:
def setupMC(fname, params):
    mov = cm.load(fname)
    mc = MotionCorrect(fname, mov.min(), dview=None, 
                       max_shifts=params['max_shifts'], 
                       niter_rig=params['niter_rig'], 
                       splits_rig=params['splits_rig'], 
                       num_splits_to_process_rig=params['num_splits_to_process_rig'], 
                       strides= params['strides'], 
                       overlaps= params['overlaps'], 
                       splits_els=params['splits_els'], 
                       num_splits_to_process_els=params['num_splits_to_process_els'], 
                       upsample_factor_grid=params['upsample_factor_grid'], 
                       max_deviation_rigid=params['max_deviation_rigid'], 
                       shifts_opencv = True, nonneg_movie = False)
    return mc

In [None]:
# this loop runs the MC for each file and stores the resulting object in mc_results
# each MC also produces 2 output files in CaImAn mmap format (one for rigid and one for elastic registration)
# the mmap files are stored in the same folder as the original TIFF image
t_start = time.time()
mc_results = []
for fname in tiff_files:
    print('Running MC for %s' % (fname))
    mc = setupMC(fname, params)
    # run rigid motion correction for initial alignment
    mc.motion_correct_rigid(save_movie=False)
    # run piecewise rigid (elastic) correction for fine alignment (use template from rigid)
    mc.motion_correct_pwrigid(save_movie=True, template=mc.total_template_rig, show_template = False)
    # MEMORY MAPPING
    # memory map the file in order 'C'
    fname_els = mc.fname_tot_els
    bord_px_els = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)),
                                     np.max(np.abs(mc.y_shifts_els)))).astype(np.int) 
    
    base_name = os.path.basename(fname).replace('.tif','') + '_els_'
    fname_new = cm.save_memmap(fname_els, base_name=base_name, order = 'C', 
                               border_to_0 = bord_px_els, remove_init=del_frame)
    os.remove(mc.fname_tot_els[0])
    mc.fname_tot_els = fname_new
    mc_results.append(mc)

# clear output and print elapsed time
clear_output()
t_elapsed = time.time() - t_start
print('Finished MC for %1.0f files in %1.2f s (%1.2f s per file)' % (len(mc_results), 
                                                                     t_elapsed, t_elapsed/len(mc_results)))

In [None]:
# STOP CLUSTER and clean up log files
print('Stopping cluster')
cm.stop_server()

print('Clean up log files')
log_files = glob.glob('*_LOG_*')
for log_file in log_files:
    os.remove(log_file)

### Results of motion correction
`compute_metrics_motion_correction` can be used to calculate different metrics to assess the quality of motion correction.

In [None]:
def computeMetric(mc):
    # compute metrics for the motion correction
    # note that we use a modified version of cm.motion_correction.compute_metrics_motion_correction 
    # that returns frame-by-template and frame-by-frame-1 correlations
    
    # compute borders to exclude
    bord_px_els = np.ceil(np.maximum(np.max(np.abs(mc.x_shifts_els)),
                                     np.max(np.abs(mc.y_shifts_els)))).astype(np.int)
    
    final_size = np.subtract(mc.total_template_els.shape, 2 * bord_px_els)
    winsize = 100
    swap_dim = False
    resize_fact_flow = .2
    
    # uncorrected
    tmpl, corr_tmpl_unc, corr_frame_unc, flows_orig, norms, smoothness = \
    cm.motion_correction.compute_metrics_motion_correction(mc.fname, final_size[0], final_size[1], 
                                                           swap_dim, winsize=winsize, play_flow=False, 
                                                           resize_fact_flow=resize_fact_flow)
    # rigid
    tmpl, corr_tmpl_rig, corr_frame_rig, flows_orig, norms, smoothness = \
    cm.motion_correction.compute_metrics_motion_correction(mc.fname_tot_rig, final_size[0], final_size[1], 
                                                           swap_dim, winsize=winsize, play_flow=False, 
                                                           resize_fact_flow=resize_fact_flow)
    # elastic
    tmpl, corr_tmpl_els, corr_frame_els, flows_orig, norms, smoothness = \
    cm.motion_correction.compute_metrics_motion_correction(mc.fname_tot_els, final_size[0], final_size[1], 
                                                           swap_dim, winsize=winsize, play_flow=False, 
                                                           resize_fact_flow=resize_fact_flow)

    return corr_tmpl_unc, corr_frame_unc, corr_tmpl_els, corr_frame_els

In [None]:
# Compare movies for selected file
file_ix = 1 # which file to show the comparison for
if file_ix > len(mc_results):
    raise Exception('MC results only available for %1.0f files!' % (len(mc_results)))

# load uncorrected and corrected movies
mov_uc = cm.load(mc_results[file_ix-1].fname)[del_frame:,:,:]
# mov_rig = cm.load(mc_results[file_ix-1].fname_tot_rig)
mov_els = cm.load(mc_results[file_ix-1].fname_tot_els)
# compare movies
print('Movie comparison for %s\n(uncorrected, elastic)' % 
      (tiff_files[file_ix-1]))
mov_all = cm.concatenate([mov_uc, mov_els], axis=2)
mov_all.play(fr=display_fr, backend=display_backend, magnification=10)