In [None]:
import os, random, sys, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
import dask
from dask.diagnostics import ProgressBar
import caiman as cm
import h5py
import glob
from sklearn.preprocessing import StandardScaler
import tifffile as tff
import joblib

codeDir = r'V:/code/python/code'
sys.path.append(codeDir)
import apCode.FileTools as ft
import apCode.volTools as volt
from apCode.machineLearning import ml as mlearn
import apCode.behavior.FreeSwimBehavior as fsb
import apCode.behavior.headFixed as hf
import apCode.SignalProcessingTools as spt
import apCode.geom as geom
import seaborn as sns
import importlib
from apCode import util as util
from apCode import hdf
from apCode.imageAnalysis.spim import regress
from apCode.behavior import gmm as my_gmm


plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42


try:
    if __IPYTHON__:
        get_ipython().magic('load_ext autoreload')
        get_ipython().magic('autoreload 2')
except NameError:
    pass

# Setting seed for reproducability
seed = 143
random.seed = seed

print(time.ctime())

### *Read the excel sheet with the data paths*

In [None]:
#%% Path to excel sheet storing paths to data and other relevant info
dir_xls = r'\\Koyama-S2\Data3\Avinash\Projects\RS recruitment\GCaMP imaging'
path_xls =  glob.glob(os.path.join(dir_xls, 'GCaMP volumetric imaging summary*.xlsx'))[-1]


### *Select a fish to run regression analysis on*

In [None]:
#%% Read xl file
idx_fish = 14
xls = pd.read_excel(path_xls, sheet_name='Sheet1')
dir_fish = np.array(xls.loc[xls.FishIdx == idx_fish].Path)[0]
print(dir_fish)

### *Read registered Ca$^{2+}$ images. If not already registered, then register and read*

In [None]:
%%time
path_hdf = glob.glob(dir_fish + "/procData*.h5")[-1]
reg_ca = False
with h5py.File(path_hdf, mode='r') as hFile:
    if 'ca_reg' not in hFile:
        reg_ca = True 
if reg_ca:
    %time path_hdf = hf.register_piecewise_from_hdf(path_hdf)

df_ca = dict(ca=[], trlIdx=[], sessionIdx=[], stimLoc=[], trlLen_ca=[])
with h5py.File(path_hdf, mode='r') as hFile:
    trlIdx = np.array(hFile['trlIdx_ca'])
    sessionIdx = np.array(hFile['sessionIdx'])
    stimLoc = util.to_utf(np.array(hFile['stimLoc']))
    trlLen_ca = np.min(hFile['nImgsInTrl_ca'])
    tss = [[trl, sess, stim] for trl, sess, stim in zip(trlIdx, sessionIdx, stimLoc)]
    tss_un, inds_un = np.unique(tss, axis=0, return_index=True)
    tss_un = tss_un[np.argsort(inds_un)]
    nTrls = len(tss_un)
    for iTrl, tss_ in enumerate(tss_un):
        print(f'Trl {iTrl+1}/{nTrls}' )
        trl, sess, stim = int(tss_[0]), int(tss_[1]), tss_[2]
        inds = np.where((trlIdx==trl) & (sessionIdx==sess) & (stimLoc==stim))[0]
        ca_ = np.array(hFile['ca_reg'][:, inds]).swapaxes(0, 1)[:trlLen_ca]
        df_ca['ca'].append([ca_])
        df_ca['trlIdx'].append(trl)
        df_ca['sessionIdx'].append(sess)
        df_ca['stimLoc'].append(stim)
        df_ca['trlLen_ca'].append(trlLen_ca)
df_ca['trlIdx_glob'] = np.arange(nTrls)
df_ca = pd.DataFrame(df_ca)

### *Now read behavior and merge imaging and behavior into a single dataframe*

In [None]:
df_ta = {}
with h5py.File(path_hdf, mode='r') as hFile:
    stimLoc = util.to_utf(np.array(hFile['behav/stimLoc']))
    sess, stim = zip(*[sl.split("_") for sl in stimLoc])
    sess = np.array(sess).astype(int)-1
    stim = np.array(stim)
    df_ta['sessionIdx'] = sess
    df_ta['stimLoc'] = stim
    df_ta['trlIdx_glob'] = np.arange(len(sess))
    ta = np.array(hFile['behav/tailAngles'])
    nTrls = ta.shape[0]//50
    ta_trl = np.vsplit(ta, nTrls)
    df_ta['tailAngles'] = ta_trl
    df_ta['trlLen_behav'] = np.repeat(ta.shape[1], nTrls)
