## Run CNMF source extraction on movies
Step 2 of the Caiman processing pipeline for dendritic two-photon calcium imaging movies. This part uses mmap files as input. These are created during motion correction with the Caiman toolbox (see `01_Preprocess_MC_3D.ipynb`). 

### Imports & Setup
The first cells import the various Python modules required by the notebook. In particular, a number of modules are imported from the Caiman package. In addition, we also setup the environment so that everything works as expected.

In [None]:
# Generic imports
# from __future__ import absolute_import, division, print_function
# from builtins import *

import os, platform, glob, sys, re
import fnmatch
import json
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import savemat

from IPython.display import clear_output

# Import Bokeh library
from bokeh.plotting import Figure, show
from bokeh.layouts import gridplot
from bokeh.models import Range1d, CrosshairTool, HoverTool, Legend
from bokeh.io import output_notebook, export_svgs
from bokeh.models.sources import ColumnDataSource

%matplotlib inline

In [None]:
# on Linux we have to add the caiman folder to Pythonpath
if platform.system() == 'Linux':
    sys.path.append(os.path.expanduser('~/caiman'))
# environment variables for parallel processing
os.environ['MKL_NUM_THREADS']='1'
os.environ['OPENBLAS_NUM_THREADS']='1'
os.environ['VECLIB_MAXIMUM_THREADS']='1'

In [None]:
# CaImAn imports
import caiman as cm
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.components_evaluation import estimate_components_quality as estimate_q
from caiman.components_evaluation import estimate_components_quality_auto
from caiman.utils.visualization import plot_contours, nb_view_patches, nb_plot_contour
from caiman.source_extraction.cnmf import utilities as cnmf_utils
import caiman_utils as cm_utils

### Select files and parameters
The following need to be specified:
- data_folder ... location of the data (relative to ~/Data)
- mc_output ... select if output of rigid ('rig') or piece-wise rigid ('els') motion correction should be used (currently only 'rig' is tested and works)
- max_files ... maximum number of files to process, e.g. for testing (if 0, all files will be processed)

In [None]:
animal_folder = 'M3_October_2018'
date_folder = 'M3_2018-10-02'
session_folder = 'S1'
group_id = 'G0'

mc_output = 'rig'
remove_bad_frames = True # remove bad frames specified in Json file

# create the complete path to the data folder
if platform.system() == 'Linux':
    data_folder = '/home/ubuntu/Data'
elif platform.system() == 'Darwin':
    data_folder = '/Users/Henry/Data/temp/Dendrites_Gwen'
data_folder = os.path.join(data_folder, animal_folder, date_folder, session_folder)
data_folder

In [None]:
# select the mmap file created during motion correction
all_files = os.listdir(data_folder)
mmap_files = sorted([x for x in all_files if x.startswith('%s_%s' % (date_folder, session_folder)) 
           and x.endswith('.mmap') and mc_output in x and group_id in x])
n_planes = len(mmap_files)

print('Found %d mmap files. Check allocation to planes!' % (n_planes))
for i_plane in range(n_planes):
    print('Plane %d: %s' % (i_plane, mmap_files[i_plane]))

In [None]:
mmap_files = [os.path.join(data_folder, x) for x in mmap_files]
# get metadata
for file in os.listdir(data_folder):
    if fnmatch.fnmatch(file, '%s_%s_Join_%s_*[!badFrames].json' % (date_folder, session_folder, group_id)):
        meta = json.load(open(os.path.join(data_folder,file)))
        break
trial_index = np.array(meta['trial_index'])
frame_rate = meta['frame_rate']

### Setup cluster
The default backend mode for parallel processing is through the multiprocessing package. This will allow us to use all the cores in the VM.

In [None]:
# start the cluster (if a cluster already exists terminate it)
n_processes = 4 # number of compute processes (None to select automatically)
if 'dview' in locals():
    dview.terminate()
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='ipyparallel', n_processes=n_processes, single_thread=False)

### Parameters for source extraction
Next, we define the important parameters for calcium source extraction. These parameters will have to be iteratively refined for the respective datasets.

The parameters are stored in the Python dictionary `params_cnmf`.

In [None]:
# parameters for source extraction and deconvolution
decay_time = 0.4            # length of a typical transient in seconds
p = 1                       # order of the autoregressive system (normally 1, 2 for fast indicators / imaging)
gnb = 2                     # number of global background components
merge_thresh = 0.8          # merging threshold, max correlation allowed
rf = 10                     # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50 / None: no patches
stride_cnmf = 5             # amount of overlap between the patches in pixels
K = None                    # number of components per patch (usually None)
gSig = [4, 4]               # expected half size of neurons
init_method = 'sparse_nmf'  # initialization method (if analyzing dendritic data use 'sparse_nmf')
is_dendrites = True         # flag for analyzing dendritic data
#alpha_snmf = 10e2           # sparsity penalty for dendritic data analysis through sparse NMF
alpha_snmf = 1e-6

