<a href="https://githubtocolab.com/neurallatents/neurallatents.github.io/blob/master/notebooks/mc_maze.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MC_Maze Dataset

[DANDI](https://dandiarchive.org/#/dandiset/000128)

## 1 Overview

The MC_Maze dataset includes data from four recording sessions of a macaque performing delayed center-out reaches, with neural activity recorded from the primary motor and dorsal premotor cortices. This data was provided by Krishna Shenoy, Mark Churchland, and Matt Kaufman from Stanford University, and you can learn more about the task design, data collection, and their analyses of the data in a number of papers, including [this](https://pubmed.ncbi.nlm.nih.gov/21040842/) (Churchland et al. 2010).

### 1.1 Task

The maze task is a delayed center-out reaching task, meaning that there was a period between target presentation and the go cue when the monkey could plan and prepare its movement. The reaches took place in a number of maze configurations, resulting in a variety of straight and curved reaches. In some trials, three targets were presented, though only one was reachable, further complicating the task for the monkey. The delayed reaching paradigm allows for the examination of neural activity during both movement preparation and execution.

### 1.2 Data

For these datasets, neural activity was recorded from two Utah arrays: one implanted in the dorsal premotor cortex, which is thought to play a role in movement planning, and one in the primary motor cortex. This recorded data was spike sorted offline into the provided unit spike times. In addition to the neural data, cursor position and the monkeys' hand and gaze position were recorded during the experiment, and we estimated the hand velocity offline using the recorded hand position.

The MC_Maze datasets are entirely trialized, and no data was recorded between trials. As a result, though the data is presented here as a single continuous block, trials are separated by NaN margins to indicate when the data is discontinuous. In addition, in three of our dataset files, we reduced the number of trials to a fixed amount for evaluation of model performance on limited data.

## 2 Exploring the data

### 2.1 Setup

First, let's make the necessary imports and load the dataset.

In [None]:
## Download dataset and required packages if necessary
!pip install git+https://github.com/neurallatents/nlb_tools.git
!pip install dandi
!dandi download https://gui.dandiarchive.org/dandiset/000128

Collecting git+https://github.com/neurallatents/nlb_tools.git
  Cloning https://github.com/neurallatents/nlb_tools.git to /tmp/pip-req-build-csaqeus7
  Running command git clone --filter=blob:none --quiet https://github.com/neurallatents/nlb_tools.git /tmp/pip-req-build-csaqeus7
  Resolved https://github.com/neurallatents/nlb_tools.git to commit 1ddc15f45b56388ff093d1396b7b87b36fa32a68
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pandas<=1.3.4,>=1.0.0 (from nlb_tools==0.0.3)
  Downloading pandas-1.3.4.tar.gz (4.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m90.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pynwb (from nlb_tools==0.0.3)
  D

In [None]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='.*unverified HTTPS request.*')
!pip install pandas==2.2.2 --force-reinstall

In [None]:
## Imports

# %matplotlib widget # uncomment for interactive plots
from nlb_tools.nwb_interface import NWBDataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

After running the above cell, please re-run the 'Imports' cell (cell `91Bb2uZWcRGx`) to ensure all libraries are loaded with the correct `pandas` version. You might need to restart the runtime and run all cells up to that point if the issue persists.

In [None]:
## Load dataset
dataset = NWBDataset("/content/000128/sub-Jenkins/", "*train", split_heldout=False)

In [None]:
#prepping the dataset
dataset_copy = dataset.data.copy()
df = dataset_copy.reset_index()
times_trials = dataset.trial_info

#converting times into milliseconds
df["milliseconds"]=df["clock_time"].dt.total_seconds()*1000
times_trials["start_ms"]=times_trials["start_time"].dt.total_seconds()*1000
times_trials["end_ms"]=times_trials["end_time"].dt.total_seconds()*1000

In [None]:
#mapping each millisecond to the trial index based on the trial start & end times
df['corresponding_trial'] = -1

for idx, row in times_trials.iterrows():
    mask = (df['milliseconds'] >= row['start_ms']) & (df['milliseconds'] <= row['end_ms'])
    df.loc[mask, 'corresponding_trial'] = row['trial_id']

In [None]:
neurons = df["spikes"].columns.to_list() #list of recorded neurons
n_dict = {} #if a neuron fired at all during a particular trial, how many times?
for N in neurons:
    neuron_trials = df[df[("spikes", N)] > 0.0]["corresponding_trial"].unique()
    n_dict[N]=neuron_trials

In [None]:
#defining a padding function to make the lists equal in length in order to create a df
def pad_lists(list_dict):
    max_len = max(len(v) for v in list_dict.values())
    padded = {
        k: list(v) + [np.nan] * (max_len - len(v))   # convert to list first
        for k, v in list_dict.items()
    }
    return pd.DataFrame(padded)
n_df = pad_lists(n_dict)

#calculating the correlation matrix (for each pair of recorded neurons,
#what was the correlation in terms of their trial-based firing?)
corr = n_df.corr()
print(corr)

In [None]:
#possible reason for result: lots of zeros in the matrix
#strategy shift - instead of focusing on trial-based firing correspondance,
#look at the more granular level of ms, i.e., what neurons seem to consistently
#fire together across different milliseconds?

In [None]:
# creating a spike train array for each neuron, i.e., a simple array of the neuron's firing rate at each millisecond
spikes = {nid: df[("spikes", nid)].values.astype(np.uint8)
          for nid in neurons}

In [None]:
#defining a cross-correlogram function: how do spikes from one neuron relate to
#spikes from another neuron across the time period - useful for functional connections

#how does it work:
#step 1: expand the spike count array to a spike-time array
#step 2: determine the lags: we chose the -80ms to +80ms lag in line with what the paper got for the best results
#step 3: for each spike of neuron a, calculate the difference in spike times with each spike of neuron b and map to one of the lag "keys"
#step 4: how many of these (a, b) spike pairs fall into each lag category
def ccg_count(a, b, max_lag=80):
    spk_a = np.repeat(np.arange(len(a)), a)
    spk_b = np.repeat(np.arange(len(b)), b)

    cc = np.zeros(2*max_lag+1, dtype=np.int32)

    for t in spk_a:
        diffs = spk_b - t
        valid = diffs[(diffs >= -max_lag) & (diffs <= max_lag)]
        np.add.at(cc, valid + max_lag, 1)

    lags = np.arange(-max_lag, max_lag+1)
    return lags, cc

In [None]:
#jittering each spike time randomly as a precaution against false correlation
#how does it work?
#imagine my average time to compute to campus is 20 minutes and classmate who lives on campus is 10 minutes. both of us have class at the same time in WLH
# this is a true correlation cuz we are systematically arriving at WLH across times (today, tomorrow, day after). lets say i bump into a friend whi happens to be p
# passing from in front of WLH as I arrive at campus. they are not taking the class but they happened to be at wlh the same time as me (false correlation)
#jittering will shift our times - i arrive say 10 minutes late (at t=30) to campus, my classmate arrives 5 minutes early (t=5). the friendi bumped into is also
#jittered and arrives 7 minutes early. In this case, me and my classmate are still at wlh during the "event" (class) but the friend and I no longer
#bump into each other (false correlation)
def jitter_spikes_count(train, window=20):
    spk_times = np.repeat(np.arange(len(train)), train)  # all spike times
    jit = np.zeros_like(train)

    for t in spk_times:
        t_new = t + np.random.randint(-window, window+1)
        if 0 <= t_new < len(train):
            jit[t_new] += 1  # maintain multiple spikes
    return jit

  #applying jittering to the ccg output
def jitter_corrected_ccg(a, b, max_lag=80, window=20, n_shuffles=50):
  lags, raw = ccg_count(a, b, max_lag)

  baseline = np.zeros_like(raw, dtype=float)
  for _ in range(n_shuffles):
      aj = jitter_spikes_count(a, window)
      bj = jitter_spikes_count(b, window)
      _, sh = ccg_count(aj, bj, max_lag)
      baseline += sh

  baseline /= n_shuffles
  return lags, raw - baseline

In [None]:
#assigns a score to the neuron a and b's functional connectivity based on systematic relation of spike times
def compute_connectivity(spikes, max_lag=80):
    ids = list(spikes.keys())
    N = len(ids)
    conn = np.zeros((N, N))

    for i in range(N):
        for j in range(N):
            if i == j:
                continue

            lags, corr = jitter_corrected_ccg(spikes[ids[i]], spikes[ids[j]], max_lag)

            # connectivity metric: mean from lag 1 to 6 ms
            c = corr[max_lag+1 : max_lag+6].mean()
            conn[i, j] = c

    return ids, conn

In [None]:
!pip install networkx

In [None]:
import networkx as nx

ids, conn = compute_connectivity(spikes)

# convert to graph (only keep significant or positive edges)
G = nx.from_numpy_array(conn > 0)    # threshold optional

# relabel nodes from indices to neuron IDs
mapping = {i: ids[i] for i in range(len(ids))}
G = nx.relabel_nodes(G, mapping)

# community detection (Louvain-style)
from networkx.algorithms.community import greedy_modularity_communities
communities = greedy_modularity_communities(G)

for i, community in enumerate(communities):
    print(f"Community {i+1}: {sorted(community)}")
conn_df = pd.DataFrame(conn, index=ids, columns=ids)
print(conn_df.head())


In [None]:
pos = nx.spring_layout(G, seed=42)
colors = ['red', 'blue', 'green', 'orange', 'purple', 'cyan', 'magenta']
plt.figure(figsize=(8,6))

for i, community in enumerate(communities):
    nx.draw_networkx_nodes(G, pos, nodelist=community,
                           node_color=colors[i % len(colors)], node_size=500)
nx.draw_networkx_edges(G, pos)
nx.draw_networkx_labels(G, pos)
plt.title("Neuron Communities")
plt.show()

### 2.2 Continuous data

The continuous data provided with the MC_Maze datasets includes:
* `cursor_pos` - x and y position of the cursor controlled by the monkey
* `eye_pos` - x and y position of the monkey's point of gaze on the screen, in mm
* `hand_pos` - x and y position of the monkey's hand, in mm
* `hand_vel` - x and y velocities of the monkey's hand, in mm/s, computed offline using `np.gradient`
* `spikes` - spike times binned at 1 ms

In [None]:
## View 'dataset.data'
dataset.data

### 2.3 Trial metadata

The trial info dataframe has a number of fields containing information about each trial:
* `trial_id` - a number assigned to each trial during loading
* `start_time` - time when the trial begins
* `end_time` - time when the trial ends
* `trial_type` - the maze configuration that was used for the trial
* `trial_version` - a number 0-2 indicating which variant of the maze is presented. 0 is 1-target no-barrier, 1 is 1-target with barriers, 2 is 3-target with barriers
* `maze_id` - a unique identifier for the maze configuration used. Different maze sets were used for each session, so `trial_type` is not unique across dataset files
* `success` - whether the trial was successful. In provided training data, unsuccessful trials have already been removed
* `target_on_time` - time of target presentation
* `go_cue_time` - time of go cue
* `move_onset_time` - time of movement onset, calculated offline with robust algorithm
* `rt` - reaction time in ms
* `delay` - time between target presentation and go cue in ms
* `num_targets` - number of targets displayed in the maze
* `target_pos` - x and y position of the target(s)
* `num_barriers` - number of barriers in the maze
* `barrier_pos` - position of the barrier(s). First two values are the x and y positions of the center of the barrier, last two values are the half-width and half-height of the barrier
* `active_target` - which target is reachable and was hit by the monkey. Its value corresponds to the index of the target in `target_pos`

In [None]:
## View 'dataset.trial_info'
dataset.trial_info

### 2.4 Reach conditions

The full MC_Maze dataset has 108 different reach conditions, and the reduced-size datasets each have 27. Because of the maze barriers, reaches take on a variety of straight and curved trajectories. Here, we'll plot the average trajectory per condition to see what typical reaches look like.

In [None]:
## Optional resampling
# It may be beneficial to resample the data before you proceed to the analysis sections,
# as they may be fairly memory-intensive. However, we have not tested this notebook at bin sizes
# of over 20 ms, so we cannot guarantee that everything will work as intended at
# those larger bin sizes. If you still have memory issues, you could also use the reduced size
# datasets instead of the full one.
dataset.resample(5)

In [None]:
## Plot trial-averaged reaches

# Find unique conditions
conds = dataset.trial_info.set_index(['trial_type', 'trial_version']).index.unique().tolist()

# Initialize plot
fig = plt.figure(figsize=(6, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])

# Loop over conditions and compute average trajectory
for cond in conds:
    # Find trials in condition
    mask = np.all(dataset.trial_info[['trial_type', 'trial_version']] == cond, axis=1)
    # Extract trial data
    trial_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(-50, 450), ignored_trials=(~mask))
    # Average hand position across trials
    traj = trial_data.groupby('align_time')[[('hand_pos', 'x'), ('hand_pos', 'y')]].mean().to_numpy()
    # Determine reach angle for color
    active_target = dataset.trial_info[mask].target_pos.iloc[0][dataset.trial_info[mask].active_target.iloc[0]]
    reach_angle = np.arctan2(*active_target[::-1])
    # Plot reach
    ax.plot(traj[:, 0], traj[:, 1], linewidth=0.7, color=plt.cm.hsv(reach_angle / (2*np.pi) + 0.5))

plt.axis('off')
plt.show()

### 2.5 Single-neuron responses

As shown above, there are a large number of conditions in the full MC_Maze dataset which differ in the position and number of targets and barriers. While many of the reaches follow similar trajectories, the differences in the exact positions of targets and barriers may result in slight differences in observed neural activity, especially during pre-movement planning periods. Here, we'll plot PSTHs for a single neuron for a subset of the conditions.

In [None]:
## Plot PSTHs

# Seed generator for consistent plots
np.random.seed(2468)
n_conds = 8 # number of conditions to plot

# Smooth spikes with 50 ms std Gaussian
dataset.smooth_spk(50, name='smth_50')

# Plot random neuron
neur_num = np.random.choice(dataset.data.spikes.columns)

# Find unique conditions
conds = dataset.trial_info.set_index(['trial_type', 'trial_version']).index.unique().tolist()

# Plot random subset of conditions
for i in np.random.choice(len(conds), size=n_conds, replace=False):
    cond = conds[i]
    # Find trials in condition
    mask = np.all(dataset.trial_info[['trial_type', 'trial_version']] == cond, axis=1)
    # Extract trial data
    trial_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(-50, 450), ignored_trials=(~mask))
    # Average hand position across trials
    psth = trial_data.groupby('align_time')[[('spikes_smth_50', neur_num)]].mean().to_numpy() / dataset.bin_width * 1000
    # Color PSTHs by reach angle
    active_target = dataset.trial_info[mask].target_pos.iloc[0][dataset.trial_info[mask].active_target.iloc[0]]
    reach_angle = np.arctan2(*active_target[::-1])
    # Plot reach
    plt.plot(np.arange(-50, 450, dataset.bin_width), psth, label=cond, color=plt.cm.hsv(reach_angle / (2*np.pi) + 0.5))

