In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pdb
import numpy as np
import matplotlib.pyplot as plt
import scipy 
import time
import torch
import glob
import pickle
import pandas as pd
from tqdm import tqdm
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold

In [3]:
import sys

In [4]:
sys.path.append('/home/akumar/nse/neural_control')
from utils import apply_df_filters, calc_loadings, calc_cascaded_loadings
from loaders import load_sabes, load_peanut, load_cv
from decoders import lr_decoder, lr_encoder
from subspaces import SubspaceIdentification, IteratedStableEstimator, estimate_autocorrelation

In [5]:
from dca.cov_util import calc_cross_cov_mats_from_data, calc_pi_from_cross_cov_mats
from dca_research.kca import calc_mmse_from_cross_cov_mats
from dca_research.lqg import build_loss as build_lqg_loss

In [6]:
data_path = '/mnt/sdb1/nc_data/'

In [7]:
with open('%s/sabes_decoding_df.dat' % data_path, 'rb') as f:
    sabes_df = pickle.load(f)

In [8]:
with open('%s/peanut_dimreduc_df.dat' % data_path, 'rb') as f:
    peanut_df = pickle.load(f)

In [9]:
with open('%s/cv_dimreduc_df.dat' % data_path, 'rb') as f:
    cv_df = pickle.load(f)

In [10]:
with open('%s/loco_dimreduc_df.dat' % data_path, 'rb') as f:
    loco_df = pickle.load(f)

### Single unit calculations - no trialization

In [9]:
sabes_df.iloc[0]

