## Motion Correction using the NoRMCorre algorithm

Can use both rigid and piecewise motion correction. The original motion correction algorithm is described in this paper:
https://www.sciencedirect.com/science/article/pii/S0165027017302753?via%3Dihub

More details and tips here: https://caiman.readthedocs.io/en/master/CaImAn_Tips.html#motion-correction-tips

And an example pipeline here:  https://github.com/flatironinstitute/CaImAn/blob/6c5c3b6117b71b6e8b44f62fc26fd3b3d914de12/demos/notebooks/demo_motion_correction.ipynb

In [None]:
try:
    get_ipython().magic(u'load_ext autoreload')
    get_ipython().magic(u'autoreload 2')
    get_ipython().magic(u'matplotlib qt')
except:
    pass



import matplotlib.pyplot as plt
import numpy as np
import os
import pathlib
from shutil import copyfile
from mpl_toolkits.axes_grid1 import make_axes_locatable

import caiman as cm
from caiman.motion_correction import MotionCorrect
from caiman.source_extraction.cnmf import params as params
from caiman.utils.visualization import inspect_correlation_pnr

import logging
from fancylog import fancylog
import fancylog as package


try:
    cv2.setNumThreads(0)
except:
    pass
import bokeh.plotting as bpl

from fcutils.file_io.utils import check_create_folder, get_file_name
from fcutils.plotting.utils import clean_axes, save_figure
from fcutils.plotting.colors import *

from movie_visualizer import compare_videos
from utils import start_server, load_params, add_to_params_dict

bpl.output_notebook()

c, dview, n_processes = start_server()

from IPython.display import clear_output
%config InlineBackend.figure_format = "retina"

## Get file paths

In [None]:
# Get files to process
fld = "D:\\Dropbox (UCL - SWC)\\Project_vgatPAG\\analysis\\doric\\BF164p1\\19JUN05" # main data folder

fnames    = [os.path.join(fld, '19JUN05_BF164p1_v1_ds126_crop_ffcSub.tiff')]  # video to processprocessed
base_name = pathlib.Path(fnames[0]).stem # used to save mmepped data

output_fld = os.path.join(fld, "MC_output3") # plots and other stuff will be saved here
check_create_folder(output_fld)



## Set up MC params

Each session's parameters are saved in fld > 01_PARAMS > params.yml

In [None]:
# dataset dependent parameters
frate = 10.                       # movie frame rate
decay_time = 2.              # length of a typical transient in seconds


# Load recording specific params and fill them up
prms = load_params(fld)['motion_correction']
prms = add_to_params_dict(prms, fnames=fnames, fr=frate, decay_time=decay_time)


opts = params.CNMFParams(params_dict=prms)


# Setup logging
logging_file = fancylog.start_logging(os.path.join(fld, "02_LOGS"), package, file_log_level="INFO", variables=[opts], verbose=False, filename='motion_correction_logs')

with open( os.path.join(output_fld, "log_file_path.txt"), "w+") as t:
    t.write(logging_file)

logging.info("Starting motion correction")
logging.info(f"Output folder: {output_fld}")

# Motion Correction
The background signal in micro-endoscopic data is very strong and makes the motion correction challenging. 
As a first step the algorithm performs a high pass spatial filtering with a Gaussian kernel to remove the bulk of the background and enhance spatial landmarks. 
The size of the kernel is given from the parameter `gSig_filt`. If this is left to the default value of `None` then no spatial filtering is performed (default option, used in 2p data).
After spatial filtering, the NoRMCorre algorithm is used to determine the motion in each frame. The inferred motion is then applied to the *original* data so no information is lost.



## Rigid motion correction

In [None]:
mc = MotionCorrect(fnames, dview=dview, **opts.get_group('motion'))

In [None]:
# correct for rigid motion correction and save the file (in memory mapped form)
_ = mc.motion_correct(save_movie=True)

# Now save in C order for CNMF-E 
bord_px = 0
bord_px = 0 if prms['border_nan'] is 'copy' else bord_px
fname_new = cm.mmapping.save_memmap(mc.mmap_file, base_name=base_name+"_rig_", order='C', border_to_0=bord_px)


logging.info(f"Rigid motion corrected video was saved at:\n {fname_new}")


### Piecewise rigid motion correction

In [None]:
# motion correct piecewise rigid
mc.pw_rigid = True  # turn the flag to True for pw-rigid motion correction
mc.template = mc.mmap_file  # use the results of the rigid motion corrction to save in computation
_ = mc.motion_correct(save_movie=True, template=mc.total_template_rig)