method_deconvolution='oasis'# deconvolution method (oasis or cvxpy)

# parameters for component evaluation
min_SNR = 2.5               # signal to noise ratio for accepting a component
rval_thr = 0.8              # space correlation threshold for accepting a component
use_cnn = False             # whether to use CNN to filter components
cnn_thr = 0.8               # threshold for CNN based classifier

final_rate = frame_rate             # final frame rate in Hz

### Run CNMF on joined mmap file
According to the Caiman tutorials, CNMF source extraction should be run in several stages. First, we run CNMF on patches of the dataset and then evaluate the quality of the extracted components. Next, CNMF is run again but on the full field-of-view. Finally, the extracted components are again evaluated and classified (good and bad).

Load mmap file.

In [None]:
def getBadFramesByTrial(bad_frames, trial_index):

    bad_frames_by_trial = dict()
    for ix, i_bad in enumerate(bad_frames):
        trial_index_bad = trial_index[bad_frames[ix]]
        ix_from_trial_start = bad_frames[ix] - np.where(trial_index==trial_index[bad_frames[ix]])[0][0]
#         print('Trial / Frame from trial start: %1.0f / %1.0f' % (trial_index_bad, ix_from_trial_start))
        if trial_index_bad in bad_frames_by_trial:
            bad_frames_by_trial[trial_index_bad] = bad_frames_by_trial[trial_index_bad] + [ix_from_trial_start]
        else:
            bad_frames_by_trial[trial_index_bad] = [ix_from_trial_start]
    
    return bad_frames_by_trial

In [None]:
# load data
Yr, dims, T = cm.load_memmap(fname)
d1, d2 = dims

# offset if movie is negative
if np.min(Yr) < 0:
    Yr = Yr - np.min(Yr)
    
# if file is in F order, convert to C order (required for cnmf)
# TODO: check if this can be done more efficienctly by saving directly in C orde
if np.isfortran(Yr):
    Yr = np.ascontiguousarray(Yr)

images = np.reshape(Yr.T, [T] + list(dims), order='F')
Y = np.reshape(Yr, dims + (T,), order='F')

In [None]:
# remove bad frames specified in corresponding Json file
bad_frames = np.array([])
bad_frames_by_trial = dict()
if remove_bad_frames:
    bad_frames = json.load(open(fname.replace('.mmap','badFrames.json')))
    bad_frames = np.array(bad_frames['frames'])
    bad_frames_by_trial = getBadFramesByTrial(bad_frames, trial_index)
    Yr = np.delete(Yr, bad_frames, axis=1)
    trial_index = np.delete(trial_index, bad_frames, axis=0)
    T = Yr.shape[1]
    images = np.reshape(Yr.T, [T] + list(dims), order='F')
    # make sure movie is not negative
    add_to_movie = - np.min(images)
    fname_new = cm.save_memmap([images], base_name=os.path.join(data_folder, 'removedFrames'), 
                               add_to_movie=add_to_movie, order='C')
    Yr, dims, T = cm.load_memmap(fname_new)
    images = np.reshape(Yr.T, [T] + list(dims), order='F')
    Y = np.reshape(Yr, dims + (T,), order='F')
    print('Deleted %1d frames. Saved to new file %s.' % (len(bad_frames), os.path.basename(fname_new)))
    print('Deleted frames:')
    print(bad_frames)

In [None]:
#  checks on movies (might take time if large!)
# if np.min(images) < 0:
#     add_to_movie = - np.min(images)
#     fname_new = cm.save_memmap([images], base_name=os.path.join(data_folder, 'Yr'), 
#                                add_to_movie=add_to_movie)
#     Yr, dims, T = cm.load_memmap(fname_new)
#     images = np.reshape(Yr.T, [T] + list(dims), order='F')
#     Y = np.reshape(Yr, dims + (T,), order='F')

if np.sum(np.isnan(images)) > 0:
    raise Exception('Movie contains nan! You did not remove enough borders')

# correlation image
Cn = cm.local_correlations(Y)
Cn[np.isnan(Cn)] = 0

In [None]:
# sanity check the image
avg_img = np.mean(images,axis=0)
plt.figure(figsize=(10,20))
plt.subplot(1,2,1)
plt.imshow(avg_img, cmap='gray'), plt.title('Frame average');
plt.subplot(1,2,2)
plt.imshow(Cn, cmap='jet'), plt.title('Correlation image');