dim                                                               10
fold_idx                                                           0
train_idxs         [1983, 1984, 1985, 1986, 1987, 1988, 1989, 199...
test_idxs          [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...
dimreduc_method                                                  KCA
dimreduc_args        {'T': 3, 'causal_weights': (1, 1), 'n_init': 5}
coef               [[0.04434436239872305, -0.1047764350198416, -0...
score                          tensor(141.1882, dtype=torch.float64)
bin_width                                                         50
filter_fn                                                       none
filter_kwargs                                                     {}
boxcox                                                           0.5
spike_threshold                                                  100
dim_vals           [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...
n_folds                           

In [11]:
sabes_su_l = []
decoder_params = {'trainlag': 4, 'testlag': 4, 'decoding_window': 3}

data_path = '/mnt/sdb1/nc_data/sabes'
data_files = np.unique(sabes_df['data_file'].values)

for i, data_file in tqdm(enumerate(data_files)):    
    dat = load_sabes('%s/%s' % (data_path, data_file), bin_width=sabes_df.iloc[0]["bin_width"],
                     filter_fn=sabes_df.iloc[0]['filter_fn'], filter_kwargs=sabes_df.iloc[0]['filter_kwargs'],
                     boxcox=sabes_df.iloc[0]['boxcox'], spike_threshold=sabes_df.iloc[0]['spike_threshold'])
    
    X = np.squeeze(dat['spike_rates'])
    Z = dat['behavior']

    kfold = KFold(n_splits=5, shuffle=False)

    # Average results across folds
    decoding_weights = []
    encoding_weights = []
    su_decoding_weights = []
    su_encoding_weights = []
    su_r2_pos = []
    su_r2_vel = []
    su_r2_enc = []

    # Single unit statistics
    su_var = np.zeros((5, X.shape[-1]))
    su_mmse = np.zeros((5, X.shape[-1]))
    su_pi = np.zeros((5, X.shape[-1]))
    su_fcca = np.zeros((5, X.shape[-1]))

    # decoding/encoding weights after projection
    dims = np.arange(2, 11)
    proj_dw_pos = np.zeros((5, 4, dims.size, X.shape[-1]))
    proj_dw_vel = np.zeros((5, 4, dims.size, X.shape[-1]))
    proj_ew = np.zeros((5, 4, dims.size, X.shape[-1]))

    for fold_idx, (train_idxs, test_idxs) in enumerate(kfold.split(X)):

        r = {}

        ztrain = Z[train_idxs, :]
        ztest = Z[test_idxs, :]

        # Population level decoding/encoding - use the coefficient in the linear fit
        # Record both the weights in the coefficient but also the loadings onto the SVD

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        ccm = calc_cross_cov_mats_from_data(xtrain, T=10)
        ccm = torch.tensor(ccm)

        _, _, _, decodingregressor = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params)
        _, encodingregressor = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params)

        decoding_weights.append(decodingregressor.coef_)
        encoding_weights.append(encodingregressor.coef_)                
        
        r2_pos_decoding, r2_vel_decoding, r2_encoding = [], [], []
        
        su_dw = []
        su_ew = []            
        sur2pos = []
        sur2vel = []
        sur2enc = []
        for neu_idx in range(X.shape[-1]):           #Fit all neurons one by one
            
            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            # Decoding
            r2_pos, r2_vel, _, dr = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params)
            r2_pos_decoding.append(r2_pos)
            r2_vel_decoding.append(r2_vel)
            su_dw.append(dr.coef_)
            sur2pos.append(r2_pos)
            sur2vel.append(r2_vel)

            # Encoding
            r2_encoding_, er = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params)
            r2_encoding.append(r2_encoding_)
            su_ew.append(er.coef_)        
            sur2enc.append(r2_encoding)


        su_decoding_weights.append(np.array(su_dw))
        su_encoding_weights.append(np.array(su_ew))
        
        su_r2_pos.append(np.array(sur2pos))
        su_r2_vel.append(np.array(sur2vel))
        su_r2_enc.append(np.array(sur2enc))
        
        lqg_loss = build_lqg_loss(ccm, 1, ortho_lambda=0., project_mmse=False, loss_type='trace')

        for neu_idx in range(X.shape[-1]):

            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            su_var[fold_idx, neu_idx] = np.var(xtrain)
            proj = np.zeros((ccm.shape[-1], 1))
            proj[neu_idx] = 1
            proj = torch.tensor(proj)
            su_mmse[fold_idx, neu_idx] = calc_mmse_from_cross_cov_mats(ccm[0:4, ...], proj=proj).numpy()
            su_pi[fold_idx, neu_idx] = calc_pi_from_cross_cov_mats(torch.unsqueeze(torch.unsqueeze(ccm[0:9, neu_idx ,neu_idx], 1), 2))
            su_fcca[fold_idx, neu_idx] = lqg_loss(proj).numpy()            

        # Calculate decoding weights based on projection of the data first

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        for dr_idx, dimreduc_method in enumerate(['DCA', 'KCA', 'LQGCA', 'PCA']):
            for didx, d in enumerate(dims):
                df_ = apply_df_filters(sabes_df, data_file=data_file, fold_idx=fold_idx, dim=d, dimreduc_method=dimreduc_method)
                if dimreduc_method == 'LQGCA':
                    df_ = apply_df_filters(df_, dimreduc_args={'T': 3, 'loss_type': 'trace', 'n_init': 5})
                V = df_.iloc[0]['coef']
                if dimreduc_method == 'PCA':
                    V = V[:, 0:2]

                xtrain_proj = xtrain @ V 
                xtest_proj = xtest @ V

                _, _, _, decodingregressor = lr_decoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params)
                _, encodingregressor = lr_encoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params)
                proj_dw_pos[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_[0:2, :].T, d2= decoder_params['decoding_window'])
                proj_dw_vel[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_[2:4, :].T, d2= decoder_params['decoding_window'])
                proj_ew[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, encodingregressor.coef_)

    # Average results across folds
    decoding_weights = np.mean(np.array(decoding_weights), axis=0)
    encoding_weights = np.mean(np.array(encoding_weights), axis=0)
    su_decoding_weights = np.mean(np.array(su_decoding_weights), axis=0)
    su_encoding_weights = np.mean(np.array(su_encoding_weights), axis=0)
    
    su_r2_pos = np.mean(np.array(su_r2_pos), axis=0)
    su_r2_vel = np.mean(np.array(su_r2_vel), axis=0)
    su_r2_enc = np.mean(np.array(su_r2_enc), axis=0)

    su_var = np.mean(su_var, axis=0)
    su_mmse = np.mean(su_mmse, axis=0)
    su_pi = np.mean(su_pi, axis=0)
    su_fcca = np.mean(su_fcca, axis=0)

    proj_dw_pos = np.mean(proj_dw_pos, axis=0)
    proj_dw_vel = np.mean(proj_dw_vel, axis=0)
    proj_ew = np.mean(proj_ew, axis=0)

    result = {}
    for variable in ('data_file', 'decoding_weights', 'encoding_weights', 'su_decoding_weights', 'su_encoding_weights', 'su_r2_pos',
                     'su_r2_vel', 'su_r2_enc', 'su_var', 'su_mmse', 'su_pi', 'su_fcca', 'proj_dw_pos', 'proj_dw_vel', 'proj_ew', 'decoder_params'):
        result[variable] = eval(variable)

    sabes_su_l.append(result)