# Add labels
plt.ylim(bottom=0)
plt.xlabel('Time after movement onset (ms)')
plt.ylabel('Firing rate (spk/s)')
plt.title(f'Neur {neur_num} PSTH')
plt.legend(title='condition', loc='upper right')
plt.show()

As you can see,  this neuron displays a variety of responses in different conditions. However, the information provided by this individual neuron is limited, and this method relies on averaging across trials, which discards single-trial variablity as noise. Instead of averaging across trials, we can use the activity of the whole neural population on a single-trial basis to extract behaviorally-relevant information.

### 2.6 Decoding hand kinematics

Next, we'll try to decode the monkey's hand velocity solely from smoothed population spiking activity. Since it takes time for signals to travel from the motor cortex to muscles, we lag the kinematics data relative to neural data. 80 ms lag is where we have seen the best results, but feel free to vary the value and compare performance.

In [None]:
## Kinematic decoding

# Extract neural data and lagged hand velocity
trial_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(-130, 370))
lagged_trial_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(-50, 450))
rates = trial_data.spikes_smth_50.to_numpy()
vel = lagged_trial_data.hand_vel.to_numpy()

# Fit and evaluate decoder
gscv = GridSearchCV(Ridge(), {'alpha': np.logspace(-4, 0, 5)})
gscv.fit(rates, vel)
pred_vel = gscv.predict(rates)
print(f"Decoding R2: {gscv.best_score_}")

