In [None]:
import sys, os
import numpy as np
import torch
from scipy.stats import zscore
from neuropop import nn_prediction
from rastermap import Rastermap, utils

### use cuda version of torch if available
device = torch.device('cuda')

# path to paper code
sys.path.insert(0, '/github/rastermap/paper')
import fig2

# path to directory with data etc
### *** CHANGE THIS TO WHEREEVER YOU ARE DOWNLOADING THE DATA ***
root = "/media/carsen/ssd2/rastermap_paper/"
# (in this folder we have a "data" folder and a "results" folder)
os.makedirs(os.path.join(root, "data"), exist_ok=True)
os.makedirs(os.path.join(root, "results"), exist_ok=True)

### load spont data

(this data will be available upon publication of the paper)


In [None]:
dat = np.load(os.path.join(root, "data/", "spont_data.npz"))
spks, U, sv, V = dat["spks"], dat["U"], dat["sv"], dat["V"]
xpos, ypos = dat["xpos"], dat["ypos"]
tcam, tneural = dat["tcam"], dat["tneural"]
run, beh, beh_names = dat["run"], dat["beh"], dat["beh_names"]

### predict neural activity from behavior and run rastermap

In [None]:
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)

### fit linear and non-linear model from behavior to neural activity
Vfit = V.copy() * sv
for i in range(2):
    if i==1:
        pred_model = nn_prediction.PredictionNetwork(n_in=beh.shape[-1], n_kp=22, identity=False, 
                                                n_filt=10, n_latents=0,
                                n_out=Vfit.shape[-1], n_core_layers=1,
                                relu_wavelets=False, relu_latents=False)
    else:
        pred_model = nn_prediction.PredictionNetwork(n_in=beh.shape[-1], n_kp=22, n_out=Vfit.shape[-1], )
    pred_model.to(device)

    y_pred_all, ve_all, itest = pred_model.train_model(beh, Vfit, tcam, tneural, delay=-1,
                                                        learning_rate=1e-3, n_iter=400,
                                                    device=device, verbose=True)
    if i==1:
        y_pred_nn = y_pred_all.copy()

### run rastermap
model = Rastermap(n_clusters=100, 
                    n_PCs=128, 
                    locality=0.75,
                    time_lag_window=5,
                    ).fit(spks)
cc_nodes = model.cc.copy()
isort = model.isort

### bin full data into superneurons
nbin = 200
sn = zscore(utils.bin1d(spks[isort], nbin, axis=0), axis=1)
np.random.seed(0)
sn_rand = zscore(utils.bin1d(spks[np.random.permutation(spks.shape[0])], nbin, axis=0), axis=1)

# sort in time
model2 = Rastermap(n_clusters=100, locality=0.,
                    n_PCs=128).fit(sn.T)
isort2 = model2.isort

### bin test data and prediction into superneurons
sn_test = utils.bin1d(spks[isort][:,itest.flatten()], nbin, axis=0)
sn_pred_test = utils.bin1d(U[isort] @ y_pred_nn.T, nbin, axis=0)
sn_pred_test -= sn_test.mean(axis=1, keepdims=True)
sn_pred_test /= sn_test.std(axis=1, keepdims=True)
sn_test = zscore(sn_test, axis=1)
cc_pred = (sn_test * zscore(sn_pred_test, axis=1)).mean(axis=1)
# sort and bin PCs for maxstim estimation
U_sn = utils.bin1d(U[isort], nbin)

### maxstim estimation for superneurons (receptive fields)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
ms_model = nn_prediction.MaxStimModel(pred_model)
ms_model.requires_grad = False
u = torch.from_numpy(U_sn).to(device)
u.requires_grad = False
xr = ms_model.train_batch(u, n_iter=200, learning_rate=1e-2)
rfs = xr.detach().cpu().numpy()


# save results
np.savez(os.path.join(root, "results", "spont_proc.npz"), sn_test=sn_test, 
            sn_pred_test=sn_pred_test, itest=itest, 
            sn=sn, sn_rand=sn_rand, isort2=isort2,
            rfs=rfs, isort=isort, cc_nodes=cc_nodes,
            xpos=xpos, ypos=ypos,
            tcam=tcam, tneural=tneural, 
            run=run, beh=beh, beh_names=beh_names)


### make figure

In [None]:
# root path has folder "results" with saved results
# will save figures to "figures" folder
os.makedirs(os.path.join(root, "figures/"), exist_ok=True)
fig2.fig2(root, save_figure=True)

### make supp fig 

In [None]:
fig2.suppfig_random(root, save_figure=True)