In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import pdb
import numpy as np
import matplotlib.pyplot as plt
import scipy 
import time
import torch
import glob
import pickle
import pandas as pd
from tqdm import tqdm
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold

In [5]:
import sys

In [6]:
sys.path.append('/home/akumar/nse/neural_control')
from utils import apply_df_filters, calc_loadings, calc_cascaded_loadings
from loaders import load_sabes, load_peanut, load_cv
from decoders import lr_decoder, lr_encoder
from subspaces import SubspaceIdentification, IteratedStableEstimator, estimate_autocorrelation

In [7]:
from dca.cov_util import calc_cross_cov_mats_from_data, calc_pi_from_cross_cov_mats
from dca_research.kca import calc_mmse_from_cross_cov_mats
from dca_research.lqg import build_loss as build_lqg_loss

In [8]:
#data_path = '/mnt/sdb1/nc_data/'
data_path = '/home/akumar/nse/neural_control/data'

In [13]:
with open('%s/sabes_decoding_df.dat' % data_path, 'rb') as f:
    sabes_df = pickle.load(f)

In [8]:
with open('%s/peanut_dimreduc_df.dat' % data_path, 'rb') as f:
    peanut_df = pickle.load(f)

In [9]:
with open('%s/cv_dimreduc_df.dat' % data_path, 'rb') as f:
    cv_df = pickle.load(f)

In [10]:
with open('%s/loco_dimreduc_df.dat' % data_path, 'rb') as f:
    loco_df = pickle.load(f)

### Single unit calculations - no trialization

In [9]:
sabes_df.iloc[0]

