## Run CNMF source extraction on movies
Step 2 of the Caiman processing pipeline for dendritic two-photon calcium imaging movies. This part uses mmap files as input. These are created during motion correction with the Caiman toolbox (see `01_Preprocess_MC_3D.ipynb`). 

### Imports & Setup
The first cells import the various Python modules required by the notebook. In particular, a number of modules are imported from the Caiman package. In addition, we also setup the environment so that everything works as expected.

In [None]:
# Generic imports
# from __future__ import absolute_import, division, print_function
# from builtins import *

import os, platform, glob, sys, re, copy
import fnmatch
import json
import time
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import savemat

from IPython.display import clear_output

# Import Bokeh library
import bokeh.plotting as plotting
from bokeh.plotting import Figure, show
from bokeh.layouts import gridplot
from bokeh.models import Range1d, CrosshairTool, HoverTool, Legend
from bokeh.io import output_notebook, export_svgs
from bokeh.models.sources import ColumnDataSource

%matplotlib inline

In [None]:
# This has to be in a separate cell, otherwise it wont work.
from bokeh import resources
output_notebook(resources=resources.INLINE)

In [None]:
# on Linux we have to add the caiman folder to Pythonpath
if platform.system() == 'Linux':
    sys.path.append(os.path.expanduser('~/caiman'))
# environment variables for parallel processing
os.environ['MKL_NUM_THREADS']='1'
os.environ['OPENBLAS_NUM_THREADS']='1'
os.environ['VECLIB_MAXIMUM_THREADS']='1'

In [None]:
# CaImAn imports
import caiman as cm
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.source_extraction.cnmf import params as params
from caiman.components_evaluation import estimate_components_quality as estimate_q
from caiman.components_evaluation import estimate_components_quality_auto
from caiman.utils.visualization import plot_contours, nb_view_patches, nb_plot_contour
from caiman.source_extraction.cnmf import utilities as cnmf_utils
import caiman_utils as cm_utils
import utils as utils

In [None]:
from importlib import reload
reload(cm_utils)
reload(utils)

### Select files
The following need to be specified:
- data_folder ... location of the data (relative to ~/Data)
- mc_output ... select if output of rigid ('rig') or piece-wise rigid ('els') motion correction should be used (currently only 'rig' is tested and works)
- max_files ... maximum number of files to process, e.g. for testing (if 0, all files will be processed)

In [None]:
animal_folder = 'M3_October_2018'
date_folder = 'M3_2018-10-02'
session_folder = 'S1'
group_id = 'G0'

mc_output = 'rig'
remove_bad_frames = True # remove bad frames specified in Json file

# create the complete path to the data folder
if platform.system() == 'Linux':
    data_folder = '/home/ubuntu/Data'
elif platform.system() == 'Darwin':
    data_folder = '/Users/Henry/Data/temp/Dendrites_Gwen'
data_folder = os.path.join(data_folder, animal_folder, date_folder, session_folder)

In [None]:
# select the mmap file created during motion correction
all_files = os.listdir(data_folder)
mmap_files = sorted([x for x in all_files if x.startswith('%s_%s' % (date_folder, session_folder)) 
           and x.endswith('.mmap') and mc_output in x and group_id in x and not 'remFrames' in x])
n_planes = len(mmap_files)

print('Found %d mmap files. Check allocation to planes!' % (n_planes))
for i_plane in range(n_planes):
    print('Plane %d: %s' % (i_plane, mmap_files[i_plane]))

In [None]:
mmap_files = [os.path.join(data_folder, x) for x in mmap_files]
# get metadata
for file in os.listdir(data_folder):
    if fnmatch.fnmatch(file, '%s_%s_Join_%s_*[!badFrames].json' % (date_folder, session_folder, group_id)):
        meta = json.load(open(os.path.join(data_folder,file)))
        break
trial_index = np.array(meta['trial_index'])
frame_rate = meta['frame_rate'] / n_planes

### Load data and remove bad frames

In [None]:
bad_frames = np.array([], dtype='int64')
fname_list = []
images_list = []

# first, create list of bad frame indices (for all planes combined)
for fname in mmap_files:
    bad_frames = np.concatenate((bad_frames, cm_utils.getBadFrames(fname)))
bad_frames = np.unique(bad_frames)

# remove the bad frames from all files
for fname in mmap_files:
    Yr, dims = cm_utils.loadData(fname)
    images, Y, fname_rem, bad_frames_by_trial, trial_idx = cm_utils.removeBadFrames(fname, 
                                                                                      trial_index, 
                                                                                      Yr, dims, bad_frames, 
                                                                                      data_folder)
    fname_list.append(fname_rem)
    images_list.append(images)