df_ta = pd.DataFrame(df_ta)
df = pd.merge(df_ca, df_ta, how='inner')

# Delete large redundant variable from memory
if 'df' in locals():
    if 'df_ca' in locals():
        del df_ca
    if 'df_ta' in locals():
        del df_ta

### *Optionally, denoise $Ca^{2+}$ images (can take upto an hour)*

In [None]:
%%time
ca = np.squeeze([np.array(_) for _ in np.array(df.ca)])
ca = ca.reshape(-1, *ca.shape[-3:]).swapaxes(0, 1)

ca_den =[]
for iSlc, slc in enumerate(ca):
    print(f'Slc {iSlc+1}/{ca.shape[0]}')
    ca_den.append(volt.denoise_ipca(slc))
ca_den = np.array(ca_den)
ca_den = ca_den.swapaxes(0, 1)

### *Play movie to see how the raw and filted images compare*

In [None]:
iSlc = 16
slc = np.concatenate((ca[iSlc], ca_den[:, iSlc]), axis=1)
cm.movie(slc, fr=10).play(magnification=1.5, q_max=99)

## *Regression*

#### *Start by defining some useful functions*

In [None]:
#%% Some useful functions


def padIr(ir_trl, pad_pre, pad_post):
    """
    Pads the impulse response timeseries obtained from 
    predictions on behavioral feature matrix to match
    time length with ca responses
    """
    ir_ser = []
    for c in ir_trl:
        ir_ser.append(np.pad(c,((0,0),(pad_pre, pad_post))).flatten())
    return np.array(ir_ser)

def serializeHyperstack(vol):
    """
    Given, a hyperstack, returns a 2D array with pixels serialized for regression, etc.
    Parameters
    ----------
    vol: array, (nTimePoints, nSlices, nRows, nCols)
    Returns
    -------
    vol_ser: array, (nTimePoints, nPixels)
    """
    vol_trans = np.transpose(vol,(2,3,1,0))
    vol_ser = vol_trans.reshape(-1, vol_trans.shape[-1])
    vol_ser = np.swapaxes(vol_ser,0,1)
    return vol_ser

def deserializeToHyperstack(arr, volDims):
    """
    Given an array which 
    """
    volDims = (np.array(volDims))[[1,2,0]]
    vol = arr.reshape(arr.shape[0],*volDims)
    vol = np.transpose(vol,(0,3,1,2))
    return vol

def pxlsToVol(pxls, volDims):
    """    
    """
    volDims = (np.array(volDims))[[1, 2, 0]]
    vol = pxls.reshape(*volDims)
    vol = np.transpose(vol,(2,0,1))
    return vol


def resample(t, y, tt):
    """Super sample a signal using interpolation"""
    import numpy as np
    from scipy.interpolate import interp1d
    t = np.concatenate((tt[0].reshape((-1,)), t, tt[-1].reshape((-1,))))
    y = np.concatenate((np.array(0).reshape((-1,)),y,np.array(0).reshape((-1,))))
    f = interp1d(t,y,kind = 'slinear')
    return f(tt)


def regOutsToVol(ro, volDims):
    if np.ndim(ro)<2:
        ro = ro[:, np.newaxis]  
    ro = ro.T
    vol = []
    for _ in ro:
        vol.append(pxlsToVol(_, volDims))
    vol = np.squeeze(vol)
    return vol


def convolve_trlwise(ir_trl, ker, regInds):
    """
    Convolve impulse trains with Ca kenel, trial-by_tril
    Parameters
    ----------
    ir_trl: array, (nTrls, nRegressors, nTimePtsInTrl)
    ker: array, (kernelLen, )
    regInds: array, (n, )
        Indices of regressors to convolve
    Returns
    --------
    reg_trl: array, (*ir_trl.shape)
    """
    func = lambda x, ker: np.convolve(x, ker, mode='full')[:len(x)]
    reg_trl=[]
    for trl in ir_trl:
        reg_reg=[]
        for iReg, reg in enumerate(trl):
            if iReg in regInds:
                y = dask.delayed(func)(reg, ker)
            else:
                y = reg
            reg_reg.append(y)
        reg_trl.append(reg_reg)
    reg_trl = dask.compute(*reg_trl)
    return np.array(reg_trl)