0it [00:00, ?it/s]

Processing spikes


100%|██████████| 1/1 [00:25<00:00, 25.11s/it]
1it [01:09, 69.11s/it]

Processing spikes


100%|██████████| 1/1 [00:34<00:00, 34.96s/it]
2it [02:41, 82.52s/it]

Processing spikes


100%|██████████| 1/1 [00:07<00:00,  7.11s/it]
3it [03:09, 57.71s/it]

Processing spikes


100%|██████████| 1/1 [00:48<00:00, 48.27s/it]
4it [05:18, 86.01s/it]

Processing spikes


100%|██████████| 1/1 [00:20<00:00, 20.25s/it]
5it [05:59, 69.81s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.43s/it]
6it [06:18, 52.52s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.10s/it]
7it [06:36, 41.18s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.70s/it]
8it [07:01, 36.12s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.89s/it]
9it [07:22, 31.43s/it]

Processing spikes


100%|██████████| 1/1 [00:05<00:00,  5.34s/it]
10it [07:40, 27.08s/it]

Processing spikes


100%|██████████| 1/1 [00:07<00:00,  7.12s/it]
11it [08:12, 28.74s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.97s/it]
12it [08:37, 27.47s/it]

Processing spikes


100%|██████████| 1/1 [00:09<00:00,  9.87s/it]
13it [09:10, 29.32s/it]

Processing spikes


100%|██████████| 1/1 [00:07<00:00,  7.49s/it]
14it [09:33, 27.37s/it]

Processing spikes


100%|██████████| 1/1 [00:07<00:00,  7.26s/it]
15it [09:57, 26.36s/it]

Processing spikes


100%|██████████| 1/1 [00:07<00:00,  7.42s/it]
16it [10:22, 25.82s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.77s/it]
17it [10:46, 25.31s/it]

Processing spikes


100%|██████████| 1/1 [00:07<00:00,  7.15s/it]
18it [11:12, 25.51s/it]

Processing spikes


100%|██████████| 1/1 [00:07<00:00,  7.40s/it]
19it [11:36, 25.21s/it]

Processing spikes


100%|██████████| 1/1 [00:08<00:00,  8.40s/it]
20it [12:02, 25.40s/it]

Processing spikes


100%|██████████| 1/1 [00:10<00:00, 10.67s/it]
21it [12:42, 29.58s/it]

Processing spikes


100%|██████████| 1/1 [00:06<00:00,  6.66s/it]
22it [13:05, 27.81s/it]

Processing spikes


100%|██████████| 1/1 [00:07<00:00,  7.97s/it]
23it [13:31, 27.05s/it]

Processing spikes


100%|██████████| 1/1 [00:08<00:00,  8.20s/it]
24it [13:55, 26.28s/it]

Processing spikes


100%|██████████| 1/1 [00:08<00:00,  8.72s/it]
25it [14:28, 28.25s/it]

Processing spikes


100%|██████████| 1/1 [00:08<00:00,  8.55s/it]
26it [14:55, 27.92s/it]

Processing spikes


100%|██████████| 1/1 [00:10<00:00, 10.58s/it]
27it [15:31, 30.21s/it]

Processing spikes


100%|██████████| 1/1 [00:11<00:00, 11.82s/it]
28it [16:03, 34.42s/it]


In [12]:
sabes_su_df = pd.DataFrame(sabes_su_l)

In [13]:
with open('/mnt/sdb1/nc_data/sabes_su_df.dat', 'wb') as f:
    f.write(pickle.dumps(sabes_su_l))

In [91]:
# Loco
loco_su_l = []
decoder_params = {'trainlag': 4, 'testlag': 4, 'decoding_window': 3}
# Copied from the submit file
loader_params = {'bin_width':50, 'filter_fn':'none', 'filter_kwargs':{}, 'boxcox':0.5, 'spike_threshold':100, 'region': 'S1'}

data_path = '/mnt/Secondary/data/sabes'
data_files = np.unique(loco_df['data_file'].values)