trial_index = trial_idx

### Display frame average for each plane

In [None]:
plt.figure(figsize=(30,60))
for ix_plane in range(n_planes):
    avg_img = np.mean(images_list[ix_plane],axis=0)
    plt.subplot(1,n_planes,ix_plane+1)
    plt.imshow(avg_img, cmap='gray'), plt.title('Frame average - Plane %d' % (ix_plane));

### Specify if plane contains dendritic signals
CaImAn uses different initialization methods depending on whether the signals are dendritic or somatic. Therefore, we need to specify the types of signal expected in each plane.

In [None]:
is_dendritic = [True, False, True]

### Setup cluster
The default backend mode for parallel processing is through the multiprocessing package. This will allow us to use all the cores in the VM.

In [None]:
# start the cluster (if a cluster already exists terminate it)
n_processes = 4 # number of compute processes (None to select automatically)
if 'dview' in locals() and dview is not None:
    dview.terminate()
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=n_processes, single_thread=False)

### Parameters for source extraction
Next, we define the important parameters for calcium source extraction. These parameters will have to be iteratively refined for the respective datasets.


In [None]:
# dataset dependent parameters
decay_time = 0.4                            # length of a typical transient in seconds

# parameters for source extraction and deconvolution
p = 1                       # order of the autoregressive system
gnb = 2                     # number of global background components
merge_thresh = 0.8          # merging threshold, max correlation allowed
rf = [25, 50]                     # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
stride_cnmf = 5             # amount of overlap between the patches in pixels
K = 20                       # number of components per patch
gSig = [4, 4]               # expected half size of neurons in pixels

method_init = 'sparse_nmf'  # initialization method (if analyzing dendritic data use 'sparse_nmf', else 'greedy_roi')
#alpha_snmf = 10e2           # sparsity penalty for dendritic data analysis through sparse NMF
alpha_snmf = 100
normalize_init = True      # default is True

ssub = 1                    # spatial subsampling during initialization
tsub = 1                    # temporal subsampling during intialization

# parameters for component evaluation
min_SNR = 2.0               # signal to noise ratio for accepting a component
rval_thr = 0.85              # space correlation threshold for accepting a component
cnn_thr = 0.99              # threshold for CNN based classifier
cnn_lowest = 0.1 # neurons with cnn probability lower than this value are rejected

In [None]:
# create Parameters object
# unspecified parameters get default values
opts_dict = {'fnames': fname_rem,
            'fr': frame_rate,
            'decay_time': decay_time,
            'p': 1,
            'nb': gnb,
            'rf': rf,
            'K': K, 
            'stride': stride_cnmf,
            'method_init': method_init,
            'alpha_snmf': alpha_snmf,
             'normalize_init': normalize_init,
            'rolling_sum': True,
            'only_init': True,
            'ssub': ssub,
            'tsub': tsub,
            'min_SNR': min_SNR,
            'rval_thr': rval_thr,
            'use_cnn': True,
            'min_cnn_thr': cnn_thr,
            'cnn_lowest': cnn_lowest}

opts = params.CNMFParams(params_dict=opts_dict)

To get a dict with all parameters, use `opts.to_dict()`

#### Run CNMF on patches

In [None]:
# First extract spatial and temporal components on patches and combine them
# for this step deconvolution is turned off (p=0)
# Then re-run seeded CNMF on accepted patches to refine and perform deconvolution
opts.set('temporal', {'p': 0})
cnm_list = []

t_start = time.time()
for ix_plane in range(n_planes):
    opts_plane = copy.deepcopy(opts)
    if is_dendritic[ix_plane]:
        opts_plane.set('init', {'method_init': 'sparse_nmf'})
    else:
        opts_plane.set('init', {'method_init': 'greedy_roi'})
    cnm = cnmf.CNMF(n_processes, params=opts_plane, dview=dview)
    cnm.fit(images_list[ix_plane])
     
    cnm.params.set('temporal', {'p': p})
    cnm2 = cnm.refit(images_list[ix_plane], dview=dview)
    
    cnm_list.append(cnm2)
    
    clear_output()
    
t_elapsed = time.time() - t_start
print('\nFinished Source Extract in %1.2f s' % (t_elapsed))

In [None]:
# Evaluate components
for ix_plane, cnm in enumerate(cnm_list):
    cnm.estimates.evaluate_components(images_list[ix_plane], cnm.params, dview=dview)
    cnm_list[ix_plane] = cnm
    print('\nPlane %d' % (ix_plane))
    print('Found %d good / %d bad components\n' % (len(cnm.estimates.idx_components), 
                                                 len(cnm.estimates.idx_components_bad)))