dim                                                               10
fold_idx                                                           0
train_idxs         [1983, 1984, 1985, 1986, 1987, 1988, 1989, 199...
test_idxs          [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...
dimreduc_method                                                  KCA
dimreduc_args        {'T': 3, 'causal_weights': (1, 1), 'n_init': 5}
coef               [[0.04434436239872305, -0.1047764350198416, -0...
score                          tensor(141.1882, dtype=torch.float64)
bin_width                                                         50
filter_fn                                                       none
filter_kwargs                                                     {}
boxcox                                                           0.5
spike_threshold                                                  100
dim_vals           [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...
n_folds                           

In [9]:
sabes_su_l = []
decoder_params = {'trainlag': 4, 'testlag': 4, 'decoding_window': 3}

#data_path = '/mnt/sdb1/nc_data/sabes'
data_path = '/mnt/Secondary/data/sabes'
data_files = np.unique(sabes_df['data_file'].values)

for i, data_file in tqdm(enumerate(data_files)):    
    dat = load_sabes('%s/%s' % (data_path, data_file), bin_width=sabes_df.iloc[0]["bin_width"],
                     filter_fn=sabes_df.iloc[0]['filter_fn'], filter_kwargs=sabes_df.iloc[0]['filter_kwargs'],
                     boxcox=sabes_df.iloc[0]['boxcox'], spike_threshold=sabes_df.iloc[0]['spike_threshold'])
    
    X = np.squeeze(dat['spike_rates'])
    Z = dat['behavior']

    kfold = KFold(n_splits=5, shuffle=False)

    # Average results across folds
    decoding_weights = []
    encoding_weights = []
    su_decoding_weights = []
    su_encoding_weights = []
    su_r2_pos = []
    su_r2_vel = []
    su_r2_enc = []

    # Single unit statistics
    su_var = np.zeros((5, X.shape[-1]))
    su_mmse = np.zeros((5, X.shape[-1]))
    su_pi = np.zeros((5, X.shape[-1]))
    su_fcca = np.zeros((5, X.shape[-1]))

    # decoding/encoding weights after projection
    dims = np.arange(2, 11)
    proj_dw_pos = np.zeros((5, 4, dims.size, X.shape[-1]))
    proj_dw_vel = np.zeros((5, 4, dims.size, X.shape[-1]))
    proj_ew = np.zeros((5, 4, dims.size, X.shape[-1]))

    for fold_idx, (train_idxs, test_idxs) in enumerate(kfold.split(X)):

        r = {}

        ztrain = Z[train_idxs, :]
        ztest = Z[test_idxs, :]

        # Population level decoding/encoding - use the coefficient in the linear fit
        # Record both the weights in the coefficient but also the loadings onto the SVD

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        ccm = calc_cross_cov_mats_from_data(xtrain, T=10)
        ccm = torch.tensor(ccm)

        _, _, _, decodingregressor = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params)
        _, encodingregressor = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params)

        decoding_weights.append(decodingregressor.coef_)
        encoding_weights.append(encodingregressor.coef_)                
        
        r2_pos_decoding, r2_vel_decoding, r2_encoding = [], [], []
        
        su_dw = []
        su_ew = []            
        sur2pos = []
        sur2vel = []
        sur2enc = []
        for neu_idx in range(X.shape[-1]):           #Fit all neurons one by one
            
            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            # Decoding
            r2_pos, r2_vel, _, dr = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params)
            r2_pos_decoding.append(r2_pos)
            r2_vel_decoding.append(r2_vel)
            su_dw.append(dr.coef_)
            sur2pos.append(r2_pos)
            sur2vel.append(r2_vel)

            # Encoding
            r2_encoding_, er = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params)
            r2_encoding.append(r2_encoding_)
            su_ew.append(er.coef_)        
            sur2enc.append(r2_encoding)


        su_decoding_weights.append(np.array(su_dw))
        su_encoding_weights.append(np.array(su_ew))
        
        su_r2_pos.append(np.array(sur2pos))
        su_r2_vel.append(np.array(sur2vel))
        su_r2_enc.append(np.array(sur2enc))
        
        lqg_loss = build_lqg_loss(ccm, 1, ortho_lambda=0., project_mmse=False, loss_type='trace')

        for neu_idx in range(X.shape[-1]):

            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            su_var[fold_idx, neu_idx] = np.var(xtrain)
            proj = np.zeros((ccm.shape[-1], 1))
            proj[neu_idx] = 1
            proj = torch.tensor(proj)
            su_mmse[fold_idx, neu_idx] = calc_mmse_from_cross_cov_mats(ccm[0:4, ...], proj=proj).numpy()
            su_pi[fold_idx, neu_idx] = calc_pi_from_cross_cov_mats(torch.unsqueeze(torch.unsqueeze(ccm[0:9, neu_idx ,neu_idx], 1), 2))
            su_fcca[fold_idx, neu_idx] = lqg_loss(proj).numpy()            

        # Calculate decoding weights based on projection of the data first

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        for dr_idx, dimreduc_method in enumerate(['DCA', 'KCA', 'LQGCA', 'PCA']):
            for didx, d in enumerate(dims):
                df_ = apply_df_filters(sabes_df, data_file=data_file, fold_idx=fold_idx, dim=d, dimreduc_method=dimreduc_method)
                if dimreduc_method == 'LQGCA':
                    df_ = apply_df_filters(df_, dimreduc_args={'T': 3, 'loss_type': 'trace', 'n_init': 5})
                V = df_.iloc[0]['coef']
                if dimreduc_method == 'PCA':
                    V = V[:, 0:2]

                xtrain_proj = xtrain @ V 
                xtest_proj = xtest @ V

                _, _, _, decodingregressor = lr_decoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params)
                _, encodingregressor = lr_encoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params)
                proj_dw_pos[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_[0:2, :].T, d2= decoder_params['decoding_window'])
                proj_dw_vel[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_[2:4, :].T, d2= decoder_params['decoding_window'])
                proj_ew[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, encodingregressor.coef_)

    # Average results across folds
    decoding_weights = np.mean(np.array(decoding_weights), axis=0)
    encoding_weights = np.mean(np.array(encoding_weights), axis=0)
    su_decoding_weights = np.mean(np.array(su_decoding_weights), axis=0)
    su_encoding_weights = np.mean(np.array(su_encoding_weights), axis=0)
    
    su_r2_pos = np.mean(np.array(su_r2_pos), axis=0)
    su_r2_vel = np.mean(np.array(su_r2_vel), axis=0)
    su_r2_enc = np.mean(np.array(su_r2_enc), axis=0)

    su_var = np.mean(su_var, axis=0)
    su_mmse = np.mean(su_mmse, axis=0)
    su_pi = np.mean(su_pi, axis=0)
    su_fcca = np.mean(su_fcca, axis=0)

    proj_dw_pos = np.mean(proj_dw_pos, axis=0)
    proj_dw_vel = np.mean(proj_dw_vel, axis=0)
    proj_ew = np.mean(proj_ew, axis=0)

    result = {}
    for variable in ('data_file', 'decoding_weights', 'encoding_weights', 'su_decoding_weights', 'su_encoding_weights', 'su_r2_pos',
                     'su_r2_vel', 'su_r2_enc', 'su_var', 'su_mmse', 'su_pi', 'su_fcca', 'proj_dw_pos', 'proj_dw_vel', 'proj_ew', 'decoder_params'):
        result[variable] = eval(variable)

    sabes_su_l.append(result)

