## Temporary notebook for tuning of CNMF source extraction parameters
In this notebook, a single-plane MMAP file is specified as input. The MMAP file should be concatenated across trials and motion corrected.

## Imports

In [None]:
import os, sys, time, copy
import numpy as np
import scipy.spatial.distance as distance
import matplotlib.pyplot as plt
import ipyparallel

%matplotlib inline

In [None]:
import caiman as cm
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.source_extraction.cnmf import params as params

In [None]:
os.environ['MKL_NUM_THREADS']='1'
os.environ['OPENBLAS_NUM_THREADS']='1'

### Input file

In [None]:
data_folder = '/Users/Henry/Data/temp/Dendrites_Gwen/M5.2/20181211/S1/parameterTuning'
# data_folder = '/home/luetcke/test/M5.2/20181211/S1/parameterTuning'
# mmap_file = '20181211_S1_Join_G0_F707_P0_rig_remFrames_d1_84_d2_508_d3_1_order_C_frames_704_.mmap'
frame_rate = 10.2

In [None]:
# this cell is tagged as `parameters`
mmap_file = '20181211_S1_Join_G0_F707_P0_rig_remFrames_d1_84_d2_508_d3_1_order_C_frames_704_.mmap'

In [None]:
mmap_file = os.path.join(data_folder, mmap_file)

print('\nUsing input file %s\n' % (mmap_file))

### Compute params

In [None]:
ncpus = 4

### Setup cluster

In [None]:
%%bash -s "$ncpus"
source /opt/Anaconda3-5.1.0-Linux-x86_64/bin/activate caiman || source activate caiman
ipcluster stop
sleep 5
ipcluster start --daemonize -n $1

In [None]:
time.sleep(5)
# connect client
client = ipyparallel.Client()
time.sleep(2)
while len(client) < ncpus:
    sys.stdout.write(".")  # Give some visual feedback of things starting
    sys.stdout.flush()     # (de-buffered)
    time.sleep(0.5)

# create dview object
client.direct_view().execute('__a=1', block=True)
dview = client[:]
n_processes = len(client)
print('\n\nThe cluster appears to be setup. Number of parallel processes: %d' % (n_processes))

### Parameters for source extraction
Define the important parameters for calcium source extraction.

In [None]:
# dataset dependent parameters
decay_time = 0.4                            # length of a typical transient in seconds

# parameters for source extraction and deconvolution
p = 1                         # order of the autoregressive system
gnb = 2                       # number of global background components
merge_thresh = 0.8            # merging threshold, max correlation allowed
rf = [7,14]                   # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
rf = None
stride_cnmf = 3               # amount of overlap between the patches in pixels
K = 20                        # max. number of components per patch
gSig = [7,35]                 # expected half size of neurons in pixels

method_init = 'sparse_nmf'    # initialization method (if analyzing dendritic data use 'sparse_nmf', else 'greedy_roi')
#alpha_snmf = 10e2            # sparsity penalty for dendritic data analysis through sparse NMF
alpha_snmf = 1
normalize_init = True         # default is True
sigma_smooth_snmf = (0.5, 1, 1) # defaults to (0.5, 0.5, 0.5)
max_iter_snmf = 500           # defaults to 500

ssub = 1                      # spatial subsampling during initialization
tsub = 1                      # temporal subsampling during intialization

In [None]:
# create Parameters object
# unspecified parameters get default values
opts_dict = {'fnames': [mmap_file],
             'fr': frame_rate,
            'decay_time': decay_time,
            'p': p,
            'nb': gnb,
            'rf': rf,
            'K': K,
             'gSig': gSig,
            'stride': stride_cnmf,
            'method_init': method_init,
            'alpha_snmf': alpha_snmf,
            'normalize_init': normalize_init,
            'sigma_smooth_snmf': sigma_smooth_snmf,
            'max_iter_snmf': max_iter_snmf,
            'rolling_sum': True,
            'only_init': True,
            'ssub': ssub,
            'tsub': tsub}

opts = params.CNMFParams(params_dict=opts_dict)

### Prepare data for source extraction

