In [None]:
import os, random, sys, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
import dask
import caiman as cm
import h5py
# from skimage.external import tifffile as tff
from sklearn.decomposition import PCA
import tifffile as tff
import joblib

codeDir = r'V:/code/python/code'
sys.path.append(codeDir)
import apCode.FileTools as ft
import apCode.volTools as volt
from apCode.machineLearning import ml as mlearn
import apCode.behavior.FreeSwimBehavior as fsb
import apCode.behavior.headFixed as hf
import apCode.SignalProcessingTools as spt
import apCode.geom as geom
# import seaborn as sns
import importlib
from apCode import util as util
from apCode import hdf
from apCode.imageAnalysis.spim import regress
from apCode.behavior import gmm as my_gmm


plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42


try:
    if __IPYTHON__:
        get_ipython().magic('load_ext autoreload')
        get_ipython().magic('autoreload 2')
except NameError:
    pass

# Setting seed for reproducability
seed = 143
random.seed = seed

print(time.ctime())

In [None]:
#%% Path to excel sheet storing paths to data and other relevant info
dir_xls = r'Y:\Avinash\Projects\RS recruitment\GCaMP imaging'
file_xls = 'GCaMP volumetric imaging summary.xlsx'


In [None]:
#%% Read xl file
idx_fish = 4
xls = pd.read_excel(os.path.join(dir_xls, file_xls), sheet_name='Sheet1')
path_now = np.array(xls.loc[xls.FishIdx == idx_fish].Path)[0]
print(path_now)

# Read and register Ca$^{2+}$ images first

In [None]:
%%time
#%% Read ScanImage tif files and store in trialized format to HDF file
regex = r'\d{1,5}_[ht]'
regMethod = 'cpwr' # ('st', 'cr', 'cpwr')

subDirs = [os.path.join(path_now, sd) for sd in ft.subDirsInDir(path_now)\
           if re.match(regex, sd)]
hFileName = f'procData_{util.timestamp()}.h5'
hFilePath = os.path.join(path_now, hFileName)
stimLoc = []
with h5py.File(hFilePath, mode='a') as hFile:
    for iSub, sd in enumerate(subDirs):
        hort= os.path.split(sd)[-1]            
        fn = ft.findAndSortFilesInDir(sd,ext = '16ch')
        if len(fn)>1:        
            idx = input('Enter index of "bas" file: ')
            idx = int(idx)
            path_bas = os.path.join(sd, fn[idx])
        else:
            path_bas = os.path.join(sd,fn[-1])
        path_tif = os.path.join(sd,'ca')
        out = hf.readPeriStimulusTifImages_volumetric(path_tif, path_bas)
        nTrls = out['images_trl'].shape[0]
        stimLoc.extend([hort]*nTrls)
        if iSub ==0:
            nImgsInTrl = out['images_trl'].shape[1]
        if not 'nImgsInTrl' in hFile:
            hFile.create_dataset('nImgsInTrl', data = nImgsInTrl)
        else:
            foo = hFile['nImgsInTrl']
            foo[...] = nImgsInTrl
        keyName = 'ca_trls_raw'
        if not keyName in hFile:
            print(f'Creating {keyName} in h5 file')            
            hFile.create_dataset(keyName, data = out['images_trl'],\
                                 maxshape = (None, *out['images_trl'].shape[1:]),\
                                 compression = 'lzf')
        else:
            print(f'Appending to {keyName} in h5 file')
            hFile[keyName].resize((hFile[keyName].shape[0] + out['images_trl'].shape[0]),\
                                  axis = 0)
            hFile[keyName][-out['images_trl'].shape[0]:] = out['images_trl']
        keyName = f'inds_excluded/{os.path.split(sd)[-1]}'
        if not keyName in hFile:
            hFile.create_dataset(keyName, data = out['inds_excluded'])
        else:
            foo = hFile[keyName]
            foo[...] = out['inds_excluded']
    stimLoc_ascii = util.to_ascii(stimLoc)
    if not 'stimLocVec' in hFile:
        hFile.create_dataset('stimLocVec', data = np.array(stimLoc_ascii))
    else:
        foo = hFile['stimLocVec']
        foo[...] = np.array(stimLoc_ascii)
        
    

In [None]:
# from apCode.volTools import Register
# nTrls = images.shape[0]
# trlLen = images.shape[1]
# volDims = images.shape[-3:]
# images_ser = np.swapaxes(images.reshape(-1, *volDims), 0, 1)
# images_reg = []
# for iSlc, slc in enumerate(images_ser):
#     print(f'{iSlc+1}/{volDims[0]}')
#     img = Register(regMethod='cpwr').fit(slc).transform(slc)
#     images_reg.append(img)
# images_reg = np.swapaxes(np.array(images_reg), 0, 1)
# images_reg = images_reg.reshape(nTrls, trlLen, *volDims)

In [None]:
%%time
#%% Read stored images from HDF file and register these images
filtSize = 1
regMethod = 'st' # ('st', 'cr', 'cpwr')

if not 'hFilePath' in locals():
    hFileName = ft.findAndSortFilesInDir(path_now, ext = 'h5', search_str='procData')[-1]
    hFilePath = os.path.join(path_now, hFileName)
if not 'images' in locals():
    with h5py.File(os.path.join(path_now, hFileName), mode = 'r') as hFile:
        %time images = np.array(hFile['ca_trls_raw'])
nTrls = images.shape[0]
nTimePts = images.shape[1]
volDims = images.shape[-3:]
# try:
#     %time images_reg, regObj = hf.register_volumes_by_slices_and_trials(images, regMethod= regMethod,\
#                                                                     filtSize = filtSize)
# except:
#     print(f'{regMethod} failed, trying "st"...')
#     %time images_reg, regObj = hf.register_volumes_by_slices_and_trials(images, regMethod= 'st',\
#                                                                     filtSize = filtSize)

if regMethod.lower() =='st':
    images_reg, regObj = hf.register_trialized_volumes_by_slices(images, filtSize = filtSize,\
                                                                 regMethod = regMethod)
else:
    pass

In [None]:
%%time
#%% Save registered images to HDF file
# hFileName = ft.findAndSortFilesInDir(path_now, ext = 'h5', search_str='procData')[-1]
with h5py.File(hFilePath, mode = 'a') as hFile:
    keyName = 'ca_trls_reg'
    if not keyName in hFile:
        hFile.create_dataset(keyName, data=images_reg)
    else:
        foo = hFile[keyName]
        foo[...]= images_reg

In [None]:
images_reg.shape

In [None]:
foo = images_reg.reshape(-1,*images_reg.shape[2:])
foo.shape

In [None]:
iSlice = 14
fps = 20
# cm.movie(foo[:,iSlice,...],fr = fps).play(magnification=1.5, q_min = 5, q_max = 99)
foo_max = foo.max(axis=1)
cm.movie(foo_max, fr = fps).play(magnification=2, q_max = 95)

In [None]:
%time foo_max_den = volt.denoise_ipca(foo_max)

In [None]:
path_now

In [None]:
cm.movie(foo_max_den, fr = fps).play(magnification=2, q_max=95)