### CNMF source extraction

#### Iterative CNMF first on patches, then full FoV
This routine runs an initial CNMF on patches (in parallel), followed by classification into good and bad components. The good components are then re-run through the CNMF algorithm. This procedure follows the suggested workflow in the CaImAn tutorials.

In [None]:
cnmf_out, idx_comps, idx_comps_bad = cm_utils.run_cnmf_iterative(images, frame_rate, decay_time, dims, n_processes, K, 
                                                              gSig, merge_thresh, p, dview, rf, stride_cnmf, 
                                                              init_method, alpha_snmf, gnb, method_deconvolution, 
                                                              min_SNR, rval_thr, use_cnn, cnn_thr)
# clear_output()

#### Run single CNMF on full FoV, without patches
Alternatively, one can also run a single CNMF on the full FoV. Since there are no patches, the source extraction is not run in parallel. After CNMF, components are evaluated as above.

In [None]:
cnmf_out, idx_comps, idx_comps_bad = cm_utils.run_cnmf_single(images, frame_rate, decay_time, dims, n_processes, K, 
                                                              gSig, merge_thresh, p, dview, rf, stride_cnmf, 
                                                              init_method, alpha_snmf, gnb, method_deconvolution, 
                                                              min_SNR, rval_thr, use_cnn, cnn_thr)

In [None]:
print('Number of detected components')
print('Total: %1.0f' % (len(idx_comps)+len(idx_comps_bad)))
print('Good: %1.0f' % (len(idx_comps)))
print('Bad: %1.0f' % (len(idx_comps_bad)))

In [None]:
# Unravel results
A, C, b, f, YrA, S, sn = cnmf_out.A, cnmf_out.C, cnmf_out.b, cnmf_out.f, cnmf_out.YrA, cnmf_out.S, cnmf_out.sn
# A   ... n_pixel x n_components sparse matrix (component locations)
# C   ... n_component x t np.array (fitted signal)
# b   ... ? np.array
# f   ... ? np.array (b / f related to global background components)
# YrA ... n_component x t np.array (residual)
# S   ... deconvolved signal (spike rate(ish))
# sn  ... n_pixel np.array (SNR?)

In [None]:
# Plot good and bad components
plt.figure(figsize=(20,30));
plt.subplot(121); crd_good = cm.utils.visualization.plot_contours(A[:,idx_comps], Cn, thr=.8, vmax=0.75)
plt.title('Contour plots of accepted components')
plt.subplot(122); crd_bad = cm.utils.visualization.plot_contours(A[:,idx_comps_bad], Cn, thr=.8, vmax=0.75)
plt.title('Contour plots of rejected components')

In [None]:
# Plot good components on background image and as component map
A_dense = A.todense()
counter = 1
plt.figure(figsize=(20,40));
for i_comp in range(len(idx_comps)):
    plt.subplot(len(idx_comps),2,counter)
    counter += 1
    dummy = cm.utils.visualization.plot_contours(A[:,idx_comps[i_comp]], avg_img, cmap='gray', 
                                                 colors='r', display_numbers=False)
    component_img = np.array(np.reshape(A_dense[:,idx_comps[i_comp]], dims, order='F'))
    plt.subplot(len(idx_comps),2,counter)
    counter += 1
    plt.imshow(component_img), plt.title('Component %1.0f' % (i_comp))
    
plt.tight_layout()

### Manually reclassify components
Exclude some 'good' components (select the index, i.e. 0,1,2 as shown in the plot above). These will be added to the list of bad components

In [None]:
# before re-classification
print('Good components: ')
print(idx_comps)
print('Bad components: ')
print(idx_comps_bad)

In [None]:
comps_to_exclude = [20,21,22,23,24,25,26,27,28,29,30] # should be index of the good components (i.e. 0,1,2 as shown in plot above)

# add to bad components
idx_comps_bad = np.sort(np.append(idx_comps_bad, idx_comps[comps_to_exclude]))
# remove from good components
idx_comps = np.delete(idx_comps, comps_to_exclude)

In [None]:
# after re-classification
print('Good components: ')
print(idx_comps)
print('Bad components: ')
print(idx_comps_bad)

In [None]:
# create component_matrix with good components
for i_comp in range(len(idx_comps)):
    component_img = np.array(np.reshape(A_dense[:,idx_comps[i_comp]], dims, order='F'))
    if i_comp == 0:
        component_matrix = component_img
    else:
        component_matrix = np.dstack((component_matrix, component_img))

