# Tutorial based on demo created my the Mathis lab on how to use and reproduce paper figures with CEBRA

## Install the required libraries before running this script.

In [None]:
!pip install --pre 'cebra[dev,demos]'

## Import all necessary libraries for data processing, visualization, and modeling.

In [None]:
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib as jl
import cebra.datasets
from cebra import CEBRA

from matplotlib.collections import LineCollection
import pandas as pd

## Load the hippocampus dataset using the CEBRA library.

In [None]:
hippocampus_pos = cebra.datasets.init('rat-hippocampus-single-achilles')

## Task: use plt.subplot() to create two subplots

### 1. A plot that shows neuronal activity over time.
### 2. A plot that shows the position over time.

In [None]:
# Hint 1: For the neuronal activity plot, use ax.imshow(hippocampus_pos.neural.numpy()[:1000].T, aspect='auto', cmap='gray_r').
# Hint 2: For the position plot, use ax.scatter(np.arange(1000), hippocampus_pos.continuous_index[:1000, 0], c='gray', s=1).
# 
# START YOUR CODE BELOW

# fig = plt.figure(figsize=(9,3), dpi=150)
# plt.subplots_adjust(wspace=0.3)
# ax = plt.subplot(121)

# ax2 = plt.subplot(122)

## Task: Train a CEBRA model to learn embeddings from the hippocampus dataset and visualize it

In [None]:
# START YOUR CODE BELOW
max_iterations = 10000
output_dimension = 32

cebra_posdir3_model = CEBRA(model_architecture='offset10-model',
                        batch_size=512,
                        learning_rate=3e-4,
                        temperature=1,
                        output_dimension=3,
                        max_iterations=max_iterations,
                        distance='cosine',
                        conditional='time_delta',
                        device='cuda_if_available',
                        verbose=True,
                        time_offsets=10)

cebra_posdir3_model.fit(hippocampus_pos.neural, hippocampus_pos.continuous_index.numpy())
cebra_posdir3 = cebra_posdir3_model.transform(hippocampus_pos.neural)

# Hint: Use cebra.integrations.plotly.plot_embedding_interactive on cebra_posdir3
import cebra.integrations.plotly
# fig = cebra.integrations.plotly.plot_embedding_interactive(YOUR VARIABLE, embedding_labels = HIPPOCAMPUS LABELS, 
                                                           # title = "CEBRA-Behavior", cmap = "rainbow")
# fig.show() 

## Task: Train a control model with shuffled neural data.

In [None]:
cebra_posdir_shuffled3_model = CEBRA(model_architecture='offset10-model',
                        batch_size=512,
                        learning_rate=3e-4,
                        temperature=1,
                        output_dimension=3,
                        max_iterations=max_iterations,
                        distance='cosine',
                        conditional='time_delta',
                        device='cuda_if_available',
                        verbose=True,
                        time_offsets=10)

# Hint: use np.random.permutation on hippocampus_pos.continuous_index.numpy and then fit the model by using the neural data and the shuffled 
# posdir
hippocampus_shuffled_posdir = 
cebra_posdir_shuffled3_model.fit(..., ...)
cebra_posdir_shuffled3 = cebra_posdir_shuffled3_model.transform(...)

## Task: Train a model that uses time without the behavior information.'

In [None]:
cebra_time3_model = CEBRA(model_architecture='offset10-model',
                        batch_size=512,
                        learning_rate=3e-4,
                        temperature=1.12,
                        output_dimension=3,
                        max_iterations=max_iterations,
                        distance='cosine',
                        conditional='time',
                        device='cuda_if_available',
                        verbose=True,
                        time_offsets=10)

# Hint: fit the model ONLY to the neural data
cebra_time3_model.fit(...)
cebra_time3 = cebra_time3_model.transform(...)

## Task: Train a model that uses both time and positional information.

In [None]:
cebra_hybrid_model = CEBRA(model_architecture='offset10-model',
                        batch_size=512,
                        learning_rate=3e-4,
                        temperature=1,
                        output_dimension=3,
                        max_iterations=max_iterations,
                        distance='cosine',
                        conditional='time_delta',
                        device='cuda_if_available',
                        verbose=True,
                        time_offsets=10,
                        hybrid = True) # NOTE the new variable to make it hybrid

# Hint: run it as the version with behavior labels
cebra_hybrid_model.fit(...l, ...)
cebra_hybrid = cebra_hybrid_model.transform(...)

## Task: Visualize the embeddings from CEBRA-Behavior, CEBRA-Time and CEBRA-Hybrid

In [None]:
def plot_hippocampus(ax, embedding, label, gray = False, idx_order = (0,1,2)):
    r_ind = label[:,1] == 1
    l_ind = label[:,2] == 1

    if not gray:
        r_cmap = 'cool'
        l_cmap = 'viridis'
        r_c = label[r_ind, 0]
        l_c = label[l_ind, 0]
    else:
        r_cmap = None
        l_cmap = None
        r_c = 'gray'
        l_c = 'gray'

    idx1, idx2, idx3 = idx_order
    r=ax.scatter(embedding [r_ind,idx1],
               embedding [r_ind,idx2],
               embedding [r_ind,idx3],
               c=r_c,
               cmap=r_cmap, s=0.5)
    l=ax.scatter(embedding [l_ind,idx1],
               embedding [l_ind,idx2],
               embedding [l_ind,idx3],
               c=l_c,
               cmap=l_cmap, s=0.5)

    ax.grid(False)
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False
    ax.xaxis.pane.set_edgecolor('w')
    ax.yaxis.pane.set_edgecolor('w')
    ax.zaxis.pane.set_edgecolor('w')

    return ax

In [None]:
%matplotlib notebook
# Hint: we need 4 subplots
# START YOUR CODE BELOW

fig = plt.figure(figsize=(10,2))
ax1 = plt.subplot(141, projection='3d')
...

# Hint: use the function created above, plot_hippocampus, which takes the ax and the embedding, i.e. the model you want to plot 
# against the label you want to use, i.e. the corresponding hippocampus index

ax1 = ... 

ax1.set_title('CEBRA-Behavior')
...
plt.show()