In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import sys 
import pdb 
from neurosim.models.ssr import StateSpaceRealization as SSR
from glob import glob
from dca.cov_util import calc_pi_from_data, calc_pi_from_cross_cov_mats, calc_cov_from_cross_cov_mats

In [3]:
sys.path.append('../..')

In [4]:
from loaders import load_sabes
from subspaces import CrossSubspaceIdentification, SubspaceIdentification, estimate_autocorrelation, BRSSID

### Scratch

In [15]:
indy_files = glob('/home/akumar/nse/neural_control/data/indy*')

In [16]:
dat = load_sabes(indy_files[0])

Processing spikes


100%|██████████| 1/1 [00:13<00:00, 13.62s/it]


In [18]:
y = np.squeeze(dat['spike_rates'])
pi = calc_pi_from_data(y, T=3)

In [28]:
ccmy = estimate_autocorrelation(y, 8)

In [46]:
calc_pi_from_cross_cov_mats(ccmy)

23.234529534801936

In [29]:
ssid = SubspaceIdentification()

In [30]:
# First verify that the PI is recovered from the canonical correlation coefficients

In [47]:
ht = ssid.form_hankel_toeplitz(ccmy, T=3)

In [49]:
cc = ht[1]

In [51]:
-0.5 * sum([np.log(1 - c**2) for c in cc])

23.234529534801784

In [27]:
# Next verify that the canonical correlation coefficients returned by psid yield the mutual informaton between neural data and behavior

In [53]:
def mutual_information(covjoint, covx, covy):
    return 0.5 * (np.linalg.slogdet(covx)[1] + np.linalg.slogdet(covy)[1] - np.linalg.slogdet(covjoint)[1])


In [54]:
z = np.squeeze(dat['behavior'])
ccmz = estimate_autocorrelation(z, 8)
ccm = estimate_autocorrelation(np.hstack([y, z]), 8)

In [67]:
mutual_information(calc_cov_from_cross_cov_mats(ccm[0:3]), calc_cov_from_cross_cov_mats(ccmy[0:3]), calc_cov_from_cross_cov_mats(ccmz[0:3]))

1.2246191502837576

In [57]:
bsid = PSIDSubspaceIdentification()

In [65]:
bht = bsid.form_hankel_toeplitz(ccm, 2, y.shape[1])

In [68]:
-0.5 * sum([np.log(1 - c**2) for c in bht[1]])

1.297904220247693

In [64]:
# Can also test on synthetic data

In [69]:
from ppmi.gaussian import gaussian_model

In [77]:
y, z, _, _, _, _, _ = gaussian_model()

In [94]:
ccmy = estimate_autocorrelation(y, 8)
ccmz = estimate_autocorrelation(z, 8)
ccm = estimate_autocorrelation(np.hstack([y, z]), 8)

In [95]:
mutual_information(calc_cov_from_cross_cov_mats(ccm[0:3]), calc_cov_from_cross_cov_mats(ccmy[0:3]), calc_cov_from_cross_cov_mats(ccmz[0:3]))

21.162377448897853

In [91]:
bht = bsid.form_hankel_toeplitz(np.transpose(ccm, (0, 2, 1)), 3, z.shape[1])

In [92]:
bht[1].shape

(40,)

In [93]:
-0.5 * sum([np.log(1 - c**2) for c in bht[1]])

12.369749093748798

### Testing on sabes lab data 

In [5]:
from loaders import load_sabes

In [6]:
dat = load_sabes('/home/akumar/nse/neural_control/data/indy_20160624_03.mat')

Processing spikes


100%|██████████| 1/1 [00:14<00:00, 14.28s/it]


In [7]:
y = np.squeeze(dat['spike_rates'])
z = np.squeeze(dat['behavior'])

In [14]:
bsid1 = CrossSubspaceIdentification()
bsid2 = BRSSID()

In [11]:
z.shape

(9998, 2)

In [22]:
bsid2.identify(y, z, order=6, T=5)

> [0;32m/home/akumar/nse/neural_control/subspaces.py[0m(690)[0;36midentify[0;34m()[0m
[0;32m    688 [0;31m        [0mxt[0m[0;34m,[0m [0mxt1[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_predictor_space[0m[0;34m([0m[0my[0m[0;34m,[0m [0mz[0m[0;34m,[0m [0mT[0m[0;34m,[0m [0mint[0m[0;34m([0m[0morder[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    689 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 690 [0;31m        [0mA[0m[0;34m,[0m [0mC[0m[0;34m,[0m [0mCbar[0m[0;34m,[0m [0mrho_A[0m[0;34m,[0m [0mrho_C[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mestimator[0m[0;34m.[0m[0mfit[0m[0;34m([0m[0mnp[0m[0;34m.[0m[0mhstack[0m[0;34m([0m[0;34m[[0m[0my[0m[0;34m[[0m[0mT[0m[0;34m:[0m[0;34m-[0m[0mT[0m[0;34m][0m[0;34m,[0m [0mz[0m[0;34m[[0m[0mT[0m[0;34m:[0m[0;34m-[0m[0mT[0m[0;34m][0m[0;34m][0m[0;34m)[0m[0;34m,[0

BdbQuit: 

In [13]:
bsid1.identify(y, z, 5)

> [0;32m/home/akumar/nse/neural_control/subspaces.py[0m(675)[0;36mget_predictor_space[0;34m()[0m
[0;32m    673 [0;31m[0;34m[0m[0m
[0m[0;32m    674 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 675 [0;31m        [0mXt[0m [0;34m=[0m [0mSigmabart[0m[0;34m.[0m[0mT[0m [0;34m@[0m [0mnp[0m[0;34m.[0m[0mlinalg[0m[0;34m.[0m[0minv[0m[0;34m([0m[0mTm1_y[0m[0;34m)[0m [0;34m@[0m [0mymt1[0m[0;34m[[0m[0mm[0m[0;34m:[0m[0;34m,[0m [0;34m:[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    676 [0;31m        [0mXt1[0m [0;34m=[0m [0mSigmabart1[0m[0;34m.[0m[0mT[0m [0;34m@[0m [0mnp[0m[0;34m.[0m[0mlinalg[0m[0;34m.[0m[0minv[0m[0;34m([0m[0mTm1_y[0m[0;34m)[0m [0;34m@[0m [0mymt1[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    677 [0;31m[0;34m[0m[0m
[0m
(596, 596)
(447, 5)
(596, 9995)
(596, 9995)


BdbQuit: 