# Merge predictions back to continuous data
pred_vel_df = pd.DataFrame(pred_vel, index=lagged_trial_data.clock_time, columns=pd.MultiIndex.from_tuples([('pred_vel', 'x'), ('pred_vel', 'y')]))
dataset.data = pd.concat([dataset.data, pred_vel_df], axis=1)

We got an R2 of around 0.6, which is not too bad. Let's visualize how our predicted kinematics compare to the true data. We'll plot results for only one condition to keep things from getting too cluttered.

In [None]:
## Plot predicted vs true kinematics

# Choose 23rd condition to plot
cond = conds[23]

# Find trials in condition and extract data
mask = np.all(dataset.trial_info[['trial_type', 'trial_version']] == cond, axis=1)
trial_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(-50, 450), ignored_trials=(~mask))

# Initialize figure
fig, axs = plt.subplots(2, 3, figsize=(10, 4))
t = np.arange(-50, 450, dataset.bin_width)

# Loop through trials in condition
for _, trial in trial_data.groupby('trial_id'):
    # True and predicted x velocity
    axs[0][0].plot(t, trial.hand_vel.x, linewidth=0.7, color='black')
    axs[1][0].plot(t, trial.pred_vel.x, linewidth=0.7, color='blue')
    # True and predicted y velocity
    axs[0][1].plot(t, trial.hand_vel.y, linewidth=0.7, color='black')
    axs[1][1].plot(t, trial.pred_vel.y, linewidth=0.7, color='blue')
    # True and predicted trajectories
    true_traj = np.cumsum(trial.hand_vel.to_numpy(), axis=0) * dataset.bin_width / 1000
    pred_traj = np.cumsum(trial.pred_vel.to_numpy(), axis=0) * dataset.bin_width / 1000
    axs[0][2].plot(true_traj[:, 0], true_traj[:, 1], linewidth=0.7, color='black')
    axs[1][2].plot(pred_traj[:, 0], pred_traj[:, 1], linewidth=0.7, color='blue')

