In [None]:
import os, random, sys, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
import dask
from dask.diagnostics import ProgressBar
import caiman as cm
import h5py
# from skimage.external import tifffile as tff
from sklearn.decomposition import PCA
import tifffile as tff
import joblib
import plotly.graph_objects as go
# import seaborn as sns

codeDir = r'V:/code/python/code'
sys.path.append(codeDir)
import apCode.FileTools as ft
import apCode.volTools as volt
from apCode.machineLearning import ml as mlearn
import apCode.behavior.FreeSwimBehavior as fsb
import apCode.behavior.headFixed as hf
import apCode.SignalProcessingTools as spt
import apCode.geom as geom
import importlib
from apCode import util as util
from apCode import hdf
from apCode.imageAnalysis.spim import regress
from apCode.behavior import gmm as my_gmm
from apCode.machineLearning.preprocessing import Scaler

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42


try:
    if __IPYTHON__:
        get_ipython().magic('load_ext autoreload')
        get_ipython().magic('autoreload 2')
except NameError:
    pass

# Setting seed for reproducability
seed = 143
random.seed = seed

print(time.ctime())

### *Read the xls sheet with all the data paths*

In [None]:
#%% Path to excel sheet storing paths to data and other relevant info
dir_xls = r'Y:\Avinash\Projects\RS recruitment\GCaMP imaging'
file_xls = 'GCaMP volumetric imaging summary.xlsx'
xls = pd.read_excel(os.path.join(dir_xls, file_xls), sheet_name='Sheet1')
xls.head()

In [None]:
#%% Read xl file
idx_fish = 9
path_now = np.array(xls.loc[xls.FishIdx == idx_fish].Path)[0]
print(path_now)
if 'hFilePath' in locals():
    del hFilePath

# Continue from the saved dataframe

In [None]:
%%time

startFresh = True # Reads hFile and df

# if (startFresh) & (('hFilePath' in locals()) | ('df' in locals())):
#     del hFilePath, df

#%% If stored dataframe exists in path read it
if 'hFilePath' not in locals():
    hFileName = ft.findAndSortFilesInDir(path_now, ext = 'h5', search_str = 'procData')[-1]
    hFilePath = os.path.join(path_now, hFileName)
with h5py.File(hFilePath, mode = 'r') as hFile:
    print(hFile.keys())

if 'df' not in locals():
    file_df = ft.findAndSortFilesInDir(path_now, ext = 'pickle', search_str = 'dataFrame')
    if len(file_df)>0:
        file_df = file_df[-1]
        path_df = os.path.join(path_now, file_df)
        print(path_df)
        print('Reading dataframe...')
        %time df = pd.read_pickle(path_df)       
    else:
        print('No dataframe found in path!')
    

In [None]:
with h5py.File(hFilePath, mode='r') as hFile:
    print(hFile['ca_trls_reg'].shape)
    stimLoc = util.to_utf(hFile['stimLocVec'][()])
    print(len(stimLoc))

In [None]:
#%% Extract useful info
ta_trl = np.array([np.array(_) for _ in df.tailAngles])
ta = np.concatenate(ta_trl,axis = 1)

### *Can I get a clearer and crisper image volume than the offset map from registration to draw ROIs on?*

In [None]:
%%time
with h5py.File(hFilePath, mode='r') as hFile:
    ca = hFile['ca_trls_reg'][()]

nTrls = ca.shape[0]
trlLen = ca.shape[1]
volDims = ca.shape[-3:]  
ca = ca.reshape(-1, *volDims)
ca = ca[:, 1:]
ca_avg = ca.mean(axis=0)
ca_avg = ca_avg - ca_avg.min() + 1

In [None]:
#%% Save average stack for future reference
# foo = ca_avg.copy()
# for z in range(ca_avg.shape[0]):
#     foo[z] = spt.stats.saturateByPerc(ca_avg[z], perc_up=99)
# foo = spt.stats.saturateByPerc(ca_avg, perc_up=99)
tff.imsave(os.path.join(path_now, 'averageCaImgVol.tif'), data=ca_avg.astype('int'))

