In [None]:
import os, random, sys, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
import dask
from dask.diagnostics import ProgressBar
import dask.dataframe as dask_df
import caiman as cm
import h5py
# from skimage.external import tifffile as tff
from sklearn.decomposition import PCA
import tifffile as tff
import joblib
import glob
import re
# import plotly.graph_objects as go


import bokeh.plotting as bpl
import caiman as cm
from caiman.motion_correction import MotionCorrect
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.source_extraction.cnmf import params as params
from caiman.utils.utils import download_demo
from caiman.utils.visualization import plot_contours, nb_view_patches, nb_plot_contour
bpl.output_notebook()


codeDir = r'V:/code/python/code'
sys.path.append(codeDir)
import apCode.FileTools as ft
import apCode.volTools as volt
from apCode.machineLearning import ml as mlearn
import apCode.behavior.FreeSwimBehavior as fsb
import apCode.behavior.headFixed as hf
import apCode.SignalProcessingTools as spt
import apCode.geom as geom
import importlib
from apCode import util as util
from apCode import hdf
from apCode.imageAnalysis.spim import regress
from apCode.behavior import gmm as my_gmm
from apCode.machineLearning.preprocessing import Scaler


plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42


try:
    if __IPYTHON__:
        get_ipython().magic('load_ext autoreload')
        get_ipython().magic('autoreload 2')
except NameError:
    pass



# Setting seed for reproducability
seed = 143
random.seed = seed

print(time.ctime())

In [None]:
#%% Path to excel sheet storing paths to data and other relevant info
dir_xls = r'Y:\Avinash\Projects\RS recruitment\GCaMP imaging'
file_xls = 'GCaMP volumetric imaging summary.xlsx'


In [None]:
#%% Read xl file
idx_fish = 8
xls = pd.read_excel(os.path.join(dir_xls, file_xls), sheet_name='Sheet1')
path_now = np.array(xls.loc[xls.FishIdx == idx_fish].Path)[0]
print(path_now)

### *Open an existing HDF file or create a new one, if none exists*

In [None]:
path_hFile = glob.glob(os.path.join(path_now, 'procData*.h5'))[-1]
if len(path_hFile)==0:
    path_hFile = os.path.join(path_now, f'procData_{util.timestamp()}.h5')
    with h5py.File(path_hFileile, mode='a') as hFile:
        pass
print(path_hFile)