for i, data_file in tqdm(enumerate(data_files)):    
    dat = load_sabes(data_file, bin_width=loader_params["bin_width"],
                     filter_fn=loader_params['filter_fn'], filter_kwargs=loader_params['filter_kwargs'],
                     boxcox=loader_params['boxcox'], spike_threshold=loader_params['spike_threshold'], region='S1')
    
    X = np.squeeze(dat['spike_rates'])
    Z = dat['behavior']

    kfold = KFold(n_splits=5, shuffle=False)

    # Average results across folds
    decoding_weights = []
    encoding_weights = []
    su_decoding_weights = []
    su_encoding_weights = []
    su_r2_pos = []
    su_r2_vel = []
    su_r2_enc = []

    # Single unit statistics
    su_var = np.zeros((5, X.shape[-1]))
    su_mmse = np.zeros((5, X.shape[-1]))
    su_pi = np.zeros((5, X.shape[-1]))
    su_fcca = np.zeros((5, X.shape[-1]))

    # decoding/encoding weights after projection
    dims = np.arange(2, 11)
    proj_dw_pos = np.zeros((5, 4, dims.size, X.shape[-1]))
    proj_dw_vel = np.zeros((5, 4, dims.size, X.shape[-1]))
    proj_ew = np.zeros((5, 4, dims.size, X.shape[-1]))

    for fold_idx, (train_idxs, test_idxs) in enumerate(kfold.split(X)):

        r = {}

        ztrain = Z[train_idxs, :]
        ztest = Z[test_idxs, :]

        # Population level decoding/encoding - use the coefficient in the linear fit
        # Record both the weights in the coefficient but also the loadings onto the SVD

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        ccm = calc_cross_cov_mats_from_data(xtrain, T=10)
        ccm = torch.tensor(ccm)

        _, _, _, decodingregressor = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params)
        _, encodingregressor = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params)

        decoding_weights.append(decodingregressor.coef_)
        encoding_weights.append(encodingregressor.coef_)                
        
        r2_pos_decoding, r2_vel_decoding, r2_encoding = [], [], []
        
        su_dw = []
        su_ew = []            
        sur2pos = []
        sur2vel = []
        sur2enc = []
        for neu_idx in range(X.shape[-1]):           #Fit all neurons one by one
            
            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            # Decoding
            r2_pos, r2_vel, _, dr = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params)
            r2_pos_decoding.append(r2_pos)
            r2_vel_decoding.append(r2_vel)
            su_dw.append(dr.coef_)
            sur2pos.append(r2_pos)
            sur2vel.append(r2_vel)

            # Encoding
            r2_encoding_, er = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params)
            r2_encoding.append(r2_encoding_)
            su_ew.append(er.coef_)        
            sur2enc.append(r2_encoding)


        su_decoding_weights.append(np.array(su_dw))
        su_encoding_weights.append(np.array(su_ew))
        
        su_r2_pos.append(np.array(sur2pos))
        su_r2_vel.append(np.array(sur2vel))
        su_r2_enc.append(np.array(sur2enc))
        
        lqg_loss = build_lqg_loss(ccm, 1, ortho_lambda=0., project_mmse=False, loss_type='trace')

        for neu_idx in range(X.shape[-1]):

            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            su_var[fold_idx, neu_idx] = np.var(xtrain)
            proj = np.zeros((ccm.shape[-1], 1))
            proj[neu_idx] = 1
            proj = torch.tensor(proj)
            su_mmse[fold_idx, neu_idx] = calc_mmse_from_cross_cov_mats(ccm[0:4, ...], proj=proj).numpy()
            su_pi[fold_idx, neu_idx] = calc_pi_from_cross_cov_mats(torch.unsqueeze(torch.unsqueeze(ccm[0:9, neu_idx ,neu_idx], 1), 2))
            su_fcca[fold_idx, neu_idx] = lqg_loss(proj).numpy()            

        # Calculate decoding weights based on projection of the data first

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        for dr_idx, dimreduc_method in enumerate(['DCA', 'KCA', 'LQGCA', 'PCA']):
            for didx, d in enumerate(dims):
                df_ = apply_df_filters(loco_df, data_file=data_file, fold_idx=fold_idx, dim=d, dimreduc_method=dimreduc_method, region='S1')
                if dimreduc_method == 'LQGCA':
                    df_ = apply_df_filters(df_, dimreduc_args={'T': 3, 'loss_type': 'trace', 'n_init': 5})
                V = df_.iloc[0]['coef']
                if dimreduc_method == 'PCA':
                    V = V[:, 0:2]

                xtrain_proj = xtrain @ V 
                xtest_proj = xtest @ V

                _, _, _, decodingregressor = lr_decoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params)
                _, encodingregressor = lr_encoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params)
                proj_dw_pos[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_[0:2, :].T, d2= decoder_params['decoding_window'])
                proj_dw_vel[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_[2:4, :].T, d2= decoder_params['decoding_window'])
                proj_ew[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, encodingregressor.coef_)

    # Average results across folds
    decoding_weights = np.mean(np.array(decoding_weights), axis=0)
    encoding_weights = np.mean(np.array(encoding_weights), axis=0)
    su_decoding_weights = np.mean(np.array(su_decoding_weights), axis=0)
    su_encoding_weights = np.mean(np.array(su_encoding_weights), axis=0)
    
    su_r2_pos = np.mean(np.array(su_r2_pos), axis=0)
    su_r2_vel = np.mean(np.array(su_r2_vel), axis=0)
    su_r2_enc = np.mean(np.array(su_r2_enc), axis=0)

    su_var = np.mean(su_var, axis=0)
    su_mmse = np.mean(su_mmse, axis=0)
    su_pi = np.mean(su_pi, axis=0)
    su_fcca = np.mean(su_fcca, axis=0)

    proj_dw_pos = np.mean(proj_dw_pos, axis=0)
    proj_dw_vel = np.mean(proj_dw_vel, axis=0)
    proj_ew = np.mean(proj_ew, axis=0)

    result = {}
    for variable in ('data_file', 'decoding_weights', 'encoding_weights', 'su_decoding_weights', 'su_encoding_weights', 'su_r2_pos',
                     'su_r2_vel', 'su_r2_enc', 'su_var', 'su_mmse', 'su_pi', 'su_fcca', 'proj_dw_pos', 'proj_dw_vel', 'proj_ew', 'decoder_params'):
        result[variable] = eval(variable)

    loco_su_l.append(result)

