# This is the example of MARBLE applied to low-rank RNNs

In [None]:
%load_ext autoreload
%autoreload 2
import torch
import pickle
import numpy as np
import matplotlib.pyplot as plt
import MARBLE
from MARBLE import geometry, plotting
from RNN_scripts import dms, helpers

We start by downloading some intermediate data to reproduce results. 

In [None]:
# download datasets
!mkdir data
!wget -nc https://dataverse.harvard.edu/api/access/datafile/6963161 -O data/dms_rank2_500.pt
!wget -nc https://dataverse.harvard.edu/api/access/datafile/6963162 -O data/dms_rank2_500_2.pt
!wget -nc https://dataverse.harvard.edu/api/access/datafile/6963166 -O data/dms_rank2_500_sampled_1.pt 
!wget -nc https://dataverse.harvard.edu/api/access/datafile/6963165 -O data/dms_rank2_500_sampled_2.pt 
!wget -nc https://dataverse.harvard.edu/api/access/datafile/6963163 -O data/RNN_trajectories11.pkl 
!wget -nc https://dataverse.harvard.edu/api/access/datafile/6963164 -O data/RNN_trajectories12.pkl 
!wget -nc https://dataverse.harvard.edu/api/access/datafile/6963158 -O data/RNN_trajectories2.pkl
!wget -nc https://dataverse.harvard.edu/api/access/datafile/7062817 -O data/data_solution_1.pkl
!wget -nc https://dataverse.harvard.edu/api/access/datafile/7062816 -O data/data_all_solutions.pkl
!wget -nc https://dataverse.harvard.edu/api/access/datafile/7509026 -O data/best_model_20231023-113502.pth
!wget -nc https://dataverse.harvard.edu/api/access/datafile/7509027 -O data/best_model_20231023-115754.pth

Load two RNN solutions pretrained on the DMS task

In [None]:
_, net1 = helpers.load_network('data/dms_rank2_500.pt')
_, net2 = helpers.load_network('data/dms_rank2_500_2.pt')

helpers.plot_coefficients(net1)
helpers.plot_coefficients(net2)

Display input/output trajectories

In [None]:
dms.stimulus1_duration_min = 500
dms.stimulus1_duration_max = 500
dms.delay_duration_min = 1000
dms.delay_duration_max = 1000
dms.stimulus2_duration_min = 500
dms.stimulus2_duration_max = 500
dms.decision_duration = 200
dms.setup()

x1 = dms.generate_dms_data(1, type='A-A')[0]
x2 = dms.generate_dms_data(1, type='B-A')[0]
outp1, traj1 = net1.forward(x1)
outp2, traj2 = net1.forward(x2)
x1, x2 = x1.squeeze().numpy(), x2.squeeze().numpy()
outp1 = outp1.detach().squeeze().numpy()
outp2 = outp2.detach().squeeze().numpy()

def time_mapping(t):
    return t * dms.deltaT / 1000

fig, ax = plt.subplots(figsize=(3, 2))

ax.plot(time_mapping(np.arange(x1.shape[0])), x1[:, 0], c='#65BADA', zorder=30, lw=2)
ax.plot(time_mapping(np.arange(x1.shape[0])), x2[:, 0], c='#C82E6B', zorder=30, lw=2)
ax.plot([0, 0.2], [-.25, -.25], c='gray', lw=4)


fig, ax = plt.subplots(figsize=(3, 2))
ax.plot(time_mapping(np.arange(outp1.shape[0])), outp1, color='#65BADA', zorder=30, lw=2)
ax.plot(time_mapping(np.arange(outp1.shape[0])), outp2, color='#C82E6B', zorder=30, lw=2)
ax.plot([0, 0.2], [-1.25, -1.25], c='gray', lw=4)

Create new network by fitting Gaussian mixture to the connectivity space

In [None]:
_, net1_sampled_1 = helpers.sample_network(net1, 'data/dms_rank2_500_sampled_1.pt', seed=0)
_, net1_sampled_2 = helpers.sample_network(net1, 'data/dms_rank2_500_sampled_2.pt', seed=1)

helpers.plot_coefficients(net1_sampled_1)
helpers.plot_coefficients(net1_sampled_2)

Design stimulus conditions

