# DMFC recordings in macaque (Sohn, Narain et al 2019)

original data-loading notebook from: https://neurallatents.github.io/datasets#dmfcrsg

In [None]:
## Download dataset and required packages if necessary
!pip install git+https://github.com/neurallatents/nlb_tools.git
!pip install dandi
!dandi download https://gui.dandiarchive.org/#/dandiset/000130

load dataset

In [None]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from nlb_tools.nwb_interface import NWBDataset
from scipy.stats import zscore

sys.path.insert(0, "/github/rastermap/paper")
import fig5

# path to directory with data etc
### *** CHANGE THIS TO WHEREEVER YOU ARE SAVING YOUR DATA OUTPUTS ***
root = "/media/carsen/ssd2/rastermap_paper/"

## Load dataset
dataset = NWBDataset("000130/sub-Haydn/", "*train", split_heldout=False)

# bin at 20ms
dataset.resample(20)

# convert neural times from nanoseconds to seconds 
neural_time = (dataset.data.index.to_numpy() / 1e3).astype("float") / 1e6

# convert task times from nanoseconds to seconds for valid trials
# (valid trials = set_time at least 3 second after start of exp and before end of exp)
igood = ~dataset.trial_info.ready_time.isna()
igood *= ~dataset.trial_info.set_time.isna()
igood *= ~dataset.trial_info.go_time.isna()

ready_time = (dataset.trial_info.ready_time.to_numpy() / 1e3).astype("float") / 1e6
set_time = (dataset.trial_info.set_time.to_numpy() / 1e3).astype("float") / 1e6
go_time = (dataset.trial_info.go_time.to_numpy() / 1e3).astype("float") / 1e6

nt_sec = 3
igood *= (set_time - nt_sec) > 0
igood *= (set_time + nt_sec - neural_time[-1]) < 0

ready_time = ready_time[igood]
set_time = set_time[igood]
go_time = go_time[igood]
is_short = dataset.trial_info.is_short.to_numpy()[igood]
is_eye = dataset.trial_info.is_eye.to_numpy()[igood]
iti = dataset.trial_info.iti.to_numpy()[igood]

print(f"number of trials: {len(set_time)}")

print(len(is_eye), len(is_short))

# some spike timebins are NaN, replace with nearby values
spks = dataset.data.to_numpy().T.copy()
spks = spks.astype("float32")
ibad = np.isnan(spks[0])
nbad = np.arange(0, spks.shape[-1])[~ibad]
ibad = np.nonzero(ibad)[0]
ireplace = np.array([nbad[np.abs(nbad - ibad[i]).argmin()] for i in range(len(ibad))])
spks[:, ibad] = spks[:, ireplace]
print(spks.shape)

### run rastermap

In [None]:
from rastermap import Rastermap
model = Rastermap(n_clusters=None, # None turns off clustering and sorts single neurons
                  n_PCs=48, # use fewer PCs than neurons
                  locality=0.5, # some locality in sorting (this is a value from 0-1)
                  time_lag_window=20, # use future timepoints to compute correlation
                  grid_upsample=0, # 0 turns off upsampling since we're using single neurons
                  mean_time=True,
                  bin_size=1,
                  time_bin=1
                ).fit(spks, compute_X_embedding=True)
y = model.embedding # neurons x 1
isort = model.isort

reshape spks into trials

In [None]:
set_idx = np.array([np.abs(neural_time - set_time[i]).argmin() for i in range(len(set_time))])
nt = int(nt_sec/.02)
print(nt)
set_idx = np.arange(-nt, nt+1) + set_idx[:,np.newaxis]
spks_trials = spks[:, set_idx].copy()


# split trials into short and long prior blocks, and by set_time bins
ttypes = [is_short, ~is_short]
ttypes = [(is_eye)*(is_short), (is_eye)*(~is_short), (~is_eye)*(is_short), (~is_eye)*(~is_short)]
bins = list(np.linspace(0.46, 0.85, 6))
bins.extend(list(np.arange(0.95, 1.3, 0.1)))
sr = np.digitize(set_time - ready_time, bins) - 1
ttypes = [sr==i for i in range(9)]
ttypes.insert(4, (sr==4)*(is_short))
ttypes[5] = (sr==4)*(~is_short)
rts = np.array(bins) + 0.05
rts = (nt - rts/0.02).astype("int")
rts = list(rts)
rts.insert(4, rts[4])
gts = np.array([(go_time[ttypes[k]] - set_time[ttypes[k]]).mean() for k in range(len(ttypes))])
gts = (nt + gts/0.02).astype("int")

psths = []
for k in range(len(ttypes)):
    psths.append(spks_trials[isort][:,ttypes[k]].mean(axis=1))
psths = np.array(psths)
psths = psths.transpose(1,0,2).reshape(spks.shape[0], -1)
psths = zscore(psths, axis=1)
psths = psths.reshape(spks.shape[0], len(ttypes), -1).transpose(1,0,2)

### make figure

In [None]:
os.makedirs(os.path.join(root, "figures/"), exist_ok=True)
fig5.fig5(root, psths, rts, gts, save_figure=True)