0it [00:00, ?it/s]

Processing spikes


100%|██████████| 1/1 [00:15<00:00, 15.80s/it]
1it [00:44, 44.60s/it]

Processing spikes


100%|██████████| 1/1 [00:22<00:00, 22.28s/it]
2it [01:43, 52.78s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.54s/it]
3it [01:58, 35.50s/it]

Processing spikes


100%|██████████| 1/1 [00:30<00:00, 30.95s/it]
4it [03:15, 52.24s/it]

Processing spikes


100%|██████████| 1/1 [00:13<00:00, 13.40s/it]
5it [03:44, 43.77s/it]

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.54s/it]
6it [03:56, 32.91s/it]

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.32s/it]
7it [04:07, 25.77s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.33s/it]
8it [04:21, 21.93s/it]

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.80s/it]
9it [04:33, 18.98s/it]

Processing spikes


100%|██████████| 1/1 [00:03<00:00,  3.44s/it]
10it [04:44, 16.55s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.65s/it]
11it [04:59, 15.90s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.50s/it]
12it [05:13, 15.30s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.14s/it]
13it [05:31, 16.25s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.68s/it]
14it [05:46, 15.68s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.75s/it]
15it [06:01, 15.50s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.72s/it]
16it [06:14, 14.99s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.32s/it]
17it [06:28, 14.58s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.57s/it]
18it [06:43, 14.74s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.61s/it]
19it [06:57, 14.41s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.29s/it]
20it [07:13, 14.92s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.76s/it]
21it [07:31, 15.90s/it]

Processing spikes


100%|██████████| 1/1 [00:04<00:00,  4.08s/it]
22it [07:44, 14.89s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.14s/it]
23it [07:58, 14.76s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.20s/it]
24it [08:12, 14.52s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.59s/it]
25it [08:27, 14.76s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.37s/it]
26it [08:44, 15.22s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.73s/it]
27it [09:04, 16.61s/it]

Processing spikes


100%|██████████| 1/1 [00:07<00:00,  7.53s/it]
28it [09:24, 20.16s/it]


In [10]:
sabes_su_df = pd.DataFrame(sabes_su_l)

In [12]:
# with open('/mnt/sdb1/nc_data/sabes_su_df.dat', 'wb') as f:
#     f.write(pickle.dumps(sabes_su_l))
with open('/home/akumar/nse/neural_control/data/sabes_su_df.dat', 'wb') as f:
    f.write(pickle.dumps(sabes_su_l))

In [None]:
# Augment with orientation model encoding r^2

In [91]:
# Loco
loco_su_l = []
decoder_params = {'trainlag': 4, 'testlag': 4, 'decoding_window': 3}
# Copied from the submit file
loader_params = {'bin_width':50, 'filter_fn':'none', 'filter_kwargs':{}, 'boxcox':0.5, 'spike_threshold':100, 'region': 'S1'}

data_path = '/mnt/Secondary/data/sabes'
data_files = np.unique(loco_df['data_file'].values)