# Set up shared axes
for i in range(2):
    axs[i][0].set_xlim(-50, 450)
    axs[i][1].set_xlim(-50, 450)
    axs[i][2].set_xlim(-180, 180)
    axs[i][2].set_ylim(-130, 130)

# Add labels
axs[0][0].set_title('X velocity (mm/s)')
axs[0][1].set_title('Y velocity (mm/s)')
axs[0][2].set_title('Reach trajectory')
plt.show()

As you can see, decoding from single-trial smoothed spikes is able to recover general trends in hand kinematics but cannot accurately recover the true kinematics very well, particularly in the y-direction.

### 2.7 Neural trajectories

Finally, we'll look at how neural population activity evolves over time in each condition by applying PCA to trial-averaged smoothed spikes. We'll plot the resulting trajectories and compare them across conditions.

In [None]:
## Plot neural trajectories for subset of conditions

# Seed generator for consistent plots
np.random.seed(2021)
n_conds = 27 # number of conditions to plot

# Get unique conditions
conds = dataset.trial_info.set_index(['trial_type', 'trial_version']).index.unique().tolist()

# Loop through conditions
rates = []
colors = []
for i in np.random.choice(len(conds), n_conds):
    cond = conds[i]
    # Find trials in condition
    mask = np.all(dataset.trial_info[['trial_type', 'trial_version']] == cond, axis=1)
    # Extract trial data
    trial_data = dataset.make_trial_data(align_field='move_onset_time', align_range=(-50, 450), ignored_trials=(~mask))
    # Append averaged smoothed spikes for condition
    rates.append(trial_data.groupby('align_time')[trial_data[['spikes_smth_50']].columns].mean().to_numpy())
    # Append reach angle-based color for condition
    active_target = dataset.trial_info[mask].target_pos.iloc[0][dataset.trial_info[mask].active_target.iloc[0]]
    reach_angle = np.arctan2(*active_target[::-1])
    colors.append(plt.cm.hsv(reach_angle / (2*np.pi) + 0.5))