In [None]:
def ca_tifs_to_hdf(fishDir, hFilePath=None, grpName='ca', regex='\d{3}_[h/t]', verbose=True):
    """ 
    Given the directory to data from a single fish, reads all the tif images in
    a subdirectory 'fishir/d{3}_[h/t]/ca/' and writes them to an already existing 
    hdf file in the path or, if absent, a newly created and timestamped one.
    The paths to the tif files can be found in hdfFile['filePaths_ca'], whereas the
    images and all other relevant info can be found in hdfFile['/ca/']
    Parameters
    ----------
    fishDir: str
        Root directory containing all the image subdirectories
    hFilePath: str or None
        Full path to an existing file or if None, then the program automatically
        looks for an HDF file in the path, and if absent, creates one with
        the name procData_{timestamp}.h5
    grpName: str
        The name of the grop in the HDF file under which to store all the relevant
        inforation.
    """
    import glob
    import dask
    import re
    import numpy as np
    import h5py
    from apCode import util
    from apCode import hdf
    
    imgDirs = glob.glob(os.path.join(fishDir, '*/ca/'), recursive=True)
    if len(imgDirs)==0:
        print('No Ca image subdirectories found, check path')
        return None
    else:
        print(f'{len(imgDirs)} Ca subdirectories found')
    if hFilePath is None:
        path_hFile = glob.glob(os.path.join(path_now, 'procData*.h5'))[-1]
        if len(path_hFile)==0:
            path_hFile = os.path.join(path_now, f'procData_{util.timestamp()}.h5')           
    else:
        path_hFile = hFilePath
    for iSession, imgDir in enumerate(imgDirs):
        print(f'Session # {iSession}, {imgDir}')
        imgs_ca, tifInfo = volt.dask_array_from_scanimage_tifs(imgDir)
        imgs_ca = imgs_ca - imgs_ca.mean(axis=0, keepdims=True)
        imgs_ca = imgs_ca - imgs_ca.min() + 1
        print('Baseline adjusting imgs_ca')
        nImgs = len(imgs_ca)
        filePaths = util.to_ascii(tifInfo['filePaths'])
        print(f'{len(filePaths)} tif files in directory')
        sessionNum = np.array([iSession]*nImgs)        
        stimLoc = np.array(re.findall(regex, imgDir)*nImgs)
        stimLoc = util.to_ascii(stimLoc)
        keys = [f'filePaths_{grpName}', f'{grpName}/imgs_raw', f'{grpName}/sessionNum', f'{grpName}/stimLoc']
        vals = [filePaths, imgs_ca, sessionNum, stimLoc]
        with h5py.File(path_hFile, mode='a') as hFile:
            if iSession==0:
                if grpName in hFile:
                    del hFile[grpName]
                if f'filePaths_{grpName}' in hFile:
                    del hFile[f'filePaths_{grpName}']
            for key, val in zip(keys, vals):                
                if key == f'{grpName}/imgs_raw':
                    sfx = f'{util.to_utf([stimLoc[0]])[0]}'
                    key_new =  '/' + key + f'/{sfx}'
                    print(f'Dask array to HDF: {key_new}')
                    val.to_hdf5(path_hFile, key_new)
                else:
                    hFile = hdf.createOrAppendToHdf(hFile, key, val, verbose=verbose)
    return path_hFile
        
    
    

In [None]:
%time path_hFile = ca_tifs_to_hdf(path_now, verbose=True)

In [None]:
#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(backend='local', n_processes=None,\
                                                 single_thread=False)

In [None]:
# %time foo = cm.load(path_hFile,  var_name_hdf5='ca/imgs_raw/001_h')

In [None]:
dirs_ca= glob.glob(os.path.join(path_now, '*/ca/'), recursive=True)
dirs_ca

In [None]:
fnames = glob.glob(os.path.join(dirs_ca[0], '*.tif'))
len(fnames)

In [None]:
%time fname_new = cm.save_memmap(fnames, is_3D=True, dview=dview, order='C')

In [None]:
%time Yr, dims, T = cm.load_memmap(fname_new, mode='r')

In [None]:
(Yr.shape[-1]//30)*30

In [None]:
getsize = lambda path, p: os.path.getsize(path)/(1024)**p

In [None]:
s = np.array([getsize(fn, 2) for fn in fnames])


# Continue from the saved dataframe

In [None]:
%%time

startFresh = True # Reads hFile and df

if (startFresh) & (('hFilePath' in locals()) | ('df' in locals())):
    if 'hFilePath' in locals():
        del hFilePath
    if 'df' in locals():
        del df

#%% If stored dataframe exists in path read it
if 'hFilePath' not in locals():
    hFilePath = glob.glob(os.path.join(path_now, 'procData*.h5'))[-1]
#     hFileName = ft.findAndSortFilesInDir(path_now, ext = 'h5', search_str = 'procData')[-1]
#     hFilePath = os.path.join(path_now, hFileName)
with h5py.File(hFilePath, mode = 'r') as hFile:
    print(hFile.keys())

if 'df' not in locals():
    file_df = ft.findAndSortFilesInDir(path_now, ext = 'pickle', search_str = 'dataFrame')
    if len(file_df)>0:
        file_df = file_df[-1]
        path_df = os.path.join(path_now, file_df)
        print(path_df)
        print('Reading dataframe...')
        %time df = pd.read_pickle(path_df)       
    else:
        print('No dataframe found in path!')
    

In [None]:
var_name_hdf = 'images_reg_ipca_flt_sigma-100'
images = cm.load([hFilePath], fr=2, is3D=True, var_name_hdf5=var_name_hdf)
images = images[:,1:]
images -= images.min()-1
images = images.transpose(0, 2, 3, 1)
print(f'img dims = {images.shape}')

## *Try CNMF*

In [None]:
#%% Save as memory mapped file (required for patch-based approach to CNMF)
fname = 'mov_cnmf_3d.mmap'
fname = os.path.join(path_now, fname)
%time fname = images.save(fname, order='C')



In [None]:
# now load the file
Yr, dims, T = cm.load_memmap(fname)
images_mm = np.reshape(Yr.T, [T] + list(dims), order='F')
print(images_mm.shape)

In [None]:
#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(backend='local', n_processes=None,\
                                                 single_thread=False)