Plot contours of selected and rejected components

In [None]:
for ix_plane, cnm in enumerate(cnm_list):
    Cn = cm.local_correlations(images_list[ix_plane].transpose(1,2,0))
    Cn[np.isnan(Cn)] = 0
#     cnm = cnm_list[plane_ix]
    cnm.estimates.plot_contours_nb(img=Cn, idx=cnm.estimates.idx_components)

View traces of accepted and rejected components. Note that if you get data rate error you can start Jupyter notebooks using: 'jupyter notebook --NotebookApp.iopub_data_rate_limit=1.0e10'

In [None]:
# accepted components
plane_ix = 0

Cn = cm.local_correlations(images_list[plane_ix].transpose(1,2,0))
Cn[np.isnan(Cn)] = 0
cnm = cnm_list[plane_ix]
cnm.estimates.nb_view_components(img=Cn, idx=cnm.estimates.idx_components)

In [None]:
# rejected components
plane_ix = 0

Cn = cm.local_correlations(images_list[plane_ix].transpose(1,2,0))
Cn[np.isnan(Cn)] = 0
cnm = cnm_list[plane_ix]
if len(cnm.estimates.idx_components_bad) > 0:
    cnm.estimates.nb_view_components(img=Cn, idx=cnm.estimates.idx_components_bad)
else:
    print("No components were rejected.")

In [None]:
plane_ix = 0

cnm = cnm_list[plane_ix]
# Unravel results
A, C, b, f, YrA, S, sn = cnm.estimates.A, cnm.estimates.C, cnm.estimates.b, cnm.estimates.f, \
cnm.estimates.YrA, cnm.estimates.S, cnm.estimates.sn
A_dense = A.todense()
idx_comps = cnm.estimates.idx_components
idx_comps_bad = cnm.estimates.idx_components_bad

# A   ... n_pixel x n_components sparse matrix (component locations)
# C   ... n_component x t np.array (fitted signal)
# b   ... ? np.array
# f   ... ? np.array (b / f related to global background components)
# YrA ... n_component x t np.array (residual)
# S   ... deconvolved signal (spike rate(ish))
# sn  ... n_pixel np.array (SNR?)

In [None]:
# Plot good components on background image and as component map
counter = 1
plt.figure(figsize=(20,40));
for i_comp in range(len(idx_comps)):
    plt.subplot(len(idx_comps),2,counter)
    counter += 1
    dummy = cm.utils.visualization.plot_contours(A[:,idx_comps[i_comp]], avg_img, cmap='gray', 
                                                 colors='r', display_numbers=False)
    component_img = np.array(np.reshape(A_dense[:,idx_comps[i_comp]], dims, order='F'))
    plt.subplot(len(idx_comps),2,counter)
    counter += 1
    plt.imshow(component_img), plt.title('Component %1.0f' % (i_comp))
    
plt.tight_layout()

In [None]:
# before re-classification
print('Good components: ')
print(idx_comps)
print('Bad components: ')
print(idx_comps_bad)

In [None]:
comps_to_exclude = [0,1,2,3,4,5,6,9,10,11,12,16] # should be index of the good components (i.e. 0,1,2 as shown in plot above)

# add to bad components
idx_comps_bad = np.sort(np.append(idx_comps_bad, idx_comps[comps_to_exclude]))
# remove from good components
idx_comps = np.delete(idx_comps, comps_to_exclude)

In [None]:
# after re-classification
print('Good components: ')
print(idx_comps)
print('Bad components: ')
print(idx_comps_bad)

In [None]:
# create component_matrix with good components
for i_comp in range(len(idx_comps)):
    component_img = np.array(np.reshape(A_dense[:,idx_comps[i_comp]], dims, order='F'))
    if i_comp == 0:
        component_matrix = component_img
    else:
        component_matrix = np.dstack((component_matrix, component_img))

In [None]:
# saving
npz_name = os.path.join(data_folder, '%s_%s_Join_%s_P%d_results_CNMF.npz' % (date_folder, session_folder, group_id, plane_ix))
np.savez(npz_name, Cn=Cn, A=A.todense(), C=C, b=b, f=f, YrA=YrA, sn=sn, S=S,
         dims=dims, idx_components=idx_comps, idx_components_bad=idx_comps_bad)

#### Extract DF/F values and select high-quality components

In [None]:
for ix_plane, cnm in enumerate(cnm_list):
    cnm = cnm_list[plane_ix]
    cnm.estimates.detrend_df_f(quantileMin=8, frames_window=250)
    cnm.estimates.select_components(use_object=True)
    cnm_list[ix_plane] = cnm