In [None]:
stim1_begin, stim1_end, stim2_begin, stim2_end, decision = 25, 50, 200, 225, 275
epochs = [0, stim1_begin, stim1_end, stim2_begin, stim2_end, decision]

n_gains=20
gain = np.linspace(1, 0, n_gains)
    
input_data = torch.zeros(n_gains, decision, 2)
for i, g in enumerate(gain):
    input_data[i, stim1_begin:stim1_end, 0] = g
    input_data[i, stim2_begin:stim2_end, 0] = g

Generate synthetic data

In [None]:
# uncomment below to generate trajectories (slow, if not loading existing file) 
n_traj=200
traj11 = helpers.generate_trajectories(net1_sampled_1, input_data, epochs, n_traj, fname='data/RNN_trajectories11.pkl')
traj12 = helpers.generate_trajectories(net1_sampled_2, input_data, epochs, n_traj, fname='data/RNN_trajectories12.pkl')
traj2 = helpers.generate_trajectories(net2, input_data, epochs, n_traj, fname='data/RNN_trajectories2.pkl')

Plot phase portraits of two different dynamics  

In [None]:
n_traj=2
traj = helpers.generate_trajectories(net1_sampled_1, input_data, epochs, n_traj)
helpers.plot_experiment(net1_sampled_1, input_data, traj, epochs, rect=(-6, 6, -4, 4), traj_to_show=1)

In [None]:
helpers.plot_experiment(net2, input_data, traj2, epochs, rect=(-6, 6, -4, 4), traj_to_show=1)

In [None]:
accuracy = []
for g in gain:
    _, _, _, x_val, y_val, mask_val = dms.generate_dms_data(10000, gain=g)
    loss, acc = dms.test_dms(net1, x_val, y_val, mask_val)
    accuracy.append(acc)
    
plt.plot(gain, accuracy)

Concatenate data and create datasets

In [None]:
transient = 15 #clip 15 timesteps
pos11, vel11 = helpers.aggregate_data(traj11, epochs, transient)

Train a model on network solution 1

In [None]:
data = pickle.load(open('./data/data_solution_1.pkl','rb')) #if you're impatient
#data = MARBLE.construct_dataset(pos11, features=vel11, graph_type='cknn', k=15, spacing=0.01) # takes 5-10 mins

In [None]:
model_file = 'best_model_20231023-113502.pth'
model = MARBLE.net(data, loadpath='./data/'+model_file)

#params = {'epochs': 40, 
#          'hidden_channels': 64, 
#          'out_channels': 5,
#          'diffusion': False,
#          'inner_product_features': False, #geometry-aware for maximal expressivity
#          }
#model = MARBLE.net(data, params=params)
#model.fit(data, outdir='data')

In [None]:
data = model.transform(data)
data = MARBLE.distribution_distances(data, n_clusters=60)

Cluster and plot distance matrix

In [None]:
from scipy.cluster.hierarchy import dendrogram                                                          
from scipy.cluster.hierarchy import linkage                                                             
from scipy.spatial.distance import squareform   

def cluster_matrix(df, distance=False, ax=None):                                                                 
    """Return sorted labels to cluster a matrix with linkage.                                           
                                                                                                        
    If distance matrix already set distance=True.                                                       
    """                                                                                                 
                                                                                                        
    with np.errstate(divide="ignore", invalid="ignore"):                                                
        _data = df if distance else 1.0 / df                              
                                                                                                        
    _data[_data > 1e10] = 1000                                                                          
    np.fill_diagonal(_data, 0.0)                                                                        
    dists = squareform(_data)                                                                           
    Z = linkage(dists, "ward")                                                                          
    labels = np.arange(0, len(df))                                                                     
    dn = dendrogram(Z, labels=labels, ax=ax)                                                     
    return labels[dn["leaves"]]     

plt.figure(figsize=(5, 2))
l = cluster_matrix(data.dist, distance=True, ax=plt.gca())
plt.xlabel('original labels')

plt.figure(figsize=(4, 3))
plt.plot(np.arange(0, len(data.dist)), l)
plt.xlabel('original labels')
plt.ylabel('clustered labels')
plt.axvline(len(data.dist)/2-0.5, c='r')
plt.axvline(3*len(data.dist)/4-0.5, c='r')