0it [00:00, ?it/s]

Processing spikes


100%|██████████| 1/1 [00:16<00:00, 16.61s/it]
1it [00:38, 39.00s/it]

Processing spikes


100%|██████████| 1/1 [00:22<00:00, 22.38s/it]
2it [01:32, 47.38s/it]

Processing spikes


100%|██████████| 1/1 [00:29<00:00, 29.43s/it]
3it [02:38, 55.89s/it]

Processing spikes


100%|██████████| 1/1 [00:10<00:00, 10.44s/it]
4it [03:05, 44.40s/it]

Processing spikes


100%|██████████| 1/1 [00:24<00:00, 24.97s/it]
5it [04:01, 48.73s/it]

Processing spikes


100%|██████████| 1/1 [00:17<00:00, 17.29s/it]
6it [04:43, 46.57s/it]

Processing spikes


100%|██████████| 1/1 [00:18<00:00, 18.85s/it]
7it [05:30, 46.67s/it]

Processing spikes


100%|██████████| 1/1 [00:16<00:00, 16.72s/it]
8it [06:13, 45.36s/it]

Processing spikes


100%|██████████| 1/1 [00:10<00:00, 10.64s/it]
9it [06:40, 39.64s/it]

Processing spikes


100%|██████████| 1/1 [00:21<00:00, 21.42s/it]
10it [07:31, 45.14s/it]


In [94]:
# Peanut
peanut_su_l = []
decoder_params = {'trainlag': 0, 'testlag': 0, 'decoding_window': 6}

fpath = '/mnt/Secondary/data/peanut/data_dict_peanut_day14.obj'
epochs = np.unique(peanut_df['epoch'].values)