### Save and view results

In [None]:
# saving
npz_name = os.path.join(data_folder, '%s_%s_Join_%s_results_CNMF.npz' % (date_folder, session_folder, group_id))
np.savez(npz_name, Cn=Cn, A=A.todense(), C=C, b=b, f=f, YrA=YrA, sn=sn, S=S,
         d1=d1, d2=d2, idx_components=idx_comps, idx_components_bad=idx_comps_bad)

Interactive viewer for traces of accepted and rejected components.

In [None]:
# This has to be in a separate cell, otherwise it wont work.
from bokeh import resources
output_notebook(resources=resources.INLINE)

In [None]:
# accepted components
if len(idx_comps) > 0:
    nb_view_patches(Yr, A.tocsc()[:, idx_comps], C[idx_comps], 
                    b, f, dims[0], dims[1], YrA=YrA[idx_comps], image_neurons = Cn,
                    denoised_color = 'red');
else:
    print("No accepted components!")

### Calculate DFF and plot traces
The CaImAn function `detrend_df_f` uses a sliding window percentile filter to determine the baseline and compute DFF.
Note: for noisy traces and / or high levels of activity, `detrend_df_f` seems to produce sometimes unexpected results (i.e. trace whose shape differs a lot from the extracted component traces). It might be better to use the extracted component traces (see below) for downstream analysis. 

In [None]:
F_dff = cnmf_utils.detrend_df_f(A, b, C, f, YrA = YrA, quantileMin=8, frames_window=500)
# select good components
F_dff = F_dff[idx_comps,:]

t = np.arange(0, F_dff.shape[-1]) / frame_rate

### Extract component traces

In [None]:
# code from nb_view_patches
YrA_good = YrA[idx_comps] # residual - good components
C_good = C[idx_comps] # denoised signal - good components 
Y_r = C_good + YrA_good # ROI signal - good components
S_good = S[idx_comps] # Deconvolved signal - good components
x = np.arange(T)

### Plot DFF and component trace for selected component

In [None]:
comp_to_plot = 0 # index of component to plot

# create the data source for components
# Note: division by 100 in nb_view_patches code!
source = ColumnDataSource(data=dict(x=x, y=Y_r[comp_to_plot], y2=C_good[comp_to_plot]))

p1 = Figure(plot_width=800, plot_height=300, title='Caiman components')
# plot ROI signal in blue
p1.line('x', 'y', source=source, line_width=1, line_alpha=0.6, color='blue')
# plot denoised in red
p1.line('x', 'y2', source=source, line_width=1, line_alpha=0.6, color='red')

# create the data source for DFF
source_dff = ColumnDataSource(data=dict(x=x, y=F_dff[comp_to_plot]))

p2 = Figure(plot_width=800, plot_height=300, title='Caiman DFF')
# plot ROI signal in blue
p2.line('x', 'y', source=source_dff, line_width=1, line_alpha=0.6, color='blue')

# make a grid
grid = gridplot([[p1], [p2]], sizing_mode='fixed', toolbar_location='left')

show(grid)

### Create stacked plot of all components
This plot also shows the trial for each frame.

In [None]:
source_files = meta['source_file']
source_frames = np.array(meta['source_frames'])

In [None]:
# get corresponding trial name for each frame
trial_names = [x.replace('_crop.tif','') for x in source_files]
trial_names_frames = [trial_names[x] for x in trial_index]

Next, we need to define some functions for plotting.

In [None]:
def getHover():
    """Define and return hover tool for a plot"""
    # Define hover tool
    hover = HoverTool()
    hover.tooltips = [
        ("index", "$index"),
        ("(x,y)", "($x, $y)"),
        ("trial", "@trial_idx (@trial_name)"),
    ]
    return hover

In [None]:
def plotTimeseries(p, t, y, legend=None, stack=True, xlabel='', ylabel='', output_backend='canvas', 
                   trial_index=trial_index, trial_names_frames=trial_names_frames):
    """
    Plot a timeseries in Figure p using the Bokeh library
    
    Input arguments:
    p ... Bokeh figure
    t ... 1d time axis vector (numpy array)
    y ... 2d data numpy array (number of traces x time)
    legend ... list of items to be used as figure legend
    stack ... whether to stack traces or nor (True / False)
    xlabel ... label for x-axis
    ylabel ... label for y-axis
    output_backend ... 'canvas' or 'svg'
    trial_index ... trial index for each frame
    trial_names_frames ... trial name for each frame
    """
    
    colors_list = ['red', 'green', 'blue', 'yellow', 'cyan', 'orange', 'magenta', 'black', 'gray']
    p.add_tools(CrosshairTool(), getHover())
    
    offset = 0
    for i in range(y.shape[0]):
        if len(colors_list) < i+1:
            colors_list = colors_list + colors_list
        
        plot_trace = y[i, :]
        if stack:
            plot_trace = plot_trace - min(plot_trace) + offset
            offset = max(plot_trace)
        
        # create ColumnDataSource
        data = {
            'x': t, 
            'y': plot_trace,
            'trial_idx': trial_index,
            'trial_name': trial_names_frames
        }
        data_source = ColumnDataSource(data)

        # add line
        p.line('x', 'y', source=data_source, line_width=2, legend=legend[i], color=colors_list[i])
        