## Make simple Ca$^{2+}$ response maps to distinguish head and tail-elicited responses

## ROI analysis

In [None]:
#%% Read ROIs
# dir_rois= r'Y:\Avinash\Head-fixed tail free\GCaMP imaging\2020-01-11\f1\figs\regression_ipca_flt_sigma-100_20200317-0507\betas\RoiSet.zip'

dir_rois = os.path.join(path_now, 'RoiSet.zip')
dir_rois

In [None]:
filtSize = 1

hFileName = ft.findAndSortFilesInDir(path_now, ext = 'h5', search_str='procData')[-1]
hFilePath = os.path.join(path_now, hFileName)

if 'images_blah' not in locals():
    with h5py.File(hFilePath, mode = 'r') as hFile:
#         images = np.array(hFile[f'images_reg_ipca_flt_sigma-{int(filtSize*100)}'])
        images_trl = np.array(hFile['ca_trls_reg'])
    images = images_trl.reshape(-1, *images_trl.shape[-3:])
#     images = images[:, 1:]
    

# if 'df' not in locals():
#     file_df = ft.findAndSortFilesInDir(path_now, ext = 'pickle', search_str='dataFrame')[-1]
#     %time df = pd.read_pickle(os.path.join(path_now, file_df))


In [None]:
#%% Some functions and reading of ROIs

def strip_suffices(strList):
    strList_new = []
    for _ in strList:
        a, b, c = _.split('.')
        strList_new.append(a + '.' + b)
    return np.array(strList_new)

def consolidate_rois(rois, volDims):
    roiNames_orig = list(rois.keys())
    roiNames = strip_suffices(roiNames_orig)
    roiNames_unique = np.unique(roiNames)
    masks = []
    for rn in roiNames_unique:
        inds = util.findStrInList(rn, roiNames)
        mask = np.zeros(volDims)
        for ind in inds:
            roi_ = rois[roiNames_orig[ind]]
            z = roi_['position']
            mask[z] = roi_['mask']
        masks.append(mask)
    return np.array(masks), roiNames_unique


imgDims = images.shape[-2:]
volDims = images.shape[-3:]

_, rois = mlearn.readImageJRois(dir_rois, imgDims, multiLevel=False)
masks, roiNames = consolidate_rois(rois, volDims)


In [None]:
%%time 
def func_now(images, mask):
    return np.apply_over_axes(np.mean, images*mask[None, ...], [1, 2, 3]).flatten()
roi_ts = []
for iMask, mask in enumerate(masks):
    print(f'{iMask+1}/{masks.shape[0]}')
    ts = func_now(images, mask)
    roi_ts.append(ts)
roi_ts = np.array(roi_ts)

In [None]:
ind = util.findStrInList('R.Mauthner', roiNames)[0]

nTrls = images_trl.shape[0]

stimInds = np.arange(nTrls)
plt.style.use(('fivethirtyeight', 'seaborn-paper'))
plt.figure(figsize=(20, 5))
plt.plot(roi_ts[ind])

In [None]:
scale = False

stimLoc = np.array([sl[-1] for sl in stimLoc])
stimLoc_unique = np.unique(stimLoc)
if len(stimLoc_unique)==1:
    n = len(stimLoc)//2
    stimLoc[-n:] = np.setdiff1d(['h', 't'], stimLoc_unique)


nTrls = images_trl.shape[0]
# nTrls = df.shape[0]
roi_ts_trls = roi_ts.reshape(roi_ts.shape[0], nTrls, -1)
roi_ts_trls -= roi_ts_trls[...,1][...,None]
if 'stimLoc' not in locals():
    stimLoc = np.array(df.stimLoc)
trls_head = np.where(stimLoc=='h')[0]
trls_tail = np.where(stimLoc=='t')[0]

