In [None]:
import sys, os
import numpy as np
import torch
from scipy.stats import zscore
from neuropop import nn_prediction
from rastermap import Rastermap, utils

### use cuda version of torch if available
# (otherwise use >>> device = torch.device('cpu'))
device = torch.device('cuda')

# path to paper code
sys.path.insert(0, '/github/rastermap/paper')
import fig3

# path to directory with data etc
### *** CHANGE THIS TO WHEREEVER YOU ARE DOWNLOADING THE DATA ***
root = "/media/carsen/ssd2/rastermap_paper/"
# (in this folder we have a "data" folder and a "results" folder)
os.makedirs(os.path.join(root, "data"), exist_ok=True)
os.makedirs(os.path.join(root, "results"), exist_ok=True)

### load spont data

(this data will be available upon publication of the paper)


In [None]:
dat = np.load(os.path.join(root, "data/", "spont_data.npz"))
spks, U, sv, V = dat["spks"], dat["U"], dat["sv"], dat["V"]
xpos, ypos = dat["xpos"], dat["ypos"]
tcam, tneural = dat["tcam"], dat["tneural"]
run, beh, beh_names = dat["run"], dat["beh"], dat["beh_names"]

### predict neural activity from behavior and run rastermap

In [None]:
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)

### fit linear and non-linear model from behavior to neural activity
Vfit = V.copy() * sv
for i in range(2):
    if i==1:
        pred_model = nn_prediction.PredictionNetwork(n_in=beh.shape[-1], n_kp=22, identity=False, 
                                                n_filt=10, n_latents=0,
                                n_out=Vfit.shape[-1], n_core_layers=1,
                                relu_wavelets=False, relu_latents=False)
    else:
        pred_model = nn_prediction.PredictionNetwork(n_in=beh.shape[-1], n_kp=22, n_out=Vfit.shape[-1], )
    pred_model.to(device)

    y_pred_all, ve_all, itest = pred_model.train_model(beh, Vfit, tcam, tneural, delay=-1,
                                                        learning_rate=1e-3, n_iter=400,
                                                    device=device, verbose=True)
    if i==1:
        y_pred_nn = y_pred_all.copy()

### run rastermap
model = Rastermap(n_clusters=100, 
                    n_PCs=128, 
                    locality=0.0,
                    time_lag_window=5,
                    ).fit(spks)
cc_nodes = model.cc.copy()
isort = model.isort


In [None]:

### bin full data into superneurons
nbin = 50
sn = zscore(utils.bin1d(spks[isort], nbin, axis=0), axis=1)
np.random.seed(0)
sn_rand = zscore(utils.bin1d(spks[np.random.permutation(spks.shape[0])], nbin, axis=0), axis=1)

# sort in time
model2 = Rastermap(n_clusters=100, locality=0.,
                    n_PCs=128).fit(sn.T)
isort2 = model2.isort

### bin test data and prediction into superneurons
sn_test = utils.bin1d(spks[isort][:,itest.flatten()], nbin, axis=0)
sn_pred_test = utils.bin1d(U[isort] @ y_pred_nn.T, nbin, axis=0)
sn_pred_test -= sn_test.mean(axis=1, keepdims=True)
sn_pred_test /= sn_test.std(axis=1, keepdims=True)
sn_test = zscore(sn_test, axis=1)
cc_pred = (sn_test * zscore(sn_pred_test, axis=1)).mean(axis=1)
# sort and bin PCs for maxstim estimation
U_sn = utils.bin1d(U[isort], nbin)

### maxstim estimation for superneurons (receptive fields)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
ms_model = nn_prediction.MaxStimModel(pred_model)
ms_model.requires_grad = False
u = torch.from_numpy(U_sn).to(device)
u.requires_grad = False
xr = ms_model.train_batch(u, n_iter=200, learning_rate=1e-2)
rfs = xr.detach().cpu().numpy()


# save results
np.savez(os.path.join(root, "results", "spont_proc.npz"), sn_test=sn_test, 
            sn_pred_test=sn_pred_test, itest=itest, 
            sn=sn, sn_rand=sn_rand, isort2=isort2,
            rfs=rfs, isort=isort, cc_nodes=cc_nodes,
            xpos=xpos, ypos=ypos,
            tcam=tcam, tneural=tneural, 
            run=run, beh=beh, beh_names=beh_names)


### make figure

In [None]:
# root path has folder "results" with saved results
# will save figures to "figures" folder
os.makedirs(os.path.join(root, "figures/"), exist_ok=True)
fig3.fig3(root, save_figure=True)

### make supp fig 

In [None]:
fig3.suppfig_random(root, save_figure=True)

### supp behavior sorting

In [None]:
from neuropop.utils import resample_data

# resample behavior for correlation sorting
beh = resample_data(beh, tcam, tneural)