for i, epoch in tqdm(enumerate(epochs)):    
    dat = load_peanut(fpath, epoch, speed_threshold=peanut_df.iloc[0]['speed_threshold'], 
                      bin_width=peanut_df.iloc[0]['bin_width'], filter_fn='none', filter_kwargs={},
                      spike_threshold=peanut_df.iloc[0]['spike_threshold'], boxcox=peanut_df.iloc[0]['boxcox'])

    X = np.squeeze(dat['spike_rates'])
    Z = dat['behavior']

    kfold = KFold(n_splits=5, shuffle=False)

    # Average results across folds
    decoding_weights = []
    encoding_weights = []
    su_decoding_weights = []
    su_encoding_weights = []
    su_r2_pos = []
    su_r2_enc = []

    # Single unit statistics
    su_var = np.zeros((5, X.shape[-1]))
    su_mmse = np.zeros((5, X.shape[-1]))
    su_pi = np.zeros((5, X.shape[-1]))
    su_fcca = np.zeros((5, X.shape[-1]))

    # decoding/encoding weights after projection
    dims = np.arange(2, 11)
    proj_dw = np.zeros((5, 4, dims.size, X.shape[-1]))
    proj_ew = np.zeros((5, 4, dims.size, X.shape[-1]))

    for fold_idx, (train_idxs, test_idxs) in enumerate(kfold.split(X)):

        r = {}

        ztrain = Z[train_idxs, :]
        ztest = Z[test_idxs, :]

        # Population level decoding/encoding - use the coefficient in the linear fit
        # Record both the weights in the coefficient but also the loadings onto the SVD

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        ccm = calc_cross_cov_mats_from_data(xtrain, T=10)
        ccm = torch.tensor(ccm)

        _, decodingregressor = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params, include_velocity=False, include_acc=False)
        _, encodingregressor = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params, include_velocity=False, include_acc=False)

        decoding_weights.append(decodingregressor.coef_)
        encoding_weights.append(encodingregressor.coef_)                
        
        r2_pos_decoding, r2_encoding = [], []
        
        su_dw = []
        su_ew = []            
        
        for neu_idx in range(X.shape[-1]):           #Fit all neurons one by one
            
            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            # Decoding
            r2_pos, dr = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params, include_velocity=False, include_acc=False)
            r2_pos_decoding.append(r2_pos)
            su_dw.append(dr.coef_)

            # Encoding
            r2_encoding_, er = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params, include_velocity=False, include_acc=False)
            r2_encoding.append(r2_encoding_)
            su_ew.append(er.coef_)        
            
        su_decoding_weights.append(np.array(su_dw))
        su_encoding_weights.append(np.array(su_ew))
        
        su_r2_pos.append(np.array(r2_pos))
        su_r2_enc.append(np.array(r2_encoding))
        

        lqg_loss = build_lqg_loss(ccm, 1, ortho_lambda=0., project_mmse=False, loss_type='trace')

        for neu_idx in range(X.shape[-1]):

            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            su_var[fold_idx, neu_idx] = np.var(xtrain)
            proj = np.zeros((ccm.shape[-1], 1))
            proj[neu_idx] = 1
            proj = torch.tensor(proj)
            su_mmse[fold_idx, neu_idx] = calc_mmse_from_cross_cov_mats(ccm[0:4, ...], proj=proj).numpy()
            su_pi[fold_idx, neu_idx] = calc_pi_from_cross_cov_mats(torch.unsqueeze(torch.unsqueeze(ccm[0:9, neu_idx ,neu_idx], 1), 2))
            su_fcca[fold_idx, neu_idx] = lqg_loss(proj).numpy()            

        # Calculate decoding weights based on projection of the data first

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        for dr_idx, dimreduc_method in enumerate(['DCA', 'KCA', 'LQGCA', 'PCA']):
            for didx, d in enumerate(dims):
                df_ = apply_df_filters(peanut_df, epoch=epoch, fold_idx=fold_idx, dim=d, dimreduc_method=dimreduc_method)
                if dimreduc_method == 'LQGCA':
                    df_ = apply_df_filters(df_, dimreduc_args={'T': 3, 'loss_type': 'trace', 'n_init': 5})
                V = df_.iloc[0]['coef']
                if dimreduc_method == 'PCA':
                    V = V[:, 0:2]

                xtrain_proj = xtrain @ V 
                xtest_proj = xtest @ V

                _, decodingregressor = lr_decoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params, include_velocity=False, include_acc=False)
                _, encodingregressor = lr_encoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params, include_velocity=False, include_acc=False)

                proj_dw[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_.T, d2= decoder_params['decoding_window'])
                proj_ew[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, encodingregressor.coef_)


    # Average results across folds
    decoding_weights = np.mean(np.array(decoding_weights), axis=0)
    encoding_weights = np.mean(np.array(encoding_weights), axis=0)
    su_decoding_weights = np.mean(np.array(su_decoding_weights), axis=0)
    su_encoding_weights = np.mean(np.array(su_encoding_weights), axis=0)
    
    su_r2_pos = np.mean(np.array(su_r2_pos), axis=0)
    su_r2_enc = np.mean(np.array(su_r2_enc), axis=0)

    su_var = np.mean(su_var, axis=0)
    su_mmse = np.mean(su_mmse, axis=0)
    su_pi = np.mean(su_pi, axis=0)
    su_fcca = np.mean(su_fcca, axis=0)

    proj_dw = np.mean(proj_dw, axis=0)
    proj_ew = np.mean(proj_ew, axis=0)

    result = {}
    for variable in ('epoch', 'decoding_weights', 'encoding_weights', 'su_decoding_weights', 'su_encoding_weights', 'su_r2_pos',
                     'su_r2_enc', 'su_var', 'su_mmse', 'su_pi', 'su_fcca', 'proj_dw', 'proj_ew'):
        result[variable] = eval(variable)

    peanut_su_l.append(result)