for i, data_file in tqdm(enumerate(data_files)):    
    dat = load_sabes(data_file, bin_width=loader_params["bin_width"],
                     filter_fn=loader_params['filter_fn'], filter_kwargs=loader_params['filter_kwargs'],
                     boxcox=loader_params['boxcox'], spike_threshold=loader_params['spike_threshold'], region='S1')
    
    X = np.squeeze(dat['spike_rates'])
    Z = dat['behavior']

    kfold = KFold(n_splits=5, shuffle=False)

    # Average results across folds
    decoding_weights = []
    encoding_weights = []
    su_decoding_weights = []
    su_encoding_weights = []
    su_r2_pos = []
    su_r2_vel = []
    su_r2_enc = []

    # Single unit statistics
    su_var = np.zeros((5, X.shape[-1]))
    su_mmse = np.zeros((5, X.shape[-1]))
    su_pi = np.zeros((5, X.shape[-1]))
    su_fcca = np.zeros((5, X.shape[-1]))

    # decoding/encoding weights after projection
    dims = np.arange(2, 11)
    proj_dw_pos = np.zeros((5, 4, dims.size, X.shape[-1]))
    proj_dw_vel = np.zeros((5, 4, dims.size, X.shape[-1]))
    proj_ew = np.zeros((5, 4, dims.size, X.shape[-1]))

    for fold_idx, (train_idxs, test_idxs) in enumerate(kfold.split(X)):

        r = {}

        ztrain = Z[train_idxs, :]
        ztest = Z[test_idxs, :]

        # Population level decoding/encoding - use the coefficient in the linear fit
        # Record both the weights in the coefficient but also the loadings onto the SVD

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        ccm = calc_cross_cov_mats_from_data(xtrain, T=10)
        ccm = torch.tensor(ccm)

        _, _, _, decodingregressor = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params)
        _, encodingregressor = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params)

        decoding_weights.append(decodingregressor.coef_)
        encoding_weights.append(encodingregressor.coef_)                
        
        r2_pos_decoding, r2_vel_decoding, r2_encoding = [], [], []
        
        su_dw = []
        su_ew = []            
        sur2pos = []
        sur2vel = []
        sur2enc = []
        for neu_idx in range(X.shape[-1]):           #Fit all neurons one by one
            
            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            # Decoding
            r2_pos, r2_vel, _, dr = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params)
            r2_pos_decoding.append(r2_pos)
            r2_vel_decoding.append(r2_vel)
            su_dw.append(dr.coef_)
            sur2pos.append(r2_pos)
            sur2vel.append(r2_vel)

            # Encoding
            r2_encoding_, er = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params)
            r2_encoding.append(r2_encoding_)
            su_ew.append(er.coef_)        
            sur2enc.append(r2_encoding)


        su_decoding_weights.append(np.array(su_dw))
        su_encoding_weights.append(np.array(su_ew))
        
        su_r2_pos.append(np.array(sur2pos))
        su_r2_vel.append(np.array(sur2vel))
        su_r2_enc.append(np.array(sur2enc))
        
        lqg_loss = build_lqg_loss(ccm, 1, ortho_lambda=0., project_mmse=False, loss_type='trace')

        for neu_idx in range(X.shape[-1]):

            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            su_var[fold_idx, neu_idx] = np.var(xtrain)
            proj = np.zeros((ccm.shape[-1], 1))
            proj[neu_idx] = 1
            proj = torch.tensor(proj)
            su_mmse[fold_idx, neu_idx] = calc_mmse_from_cross_cov_mats(ccm[0:4, ...], proj=proj).numpy()
            su_pi[fold_idx, neu_idx] = calc_pi_from_cross_cov_mats(torch.unsqueeze(torch.unsqueeze(ccm[0:9, neu_idx ,neu_idx], 1), 2))
            su_fcca[fold_idx, neu_idx] = lqg_loss(proj).numpy()            

        # Calculate decoding weights based on projection of the data first

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        for dr_idx, dimreduc_method in enumerate(['DCA', 'KCA', 'LQGCA', 'PCA']):
            for didx, d in enumerate(dims):
                df_ = apply_df_filters(loco_df, data_file=data_file, fold_idx=fold_idx, dim=d, dimreduc_method=dimreduc_method, region='S1')
                if dimreduc_method == 'LQGCA':
                    df_ = apply_df_filters(df_, dimreduc_args={'T': 3, 'loss_type': 'trace', 'n_init': 5})
                V = df_.iloc[0]['coef']
                if dimreduc_method == 'PCA':
                    V = V[:, 0:2]

                xtrain_proj = xtrain @ V 
                xtest_proj = xtest @ V

                _, _, _, decodingregressor = lr_decoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params)
                _, encodingregressor = lr_encoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params)
                proj_dw_pos[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_[0:2, :].T, d2= decoder_params['decoding_window'])
                proj_dw_vel[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_[2:4, :].T, d2= decoder_params['decoding_window'])
                proj_ew[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, encodingregressor.coef_)

    # Average results across folds
    decoding_weights = np.mean(np.array(decoding_weights), axis=0)
    encoding_weights = np.mean(np.array(encoding_weights), axis=0)
    su_decoding_weights = np.mean(np.array(su_decoding_weights), axis=0)
    su_encoding_weights = np.mean(np.array(su_encoding_weights), axis=0)
    
    su_r2_pos = np.mean(np.array(su_r2_pos), axis=0)
    su_r2_vel = np.mean(np.array(su_r2_vel), axis=0)
    su_r2_enc = np.mean(np.array(su_r2_enc), axis=0)

    su_var = np.mean(su_var, axis=0)
    su_mmse = np.mean(su_mmse, axis=0)
    su_pi = np.mean(su_pi, axis=0)
    su_fcca = np.mean(su_fcca, axis=0)

    proj_dw_pos = np.mean(proj_dw_pos, axis=0)
    proj_dw_vel = np.mean(proj_dw_vel, axis=0)
    proj_ew = np.mean(proj_ew, axis=0)

    result = {}
    for variable in ('data_file', 'decoding_weights', 'encoding_weights', 'su_decoding_weights', 'su_encoding_weights', 'su_r2_pos',
                     'su_r2_vel', 'su_r2_enc', 'su_var', 'su_mmse', 'su_pi', 'su_fcca', 'proj_dw_pos', 'proj_dw_vel', 'proj_ew', 'decoder_params'):
        result[variable] = eval(variable)

    loco_su_l.append(result)

