# Sensitivity analysis

In this notebook, we compare MARBLE for different hyperparameter settings to show the robustness of the results

In [None]:
! pip install cebra statannotations elephant

import numpy as np
import matplotlib.pyplot as plt
import pickle
from statannotations.Annotator import Annotator
import pandas as pd
import seaborn as sns

import cebra

import MARBLE
from rat_utils import *

## Load the data

In [None]:
!mkdir data
!wget -nc https://dataverse.harvard.edu/api/access/datafile/7609512 -O data/rat_data.pkl

with open('data/rat_data.pkl', 'rb') as handle:
    hippocampus_pos = pickle.load(handle)
    
hippocampus_pos = hippocampus_pos['achilles']

In [None]:
# defining train and test splits of the data
def split_data(data, test_ratio):

    split_idx = int(data['neural'].shape[0] * (1-test_ratio))
    neural_train = data['neural'][:split_idx]
    neural_test = data['neural'][split_idx:]
    label_train = data['continuous_index'][:split_idx]
    label_test = data['continuous_index'][split_idx:]
    
    return neural_train.numpy(), neural_test.numpy(), label_train.numpy(), label_test.numpy()

neural_train, neural_test, label_train, label_test = split_data(hippocampus_pos, 0.2)

## Robustness vs number of PCA components

### Fetch and evaluate pretrained MARBLE models

In [None]:
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10212902 -O data/marble_achilles_pca3.pth
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10212900 -O data/marble_achilles_pca5.pth
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10212891 -O data/marble_achilles_pca10.pth
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10212896 -O data/marble_achilles_pca20.pth
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10212901 -O data/marble_achilles_pca30.pth

In [None]:
training, testing, labels_train, labels_test = [], [], [], []

for i, pca_n in enumerate([3, 5, 10, 20, 30]):
    data_train, label_train_marble, pca = convert_spikes_to_rates(neural_train.T, label_train, pca_n=pca_n)
    data_test, label_test_marble, _ = convert_spikes_to_rates(neural_test.T, label_test, pca=pca)
    marble_model = MARBLE.net(data_train, loadpath=f"data/marble_achilles_pca{pca_n}.pth")
    
    data_train = marble_model.transform(data_train)
    data_test = marble_model.transform(data_test)
    
    training.append(data_train)
    testing.append(data_test)
    labels_train.append(label_train_marble)
    labels_test.append(label_test_marble)

### Fetch and evaluate pretrained Cebra models for comparison

In [None]:
#Cebra-time
!wget -nc https://dataverse.harvard.edu/api/access/datafile/7609517 -O data/cebra_time_achilles_32D.pt
cebra_time_model = cebra.CEBRA.load("data/cebra_time_achilles_32D.pt")
cebra_time_train = cebra_time_model.transform(neural_train)
cebra_time_test = cebra_time_model.transform(neural_test)

#Cebra-behaviour
!wget -nc https://dataverse.harvard.edu/api/access/datafile/7609520 -O data/cebra_behaviour_achilles_32D.pt
cebra_behaviour_model = cebra.CEBRA.load("data/cebra_behaviour_achilles_32D.pt")
cebra_behaviour_train = cebra_behaviour_model.transform(neural_train)
cebra_behaviour_test = cebra_behaviour_model.transform(neural_test)

In [None]:
cebra_time_decode = decoding_pos_dir(cebra_time_train, cebra_time_test, label_train, label_test)
cebra_behaviour_decode = decoding_pos_dir(cebra_behaviour_train, cebra_behaviour_test, label_train, label_test)
marble_decode = decoding_pos_dir(data_train.emb, data_test.emb, label_train_marble, label_test_marble)

In [None]:
num_models = 5
num_plots_per_model = 2
fig = plt.figure(figsize=(20,12))

for i, pca_n in enumerate([3, 5, 10, 20, 30]):
    
    data_train = training[i]
    data_test = testing[i]
    label_train_marble = labels_train[i]
    label_test_marble = labels_test[i]

    # Calculate subplot index for training data
    ax1 = fig.add_subplot( num_plots_per_model, num_models, i+1, projection='3d')
    ax = cebra.plot_embedding(ax=ax1, embedding=data_train.emb, embedding_labels=label_train_marble[:,0], markersize=0.2, title=f'MARBLE-train_pca{pca_n}')
    
    # Calculate subplot index for testing data
    ax2 = fig.add_subplot(num_plots_per_model, num_models, num_models + i + 1, projection='3d')
    ax = cebra.plot_embedding(ax=ax2, embedding=data_test.emb, embedding_labels=label_test_marble[:,0], markersize=1, title=f'MARBLE-test_pca{pca_n}')
    
plt.tight_layout()

In [None]:
results = [cebra_time_decode[4][:-1], cebra_behaviour_decode[4][:-1] ]
for i, pca_n in enumerate([3, 5, 10, 20, 30]):
    
    data_train = training[i]
    data_test = testing[i]
    label_train_marble = labels_train[i]
    label_test_marble = labels_test[i]
    
    marble_decode = decoding_pos_dir(data_train.emb, data_test.emb, label_train_marble, label_test_marble)
    results.append(marble_decode[4])
    
results = pd.DataFrame(data=np.vstack(results).T,columns=['c-time','c-behaviour','pca3','pca5','pca10','pca20','pca30',])
results = results.melt()
results.columns = ['model','accuracy']

f, ax = plt.subplots(figsize=(8,8))
sns.despine(bottom=True, left=True)