### *Inititalize CNMF object*

In [None]:
#%% set parameters
fr = 2
merge_thresh = 0.85         # merging threshold, max correlation allowed
p = 2                       # order of the autoregressive system
# gnb = 2                     # number of global background components
rf = 15                     # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
stride = 1                  # amount of overlap between the patches in pixels
K = 12                      # number of components per patch
gSig = [2, 2, 2]            # expected half size of neurons in pixels
method_init = 'greedy_roi'  # initialization method (if analyzing dendritic data using 'sparse_nmf')
ssub = 1                    # spatial subsampling during initialization
tsub = 1                    # temporal subsampling during intialization

# parameters for component evaluation
min_SNR = 2.0               # signal to noise ratio for accepting a component
rval_thr = 1.0              # space correlation threshold for accepting a component

remove_very_bad_comps = True
cnn_thr = 0.99              # threshold for CNN based classifier
cnn_lowest = 0.1 # neurons with cnn probability lower than this value are rejected


In [None]:
# INIT
cnm = cnmf.CNMF(n_processes, fr=fr, k=K, gSig=gSig, merge_thresh=merge_thresh, p=p,
                rf=rf, dview=dview)

%time cnm = cnm.fit(images_mm)
nComps = cnm.estimates.A.shape[-1]
print(f'{nComps} components')

In [None]:
# cnm.estimates.nb_view_components_3d(image_type='mean', dims=dims);

plt.figure(figsize=(20, 10)); 
plt.subplot(211)
foo = cnm.estimates.A.toarray().max(axis=-1).reshape(dims, order='F')
plt.imshow(spt.stats.saturateByPerc(foo.max(axis=-1)))

# nmf_time = cnm.estimates.C
# plt.subplot(212)
# plt.plot(nmf_time[i])

In [None]:
#%% COMPONENT EVALUATION
# the components are evaluated in two ways:
#   a) the shape of each component must be correlated with the data
#   b) a minimum peak SNR is required over the length of a transient

decay_time = 1.5  # length of typical transient in seconds 
use_cnn = False  # CNN classifier is designed for 2d (real) data
min_SNR = 2      # accept components with that peak-SNR or higher
rval_thr = 0.7   # accept components with space correlation threshold or higher
cnm.params.change_params(params_dict={'fr': fr,
                                      'decay_time': decay_time,
                                      'min_SNR': min_SNR,
                                      'rval_thr': rval_thr,
                                      'use_cnn': use_cnn})

%time cnm.estimates.evaluate_components(images_mm, cnm.params, dview=dview);
print(('Keeping ' + str(len(cnm.estimates.idx_components)) +
       ' and discarding  ' + str(len(cnm.estimates.idx_components_bad))))

In [None]:
del foo

In [None]:
# cnm.estimates.nb_view_components_3d(image_type='mean', dims=dims);

plt.figure(figsize=(20, 10)); 
plt.subplot(211)
foo = cnm.estimates.A.toarray().max(axis=-1).reshape(dims, order='F')
plt.imshow(spt.stats.saturateByPerc(foo.max(axis=-1)))

# nmf_time = cnm.estimates.C
# plt.subplot(212)
# plt.plot(nmf_time[i])

## *Now we re-run CNMF on whole FOV seeded with the accepted components*