0it [00:00, ?it/s]

Processing spikes


100%|██████████| 1/1 [00:16<00:00, 16.61s/it]
1it [00:38, 39.00s/it]

Processing spikes


100%|██████████| 1/1 [00:22<00:00, 22.38s/it]
2it [01:32, 47.38s/it]

Processing spikes


100%|██████████| 1/1 [00:29<00:00, 29.43s/it]
3it [02:38, 55.89s/it]

Processing spikes


100%|██████████| 1/1 [00:10<00:00, 10.44s/it]
4it [03:05, 44.40s/it]

Processing spikes


100%|██████████| 1/1 [00:24<00:00, 24.97s/it]
5it [04:01, 48.73s/it]

Processing spikes


100%|██████████| 1/1 [00:17<00:00, 17.29s/it]
6it [04:43, 46.57s/it]

Processing spikes


100%|██████████| 1/1 [00:18<00:00, 18.85s/it]
7it [05:30, 46.67s/it]

Processing spikes


100%|██████████| 1/1 [00:16<00:00, 16.72s/it]
8it [06:13, 45.36s/it]

Processing spikes


100%|██████████| 1/1 [00:10<00:00, 10.64s/it]
9it [06:40, 39.64s/it]

Processing spikes


100%|██████████| 1/1 [00:21<00:00, 21.42s/it]
10it [07:31, 45.14s/it]


In [94]:
# Peanut
peanut_su_l = []
decoder_params = {'trainlag': 0, 'testlag': 0, 'decoding_window': 6}

fpath = '/mnt/Secondary/data/peanut/data_dict_peanut_day14.obj'
epochs = np.unique(peanut_df['epoch'].values)