whisk_speed = (np.diff(beh[:,1:3], axis=0)**2).sum(axis=-1)**0.5
nose_speed = (np.diff(beh[:,3:5], axis=0)**2).sum(axis=-1)**0.5
vars = np.stack((run[:-1], whisk_speed, nose_speed, beh[:-1,0]), axis=-1)
isorts_beh = []
corrs_beh = []
sn_beh = []
itrain = np.ones(spks.shape[1]-1, "bool")
itrain[itest.flatten()] = False
for i in range(vars.shape[1]):
    corr = (spks[:,:-1][:,itrain] * zscore(vars[itrain,i])).mean(axis=1)
    isort = corr.argsort()
    isorts_beh.append(isort)
    corrs_beh.append(corr)
    sn_beh.append(zscore(utils.bin1d(spks[isort][:,itest.flatten()], nbin, axis=0), axis=1))
corrs_beh = np.array(corrs_beh)
isorts_beh = np.array(isorts_beh)
sn_beh = np.array(sn_beh)


np.savez(os.path.join(root, "results/spont_corrs_beh.npz"), 
         vars=vars,
         sn_beh=sn_beh,
         corrs_beh=corrs_beh,
         isorts_beh=isorts_beh)


In [None]:
# make figure
fig3.suppfig_beh(root, save_figure=True)


### supp locality changes

In [None]:
import imp 
from metrics import distance_matrix
from tqdm import tqdm

localities = [0, 0.25, 0.5, 0.75, 1.0]
ccs_spont, isorts_spont, BBts, scores_spont = [], [], [], []

for k, loc in tqdm(enumerate(localities)):
    model = Rastermap(n_clusters=100, n_PCs=128, locality=loc,
                        time_lag_window=5, verbose=False).fit(spks)
    cc_nodes = model.cc.copy()
    BBt = model.BBt
    ccs_spont.append(cc_nodes)
    isorts_spont.append(model.isort)
    BBts.append(BBt.copy())

    inds = np.triu_indices(BBt.shape[0], k=1)
    cc = model.cc[inds[0], inds[1]].copy()
    bb = BBt[inds[0], inds[1]]
    score_global = (zscore(bb) * zscore(cc)).mean()

    cc = model.cc.copy()
    cc -= np.diag(np.diag(cc))
    score_local = 0
    ntot = 0
    lscores = np.nan * np.zeros((len(cc), 2))
    for j in range(len(cc)):
        if j < len(cc)-1:
            csort = cc[j].argsort()[::-1]
            lscores[j, 0] = np.nonzero(csort==(j+1))[0][0]
            ntot += 1
        if j > 0:
            csort = cc[:,j].argsort()[::-1]
            lscores[j, 1] = np.nonzero(csort==(j-1))[0][0]
            ntot += 1
        if lscores[j,0]==0 or lscores[j,1]==0:
            score_local += 1
            if lscores[j,0]==1 or lscores[j,1]==1:
                score_local += 1
    score_local /= ntot 

    scores_spont.append(np.array([score_global, score_local]))
    print(scores_spont[-1])
    

dat = np.load(os.path.join(root, "data/", "corridor_neur.npz"))
xpos, ypos, spks = dat["xpos"], dat["ypos"], dat["spks"]
spks = zscore(spks, axis=1)

isorts_vr, ccs_vr, BBts, scores_vr = [], [], [], []
for loc in tqdm(localities):
    model = Rastermap(n_clusters=100, n_PCs=200, 
                        time_lag_window=10, 
                        locality=loc, bin_size=100).fit(spks)
    isort = model.isort 
    cc_nodes = model.cc
    BBt = model.BBt
    isorts_vr.append(isort)
    ccs_vr.append(cc_nodes)
    BBts.append(BBt)

    inds = np.triu_indices(BBt.shape[0], k=1)
    cc = model.cc[inds[0], inds[1]].copy()
    bb = BBt[inds[0], inds[1]]
    score_global = (zscore(bb) * zscore(cc)).mean()

    cc = model.cc.copy()
    cc -= np.diag(np.diag(cc))
    score_local = 0
    ntot = 0
    lscores = np.nan * np.zeros((len(cc), 2))
    for j in range(len(cc)):
        if j < len(cc)-1:
            csort = cc[j].argsort()[::-1]
            lscores[j, 0] = np.nonzero(csort==(j+1))[0][0]
            ntot += 1
        if j > 0:
            csort = cc[:,j].argsort()[::-1]
            lscores[j, 1] = np.nonzero(csort==(j-1))[0][0]
            ntot += 1
        if lscores[j,0]==0 or lscores[j,1]==0:
            score_local += 1
            if lscores[j,0]==1 or lscores[j,1]==1:
                score_local += 1
    score_local /= ntot 
    
    scores_vr.append(np.array([score_global, score_local]))
    print(scores_vr[-1])

print(BBts[0][0,1], BBts[-1][0,1])

np.savez(os.path.join(root, "results/asym_vr_spont.npz"), 
         ccs_vr=np.array(ccs_vr), 
         ccs_spont=np.array(ccs_spont),
         isorts_vr=np.array(isorts_vr), 
         isorts_spont=np.array(isorts_spont), 
         localities=localities, 
         BBts=np.array(BBts),
         scores_vr=np.array(scores_vr),
         scores_spont=np.array(scores_spont))


In [None]:
# make figure
fig3.suppfig_locality(root, save_figure=True)