### *Load the GMM model and predict labels on tail angles*

In [None]:
dir_gmm = os.path.join(dir_xls, 'Group')
path_gmm = glob.glob(os.path.join(dir_gmm, 'gmm_headFixed_*.pkl'))[-1]
gmm_model = joblib.load(path_gmm)

ta_trl = np.array([np.array(_) for _ in df.tailAngles])
nTrls = len(ta_trl)
ta = np.concatenate(ta_trl, axis=1)
%time ta = hf.cleanTailAngles(ta, svd=gmm_model.svd)[0]
ta_trl = np.array(np.hsplit(ta, nTrls))
labels, features = gmm_model.predict(ta)

### *Make a set of labels-based impulse response functions for regression*

In [None]:
%%time

tPeriStim_behav = (-1, 6) # Pre- and pos-stim periods in seconds for behavior trials
tPeriStim_ca = (-1, 10) # Pre- and post-stim periods in seconds for ca trials
Fs_behav = 500

getStimName = lambda s: 'Head' if s == 'h' else 'Tail'

ir, names_ir = hf.impulse_trains_from_labels(labels, ta, split_lr=False)

pad_post = int((tPeriStim_ca[-1]-tPeriStim_behav[-1])*Fs_behav)
n_pre_behav = int(np.abs(tPeriStim_behav[0])*Fs_behav)
stimLoc = np.array(df.stimLoc)
stimLoc_unique = np.unique(stimLoc)
sessionIdx  = np.array(df.sessionIdx)
sessionIdx_unique = np.unique(sessionIdx)
nSessions = len(sessionIdx_unique)

nTrls = df.shape[0]
ir_trl = np.transpose(ir.reshape(ir.shape[0], nTrls,-1),(1 ,0, 2))
names_ir = list(names_ir)
foo = []
count = 1
for sl, trl in zip(stimLoc, ir_trl):
    ht = np.zeros((len(stimLoc_unique), trl.shape[-1]))
    ind = np.where(stimLoc_unique == sl)[0]
    ht[ind, n_pre_behav-1]=1 
    trl_ht = np.r_[trl, ht]
    blah = np.pad(trl_ht,((0,0),(0,pad_post)), mode = 'constant')
    session_now = sessionIdx[count-1]
    session_idx = np.zeros((nSessions,blah.shape[-1]))*(count/ir_trl.shape[0])
    session_idx[session_now-1,:] = 1
    foo.append(np.r_[blah, session_idx])
    count += 1
ir_trl = np.array(foo)
ir_ser = np.concatenate(ir_trl,axis = 1)

names_ir = list(names_ir)
names_ir.extend([getStimName(s) for s in stimLoc_unique])
for idx in sessionIdx_unique:
    names_ir.extend([f'Session-{idx}'])
regNames = names_ir.copy()



### *Plot to see what these look like* 

In [None]:
#%% Display impulse trains and other regressors
t_full = np.arange(ir_ser.shape[-1])*(1/Fs_behav)
yOff = util.yOffMat(ir_ser)
plt.figure(figsize = (20, 10))
plt.plot(t_full, (ir_ser-yOff).T)
yt = -np.arange(ir_ser.shape[0])
plt.yticks(yt, regNames)
plt.xlim(t_full[0], t_full[-1])
plt.xlabel('Time (s)')
plt.title('Impulse responses & other regressors', fontsize=20);

### *Convolve trial-by-trial with $Ca^{2+}$ kernel to produce final regressors*

In [None]:
%%time
#%% CIRF in slightly subSampled behavAndScan time, followed by convolution to generate regressors
tLen = 6 # Length of kernel
tau_rise = 0.2 # Rise constant
tau_decay = 1 # Decay constant
dt_behav = 1/500

### CIRF
t_cirf = np.arange(0, tLen, dt_behav)
cirf = spt.generateEPSP(t_cirf, tau_rise, tau_decay, 1, 0)

ind = util.findStrInList('session', regNames, case_sensitive=False)[0]
regInds = np.arange(ind)
%time regressors = convolve_trlwise(ir_trl, cirf, regInds)
regressors = np.concatenate(regressors, axis=1)
scaler = StandardScaler(with_mean=False)
regressors = scaler.fit_transform(regressors.T).T