In [None]:
%%time
#%% Save each of the registered slices in a separate .tif file as well as the max-int movie
saveDir = os.path.join(path_now, f'registered_slices_{regMethod}_{util.timestamp()}')
if not os.path.exists(saveDir):
    os.mkdir(saveDir) 
from skimage.io import imsave
foo = images_reg.reshape(-1,*images_reg.shape[2:])
foo_int= foo.copy()
foo_int = np.swapaxes(foo_int,0,1).astype('int16')
for z, img in enumerate(foo_int):
    tff.imsave(os.path.join(saveDir,r'slice_{:03d}.tif'.format(z)),img)

#-- Also, save the movie of the max-z projection of slices    
foo_max_int = foo.max(axis = 1).astype('int16')
tff.imsave(os.path.join(os.path.split(saveDir)[0],r'maxInt_z_movie.tif'), foo_max_int)

In [None]:
### Serialize and make caiman movie object from mean intensity projection
# iSlices = np.array([14, 15, 16])
iSlices = np.arange(1,30)

images_reg_ser = images_reg.reshape(-1, *images_reg.shape[-3:])
# mov = cm.movie(images_reg_ser.mean(axis = 1), fr = 10)
mov = cm.movie(images_reg_ser[:,iSlices].max(axis = 1), fr = 20)
%time mov_flt = cm.movie(volt.img.gaussFilt(mov, sigma = 0.75), fr = 20)

In [None]:
mov_flt = mov_flt - mov_flt.min()
mov_flt.play(magnification=1.5, q_min = 5, q_max = 95)


In [None]:
tff.imsave(os.path.join(path_now, 'Movie_maxIntProjZ_flt.tif'), mov_flt.astype('int16'))

In [None]:
#%% Run NMF on filtered
%time nmf_flt = mov_flt.NonnegativeMatrixFactorization(n_components=50)
np.save(os.path.join(path_now, 'nmf_flt.npy'), nmf_flt)

In [None]:
q_max = 99