plt.figure()
im = plt.imshow(data.dist)
plt.colorbar(im)
plt.axhline(len(data.dist)/2-0.5, c='r')
plt.axhline(3*len(data.dist)/4-0.5, c='r')
plt.axvline(len(data.dist)/2-0.5, c='r')
plt.axvline(3*len(data.dist)/4-0.5, c='r')

# Train network with solution I and solution II

In [None]:
transient = 15 #clip 15 timesteps
pos11, vel11 = helpers.aggregate_data(traj11, epochs, transient)
pos12, vel12 = helpers.aggregate_data(traj12, epochs, transient)
pos2, vel2 = helpers.aggregate_data(traj2, epochs, transient)

pos = pos11 + pos12 + pos2
vel = vel11 + vel12 + vel2

In [None]:
data2 = pickle.load(open('./data/data_all_solutions.pkl','rb')) #if you're impatient
#data2 = MARBLE.construct_dataset(pos, features=vel, graph_type='cknn', k=15, spacing=0.01) # takes up to 15 mins

In [None]:
model_file = 'best_model_20231023-115754.pth'
model2 = MARBLE.net(data, loadpath='./data/'+model_file)

#params = {'epochs': 40,
#          'order': 2,
#          'hidden_channels': 64,
#          'out_channels': 5,
#          'diffusion': False,
#          'inner_product_features': True, #geometry-agnostic as manifolds are differently oriented across networks
#         }

#model2 = MARBLE.net(data2, params=params)
#model2.fit(data2, outdir='data')

In [None]:
data2 = model2.transform(data2)
data2 = MARBLE.distribution_distances(data2, n_clusters=60)

In [None]:
n = len(data2.dist) // 6
dist = data2.dist[:2 * n,:2 * n]
ind = list(range(n,2 * n)) + list(range(3 * n,4 * n)) + list(range(5 * n,6 * n))
dist = data2.dist[ind,:][:,ind]
im = plt.imshow(dist)
plt.colorbar(im)

In [None]:
labels = np.array([g for g in gain ])

fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot()
emb_MDS, _ = geometry.embed(dist, embed_typ = 'MDS')
ax = plotting.embedding(emb_MDS[:n], labels, ax=ax, s=30, alpha=1, axes_visible=True)
ax = plotting.embedding(emb_MDS[n:2*n], labels, ax=ax, s=30, alpha=1, axes_visible=True, cmap='PuOr')
ax = plotting.embedding(emb_MDS[2*n:], labels, ax=ax, s=30, alpha=1, axes_visible=True, cmap='PRGn')

# Compare results with Canonical Correlation Analysis

In [None]:
transient = 15 #clip 15 timesteps
pos11, vel11 = helpers.aggregate_data(traj11, epochs, transient, pca=False)
pos12, vel12 = helpers.aggregate_data(traj12, epochs, transient, pca=False)
pos2, vel2 = helpers.aggregate_data(traj2, epochs, transient, pca=False)

pos = pos11 + pos12 + pos2
vel = vel11 + vel12 + vel2

In [None]:
from sklearn.cross_decomposition import CCA
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

n_comp = 3

scaler = StandardScaler() 
s = data2._slice_dict["x"]
ns = len(s)-1
cca = CCA(scale=False, n_components=n_comp) #define CCA
pca = PCA(n_components=3)

dist_CCA = np.zeros([ns,ns])

pos_transform = []
for p in tqdm(pos):
    p = scaler.fit_transform(p)
    p = pca.fit_transform(p.T).T
    pos_transform.append(p)
    
for i in tqdm(range(ns)):
    for j in range(ns):
    
        u = pos_transform[i]
        v = pos_transform[j]
          
        cca.fit(u, v) #fit our scaled data
        u, v = cca.transform(u, v)

        comp_corr = [np.corrcoef(u[:, i], v[:, i])[1][0] for i in range(n_comp)]
        comp_corr = np.array(comp_corr)
        dist_CCA[i,j] = comp_corr.mean()

Note that as CCA measures the change in loadings between principal axes across datasets, it does not pick up dynamical changes because they all relate to within-plane variation.

In [None]:
im = plt.imshow(dist_CCA)
plt.colorbar(im)