In [None]:
Yr, dims, T = cm.load_memmap(mmap_file)
images = np.reshape(Yr.T, [T] + list(dims), order='F')

### Display frame average

In [None]:
plt.figure(figsize=(30,30))
avg_img = np.mean(images,axis=0)
plt.imshow(avg_img, cmap='gray'), plt.title('Frame average', fontsize=32);

In [None]:
frame_ix = 100
plt.figure(figsize=(30,30))
plt.imshow(images[frame_ix,:,:], cmap='gray'), plt.title('Frame %d' % (frame_ix), fontsize=32);

### Run CNMF

In [None]:
t_start = time.time()

# Extract spatial and temporal components on patches and combine them
# for this step deconvolution is turned off (p=0)
opts.set('temporal', {'p': 0})
cnm = cnmf.CNMF(n_processes, params=opts, dview=dview)
cnm.fit(images)
     
# Re-run seeded CNMF on accepted patches to refine and perform deconvolution
cnm.params.set('temporal', {'p': p})
cnm2 = cnm.refit(images, dview=dview)

t_elapsed = time.time() - t_start
print('\nFinished Source Extract in %1.2f s' % (t_elapsed))

### Evaluate components

In [None]:
cnm2.dview = None
cnm_final = copy.deepcopy(cnm2)

In [None]:
# Parameters for evaluation
quality_params = {
    'min_SNR': 3,               # signal to noise ratio for accepting a component
    'rval_thr': 0.99,           # space correlation threshold for accepting a component
    'use_cnn': False,           # use CNN classifier
    'cnn_thr': 0.95,            # threshold for CNN based classifier
    'cnn_lowest': 0.1           # neurons with cnn probability lower than this value are rejected
}

opts.set('quality', quality_params)
cnm_final.estimates.evaluate_components(images, opts, dview=dview)
print('Found %d good / %d bad components\n' % (len(cnm_final.estimates.idx_components), 
                                               len(cnm_final.estimates.idx_components_bad)))

In [None]:
# calculate correlation image
cc = cm.local_correlations(images.transpose(1,2,0))
cc[np.isnan(cc)] = 0

In [None]:
def nb_view_components(cnm_list, img, good_or_bad='good'):
    '''
    View components in Caiman slider plot for different planes. 
    Choose 'good' or 'bad' to display accepted or rejected components.
    '''
    for ix_plane, cnm in enumerate(cnm_list):
        show_plot = True
        print('Plane %d' % (ix_plane))
        if good_or_bad == 'good':
            component_list = cnm.estimates.idx_components
        elif good_or_bad == 'bad':
            component_list = cnm.estimates.idx_components_bad
        
        if len(component_list) == 0:
            print('No valid %s components in this plane!\n\n' % (good_or_bad))
            show_plot = False
        elif len(component_list) == 1: # adress caiman bug if only 1 component
            print('Found 1 %s component. Duplicating due to Caiman bug.' % (good_or_bad))
            component_list = np.append(component_list, component_list[0])
        if show_plot:
            cnm.estimates.nb_view_components(img=img, idx=component_list)

### Interactive component plot

In [None]:
# accepted components
# nb_view_components([cnm_final], cc, good_or_bad='good')

In [None]:
# rejected components
# nb_view_components([cnm_final], cc, good_or_bad='bad')

In [None]:
def plot_component_contours(cnm, images, idx_comps, data_folder, file_stem, label='good'):
    avg_img = np.mean(images,axis=0)
    
    A = cnm.estimates.A
    A_dense = A.todense()

    counter = 1
    plt.figure(figsize=(30,30));
    for i_comp in range(len(idx_comps)):
        plt.subplot(len(idx_comps),2,counter)
        if counter == 1:
            plt.title('CNMF Components', fontsize=24);

        counter += 1
        dummy = cm.utils.visualization.plot_contours(A[:,idx_comps[i_comp]], avg_img, cmap='gray', 
                                                     colors='r', display_numbers=False)
        component_img = np.array(np.reshape(A_dense[:,idx_comps[i_comp]], avg_img.shape, order='F'))
        plt.subplot(len(idx_comps),2,counter)
        counter += 1
        plt.imshow(component_img), plt.title('Component %1.0f' % (i_comp), fontsize=24)

    plt.tight_layout()

    out_file = '%sComponentContours_%s_test.png' % (file_stem, label)
    fig_name = os.path.join(data_folder, out_file)

    plt.savefig(fig_name)
    plt.show()