roi_ts_head = roi_ts_trls[:, trls_head]
roi_ts_tail = roi_ts_trls[:, trls_tail]

mu_head = roi_ts_head.mean(axis=1)
sem_head = roi_ts_head.std(axis=1)/np.sqrt(mu_head.shape[0])
mu_tail = roi_ts_tail.mean(axis=1)
sem_tail = roi_ts_tail.std(axis=1)/np.sqrt(mu_tail.shape[0])

if scale:
    ind_m = util.findStrInList('R.Mauthner', roiNames)[0]
    h = mu_head[ind]
    t = mu_tail[ind]
    sf = t.max()/h.max()
    print(f'Scaling factor = {sf}')
    roi_ts_head = roi_ts_head*sf
    mu_head = roi_ts_head.mean(axis=1)
    sem_head = roi_ts_head.std(axis=1)/np.sqrt(mu_head.shape[0])


# plt.figure(figsize=(20, 20*nRows/nCols))
nCols = 3
nRows = int(np.ceil(len(roiNames)/nCols))
fh, ax = plt.subplots(nrows=nRows, ncols=nCols, sharex=True, figsize=(20, 20*nRows/nCols))
ax = ax.flatten()
fh.tight_layout()

t = np.arange(mu_head.shape[1])*(1/2)
for iRoi, roi_ in enumerate(mu_head):
#     ax[iRoi].plot(mu_head[iRoi], c=plt.cm.tab10(0), label='Head')
    ax[iRoi].fill_between(t, mu_head[iRoi]-sem_head[iRoi], mu_head[iRoi]+sem_head[iRoi],
                          color=plt.cm.tab10(0), alpha=0.5, label='Head')
    ax[iRoi].fill_between(t, mu_tail[iRoi]-sem_tail[iRoi], mu_tail[iRoi]+sem_tail[iRoi],
                          color=plt.cm.tab10(1), alpha=0.5, label='Tail')
#     ax[iRoi].plot(t,mu_tail[iRoi], c=plt.cm.tab10(1), label='Tail')
    ax[iRoi].set_yticks([])
    ax[iRoi].set_title(r'${}$'.format(roiNames[iRoi]), fontsize=20)
    if iRoi==0:
        ax[iRoi].legend(loc='best', fontsize=20)
fh.suptitle('Average Ca$^{2+}$ response for escape trials_Head vs tail stimulation\n R = ipi, L = contra', \
           fontsize=24);
fh.subplots_adjust(top=0.955, hspace=0.12)

dir_figs = os.path.join(path_now, 'figs')
if not os.path.exists(dir_figs):
    os.mkdir(dir_figs)
fn = f'Fig-{util.timestamp("minute")}_Trial-averaged Ca2+ responses_head and tail trials'
# fh.savefig(os.path.join(dir_figs, fn + '.pdf'), dpi='figure', format='pdf')
# fh.savefig(os.path.join(dir_figs, fn + '.png'), dpi='figure', format='png')
print(f'Saved at \n{dir_figs}')

In [None]:
#%% Save a dataframe of roi timeseries that can be quickly accessed later
df_roi = dict(roiName=[], stimLoc=[], ca=[])
for iRoi, roi_ in enumerate(roiNames):
    for iStim, sl in enumerate(stimLoc):
        df_roi['roiName'].append(roi_)
        df_roi['stimLoc'].append(sl)
        df_roi['ca'].append(roi_ts_trls[iRoi, iStim])
print('Converting from dict to dataframe...')
df_roi = pd.DataFrame(df_roi)

print('Saving dataframe...')
%time df_roi.to_pickle(os.path.join(path_now, f'roi_ts_dataframe_{util.timestamp()}.pkl'))

In [None]:
path_now.replace("\\", "/"), f'fishIdx = {idx_fish}' 

In [None]:
import glob
path_ = glob.glob(os.path.join(path_now, 'roi_ts_*.pkl'))[0]
df_roi = pd.read_pickle(path_)

In [None]:
df_roi.head()