In [None]:
%%time
cnm.params.set('temporal', {'p': p})
cnm2 = cnm.refit(images_mm, dview=dview)

print(f'nComps before = {cnm.estimates.A.shape[-1]}, nComps now = {cnm2.estimates.A.shape[-1]}')


In [None]:
print(cnm2.estimates.A.shape[-1])
cnm2.estimates.select_components(use_object=True)
print(cnm2.estimates.A.shape[-1])

In [None]:
plt.figure(figsize=(20, 10)); 
# plt.subplot(211)
foo = cnm2.estimates.A.toarray().max(axis=-1).reshape(dims, order='F')
plt.imshow(spt.stats.saturateByPerc(foo.max(axis=-1), perc_up=95))

# nmf_time = cnm.estimates.C
# plt.subplot(212)
# plt.plot(nmf_time[i])

In [None]:
nmf_time = cnm2.estimates.C
nmf_space = cnm2.estimates.A.toarray().reshape(*dims, -1, order='F')
img_avg_zProj = images.mean(axis=0).max(axis=-1)
img_avg_zProj_norm = spt.standardize(spt.stats.saturateByPerc(img_avg_zProj, perc_up=95))

In [None]:
iComp = 28  # 28

# spatial stuff
a = spt.standardize(nmf_space[..., iComp].max(axis=-1))
img = np.zeros(a.shape + (3,))
img[..., 0] = a
img[..., 2] = img_avg_zProj_norm

plt.figure(figsize=(20, 10))
plt.subplot(121)
plt.imshow(np.fliplr(img.transpose(1, 0, 2)))

# timeseries
ts = nmf_time[iComp]
stimLoc = df.stimLoc
inds_head = np.where(stimLoc=='h')[0]
inds_tail = np.where(stimLoc=='t')[0]
nTrls = df.shape[0]
ts_trl = ts.reshape(nTrls,-1)
ts_trl_head = ts_trl[inds_head]
ts_trl_head = ts_trl_head - ts_trl_head[:, 0][:, None]
ts_trl_tail = ts_trl[inds_tail]
ts_trl_tail = ts_trl_tail - ts_trl_tail[:, 0][:, None]
mu_head = ts_trl_head.mean(axis=0)
mu_tail = ts_trl_tail.mean(axis=0)

plt.subplot(122)
# plt.figure(figsize=(20, 5))
plt.plot(mu_head, label='Head')
plt.plot(mu_tail, label='Tail')
# plt.xlim(0, nmf_time.shape[1])
plt.xlim(0, mu_head.shape[0]-1)
plt.legend(fontsize=20)

In [None]:
# #%% Extract DF/F values
# cnm2 = cnm2.estimates.detrend_df_f(quantileMin=8, frames_window=250)
# dff = cnm2.F_dff

In [None]:

#%% reconstruct denoised movie
denoised = (cnm2.estimates.A.dot(cnm2.estimates.C) + \
            cnm2.estimates.b.dot(cnm2.estimates.f)).reshape(dims + (-1,), order='F')
denoised = denoised.transpose(3, 0, 1, 2)
print(denoised.shape)
mov = cm.movie(denoised.max(axis=-1), fr=2)

In [None]:
mov.play(magnification=3, q_max=95)

In [None]:
#%% Standard NMF
mov -= mov.min()
nmf_space, nmf_time = mov.NonnegativeMatrixFactorization(n_components=30)

In [None]:
iComp = 14
plt.figure(figsize=(10, 5))
plt.subplot(211)
plt.imshow(nmf_space[iComp])

plt.subplot(212)
plt.plot(nmf_time[iComp])

## *3D version*

In [None]:
#%% Rearrange dimensions to put txyz format
images_txyz = np.transpose(images, (0, 2, 3, 1))[...,1:]
images_txyz.shape

In [None]:
#%% Save as memory mapped file
fn_new = cm.save_memmap([images_txyz], order='C', base_name='Yr_3d2', is_3D=True)