#     p.legend.location = (0,-30)
    p.legend.click_policy="hide"
    
    # format plot
    p.xaxis.axis_label = xlabel
    p.yaxis.axis_label = ylabel
    
    p.x_range = Range1d(np.min(t), np.max(t))
    
    p.background_fill_color = None
    p.border_fill_color = None
    
    p.output_backend = output_backend

    show(p)
    
    return p

The next cell plots the figure. The function `plotTimeseries` can plot traces on top of each other (`stack=False`) or stacked (`stack=True`). The interactive toolbar on the right of the figures allows panning, zooming, saving etc. One can also hide traces by clicking the corresponding legend item. To save the figure, click the disk icon in the plotting toolbar. With the default `output_backend` ('canvas'), a png file will be saved. To save to svg format, change `output_backend` to 'svg'.

Here we can also select to plot only a subset of good components and what type of data to plot (DFF, Y_r or C).

In [None]:
comp_idx = [0,1,2] # select index of components to plot, e.g. [0,1,2] / use None to plot all components
source = 'S_good' # select the data that should be plotted ('F_dff', 'Y_r', 'C_good', 'S_good')

if source == 'F_dff':
    source_data = F_dff
elif source == 'Y_r':
    source_data = Y_r
elif source == 'C_good':
    source_data = C_good
elif source == 'S_good':
    source_data = S_good
else:
    raise Exception('Specified source_data is not implemented')

if comp_idx is not None:
    source_data_plot = source_data[comp_idx,:]
    ix = idx_comps[comp_idx]
else:
    source_data_plot = source_data
    ix = idx_comps

p = Figure(plot_width=900, plot_height=600, 
           title=('%s %s CNMF Results' % (date_folder, session_folder)))    
legend_text = ['Component %1d' % (x) for x in range(len(ix))]
# this is the call to the plotting function (change args. as required)
plotTimeseries(p, t, source_data_plot, 
               legend=legend_text, 
               stack=True, 
               xlabel='Time [s]', ylabel=source,
               output_backend='canvas',
               trial_index=trial_index,
               trial_names_frames=trial_names_frames
              )

### Split up by trials and save as .mat

In [None]:
# check if our numbers match
if not (np.sum(source_frames)-len(bad_frames)) == F_dff.shape[-1]:
    raise Exception('Sum of source frames minus number of bad frames must be equal to number of timepoints.')

In [None]:
results_dff = dict()
results_Yr = dict()
results_C = dict()
results_S = dict()
removed_frames = dict()
for ix, trial_file in enumerate(source_files):
    # get indices for current trial's frames
    trial_indices = np.where(trial_index==ix)[0]
    
    if ix in bad_frames_by_trial:
        removed_frames_trial = bad_frames_by_trial[ix]
    else:
        removed_frames_trial = []
    
    # create a valid Matlab variable / field name
    field_name = str('x' + source_files[ix]).replace('.tif','').replace('-','_')
    results_dff[field_name] = F_dff[:,trial_indices]
    results_Yr[field_name] = Y_r[:,trial_indices]
    results_C[field_name] = C_good[:,trial_indices]
    results_S[field_name] = S_good[:,trial_indices]
    removed_frames[field_name] = removed_frames_trial

In [None]:
# prepare the dictionary for saving as mat file
# the field names will be the variable names in Matlab
mdict={'trials': [str(x) for x in source_files], 
       'dff_trial': results_dff,
       'Yr_trial': results_Yr,
       'C_trial': results_C,
       'Deconv_trial': results_S,
       'removed_frames': removed_frames,
       'mean_image': avg_img,
       'spatial_components': component_matrix
      }

In [None]:
# save the .mat file
matfile_name = os.path.join(data_folder, '%s_%s_Join_%s_results_CNMF.mat' % (date_folder, session_folder, group_id))
savemat(os.path.join(data_folder, matfile_name), mdict=mdict)