In [1]:
%load_ext autoreload
%autoreload 2

In [28]:
import numpy as np
import sys 
import pdb 
import scipy
from neurosim.models.ssr import StateSpaceRealization as SSR
from neurosim.models.ssr import gen_random_model
from sklearn.linear_model import LinearRegression
from sklearn.decomposition import TruncatedSVD
from glob import glob
from dca.cov_util import calc_pi_from_data, calc_pi_from_cross_cov_mats, calc_cov_from_cross_cov_mats, form_lag_matrix

In [4]:
sys.path.append('../..')

In [5]:
from loaders import load_sabes
from subspaces import CrossSubspaceIdentification, SubspaceIdentification, estimate_autocorrelation, BRSSID

### Scratch

In [15]:
indy_files = glob('/home/akumar/nse/neural_control/data/indy*')

In [16]:
dat = load_sabes(indy_files[0])

Processing spikes


100%|██████████| 1/1 [00:13<00:00, 13.62s/it]


In [18]:
y = np.squeeze(dat['spike_rates'])
pi = calc_pi_from_data(y, T=3)

In [28]:
ccmy = estimate_autocorrelation(y, 8)

In [46]:
calc_pi_from_cross_cov_mats(ccmy)

23.234529534801936

In [29]:
ssid = SubspaceIdentification()

In [30]:
# First verify that the PI is recovered from the canonical correlation coefficients

In [47]:
ht = ssid.form_hankel_toeplitz(ccmy, T=3)

In [49]:
cc = ht[1]

In [51]:
-0.5 * sum([np.log(1 - c**2) for c in cc])

23.234529534801784

In [27]:
# Next verify that the canonical correlation coefficients returned by psid yield the mutual informaton between neural data and behavior

In [53]:
def mutual_information(covjoint, covx, covy):
    return 0.5 * (np.linalg.slogdet(covx)[1] + np.linalg.slogdet(covy)[1] - np.linalg.slogdet(covjoint)[1])


In [54]:
z = np.squeeze(dat['behavior'])
ccmz = estimate_autocorrelation(z, 8)
ccm = estimate_autocorrelation(np.hstack([y, z]), 8)

In [67]:
mutual_information(calc_cov_from_cross_cov_mats(ccm[0:3]), calc_cov_from_cross_cov_mats(ccmy[0:3]), calc_cov_from_cross_cov_mats(ccmz[0:3]))

1.2246191502837576

In [57]:
bsid = PSIDSubspaceIdentification()

In [65]:
bht = bsid.form_hankel_toeplitz(ccm, 2, y.shape[1])

In [68]:
-0.5 * sum([np.log(1 - c**2) for c in bht[1]])

1.297904220247693

In [64]:
# Can also test on synthetic data

In [69]:
from ppmi.gaussian import gaussian_model

In [77]:
y, z, _, _, _, _, _ = gaussian_model()

In [94]:
ccmy = estimate_autocorrelation(y, 8)
ccmz = estimate_autocorrelation(z, 8)
ccm = estimate_autocorrelation(np.hstack([y, z]), 8)

In [95]:
mutual_information(calc_cov_from_cross_cov_mats(ccm[0:3]), calc_cov_from_cross_cov_mats(ccmy[0:3]), calc_cov_from_cross_cov_mats(ccmz[0:3]))

21.162377448897853

In [91]:
bht = bsid.form_hankel_toeplitz(np.transpose(ccm, (0, 2, 1)), 3, z.shape[1])

In [92]:
bht[1].shape

(40,)

In [93]:
-0.5 * sum([np.log(1 - c**2) for c in bht[1]])

12.369749093748798

### Testing on sabes lab data 

In [14]:
from loaders import load_sabes

In [16]:
dat = load_sabes('/home/akumar/nse/neural_control/data/indy_20160624_03.mat')
#dat = load_sabes('/mnt/Secondary/data/sabes/indy_20160624_03.mat')

Processing spikes


100%|██████████| 1/1 [00:08<00:00,  8.24s/it]


In [17]:
y = np.squeeze(dat['spike_rates'])
z = np.squeeze(dat['behavior'])

In [23]:
# How good is basic subspace identification
T = 5
ydim = y.shape[1]
zdim = z.shape[1]

yt = form_lag_matrix(y, 2*T)
zt = form_lag_matrix(z, 2*T)

# "Past" of y and "Future" of z
ypast = yt[:, :T*ydim]
zfut = zt[:, -T*zdim:]