In [None]:
roiNames = np.array(df_roi.roiName)
# inds = util.findStrInList('R.MiDi', roiNames)
inds = np.where(roiNames=='R.MiD2')
roiNames[inds] = 'R.MiD2i'
df_roi = df_roi.assign(roiName=roiNames)

In [None]:
inds

In [None]:

print('Saving dataframe...')
%time df_roi.to_pickle(os.path.join(path_now, f'roi_ts_dataframe_{util.timestamp()}.pkl'))

In [None]:
nTrls = ca_trl.shape[0]
roi_ts_trls = roi_ts.reshape(roi_ts.shape[0], nTrls, -1)
roi_ts_trls -= roi_ts_trls[...,0][...,None]
trls_head = np.where(stimLoc=='h')[0]
trls_tail = np.where(stimLoc=='t')[0]

roi_ts_head = roi_ts_trls[:, trls_head]
roi_ts_tail = roi_ts_trls[:, trls_tail]

plt.figure(figsize=(10, 10))
mu_head = roi_ts_head.mean(axis=1)
sigma_head = roi_ts_head.std(axis=1)
mu_tail = roi_ts_tail.mean(axis=1)
sigma_tail = roi_ts_tail.std(axis=1)


yOff = 2*np.max(mu_head)*np.arange(roi_ts_trls.shape[0])[:, None]
# yOff = util.yOffMat(mu_head)*2

plt.subplot(121)
plt.plot((mu_head-yOff).T);
plt.plot((mu_head+sigma_head-yOff).T, c='k', alpha=0.25);


plt.subplot(122)
plt.plot((mu_tail-yOff).T);
plt.plot((mu_tail+sigma_tail-yOff).T, c='k', alpha=0.25);



In [None]:
#%% NMF in ROI-masked areas
masks_zProj = masks==1
masks_zPproj = masks.max(axis=1).max(axis=0)
images_zProj = images.mean(axis=1)
images_zProj_mask = masks_zPproj[None, ...]*images_zProj
mov = cm.movie(images_zProj_mask)
mov -= mov.min()

In [None]:
%time nmf_space, nmf_time = mov.NonnegativeMatrixFactorization()

In [None]:
iComp = 10
plt.figure(figsize=(10, 10))
plt.subplot(211)
plt.imshow(spt.standardize(nmf_space[iComp]), vmax=0.5)
plt.subplot(212)
plt.plot(nmf_time.T[iComp])



## *Try CNMF*

In [None]:
import bokeh.plotting as bpl
import caiman as cm
from caiman.motion_correction import MotionCorrect
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.source_extraction.cnmf import params as params
from caiman.utils.utils import download_demo
from caiman.utils.visualization import plot_contours, nb_view_patches, nb_plot_contour
bpl.output_notebook()

# *CNMF* 

## *If data is small enough use a single patch approach*

In [None]:
images_now = images[:,1:]
images_ser = images_now.reshape(images_now.shape[0], -1)
nTrls = df.shape[0]
trlLen = images_now.shape[0]/nTrls

# %time regObj = regress(X_reg[:,-2:], images_ser, n_jobs=-1, fit_intercept=True)


In [None]:
x = np.arange(X_reg.shape[0])


In [None]:
images_now = images[:,1:]
imgs_proj = images_now.mean(axis=1)
mov = cm.movie(imgs_proj, fr=2)
# mov -= mov.min()
df_ca, baseline = mov.computeDFF()
# df = mov.bilateral_blur_2D()
# df = mov.copy()
df_ca = np.array(df_ca)
df_ca -= df_ca.min()
print(df_ca.shape)

In [None]:
#%% Save as memory mapped file
fn_new = cm.save_memmap([df_ca], order='C', base_name='Yr9')


In [None]:
# now load the file
Yr, dims, T = cm.load_memmap(fn_new)
images_now = np.reshape(Yr.T, [T] + list(dims), order='F')
print(images_now.shape)