8it [01:13,  9.20s/it]


#### Visualization

In [14]:
# Barycentric plotting
import plotly
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go

### Single Unit Calcs with Trialization

In [96]:
# This requires trialization dimreduc and a search for best decoding parameters in the segmented setting

In [65]:
start_times = {'indy_20160426_01': 0,
               'indy_20160622_01':1700,
               'indy_20160624_03': 500,
               'indy_20160627_01': 0,
               'indy_20160630_01': 0,
               'indy_20160915_01': 0,
               'indy_20160921_01': 0,
               'indy_20160930_02': 0,
               'indy_20160930_05': 300,
               'indy_20161005_06': 0,
               'indy_20161006_02': 350,
               'indy_20161007_02': 950,
               'indy_20161011_03': 0,
               'indy_20161013_03': 0,
               'indy_20161014_04': 0,
               'indy_20161017_02': 0,
               'indy_20161024_03': 0,
               'indy_20161025_04': 0,
               'indy_20161026_03': 0,
               'indy_20161027_03': 500,
               'indy_20161206_02': 5500,
               'indy_20161207_02': 0,
               'indy_20161212_02': 0,
               'indy_20161220_02': 0,
               'indy_20170123_02': 0,
               'indy_20170124_01': 0,
               'indy_20170127_03': 0,
               'indy_20170131_02': 0,
               }

In [95]:
from segmentation import reach_segment_sabes
from loaders import segment_peanut

In [None]:
# Same sort of calculation but now trialized
_su_trialized_l = []
decoder_params = {'trainlag': 4, 'testlag': 4, 'decoding_window': 3}

data_path = '/mnt/Secondary/data/sabes'
data_files = np.unique(sabes_df['data_file'].values)