for i, epoch in tqdm(enumerate(epochs)):    
    dat = load_peanut(fpath, epoch, speed_threshold=peanut_df.iloc[0]['speed_threshold'], 
                      bin_width=peanut_df.iloc[0]['bin_width'], filter_fn='none', filter_kwargs={},
                      spike_threshold=peanut_df.iloc[0]['spike_threshold'], boxcox=peanut_df.iloc[0]['boxcox'])

    X = np.squeeze(dat['spike_rates'])
    Z = dat['behavior']

    kfold = KFold(n_splits=5, shuffle=False)

    # Average results across folds
    decoding_weights = []
    encoding_weights = []
    su_decoding_weights = []
    su_encoding_weights = []
    su_r2_pos = []
    su_r2_enc = []

    # Single unit statistics
    su_var = np.zeros((5, X.shape[-1]))
    su_mmse = np.zeros((5, X.shape[-1]))
    su_pi = np.zeros((5, X.shape[-1]))
    su_fcca = np.zeros((5, X.shape[-1]))

    # decoding/encoding weights after projection
    dims = np.arange(2, 11)
    proj_dw = np.zeros((5, 4, dims.size, X.shape[-1]))
    proj_ew = np.zeros((5, 4, dims.size, X.shape[-1]))

    for fold_idx, (train_idxs, test_idxs) in enumerate(kfold.split(X)):

        r = {}

        ztrain = Z[train_idxs, :]
        ztest = Z[test_idxs, :]

        # Population level decoding/encoding - use the coefficient in the linear fit
        # Record both the weights in the coefficient but also the loadings onto the SVD

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        ccm = calc_cross_cov_mats_from_data(xtrain, T=10)
        ccm = torch.tensor(ccm)

        _, decodingregressor = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params, include_velocity=False, include_acc=False)
        _, encodingregressor = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params, include_velocity=False, include_acc=False)

        decoding_weights.append(decodingregressor.coef_)
        encoding_weights.append(encodingregressor.coef_)                
        
        r2_pos_decoding, r2_encoding = [], []
        
        su_dw = []
        su_ew = []            
        
        for neu_idx in range(X.shape[-1]):           #Fit all neurons one by one
            
            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            # Decoding
            r2_pos, dr = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params, include_velocity=False, include_acc=False)
            r2_pos_decoding.append(r2_pos)
            su_dw.append(dr.coef_)

            # Encoding
            r2_encoding_, er = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params, include_velocity=False, include_acc=False)
            r2_encoding.append(r2_encoding_)
            su_ew.append(er.coef_)        
            
        su_decoding_weights.append(np.array(su_dw))
        su_encoding_weights.append(np.array(su_ew))
        
        su_r2_pos.append(np.array(r2_pos))
        su_r2_enc.append(np.array(r2_encoding))
        

        lqg_loss = build_lqg_loss(ccm, 1, ortho_lambda=0., project_mmse=False, loss_type='trace')

        for neu_idx in range(X.shape[-1]):

            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            su_var[fold_idx, neu_idx] = np.var(xtrain)
            proj = np.zeros((ccm.shape[-1], 1))
            proj[neu_idx] = 1
            proj = torch.tensor(proj)
            su_mmse[fold_idx, neu_idx] = calc_mmse_from_cross_cov_mats(ccm[0:4, ...], proj=proj).numpy()
            su_pi[fold_idx, neu_idx] = calc_pi_from_cross_cov_mats(torch.unsqueeze(torch.unsqueeze(ccm[0:9, neu_idx ,neu_idx], 1), 2))
            su_fcca[fold_idx, neu_idx] = lqg_loss(proj).numpy()            

        # Calculate decoding weights based on projection of the data first

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        for dr_idx, dimreduc_method in enumerate(['DCA', 'KCA', 'LQGCA', 'PCA']):
            for didx, d in enumerate(dims):
                df_ = apply_df_filters(peanut_df, epoch=epoch, fold_idx=fold_idx, dim=d, dimreduc_method=dimreduc_method)
                if dimreduc_method == 'LQGCA':
                    df_ = apply_df_filters(df_, dimreduc_args={'T': 3, 'loss_type': 'trace', 'n_init': 5})
                V = df_.iloc[0]['coef']
                if dimreduc_method == 'PCA':
                    V = V[:, 0:2]

                xtrain_proj = xtrain @ V 
                xtest_proj = xtest @ V

                _, decodingregressor = lr_decoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params, include_velocity=False, include_acc=False)
                _, encodingregressor = lr_encoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params, include_velocity=False, include_acc=False)

                proj_dw[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_.T, d2= decoder_params['decoding_window'])
                proj_ew[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, encodingregressor.coef_)


    # Average results across folds
    decoding_weights = np.mean(np.array(decoding_weights), axis=0)
    encoding_weights = np.mean(np.array(encoding_weights), axis=0)
    su_decoding_weights = np.mean(np.array(su_decoding_weights), axis=0)
    su_encoding_weights = np.mean(np.array(su_encoding_weights), axis=0)
    
    su_r2_pos = np.mean(np.array(su_r2_pos), axis=0)
    su_r2_enc = np.mean(np.array(su_r2_enc), axis=0)

    su_var = np.mean(su_var, axis=0)
    su_mmse = np.mean(su_mmse, axis=0)
    su_pi = np.mean(su_pi, axis=0)
    su_fcca = np.mean(su_fcca, axis=0)

    proj_dw = np.mean(proj_dw, axis=0)
    proj_ew = np.mean(proj_ew, axis=0)

    result = {}
    for variable in ('epoch', 'decoding_weights', 'encoding_weights', 'su_decoding_weights', 'su_encoding_weights', 'su_r2_pos',
                     'su_r2_enc', 'su_var', 'su_mmse', 'su_pi', 'su_fcca', 'proj_dw', 'proj_ew'):
        result[variable] = eval(variable)

    peanut_su_l.append(result)

8it [01:13,  9.20s/it]


### Single Unit Calcs with Trialization

In [96]:
# This requires trialization dimreduc and a search for best decoding parameters in the segmented setting

In [9]:
start_times = {'indy_20160426_01': 0,
               'indy_20160622_01':1700,
               'indy_20160624_03': 500,
               'indy_20160627_01': 0,
               'indy_20160630_01': 0,
               'indy_20160915_01': 0,
               'indy_20160921_01': 0,
               'indy_20160930_02': 0,
               'indy_20160930_05': 300,
               'indy_20161005_06': 0,
               'indy_20161006_02': 350,
               'indy_20161007_02': 950,
               'indy_20161011_03': 0,
               'indy_20161013_03': 0,
               'indy_20161014_04': 0,
               'indy_20161017_02': 0,
               'indy_20161024_03': 0,
               'indy_20161025_04': 0,
               'indy_20161026_03': 0,
               'indy_20161027_03': 500,
               'indy_20161206_02': 5500,
               'indy_20161207_02': 0,
               'indy_20161212_02': 0,
               'indy_20161220_02': 0,
               'indy_20170123_02': 0,
               'indy_20170124_01': 0,
               'indy_20170127_03': 0,
               'indy_20170131_02': 0,
               }

