# Basis profile curve identification to understand electrical stimulation effects in human brain networks

The Basis Profile Curve identification method was developed in Matlab by Kai J. Miller and is described here:
- Miller, K. J., Müller, K. R., & Hermes, D. (2021). Basis profile curve identification to understand electrical stimulation effects in human brain networks. *PLoS computational biology*, 17(9), e1008710. doi: https://doi.org/10.1371/journal.pcbi.1008710)

> **Abstract (Miller et al. 2021)**
<br /> Brain networks can be explored by delivering brief pulses of electrical current in one area while measuring voltage responses in other areas. We propose a convergent paradigm to study brain dynamics, focusing on a single brain site to observe the average effect of stimulating each of many other brain sites. Viewed in this manner, visually-apparent motifs in the temporal response shape emerge from adjacent stimulation sites. This work constructs and illustrates a data-driven approach to determine characteristic spatiotemporal structure in these response shapes, summarized by a set of unique “basis profile curves” (BPCs). Each BPC may be mapped back to underlying anatomy in a natural way, quantifying projection strength from each stimulation site using simple metrics. Our technique is demonstrated for an array of implanted brain surface electrodes in a human patient. This framework enables straightforward interpretation of single-pulse brain stimulation data, and can be applied generically to explore the diverse milieu of interactions that comprise the connectome.

This project was supported by the National Institute Of Mental Health of the National Institutes of Health under Award Number R01MH122258. The content is solely the responsibility of the authors and does not necessarily represent the official views of the National Institutes of Health.

This Jupyter Notebook was written by Alex Rockhill, Tal Pal Attia, Harvey Huang, Max vd Boom and Dora Hermes.

# Overview

This notebook will walk you through the following five steps to compute and visualize Basis Profile Curves (BPCs) in an interactive manner:
1. Python packages
2. Load BIDS data and metadata using MNE, and look at stimulation driven inputs to 1 electrode
3. Group these inputs into Basis Profile Curves (BPCs)
4. Visualize the BPCs


# 1. Python packages

