In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import time
import sys
import pickle
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA
from dca.cov_util import form_lag_matrix, calc_cross_cov_mats_from_data
import glob
import pdb
from statsmodels.tsa import stattools
from dca_research.lqg import LQGComponentsAnalysis as LQGCA

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from pyuoi.linear_model.var import VAR

In [4]:
sys.path.append('/home/akumar/nse/neural_control')
from loaders import load_sabes
from decoders import lr_decoder

In [7]:
# Time scale analysis: 
# A few questions:
# (1) Use the subspaces derived at fast timescales to assess coupling/decoding at slower/the same timescales
# (2) Use subspaces derived at slow timeslcaes to assess coupling/decoding at same/faster time scales
# (3) Types of analyses:
# (a) Decoding vs. dim, decoding vs. t --> By "decoding" at slower timescales, we could try some kind of smoothing
# (b) Time-resolved CCA, look at the different populations of neurons

### Preprocess and save

In [8]:
# Select the indy data file that has both M1/S1 and one loco datafile that seems to give decent decoding performance
data_files = ['/mnt/Secondary/data/sabes/indy_20160426_01.mat', '/mnt/Secondary/data/sabes/loco_20170227_04.mat']

In [9]:

# Collect into bin width and filter param tuples

processing_params = [
    # Baseline - non-overlapping windows with no window
    (5, dict(filter_fn='none', filter_kwargs={})),
    (10, dict(filter_fn='none', filter_kwargs={})),
    (15, dict(filter_fn='none', filter_kwargs={})),
    (25, dict(filter_fn='none', filter_kwargs={})),
    (50, dict(filter_fn='none', filter_kwargs={})),

    # Narrow bins with smoothing timescales
    (5, dict(filter_fn='gaussian', filter_kwargs={'sigma':1})),
    (5, dict(filter_fn='gaussian', filter_kwargs={'sigma':3})),
    (5, dict(filter_fn='gaussian', filter_kwargs={'sigma':5})),

    (10, dict(filter_fn='gaussian', filter_kwargs={'sigma':1})),
    (10, dict(filter_fn='gaussian', filter_kwargs={'sigma':3})),
    (10, dict(filter_fn='gaussian', filter_kwargs={'sigma':5})),


    (25, dict(filter_fn='gaussian', filter_kwargs={'sigma':1})),
    (25, dict(filter_fn='gaussian', filter_kwargs={'sigma':3})),
    (25, dict(filter_fn='gaussian', filter_kwargs={'sigma':5})),

    # Narrow bins with boxcar filter of various widths
    (5, dict(filter_fn='window', filter_kwargs={'window_name':'boxcar', 'window_length':3})),
    (5, dict(filter_fn='window', filter_kwargs={'window_name':'boxcar', 'window_length':5})),
    (5, dict(filter_fn='window', filter_kwargs={'window_name':'boxcar', 'window_length':10})),


    (10, dict(filter_fn='window', filter_kwargs={'window_name':'boxcar', 'window_length':3})),
    (10, dict(filter_fn='window', filter_kwargs={'window_name':'boxcar', 'window_length':5})),
    (10, dict(filter_fn='window', filter_kwargs={'window_name':'boxcar', 'window_length':10})),


    (25, dict(filter_fn='window', filter_kwargs={'window_name':'boxcar', 'window_length':3})),
    (25, dict(filter_fn='window', filter_kwargs={'window_name':'boxcar', 'window_length':5})),
    (25, dict(filter_fn='window', filter_kwargs={'window_name':'boxcar', 'window_length':10})),

    # Hann windows
    (5, dict(filter_fn='window', filter_kwargs={'window_name':'hann', 'window_length':3})),
    (5, dict(filter_fn='window', filter_kwargs={'window_name':'hann', 'window_length':5})),
    (5, dict(filter_fn='window', filter_kwargs={'window_name':'hann', 'window_length':10})),


    (10, dict(filter_fn='window', filter_kwargs={'window_name':'hann', 'window_length':3})),
    (10, dict(filter_fn='window', filter_kwargs={'window_name':'hann', 'window_length':5})),
    (10, dict(filter_fn='window', filter_kwargs={'window_name':'hann', 'window_length':10})),


    (25, dict(filter_fn='window', filter_kwargs={'window_name':'hann', 'window_length':3})),
    (25, dict(filter_fn='window', filter_kwargs={'window_name':'hann', 'window_length':5})),
    (25, dict(filter_fn='window', filter_kwargs={'window_name':'hann', 'window_length':10})),
]

