# Example workflow and comparison between MARBLE with CEBRA

In [32]:
%load_ext autoreload
%autoreload 2

#import matplotlib.pyplot as plt
import numpy as np
#import pandas as pd
import pickle
#import seaborn as sns
#from statannotations.Annotator import Annotator
#from sklearn.model_selection import KFold
from macaque_reaching_helpers import fit_pca, format_data
#from tqdm import tqdm

import MARBLE

!pip install cebra
from cebra import CEBRA

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load data

This part is data specific and you will need to adapt it to your own dataset.

In [16]:
!wget -nc https://dataverse.harvard.edu/api/access/datafile/6969883 -O data/rate_data_20ms_100ms.pkl

with open('data/rate_data_20ms_100ms.pkl', 'rb') as handle:
    rates = pickle.load(handle)

!wget -nc https://dataverse.harvard.edu/api/access/datafile/6963200 -O data/trial_ids.pkl

with open('data/trial_ids.pkl', 'rb') as handle:
    trial_ids = pickle.load(handle)

conditions = ["DownLeft", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight"]

File ‘data/rate_data_20ms_100ms.pkl’ already there; not retrieving.
File ‘data/trial_ids.pkl’ already there; not retrieving.


## Linear dimensionality reduction and filtering of data. 

In [30]:
pca_n = 5
filter_data = True
day = 5 #load one session

pca = fit_pca(rates, day, conditions, filter_data=filter_data, pca_n=pca_n)
        
pos, vel, timepoints, condition_labels, trial_indexes = format_data(rates, 
                                                                    trial_ids,
                                                                    day, 
                                                                    conditions, 
                                                                    pca=pca,
                                                                    filter_data=filter_data)

## Run CEBRA

In [29]:
cebra_model = CEBRA(model_architecture='offset10-model',
                    batch_size=512,
                    learning_rate=0.0001,
                    temperature=1,
                    output_dimension=3,
                    max_iterations=5000,
                    distance='euclidean',
                    conditional='time_delta',
                    device='cpu',
                    verbose=True,
                    time_offsets=10)

pos_all = np.vstack(pos)
condition_labels = np.hstack(condition_labels)
        
cebra_model.fit(pos_all, condition_labels)
cebra_pos = cebra_model.transform(pos_all)

pos:  0.3790 neg:  5.1307 total:  5.5096 temperature:  1.0000: 100%|█| 5000/5000


## Run MARBLE

In [33]:
data = MARBLE.construct_dataset(
    anchor=pos,
    vector=vel,
    k=30,
    spacing=0.0,
    delta=1.5,
)

params = {
    "epochs": 120,  # optimisation epochs
    "order": 2,  # order of derivatives
    "hidden_channels": 100,  # number of internal dimensions in MLP
    "out_channels": 3, 
    "inner_product_features": False,
    "diffusion": True,
}

model = MARBLE.net(data, params=params)

model.fit(data, outdir="data/session_{}_20ms".format(day))
data = model.transform(data)


---- Embedding dimension: 5
---- Signal dimension: 5
---- Computing kernels ... 
---- Computing eigendecomposition ... 
---- Settings: 

epochs : 120
order : 2
hidden_channels : [100]
out_channels : 3
inner_product_features : False
diffusion : True
batch_size : 32
lr : 0.01
momentum : 0.9
dropout : 0.0
batch_norm : batch_norm
bias : True
architecture : GAT
GAT_hidden_layers : 50
GAT_attention_heads : 1
align_datasets : True
frac_sampled_nb : -1
include_positions : False
include_self : True
vec_norm : False
emb_norm : False
seed : 0
dim_signal : 5
dim_emb : 5
slices : tensor([    0,  2278,  4488,  6596,  8568, 10812, 12954, 15130])
n_sampled_nb : -1

---- Number of features to pass to the MLP:  15
---- Total number of parameters:  1911

Using device cpu

---- Training network ...

---- Timestamp: 20241104-112714


Representation alignment


RuntimeError: mat1 and mat2 shapes cannot be multiplied (2210x5 and 2x2)

## Plot embeddings

In [None]:

# Configure Plotly to be rendered inline in the notebook.
plotly.offline.init_notebook_mode()
    
emb = embeddings[d]
label = labels[d]
time = np.hstack(timepoints[d])

# Configure the plot.
data = []
colors = pl.cm.viridis(np.linspace(0,1,7)) #colours for 7 conditions

for i in range(7):
    trace = go.Scatter3d(
        x=emb[label==i,0],  
        y=emb[label==i,1],  
        z=emb[label==i,2],  
        mode='markers',
        marker={
            'size': 1,
            'opacity': 1,
            'color':'rgb({},{},{})'.format(colors[i,0],colors[i,1],colors[i,2]),  # set color to an array/list of desired values
        },
    )
    data.append(trace)

layout = go.Layout(
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)',
    xaxis=dict(showgrid=False,showline=False),
    yaxis=dict(showgrid=False,showline=False)
)

plot_figure = go.Figure(data=data, layout=layout)
plot_figure.update_scenes(xaxis_visible=False, yaxis_visible=False, zaxis_visible=False)

# Render the plot.
plot_figure.show()