%time ca_ser = serializeHyperstack(ca_den)

t_behav = np.linspace(0, 1, regressors.shape[1])
t_ca = np.linspace(0, 1, ca_ser.shape[0])


regressors = dask.compute(*[dask.delayed(resample)(t_behav, reg, t_ca) for reg in regressors])
regressors = np.array(regressors)


if 'path_hdf' not in locals():
    path_hdf = glob.glob(os.path.join(dir_fish, 'procData*.h5'))[-1]
       
with h5py.File(path_hdf, mode = 'r+') as hFile:
    if 'regression' in hFile:
        del hFile['regression']
    grp = hFile.create_group('regression')   
    grp.create_dataset('regressors', data=regressors.T)
    grp.create_dataset('regressor_names', data=util.to_ascii(regNames))
    grp.create_dataset('impulse_trains', data=ir_ser)


### *Plot regressors* 

In [None]:
#%% Plot all regressors
yOff = util.yOffMat(regressors)
plt.figure(figsize = (20, 15))
plt.plot(t_ca, (regressors-yOff).T)
plt.xlim(t_ca.min(), t_ca.max())
plt.yticks(-yOff, regNames)
plt.xlabel('Time (s)')
plt.title('Regressors');

### *Filter images a bit to improve regression (optional)*

In [None]:
%%time
filtSize = 0.5

ca_den_flt = []
for iSlc, slc in enumerate(ca_den.swapaxes(0, 1)):
    print(f'{iSlc + 1}/{ca_den.shape[1]}')
    slc_flt = volt.img.gaussFilt(slc, sigma=filtSize)
    ca_den_flt.append(slc_flt)
ca_den_flt = np.array(ca_den_flt).swapaxes(0, 1)

ca_ser = serializeHyperstack(ca_den)

if 'path_hdf' not in locals():
    path_hdf = glob.glob(os.path.join(dir_fish, 'procData*.h5'))[-1]
    
with h5py.File(path_hd, mode = 'r+') as hFile:
    keyName = f'ca_den_flt_sigma-{int(filtSize*100)}'
    if keyName in hFile:
        del hFile[keyName]
    %time hFile.create_dataset(keyName, data=ca_den_flt)

### *Regress*

In [None]:
#%% Regress
%time regObj = regress(regressors.T, ca_ser, n_jobs=-1, fit_intercept=True)


### *Reshape regression outputs into volumes*

In [None]:
betas_vol = regOutsToVol(regObj.coef_, ca_den.shape[-3:])
intercept_vol = regOutsToVol(regObj.intercept_, ca_den.shape[-3:])
t_vol = regOutsToVol(regObj.T_, ca_den.shape[-3:])


### *Plot max-int z projections of regression outputs for quick visual examination*

In [None]:
iReg = 5
q_max = 99
q_min = 10

plt.figure(figsize=(20, 10))
plt.imshow(spt.stats.saturateByPerc(betas_vol[iReg].max(axis=0), perc_up=q_max, perc_low=q_min))
plt.title(f'Regressors: {regNames[iReg]}')
plt.colorbar();

### *Save beta and t-value maps*

In [None]:
%%time
#%% Save regression images
figDir = os.path.join(dir_fish, f'figs/regression')
os.makedirs(figDir, exist_ok=True)

### First save coefficients
foo = betas_vol.astype('float32')
dir_now = os.path.join(figDir, 'betas')
os.makedirs(dir_now, exist_ok=True)

for iReg, vol in enumerate(foo):
    tff.imsave(os.path.join(dir_now, f'Fig-{util.timestamp()}_regressor-{regNames[iReg]}_coef.tif'),vol[1:])
tff.imsave(os.path.join(dir_now,f'Fig-{util.timestamp()}_regressor_intercept.tif'), intercept_vol)
    
foo = t_vol.astype('float32')[1:]
dir_now = os.path.join(figDir, 'tValues')
os.makedirs(dir_now, exist_ok=True)
for iReg, vol in enumerate(foo):
    tff.imsave(os.path.join(dir_now,f'Fig-{util.timestamp()}_regressor-{regNames[iReg]}_tVals.tif'),vol[1:])

print(f'Saved at \n{figDir}')