for didx, data_file in enumerate(data_files):
    for i, loader_param in tqdm(enumerate(processing_params)):
        if os.path.exists('/mnt/Secondary/data/sabes_tmp/didx_%d_loader_idx_%d.pkl' % (didx, i)):
            print('Skipping')
            continue
        else:
            print('Loading M1')
            datM1 = load_sabes(data_file, bin_width=loader_param[0], filter_fn=loader_param[1]['filter_fn'], filter_kwargs=loader_param[1]['filter_kwargs'], region='M1')
            print('Loading S1')
            datS1 = load_sabes(data_file, bin_width=loader_param[0], filter_fn=loader_param[1]['filter_fn'], filter_kwargs=loader_param[1]['filter_kwargs'], region='S1')

            # Save away with the loader parameters
            with open('/mnt/Secondary/data/sabes_tmp/didx_%d_loader_idx_%d.pkl' % (didx, i), 'wb') as f:
                f.write(pickle.dumps(datM1))
                f.write(pickle.dumps(datS1))
                f.write(pickle.dumps(loader_param))

0it [00:00, ?it/s]

Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Skipping
Loading M1
Loading S1
96


15it [03:36, 14.41s/it]

Loading M1
Loading S1
96


16it [07:12, 32.01s/it]

Loading M1
Loading S1
96


17it [10:47, 52.20s/it]

Loading M1
Loading S1
96


18it [12:35, 59.77s/it]

Loading M1
Loading S1
96


19it [14:23, 67.60s/it]

Loading M1
Loading S1
96


20it [16:11, 75.19s/it]

Loading M1
Loading S1
96


21it [16:54, 68.47s/it]

Loading M1
Loading S1
96


22it [17:37, 62.63s/it]

Loading M1
Loading S1
96


23it [18:22, 58.00s/it]

Loading M1
Loading S1
96


24it [21:56, 99.25s/it]

Loading M1
Loading S1
96


25it [25:30, 130.53s/it]

Loading M1
Loading S1
96


26it [29:05, 154.07s/it]

Loading M1
Loading S1
96


27it [30:52, 140.68s/it]

Loading M1
Loading S1
96


28it [32:39, 130.84s/it]

Loading M1
Loading S1
96


29it [34:27, 124.14s/it]

Loading M1
Loading S1
96


30it [35:10, 100.40s/it]

Loading M1
Loading S1
96


31it [35:54, 83.54s/it] 

Loading M1
Loading S1
96


32it [36:37, 68.68s/it]
0it [00:00, ?it/s]

Loading M1
Loading S1
96


1it [04:12, 252.28s/it]

Loading M1
Loading S1
96


2it [06:18, 178.21s/it]

Loading M1
Loading S1
96


3it [07:42, 135.23s/it]

Loading M1
Loading S1
96


4it [08:34, 102.31s/it]

Loading M1
Loading S1
96


5it [09:00, 74.64s/it] 

Loading M1
Loading S1
96


6it [13:15, 136.06s/it]

Loading M1
Loading S1
96


7it [17:30, 175.04s/it]

Loading M1
Loading S1
96


8it [21:47, 201.13s/it]

Loading M1
Loading S1
96


9it [23:55, 178.34s/it]

Loading M1
Loading S1
96


10it [26:04, 163.00s/it]

Loading M1
Loading S1
96


11it [28:13, 152.54s/it]

Loading M1
Loading S1
96


12it [29:05, 121.84s/it]

Loading M1
Loading S1
96


13it [29:57, 100.83s/it]

Loading M1
Loading S1
96


14it [30:49, 86.00s/it] 

Loading M1
Loading S1
96


15it [35:03, 136.79s/it]

Loading M1
Loading S1
96


16it [39:18, 172.28s/it]

Loading M1
Loading S1
96


17it [43:33, 197.21s/it]

Loading M1
Loading S1
96


18it [45:41, 176.41s/it]

Loading M1
Loading S1
96


19it [47:49, 161.90s/it]

