# Sensitivity analysis

This notebook will compare the decoding accuracy of MARBLE representations on different hyperparameter choices.

In [None]:
!pip install statannotations ipympl

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import seaborn as sns
from statannotations.Annotator import Annotator
from sklearn.model_selection import KFold
from macaque_reaching_helpers import *
from tqdm import tqdm

Load kinematics data

In [None]:
!mkdir data
!wget -nc https://dataverse.harvard.edu/api/access/datafile/6969885 -O data/kinematics.pkl
    
with open('data/kinematics.pkl', 'rb') as handle:
    data = pickle.load(handle)

# Load MARBLE embeddings

In [None]:
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10209904 -O data/marble_embeddings_out20_pca5_25ms.pkl

with open('./data/marble_embeddings_out20_pca5_25ms.pkl', 'rb') as handle:
    _, marble_embeddings_5_25, _, _, trial_ids, _  = pickle.load(handle)

!wget -nc https://dataverse.harvard.edu/api/access/datafile/10209903 -O data/marble_embeddings_out20_pca5_50ms.pkl

with open('./data/marble_embeddings_out20_pca5_50ms.pkl', 'rb') as handle:
    _, marble_embeddings_5_50, _, _, trial_ids, _  = pickle.load(handle)

!wget -nc https://dataverse.harvard.edu/api/access/datafile/7062022 -O data/marble_embeddings_out20_pca5_100ms.pkl

with open('./data/marble_embeddings_out20_pca5_100ms.pkl', 'rb') as handle:
    _, marble_embeddings_5, _, _, trial_ids, _  = pickle.load(handle)

!wget -nc https://dataverse.harvard.edu/api/access/datafile/10209907 -O data/marble_embeddings_out20_pca7_100ms.pkl

with open('./data/marble_embeddings_out20_pca7_100ms.pkl', 'rb') as handle:
    _, marble_embeddings_7, _, _, _, _  = pickle.load(handle)

!wget -nc https://dataverse.harvard.edu/api/access/datafile/10209905 -O data/marble_embeddings_out20_pca10_100ms.pkl

with open('./data/marble_embeddings_out20_pca10_100ms.pkl', 'rb') as handle:
    _, marble_embeddings_10, _, _, _, _  = pickle.load(handle)

# define conditions of movement
conditions=['DownLeft','Left','UpLeft','Up','UpRight','Right','DownRight']  

# Load plain firing rates for comparison

In [None]:
# use the saved trial ids to match the embeddings to the kinematics

days = list(np.arange(20)) #we only computed the first 22 sessions to save compute time
for d in days:
    unique_trial_ids = np.unique(trial_ids[d])
    for t in unique_trial_ids:
        data[d][t]['kinematics'] = data[d][t]['kinematics'][:,:-1] #remove last point because
        data[d][t]['marble_emb_5'] = marble_embeddings_5[d][trial_ids[d]==t,:].T
        data[d][t]['marble_emb_7'] = marble_embeddings_7[d][trial_ids[d]==t,:].T
        data[d][t]['marble_emb_10'] = marble_embeddings_10[d][trial_ids[d]==t,:].T
        data[d][t]['marble_emb_5_25'] = marble_embeddings_5_25[d][trial_ids[d]==t,:].T
        data[d][t]['marble_emb_5_50'] = marble_embeddings_5_50[d][trial_ids[d]==t,:].T

# Decode across all sessions

Above we decoded for a single session. Lets now loop over every session and compute some quantitative comparisons with the ground truth kinematics.

In [None]:
kf = KFold(n_splits=5, shuffle=True) # use 5-fold split of the data 

r2_marble_vel_5 = []; r2_marble_pos_5 = []
r2_marble_vel_5_25 = []; r2_marble_pos_5_25 = []
r2_marble_vel_5_50 = []; r2_marble_pos_5_50 = []
r2_marble_vel_7 = []; r2_marble_pos_7 = []
r2_marble_vel_10 = []; r2_marble_pos_10 = []