random_sampling = np.random.randint(low=0, high=results.shape[0], size=(200,))
order = ['c-time','c-behaviour','pca5','pca3','pca10','pca20','pca30',]
sns.stripplot(
    data=results.iloc[random_sampling,:], x="model", y="accuracy", order=order, 
    dodge=True, alpha=.5, zorder=1, color='gray',
)

sns.pointplot(
    data=results, x="model", y="accuracy",  order=order, 
    join=False, dodge=.8 - .8 / 3, palette="dark",
    markers="d", scale=.75, errorbar=None
)

plt.ylim([0,0.4])

pairs=[("c-time", "c-behaviour"),
      ("c-behaviour", "pca3"),
      ("pca3", "pca5"),
      ("pca5", "pca10"),
      ("pca10", "pca20"),
      ("pca20", "pca30"),]

annotator = Annotator(ax, pairs, data=results, x="model", y="accuracy",order=order)
annotator.configure(test='Wilcoxon', text_format='star', loc='outside')
annotator.apply_and_annotate()
plt.tight_layout()

## Robustness against kernel width

### Fetch pretrained MARBLE models

In [None]:
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10212895 -O data/marble_achilles_kw3.pth
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10212898 -O data/marble_achilles_kw5.pth
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10212892 -O data/marble_achilles_kw10.pth
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10212893 -O data/marble_achilles_kw20.pth
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10212894 -O data/marble_achilles_kw30.pth
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10212897 -O data/marble_achilles_kw50.pth
!wget -nc https://dataverse.harvard.edu/api/access/datafile/10212899 -O data/marble_achilles_kw100.pth

### Apply to test data

In [None]:
training, testing, labels_train, labels_test = [], [], [], []
pca_n = 20

for i, kw in enumerate([3, 5, 10, 20, 30, 50, 100]):
    data_train, label_train_marble, pca = convert_spikes_to_rates(neural_train.T, label_train, pca_n=pca_n, kernel_width=kw)
    data_test, label_test_marble, _ = convert_spikes_to_rates(neural_test.T, label_test, kernel_width=kw,  pca=pca)
    marble_model = MARBLE.net(data_train, loadpath=f"data/marble_achilles_kw{kw}.pth")
    
    data_train = marble_model.transform(data_train)
    data_test = marble_model.transform(data_test)
    
    training.append(data_train)
    testing.append(data_test)
    labels_train.append(label_train_marble)
    labels_test.append(label_test_marble)

In [None]:
num_models = 7
num_plots_per_model = 2
fig = plt.figure(figsize=(20,12))

for i, kw in enumerate([3, 5, 10, 20, 30, 50, 100]):
    data_train = training[i]
    data_test = testing[i]
    label_train_marble = labels_train[i]
    label_test_marble = labels_test[i]

    # Calculate subplot index for training data in the first row
    ax1 = fig.add_subplot(num_plots_per_model, num_models, i+1, projection='3d')
    ax = cebra.plot_embedding(ax=ax1, embedding=data_train.emb, embedding_labels=label_train_marble[:,0], markersize=0.2, title=f'MARBLE-train_kw{kw}')
    
    # Calculate subplot index for testing data in the second row
    ax2 = fig.add_subplot(num_plots_per_model, num_models, num_models + i + 1, projection='3d')
    ax = cebra.plot_embedding(ax=ax2, embedding=data_test.emb, embedding_labels=label_test_marble[:,0], markersize=1, title=f'MARBLE-test_kw{kw}')

plt.tight_layout()
plt.savefig('marble_cebra_embeddings_3D_kernelwidth_scan.png')  

In [None]:

results = [cebra_time_decode[4][:-1], cebra_behaviour_decode[4][:-1] ]
for i, kw in enumerate([3, 5, 10, 20, 30, 50, 100]):
    
    data_train = training[i]
    data_test = testing[i]
    label_train_marble = labels_train[i]
    label_test_marble = labels_test[i]
    
    marble_decode = decoding_pos_dir(data_train.emb, data_test.emb, label_train_marble, label_test_marble)
    results.append(marble_decode[4])
    
    
    
results = pd.DataFrame(data=np.vstack(results).T,columns=['c-time','c-pos+dir', 'c-pos','kw3','kw5','kw10','kw20','kw30','kw50','kw100',])
results = results.melt()
results.columns = ['model','accuracy']

f, ax = plt.subplots(figsize=(8,8))
sns.despine(bottom=True, left=True)

random_sampling = np.random.randint(low=0, high=results.shape[0], size=(200,))

sns.stripplot(
    data=results.iloc[random_sampling,:], x="model", y="accuracy", order=['c-time','c-pos+dir', 'c-pos','kw3','kw5','kw10','kw20','kw30','kw50','kw100',], 
    dodge=True, alpha=.5, zorder=1, color='gray',
)

sns.pointplot(
    data=results, x="model", y="accuracy",
    join=False, dodge=.8 - .8 / 3, palette="dark",
    markers="d", scale=.75, errorbar=None
)

plt.ylim([0,0.4])

pairs=[("c-time", "kw3"),
      ("c-pos+dir", "kw3"),
      ("c-pos", "kw3"),
        ("kw3", "kw5"),
      ("kw5", "kw10"),
      ("kw10", "kw20"),
      ("kw20", "kw30"),
      ("kw30", "kw50"),
      ("kw50", "kw100")]

annotator = Annotator(ax, pairs, data=results, x="model", y="accuracy",)
annotator.configure(test='Wilcoxon', text_format='star', loc='outside')
annotator.apply_and_annotate()
plt.tight_layout()
plt.savefig('decoding_accuracy_rat_32output_violin_kwscan.svg')  