Loading M1
Loading S1
96


20it [49:56, 151.38s/it]

Loading M1
Loading S1
96


21it [50:47, 121.33s/it]

Loading M1
Loading S1
96


22it [51:40, 100.65s/it]

Loading M1
Loading S1
96


23it [52:32, 86.15s/it] 

Loading M1
Loading S1
96


24it [56:46, 136.33s/it]

Loading M1
Loading S1
96


25it [1:00:59, 171.51s/it]

Loading M1
Loading S1
96


26it [1:05:14, 196.62s/it]

Loading M1
Loading S1
96


27it [1:07:21, 175.79s/it]

Loading M1
Loading S1
96


28it [1:09:29, 161.36s/it]

Loading M1
Loading S1
96


29it [1:11:38, 151.45s/it]

Loading M1
Loading S1
96


30it [1:12:29, 121.44s/it]

Loading M1
Loading S1
96


31it [1:13:21, 100.59s/it]

Loading M1
Loading S1
96


32it [1:14:13, 139.16s/it]


### Calculations

In [15]:
# Analysis #1: Effective single neuron autocorrelation times

In [5]:
fls = glob.glob('/mnt/Secondary/data/sabes_tmp/*.pkl')

In [10]:
acfl = []
for fl in tqdm(fls):
    r = {}
    r['didx'] = int(fl.split('didx_')[1].split('_')[0])
    with open(fl, 'rb') as f:
        datM1 = pickle.load(f)
        datS1 = pickle.load(f)
        lparam = pickle.load(f)
    r['bin_width'] = lparam[0]
    r['filter_fn'] = lparam[1]['filter_fn']
    r['filter_kwargs'] = lparam[1]['filter_kwargs']

    XM1 = datM1['spike_rates'].squeeze()
    m1acf = []
    for j in range(XM1.shape[1]):
        m1acf.append(stattools.acf(XM1[:, j]))
    
    m1acf = np.array(m1acf).T

    XS1 = datS1['spike_rates'].squeeze()
    s1acf = []
    for j in range(XS1.shape[1]):
        s1acf.append(stattools.acf(XS1[:, j]))
    
    s1acf = np.array(s1acf).T

    r['M1acf'] = m1acf
    r['S1acf'] = s1acf
    acfl.append(r) 

 47%|████▋     | 30/64 [02:14<02:58,  5.26s/it]

In [None]:
with open('/home/akumar/nse/neural_control/data/timescales/acf.pkl', 'wb') as f:
    f.write(pickle.dumps(acfl))

In [6]:
import time

In [23]:
# Analysis 2: Cross-correlations between M1/S1 neurons
ccfl = []
for fl in tqdm(fls):
    r = {}
    r['didx'] = int(fl.split('didx_')[1].split('_')[0])
    with open(fl, 'rb') as f:
        datM1 = pickle.load(f)
        datS1 = pickle.load(f)
        lparam = pickle.load(f)
    r['bin_width'] = lparam[0]
    r['filter_fn'] = lparam[1]['filter_fn']
    r['filter_kwargs'] = lparam[1]['filter_kwargs']

    XM1 = datM1['spike_rates'].squeeze()
    XS1 = datS1['spike_rates'].squeeze()

    x = np.hstack([XM1, XS1])
    ccm = calc_cross_cov_mats_from_data(x, 30, chunks=10)
    r['ccf'] = ccm
    ccfl.append(r) 

  0%|          | 0/64 [00:00<?, ?it/s]

In [14]:
# Analysis #2: VAR spectra

In [7]:
# varl = []
# for fl in tqdm(fls):
#     r = {}
#     r['didx'] = int(fl.split('didx_')[1].split('_')[0])
#     with open(fl, 'rb') as f:
#         datM1 = pickle.load(f)
#         datS1 = pickle.load(f)
#         lparam = pickle.load(f)
#     r['bin_width'] = lparam[0]
#     r['filter_fn'] = lparam[1]['filter_fn']
#     r['filter_kwargs'] = lparam[1]['filter_kwargs']

#     XM1 = datM1['spike_rates'].squeeze()
#     varmodel = VAR(estimator='ols', order=2)
#     varmodel.fit(XM1)
#     r['AM1'] = varmodel.coef_ 
    
