# Evaluate motion correction
Compute metrics to evaluate the quality of the motion correction 

In [None]:
try:
    get_ipython().magic(u'load_ext autoreload')
    get_ipython().magic(u'autoreload 2')
    get_ipython().magic(u'matplotlib qt')
except:
    pass

import matplotlib.pyplot as plt
import numpy as np
import os
import pathlib
from shutil import copyfile

import caiman as cm
from caiman.utils.visualization import inspect_correlation_pnr
from caiman.motion_correction import compute_metrics_motion_correction

import logging
from fancylog import fancylog
import fancylog as package


try:
    cv2.setNumThreads(0)
except:
    pass

from fcutils.file_io.io import load_yaml
from fcutils.file_io.utils import check_create_folder, get_file_name, check_file_exists
from fcutils.plotting.utils import clean_axes, save_figure, add_colorbar_to_img
from fcutils.plotting.colors import *

from movie_visualizer import compare_videos
from utils import load_mmap_video_caiman

## Get files and metadata

The evaluation of MC is based on 3 videos:

- raw (ffcsub) vide
- raw with rigid MC 
- raw with piece wise MC

Note that the shifts for the motion correction are computed on transformed and normalized video, but they're then used to correct the raw video. 
This allows us to estimate the quality of the motion correction without worrying about how the normalization altered the values in the videos (which would throw off some of the metrics used here)

In [None]:
fld = 'D:\\Dropbox (UCL - SWC)\\Project_vgatPAG\\analysis\\doric\\BF164p2\\19JUN26' # <- which folder/recording we are working on
metadata = load_yaml(os.path.join(fld, "01_PARAMS", "analysis_metadata.yml"))
output_fld = os.path.join(fld, metadata['outputfld']) # plots and other stuff will be saved here

# Setup logging
logging_file = fancylog.start_logging(os.path.join(fld, "02_LOGS"), package, file_log_level="INFO", verbose=False, filename='motion_correction_evaluation_logs')

logging.info("ANALYSIS METADATA FILE:")
for k,v in metadata.items():
    logging.info(f"{k}:  {v}")
logging.info(f"Output folder: {output_fld}")

# Get paths to videos
paths_file = os.path.join(output_fld, "video_paths.yml")
check_file_exists(paths_file, raise_error=True)
video_paths = load_yaml(paths_file)
logging.info(f"Video paths: {video_paths}")


## Inspect videos
Use this code to look at the videos side by side. 

In [None]:
# If this throws a runtime error just run the cell again
compare_videos(raw = video_paths['raw'], rigid=video_paths['raw_rig_mc'], pw=video_paths['raw_pw_mc'])

## Compute quality metrics

