In [None]:
import matplotlib.pyplot as plt

import numpy as np
from scipy.ndimage import gaussian_filter1d
from rastermap.svd import SVD
import sys, os
from rastermap import Rastermap
from scipy.stats import zscore
from rastermap.utils import bin1d


sys.path.insert(0, '/github/rastermap/paper/')
import metrics, simulations, fig1

root = "/media/carsen/ssd2/rastermap_paper/"
os.makedirs(os.path.join(root, "simulations/"), exist_ok=True)


### make simulations

In [None]:

n_per_module = 1000
for random_state in range(0, 10):
    out = simulations.make_full_simulation(n_per_module=n_per_module, random_state=random_state)
    spks, xi_all, stim_times_all, psth, psth_spont, iperm = out
    np.savez(os.path.join(root, "simulations/", f"sim_{random_state}.npz"), 
                spks=spks, xi_all=xi_all, 
                stim_times_all=np.array(stim_times_all, dtype=object), 
                psth=psth, psth_spont=psth_spont, iperm=iperm)



### run embedding algorithms and benchmark performance

In [None]:
imp.reload(simulations)
simulations.embedding_performance(root, save=True)

### make figure

In [None]:
# root path has folder "simulations" with saved results
# will save figures to "figures" folder
os.makedirs(os.path.join(root, "figures/"), exist_ok=True)
fig1.fig1(root, save_figure=True)   

In [None]:
#d = 0
#div_map = [[5, 42], [43, 108], [109, 136], [136, 170], [176, 200]]
#plt.figure(figsize=(12,3))
#plt.imshow(X_emb[div_map[d][0] : div_map[d][1],:8000], aspect="auto", vmax=2, vmin=0)

### supplementary analyses

In [None]:
# run t-SNE with different perplexities
knn = np.array([10,50,100,200,500])
mnn_all = np.zeros((10, 7, len(knn)))
rho_all = np.zeros((10, 7))
embs_all = np.zeros((10, 7, 6000, 1))
scores_all = np.zeros((10, 2, 8, 5))
for random_state in range(10):
    print(random_state)
    dat = np.load(os.path.join(root, "simulations", f"sim_{random_state}.npz"), allow_pickle=True)
    spks = dat["spks"]
    # run rastermap to get PCs
    model = Rastermap(n_clusters=100, n_PCs=200, locality=0.8,
                    time_lag_window=10, time_bin=10).fit(spks)   
    perplexities = []
    j = 0
    for perplexity in [10,30,60,100,200]:
        M = metrics.run_TSNE(model.Usv, perplexities=[perplexity])
        embs_all[random_state, j] = M
        j += 1
        perplexities.append([perplexity, 0])
        if perplexity > 60:
            M = metrics.run_TSNE(model.Usv, perplexities=[30, perplexity])
            embs_all[random_state, j] = M
            j += 1
            perplexities.append([30, perplexity])
    contamination_scores, triplet_scores = metrics.benchmarks(dat["xi_all"], embs_all[random_state])
    mnn, rho = metrics.embedding_quality_gt(dat["xi_all"], embs_all[random_state], knn=knn.copy())
    mnn_all[random_state], rho_all[random_state] = mnn, rho
    scores_all[random_state] = np.stack((contamination_scores, triplet_scores), 
                                            axis=0)
    
np.savez(os.path.join(root, "simulations", "sim_performance_tsne.npz"), 
         embs_all=embs_all, scores_all=scores_all, 
         mnn_all=mnn_all, rho_all=rho_all, knn=knn,
         perplexities=perplexities)

In [None]:
# run UMAP with different n_neighbors
knn = np.array([10,50,100,200,500])
n_neighbors = np.array([5, 15, 30, 60, 100, 200])
mnn_all = np.zeros((10, 6, len(knn)))
rho_all = np.zeros((10, 6))
embs_all = np.zeros((10, 6, 6000, 1))
scores_all = np.zeros((10, 2, 7, 5))
for random_state in range(10):
    print(random_state)
    dat = np.load(os.path.join(root, "simulations", f"sim_{random_state}.npz"), allow_pickle=True)
    spks = dat["spks"]
    # run rastermap to get PCs
    model = Rastermap(n_clusters=100, n_PCs=200, locality=0.8,
                    time_lag_window=10, time_bin=10).fit(spks)   
    j = 0
    for nneigh in n_neighbors:
        M = metrics.run_UMAP(model.Usv, n_neighbors=nneigh)
        embs_all[random_state, j] = M
        j += 1
        print(j)
    contamination_scores, triplet_scores = metrics.benchmarks(dat["xi_all"], embs_all[random_state])
    mnn, rho = metrics.embedding_quality_gt(dat["xi_all"], embs_all[random_state], knn=knn.copy())
    mnn_all[random_state], rho_all[random_state] = mnn, rho
    scores_all[random_state] = np.stack((contamination_scores, triplet_scores), 
                                            axis=0)
    
np.savez(os.path.join(root, "simulations", "sim_performance_umap.npz"), 
         embs_all=embs_all, scores_all=scores_all, 
         mnn_all=mnn_all, rho_all=rho_all, knn=knn,
         n_neighbors=n_neighbors)

In [None]:
# compute neighbor scores for original embeddings
d2 = np.load(os.path.join(root, "simulations", "sim_performance.npz"), allow_pickle=True) 
mnn_all = np.zeros((10, 5, len(knn)))
rho_all = np.zeros((10, 5))
for random_state in range(10):
    dat = np.load(os.path.join(root, "simulations", f"sim_{random_state}.npz"), allow_pickle=True)
    embs = d2["embs_all"][random_state].squeeze()
    mnn, rho = metrics.embedding_quality_gt(dat["xi_all"], embs, knn=knn.copy())
    mnn_all[random_state], rho_all[random_state] = mnn, rho
np.savez(os.path.join(root, "simulations", "sim_performance_neigh.npz"), 
         mnn_all=mnn_all, rho_all=rho_all, knn=knn)


In [None]:
fig1.suppfig_scores(root)