# Inspect/Visualise CNFM-E fit to motion corrected data

In [None]:
try:
    get_ipython().magic(u'load_ext autoreload')
    get_ipython().magic(u'autoreload 2')
    get_ipython().magic(u'matplotlib qt')
except:
    pass

import matplotlib.pyplot as plt
import numpy as np
import os


import caiman as cm
from caiman.source_extraction.cnmf.spatial import threshold_components
from caiman.utils.visualization import inspect_correlation_pnr

import logging
from fancylog import fancylog
import fancylog as package


import cv2
try:
    cv2.setNumThreads(0)
except:
    pass
import bokeh.plotting as bpl

from fcutils.video.utils import open_cvwriter, get_cap_from_images_folder, save_videocap_to_video
from fcutils.plotting.utils import clean_axes, save_figure

from utils import print_cnmfe_components, plot_components_over_image, load_fit_cnfm_and_data, load_fit_cnfm
from utils import start_server, load_params,  log_cnmfe_components

bpl.output_notebook()
c, dview, n_processes = start_server()

raise ValueError("add method to load metrics files and add contours on top of optic flow stuff")


## Load Model and data

In [None]:
# Load everything
fld = "D:\\Dropbox (UCL - SWC)\\Project_vgatPAG\\analysis\\doric\\BF164p1\\19JUN05"
cnm, model_filepath, Yr, dims, T, images, smooth_bg, cn_filter, pnr = load_fit_cnfm_and_data(fld, n_processes, dview, mc_type="els")



In [None]:
# Start logging
fancylog.start_logging(os.path.join(fld, "02_LOGS"), package, file_log_level="INFO", variables=[cnm.params], verbose=True,    filename='cnmfe_component_inspection_logs')
logging.info("Starting CNMF-E component inspection")

# Inspect components

Apply a number of quality controls of the components to keep only the good stuff.
Quality control steps include:

    - evaluate components 
    - threshold spatial components -> remove duplicates
    - remove small/large components [currently not implemented]

there's no undo for these operations, so if you need to restart from scratch you'll need
to re load cnm:

    cnm = load_fit_cnfm(model_path, n_processes, dview)

To check the params used for quality check use:

    cnm.params.quality

In [None]:
cnm.params.quality

In [None]:
print_cnmfe_components(cnm, msg="Before quality control:")
log_cnmfe_components(cnm,  msg="Before quality control")

In [None]:
# Check: https://github.com/flatironinstitute/CaImAn/blob/6c33118a5f55e5e178a0e18c896329f416ffef55/caiman/source_extraction/cnmf/estimates.py#L943

# the components are evaluated in three ways:
#   a) the shape of each component must be correlated with the data
#   b) a minimum peak SNR is required over the length of a transient
#   c) each shape passes a CNN based classifier

# Update params     
_params = load_params(fld)['cnmf_evalutaion']                   
cnm.params.set('quality', _params)

logging.info(f"Components quality estimation, params: {_params}")

# Estimate
cnm.estimates.evaluate_components(images, cnm.params, dview=dview)

# log
print_cnmfe_components(cnm, msg="After evaluate components")
log_cnmfe_components(cnm,  msg="After evaluate components")

### Remove overlapping components

In [None]:
# Estimate spatial components
cnm.estimates.dims = dims
cnm.estimates.threshold_spatial_components(dview=dview)

In [None]:
# Remove large/small and duplicate neurons
# cnm.estimates.remove_small_large_neurons()
# _ = cnm.estimates.remove_duplicates(plot_duplicates=False)

# print_cnmfe_components(cnm, msg="After Remove overlapping")
# log_cnmfe_components(cnm,  msg="After Remove overlapping")

## Visualise Components location

In [None]:
# Visualise good/bad components over the cn_filter image
coordinates = cm.utils.visualization.get_contours(cnm.estimates.A, smooth_bg.shape, thr=.2, thr_method="max")
good_compontents = cnm.estimates.idx_components

f, axarr = plt.subplots(figsize=(15, 10), ncols=2)
for ax, im, ttl in zip(axarr, [cn_filter, cn_filter], ["good", "bad"]):
    plot_components_over_image(im, ax, coordinates, 2, good_compontents, cmap="gray", only=ttl)
axarr[0].set(title="GOOD components")
axarr[1].set(title="BAD components")
clean_axes(f)

img_filepath = os.path.join(fld, "components_not_curated")
save_figure(f, img_filepath)
logging.info(f"Saved image with contours at {img_filepath}")

# MANUAL COMPONENTS CURATION
TODO...

In [None]:
# Plot again contours of components to compare with earlier results
coordinates = cm.utils.visualization.get_contours(cnm.estimates.A, smooth_bg.shape, thr=.2, thr_method="max")
good_compontents = cnm.estimates.idx_components

f, axarr = plt.subplots(figsize=(15, 10), ncols=2)
for ax, im, ttl in zip(axarr, [cn_filter, cn_filter], ["good", "bad"]):
    plot_components_over_image(im, ax, coordinates, 2, good_compontents, cmap="gray", only=ttl)
axarr[0].set(title="GOOD components")
axarr[1].set(title="BAD components")
clean_axes(f)

img_filepath = os.path.join(fld, "components_curated")
save_figure(f, img_filepath)
logging.info(f"Saved image with contours at {img_filepath}")

## Save updated cnmfe model
If you're happy with the results, you can save the model for further analysis


In [None]:
new_model_path = os.path.join(fld, "cnmfe_fit_curated.hdf5")
cnm.save(new_model_path)
logging.info(f"Saving updated model at {new_model_path}")

# Visualise components signal
Create plots and videos to visualise the location and the signals and components

In [None]:

# if you want to look at bad components use: cnm.estimates.idx_components_bad instead of cnm.estimates.idx_components
cnm.estimates.nb_view_components(img=cn_filter, 
                                idx=cnm.estimates.idx_components,
                                denoised_color='red', 
                                cmap='gray')

## Make video with components over motion corrected data

In [None]:
frames_fld = "D:\\Dropbox (UCL - SWC)\\Project_vgatPAG\\analysis\\doric\\Fede\\frames" # folder where frames will be saved before creating video


tot_frames, w, h = images.shape
frames = np.arange(0, tot_frames, 100)

plt.ioff()

# Create an image for each frame and save it
if True:
    for n, fnum in tqdm(enumerate(frames)):
        f, ax = plt.subplots(figsize=(120, 100), dpi=5)
        img = cv2.filter2D(images[fnum, :, :].copy(), -1, kernel/121)
        plot_components_over_image(img/smooth_bg, ax, coordinates, 17)
        f.savefig(os.path.join(fld, f"{n}"), dpi=5)
        plt.close()
        
plt.ion()

# Stitch images into a video
""" 
run this in ffmpeg from the correct folder to get the video out

ffmpeg -i frames\%1d.png -c:v libx264 -vf fps=10 -pix_fmt yuv420p out.mp4
"""

# Stop cluster

In [None]:
cm.stop_server(dview=dview)