#     XS1 = datS1['spike_rates'].squeeze()
#     varmodel = VAR(estimator='ols', order=2)
#     varmodel.fit(XS1)
#     r['AS1'] = varmodel.coef_

#     varl.append(r)

If you wish to scale the data, use Pipeline with a StandardScaler in a preprocessing stage. To reproduce the previous behavior:

from sklearn.pipeline import make_pipeline

model = make_pipeline(StandardScaler(with_mean=False), VAR_OLS_Wrapper())

If you wish to pass a sample_weight parameter, you need to pass it as a fit parameter to each step of the pipeline as follows:

kwargs = {s[0] + '__sample_weight': sample_weight for s in model.steps}
model.fit(X, y, **kwargs)


If you wish to scale the data, use Pipeline with a StandardScaler in a preprocessing stage. To reproduce the previous behavior:

from sklearn.pipeline import make_pipeline

model = make_pipeline(StandardScaler(with_mean=False), VAR_OLS_Wrapper())

If you wish to pass a sample_weight parameter, you need to pass it as a fit parameter to each step of the pipeline as follows:

kwargs = {s[0] + '__sample_weight': sample_weight for s in model.steps}
model.fit(X, y, **kwargs)


If you wish to scale the data, use Pipeline with

In [9]:
# Analysis 3: Fit LQGCA and PCA across dimensions
dimreducl = []
dimvals = np.arange(1, 31, 2)
for fl in tqdm(fls):
    with open(fl, 'rb') as f:
        datM1 = pickle.load(f)
        datS1 = pickle.load(f)
        lparam = pickle.load(f)

    XM1 = datM1['spike_rates'].squeeze()
    XS1 = datS1['spike_rates'].squeeze()

    pcaM1 = PCA().fit(XM1).components_.T
    pcaS1 = PCA().fit(XS1).components_.T

    T = [1, 3, 5]

    for k, dim in enumerate(dimvals):
        for t in T:
            r = {}
            r['didx'] = int(fl.split('didx_')[1].split('_')[0])
            r['bin_width'] = lparam[0]
            r['filter_fn'] = lparam[1]['filter_fn']
            r['filter_kwargs'] = lparam[1]['filter_kwargs']
            r['dim'] = dim
            r['T'] = T
            r['pcacoefM1'] = pcaM1[:, 0:dim]
            r['pcacoefS1'] = pcaS1[:, 0:dim]


            lqgmodel = LQGCA(T=t, d=dim, n_init=10).fit(XM1)
            r['fcacoefM1'] = lqgmodel.coef_
            lqgmodel = LQGCA(T=t, d=dim, n_init=10).fit(XS1)
            r['fcacoefS1'] = lqgmodel.coef_
            dimreducl.append(r)

  0%|          | 0/64 [00:00<?, ?it/s]

In [6]:
# Analysis 4: Canonical correlation analysis
ccal = []
lags = np.array([4, 2, 0])
windows = np.array([5, 3, 1])

for fl in tqdm(fls):
    with open(fl, 'rb') as f:
        datM1 = pickle.load(f)
        datS1 = pickle.load(f)
        lparam = pickle.load(f)

    Y = datM1['spike_rates'].squeeze()
    X = datS1['spike_rates'].squeeze()

    for k, lag in enumerate(lags):
        for w, window in enumerate(windows):
            t0 = time.time()
            r = {}
            r['didx'] = int(fl.split('didx_')[1].split('_')[0])
            r['bin_width'] = lparam[0]
            r['filter_fn'] = lparam[1]['filter_fn']
            r['filter_kwargs'] = lparam[1]['filter_kwargs']
            r['lag'] = lag
            r['win'] = w

            # Apply window and lag relative to each other
            if lag != 0:
                x = X[:-lag, :]
                y = Y[lag:, :]
            else:
                x = X
                y = Y

            if w > 1:
                x = form_lag_matrix(x, w)
                y = form_lag_matrix(y, w)

            ccamodel = CCA(n_components=6)
            ccamodel.fit(x, y)
            r['ccamodel'] = ccamodel
            ccal.append(r)
            print(time.time() - t0)            

  0%|          | 0/64 [00:00<?, ?it/s]

37.03797483444214
37.027458906173706