# Stack data and apply PCA
rate_stack = np.vstack(rates)
rate_scaled = StandardScaler().fit_transform(rate_stack)
pca = PCA(n_components=3)
traj_stack = pca.fit_transform(rate_scaled)
traj_arr = traj_stack.reshape((n_conds, len(rates[0]), -1))

# Loop through trajectories and plot
fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(111, projection='3d')
for traj, col in zip(traj_arr, colors):
    ax.plot(traj[:, 0], traj[:, 1], traj[:, 2], color=col)
    ax.scatter(traj[0, 0], traj[0, 1], traj[0, 2], color=col)

# Add labels
ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_zlabel('PC3')
plt.show()

You might see that conditions with similar reach angles tend to be clustered together. Note that coloring by angle to the target does not take into account the curved paths of the reaches toward the target, so you should not expect all similarly colored conditions to have similar neural trajectories.

Using an alternate dimensionality reduction method called jPCA, researchers have revealed rotational dynamics in neural population activity during these maze reaches. You can read more about that in this [paper](https://www.nature.com/articles/nature11129).

## 3 Summary

In this notebook, we:
* introduced the MC_Maze dataset, describing the task and provided data
* looked at what specific continuous and trial data is included
* demonstrated the task by plotting average reach trajectories for each condition
* explored single-neuron responses by plotting PSTHs for some conditions
* evaluated how accurately spike-smoothed population activity can decode hand velocity
* visualized the timecourse of neural population activity by extracting trial-averaged neural trajectories with PCA

The Maze datasets are exceptional in their combination of behavioral richness (number of task configurations), stereotyped behavior across repeated trials (tens of repeats for each task configuration), and high total trial counts (thousands). Due to the instructed delay paradigm and lack of unpredictable task events, population activity can be well-modeled as an autonomous dynamical system. Because of this, the dataset has been used for validating a number of latent variable models.