In [10]:
from segmentation import reach_segment_sabes
from loaders import segment_peanut

In [33]:
# Same sort of calculation but now trialized
su_trialized_l = []
decoder_params = {'trainlag': 4, 'testlag': 4, 'decoding_window': 3}

data_path = '/mnt/Secondary/data/sabes'
data_files = np.unique(sabes_df['data_file'].values)

for i, data_file in tqdm(enumerate(data_files)):    
    dat = load_sabes('%s/%s' % (data_path, data_file), bin_width=sabes_df.iloc[0]["bin_width"],
                     filter_fn=sabes_df.iloc[0]['filter_fn'], filter_kwargs=sabes_df.iloc[0]['filter_kwargs'],
                     boxcox=sabes_df.iloc[0]['boxcox'], spike_threshold=sabes_df.iloc[0]['spike_threshold'], segment=False)
    
    dat_segment = reach_segment_sabes(dat, start_time=start_times[data_file.split('.mat')[0]])

    X = dat['spike_rates']
    Z = dat['behavior']

    # Exclude exceedingly short transitions
    T = 11
    t = np.array([t_[1] - t_[0] for t_ in dat_segment['transition_times']])
    valid_transitions = np.arange(t.size)[t >= T]

    X = np.array([dat['spike_rates'][0, dat_segment['transition_times'][idx][0]:dat_segment['transition_times'][idx][1]] 
                  for idx in valid_transitions])
    Z = np.array([dat['behavior'][dat_segment['transition_times'][idx][0]:dat_segment['transition_times'][idx][1]] 
                  for idx in valid_transitions])

    r = {}
    r['data_file'] = data_file

    kfold = KFold(n_splits=5, shuffle=False)

    # Average results across folds
    decoding_weights = []
    encoding_weights = []
    su_decoding_weights = []
    su_encoding_weights = []
    su_r2_pos = []
    su_r2_vel = []
    su_r2_enc = []

    # Single unit statistics
    su_var = np.zeros((5, X[0].shape[-1]))
    su_mmse = np.zeros((5, X[0].shape[-1]))
    su_pi = np.zeros((5, X[0].shape[-1]))
    su_fcca = np.zeros((5, X[0].shape[-1]))

    # decoding/encoding weights after projection
    dims = np.arange(2, 11)
    proj_dw_pos = np.zeros((5, 4, dims.size, X[0].shape[-1]))
    proj_dw_vel = np.zeros((5, 4, dims.size, X[0].shape[-1]))
    proj_ew = np.zeros((5, 4, dims.size, X[0].shape[-1]))
    
    for fold_idx, (train_idxs, test_idxs) in enumerate(kfold.split(X)):

        r = {}


        ztrain = list(Z[train_idxs])
        ztest = list(Z[test_idxs])

        # Population level decoding/encoding - use the coefficient in the linear fit
        # Record both the weights in the coefficient but also the loadings onto the SVD

        xtrain = list(X[train_idxs])
        xtest = list(X[test_idxs])

        ccm = calc_cross_cov_mats_from_data(xtrain, T=10)
        ccm = torch.tensor(ccm)

        _, _, _, decodingregressor = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params)
        _, encodingregressor = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params)

        decoding_weights.append(decodingregressor.coef_)
        encoding_weights.append(encodingregressor.coef_)                
        
        r2_pos_decoding, r2_vel_decoding, r2_encoding = [], [], []
        
        su_dw = []
        su_ew = []            
        
        for neu_idx in range(X[0].shape[-1]):           #Fit all neurons one by one
            
            xtrain = list(X[train_idxs])
            xtest = list(X[test_idxs])

            xtrain = [x[:, neu_idx][:, np.newaxis] for x in xtrain]
            xtest = [x[:, neu_idx][:, np.newaxis] for x in xtest]

            # Decoding
            r2_pos, r2_vel, _, dr = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params)
            r2_pos_decoding.append(r2_pos)
            r2_vel_decoding.append(r2_vel)
            su_dw.append(dr.coef_)

            # Encoding
            r2_encoding_, er = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params)
            r2_encoding.append(r2_encoding_)
            su_ew.append(er.coef_)        
            
        su_decoding_weights.append(np.array(su_dw))
        su_encoding_weights.append(np.array(su_ew))
        
        su_r2_pos.append(np.array(r2_pos))
        su_r2_vel.append(np.array(r2_vel))
        su_r2_enc.append(np.array(r2_encoding))
        

        lqg_loss = build_lqg_loss(ccm, 1, ortho_lambda=0., project_mmse=False, loss_type='trace')

        for neu_idx in range(X[0].shape[-1]):

            xtrain = list(X[train_idxs])
            xtest = list(X[test_idxs])

            xtrain = np.vstack([x[:, neu_idx][:, np.newaxis] for x in xtrain])
            xtest = np.vstack([x[:, neu_idx][:, np.newaxis] for x in xtest])

            su_var[fold_idx, neu_idx] = np.var(xtrain)
            proj = np.zeros((ccm.shape[-1], 1))
            proj[neu_idx] = 1
            proj = torch.tensor(proj)
            su_mmse[fold_idx, neu_idx] = calc_mmse_from_cross_cov_mats(ccm[0:4, ...], proj=proj).numpy()
            su_pi[fold_idx, neu_idx] = calc_pi_from_cross_cov_mats(torch.unsqueeze(torch.unsqueeze(ccm[0:9, neu_idx ,neu_idx], 1), 2))
            su_fcca[fold_idx, neu_idx] = lqg_loss(proj).numpy()            

        # Calculate decoding weights based on projection of the data first

        xtrain = list(X[train_idxs])
        xtest = list(X[test_idxs])

        for dr_idx, dimreduc_method in enumerate(['DCA', 'KCA', 'LQGCA', 'PCA']):
            for didx, d in enumerate(dims):
                df_ = apply_df_filters(sabes_df, data_file=data_file, fold_idx=fold_idx, dim=d, dimreduc_method=dimreduc_method)
                if dimreduc_method == 'LQGCA':
                    df_ = apply_df_filters(df_, dimreduc_args={'T': 3, 'loss_type': 'trace', 'n_init': 5})
                V = df_.iloc[0]['coef']
                if dimreduc_method == 'PCA':
                    V = V[:, 0:2]

                xtrain_proj = [x @ V for x in xtrain] 
                xtest_proj = [x @ V for x in xtest]

                _, _, _, decodingregressor = lr_decoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params)
                _, encodingregressor = lr_encoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params)

                proj_dw_pos[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_[0:2, :].T, d2= decoder_params['decoding_window'])
                proj_dw_vel[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_[2:4, :].T, d2= decoder_params['decoding_window'])
                proj_ew[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, encodingregressor.coef_)

    # Average results across folds
    decoding_weights = np.mean(np.array(decoding_weights), axis=0)
    encoding_weights = np.mean(np.array(encoding_weights), axis=0)
    su_decoding_weights = np.mean(np.array(su_decoding_weights), axis=0)
    su_encoding_weights = np.mean(np.array(su_encoding_weights), axis=0)
    
    su_r2_pos = np.mean(np.array(su_r2_pos), axis=0)
    su_r2_vel = np.mean(np.array(su_r2_vel), axis=0)
    su_r2_enc = np.mean(np.array(su_r2_enc), axis=0)

    su_var = np.mean(su_var, axis=0)
    su_mmse = np.mean(su_mmse, axis=0)
    su_pi = np.mean(su_pi, axis=0)
    su_fcca = np.mean(su_fcca, axis=0)

    proj_dw_pos = np.mean(proj_dw_pos, axis=0)
    proj_dw_vel = np.mean(proj_dw_vel, axis=0)
    proj_ew = np.mean(proj_ew, axis=0)

    result = {}
    for variable in ('data_file', 'decoding_weights', 'encoding_weights', 'su_decoding_weights', 'su_encoding_weights', 'su_r2_pos',
                     'su_r2_vel', 'su_r2_enc', 'su_var', 'su_mmse', 'su_pi', 'su_fcca', 'proj_dw_pos', 'proj_dw_vel', 'proj_ew'):
        result[variable] = eval(variable)

    su_trialized_l.append(result)

0it [00:00, ?it/s]

Processing spikes


100%|██████████| 1/1 [00:15<00:00, 15.78s/it]


> [0;32m/tmp/ipykernel_21463/3381761480.py[0m(28)[0;36m<module>[0;34m()[0m
[0;32m     26 [0;31m                  for idx in valid_transitions])
[0m[0;32m     27 [0;31m    [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 28 [0;31m    [0mr[0m [0;34m=[0m [0;34m{[0m[0;34m}[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     29 [0;31m    [0mr[0m[0;34m[[0m[0;34m'data_file'[0m[0;34m][0m [0;34m=[0m [0mdata_file[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     30 [0;31m[0;34m[0m[0m
[0m
(35240, 2)


0it [01:34, ?it/s]


BdbQuit: 