# Example workflow and comparison between MARBLE with CEBRA

In [None]:
%load_ext autoreload
%autoreload 2

! pip install ipympl
%matplotlib widget

import numpy as np
import pickle
from macaque_reaching_helpers import fit_pca, format_data
import matplotlib as mpl

import MARBLE

!pip install cebra
from cebra import CEBRA

## Load data

This part is data specific and you will need to adapt it to your own dataset.

In [None]:
!wget -nc https://dataverse.harvard.edu/api/access/datafile/6969883 -O data/rate_data_20ms_100ms.pkl

with open('data/rate_data_20ms_100ms.pkl', 'rb') as handle:
    rates = pickle.load(handle)

!wget -nc https://dataverse.harvard.edu/api/access/datafile/6963200 -O data/trial_ids.pkl

with open('data/trial_ids.pkl', 'rb') as handle:
    trial_ids = pickle.load(handle)

conditions = ["DownLeft", "Left", "UpLeft", "Up", "UpRight", "Right", "DownRight"]

## Linear dimensionality reduction and filtering of data. 

In [None]:
pca_n = 5
filter_data = True
day = 5 #load one session

pca = fit_pca(rates, day, conditions, filter_data=filter_data, pca_n=pca_n)
        
pos, vel, timepoints, condition_labels, trial_indexes = format_data(rates, 
                                                                    trial_ids,
                                                                    day, 
                                                                    conditions, 
                                                                    pca=pca,
                                                                    filter_data=filter_data)

## Run CEBRA

In [None]:
cebra_model = CEBRA(model_architecture='offset10-model',
                    batch_size=512,
                    learning_rate=0.0001,
                    temperature=1,
                    output_dimension=3,
                    max_iterations=5000,
                    distance='euclidean',
                    conditional='time_delta',
                    device='cpu',
                    verbose=True,
                    time_offsets=10)

pos_all = np.vstack(pos)
condition_labels = np.hstack(condition_labels)
        
cebra_model.fit(pos_all, condition_labels)
cebra_pos = cebra_model.transform(pos_all)

## Run MARBLE

In [None]:
data = MARBLE.construct_dataset(
    anchor=pos,
    vector=vel,
    k=30,
    spacing=0.0,
    delta=1.5,
)

params = {
    "epochs": 120,  # optimisation epochs
    "order": 2,  # order of derivatives
    "hidden_channels": 100,  # number of internal dimensions in MLP
    "out_channels": 3, 
    "inner_product_features": False,
    "diffusion": True,
}

model = MARBLE.net(data, params=params)

model.fit(data, outdir="data/session_{}_20ms".format(day))
data = model.transform(data)

## Plot embeddings

In [None]:
label = np.hstack(condition_labels)

colors = mpl.cm.viridis(np.linspace(0, 1, 7))
fig = plt.figure(figsize=(10, 8))


ax1 = fig.add_subplot(121, projection='3d')

emb = cebra_pos

for i in range(7):
    # Filter points by condition label
    indices = label == i
    ax1.scatter(
        emb[indices, 0],  # x-coordinates
        emb[indices, 1],  # y-coordinates
        emb[indices, 2],  # z-coordinates
        s=10,  # marker size
        color=colors[i],  # color for each condition
        label=f'Condition {i}',
        alpha=0.8
    )

ax1.grid(False)
ax1.set_xticks([])
ax1.set_yticks([])
ax1.set_zticks([])
ax1.legend()

ax2 = fig.add_subplot(121, projection='3d')

emb = data.emb

for i in range(7):
    # Filter points by condition label
    indices = label == i
    ax2.scatter(
        emb[indices, 0],  # x-coordinates
        emb[indices, 1],  # y-coordinates
        emb[indices, 2],  # z-coordinates
        s=10,  # marker size
        color=colors[i],  # color for each condition
        label=f'Condition {i}',
        alpha=0.8
    )

ax2.grid(False)
ax2.set_xticks([])
ax2.set_yticks([])
ax2.set_zticks([])
ax2.legend()

plt.show()