In [None]:
import sys, os
import numpy as np
from scipy.stats import zscore
from rastermap import Rastermap, utils

# path to paper code
sys.path.insert(0, '/github/rastermap/paper')
from loaders import tuning_curves_VR
import fig3

# path to directory with data etc
### *** CHANGE THIS TO WHEREEVER YOU ARE DOWNLOADING THE DATA ***
root = "/media/carsen/ssd2/rastermap_paper/"
# (in this folder we have a "data" folder and a "results" folder)
os.makedirs(os.path.join(root, "data"), exist_ok=True)
os.makedirs(os.path.join(root, "results"), exist_ok=True)

### load virtual reality task data

(this data will be available upon publication of the paper)


In [None]:
dat = np.load(os.path.join(root, "data/", "corridor_neur.npz"))
corridor = np.load(os.path.join(root, "data/", "corridor_behavior.npz"))

xpos, ypos, spks = dat["xpos"], dat["ypos"], dat["spks"]
spks = zscore(spks, axis=1)

### run rastermap and compute tuning curves

In [None]:
model = Rastermap(n_clusters=100, n_PCs=200, 
                    time_lag_window=10, locality=0.75).fit(spks)
isort = model.isort 
cc_nodes = model.cc
bin_size = 100
sn = zscore(utils.bin1d(spks[isort], bin_size, axis=0), axis=1)
corridor_tuning = tuning_curves_VR(sn, corridor["VRpos"], corridor["corridor_starts"])

# sort in time
model2 = Rastermap(n_clusters=100, n_splits=0, locality=0.,
                             n_PCs=200).fit(sn.T)
isort2 = model2.isort

np.savez(os.path.join(root, "results", "corridor_proc.npz"),
         sn=sn, xpos=xpos, ypos=ypos, isort=isort, isort2=isort2,
        cc_nodes=cc_nodes, corridor_tuning=corridor_tuning)

### make figure

In [None]:
# root path has folder "results" with saved results
# will save figures to "figures" folder
os.makedirs(os.path.join(root, "figures/"), exist_ok=True)
fig3.fig3(root)

### supplementary analysis

In [None]:
import metrics

ys = [metrics.run_TSNE(model.Usv), 
      metrics.run_UMAP(model.Usv)]

snys = []
ctunings = []
for k in range(2):
    isorty = ys[k][:,0].argsort()
    sny = zscore(utils.bin1d(spks[isorty], 100, axis=0))
    ctuning = tuning_curves_VR(sny, corridor["VRpos"], corridor["corridor_starts"])
    snys.append(sny)
    ctunings.append(ctuning)

np.savez(os.path.join(root, "results", "corridor_supp.npz"),
         snys=snys, ctunings=ctunings, 
         corridor_starts=corridor["corridor_starts"], 
         corridor_widths=corridor["corridor_widths"], 
         reward_inds=corridor["reward_inds"])

In [None]:
d = np.load(os.path.join(root, "results", "corridor_supp.npz"))
fig = fig3._suppfig_vr_algs(**d)