In [None]:
#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(backend='local', n_processes=None,\
                                                 single_thread=False)

### *Inititalize CNMF object*

In [None]:

# import bokeh.plotting as bpl
# bpl.output_notebook()

# set parameters
fr = 2
# K = 20  # number of neurons expected per patch
# gSig = [2, 2]  # expected half size of neurons
merge_thresh = 0.9  # merging threshold, max correlation allowed
p = 2  # order of the autoregressive system

gnb = 2                     # number of global background components
rf = 45                     # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
stride_cnmf = 10             # amount of overlap between the patches in pixels
K = 4                       # number of components per patch
gSig = [20, 20]               # expected half size of neurons in pixels
method_init = 'greedy_roi'  # initialization method (if analyzing dendritic data using 'sparse_nmf')
ssub = 1                    # spatial subsampling during initialization
tsub = 1                    # temporal subsampling during intialization

# parameters for component evaluation
min_SNR = 2.0               # signal to noise ratio for accepting a component
rval_thr = 1.0              # space correlation threshold for accepting a component

remove_very_bad_comps = False
cnn_thr = 0.99              # threshold for CNN based classifier
cnn_lowest = 0.1 # neurons with cnn probability lower than this value are rejected


In [None]:
# INIT
cnm = cnmf.CNMF(n_processes, fr=fr, k=K, gSig=gSig, merge_thresh=merge_thresh, p=p, rf=rf, dview=dview,\
                min_SNR=min_SNR, rval_thr=rval_thr, remove_very_bad_comps=remove_very_bad_comps)

%time cnm = cnm.fit(images_now)
nComps = cnm.estimates.A.shape[-1]
print(f'{nComps} components')

In [None]:
#%% plot contours of found components
Cn = cm.local_correlations(images_now.transpose(1,2,0))
Cn[np.isnan(Cn)] = 0
cnm.estimates.plot_contours(img=Cn, thr=0.8);


In [None]:
i = 23
plt.figure(figsize=(10, 5)); 
plt.subplot(211)
plt.imshow(np.reshape(cnm.estimates.A[:,i-1].toarray(), dims, order='F'))

nmf_time = cnm.estimates.C
plt.subplot(212)
plt.plot(nmf_time[i])

In [None]:
# %%capture
#%% RE-RUN seeded CNMF on accepted patches to refine and perform deconvolution 
%time cnm2 = cnm.refit(images_now, dview=dview)
print(f'{cnm2.estimates.A.shape[-1]} components')

In [None]:
# the components are evaluated in three ways:
#   a) the shape of each component must be correlated with the data
#   b) a minimum peak SNR is required over the length of a transient
#   c) each shape passes a CNN based classifier

# cnm2.estimates.evaluate_components(images_now, cnm2.params, dview=dview)

In [None]:
cnm2.estimates.plot_contours(img=Cn, idx=cnm2.estimates.idx_components)

In [None]:
# #%% Extract DF/F values
# cnm2 = cnm2.estimates.detrend_df_f(quantileMin=8, frames_window=250)
# dff = cnm2.F_dff

In [None]:
iSlc = 16
slc = images[:,iSlc,...]
plt.imshow(slc.max(axis=0))

In [None]:
#%% Standard NMF
mov -= mov.min()
nmf_space, nmf_time = mov.NonnegativeMatrixFactorization(n_components=30)

In [None]:
iComp = 14
plt.figure(figsize=(10, 5))
plt.subplot(211)
plt.imshow(nmf_space[iComp])

plt.subplot(212)
plt.plot(nmf_time[iComp])

## *3D version*

In [None]:
#%% Rearrange dimensions to put txyz format
images_txyz = np.transpose(images, (0, 2, 3, 1))[...,1:]
images_txyz.shape

In [None]:
#%% Save as memory mapped file
fn_new = cm.save_memmap([images_txyz], order='C', base_name='Yr_3d2', is_3D=True)