for i, data_file in tqdm(enumerate(data_files)):    
    dat = load_sabes('%s/%s' % (data_path, data_file), bin_width=sabes_df.iloc[0]["bin_width"],
                     filter_fn=sabes_df.iloc[0]['filter_fn'], filter_kwargs=sabes_df.iloc[0]['filter_kwargs'],
                     boxcox=sabes_df.iloc[0]['boxcox'], spike_threshold=sabes_df.iloc[0]['spike_threshold'])
    
    dat = reach_segment_sabes(dat, start_time=

    X = np.squeeze(dat['spike_rates'])
    Z = dat['behavior']
    r = {}
    r['data_file'] = data_file

    kfold = KFold(n_splits=5, shuffle=False)

    # Average results across folds
    decoding_weights = []
    encoding_weights = []
    su_decoding_weights = []
    su_encoding_weights = []
    su_r2_pos = []
    su_r2_vel = []
    su_r2_enc = []

    # Single unit statistics
    su_var = np.zeros((5, X.shape[-1]))
    su_mmse = np.zeros((5, X.shape[-1]))
    su_pi = np.zeros((5, X.shape[-1]))
    su_fcca = np.zeros((5, X.shape[-1]))

    # decoding/encoding weights after projection
    dims = np.arange(2, 11)
    proj_dw = np.zeros((5, 4, dims.size, xtrain.shape[-1]))
    proj_ew = np.zeros((5, 4, dims.size, xtrain.shape[-1]))

    for fold_idx, (train_idxs, test_idxs) in enumerate(kfold.split(X)):

        r = {}

        ztrain = Z[train_idxs, :]
        ztest = Z[test_idxs, :]

        # Population level decoding/encoding - use the coefficient in the linear fit
        # Record both the weights in the coefficient but also the loadings onto the SVD

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        ccm = calc_cross_cov_mats_from_data(xtrain, T=10)
        ccm = torch.tensor(ccm)

        _, _, _, decodingregressor = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params)
        _, encodingregressor = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params)

        decoding_weights.append(decodingregressor.coef_)
        encoding_weights.append(encodingregressor.coef_)                
        
        r2_pos_decoding, r2_vel_decoding, r2_encoding = [], [], []
        
        su_dw = []
        su_ew = []            
        
        for neu_idx in range(X.shape[-1]):           #Fit all neurons one by one
            
            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            # Decoding
            r2_pos, r2_vel, _, dr = lr_decoder(xtest, xtrain, ztest, ztrain, **decoder_params)
            r2_pos_decoding.append(r2_pos)
            r2_vel_decoding.append(r2_vel)
            su_dw.append(dr.coef_)

            # Encoding
            r2_encoding_, er = lr_encoder(xtest, xtrain, ztest, ztrain, **decoder_params)
            r2_encoding.append(r2_encoding_)
            su_ew.append(er.coef_)        
            
        su_decoding_weights.append(np.array(su_dw))
        su_encoding_weights.append(np.array(su_ew))
        
        su_r2_pos.append(np.array(r2_pos))
        su_r2_vel.append(np.array(r2_vel))
        su_r2_enc.append(np.array(r2_encoding))
        

        lqg_loss = build_lqg_loss(ccm, 1, ortho_lambda=0., project_mmse=False, loss_type='trace')

        for neu_idx in range(X.shape[-1]):

            xtrain = X[train_idxs, neu_idx][:, np.newaxis]
            xtest = X[test_idxs, neu_idx][:, np.newaxis]

            su_var[fold_idx, neu_idx] = np.var(xtrain)
            proj = np.zeros((ccm.shape[-1], 1))
            proj[neu_idx] = 1
            proj = torch.tensor(proj)
            su_mmse[fold_idx, neu_idx] = calc_mmse_from_cross_cov_mats(ccm[0:4, ...], proj=proj).numpy()
            su_pi[fold_idx, neu_idx] = calc_pi_from_cross_cov_mats(torch.unsqueeze(torch.unsqueeze(ccm[0:9, neu_idx ,neu_idx], 1), 2))
            su_fcca[fold_idx, neu_idx] = lqg_loss(proj).numpy()            

        # Calculate decoding weights based on projection of the data first

        xtrain = X[train_idxs, :]
        xtest = X[test_idxs, :]

        for dr_idx, dimreduc_method in enumerate(['DCA', 'KCA', 'LQGCA', 'PCA']):
            for didx, d in enumerate(dims):
                df_ = apply_df_filters(sabes_df, data_file=data_file, fold_idx=fold_idx, dim=d, dimreduc_method=dimreduc_method)
                if dimreduc_method == 'LQGCA':
                    df_ = apply_df_filters(df_, dimreduc_args={'T': 3, 'loss_type': 'trace', 'n_init': 5})
                V = df_.iloc[0]['coef']
                if dimreduc_method == 'PCA':
                    V = V[:, 0:2]

                xtrain_proj = xtrain @ V 
                xtest_proj = xtest @ V

                _, _, _, decodingregressor = lr_decoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params)
                _, encodingregressor = lr_encoder(xtest_proj, xtrain_proj, ztest, ztrain, **decoder_params)

                proj_dw[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, decodingregressor.coef_, d2= decoder_params['decoding_window'])
                proj_ew[fold_idx, dr_idx, didx, :] = calc_cascaded_loadings(V, encodingregressor.coef_)


    # Average results across folds
    decoding_weights = np.mean(np.array(decoding_weights), axis=0)
    encoding_weights = np.mean(np.array(encoding_weights), axis=0)
    su_decoding_weights = np.mean(np.array(su_decoding_weights), axis=0)
    su_encoding_weights = np.mean(np.array(su_encoding_weights), axis=0)
    
    su_r2_pos = np.mean(np.array(su_r2_pos), axis=0)
    su_r2_vel = np.mean(np.array(su_r2_vel), axis=0)
    su_r2_enc = np.mean(np.array(su_r2_enc), axis=0)

    su_var = np.mean(su_var, axis=0)
    su_mmse = np.mean(su_mmse, axis=0)
    su_pi = np.mean(su_pi, axis=0)
    su_fcca = np.mean(su_fcca, axis=0)

    proj_dw = np.mean(proj_dw, axis=0)
    proj_ew = np.mean(proj_ew, axis=0)

    result = {}
    for variable in ('data_file', 'decoding_weights', 'encoding_weights', 'su_decoding_weights', 'su_encoding_weights', 'su_r2_pos',
                     'su_r2_vel', 'su_r2_enc', 'su_var', 'su_mmse', 'su_pi', 'su_fcca', 'proj_dw', 'proj_ew'):
        result[variable] = eval(variable)

    sabes_su_l.append(result)