In [None]:
# accepted components
plane_ix = 0

Cn = cm.local_correlations(images_list[plane_ix].transpose(1,2,0))
Cn[np.isnan(Cn)] = 0
cnm = cnm_list[plane_ix]
cnm.estimates.nb_view_components(img=Cn, denoised_color='red')

### Calculate DFF and plot traces
The CaImAn function `detrend_df_f` uses a sliding window percentile filter to determine the baseline and compute DFF.
Note: for noisy traces and / or high levels of activity, `detrend_df_f` seems to produce sometimes unexpected results (i.e. trace whose shape differs a lot from the extracted component traces). It might be better to use the extracted component traces (see below) for downstream analysis. 

In [None]:
F_dff = cnmf_utils.detrend_df_f(A, b, C, f, YrA = YrA, quantileMin=8, frames_window=500)
# select good components
F_dff = F_dff[idx_comps,:]

t = np.arange(0, F_dff.shape[-1]) / frame_rate

In [None]:
# code from nb_view_patches
YrA_good = YrA[idx_comps] # residual - good components
C_good = C[idx_comps] # denoised signal - good components 
Y_r = C_good + YrA_good # ROI signal - good components
S_good = S[idx_comps] # Deconvolved signal - good components

In [None]:
source_files = meta['source_file']
source_frames = np.array(meta['source_frames'])

# get corresponding trial name for each frame
trial_names = [x.replace('_crop.tif','') for x in source_files]
trial_names_frames = [trial_names[x] for x in trial_index]

### Create stacked plot of all components
This plot also shows the trial for each frame.

In [None]:
comp_idx = [0,1,2] # select index of components to plot, e.g. [0,1,2] / use None to plot all components
source = 'Y_r' # select the data that should be plotted ('F_dff', 'Y_r', 'C_good', 'S_good')

if source == 'F_dff':
    source_data = F_dff
elif source == 'Y_r':
    source_data = Y_r
elif source == 'C_good':
    source_data = C_good
elif source == 'S_good':
    source_data = S_good
else:
    raise Exception('Specified source_data is not implemented')

if comp_idx is not None:
    source_data_plot = source_data[comp_idx,:]
    ix = idx_comps[comp_idx]
else:
    source_data_plot = source_data
    ix = idx_comps

p = Figure(plot_width=900, plot_height=600, 
           title=('%s %s CNMF Results' % (date_folder, session_folder)))    
legend_text = ['Component %1d' % (x) for x in range(len(ix))]
# this is the call to the plotting function (change args. as required)
utils.plotTimeseries(p, t, source_data_plot, 
               legend=legend_text, 
               stack=True, 
               xlabel='Time [s]', ylabel=source,
               output_backend='canvas',
               trial_index=trial_index,
               trial_names_frames=trial_names_frames
              )

### Split up by trials and save as .mat

In [None]:
# check if our numbers match
if not (np.sum(source_frames)-len(bad_frames)) == F_dff.shape[-1]:
    raise Exception('Sum of source frames minus number of bad frames must be equal to number of timepoints.')

In [None]:
results_dff = dict()
results_Yr = dict()
results_C = dict()
results_S = dict()
removed_frames = dict()
for ix, trial_file in enumerate(source_files):
    # get indices for current trial's frames
    trial_indices = np.where(trial_index==ix)[0]
    
    if ix in bad_frames_by_trial:
        removed_frames_trial = bad_frames_by_trial[ix]
    else:
        removed_frames_trial = []
    
    # create a valid Matlab variable / field name
    field_name = str('x' + source_files[ix][:source_files[ix].find('/')]).replace('_Live','').replace('-','_')
    results_dff[field_name] = F_dff[:,trial_indices]
    results_Yr[field_name] = Y_r[:,trial_indices]
    results_C[field_name] = C_good[:,trial_indices]
    results_S[field_name] = S_good[:,trial_indices]
    removed_frames[field_name] = removed_frames_trial

In [None]:
# prepare the dictionary for saving as mat file
# the field names will be the variable names in Matlab
mdict={'trials': [str(x) for x in source_files], 
       'dff_trial': results_dff,
       'Yr_trial': results_Yr,
       'C_trial': results_C,
       'Deconv_trial': results_S,
       'removed_frames': removed_frames,
       'mean_image': avg_img,
       'spatial_components': component_matrix
      }

In [None]:
# save the .mat file
matfile_name = os.path.join(data_folder, '%s_%s_Join_%s_P%d_results_CNMF.mat' % (date_folder, session_folder, group_id, plane_ix))
savemat(os.path.join(data_folder, matfile_name), mdict=mdict, long_field_names=True)