In [None]:
# now load the file
Yr, dims, T = cm.load_memmap(fn_new)
Y = np.reshape(Yr.T, [T] + list(dims), order='F')
print(Y.shape)

In [None]:
# Cn = cm.local_correlations(Y)
plt.imshow(Cn.max(0) if len(Cn.shape) == 3 else Cn, cmap='viridis',
           vmin=np.percentile(Cn, 70), vmax=np.percentile(Cn, 99.9))
plt.show()

## *Single patch approach for small data*

In [None]:
# set parameters
K = 20  # number of neurons expected per patch
gSig = [2, 2, 2]  # expected half size of neurons
merge_thresh = 0.8  # merging threshold, max correlation allowed
p = 2  # order of the autoregressive system

In [None]:
# INIT
cnm = cnmf.CNMF(n_processes, k=K, gSig=gSig, merge_thresh=merge_thresh, p=p, dview=dview)

In [None]:
%%time
# %%capture
# FIT
images_now = np.reshape(Yr.T, [T] + list(dims), order='F')    # reshape data in Python format (T x X x Y x Z)
cnm = cnm.fit(images_now)

In [None]:
cnm.estimates.nb_view_components_3d(image_type='mean', dims=dims);

## *Patch approach for larger datasets*

In [None]:
# set parameters
rf = 18  # half-size of the patches in pixels. rf=25, patches are 50x50
stride = 10  # amounpl.it of overlap between the patches in pixels
K = 12  # number of neurons expected per patch
gSig = [8, 8, 2]  # expected half size of neurons
merge_thresh = 0.8  # merging threshold, max correlation allowed
p = 2  # order of the autoregressive system

In [None]:
# %%capture
#%% RUN ALGORITHM ON PATCHES

cnm = cnmf.CNMF(n_processes, k=K, gSig=gSig, merge_thresh=merge_thresh, p=p, dview=dview,
                rf=rf, stride=stride, only_init_patch=True)

%time cnm = cnm.fit(images)
print(('Number of components:' + str(cnm.estimates.A.shape[-1])))

In [None]:
# cnm.estimates.nb_view_components_3d(image_type='mean', dims=dims);

In [None]:
#%% COMPONENT EVALUATION
# the components are evaluated in two ways:
#   a) the shape of each component must be correlated with the data
#   b) a minimum peak SNR is required over the length of a transient

fr = 2 # approx final rate  (after eventual downsampling )
decay_time = 1.  # length of typical transient in seconds 
use_cnn = False  # CNN classifier is designed for 2d (real) data
min_SNR = 3      # accept components with that peak-SNR or higher
rval_thr = 0.7   # accept components with space correlation threshold or higher
cnm.params.change_params(params_dict={'fr': fr,
                                      'decay_time': decay_time,
                                      'min_SNR': min_SNR,
                                      'rval_thr': rval_thr,
                                      'use_cnn': use_cnn});
%time cnm.estimates.evaluate_components(images, cnm.params, dview=dview)

print(('Keeping ' + str(len(cnm.estimates.idx_components)) +
       ' and discarding  ' + str(len(cnm.estimates.idx_components_bad))))

In [None]:
# %%capture
cnm.params.set('temporal', {'p': p})
%time cnm2 = cnm.refit(images_now)

In [None]:

cnm2.estimates.nb_view_components_3d(image_type='corr', dims=dims, Yr=Yr,\
                                     denoised_color='red', max_projection=True);

In [None]:
cnm2.estimates.nb_view_components_3d(image_type='max', dims=dims, Yr=Yr,\
                                     denoised_color='red', max_projection=True);

In [None]:
# cnm.estimates.nb_view_components_3d(image_type='max', dims=dims, Yr=Yr,\
#                                      denoised_color='red', max_projection=True);

In [None]:
m = cnm2.estimates.A.max(1).toarray()
m = m.reshape(*images_now.shape[-3:])
m = m.transpose(2, 0, 1)
m.shape

In [None]:
plt.imshow(m[5])

In [None]:
m.transpose(2, 0, 1)