# Now save in C order for CNMF-E 
bord_px = 0 if prms['border_nan'] is 'copy' else bord_px
fname_new_els = cm.mmapping.save_memmap(mc.mmap_file, base_name=base_name+"_els_", order='C', border_to_0=bord_px)

logging.info(f"Piecewise motion corrected video was saved at:\n {fname_new_els}")



# MC quality evaluation

Visualize the results in plots and video and compute metrics to quantify the quality of the motion correction.


In [None]:
MC_TYPE = "els" # els or rigid mc
if MC_TYPE == "els":
    eval_video = fname_new_els 
else:
    eval_video = fname_new

In [None]:
# load motion corrected movie
m_rig = cm.load(eval_video)

# visualize templates
plt.figure(figsize = (20,10))
_ = plt.imshow(mc.total_template_rig, cmap = 'gray')


## Look at correlation and PNR in motion corrected movie

You can use this to find the correct values for `min_corr` and `min_pnr` values for cnmf fitting.

In [None]:
# Load memmapped motion corrected video
Yr, dims, T = cm.load_memmap(fname_new_els)
images = Yr.T.reshape((T,) + dims, order='F')

# compute some summary images (correlation and peak to noisae)
cn_filter, pnr = cm.summary_images.correlation_pnr(images, gSig=opts.init['gSig'][0], swap_dim=False) 

# inspect the summary images and set the parameters
inspect_correlation_pnr(cn_filter, pnr)

### Plot background
Crate a figure with background and other images from MC video and save to file

In [None]:
f, axarr = plt.subplots(ncols=4, figsize=(20, 8))
imgs = [mc.total_template_rig, np.median(images, axis=0), cn_filter, pnr]
ttls = ["templates", "median bg", "cn_filter", "pnr"]
for ax, img, ttl in zip(axarr, imgs, ttls):
    ax.imshow(img, cmap="viridis")
    ax.set(title=ttl)
clean_axes(f)
save_figure(f, os.path.join(output_fld, MC_TYPE+"mc_representative_images"))

## Inspect videos
In addition to these plots, it's worth using `compare_videos` in `movie_visualizer.py` to compare raw, rigid mc and pw mc videos:

```compare_videos(raw = fnames[0], rigid=fname_new, piecewise=fname_new_els)``` [better to do it outside of this ntotebook as it might block the kernel, you can find the path to the files in the jupyter notebook.]


In [None]:
logging.info(f"\n\nRaw video {fnames[0]}\n\nRigid mc: {fname_new}\n\nPiecewise mc: {fname_new_els}\n\n")

## Compute quality metrics

In [None]:
# compute metrics for the results (takes ~5 mins)
final_size = np.subtract(mc.total_template_els.shape, 2 * bord_px) # remove pixels in the boundaries
winsize = 100
swap_dim = False
resize_fact_flow = .2    # downsample for computing ROF

logging.info("Computing quality metrics")


# Compute for raw video
tmpl_orig, correlations_orig, flows_orig, norms_orig, crispness_orig = cm.motion_correction.compute_metrics_motion_correction(
    fnames[0], final_size[0], final_size[1], swap_dim, winsize=winsize, play_flow=False, resize_fact_flow=resize_fact_flow)

# Compute for rigid MC
tmpl_rig, correlations_rig, flows_rig, norms_rig, crispness_rig = cm.motion_correction.compute_metrics_motion_correction(
    mc.fname_tot_rig[0], final_size[0], final_size[1],
    swap_dim, winsize=winsize, play_flow=False, resize_fact_flow=resize_fact_flow)

# Compute for piecewise MC
tmpl_els, correlations_els, flows_els, norms_els, crispness_els = cm.motion_correction.compute_metrics_motion_correction(
    mc.fname_tot_els[0], final_size[0], final_size[1],
    swap_dim, winsize=winsize, play_flow=False, resize_fact_flow=resize_fact_flow)

# Copy metric files to the output directory
ttles = ["raw", "rigid", "piecewise"]
files = [fnames[0], mc.fname_tot_rig[0], mc.fname_tot_els[0]]
metric_files = [os.path.join(fld, get_file_name(f)+"._metrics.npz") for f in files]
dests = [os.path.join(os.path.split(f)[0], output_fld, os.path.split(f)[1]) for f in metric_files]

for src, dest in zip(metric_files, dests):
    copyfile(src, dest)