In [None]:
# now load the file
Yr, dims, T = cm.load_memmap(fn_new)
Y = np.reshape(Yr.T, [T] + list(dims), order='F')
print(Y.shape)

In [None]:
# Cn = cm.local_correlations(Y)
plt.imshow(Cn.max(0) if len(Cn.shape) == 3 else Cn, cmap='viridis',
           vmin=np.percentile(Cn, 70), vmax=np.percentile(Cn, 99.9))
plt.show()

## *Single patch approach for small data*

In [None]:
# set parameters
K = 20  # number of neurons expected per patch
gSig = [2, 2, 2]  # expected half size of neurons
merge_thresh = 0.8  # merging threshold, max correlation allowed
p = 2  # order of the autoregressive system

In [None]:
# INIT
cnm = cnmf.CNMF(n_processes, k=K, gSig=gSig, merge_thresh=merge_thresh, p=p, dview=dview)

In [None]:
%%time
# %%capture
# FIT
images_now = np.reshape(Yr.T, [T] + list(dims), order='F')    # reshape data in Python format (T x X x Y x Z)
cnm = cnm.fit(images_now)

In [None]:
cnm.estimates.nb_view_components_3d(image_type='mean', dims=dims);

## *Patch approach for larger datasets*

In [None]:
# set parameters
rf = 18  # half-size of the patches in pixels. rf=25, patches are 50x50
stride = 10  # amounpl.it of overlap between the patches in pixels
K = 12  # number of neurons expected per patch
gSig = [8, 8, 2]  # expected half size of neurons
merge_thresh = 0.8  # merging threshold, max correlation allowed
p = 2  # order of the autoregressive system

In [None]:
# %%capture
#%% RUN ALGORITHM ON PATCHES

cnm = cnmf.CNMF(n_processes, k=K, gSig=gSig, merge_thresh=merge_thresh, p=p, dview=dview,
                rf=rf, stride=stride, only_init_patch=True)

%time cnm = cnm.fit(images)
print(('Number of components:' + str(cnm.estimates.A.shape[-1])))

In [None]:
# cnm.estimates.nb_view_components_3d(image_type='mean', dims=dims);

In [None]:
#%% COMPONENT EVALUATION
# the components are evaluated in two ways:
#   a) the shape of each component must be correlated with the data
#   b) a minimum peak SNR is required over the length of a transient

fr = 2 # approx final rate  (after eventual downsampling )
decay_time = 1.  # length of typical transient in seconds 
use_cnn = False  # CNN classifier is designed for 2d (real) data
min_SNR = 3      # accept components with that peak-SNR or higher
rval_thr = 0.7   # accept components with space correlation threshold or higher
cnm.params.change_params(params_dict={'fr': fr,
                                      'decay_time': decay_time,
                                      'min_SNR': min_SNR,
                                      'rval_thr': rval_thr,
                                      'use_cnn': use_cnn});
%time cnm.estimates.evaluate_components(images, cnm.params, dview=dview)

print(('Keeping ' + str(len(cnm.estimates.idx_components)) +
       ' and discarding  ' + str(len(cnm.estimates.idx_components_bad))))

In [None]:
# %%capture
cnm.params.set('temporal', {'p': p})
%time cnm2 = cnm.refit(images_now)

In [None]:

cnm2.estimates.nb_view_components_3d(image_type='corr', dims=dims, Yr=Yr,\
                                     denoised_color='red', max_projection=True);

In [None]:
cnm2.estimates.nb_view_components_3d(image_type='max', dims=dims, Yr=Yr,\
                                     denoised_color='red', max_projection=True);

In [None]:
# cnm.estimates.nb_view_components_3d(image_type='max', dims=dims, Yr=Yr,\
#                                      denoised_color='red', max_projection=True);

In [None]:
m = cnm2.estimates.A.max(1).toarray()
m = m.reshape(*images_now.shape[-3:])
m = m.transpose(2, 0, 1)
m.shape

In [None]:
plt.imshow(m[5])

In [None]:
m.transpose(2, 0, 1)