### Correlation
Look at the correlation between each frame and a reference frame (e.g. avg frame across entire video). If the motion correction worked, the correlation of the motion corrected frames should be higher than that of the raw frame (if stuff moves frames wont't be correlated). This metric is in part influenced by neural activity but not too much. 

### Crispness
Another metric is cripsness. If the motion correction worked. The average frame should be crisper (less blurry). 

### Optic flow
Finally the last metric checks the optic flow across the video. If the video has been correctly motion corrected, the residual optic flow should be minimal. 

In [None]:
# compute metrics for the results (takes ~5 mins)
Yr, dims, T, images = load_mmap_video_caiman(video_paths['raw_rig_mc'])
winsize = 100
swap_dim = False
resize_fact_flow = .2    # downsample for computing ROF

logging.info("Computing quality metrics")

args = [dims[0], dims[1], swap_dim]
kwargs = dict(winsize=winsize, play_flow=False, resize_fact_flow=resize_fact_flow)

# Compute for raw video
tmpl_orig, correlations_orig, flows_orig, norms_orig, crispness_orig = compute_metrics_motion_correction(video_paths['raw'], *args, **kwargs)

# Compute for rigid MC
tmpl_rig, correlations_rig, flows_rig, norms_rig, crispness_rig = compute_metrics_motion_correction(video_paths['raw_rig_mc'], *args, **kwargs)

# Compute for piecewise MC
tmpl_els, correlations_els, flows_els, norms_els, crispness_els = compute_metrics_motion_correction(video_paths['raw_pw_mc'], *args, **kwargs)

In [None]:
# Copy metric files to the output directory
ttles = ["raw", "rigid", "piecewise"]
files = [video_paths['raw'], video_paths['raw_rig_mc'], video_paths['raw_pw_mc']]
metric_files = [os.path.join(fld, get_file_name(f)+"._metrics.npz") for f in files]

_metric_files = []
for f in metric_files:
    if not os.path.isfile(f):
        f = f.replace("._metrics.npz", "_metrics.npz")
    _metric_files.append(f)

dests = [os.path.join(os.path.split(f)[0], output_fld, os.path.split(f)[1]) for f in _metric_files]

for src, dest in zip(_metric_files, dests):
    copyfile(src, dest)

## Plot quality metrics

### Correlation

In [None]:
# Plot correlation with mean frame
f = plt.figure(figsize = (20,10))

# Plot correlation to mean frame for all frames
ax = plt.subplot(211)
ax.plot(correlations_orig, color=goldenrod, label="original", lw=3, alpha=1)
ax.plot(correlations_rig, color=darkseagreen, label="rigid", lw=2, alpha=.8)
ax.plot(correlations_els, color=salmon, label="piecewise", lw=2, alpha=.6)
ax.set(title="Frame by frame correlation to mean frame", xlabel="frame", ylabel="correlation", ylim=[0, 1])
ax.legend()

# Plot original vs rigid correlation
ax = plt.subplot(234)
ax.scatter(correlations_orig, correlations_rig, color=darkseagreen, alpha=.3)
ax.plot([0, 1], [0, 1], '--', lw=2, alpha=.8, color=[.4, .4, .4])
ax.set(xlabel="original", ylabel="rigid",
            xlim=[.3, .7], ylim=[.3, .7])
ax.axis('square')

# Plot original vs piecewise
ax = plt.subplot(235)
ax.scatter(correlations_orig, correlations_els, color=salmon, alpha=.3)
ax.plot([0, 1], [0, 1], '--', lw=2, alpha=.8, color=[.4, .4, .4])
ax.set(xlabel="original", ylabel="piecewise",
            xlim=[.3, .7], ylim=[.3, .7])
_ = ax.axis('square')


# Plot rigid vs piecewise
ax = plt.subplot(236)
ax.scatter(correlations_rig, correlations_els, color=blackboard, alpha=.3)
ax.plot([0, 1], [0, 1], '--', lw=2, alpha=.8, color=[.4, .4, .4])
ax.set(xlabel="rigid", ylabel="piecewise",
            xlim=[.3, .7], ylim=[.3, .7])
_ = ax.axis('square')

# save
clean_axes(f)
save_figure(f, os.path.join(output_fld, "frames_correlation_to_reference"))

### Crispness

In [None]:
# print crispness values
msg = f"Crispness:\n  original: {int(crispness_orig)}\n  rigid:  {int(crispness_rig)}\n  piecewise: {int(crispness_els)}"
logging.info(msg)




### Optic flow

In [None]:
# plot the results of Residual Optical Flow
f, axarr = plt.subplots(figsize = (20,10), ncols=2, nrows=3)

for i, (fl, ttl) in enumerate(zip(dests, ttles)):
    if not os.path.isfile(fl):
        raise ValueError
    else:
        ld = np.load(fl)
        
    if fl.endswith("mmap"):
        mean_img = np.mean(cm.load(fl[:-12] + 'mmap'), 0)[12:-12, 12:-12]
    else:
        mean_img = np.mean(cm.load(video_paths['raw']), 0)[12:-12, 12:-12]


    lq, hq = np.nanpercentile(mean_img, [.5, 99.5])

    axarr[i, 0].imshow(mean_img, vmin=lq, vmax=hq)
    axarr[i, 0].set(title="Mean image " + ttl)

    flows = ld['flows']
    img = axarr[i, 1].imshow(np.mean(np.sqrt(flows[:, :, :, 0]**2 + flows[:, :, :, 1]**2), 0))
    add_colorbar_to_img(img, axarr[i, 1], f)

    axarr[i, 1].set(title="mean optical flow " + ttl)

# save
f.tight_layout()
clean_axes(f)
save_figure(f, os.path.join(output_fld, "optic_flow_summary"))

### Save residual optic flow for each condition
This is used to later on plot the components found by CNMF-E over the residual optic flow to exclud components in poorly motion corrected parts of the frame

In [None]:
# Save the average residual flow for each condition as an image
cmap = "viridis"

for fl, ttl in zip(dests, ttles):
    ld = np.load(fl)
    mean_flow = np.mean(np.sqrt(ld['flows'][:, :, :, 0]**2 + ld['flows'][:, :, :, 1]**2), 0)

    save_path = os.path.join(output_fld, f"residual_opticflow_{ttl}")

    plt.imsave(save_path, mean_flow, cmap=cmap)
    np.save(save_path+".npy", mean_flow)

    logging.info(f"Saving residual optic flow for {ttl} at: {save_path}")