### Correlation
Look at the correlation between each frame and a reference frame (e.g. avg frame across entire video). If the motion correction worked, the correlation of the motion corrected frames should be higher than that of the raw frame (if stuff moves frames wont't be correlated). This metric is in part influenced by neural activity but not too much. 

In [None]:
# Plot correlation with mean frame
f = plt.figure(figsize = (20,10))

# Plot correlation to mean frame for all frames
ax = plt.subplot(211)
ax.plot(correlations_orig, color=goldenrod, label="original", lw=3, alpha=1)
ax.plot(correlations_rig, color=darkseagreen, label="rigid", lw=2, alpha=.8)
ax.plot(correlations_els, color=salmon, label="piecewise", lw=2, alpha=.6)
ax.set(title="Frame by frame correlation to mean frame", xlabel="frame", ylabel="correlation")
ax.legend()

# Plot original vs rigid correlation
ax = plt.subplot(234)
ax.scatter(correlations_orig, correlations_rig, color=darkseagreen, alpha=.3)
ax.plot([0, 1], [0, 1], '--', lw=2, alpha=.8, color=[.4, .4, .4])
ax.set(xlabel="original", ylabel="rigid",
            xlim=[.3, .7], ylim=[.3, .7])
ax.axis('square')

# Plot original vs piecewise
ax = plt.subplot(235)
ax.scatter(correlations_orig, correlations_els, color=salmon, alpha=.3)
ax.plot([0, 1], [0, 1], '--', lw=2, alpha=.8, color=[.4, .4, .4])
ax.set(xlabel="original", ylabel="piecewise",
            xlim=[.3, .7], ylim=[.3, .7])
_ = ax.axis('square')


# Plot rigid vs piecewise
ax = plt.subplot(236)
ax.scatter(correlations_rig, correlations_els, color=blackboard, alpha=.3)
ax.plot([0, 1], [0, 1], '--', lw=2, alpha=.8, color=[.4, .4, .4])
ax.set(xlabel="rigid", ylabel="piecewise",
            xlim=[.3, .7], ylim=[.3, .7])
_ = ax.axis('square')

# save
clean_axes(f)
save_figure(f, os.path.join(output_fld, "frames_correlation_to_reference"))

### Crispness
Another metric is cripsness. If the motion correction worked. The average frame should be crisper (less blurry). 

In [None]:
# print crispness values
msg = f"Crispness:\n  original: {int(crispness_orig)}\n  rigid:  {int(crispness_rig)}\n  piecewise: {int(crispness_els)}"
logging.info(msg)




### Optic flow
Finally the last metric checks the optic flow across the video. If the video has been correctly motion corrected, the residual optic flow should be minimal. 

In [None]:
# plot the results of Residual Optical Flow
f, axarr = plt.subplots(figsize = (20,10), ncols=2, nrows=3)

for i, (fl, ttl) in enumerate(zip(dests, ttles)):
    if not os.path.isfile(fl):
        raise ValueError
    else:
        ld = np.load(fl)
        
    if fl.endswith("mmap"):
        mean_img = np.mean(cm.load(fl[:-12] + 'mmap'), 0)[12:-12, 12:-12]
    else:
        mean_img = np.mean(cm.load(fnames[0]), 0)[12:-12, 12:-12]


    lq, hq = np.nanpercentile(mean_img, [.5, 99.5])

    axarr[i, 0].imshow(mean_img, vmin=lq, vmax=hq)
    axarr[i, 0].set(title="Mean image " + ttl)

    flows = ld['flows']
    img = axarr[i, 1].imshow(np.mean(np.sqrt(flows[:, :, :, 0]**2 + flows[:, :, :, 1]**2), 0))

    divider = make_axes_locatable(axarr[i, 1])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    f.colorbar(img, cax=cax, orientation='vertical')

    axarr[i, 1].set(title="mean optical flow " + ttl)

# save
f.tight_layout()
clean_axes(f)
save_figure(f, os.path.join(output_fld, "optic_flow_summary"))

In [None]:
# Save the average residual flow for each condition as an image
cmap = "viridis"

for fl, ttl in zip(dests, ttles):
    ld = np.load(fl)
    mean_flow = np.mean(np.sqrt(ld['flows'][:, :, :, 0]**2 + ld['flows'][:, :, :, 1]**2), 0)

    save_path = os.path.join(output_fld, f"residual_opticflow_{ttl}")

    plt.imsave(save_path, mean_flow, cmap=cmap)
    np.save(save_path+".npy", mean_flow)

    logging.info(f"Saving residual optic flow for {ttl} at: {save_path}")