# loop over sessions
for d in days:
    unique_trial_ids = np.unique(trial_ids[d])
    
    # cross validation
    for i, (train_index, test_index) in enumerate(kf.split(unique_trial_ids)):

        train_data = {key: data[d][key] for key in train_index if key in data[d]}
            
        #MARBLE PC5 (Gaussian filter width 100ms)
        Lw = train_OLE(data[d], unique_trial_ids[train_index], representation='marble_emb_5')
        
        for tr in unique_trial_ids[test_index]:
            trial_pred = decode_kinematics(data[d][tr], Lw, dt=20, representation='marble_emb_5')
            data[d][tr]['marble_decoded_5'] = trial_pred

        #MARBLE PC5 (Gaussian filter width 25ms)
        Lw = train_OLE(data[d], unique_trial_ids[train_index], representation='marble_emb_5_25')
        
        for tr in unique_trial_ids[test_index]:
            trial_pred = decode_kinematics(data[d][tr], Lw, dt=20, representation='marble_emb_5_25')
            data[d][tr]['marble_decoded_5_25'] = trial_pred

        #MARBLE PC5 (Gaussian filter width 50ms)
        Lw = train_OLE(data[d], unique_trial_ids[train_index], representation='marble_emb_5_50')
        
        for tr in unique_trial_ids[test_index]:
            trial_pred = decode_kinematics(data[d][tr], Lw, dt=20, representation='marble_emb_5_50')
            data[d][tr]['marble_decoded_5_50'] = trial_pred

        #MARBLE PC7 (Gaussian filter width 100ms)
        Lw = train_OLE(data[d], unique_trial_ids[train_index], representation='marble_emb_7')
        
        for tr in unique_trial_ids[test_index]:
            trial_pred = decode_kinematics(data[d][tr], Lw, dt=20, representation='marble_emb_7')
            data[d][tr]['marble_decoded_7'] = trial_pred

        #MARBLE PC10 (Gaussian filter width 100ms)
        Lw = train_OLE(data[d], unique_trial_ids[train_index], representation='marble_emb_10')
        
        for tr in unique_trial_ids[test_index]:
            trial_pred = decode_kinematics(data[d][tr], Lw, dt=20, representation='marble_emb_10')
            data[d][tr]['marble_decoded_10'] = trial_pred
            
    # r-squared velocity
    r2_pos, r2_vel = correlation(data[d], unique_trial_ids, representation='marble_decoded_5')   
    r2_marble_pos_5.append(r2_pos)
    r2_marble_vel_5.append(r2_vel)

    r2_pos, r2_vel = correlation(data[d], unique_trial_ids, representation='marble_decoded_5_25')   
    r2_marble_pos_5_25.append(r2_pos)
    r2_marble_vel_5_25.append(r2_vel)

    r2_pos, r2_vel = correlation(data[d], unique_trial_ids, representation='marble_decoded_5_50')   
    r2_marble_pos_5_50.append(r2_pos)
    r2_marble_vel_5_50.append(r2_vel)

    r2_pos, r2_vel = correlation(data[d], unique_trial_ids, representation='marble_decoded_7')   
    r2_marble_pos_7.append(r2_pos)
    r2_marble_vel_7.append(r2_vel)

    r2_pos, r2_vel = correlation(data[d], unique_trial_ids, representation='marble_decoded_10')   
    r2_marble_pos_10.append(r2_pos)
    r2_marble_vel_10.append(r2_vel)

How does the decoding accuracy of velocity between the two methods compare?

In [None]:
results = pd.DataFrame(data=np.vstack([ r2_marble_vel_5, r2_marble_vel_7, r2_marble_vel_10, r2_marble_vel_5_25, r2_marble_vel_5_50]).T,columns=['marble_5','marble_7','marble_10','marble_5_25', 'marble_5_50'])
results = results.melt()
results.columns = ['model','accuracy']

f, ax = plt.subplots(figsize=(4,7))
sns.despine(bottom=True, left=True)

sns.stripplot(
    data=results, x="model", y="accuracy",
    dodge=True, alpha=.5, zorder=1,
)

sns.pointplot(
    data=results, x="model", y="accuracy",
    join=False, dodge=.8 - .8 / 3, palette="dark",
    markers="d", scale=.75, errorbar=None
)

pairs=[("marble_5", "marble_7"), ("marble_5","marble_10"), ("marble_5","marble_5_25"), ("marble_5","marble_5_50")]

annotator = Annotator(ax, pairs, data=results, x="model", y="accuracy",)
annotator.configure(test='Wilcoxon', text_format='star', loc='outside')
annotator.apply_and_annotate()

Can we train a classifier to predict the movement? This tells us about the accuracy of decoding the position vectors.

In [None]:
marble_model_acc_5 = []
marble_model_acc_5_25 = []
marble_model_acc_5_50 = []
marble_model_acc_7 = []
marble_model_acc_10 = []

for d in days:           

    unique_trial_ids = np.unique(trial_ids[d])
    
    # fit classifier to kinematics
    clf = fit_classifier(data[d], conditions, unique_trial_ids, representation='kinematics')
    
    score = transform_classifier(clf, data[d], conditions, unique_trial_ids, representation='marble_decoded_5')
    marble_model_acc_5.append(score)

    score = transform_classifier(clf, data[d], conditions, unique_trial_ids, representation='marble_decoded_5_25')
    marble_model_acc_5_25.append(score)

    score = transform_classifier(clf, data[d], conditions, unique_trial_ids, representation='marble_decoded_5_50')
    marble_model_acc_5_50.append(score)

    score = transform_classifier(clf, data[d], conditions, unique_trial_ids, representation='marble_decoded_7')
    marble_model_acc_7.append(score)

    score = transform_classifier(clf, data[d], conditions, unique_trial_ids, representation='marble_decoded_10')
    marble_model_acc_10.append(score)
    

results = pd.DataFrame(data=np.vstack([marble_model_acc_5, marble_model_acc_7, marble_model_acc_10, marble_model_acc_5_25, marble_model_acc_5_50]).T,columns=['marble_model_acc_5', 'marble_model_acc_7', 'marble_model_acc_10', 'marble_model_acc_5_25', 'marble_model_acc_5_50'])

results = results.melt()
results.columns = ['model','accuracy']

In [None]:
f, ax = plt.subplots(figsize=(4,4))
sns.despine(bottom=True, left=True)

sns.stripplot(
    data=results, x="model", y="accuracy",
    dodge=True, alpha=.5, zorder=1,
)

sns.pointplot(
    data=results, x="model", y="accuracy",
    join=False, dodge=.8 - .8 / 3, palette="dark",
    markers="d", scale=.75, errorbar=None
)
plt.ylim([0,1])