In [25]:
linmodel = LinearRegression().fit(ypast, zfut)

In [27]:
linmodel.score(ypast, zfut)

0.4463327317412852

In [30]:
# Behaviorally relevant state
Z = scipy.linalg.lstsq(ypast, zfut)[0].T @ ypast.T

svd = TruncatedSVD(n_components=Z.shape[0] - 1)
svd.fit(Z.T)

U = svd.components_.T
Gamma_t = U @ np.diag(np.sqrt(svd.singular_values_))
Xt = np.linalg.pinv(Gamma_t) @ Z

In [34]:
Xt.shape

(9, 9989)

(10, 9989)

In [45]:
linmodel = LinearRegression().fit(Xt.T, z[]])
linmodel.score(Xt.T, z[:-9])

0.2724127559060829

In [48]:
from subspaces import brssid

In [49]:
brssid(y, z, 6, 5)

(0.3856758697459933,
 array([[ 1.10888572e-02,  1.87695522e-02, -4.52275281e-02,
         -4.49973443e-03, -6.69901021e-02, -1.05202299e-01],
        [-1.11765504e-02, -1.06505708e-02, -3.46171902e-04,
          1.78712328e-02, -2.46493286e-02, -8.88320316e-02],
        [ 6.99177799e-03, -6.71297370e-02,  4.55281591e-01,
         -3.29268667e-02, -1.13228261e-01,  4.71852872e-01],
        [-1.02273248e-02,  7.89429273e-03,  3.33586207e-02,
          4.61172847e-03, -6.12541809e-02,  1.04995682e-01],
        [-1.81443997e-03,  2.20271576e-02, -5.39270127e-02,
         -1.19622246e-01, -8.45296312e-02, -2.34750074e-01],
        [-9.03421747e-03, -5.75864409e-02,  1.79035044e-01,
         -1.53862898e-02, -4.69384625e-01,  1.04014644e+00],
        [-6.71747751e-03,  3.72016325e-02, -6.40972739e-03,
         -2.95132102e-02,  5.95741302e-02, -6.81051348e-02],
        [ 1.02635306e-01, -1.44866209e-01,  1.91831482e-01,
          5.20779439e-02,  1.54055935e-03, -2.51729609e-02],
        [-2

In [26]:
# Test the KF decoding frm the bsid

In [None]:
# Test on synthetic data

In [6]:
A, B, C = gen_random_model(size=6)

In [7]:
C = scipy.stats.ortho_group.rvs(dim=20)[:, 0:6]


In [8]:
ssr = SSR(A, B, C)

In [9]:
X, x = ssr.trajectory(T=int(1e5), return_state=True)

In [10]:
y = x[:, 0:18]
z = x[:, 18:]

In [11]:
bsid2 = BRSSID()

In [12]:
Ass,  _, _, r2 = bsid2.identify(y, z, order=6, T=5)

In [13]:
r2

0.09618722855310002

In [100]:
np.linalg.eigvals(Ass)

array([-0.62863086+0.j        , -0.50425209+0.3089909j ,
       -0.50425209-0.3089909j ,  0.15608834+0.51544694j,
        0.15608834-0.51544694j,  0.26256323+0.j        ])

In [99]:
np.linalg.eigvals(A)

array([-0.51242636+0.27857446j, -0.51242636-0.27857446j,
        0.13910177+0.51559358j,  0.13910177-0.51559358j,
        0.17536933+0.01531477j,  0.17536933-0.01531477j])

In [98]:
np.linalg.norm(A - Ass)

2.4086981303915054

In [54]:
np.arange(100)[5:-5 + 1]

array([ 5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
       22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
       39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
       56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
       73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
       90, 91, 92, 93, 94, 95])

In [53]:
xt.shape
z.shape

(100001, 2)

In [49]:
from sklearn.linear_model import LinearRegression

In [61]:
linmodel = LinearRegression().fit(z[4:-5], xt.T)

In [62]:
linmodel.score(z[4:-5], xt.T)

0.2831890412843523

In [64]:
linmodel.coef_ = C[-2:, :]

In [None]:
linmodel.score(z, )

In [101]:
U = scipy.stats.unitary_group.rvs(10)

In [None]:
delta = np.linspace(0, 0.5, 20)
nn = np.zeros(delta.size)

for i, alpha_ in enumerate(alpha):
    P = 