In [None]:
file_stem = os.path.basename(mmap_file).replace('.mmap','')

In [None]:
plot_component_contours(cnm_final, images, cnm_final.estimates.idx_components, data_folder, file_stem, label='good')

### Component post-processing

First try to remove components consisting of many small spots spread over a large part of the field-of-view. This is done in two ways:
1. Calculate component sparsity (i.e. the fraction of pixels with 0)
2. The cosine distance between non-zero pixels.

For good components, the sparsity should be high (i.e. > 0.99) and the distance between component pixels should be small (i.e. < 0.01)

To make the distinction clearer, it helps to threshold the component map before, i.e. at 10% of the max. value.

In [None]:
A = cnm_final.estimates.A
A_dense = A.todense()
idx_components = cnm_final.estimates.idx_components

In [None]:
# component sparsity ... fraction of pixels with 0
# distance between non-zero pixels
thresh = 0.1 # threshold at thresh*max
sparsity = []
dist = []
for ix in idx_components:
    component_img = np.array(np.reshape(A_dense[:,ix], avg_img.shape, order='F'))
    component_img[component_img < thresh*np.max(component_img)] = 0
    zeros = np.where(component_img==0)
    sparsity.append((zeros[0].shape / np.prod(avg_img.shape))[0])
    dist.append(distance.pdist(np.nonzero(component_img), metric='cosine')[0])

In [None]:
sparsity_threshold = 0.99 # components with less sparsity will be excluded
distance_threshold = 0.01 # components with larger average cosine distance will be excluded

In [None]:
idx_components_new = [x for (ix,x) in enumerate(idx_components) if sparsity[ix]>sparsity_threshold and dist[ix]<distance_threshold]
idx_components_new

In [None]:
plot_component_contours(cnm_final, images, idx_components_new, data_folder, file_stem, label='processed')

In [None]:
from scipy.signal import correlate2d as correlate2d

In [None]:
# component cross-correlation
for ix1 in idx_components_new:
    for ix2 in idx_components_new:
        if ix2 <= ix1:
            continue
        comp1 = np.array(np.reshape(A_dense[:,ix1], [1,np.prod(avg_img.shape)]))
        comp2 = np.array(np.reshape(A_dense[:,ix2], [1,np.prod(avg_img.shape)]))
        cc = np.corrcoef(comp1, comp2)[0][1]
        print('Component %d - %d:\t%1.3f' % (ix1,ix2,cc))

In [None]:
A_good = A_dense[:,idx_components_new]
A_good.shape

In [None]:
# PCA on spatial components
from sklearn.decomposition import PCA
pca = PCA(n_components=A_good.shape[1])
pca.fit(A_good)

In [None]:
plt.plot(np.arange(1,A_good.shape[1]+1), np.cumsum(pca.explained_variance_ratio_),'o-')
plt.xlabel('number of components')
plt.ylabel('cumulative explained variance');

In [None]:
# Transform spatial components
n_comps = 1 # number of output components
pca = PCA(n_components=n_comps)
pca.fit(A_good)
A_pca = pca.transform(A_good)
print("original shape:   ", A_good.shape)
print("transformed shape:", A_pca.shape)

In [None]:
# plot PCA component(s)
counter = 1
plt.figure(figsize=(30,30))
for comp in A_pca.T:
    plt.subplot(A_pca.shape[1], 1, counter)
    img = np.array(np.reshape(comp, avg_img.shape, order='F'))
    plt.imshow(component_img)
    plt.title('PCA Component %d' % (counter), fontsize=24)
    counter += 1

### Stop cluster

In [None]:
%%bash
source /opt/Anaconda3-5.1.0-Linux-x86_64/bin/activate caiman || source activate caiman
ipcluster stop
sleep 1