from skimage.util import montage
plt.figure(figsize = (20,60))
nmf_flt_space = np.array([spt.stats.saturateByPerc(_, perc_up=q_max) for _ in nmf_flt[0]])
m = montage(nmf_flt_space, rescale_intensity = True, grid_shape=(len(nmf_flt[0])//4+1, 4))
plt.imshow(m)


# plt.colorbar()

In [None]:
iComp = 30
nPre = 1

trlLen = images_reg.shape[1]
nTrls = images.shape[0]
sl = [_[-1] for _ in stimLoc]

stimInds = np.arange(nTrls)*trlLen + nPre
plt.figure(figsize = (20,15))
plt.subplot(211)
plt.imshow(spt.stats.saturateByPerc(nmf_flt[0][iComp], perc_up = 95))
plt.subplot(212)
plt.plot(nmf_flt[1][:,iComp])
for si in stimInds:
    plt.axvline(si, ls = '--', c = 'r', alpha = 0.3)
plt.xticks(stimInds, sl)
plt.xlim(0, nmf_flt[1].shape[0])

In [None]:
#%% Save nmf spatial and temporal component component images
saveDir = os.path.join(path_now, 'proc/nmf_flt/spatial comps/')
tifDir = os.path.join(saveDir,'tifs')
if not os.path.exists(saveDir):
    os.makedirs(saveDir)
    os.mkdir(tifDir)
for iComp, comp in enumerate(nmf_flt[0]):
    plt.imshow(comp, cmap = 'viridis')
    plt.axis('off')
    plt.title(f'nmf {iComp}')
    plt.savefig(os.path.join(saveDir, r'nmf_{:03}.pdf'.format(iComp)), format = 'pdf', dpi = 'figure')
    plt.savefig(os.path.join(saveDir, r'nmf_{:03}.png'.format(iComp)), format = 'png', dpi = 'figure')
    foo = (nmf[0][iComp]*(2**16-1)).astype('int')
    tff.imsave(os.path.join(tifDir, f'nmf_{iComp}.tif'),foo)

  

In [None]:
hFileName = ft.findAndSortFilesInDir(path_now, ext = 'h5', search_str='procData')[-1]
hFilePath = os.path.join(path_now,hFileName)
with h5py.File(hFilePath, mode = 'r') as hFile:
#     %time images_reg = np.array(hFile['ca_trls_reg_trialByTrial'])
    %time images_reg = np.array(hFile['ca_trls_reg'])
    stimLoc = util.to_utf(hFile['stimLocVec'])
    print(hFile.keys())

In [None]:
#%% Now save individual trial volumes as 
nPre = 3
n_comps = 5

def denoise_stack(stack, n_comps = 10):
    foo = dask.compute(*[dask.delayed(volt.ipca_denoise)(s, components = n_comps)\
                         for s in np.swapaxes(stack,0,1)])
    return np.swapaxes(np.array(foo),0,1)
saveDir = os.path.join(path_now, f'trialVolumes_avg2')
if not os.path.exists(saveDir):
    os.mkdir(saveDir)

for iTrl, stack in enumerate(images_reg):
#     pre = stack[:nPre].mean(axis = 0)
#     post = stack[nPre:].mean(axis = 0)
#     vol = post-pre
#     stack = denoise_stack(stack, n_comps = n_comps)
    vol = stack.mean(axis = 0)    
#     vol = np.median(stack, axis = 0)
    fn =  '{}_{:03d}.tif'.format(stimLoc[iTrl][-1],iTrl) 
    tff.imsave(os.path.join(saveDir,fn), vol[1:].astype('int16'), metadata = {'axes': 'ZYX'})


In [None]:
# saveDir = r'Y:\Avinash\Head-fixed tail free\GCaMP imaging\2019-11-07\f1_alx-gal4_xa316_uas-gcamp6s\registered_images_cpwr'

# reg = volt.Register(backend='dask', regMethod='cpwr')
# foo = np.swapaxes(images.reshape(-1,*images.shape[2:]), 0, 1)
# foo_reg = []
# for z, img in enumerate(foo):
#     print(f'{z+1}/{len(foo)}')
#     foo_reg.append(reg.fit(img).transform(img))

# foo_reg = np.array(foo_reg)
# foo_int = (spt.standardize(foo_reg)*(2**8)).astype('uint8')
# print('Saving...')
# for z, img in enumerate(foo_int):
#     tff.imsave(os.path.join(saveDir,r'slice_{:03d}.tif'.format(z)),img)


# Load a pre-trained U-net first, assess performance, and retrain if need be

In [None]:
dir_u = os.path.split(path_now)[0]
path_unet = os.path.join(dir_u,ft.findAndSortFilesInDir(dir_u, search_str='trainedU', ext = 'h5')[-1])

%time unet = mlearn.loadPreTrainedUnet(path_unet)


In [None]:
#%% Get the paths to all the behavior directories
roots, dirs, files = zip(*[out for out in os.walk(dir_u)])
inds = util.findStrInList('Autosave', roots)
behavDirs =np.array(roots)[inds]

In [None]:
#%% Copy images for additional training
prefFrameRangeInTrl = (510, 550)
nImgsForTraining = 100
_, dir_train = mlearn.copyImgsNNTraining(behavDirs, prefFrameRangeInTrl=prefFrameRangeInTrl,\
                                         nImgsForTraining=nImgsForTraining)

print(f'Saved at {dir_train}')

In [None]:
#%% Predict on copied images to see how good pre-trained network is
%time images_raw = volt.img.readImagesInDir(dir_train)
images = volt.img.resize(images_raw, unet.input_shape[1:3])
images_pred = np.squeeze(unet.predict(images[...,np.newaxis],verbose = 1))

In [None]:
mlearn.plotMontageOfImageCollections(images, images_pred);

In [None]:
regex = r'(\d.*\d)'
dir_train_imgs = dir_train
p, s = os.path.split(dir_train_imgs)
sffx = re.findall(regex, s)[0]
dir_train_masks =  os.path.join(os.path.split(dir_train)[0], f'masks_train_{sffx}.zip')

# aug_set = ('rn','sig','log','inv','heq','rot', 'et', 'rs')
aug_set = ('rn','sig','log','inv','heq','rot', 'rs') # For some reason 'et' augmentation stopped working (some opencv error)

unet = mlearn.retrainU(unet, dir_train_imgs, dir_train_masks, upSample=15, epochs=100, verbose=2,\
                       aug_set = aug_set)


In [None]:
%%time
#%% Test trained u-net on a set of continuous images
iTrl = 1
images_raw = volt.img.readImagesInDir(behavDirs[iTrl])
images = volt.img.resize(images_raw, unet.input_shape[1:3])
images_pred = np.squeeze(unet.predict(images[...,np.newaxis],verbose = 2))


In [None]:
#%% Look at a few images by plotting
imgInds = np.arange(510,640,2)
mlearn.plotMontageOfImageCollections(images[imgInds], images_pred[imgInds]);

# Extract and store behavior data in HDF file

In [None]:
%time hFilePath = hf.extractAndStoreBehaviorData_singleFish(path_now,uNet=unet)

In [None]:
%%time
with h5py.File(hFilePath, mode = 'r') as hFile:
    ta = np.array(hFile['behav/tailAngles'])
    ta = ta.reshape(-1,50,ta.shape[-1])
    ta_trl = np.array([_[-1] for _ in ta])
    ta_ser = np.concatenate(ta,axis = 1)
    %time ta_ser_clean, ta_svd, svd = hf.cleanTailAngles(ta_ser, None, dt = 1/500)
    ta_trl_clean = (ta_ser_clean[-1]).reshape(-1,ta_trl.shape[1])


In [None]:
dt_behav= 1/500
t_stim = 1
xl = (-0.15, 4.5)
# inds = np.random.permutation(ta_trl.shape[0])
inds = np.arange(ta_trl.shape[0])
yOff = util.yOffMat(ta_trl[inds])
t_trl = (np.arange(ta_trl.shape[1])*dt_behav) - t_stim
plt.figure(figsize = (20, 60))
plt.plot(t_trl,(ta_trl[inds]-yOff).T, c = plt.cm.tab10(0))
plt.plot(t_trl,(ta_trl_clean[inds]-yOff).T, c = plt.cm.tab10(1));
plt.xlim(t_trl.min(), t_trl.max())
plt.yticks(-yOff, np.arange(len(inds)))
plt.xlabel('Time (s)')
plt.ylabel('Trial #')

# plt.xlim(xl);


In [None]:
dt_behav= 1/500
t_stim = 1
xl = (-0.15,4.5)
figName = f'Fig-{util.timestamp()}_Behavior trials_tail bend amplitude'
figDir = os.path.join(path_now,'figs')
if not os.path.exists(figDir):
    os.mkdir(figDir)
# inds = np.random.permutation(ta_trl.shape[0])
inds = np.arange(ta_trl.shape[0]-1)
inds = np.random.permutation(inds)
yOff = util.yOffMat(ta_trl_clean[inds])
t_trl = (np.arange(ta_trl.shape[1])*dt_behav) - t_stim
plt.figure(figsize = (20, 30))
# plt.plot(t_trl,(ta_trl[inds]-yOff).T, c = plt.cm.tab10(0))
plt.plot(t_trl,(ta_trl_clean[inds]-yOff).T);
plt.xlim(xl)
plt.yticks(-yOff.ravel(), np.arange(len(inds))+1);
plt.ylabel('Trial #')
plt.xlabel('Time (s)')
plt.title('Behavior trials (total tail bend angle)')
plt.savefig(os.path.join(figDir,figName + '.pdf'), format = 'pdf', dpi = 'figure')
plt.savefig(os.path.join(figDir,figName + '.png'), format = 'png', dpi = 'figure')


# Read Ca and behavior trials from HDF file and put in a datframe

In [None]:
%%time
#%% Read Ca2+ and matching behavior trials from HDF file and create a dataframe

def get_session_inds(stimLocs, regex = r'\d{1,5}'):
    import re
    inds = []
    for s in stimLocs:
        inds.append(np.int(re.findall(regex, s)[0]))
    return np.array(inds)

if not 'hFilePath' in locals():
    
with h5py.File(hFilePath, mode = 'r') as hFile:
    ta_trl = hFile['behav/tailAngles'][()]
    ta_trl = ta_trl.reshape(-1, 50, ta_trl.shape[-1])
    ta_ser = np.concatenate(ta_trl,axis = 1)
    %time ta_ser_clean, ta_svd, svd = hf.cleanTailAngles(ta_ser, None, dt = 1/500)
    ta_trl_clean = np.transpose(ta_ser_clean.reshape(ta_ser_clean.shape[0],-1,ta_trl.shape[-1]),\
                                (1,0,2))
    stimLoc_ca = [sl for sl in np.array(util.to_utf(hFile['stimLocVec']))]
    stimLoc_behav = util.to_utf(hFile['behav/stimLoc'][()])
    ca_trl = np.array(hFile['ca_trls_reg'][()])
    blocks, inds_blocks = util.get_blocks_of_repeats(stimLoc_behav)
    inds_excluded = hFile['inds_excluded']
    keys = list(inds_excluded.keys())
    foo = []
    for iSub, inds_ in enumerate(inds_blocks):
        k = keys[iSub]
        inds_del = inds_excluded[k]
        inds_now = np.delete(inds_, inds_del, axis = 0)
        foo.extend(ta_trl_clean[inds_now])
    ta_trl_clean = np.array(foo)
    

# blocks_ca, inds_ca = util.get_blocks_of_repeats(stimLoc_ca)
# blocks_behav, inds_behav = util.get_blocks_of_repeats(stimLoc_behav)
# inds_keep_ca, inds_keep_behav = [],[]
# for iBlock in range(len(blocks_behav)):
#     ic, ib = inds_ca[iBlock], inds_behav[iBlock]
#     n = np.minimum(len(ic), len(ib))
#     inds_keep_ca.extend(ic[:n])
#     inds_keep_behav.extend(ib[:n])
# stimLoc_ca, stimLoc_behav = np.array(stimLoc_ca)[inds_keep_ca], np.array(stimLoc_behav)[inds_keep_behav]
# inds_session = get_session_inds(stimLoc_ca)[inds_keep_ca]
# getLastChar = lambda x: np.array([_[-1] for _ in x])
# stimLoc_ca, stimLoc_behav = map(getLastChar, (stimLoc_ca, stimLoc_behav))
# ca_trl = np.take(ca_trl, inds_keep_ca, axis = 0)
# ta_trl_clean = np.take(ta_trl_clean, inds_keep_behav, axis = 0)

# dic = dict(tailAngles = list(ta_trl_clean), ca = list(ca_trl), \
#            trlNum = np.arange(ta_trl_clean.shape[0]), stimLoc = stimLoc_ca,\
#            sessionIdx = list(inds_session), hFilePath = hFilePath)

# df = pd.DataFrame(dic, columns= dic.keys())

# path_df = os.path.join(path_now, 'dataFrame.pickle')
# %time df.to_pickle(path_df)

# Continue from the saved dataframe

In [None]:
#%% Read saved pickle file for continuing analysis from here
file_df = ft.findAndSortFilesInDir(path_now, ext = 'pickle', search_str='dataFrame')[-1]
%time df = pd.read_pickle(os.path.join(path_now, file_df))
ta_trl = np.array([np.array(_) for _ in df.tailAngles])
ta = np.concatenate(ta_trl,axis = 1)

## Use the group data trained GMM model to predict posture labels

In [None]:
#%% Load the trained model

dir_group = r'Y:\Avinash\Projects\RS recruitment\GCaMP imaging\Group'
fName = 'gmm_fitter_svd-3_gmm-22.pkl' 
fitter = joblib.load(os.path.join(dir_group, fName))

### *This is a 22 Gaussian components model because this number appears to be a global minimum for the function of AIC/BIC vs number # of gaussian components*

In [None]:
#%% Fit GMM model with specified number of components, predict labels for data, 
## and plot in low dimensions with PCA using labels as colors

subSample = 1
alpha = 0.5
cmap = plt.cm.gist_ncar

%matplotlib inline

%time labels, arr_feat =  my_gmm.predict_on_tailAngles_svd(ta, fitter, subSample=subSample)

%time x_pca = PCA(n_components = 3, random_state = 143).fit_transform(arr_feat)

fh,ax = plt.subplots(2,2,figsize = (15,10))
ax = ax.flatten()


clrs = cmap(spt.standardize(labels))

ax[0].scatter(x_pca[:,0], x_pca[:,1], s = 10, c = clrs, alpha = alpha)
ax[0].set_xlabel('pca 1')
ax[0].set_ylabel('pca 2')

ax[1].scatter(x_pca[:,0], x_pca[:,2], s = 10, c = clrs, alpha = alpha)
ax[1].set_xlabel('pca 1')
ax[1].set_ylabel('pca 3')

ax[2].scatter(x_pca[:,1], x_pca[:,2], s = 10, c = clrs, alpha = alpha)
ax[2].set_xlabel('pca 2')
ax[2].set_ylabel('pca 3')
fh.tight_layout()

x = np.unique(labels)
y = np.ones_like(x)
plt.figure(figsize = (20,5))
plt.scatter(x,y, c = cmap(spt.standardize(x)),s = 4000, marker = 's')
plt.yticks([])
plt.xticks(x, fontsize = 20);
plt.title('Norm-ordered colors', fontsize = 20);


In [None]:
#%% Check ouf a few trials with predictions from the GMM

iTrl = 6  # (Struggles = {9, 11}
# iTrl = np.random.choice(np.arange(df.shape[0]),size = 1)[0]
yShift = 1.1
loop = False
xl = (-0.1, 1)
# xl = 'auto'
figDir = os.path.join(path_now, 'figs/behav_trls_colored_by_gmm_label')
onOffThr = 0
figExts = ('png','pdf')
cmap = plt.cm.gist_ncar
dt_behav = 1/500


if not os.path.exists(figDir):
    os.mkdir(figDir)

from IPython import display
if loop:
    trls = np.unique(df.trlNum)
else:
    trls = [iTrl]

maxEnv_full = spt.emd.envelopesAndImf(ta[-1])['env']['max']    
for iTrl in trls:
    df_now = df.loc[df.trlNum == iTrl]
#     arr = np.array(df_features.loc[df_features.trlNum == iTrl].drop(columns = 'trlNum'))
#     ta_now = np.array(df_now.iloc[0]['tailAngles'])
    ta_now = ta_trl[iTrl]
    trlLen = ta_now.shape[-1]
    inds = np.arange(iTrl*trlLen, (iTrl+1)*trlLen)
    maxEnv = maxEnv_full[inds]
    posInds = np.linspace(0, len(ta_now)-1,4).astype(int)
    z = np.diff(ta_now[posInds],axis = 0)
    ys = util.yOffMat(z)*yShift
    z = z - ys
    y = ta_now[-1]
    x = (np.arange(len(y))*dt_behav) - 1   
    lbls_now, x_now = my_gmm.predict_on_tailAngles_svd(ta_now, fitter)
    lbls_norm = lbls_now/(labels.max()-labels.min())
    clrs = cmap(lbls_norm)
    zerInds = np.where(maxEnv <=onOffThr)[0]
    clrs[zerInds,:] = 1 # Make this invisible
    # clrs[:,-1] = 0.8 # Alpha value of points
    fh = plt.figure(figsize = (20,8));    
    # plt.scatter(x,y,c = clrs,s= 15)
    for iLine, z_ in enumerate(z):
        plt.plot(x,z_, c = plt.cm.tab10(iLine), alpha = 0.2, label = f'loc = {iLine}')
        plt.scatter(x,z_, c = clrs, s = 15)
        plt.legend(loc = 'upper right', fontsize= 15)
    plt.xlabel('Time (s)', fontsize = 20)
    plt.ylabel('Total tail bending for segment ($^o$)', fontsize = 20)
    maxInd = []
    if np.any(xl == 'auto'):
        nonZerInds = np.setdiff1d(np.arange(len(y)),zerInds)        
        if len(nonZerInds)>0:
            maxInd = np.max(nonZerInds)
        else:
            maxInd = len(x)-31
        plt.xlim(-0.1,x[maxInd])
    else:
        plt.xlim(xl)
    yl = np.min((-100, z.min())), np.max((80, z.max()))
    plt.ylim(yl)
    plt.yticks([0,-100])
    plt.title(f'Total tail curvature with points colored by cluster label (trl = {iTrl})', fontsize = 20)
    plt.show()
    for fe in figExts:
        fn = f'Fig-{util.timestamp()}_Total tail curvature timeseries with colored clustered points_trl-{iTrl}.{fe}'
#         fh.savefig(os.path.join(figDir, fn), format = f'{fe}', dpi = 'figure')
    display.clear_output(wait = True)
    time.sleep(0.05)



In [None]:
#%% Load the trained model

dir_group = r'Y:\Avinash\Projects\RS recruitment\GCaMP imaging\Group'
file_model = 'gmm_svd-3_env_pca-9_gmm-20_20200129-18.pkl' 
gmm_model = joblib.load(os.path.join(dir_group, file_model))

In [None]:
#%% Fit GMM model with specified number of components, predict labels for data, 
## and plot in low dimensions with PCA using labels as colors

subSample = 1
alpha = 0.5
cmap = plt.cm.tab20

%matplotlib inline

%time labels, features =  gmm_model.predict(ta[:,::subSample])

%time x_pca = PCA(n_components = 3, random_state = 143).fit_transform(features)

fh,ax = plt.subplots(2,2,figsize = (15,10))
ax = ax.flatten()


clrs = cmap(spt.standardize(labels))

ax[0].scatter(x_pca[:,0], x_pca[:,1], s = 10, c = clrs, alpha = alpha)
ax[0].set_xlabel('pca 1')
ax[0].set_ylabel('pca 2')

ax[1].scatter(x_pca[:,0], x_pca[:,2], s = 10, c = clrs, alpha = alpha)
ax[1].set_xlabel('pca 1')
ax[1].set_ylabel('pca 3')

ax[2].scatter(x_pca[:,1], x_pca[:,2], s = 10, c = clrs, alpha = alpha)
ax[2].set_xlabel('pca 2')
ax[2].set_ylabel('pca 3')
fh.tight_layout()

x = np.unique(labels)
y = np.ones_like(x)
plt.figure(figsize = (20,5))
plt.scatter(x,y, c = cmap(spt.standardize(x)),s = 4000, marker = 's')
plt.yticks([])
plt.xticks(x, fontsize = 20);
plt.title('Norm-ordered colors', fontsize = 20);


In [None]:
#%% Check ouf a few trials with predictions from the GMM

# iTrl = 10  # (Struggles = {9, 11}
iTrl = np.random.choice(np.arange(df.shape[0]),size = 1)[0]
yShift = 1.1
loop = False
xl = (-0.1, 4.5)
# xl = 'auto'
figDir = os.path.join(path_now, 'figs/behav_trls_colored_by_gmm_label')
onOffThr = 5
figExts = ('png','pdf')
cmap = plt.cm.tab20
dt_behav = 1/500
annotate_markers = True


if not os.path.exists(figDir):
    os.mkdir(figDir)

from IPython import display
if loop:
    trls = np.unique(df.trlNum)
else:
    trls = [iTrl]

maxEnv_full = spt.emd.envelopesAndImf(ta[-1])['env']['max']    
for iTrl in trls:
    df_now = df.loc[df.trlNum == iTrl]
    ta_now = ta_trl[iTrl]
    trlLen = ta_now.shape[-1]
    inds = np.arange(iTrl*trlLen, (iTrl+1)*trlLen)
    maxEnv = maxEnv_full[inds]
    posInds = np.linspace(0, len(ta_now)-1,4).astype(int)
    z = np.diff(ta_now[posInds],axis = 0)
    ys = util.yOffMat(z)*yShift
    z = z - ys
    y = ta_now[-1]
    x = (np.arange(len(y))*dt_behav) - 1   
    lbls_now, _ = gmm_model.predict(ta_now)
    lbls_norm = lbls_now/(labels.max()-labels.min())
    clrs = cmap(lbls_norm)
    zerInds = np.where(maxEnv <=onOffThr)[0]
    clrs[zerInds,:] = 1 # Make this invisible
    # clrs[:,-1] = 0.8 # Alpha value of points
    fh = plt.figure(figsize = (20,8))    
    for iLine, z_ in enumerate(z):
#         plt.plot(x,z_, c = cmap(iLine), alpha = 0.2, label = f'loc = {iLine}')
        plt.plot(x,z_, c = 'k', alpha = 0.2, label = f'loc = {iLine}')
        plt.scatter(x,z_, c = clrs, s = 15)
        plt.legend(loc = 'upper right', fontsize= 15)
    plt.xlabel('Time (s)', fontsize = 20)
    plt.ylabel('Total tail bending for segment ($^o$)', fontsize = 20)
    maxInd = []
    if np.any(xl == 'auto'):
        nonZerInds = np.setdiff1d(np.arange(len(y)),zerInds)        
        if len(nonZerInds)>0:
            maxInd = np.max(nonZerInds)
        else:
            maxInd = len(x)-31
        plt.xlim(-0.1,x[maxInd])
    else:
        plt.xlim(xl)
    yl = np.min((-100, z.min())), np.max((80, z.max()))
    plt.ylim(yl)
    plt.yticks([0,-100])
    plt.title(f'Total tail curvature with points colored by cluster label (trl = {iTrl})', fontsize = 20)
    plt.show()
    for fe in figExts:
        fn = f'Fig-{util.timestamp()}_Total tail curvature timeseries with colored clustered points_trl-{iTrl}.{fe}'
#         fh.savefig(os.path.join(figDir, fn), format = f'{fe}', dpi = 'figure')
    display.clear_output(wait = True)
    time.sleep(0.05)



In [None]:
#%% Check ouf a few trials with predictions from the GMM with annotated markers

# iTrl = 10  # (Struggles = {9, 11}
iTrl = np.random.choice(np.arange(df.shape[0]),size = 1)[0]
yShift = 1.1
loop = False
xl = (-0.1, 4.5)
# xl = 'auto'
figDir = os.path.join(path_now, 'figs/behav_trls_colored_by_gmm_label')
onOffThr = 0
figExts = ('png','pdf')
cmap = plt.cm.tab20
dt_behav = 1/500

%matplotlib qt
# %matplotlib notebook


if not os.path.exists(figDir):
    os.mkdir(figDir)

trls = [iTrl]    
maxEnv_full = spt.emd.envelopesAndImf(ta[-1])['env']['max']    
for iTrl in trls:
    df_now = df.loc[df.trlNum == iTrl]
    ta_now = ta_trl[iTrl]
    trlLen = ta_now.shape[-1]
    inds = np.arange(iTrl*trlLen, (iTrl+1)*trlLen)
    maxEnv = maxEnv_full[inds]
    posInds = np.linspace(0, len(ta_now)-1,4).astype(int)
    y = ta_now[-1]
    x = (np.arange(len(y))*dt_behav) - 1   
    lbls_now, _ = gmm_model.predict(ta_now)
    lbls_norm = lbls_now/(labels.max()-labels.min())
    clrs = cmap(lbls_norm)
    zerInds = np.where(maxEnv <=onOffThr)[0]
    clrs[zerInds,:] = 1 # Make this invisible
    # clrs[:,-1] = 0.8 # Alpha value of points
    fh = plt.figure(figsize = (20,8))
    plt.plot(x,y, c = 'k', alpha = 0.2)
    for lbl_ in np.unique(lbls_now):        
        inds = np.where(lbls_now == lbl_)[0]
        inds = inds[::2]
        if len(inds)>0:
            plt.scatter(x[inds],y[inds], c = clrs[inds], s = 150, marker = f"${str(lbl_)}$")    
    plt.xlabel('Time (s)', fontsize = 20)
    plt.ylabel('Total tail bending for segment ($^o$)', fontsize = 20)
    maxInd = []
    if np.any(xl == 'auto'):
        nonZerInds = np.setdiff1d(np.arange(len(y)),zerInds)        
        if len(nonZerInds)>0:
            maxInd = np.max(nonZerInds)
        else:
            maxInd = len(x)-31
        plt.xlim(-0.1,x[maxInd])
    else:
        plt.xlim(xl)
    yl = np.min((-100, z.min())), np.max((80, z.max()))
    plt.ylim(yl)
    plt.yticks([0,-100])
    plt.title(f'Total tail curvature with points colored by cluster label (trl = {iTrl})', fontsize = 20)
    for fe in figExts:
        fn = f'Fig-{util.timestamp()}_Total tail curvature timeseries with colored clustered points_trl-{iTrl}.{fe}'
#         fh.savefig(os.path.join(figDir, fn), format = f'{fe}', dpi = 'figure')
    



In [None]:

# def labelsToIrs(labels, maxEnv, nClust:int = 5, thr = 10, subSample:int = 1):
#     """Impulse resonse array from labels, (nLabels, nTimePoints)"""
#     lbls_unique = np.unique(labels)
#     zerInds = np.where(maxEnv<thr)[0]
#     ir = np.zeros((nClust, len(labels)))
#     for lbl in np.sort(lbls_unique):
#         inds = np.where(labels == lbl)[0]
#         ir[lbl,inds] = 1
#     subVec = np.zeros_like(labels)
#     subVec[::subSample] = 1
#     ir[:,zerInds] = 0     
#     ir = ir*subVec[np.newaxis,:]    
#     return ir*subVec[np.newaxis,:]

def superSample(t,y,tt):
    """Super sample a signal using interpolation"""
    import numpy as np
    from scipy.interpolate import interp1d
    t = np.concatenate((tt[0].reshape((-1,)), t, tt[-1].reshape((-1,))))
    y = np.concatenate((np.array(0).reshape((-1,)),y,np.array(0).reshape((-1,))))
    f = interp1d(t,y,kind = 'slinear')
    return f(tt)

def padIr(ir_trl, pad_pre, pad_post):
    """
    Pads the impulse response timeseries obtained from 
    predictions on behavioral feature matrix to match
    time length with ca responses
    """
    ir_ser = []
    for c in ir_trl:
        ir_ser.append(np.pad(c,((0,0),(pad_pre, pad_post))).flatten())
    return np.array(ir_ser)

def serializeHyperstack(vol):
    """
    Given, a hyperstack, returns a 2D array with pixels serialized for regression, etc.
    Parameters
    ----------
    vol: array, (nTimePoints, nSlices, nRows, nCols)
    Returns
    -------
    vol_ser: array, (nTimePoints,nPixels)
    """
    vol_trans = np.transpose(vol,(2,3,1,0))
    vol_ser = vol_trans.reshape(-1, vol_trans.shape[-1])
    vol_ser = np.swapaxes(vol_ser,0,1)
    return vol_ser

def deserializeToHyperstack(arr, volDims):
    """
    Given an array which 
    """
    volDims = (np.array(volDims))[[1,2,0]]
    vol = arr.reshape(arr.shape[0],*volDims)
    vol = np.transpose(vol,(0,3,1,2))
    return vol

def pxlsToVol(pxls, volDims):
    """
    Given an array which 
    """
    volDims = (np.array(volDims))[[1,2,0]]
    vol = pxls.reshape(*volDims)
    vol = np.transpose(vol,(2,0,1))
    return vol

def superSample_arr(t, arr, tt, n_jobs = 32):
    """
    Parameters
    ----------
    arr: array, (nSignals, nTimePoints)
    """    
    from joblib import Parallel, delayed
    n_jobs = np.min((32, os.cpu_count()))
#     from dask import delayed, compute
#     arr_sup = compute(*[delayed(superSample)(t,y,tt) for y in arr], scheduler = 'processes')
    arr_sup = Parallel(n_jobs=n_jobs,verbose=1)(delayed(superSample)(t, y, tt) for y in arr)
    return np.array(arr_sup)

def betasToVol(betas, volDims):
    if np.ndim(betas)<2:
        betas = betas[:,np.newaxis]  
    nReg = betas.shape[1]
    B = betas.T.reshape(nReg, *volDims)
    return np.squeeze(B)

In [None]:
#%% Select labels based on amplitude threshold because many labels are fitting background activity.
cmap= plt.cm.tab20
%matplotlib inline

labels_unique = np.unique(labels)
ampMeans = np.zeros((len(labels_unique),))
for lbl in labels_unique:
    inds = np.where(labels == lbl)[0]
    ampMeans[lbl] = maxEnv_full[inds].mean()
sortInds = np.argsort(ampMeans)
plt.figure(figsize = (14,8))
clrs = cmap(spt.standardize(labels_unique))
x = np.arange(len(ampMeans))
for i in x:
    plt.scatter(i, ampMeans[sortInds][i],c =clrs[i].reshape(1,-1),  marker = f"${str(sortInds[i])}$", s= 200)
plt.xticks(x)
plt.xlabel('Label #')
plt.ylabel('Mean crest envelope amplitude for label')
plt.grid()
    

In [None]:
%%time
#%% Predict labels on full time series, match lengths of behavior and ca trials, and make full set of impulse 
### trains and other regressors
thr_labelAmp = 10 # Set amplitude threshold based on the above graph
tPeriStim_behav = (-1,6) # Pre- and pos-stim periods in seconds for behavior trials
tPeriStim_ca = (-1,10) # Pre- and post-stim periods in seconds for ca trials
Fs_behav = 500

labels_sel = np.where(ampMeans>=thr_labelAmp)[0]
ir, names_ir = hf.impulse_trains_from_labels(labels, ta, labels_sel= labels_sel)
# yOff = util.yOffMat(ir)
# plt.figure(figsize = (20,20))
# plt.plot((ir-yOff).T);
# plt.yticks(-yOff, names_ir);

pad_post = int((tPeriStim_ca[-1]-tPeriStim_behav[-1])*Fs_behav)
n_pre_behav = int(np.abs(tPeriStim_behav[0])*Fs_behav)
stimLoc = np.array(df.stimLoc)
stimLoc_unique = np.unique(stimLoc)
sessionIdx  = np.array(df.sessionIdx)
sessionIdx_unique = np.unique(sessionIdx)
nSessions = len(sessionIdx_unique)

nTrls = df.shape[0]
ir_trl = np.transpose(ir.reshape(ir.shape[0], nTrls,-1),(1,0,2))
foo = []
count = 1
for sl, trl in zip(stimLoc, ir_trl):
    ht = np.zeros((len(stimLoc_unique),trl.shape[-1]))
    ind = np.where(stimLoc_unique == sl)[0]
    ht[ind,n_pre_behav-1]=1 
    trl_ht = np.r_[trl,ht]
    blah = np.pad(trl_ht,((0,0),(0,pad_post)), mode = 'constant')
    trl_prog = np.ones((1,blah.shape[-1]))*(count/ir_trl.shape[0])
    session_now = sessionIdx[count-1]
    session_idx = np.zeros((nSessions,blah.shape[-1]))*(count/ir_trl.shape[0])
    session_idx[session_now-1,:] = 1
    foo.append(np.r_[blah,trl_prog, session_idx])
    count += 1
ir_trl = np.array(foo)
ir_ser = np.concatenate(ir_trl,axis = 1)


In [None]:
#%% Display impulse trains & other regressors
getStimName = lambda s: 'Head' if s == 'h' else 'Tail'
t_full = np.arange(ir_ser.shape[-1])*(1/Fs_behav)
yOff = util.yOffMat(ir_ser)
# ytl = np.arange(ir_ser.shape[0])
plt.figure(figsize = (16,8))
plt.plot(t_full, (ir_ser-yOff).T);
# ytl = [f'Mot-{i}' for i in np.arange(ir_ser.shape[0]-len(stimLoc_unique)-1)]
# ytl = [f'Mot-{i}' for i in np.arange(len(names_ir))]
ytl = list(names_ir)
ytl.extend([getStimName(s) for s in stimLoc_unique])
ytl.extend(['Trial progress'])
for idx in sessionIdx_unique:
    ytl.extend([f'Session-{idx}'])
regNames = ytl
yt = -np.arange(ir_ser.shape[0])
plt.yticks(yt, regNames)
plt.xlabel('Time (s)')
plt.title('Impulse responses & other regressors');

In [None]:
#%% Read saved dataframe if continuing from here
file_df = ft.findAndSortFilesInDir(path_now, ext = 'pickle', search_str='dataFrame')[-1]
%time df = pd.read_pickle(os.path.join(path_now,file_df))

In [None]:
%%time
#%% CIRF in slightly subSampled behavAndScan time, followed by convolution to generate regressors
tLen = 6 # Length of kernel
tau_rise = 0.2 # Rise constant
tau_decay = 1 # Decay constant
dt_behav = 1/500
# dt_behav_new = 1/50

### CIRF
t_cirf = np.arange(0,tLen,dt_behav)
cirf = spt.generateEPSP(t_cirf,tau_rise, tau_decay,1,0)

regressors= []
for y in ir_ser[:-1]:
    regressors.append(spt.standardize(np.convolve(y, cirf, mode = 'full')[:len(y)]))
regressors.append(ir_ser[-1])
regressors = np.array(regressors)
# t = np.arange(ir_ser.shape[-1])*dt_behav
# t_sub = np.arange(t[0], t[-1],dt_behav_new)

ca_trl = np.array([np.array(_) for _ in np.array(df['ca'])])
t_behav = np.linspace(0,1,regressors.shape[1])
t_ca = np.linspace(0,1,ca_trl.shape[0]*ca_trl.shape[1])

regressors = superSample_arr(t_behav, regressors, t_ca)

%time ca_ser = serializeHyperstack(np.concatenate(ca_trl,axis = 0))
# t_behav = np.arange(regressors.shape[-1])*dt_behav_new
# t_ca = np.linspace(t_behav[0], t_behav[-1], ca_ser.shape[1])

if 'hFilePath' not in locals():
    hFileName = ft.findAndSortFilesInDir(path_now, ext = 'h5', search_str='procData')[-1]
    hFilePath = os.path.join(path_now, hFileName)
    
with h5py.File(hFilePath, mode = 'r+') as hFile:
    if 'regression' in hFile:
        del hFile['regression']
    grp = hFile.create_group('regression')   
    grp.create_dataset('regressors', data = regressors.T)
    grp.create_dataset('regressor_names', data = util.to_ascii(regNames))
    grp.create_dataset('impulse_trains', data = ir_ser)


In [None]:
#%% Plot all regressors
yOff = util.yOffMat(regressors)
plt.figure(figsize = (20,10))
plt.plot(t_ca,(regressors-yOff).T)
plt.xlim(t_ca.min(), t_ca.max())
plt.yticks(-yOff, ytl)
plt.xlabel('Time (s)')
plt.title('Regressors');

In [None]:
# inds = [10,11]
# inds_diff = np.setdiff1d(np.arange(regressors.shape[0]), inds)
# V = regressors[inds].T
# W = regressors[inds_diff].T
# W_orth = spt.linalg.orthogonalizeOnSpace(V,W).T
# X_stimHT_orth = np.c_[V,W_orth]
# print(X_stimHT_orth.shape)

# yOff = util.yOffMat(X.T).T
# plt.figure(figsize = (20,8))
# plt.plot(X-yOff)

# yOff = util.yOffMat(X_stimHT_orth.T).T
# plt.figure(figsize = (20,8))
# plt.plot(X_stimHT_orth-yOff);

In [None]:
#%% Read saved dataframe if continuing from here
file_df = ft.findAndSortFilesInDir(path_now, ext = 'pickle', search_str='dataFrame')[-1]
%time df = pd.read_pickle(os.path.join(path_now,file_df))
%time ca_trl = np.array([np.array(_) for _ in np.array(df['ca'])])

In [None]:
%%time
#%% Denoise and filter images before regression
filtSize = 0.75

images_reg_ser = ca_trl.reshape(-1, *ca_trl.shape[2:])
images_reg_ipca_flt = []
for iSlc, slc in enumerate(np.swapaxes(images_reg_ser,0,1)):
    print(f'{iSlc + 1}/{images_reg_ser.shape[1]}')
    slc_den = volt.denoise_ipca(slc)
    slc_flt = volt.img.gaussFilt(slc_den, sigma = filtSize)
    images_reg_ipca_flt.append(slc_flt)

images_reg_ipca_flt = np.swapaxes(np.array(images_reg_ipca_flt),0,1)

if 'hFilePath' not in locals():
    hFileName = ft.findAndSortFilesInDir(path_now, ext = 'h5', search_str='procData')[-1]
    hFilePath = os.path.join(path_now, hFileName)
    
with h5py.File(hFilePath, mode = 'r+') as hFile:
    keyName = f'images_reg_ipca_flt_sigma-{int(filtSize*100)}'
    if keyName in hFile:
        del hFile[keyName]
    %time hFile.create_dataset(keyName, data = images_reg_ipca_flt)

In [None]:
%%time
#%% Read relevant variables for regression
hFileName = ft.findAndSortFilesInDir(path_now, ext = 'h5', search_str='procData')[-1]
hFilePath = os.path.join(path_now, hFileName)
with h5py.File(hFilePath, mode = 'r') as hFile:
    print(hFile.keys())
    images = np.array(hFile[f'images_reg_ipca_flt_sigma-{int(filtSize*100)}'])
    X_reg = np.array(hFile['regression/regressors'])
    regNames = util.to_utf(np.array(hFile['regression/regressor_names']))
    if 'images_reg_ipca_flt' in locals():
        del images_reg_ipca_flt

## Regression

In [None]:
#%% Regress

ca_ser = images.reshape(images.shape[0],-1)
%time regObj = regress(X_reg,ca_ser, n_jobs=-1, fit_intercept=True)

betas_vol = betasToVol(regObj.coef_, images.shape[-3:])
intercept_vol = betasToVol(regObj.intercept_, images.shape[-3:])
t_vol = betasToVol(regObj.T_, images.shape[-3:])
r_vol = betasToVol(regObj.Rsq_adj_,images.shape[-3:])

In [None]:
iReg = 23
q_max = 99
q_min = 10
print(regNames[iReg])
plt.figure(figsize = (20,10))
plt.imshow(spt.stats.saturateByPerc(betas_vol[iReg][1:].max(axis= 0), perc_up = q_max, perc_low = q_min))
plt.colorbar()

In [None]:
%%time
#%% Save regression images
figDir = os.path.join(path_now, f'figs/regression_ipca_flt_sigma-{int(filtSize*100)}')
t_mult = 1000 # Multiply t-values by this value before converting to integer type because of low bit-depth otherwise

if not os.path.exists(figDir):
    os.mkdir(figDir)

### First save coefficients
foo = betas_vol.astype(int)
dir_now = os.path.join(figDir, 'betas')
if not os.path.exists(dir_now):
    os.mkdir(dir_now)
for iReg, vol in enumerate(foo):
    tff.imsave(os.path.join(dir_now,f'Fig-{util.timestamp()}_regressor-{regNames[iReg]}_coef.tif'),vol[1:])
tff.imsave(os.path.join(dir_now,f'Fig-{util.timestamp()}_regressor_intercept_coef.tif'),intercept_vol)
    
foo = ((t_vol*t_mult).astype(int))[1:]
dir_now = os.path.join(figDir, 'tValues')
if not os.path.exists(dir_now):
    os.mkdir(dir_now)
for iReg, vol in enumerate(foo):
    tff.imsave(os.path.join(dir_now,f'Fig-{util.timestamp()}_regressor-{regNames[iReg]}_tVals.tif'),vol[1:])
tff.imsave(os.path.join(dir_now,f'Fig-{util.timestamp()}_regressor_intercept_T.tif'),foo[0])
  

## Run regression with normalize parameter set to true

In [None]:
#%% Regress

ca_ser = images.reshape(images.shape[0],-1)
%time regObj = regress(X_reg,ca_ser, n_jobs=-1, fit_intercept=True, normalize = True) # NB: If fit_intercept= False, then normalize is ignored

betas_vol = betasToVol(regObj.coef_, images.shape[-3:])
intercept_vol = betasToVol(regObj.intercept_, images.shape[-3:])
t_vol = betasToVol(regObj.T_, images.shape[-3:])
r_vol = betasToVol(regObj.Rsq_adj_,images.shape[-3:])

In [None]:
iReg = 0
q_max = 95
print(regNames[iReg])
plt.figure(figsize = (20,10))
plt.imshow(spt.stats.saturateByPerc(betas_vol[iReg].max(axis= 0), perc_up = q_max))
plt.colorbar();

In [None]:
vol.shape

In [None]:
%%time
#%% Save regression images
figDir = os.path.join(path_now, f'figs/regression_ipca_flt_sigma-{int(filtSize*100)}_normalizedRegressors')
if not os.path.exists(figDir):
    os.mkdir(figDir)

### First save coefficients
foo = np.round(betas_vol).astype(int)
dir_now = os.path.join(figDir, 'betas')
if not os.path.exists(dir_now):
    os.mkdir(dir_now)
for iReg, vol in enumerate(foo):
    tff.imsave(os.path.join(dir_now,f'Fig-{util.timestamp()}_regressor-{regNames[iReg]}_coef.tif'),vol[1:])
tff.imsave(os.path.join(dir_now,f'Fig-{util.timestamp()}_regressor_intercept_coef.tif'),
           intercept_vol.astype(int)[1:])
    
foo = (np.round(t_vol).astype(int))[1:]
dir_now = os.path.join(figDir, 'tValues')
if not os.path.exists(dir_now):
    os.mkdir(dir_now)
for iReg, vol in enumerate(foo):
    tff.imsave(os.path.join(dir_now,f'Fig-{util.timestamp()}_regressor-{regNames[iReg]}_tVals.tif'),vol[1:])
tff.imsave(os.path.join(dir_now,f'Fig-{util.timestamp()}_regressor_intercept_T.tif'),foo[0][1:])
  

## Some extra stuff

In [None]:
nTrls = df.shape[0]
trlLen = X_reg.shape[0]//nTrls
X_reg_trl = X_reg.reshape(nTrls, trlLen, -1)
stimLoc = np.array(df.stimLoc)
inds_mot = util.findStrInList('Mot',util.to_utf(regNames))
inds_nonMot = np.setdiff1d(np.arange(X_reg.shape[1]),inds_mot)
inds_head = np.where(stimLoc== 'h')[0]
inds_tail = np.where(stimLoc == 't')[0]
bool_head = np.zeros_like(X_reg_trl)
bool_tail = np.zeros_like(X_reg_trl)
bool_head[inds_head] = 1
bool_tail[inds_tail]= 1
reg_head = (bool_head*X_reg_trl).sum(axis = 2).flatten()
reg_tail = (bool_tail*X_reg_trl).sum(axis = 2).flatten()
X_reg_ht_motor = spt.standardize(np.c_[reg_head,reg_tail, X_reg[:,inds_nonMot]],axis = 0)
regNames_ht_motor = np.union1d(['head_motor', 'tail_motor'], regNames[inds_nonMot])

In [None]:
plt.figure(figsize = (20,5))
plt.plot(X_reg_ht_motor-np.arange(X_reg_ht_motor.shape[1]).reshape(1,-1));

In [None]:
%%time
#%% Save regression images
figDir = os.path.join(path_now, 'figs/regression_den_flt_ht_motor')
if not os.path.exists(figDir):
    os.mkdir(figDir)

### First save coefficients
foo = betas_vol.astype(int)
for iReg, vol in enumerate(foo):
    tff.imsave(os.path.join(figDir,f'Fig-{util.timestamp()}_regressor-{regNames_ht_motor[iReg]}_coef.tif'),vol)

foo = t_vol.astype(int)
for iReg, vol in enumerate(foo):
    tff.imsave(os.path.join(figDir,f'Fig-{util.timestamp()}_regressor-{regNames_ht_motor[iReg]}_tVals.tif'),vol)

In [None]:
foo = regObj.T_.T[1:]
for iReg, r in enumerate(foo):
    vol = pxlsToVol(r, volDims)
    vol_int = (spt.zscore(vol)*(2**16)-1).astype('int32')
#     vol_int  = np.round(vol).astype('int32')
    tff.imsave(os.path.join(figDir,f'Fig-{util.timestamp()}_regressor-{iReg}_T.tif'),vol_int)

In [None]:
betas_real = np.random.choice(np.arange(-100,100), size = X_reg.shape[1], replace = True)
Y_real = X_reg@betas_real.reshape(-1,1)
Y_real = Y_real + np.random.randn(*Y_real.shape)

In [None]:
tVals = np.array([r.tvalues for r in reg_ols])
tVals = tVals[:,1:]
tVals.shape

In [None]:
t1 = np.linspace(-1,1,1000)
t2 = np.linspace(-1,1,1000)
X,Y = np.meshgrid(t1,t2)
pos = np.dstack([X,Y])
x = spt.gaussFun(t1, mu = 1)
y = spt.gaussFun(t2, sigma = 1.1)
# x = np.sin(2*np.pi*t1)
# y = np.cos(2*np.pi*t1)
V = np.c_[x,y].T
S = np.abs(np.cov(V))
# S = np.eye(2)
N = np.linalg.pinv(S)
T = np.c_[t1,t2].T
print(S)

In [None]:
mv = multivariate_normal(cov = S, mean = [0,0])
foo = mv.pdf(pos)
plt.figure(figsize = (10,10))
plt.contourf(foo)


In [None]:
mv.mean