# Visualisations of MARBLE embeddings

This notebook visualises the MARBLE latent representations of the macaque arm-reaching data obtained from binned spike counts with 20ms bin size.

We would like to thank the authors of LFADS for making this data accessible and answering our questions about the data!

### Note: the notebook relies on plotly, which may not work on all browsers. If you encounter issues on one browser (e.g., Chrome), just change to another (e.g., Firefox).

In [None]:
%load_ext autoreload
%autoreload 2
    
!pip install plotly

import numpy as np
import pandas as pd

import matplotlib.pylab as pl
import matplotlib.pyplot as plt
import plotly
import plotly.graph_objs as go
from sklearn.decomposition import PCA

import pickle

from sklearn.metrics import pairwise_distances
import torch.nn as nn
import torch

from MARBLE import geometry 

## Load data

In [None]:
# insert the pickle file of results that you want to visualise
!mkdir data
!wget -nc https://dataverse.harvard.edu/api/access/datafile/7062022 -O data/marble_embeddings_out3_pca10_100ms.pkl

with open('./data/marble_embeddings_out3_pca10_100ms.pkl', 'rb') as handle:
    data = pickle.load(handle)
    
distance_matrices = data[0]
embeddings = data[1]
timepoints = data[2]
labels = data[3]
sample_inds = data[4]
trial_ids = data[5]

# condition labels
conditions=['DownLeft','Left','UpLeft','Up','UpRight','Right','DownRight']

# load kinematics
!wget -nc https://dataverse.harvard.edu/api/access/datafile/6969885 -O data/kinematics.pkl

with open('data/kinematics.pkl', 'rb') as handle:
    kinematic_data =  pickle.load(handle)

# Generate 3D plots for a selection of sessions

Lets first do this for the MARBLE data.

In [None]:
colors = pl.cm.viridis(np.linspace(0,1,7))

# Configure Plotly to be rendered inline in the notebook.
plotly.offline.init_notebook_mode()

# looping over 10 different sessions
examples = [5,6,8,11,14,15,18,23,26,32] # these sessions were used in Figure S7
for d, i in enumerate(examples):
    emb = embeddings[d]
    label = labels[d]
    time = np.hstack(timepoints[d])
    # Configure the trace.
    data = []

    for i in range(7):
        trace = go.Scatter3d(
            x=emb[label==i,0],  
            y=emb[label==i,1],  
            z=emb[label==i,2],  
            mode='markers',
            marker={
                'size': 1,
                'opacity': 1,
                'color':'rgb({},{},{})'.format(colors[i,0],colors[i,1],colors[i,2]),  # set color to an array/list of desired values
            },
        )
        data.append(trace)

    # Configure the layout.
    layout = go.Layout(
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        xaxis=dict(showgrid=False,showline=False),
        yaxis=dict(showgrid=False,showline=False)
    )

    plot_figure = go.Figure(data=data, layout=layout)
    plot_figure.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )

    # Render the plot.
    plot_figure.show()

Lets now compare this with the LFADS embeddings.

In [None]:
colors = pl.cm.viridis(np.linspace(0,1,7))

# Configure Plotly to be rendered inline in the notebook.
plotly.offline.init_notebook_mode()

for i in range(len(examples)):
    d = examples[i]
    
    
    lfads_data = [[] for cond in conditions]
    all_data = []
    for c,cond in enumerate(conditions):   
        for t in kinematic_data[d].keys():
            if kinematic_data[d][t]['condition']==cond:
                meh = kinematic_data[d][t]['lfads_factors']
                lfads_data[c].append(meh)
                all_data.append(meh)

    lfads_data = [np.hstack(u) for u in lfads_data]
    all_data = np.hstack(all_data)            

    # need to PCA high dimension lfads data
    pca = PCA(n_components=3)
    pca.fit(all_data.T)  
    
    
    # Configure the trace.
    data = []

    for i in range(7):
        emb = pca.transform(lfads_data[i].T)
        trace = go.Scatter3d(
            x=emb[:,0],  
            y=emb[:,1],  
            z=emb[:,2],  
            mode='markers',
            marker={
                'size': 1,
                'opacity': 1,
                'color':'rgb({},{},{})'.format(colors[i,0],colors[i,1],colors[i,2]),  # set color to an array/list of desired values
            },
        )
        data.append(trace)

    # Configure the layout.
    layout = go.Layout(
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        xaxis=dict(showgrid=False,showline=False),
        yaxis=dict(showgrid=False,showline=False)
    )

    plot_figure = go.Figure(data=data, layout=layout)
    plot_figure.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
    
    # Render the plot.
    plot_figure.show()
    #plotly.offline.iplot(plot_figure)

# Average distance matrix across sessions 

Lets see what the average distance matrix looks like across sessions for MARBLE.

In [None]:
# plot average distance matrix based on clustering
plt.figure()
plt.imshow(np.mean(np.dstack(distance_matrices),2)); plt.colorbar()  

emb_MDS, _ = geometry.embed(np.mean(np.dstack(distance_matrices),2), embed_typ = 'MDS')
plt.figure()
plt.scatter(emb_MDS[:,0],emb_MDS[:,1],c=np.linspace(0,6,7))

how does this compare with LFADS?

In [None]:
# we first need to compute distance matrices for lfads 

distance_matrices_lfads = []

# loop over sessions and compute distance matrices
for d in range(len(embeddings)):
    
    lfads_data = [[] for cond in conditions]
    for t in kinematic_data[d].keys():
        for c,cond in enumerate(conditions):   
            if kinematic_data[d][t]['condition'] == cond:
                meh = kinematic_data[d][t]['lfads_factors']
                lfads_data[c].append(meh)
    
    lfads_data = [np.hstack(u).T for u in lfads_data]
    
    distances = np.zeros([len(conditions), len(conditions)])
    for i in range(len(conditions)):
        for j in range(len(conditions)):
            if i == j:
                distances[i,j] = 0
            else:
                distances[i,j] = pairwise_distances(lfads_data[i], lfads_data[j]).mean()
                
    distances = distances/np.std(distances)
    distance_matrices_lfads.append(distances)

In [None]:
# plot average distance matrix based on clustering
plt.figure()
plt.imshow(np.mean(np.dstack(distance_matrices_lfads),2))
plt.colorbar()  

emb_MDS, _ = geometry.embed(np.mean(np.dstack(distance_matrices_lfads),2), embed_typ='MDS')
plt.figure()
plt.scatter(emb_MDS[:,0], emb_MDS[:,1], c=np.linspace(0,6,7))

Both are pretty good in terms of their average embeddings!

# Plotting individual session embeddings

Here we just want to plot the distance matrix for individual sessions (Fig S7).

In [None]:
fig, axs = plt.subplots(4,len(examples),figsize=(15,5))

# loop over example sessions
for i,idx in enumerate(examples):
    
    # plot distance matrix for marble
    axs[0, i].imshow(distance_matrices[idx])
    
    # plot distance matrix for LFADS
    axs[1, i].imshow(distance_matrices_lfads[idx])    

    # plot MDS embedding of MARBLE distance matrix
    emb_MDS, _ = geometry.embed(distance_matrices[idx], embed_typ = 'MDS')
    axs[2, i].scatter(emb_MDS[:,0],emb_MDS[:,1],c=np.linspace(0,6,7))
    
    # plot MDS embedding of LFADS distance matrix
    emb_MDS, _ = geometry.embed(distance_matrices_lfads[idx], embed_typ = 'MDS')
    axs[3, i].scatter(emb_MDS[:,0],emb_MDS[:,1],c=np.linspace(0,6,7))