Dependencies:
 - [MNE](https://mne.tools/stable/index.html)
 - [MNE-BIDS](https://mne.tools/mne-bids/stable/index.html)
 - [openneuro-py](https://pypi.org/project/openneuro-py/)
 - numpy (version>=1.24.4)
 - pandas (version>=2.0.3)
 - scipy (version>=1.10.1)
 - sklearn 
 - matplotlib
 - tqdm
 - ipykernel
 - nilearn


In [None]:
# import packages
from pathlib import Path
import numpy as np
import openneuro
import mne
import mne_bids
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.decomposition import NMF
from scipy import stats


# 2. Loading and checking the BIDS data and metadata

An example iEEG dataset is available on OpenNeuro ([link](https://openneuro.org/datasets/ds003708)). This dataset is formatted according to the Brain Imaging Data Structure ([BIDS](https://bids.neuroimaging.io/)) and contains one subject to work with in this tutorial. 

This dataset includes an electrocorticography (ECoG) dataset with single pulse stimulation, and accompanying metadata, such as electrode positions, channel information, stimulation events etc.

In [None]:
# You may download the data using openneuro-py

# dataset = 'ds003708'
# root = Path('..') / '..' / dataset
# openneuro.download(dataset=dataset, target_dir=root)

In [None]:
# If you already downloaded the data, you can just set the root path to the BIDS directory

"""
# The preprocessed Brainvision data are located in /derivatives/preprocessed/
"""

root = Path('/home/jovyan/shared/ds003708-download/derivatives/preprocessed/') # '/full/path/to/Basis_profile_curve/data'

## 2.1 Load BIDS metadata

We use pyBIDS to initialize a BIDSLayout: this will index the files and metadata under the specified root folder (should be 1 subject, 1 session, 1 run).

In [None]:
"""
Specify the subject, session, task and run 
"""

bids_sub = '01' # The subject label
bids_ses = 'ieeg01' # The session label
bids_task = 'ccep' # The task name
bids_run = '01' # The run name

In [None]:
# Read the raw BIDS structure using mne_bids

path = mne_bids.BIDSPath(
    subject=bids_sub, session=bids_ses, task=bids_task, run=bids_run, root=root)
raw = mne_bids.read_raw_bids(path)

# lets look at some of the metadata
display(raw)

In [None]:
# Render an MNI brain and plot electrode positions

trans = mne.transforms.Transform(fro='head', to='mri', trans=np.eye(4))  # identity
fig = mne.viz.plot_alignment(
    raw.info, trans=trans, subject='fsaverage', surfaces='pial')
mne.viz.set_3d_view(fig, azimuth=190)

"""
Rotate the image to focus on an area of interest 
"""

In [None]:
# Make a snapshot of the current view 

xy, im = mne.viz.snapshot_brain_montage(fig, raw.info)

In [None]:
# Plot the snapshot with electrode labels added

%matplotlib
fig, ax = plt.subplots(figsize=(10, 10))
ax.axis('off')
ax.imshow(im)
for name, pos in xy.items():
    if pos[0] >= 0 and pos[1] >= 0:  # no NaN locations
        ax.text(*pos, name, ha='center', va='center', fontsize=8)

fig.show()


## 2.2 Select electrode of interest and load data


In [None]:
""" 
Pick an electrode of interest and plot epoched data
"""

contact = 'LMS2'

# Indicate how long epochs should be
tmin, tmax = -1, 2
# What is the baseline interval?
bl_tmin, bl_tmax = -0.5, -0.05

# Import events
events, event_id = mne.events_from_annotations(raw)

# Read events from the BIDS events.tsv file
metadata = pd.read_csv(path.update(suffix='events'), sep='\t')
keep = metadata.trial_type == 'electrical_stimulation'
if 'status' in metadata:
    keep = np.logical_and(keep, metadata.status == 'good')
metadata = metadata[keep]
epochs = mne.Epochs(raw, events[keep],
                    tmin=tmin, tmax=tmax,
                    baseline=(bl_tmin, bl_tmax), picks=[contact],
                    preload=True)
# try ``baseline=None`` for no baseline correction to play around
epochs.metadata = metadata  # contains stimulation location information

# unpack each pair separated by a hyphen, only use trials where
# stimulation was delivered to channels other than the channel of
# interest
epochs.metadata['site1'], epochs.metadata['site2'] = np.array([
    sites.split('-') for sites in
    epochs.metadata.electrical_stimulation_site]).T
exclude = np.in1d(epochs.metadata.site1, contact) | \
    np.in1d(epochs.metadata.site2, contact)
epochs = epochs[~exclude]

epochs.plot_image(picks=[contact], cmap='viridis', vmin=-250, vmax=250)

# 3. Calculate BPCs

Identifying basis profile curves (BPCs) that group characteristic shapes in the convergent CCEPs. The input for the BPC calculation is the convergent matrix (V) with signals from 1 channel and all stimulation pairs.

## 3.1 Select time-frame for BPC extraction

In [None]:
""" 
Select the epoch times to enter in the BPC analyses in seconds 
"""

bpc_tmin, bpc_tmax = 0.015, 1

## 3.2 Calculate the significance matrix

To calculate the significance matrix, we project the unit-normalized stimulation trials into all other trials. We then calculate t-values across all subgroups of stimulation pairs.

In [None]:
# stim_sites contains the stimulation sites for each epoch
stim_sites = epochs.metadata.electrical_stimulation_site
V = epochs.get_data(tmin=bpc_tmin, tmax=bpc_tmax)[:, 0]  # select only channel
times = epochs.times[(epochs.times >= bpc_tmin) & (epochs.times <= bpc_tmax)]
V0 = V / np.linalg.norm(V, axis=1)[:, None]  # L2 norm each trial
P = V0 @ V.T  # calculate internal cross-trial projections

# pairs contains the unique stimulation pairs (subgroups)
# we calculate tmat, where each index contains a t-value 
# t-values indicate cross-subgroup interactions
pairs = np.array(sorted(np.unique(stim_sites)))
tmat = np.zeros((len(pairs), len(pairs)))
for i, pair1 in enumerate(pairs):
    for j, pair2 in enumerate(pairs):
        b = P[np.ix_(stim_sites == pair1, stim_sites == pair2)]
        if i == j:  # subset without diagonal
            b = np.concatenate([b[np.tril_indices(b.shape[0], k=-1)],
                                b[np.triu_indices(b.shape[0], k=1)]])
        b = b.ravel()
        tmat[i, j] = np.mean(b) * np.sqrt(len(b)) / np.std(b, ddof=1)

fig, ax = plt.subplots(figsize=(10, 10))
img = ax.imshow(tmat, vmin=0, vmax=10)
ax.set_xticks(range(tmat.shape[0]))
ax.set_xticklabels(pairs, rotation=90, fontsize=8)
ax.set_xlabel('Stimulation Pair')
ax.set_yticks(range(tmat.shape[0]))
ax.set_yticklabels(pairs, fontsize=8)
ax.set_ylabel('Stimulation Pair')
ax.set_title(r'Significance Matrix $\Xi$', fontsize=15)
fig.colorbar(img, ax=ax)
fig.subplots_adjust(bottom=0.2)
fig.show()

## 3.3 Iteratively decrease inner components of non-negative matrix factorization

Using Non-Negative Matrix Factorization (NMF) on the significance matrix to cluster sites that produce similar measured responses.

In [None]:
# Find two non-negative matrices (W, H) whose product approximates the factorize, 
# non-negative and rescaled matrix (t0). 
# At the end, the matrix H has size of number of clusters by stimulation pair sub-groups.

t0 = tmat.copy()
t0[t0 < 0] = 0
t0[np.isnan(t0)] = 0
t0 /= (np.max(t0))

cluster_dim = 9
n_reruns = 20
tol = 1e-5
random_state = 11
for n_components in range(cluster_dim, 1, -1):
    this_error = None
    for k in range(n_reruns):
        model = NMF(n_components=n_components, init='random', solver='mu',
                    tol=tol, max_iter=10000, random_state=random_state).fit(t0)
        if this_error is None or model.reconstruction_err_ < this_error:
            this_error = model.reconstruction_err_
            W = model.transform(t0)
            H = model.components_
    H /= np.linalg.norm(H, axis=1)[:, None]
    nmf_penalty = np.triu(H @ H.T, k=1).sum()
    print(f'Inner dimension: {n_components}, off diagonal score: {nmf_penalty}')
    if nmf_penalty < 1:
        break


## 3.4 Identification of Basis Profile Curves 

BPCs are identified from the clustered groups (rows of H) using linear kernel PCA.

In [None]:
# Output will show the subgroup numbers clustered in each BPC

# find stimulation trials for every BPC using linear kernal PCA
def kpca(X):
    F, S, _ = np.linalg.svd(X.T)  # Compute the eigenvalues and right eigenvectors
    ES = X @ F  # kernel trick
    # divide through to obtain unit-normalized eigenvectors
    E = ES / (np.ones((X.shape[0], 1)) @ S[None])
    return E

# find significant pairs per BPC; must be > threshold and greater than other BPCs
bpc_pairs = np.zeros((len(pairs))) * np.nan  # index of bpc
Bs = np.zeros((n_components, V.shape[1]))  # n_BPCs x n_times
for bpc_idx in range(n_components):  # loop over BPCs
    bpc_pair_idxs = np.where((H[bpc_idx] == np.max(H, axis=0)) &
                             (H[bpc_idx] > 1 / (2 * np.sqrt(len(pairs)))))[0]
    bpc_pairs[bpc_pair_idxs] = bpc_idx
    bpc_trials = np.concatenate([np.where(stim_sites == pairs[idx])[0]
                                 for idx in bpc_pair_idxs])
    Bs[bpc_idx] = kpca(V[bpc_trials].T)[:, 0]  # basis vector is 1st PC
    if np.mean(Bs[bpc_idx] @ V[bpc_trials].T) < 0:
        Bs[bpc_idx] *= -1  # sign flip
    print(bpc_idx, bpc_pair_idxs)
excluded_pairs = pairs[np.isnan(bpc_pairs)]

# 4. Visualize BPCs

## 4.1 Plot Calculated BPCs

Plot the BPC waveforms

In [None]:
# %%
# plot BPCs
colors = cm.tab10(np.linspace(0, 1, 10))

fig, ax = plt.subplots(figsize=(5, 4))
for i, bpc in enumerate(Bs):
    ax.plot(times, bpc, color=colors[i], label=i)
ax.set_xlabel('Time from stimulation (s)')
ax.set_ylabel('Normalized weight of BPCs')
ax.set_title('Calculated BPCs', fontsize=15)
ax.legend()
fig.tight_layout()
fig.show()

## 4.2 Spatial representation of the BPCs

Render the BPC weights on the cortical surface

In [None]:
# %%
# curve statistics for each stim pair
alphas = np.zeros((len(stim_sites))) * np.nan
epsilon2s = np.zeros((len(stim_sites))) * np.nan
V2s = np.zeros((len(stim_sites))) * np.nan
errxprojs = np.zeros((len(pairs))) * np.nan
p_vals = np.zeros((len(pairs))) * np.nan
plotweights = np.zeros((len(pairs))) * np.nan
for bpc_idx in range(n_components):  # loop over BPCs
    # alpha coefficient weights for basis curve into V
    bpc_alphas = Bs[bpc_idx] @ V.T
    # residual epsilon (error timeseries) for basis bb after alpha*B coefficient fit
    bpc_epsilon2 = V - (Bs[bpc_idx][:, None] @ bpc_alphas[None]).T
    errxproj = bpc_epsilon2  @ bpc_epsilon2.T  # calculate all projections of error
    V_selfproj = V @ V.T  # power in each trial

    # cycle through pair types represented by this basis curve
    for pair_idx in np.where(bpc_pairs == bpc_idx)[0]:
        trials = stim_sites == pairs[pair_idx]
        # alpha coefficient weights for basis curve bb into V
        alphas[trials] = bpc_alphas[trials]
        # self-submatrix of error projections
        a = errxproj[np.ix_(trials, trials)]
        epsilon2s[trials] = np.diag(a)
        # sum-squared individual trials
        V2s[trials] = np.diag(V_selfproj[np.ix_(trials, trials)])

        # gather all off-diagonal elements from self-submatrix
        b = np.concatenate([a[np.tril_indices(a.shape[0], k=-1)],
                            a[np.triu_indices(a.shape[0], k=1)]])

        # systematic residual structure within a stim pair group for a given basis will be
        # given by set of native normalized internal cross-projections
        errxproj[pair_idx] = np.mean(b) * np.sqrt(len(b)) / np.std(b, ddof=1)

        plotweights[pair_idx] = np.mean(alphas[trials] / np.sqrt(epsilon2s[trials]))
        T_stat, p_val = stats.ttest_1samp((alphas[trials] / np.sqrt(epsilon2s[trials])), 0)
        p_vals[pair_idx] = p_val

In [None]:
# %%
# Render the BPC weights from each stimulation pair on the cortical surface.

colors = cm.tab10(np.linspace(0, 1, 10))

fig, ax = plt.subplots()
ax.axis('off')
ax.imshow(im)
for i, name in enumerate(pairs):
    if np.isnan(bpc_pairs[i]):
        continue
    ch0, ch1 = name.split('-')
    pos = (xy[ch0] + xy[ch1]) / 2
    if pos[0] < 0 or pos[0] > im.shape[0] or pos[1] < 0 or pos[1] > im.shape[1]:
        continue
    color = colors[int(bpc_pairs[i])]
    size = plotweights[i] * 200
    ax.scatter(*pos, color=color[:3], s=[size], alpha=0.75)
fig.show()


## 4.3 Optional parameters to change

Once you have completed this, you can also select a different electrode in Section 2.2 to look at the various inputs into different regions.

You could also go back to Section 3.1 and change the time interval over which the BPCs are calculated (e.g. 0.2 - 1 sec) and look at the effects on the outputs.

Alternatively, we shared 74 patients with CCEP data in BIDS format on OpenNeuro to with [a study on developmental changes in transmission speed](https://www.nature.com/articles/s41593-023-01272-0). Check out these data: https://openneuro.org/datasets/ds004080/versions/1.2.4 