# Figure 1

In [None]:
import copy
import json
import os
import pickle
from datetime import datetime

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import MultipleLocator, FormatStrFormatter
from scipy import stats

from sylseq_paper.file_utils import loadmat
from sylseq_paper.plotting import default_plot_settings, ucsf_sequential_color_palette as colors, fancy_location_colors, smoothed_weighted_histogram, area_names
from sylseq_paper.statistics import fdr_omitnans

default_plot_settings(font='Helvetica', fontsize=16)

%load_ext autoreload
%autoreload 2

## Path setup

In [None]:
data_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'data'))
img_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'imaging'))

In [None]:
# Areas to exclude
exclude_areas = ['superiorparietal', 'inferiortemporal', 'paracentral', 'medial']

time_periods = {
    'target_presentation': 'visual pres.',
    'target_presentation_delay': 'visual pres.',
    'delay': 'delay',
    'go_cue': 'go-cue',
    'pre_exec': 'speech onset',
    'pre_exec_speech': 'speech onset'
}

# Renaming `SMA` label to `medial SFG`
renamed_areas = {
    'location': {
        'supplementarymotor': 'medialsuperiorfrontal'
    },
    'fancy_location': {
        'SMA': 'medial SFG'
    }
}

def rename_df_areas(df):
    for label_type, renames in renamed_areas.items():
        for key, val in renames.items():
            df[label_type] = np.where(df[label_type].values == key, val, df[label_type].values)
    return df

In [None]:
condition = {'sequence_type': 'complex', 'syllable_type': 'complex'}
alignments = {
    'target_presentation': {
        'window': [0.0, 2.5],
        'phase': 'target_presentation',
        'state': 'p',
        'sig_label': 'target_pres'
    },
    'delay': {
        'window': [0.0, 0.75],
        'phase': 'fixation_cue',
        'state': 'f',
        'sig_label': 'delay'
    },
    'pre_exec': {
        'window': [-0.85, -0.1],
        'phase': 'first_syllable',
        'state': 'e',
        'sig_label': 'pre_exec'
    },
    'speech': {
        'window': [0.0, 2.0],
        'phase': 'first_syllable',
        'state': 'e',
        'sig_label': 'speech'
    }
}
erp_alignments = {
    'target_presentation_delay': {
        'window': [-0.5, 2.5 + 0.75],
        'phase': 'target_presentation',
        'state': 'p',
        'sig_label': 'target_pres'
    },
    'go_cue': {
        'window': [-0.25, 0.25],
        'phase': 'go_cue',
        'state': 'g',
        'sig_label': 'pre_exec'
    },
    'pre_exec_speech': {
        'window': [-0.5, 2.0],
        'phase': 'first_syllable',
        'state': 'e',
        'sig_label': 'pre_exec'
    }
}
alignment_shorthand = {
    'target_presentation': 'target_pres',
    'delay': 'delay',
    'pre_exec': 'pre_exec',
    'speech': 'speech'
}
alignment_labels = {
    'target_presentation': 'Encoding',
    'delay': 'Delay',
    'pre_exec': 'Pre-speech',
    'speech': 'Speech production'
}
sr = 100
sig_thresh = 0.05

plot_alignments = ['target_presentation', 'go_cue', 'pre_exec']
plot_subjects = ['EC217', 'EC219', 'EC223', 'EC237', 'EC240', 'EC241', 'EC253', 'EC254', 'EC260', 'EC263', 'EC276']

warped_coord_paths = {
    'lateral': os.path.join(img_dir, '{}_lateral_elecmat_2D_warped.npy'),
    'medial': os.path.join(img_dir, '{}_medial_elecmat_2D_warped.npy')
}
mni_img_path = os.path.join(img_dir, 'MNI_{}_{}_brain_2D.png')

## Load data

In [None]:
# Load bad channels
with open(os.path.join(data_dir, 'bad_channels.pkl'), 'rb') as f:
    bad_channels = pickle.load(f)

# Load ERP trials
with open(os.path.join(data_dir, 'fig1_trial_averages.pkl'), 'rb') as f:
    all_trial_avgs = pickle.load(f)
    
# Load dataframe electrode information
with open(os.path.join(data_dir, 'fig1_ndf_sustained_se.pkl'), 'rb') as f:
    results = pickle.load(f)
    neural_df = results['neural_df']
    sustained_se = results['sustained_se']
    
# Rename ROIs
neural_df = rename_df_areas(neural_df)
ndf = neural_df.loc[neural_df.significant_above_baseline_anytime]

# Non-sig electrode location information
non_sig_edf = pd.read_hdf(os.path.join(data_dir, 'fig1_nonsig_edf.h5'))

# Single trial raster data
with open(os.path.join(data_dir, 'fig1_single_elec_raster.pkl'), 'rb') as f:
    single_elec_raster_dict = pickle.load(f)

# Get the number of total available electrodes in the anatomical region. same for every alignment.
total_elec_per_region = {key: [] for key in ['fancy_location', 'total_electrodes', 'sig_above_baseline_anytime']}
for location in neural_df.fancy_location.unique():
    total_elec_per_region['fancy_location'].append(location)
    total_N = neural_df.loc[(neural_df.fancy_location == location) & (neural_df.alignment == 'target_presentation')].shape[0]
    total_elec_per_region['total_electrodes'].append(total_N)
    sig_N = neural_df.loc[(neural_df.fancy_location == location) & (neural_df.alignment == 'target_presentation') & (neural_df.significant_above_baseline_anytime)].shape[0]
    total_elec_per_region['sig_above_baseline_anytime'].append(sig_N)
total_elec_per_region = pd.DataFrame(data=total_elec_per_region)
    
## Add alignments to be used for visualization
tpdelay = ndf.loc[ndf.alignment == 'target_presentation']
tpdelay['alignment'].iloc[:] = tpdelay.shape[0] * ['target_presentation_delay']

for key in ['significant_above_baseline_anytime', 'ranksum_above_baseline_significant']:
    tpdelay[key].iloc[:] = np.nan
    
pespeech = ndf.loc[ndf.alignment == 'pre_exec']
pespeech['alignment'].iloc[:] = pespeech.shape[0] * ['pre_exec_speech']

for key in ['significant_above_baseline_anytime', 'ranksum_above_baseline_significant']:
    pespeech[key].iloc[:] = np.nan
    
ndf = pd.concat([ndf, tpdelay, pespeech], axis=0)
    
all_areas = ndf.fancy_location.unique()
raster_colors = [fancy_location_colors[a] for a in all_areas]
color_dict = dict(zip(all_areas, raster_colors))
color_maps = {}
for area, color_hex in color_dict.items():
    color_maps[area] = mpl.colors.LinearSegmentedColormap.from_list(area, ["white", color_hex])

### Load cluster data

In [None]:
# Load cluster data
results = loadmat(os.path.join(data_dir, 'fig1_nmf_clustering.mat'))
nmf_weights = results['nmf_weights']
clustered_se = results['clustered_se']

assert len(np.intersect1d(clustered_se, sustained_se)) == len(clustered_se) == len(sustained_se)

num_clusters = 4

# Cluster adjustment.
# Adjust cluster order for C1 to be motor, C2 to be sustained, C3 to be suppression, and C4 to be feedback
cluster_names_to_idx = {
    'equal_tp_motor': 1,
    'higher_motor': 2,
    'higher_tp': 0,
    'equal': 3
}

cluster_idx_to_names = {}
for key, val in cluster_names_to_idx.items():
    cluster_idx_to_names[str(val)] = key
    
cluster_name_colors = {
    'higher_motor': colors[3],
    'higher_tp': colors[4],
    'equal': colors[2],
    'equal_tp_motor': colors[1]
}

cluster_labels = np.argmax(nmf_weights, axis=-1)
_, counts = np.unique(cluster_labels, return_counts=True)

cluster_cmaps = {}
for key, c in cluster_name_colors.items():
    cluster_cmaps[key] = mpl.colors.LinearSegmentedColormap.from_list(key, ['white', c])
    
cluster_cmaps['none'] = 'Greys'
cluster_cmaps['other'] = 'Greys'

cluster_df = {key: [] for key in ['x', 'y', 'subject_electrode', 'view', 'hemisphere', 'location', 'fancy_location', 'cluster_label', 'cluster_num_label']}
for key in cluster_names_to_idx.keys():
    cluster_df[f'{key}_weight'] = []

warpx, warpy, cluster_locations = [], [], []
for cur_se, se in enumerate(clustered_se):
    
    cur_df = neural_df.loc[neural_df.subject_electrode == se]
    
    if cur_df.shape[0] == 0:
        continue
    
    cluster_df['x'].append(cur_df.x.values[0])
    cluster_df['y'].append(cur_df.y.values[0])
    cluster_df['view'].append(cur_df.view.values[0])
    cluster_df['hemisphere'].append(cur_df.hemisphere.values[0])
    
    loc = cur_df.location.values[0]
    cluster_df['location'].append(loc)
    cluster_df['fancy_location'].append(area_names[loc])
    cluster_df['subject_electrode'].append(se)
    cluster_df['cluster_num_label'].append(cluster_labels[cur_se])
    cluster_df['cluster_label'].append(cluster_idx_to_names[str(cluster_labels[cur_se])])
    
    for key, val in cluster_names_to_idx.items():
        cluster_df[f'{key}_weight'].append(nmf_weights[cur_se, val])
    
cluster_df = pd.DataFrame(data=cluster_df)

# Rename ROIs
cluster_df = rename_df_areas(cluster_df)

_, idx = np.unique(ndf.subject_electrode.values, return_index=True)
reduced_ndf = ndf.iloc[idx]

cluster_count_df = {key: [] for key in ['fancy_location', 'cluster_label', 'count', 'percent', 'percent_of_cluster']}
for area in cluster_df.fancy_location.unique():
    
    cur_df = cluster_df.loc[cluster_df.fancy_location == area]
    num_area_elecs = cur_df.shape[0]
    
    for key in cluster_names_to_idx.keys():
        
        num_cluster_total_elecs = cluster_df.loc[cluster_df.cluster_label == key].shape[0]
        
        cluster_count_df['count'].append(cur_df.loc[cur_df.cluster_label == key].shape[0])
        cluster_count_df['percent'].append(100*(cur_df.loc[cur_df.cluster_label == key].shape[0] / num_area_elecs))
        cluster_count_df['percent_of_cluster'].append(100*(cur_df.loc[cur_df.cluster_label == key].shape[0] / num_cluster_total_elecs))
        cluster_count_df['cluster_label'].append(str(key))
        cluster_count_df['fancy_location'].append(area)
        
cluster_count_df = pd.DataFrame(data=cluster_count_df)

In [None]:
erp_se_data = {key: [] for key in ['cluster_erp_trials', 'nonsustained_trials']}
cluster_erp_trials = {}

for align_key in erp_alignments.keys():
    
    cluster_erp_trials[align_key] = []
    
    for se in clustered_se:

        subj, elec = se.split('_')
        elec = int(elec)
        cluster_erp_trials[align_key].append(all_trial_avgs[align_key][subj][:, elec])
        
        if align_key == 'target_presentation_delay':
            erp_se_data['cluster_erp_trials'].append(se)
        
    cluster_erp_trials[align_key] = np.stack(cluster_erp_trials[align_key], axis=1)
    
nonsustained_trials = {}

for align_key in erp_alignments.keys():
    
    nonsustained_trials[align_key] = []
    
    for se in np.setdiff1d(ndf.subject_electrode.unique(), clustered_se):

        subj, elec = se.split('_')
        elec = int(elec)
        nonsustained_trials[align_key].append(all_trial_avgs[align_key][subj][:, elec])
        
        if align_key == 'target_presentation_delay':
            erp_se_data['nonsustained_trials'].append(se)
        
    nonsustained_trials[align_key] = np.stack(nonsustained_trials[align_key], axis=1)

### Add cluster results to ndf

In [None]:
ndf_cluster_assignments = []
ndf_cluster_weight = []

for se in ndf.subject_electrode.values:
    cur_df = cluster_df.loc[cluster_df.subject_electrode == se]
    
    if cur_df.shape[0] == 0:
        ndf_cluster_assignments.append('none')
        ndf_cluster_weight.append(-15)
        
    else:
        cur_label = cur_df.cluster_label.values[0]
        ndf_cluster_assignments.append(cur_label)
        ndf_cluster_weight.append(cur_df[f'{cur_label}_weight'].values[0])

ndf['cluster_label'] = ndf_cluster_assignments
ndf['cluster_weight'] = ndf_cluster_weight

ndf.cluster_label = ndf.cluster_label.astype('category')
ndf.cluster_label = ndf.cluster_label.cat.set_categories(['higher_tp', 'equal_tp_motor', 'equal', 'higher_motor', 'none'])

### Order areas

In [None]:
# sort areas by cluster membership
fancy_sorted_areas = cluster_count_df.groupby('fancy_location').sum().sort_values('count', ascending=False).index.values
fancy_sorted_areas = np.concatenate([fancy_sorted_areas, np.setdiff1d(ndf.fancy_location.unique(), fancy_sorted_areas)])

ndf.fancy_location = ndf.fancy_location.astype('category')
ndf.fancy_location = ndf.fancy_location.cat.set_categories(fancy_sorted_areas)

In [None]:
a = ndf.loc[(ndf.alignment == 'delay') & 
        (ndf.ranksum_above_baseline_significant)].shape[0]
b = sustained_se.shape[0]
c = 100 * b / a
print(f'{b}/{a} delay significant electrodes are sustained ({c:0.2f}%)')

## Figure panels

### Task design

In [None]:
# Load mic trial
trial_target_sequence = 'blaa-draa-gloo'

with open(os.path.join(data_dir, 'fig1_task_design_panel.pkl'), 'rb') as f:
    results = pickle.load(f)
    mic_trial = results['mic_trial']
    mic_sr = results['mic_sr']
    mic_delay = results['mic_delay']
    mic_go = results['mic_go']
    tp_cue = results['tp_cue']
    rt_cue = results['rt_cue']

In [None]:
def figure1_task_design(ax=None, return_fig=False, fs=14):

    # mic_time = np.linspace(-0.5, speech_stop + 0.5 - tp_cue, len(mic_trial))
    mic_time = np.linspace(0, 3, len(mic_trial))

    if ax is None:
        fig, ax = plt.subplots(figsize=(5.61, 3.3))

    ax.plot(mic_time, 0.5*mic_trial, color='k', alpha=0.75, linewidth=0.25, clip_on=False)
    
    line_height = 1.15

    vert_align = 0.65
    ymax = 2.15
    arrow_height = 0.15

    # Target presentation
    left_align = 0.0
    tp = mpl.patches.Rectangle((left_align + 0.1, vert_align), 0.8, 1, fc='w', ec='k', linewidth=0.5)
    ax.add_patch(tp)
    ax.plot([left_align, left_align + 0.1], [line_height, line_height], color='k')
    ax.vlines(left_align, ymin=0, ymax=line_height, color='k')
    ax.annotate(trial_target_sequence, (left_align + 0.5, vert_align + 0.5), ha='center', va='center', fontsize=fs, color='k')
    ax.annotate('', xytext=(0, arrow_height), xy=(1, arrow_height), arrowprops=dict(arrowstyle='<->'))
    ax.annotate('Encoding\n(2.5 s)', (left_align + 0.5, vert_align - 0.3), ha='center', va='center', fontsize=fs, color='k')
    ax.annotate('Target pres.', (left_align + 0.1, vert_align + 1.15), ha='left', va='center', fontsize=fs, color='k')

    # Delay
    left_align = 1.0
    fc = mpl.patches.Rectangle((left_align + 0.1, vert_align), 0.8, 1, fc='w', ec='k', linewidth=0.5)
    ax.add_patch(fc)
    ax.plot([left_align, left_align + 0.1], [line_height, line_height], color='k')
    ax.vlines(left_align, ymin=0, ymax=line_height, color='k')
    ax.annotate('+', (left_align + 0.5, vert_align + 0.5), ha='center', va='center', fontsize=fs*2, color='k', weight='bold')
    ax.annotate('', xytext=(1, arrow_height), xy=(2, arrow_height), arrowprops=dict(arrowstyle='<->'))
    ax.annotate('Delay\n(~1 s)', (left_align + 0.5, vert_align - 0.3), ha='center', va='center', fontsize=fs, color='k')
    ax.annotate('Fix. cross', (left_align + 0.1, vert_align + 1.15), ha='left', va='center', fontsize=fs, color='k')

    # Go-cue
    left_align = 2.0
    gc = mpl.patches.Rectangle((left_align + 0.1, vert_align), 0.8, 1, fc='w', ec='k', linewidth=0.5)
    ax.add_patch(gc)
    ax.plot([left_align, left_align + 0.1], [line_height, line_height], color='k')
    ax.vlines(left_align, ymin=0, ymax=line_height, color='k')
    ax.annotate('+', (left_align + 0.5, vert_align + 0.5), ha='center', va='center', fontsize=fs*2, color='green', weight='bold')
    ax.annotate('', xytext=(2, arrow_height), xy=(mic_time[int((rt_cue - tp_cue) * mic_sr)], arrow_height), arrowprops=dict(arrowstyle='<->'))
    ax.annotate('Pre-speech\n(RT)', (left_align + 0.03, vert_align - 0.3), ha='left', va='center', fontsize=fs, color='k')
    ax.annotate('Go-cue', (left_align + 0.1, vert_align + 1.15), ha='left', va='center', fontsize=fs, color='k')

    ax.axes.set(ylim=(-0.15, ymax), xlim=(0, None))
    ax.axis('off');
    
    # ax.annotate(trial_target_sequence, (mic_time[int((rt_cue - tp_cue) * mic_sr)] - 0.1, -0.4), 
    #             va='top', ha='left', annotation_clip=False, fontstyle='italic', fontsize=fs)
    ax.annotate(trial_target_sequence, (mic_time[-1], -0.4), 
                va='top', ha='right', annotation_clip=False, fontstyle='italic', fontsize=fs)
    
    if return_fig:
        return fig, ax
    else:
        return ax

In [None]:
f, _ = figure1_task_design(return_fig=True, fs=12);

### Single trial rasters

In [None]:
def single_elec_raster(axs=None, fig=None, fs=12, s=1, data_dict=single_elec_raster_dict):
    
    if axs is None:
        fig, axs = plt.subplots(1, 3, figsize=(12, 5), gridspec_kw=dict(width_ratios=[1.5, 0.5, 1.5]))

    conditions = [
        ['none', 'simple'],
        ['none', 'complex'],
        ['simple', 'simple'],
        ['simple', 'complex'],
        ['complex', 'simple'],
        ['complex', 'complex'],
    ]
    condition_labels = {
        'noneSeq_simpleSyl': 'CV\n$\it{baa}$', 
        'noneSeq_complexSyl': 'CCV\n$\it{blaa}$', 
        'simpleSeq_simpleSyl': 'CV simple seq.\n$\it{baa}$-$\it{baa}$-$\it{baa}$', 
        'simpleSeq_complexSyl': 'CCV simple seq.\n$\it{blaa}$-$\it{blaa}$-$\it{blaa}$', 
        'complexSeq_simpleSyl': 'CV complex seq.\n$\it{baa}$-$\it{daa}$-$\it{gaa}$', 
        'complexSeq_complexSyl': 'CCV complex seq.\n$\it{blaa}$-$\it{draa}$-$\it{gloo}$', 
    }

    vlim = (0, 1)
    cmap = color_maps['middle PrCG']
    
    sr = data_dict['sr']

    for cur_align, alignment in enumerate(['target_presentation_delay', 'go_cue', 'pre_exec_speech']):

        ax = axs[cur_align]
        ax.tick_params(axis='both', labelsize=fs)

        ecog_trials = data_dict[alignment]['ecog_trials']
        ecog_counts = data_dict[alignment]['ecog_counts']
        trial_lengths = data_dict[alignment]['trial_lengths']
        production_times = data_dict[alignment]['production_times']

        n = np.arange(ecog_trials.shape[0])
        x = np.linspace(*erp_alignments[alignment]['window'], ecog_trials.shape[1])

        im = ax.pcolormesh(x, n, ecog_trials, vmin=vlim[0], vmax=vlim[1], cmap=cmap, shading='auto', rasterized=True)
        ax.axes.set(ylim=(n[0], n[-1]), xlim=(x[0], x[-1]))
        
        if cur_align == 2:
            cb = fig.colorbar(im, ax=ax, label='HGA (z-score)')
            cb.ax.tick_params(labelsize=fs)

        for spine in ['left', 'right', 'top', 'bottom']:
            ax.spines[spine].set_visible(True)

        if cur_align == 1:
            ax.axes.set(xlabel='Relative time (s)')

        yticks = []
        ecog_ticks = np.cumsum(ecog_counts)
        for i, y in enumerate(ecog_ticks):
            ax.axhline(y=y, color='k')

            if i == 0:
                yticks.append(ecog_counts[i] / 2)
            else:
                yticks.append(ecog_ticks[i - 1]  + (ecog_counts[i] / 2))

        if cur_align == 0:
            ax.axes.set(yticks=yticks)
            ax.set_yticklabels(condition_labels.values(), fontsize=fs, va='center', linespacing=0.95)
            ax.yaxis.set_tick_params(length=0);
        else:
            ax.axes.set(yticks=[])

        # Set major xticks at 1, minor at 0.25
        ax.xaxis.set_major_locator(MultipleLocator(1))
        ax.xaxis.set_major_formatter(FormatStrFormatter('%d'))
        ax.xaxis.set_minor_locator(MultipleLocator(0.25))

        ax.axvline(x=0, color='k')
        
        if alignment == 'pre_exec_speech':
            for cur_t, tlength in enumerate(trial_lengths):
                ax.scatter((tlength / sr) + erp_alignments[alignment]['window'][0], cur_t,
                           color='k', s=s, alpha=0.25)
                
            for cur_p, ptime in enumerate(production_times):
                ax.scatter(ptime, cur_p,
                           color='k', s=s, alpha=0.25)

        if alignment == 'target_presentation_delay':
            ax.axvline(x=2.5, color='k')
            ax.annotate('Encode', (-0.25, n[-1]+2), ha='left', va='bottom', annotation_clip=False, fontsize=fs)
            ax.annotate('Delay', (2.5, n[-1]+2), ha='center', va='bottom', annotation_clip=False, fontsize=fs)
        elif alignment == 'go_cue':
            ax.annotate('Go\ncue', (0, n[-1]+2), ha='center', va='bottom', annotation_clip=False, fontsize=fs)
        elif alignment == 'pre_exec_speech':
            ax.annotate('Speech\nonset', (0, n[-1]+2), ha='center', va='bottom', annotation_clip=False, fontsize=fs)
            
    return axs

In [None]:
single_elec_raster(fs=5);

### Significant electrodes

In [None]:
def figure1_sig_elecs(ax=None, view='lateral', highlight_se=None, plot_nonsig_elecs=True, s=10):

    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))

    hemi = 'lh'
    img = plt.imread(mni_img_path.format(hemi, view))
    
    ax.imshow(img, alpha=0.5)
    
    if plot_nonsig_elecs:
        cur_nonsig_df = non_sig_edf.loc[(non_sig_edf.hemisphere == hemi) & (non_sig_edf.view == view)]
        ax.scatter(cur_nonsig_df.warp_x.values, cur_nonsig_df.warp_y.values, color='k', s=s/2, alpha=0.2)
    
    cur_df = ndf.loc[(ndf.hemisphere == hemi) & (ndf.view == view)].drop_duplicates(subset=['subject_electrode'])
    cur_colors = [color_dict[area_names[loc]] for loc in cur_df.location.values]
    ax.scatter(cur_df.x.values, cur_df.y.values, color=cur_colors, s=s)
    
    ax.axis('off');
    
    if highlight_se:
        a = cur_df.loc[cur_df.subject_electrode == highlight_se]
        ax.scatter(a.x.values, a.y.values, fc='None', linewidth=0.5, ec='k', s=s*3)
    
    try:
        return fig, ax
    except:
        return ax

In [None]:
f, ax = figure1_sig_elecs();

### Rasters

In [None]:
def plot_raster(ax=None, cur_col=None, xlim=None, df=None, trial_dict=None, cmaps=None, window=None, 
                sr=None, alignment=None, plot_cluster_labels=False, fs=10, s=10):
    """
    df should be ordered how you want it to appear. will use cmaps based on `location`
    """
    yticks = []
    yticklabels = []
    x = np.linspace(window[0], window[1], int(np.abs(window).sum() * sr))
    
    for cur_location, location in enumerate(df.fancy_location.unique()):
        
        cur_df = df.loc[df.fancy_location == location]
        
        # get the number of total available electrodes in the anatomical region. same for every alignment.
        total_N = total_elec_per_region.loc[total_elec_per_region.fancy_location == location].total_electrodes.values[0]
        
        try:
            y = []
            for subj, elec in zip(cur_df.subject.values, cur_df.electrode.values):
                y.append(trial_dict[alignment][subj][:, elec])
            y = np.stack(y, axis=0)
            y = np.flipud(y)
        except:
            print('Skipping', location)
            print([].shape)
            continue
            
        ax[cur_location, cur_col].pcolormesh(x, np.arange(y.shape[0]), y, vmin=0, vmax=0.75, cmap=cmaps[location], shading='auto', rasterized=True)
        
        for spine in ['left', 'right', 'top', 'bottom']:
            ax[cur_location, cur_col].spines[spine].set_visible(True)
        
        ax[cur_location, cur_col].axes.set(yticks=[], xticks=[])
        
        if cur_col == 0:
            if y.shape[0] < 35:
                ylabel = f'{location} (n={y.shape[0]}/{total_N})'
            else:
                ylabel = f'{location}\n(n={y.shape[0]}/{total_N})'
            ax[cur_location, cur_col].set_ylabel(ylabel, rotation=0, ha='right', va='center', fontsize=fs)
                    
        if plot_cluster_labels:
            
            ax_divider = make_axes_locatable(ax[cur_location, cur_col])
            # Add an Axes to the right of the main Axes.
            cax = ax_divider.append_axes('right', size='3%', pad='1%')
            cax.axis('off')
            cax.spines['top'].set_visible(True)
            cax.axes.set(ylim=(0, cur_df.shape[0]))
            
            for cur_label, clabel in enumerate(np.flip(cur_df.cluster_label.values)):
                
                if clabel == 'none':
                    cur_cmap = 'binary'
                    scale = 0.0
                else:
                    cur_cmap = 'binary'
                    scale = 0.75
                    cax.scatter(0.5, cur_label, color='k', clip_on=False, s=s)   
                    
    return ax

In [None]:
order_sort_by = dict(by=['fancy_location', 'cluster_weight', 'pre_exec_speech_reaction_time'], 
                     ascending=[True, False, True])

raster_area_counts = ndf.loc[ndf.alignment == 'target_presentation'].sort_values(**order_sort_by).fancy_location.value_counts(sort=False).values

In [None]:
def figure1_full_rasters(axs=None, order_by_cluster=False, return_fig=False, raster_fs=10, raster_s=10):
    
    if axs is None:
        
        counts = ndf.loc[ndf.alignment == 'target_presentation'].sort_values(**order_sort_by).fancy_location.value_counts(sort=False).values
        
        width_ratios = [1.5, 0.5, 1.5]
            
        fig, axs = plt.subplots(len(counts), len(width_ratios), figsize=(16, 20), 
                                gridspec_kw={'height_ratios': counts, 'wspace': 0.1, 'width_ratios': width_ratios})

    alignment = 'target_presentation_delay'
    plot_raster(ax=axs,
                cur_col=0,
                df=ndf.loc[ndf.alignment == alignment].sort_values(**order_sort_by),
                trial_dict=all_trial_avgs,
                cmaps=color_maps,
                window=erp_alignments[alignment]['window'],
                sr=sr,
                alignment=alignment,
                fs=raster_fs, 
                s=raster_s);

    for ax in axs[:, 0]:
        ax.axvline(x=0.0, color='k', alpha=0.75)
        ax.axvline(x=2.5, color='k', alpha=0.75)
        ax.tick_params(axis='x', direction='in')#, width=1, length=5)
        ax.axes.set(xticks=np.arange(-0.5, 3.3, 0.25), xticklabels=[])

    axs[-1, 0].axes.set(xticklabels=[tick if i%2 == 0 else '' for i, tick in enumerate(np.arange(-0.5, 3.3, 0.25))], 
                        xlabel=f'Time from\n{time_periods[alignment]} (s)');

    alignment = 'go_cue'
    plot_raster(ax=axs,
                cur_col=1,
                df=ndf.loc[ndf.alignment == 'pre_exec'].sort_values(**order_sort_by),
                trial_dict=all_trial_avgs,
                cmaps=color_maps,
                window=erp_alignments['go_cue']['window'],
                sr=sr,
                alignment=alignment,
                fs=raster_fs, 
                s=raster_s);

    for ax in axs[:, 1]:
        ax.axvline(x=0, color='k', alpha=0.75)
        ax.tick_params(axis='x', direction='in')#, width=1, length=5)
        ax.axes.set(xticks=[-0.25, 0.0, 0.25], xticklabels=[])

    axs[-1, 1].axes.set(xticklabels=[-0.25, 0.0, 0.25], xlabel=f'Time from\n{time_periods[alignment]} (s)');

    alignment = 'pre_exec_speech'
    plot_raster(ax=axs,
                cur_col=2,
                df=ndf.loc[ndf.alignment == alignment].sort_values(**order_sort_by),
                trial_dict=all_trial_avgs,
                cmaps=color_maps,
                window=erp_alignments[alignment]['window'],
                sr=sr,
                alignment=alignment,
                plot_cluster_labels=order_by_cluster, 
                fs=raster_fs, 
                s=raster_s);

    for ax in axs[:, 2]:
        ax.axvline(x=0, color='k', alpha=0.75)
        ax.tick_params(axis='x', direction='in')#, width=1, length=5)
        ax.axes.set(xticks=np.arange(-0.5, 2.1, 0.25), xticklabels=[])

    axs[-1, 2].axes.set(xticklabels=[tick if i%2 == 0 else '' for i, tick in enumerate(np.arange(-0.5, 2.1, 0.25))], xlabel=f'Time from\n{time_periods[alignment]} (s)');
    
    for ax in axs.ravel():
        ax.tick_params(axis='both', labelsize=raster_fs)

    if return_fig:
        return fig, axs
    else:
        return axs

In [None]:
f, _ = figure1_full_rasters(order_by_cluster=True, return_fig=True);

### Cluster averages

In [None]:
def figure1_hga_cluster_avg_per_cluster(axs=None, return_fig=False, fs=12):
    
    ylim = [-0.5, 2.0]
    
    if axs is None:
        
        fig, axs = plt.subplots(num_clusters, 3, figsize=(6, 6), 
                                gridspec_kw={'width_ratios': [1.5, 0.5, 1.5], 'hspace': 0.5})
        
    for cur_cluster, (cluster_key, cluster_idx) in enumerate(cluster_names_to_idx.items()):
        
        idx = np.where(np.isin(clustered_se, cluster_df.loc[cluster_df.cluster_label == cluster_key].subject_electrode.values))[0]
        print(cluster_key, idx.shape)
        prev_timepoint = 0
        time_x = []
        
        for col, alignment in enumerate(['target_presentation_delay', 'go_cue', 'pre_exec_speech']):
            
            ax = axs[cur_cluster, col]
            
            cur_trials = cluster_erp_trials[alignment]
            num_timepoints = cur_trials.shape[0]
            
            x = np.linspace(*erp_alignments[alignment]['window'], num_timepoints)
            time_x.append(x)

            y = cur_trials[:, idx].mean(1)
            err = stats.sem(cur_trials[:, idx], axis=1)

            ax.plot(x, y, color=cluster_name_colors[cluster_key], label=cluster_key.capitalize(), zorder=1, clip_on=False)
            ax.fill_between(x, y - err, y + err, alpha=0.3, color=cluster_name_colors[cluster_key], zorder=1, clip_on=False)

            ax.axes.set(xlim=erp_alignments[alignment]['window'])
                
        # Set major yticks at 1, minor at 0.5
        axs[cur_cluster, 0].yaxis.set_major_locator(MultipleLocator(1))
        axs[cur_cluster, 0].yaxis.set_major_formatter(FormatStrFormatter('%d'))
        axs[cur_cluster, 0].yaxis.set_minor_locator(MultipleLocator(0.5))
        
        for i, t in enumerate(time_x):
            a = axs[cur_cluster, i]

            # Set major xticks at 1, minor at 0.25
            a.xaxis.set_major_locator(MultipleLocator(1))
            a.xaxis.set_major_formatter(FormatStrFormatter('%d'))
            a.xaxis.set_minor_locator(MultipleLocator(0.25))
            
            a.axes.set(ylim=ylim)#, xticklabels=xticklabels, xticks=xticks_all)
            a.axvline(x=0, color='k', alpha=0.75, linestyle='-', zorder=0)
            a.axhline(y=0, color='k', alpha=0.75, linestyle='-', zorder=0)

            if i != 0:
                a.spines['left'].set_visible(False)
                a.axes.set(yticks=[])
                sns.despine(ax=a, left=True, offset=dict(left=2.5, bottom=2.5))
            else:
                sns.despine(ax=a, offset=dict(left=2.5, bottom=2.5))
            
        axs[cur_cluster, 0].axvline(x=2.5, color='k', alpha=0.75, linestyle='-', zorder=0)
        axs[cur_cluster, 0].axes.set(ylabel=f'Cluster {cur_cluster+1}')
    
    axs[0, 0].annotate('Delay', (2.75, ylim[1]), va='bottom', ha='center', fontsize=fs, clip_on=False)
    axs[0, 0].annotate('Encode', (-0.5, ylim[1]), va='bottom', ha='left', fontsize=fs, clip_on=False)
    axs[0, 1].annotate('Go\ncue', (0.0, ylim[1]), va='bottom', ha='center', fontsize=fs, clip_on=False)
    axs[0, 2].annotate('Speech\nonset', (0.0, ylim[1]), va='bottom', ha='center', fontsize=fs, clip_on=False)
    
    axs[-1, 1].axes.set(xlabel='Relative time (s)')
    
    if return_fig:
        return fig, axs
    else:
        return axs

In [None]:
f, _ = figure1_hga_cluster_avg_per_cluster(return_fig=True);

### Cluster counts

#### MSE proportions of clusters 1-3

In [None]:
mse_proportions = {key: [] for key in ['area', 'count', 'mse_proportion', 'chi_square']}

for area in cluster_count_df.fancy_location.unique():
    cur_df = cluster_count_df.loc[(cluster_count_df.fancy_location == area) & (cluster_count_df.cluster_label != 'equal')]
    t = cur_df['count'].values.sum()
    mse = 0
    
    if t == 0:
        continue
    
    test = []
    for group_key, c in zip(cur_df.cluster_label.values, cur_df['count'].values):
            
        mse += ((1/3) - (c/t))**2
        
        test.append(c/t)
        
    mse_proportions['area'].append(area)
    mse_proportions['count'].append(t)
    mse_proportions['mse_proportion'].append(mse / 3)
    mse_proportions['chi_square'].append(stats.chisquare(test)[1])
    
mse_proportions = pd.DataFrame(data=mse_proportions)

In [None]:
mse_proportions.sort_values(['mse_proportion'])

In [None]:
def figure1_cluster_counts_stacked(ax=None, metric='percent', fs=10, legend=True, return_fig=False):

    if ax is None:
        fig, ax = plt.subplots(figsize=(15, 5))
        
    cur_df = cluster_count_df.groupby('fancy_location').sum().sort_values('count', ascending=False)
    location_order = cur_df.index.values
    location_counts = cur_df['count'].values
    
    # Don't plot areas with less than 5 sustained elecs for the bar charts
    keep_idx = np.where(location_counts >= 5)[0]
    location_order = location_order[keep_idx]
    subareas = location_order[keep_idx]
    
    bottoms = np.zeros(len(location_order))
    for hue in ['equal_tp_motor', 'higher_tp', 'higher_motor', 'equal']:
        
        cur_df = cluster_count_df.loc[cluster_count_df.cluster_label == hue]
        cur_df = cur_df.loc[cur_df.fancy_location.isin(subareas)]
        
        ax = sns.barplot(data=cur_df, x='fancy_location', y=metric, bottom=bottoms, 
                         order=location_order, color=cluster_name_colors[hue], width=0.7,
                         label=hue)
        
        bottoms += np.array([cur_df.loc[cur_df.fancy_location == loc][metric].values[0] for loc in location_order])

    yticks = np.arange(0, 76, 25)
    ax.axes.set(ylim=(0, 75), yticks=yticks, xlabel='');
    ax.set_ylabel('Number of electrodes', fontsize=fs)
    ax.set_xticklabels(location_order, rotation=90, fontsize=fs, ha='center')
    ax.set_yticklabels(yticks, fontsize=fs)
    
    if legend:
        current_handles, current_labels = ax.get_legend_handles_labels()
        ax.legend(frameon=False, title='', handles=current_handles, labels=[c.capitalize() for c in current_labels],
                  ncol=num_clusters, loc='upper right')
    
    if return_fig:
        return fig, ax
    else:
        return ax

In [None]:
fig, _ = figure1_cluster_counts_stacked(metric='count', legend=False, fs=15, return_fig=True)

### Cluster recons

In [None]:
def figure1_cluster_recons(axs=None, cluster_key=None, view='lateral', hemi='lh', layout='vertical', s=10, return_fig=False, 
                           same_recon=False, plot_title=False, plot_density=False, same_color=False):

    if axs is None:
        
        if same_recon:
            fig, axs = plt.subplots(figsize=(6, 6))
        else:
            if layout == 'horizontal':
                fig, axs = plt.subplots(1, num_clusters, figsize=(20, 8))
            elif layout == 'vertical':
                fig, axs = plt.subplots(num_clusters, 1, figsize=(6, 15))
        
    img = plt.imread(mni_img_path.format(hemi, view))
    
    if same_recon:
        axs.imshow(img, alpha=0.5)

    for i, cur_cluster in enumerate(list(cluster_names_to_idx.keys())):
        
        if same_recon:
            ax = axs
        else:
            ax = axs[i]
            ax.imshow(img, alpha=0.5)

        cur_df = cluster_df.loc[(cluster_df[cluster_key] == cur_cluster) &
                               (cluster_df.view == view) &
                               (cluster_df.hemisphere == hemi)]
        cur_df = cur_df.sort_values(by=[f'{cur_cluster}_weight'])
        
        ax.axis('off');
        
        if plot_title:
            ax.axes.set(title=cur_cluster.capitalize())
        
        if cur_df.shape[0] == 0:
            continue
        
        max_weight = np.max(cur_df[f'{cur_cluster}_weight'].values)
        
        if plot_density:
            
            # df with all sustained elecs
            recon_df = cluster_df.loc[(cluster_df.view == view) & (cluster_df.hemisphere == hemi)]
            w = np.where(recon_df.cluster_label.values == cur_cluster, 1, 0)

            # Get smoothed Gaussian
            cluster_density, xedges, yedges = smoothed_weighted_histogram(x=recon_df.x.values, 
                                                                          y=recon_df.y.values,
                                                                          weights=w, 
                                                                          xlim=[0, img.shape[1]], 
                                                                          ylim=[0, img.shape[0]],
                                                                          bins=100,
                                                                          smooth=2,
                                                                          baseline_norm=True)
            
            alphas = np.concatenate([np.zeros(15), np.linspace(0, 0.9, 55), 0.9*np.ones(30)])
            dx = 0.05
            for i in np.arange(0, 1, dx):
                idx = np.where(~np.logical_and(cluster_density > i, cluster_density <= i + dx))
                temp = np.copy(cluster_density)
                temp[idx] = np.nan
                ax.pcolormesh(xedges, yedges, temp.T, alpha=alphas[int(100*i)], vmin=0, vmax=1, cmap=cluster_cmaps[cur_cluster], zorder=1, rasterized=True)

        else:
            
            if same_color:
                cmap = mpl.colors.LinearSegmentedColormap.from_list(key, ['white', colors[5]])
            else:
                cmap = cluster_cmaps[cur_cluster]
            ax.scatter(cur_df.x.values, cur_df.y.values, c=cur_df[f'{cur_cluster}_weight'].values/max_weight, 
                       cmap=cmap, s=s, vmin=0, vmax=0.5, alpha=0.75)
                
    if return_fig:
        return fig, axs
    else:
        return axs

# Overall Figure

In [None]:
##### Figure setup #####
default_plot_settings(font='Helvetica', fontsize=5, linewidth=0.5, ticklength=2)
label_fontsize = 5

plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['axes.labelpad'] = 1
plt.rcParams['patch.linewidth'] = 0.25
plt.rcParams['hatch.linewidth'] = 0.25
plt.rcParams['lines.markeredgewidth'] = 0.25
plt.rcParams['axes.titlepad'] = 2
plt.rcParams['xtick.major.pad'] = 2
plt.rcParams['ytick.major.pad'] = 2
plt.rcParams['legend.handletextpad'] = 1

all_axs = {}
mm = 1/25.4
fig = plt.figure(figsize=(180*mm, 233*mm))
gs = mpl.gridspec.GridSpec(100, 100, figure=fig, hspace=1, wspace=0.5)

single_raster_gs = mpl.gridspec.GridSpecFromSubplotSpec(1, 3, 
                                                      subplot_spec=gs[2:17, 50:73],
                                                      width_ratios=[1.5, 0.5, 1.5],
                                                      wspace=0.15,
                                                     hspace=0.5)
raster_gs = mpl.gridspec.GridSpecFromSubplotSpec(len(raster_area_counts), 3, 
                                                 subplot_spec=gs[21:90, 10:56], 
                                                 width_ratios=[1.5, 0.5, 1.5],
                                                 height_ratios=raster_area_counts,
                                                 wspace=0.15)
cluster_gs = mpl.gridspec.GridSpecFromSubplotSpec(num_clusters, 3, 
                                                  subplot_spec=gs[53:90, 64:],
                                                  width_ratios=[1.5, 0.5, 1.5],
                                                  wspace=0.15,
                                                 hspace=0.5)
counts_gs = mpl.gridspec.GridSpecFromSubplotSpec(num_clusters, 1, 
                                              subplot_spec=gs[65:85, 66:],
                                              hspace=0.25)

##########

axline_kwargs = {
    'linestyle': '--',
    'linewidth': 2
}

all_axs['task'] = fig.add_subplot(gs[:15, 6:41])
all_axs['task'] = figure1_task_design(ax=all_axs['task'], fs=label_fontsize)

all_axs['sig_elecs_medial'] = fig.add_subplot(gs[:8, 78:88])
all_axs['sig_elecs_medial'] = figure1_sig_elecs(ax=all_axs['sig_elecs_medial'], s=0.5, view='medial')

all_axs['sig_elecs_lateral'] = fig.add_subplot(gs[5:21, 78:])
all_axs['sig_elecs_lateral'] = figure1_sig_elecs(ax=all_axs['sig_elecs_lateral'], highlight_se='EC260_76', s=1, view='lateral')

all_axs['single_raster'] = []
for c in range(single_raster_gs._ncols):
    all_axs['single_raster'].append(fig.add_subplot(single_raster_gs[c]))
all_axs['single_raster'] = single_elec_raster(axs=all_axs['single_raster'], fs=label_fontsize, s=0.5, fig=fig)

all_axs['rasters'] = []
for r in range(raster_gs._nrows):
    all_axs['rasters'].append([])
    for c in range(raster_gs._ncols):
        all_axs['rasters'][-1].append(fig.add_subplot(raster_gs[r, c]))
        
all_axs['rasters'] = np.array(all_axs['rasters'])
all_axs['rasters'] = figure1_full_rasters(axs=all_axs['rasters'], order_by_cluster=True,
                                         raster_fs=label_fontsize, raster_s=1);

all_axs['nmf_medial'] = fig.add_subplot(gs[21:29, 59:76])
all_axs['nmf_medial'] = figure1_cluster_recons(axs=all_axs['nmf_medial'], cluster_key='cluster_label',
                                               view='medial', layout='horizontal', s=4, same_recon=True);

all_axs['nmf_lateral'] = fig.add_subplot(gs[21:41, 74:])
all_axs['nmf_lateral'] = figure1_cluster_recons(axs=all_axs['nmf_lateral'], cluster_key='cluster_label', 
                                                view='lateral', layout='horizontal', s=4, same_recon=True);

all_axs['cluster_counts'] = fig.add_subplot(gs[38:44, 62:])
all_axs['cluster_counts'] = figure1_cluster_counts_stacked(metric='count', ax=all_axs['cluster_counts'], legend=False, fs=label_fontsize);
all_axs['cluster_counts'].patch.set_alpha(0.0)

all_axs['nmf_avg'] = []
for r in range(cluster_gs._nrows):
    all_axs['nmf_avg'].append([])
    for c in range(cluster_gs._ncols):
        all_axs['nmf_avg'][-1].append(fig.add_subplot(cluster_gs[r, c]))
        
all_axs['nmf_avg'] = np.array(all_axs['nmf_avg'])
all_axs['nmf_avg'] = figure1_hga_cluster_avg_per_cluster(axs=all_axs['nmf_avg'], fs=label_fontsize)

all_axs['nmf_avg_ylabel'] = fig.add_subplot(gs[53:90, 61:])
all_axs['nmf_avg_ylabel'].axis('off')
all_axs['nmf_avg_ylabel'].annotate('HGA (z-score)', (-0.07, 0.5), rotation=90, 
                                   xycoords='axes fraction', ha='center', va='center', fontsize=label_fontsize)


#### Panel labels
panel_axs = {}
panel_axs['a'] = fig.add_subplot(gs[2:7, :10])
panel_axs['b'] = fig.add_subplot(gs[2:7, 42:52])
panel_axs['c'] = fig.add_subplot(gs[2:7, 77:87])
panel_axs['d'] = fig.add_subplot(gs[22:27, :10])
panel_axs['e'] = fig.add_subplot(gs[23:28, 58:68])
panel_axs['f'] = fig.add_subplot(gs[37:42, 58:68])
panel_axs['g'] = fig.add_subplot(gs[53:58, 58:68])

for panel, ax in panel_axs.items():
    ax.axis('off')
    ax.annotate(panel, (0, 30), xycoords='axes points', ha='right', fontsize=7, weight='bold');

# Supplementary figures

## Fig. S5 - Cluster recon w/ density

In [None]:
fig, axs = plt.subplots(3, 4, figsize=(13, 8), gridspec_kw=dict(wspace=0.025))

axs[0, :] = figure1_cluster_recons(cluster_key='cluster_label', view='medial', 
                              layout='horizontal', s=20, axs=axs[0, :], plot_title=False);
axs[1, :] = figure1_cluster_recons(cluster_key='cluster_label', view='lateral', 
                              layout='horizontal', s=20, axs=axs[1, :]);
axs[2, :] = figure1_cluster_recons(cluster_key='cluster_label', view='lateral', 
                              layout='horizontal', axs=axs[2, :], plot_density=True);

axs[0, 0].annotate('a', (0, 1.0), xycoords='axes fraction', ha='right', fontsize=18, weight='bold');
axs[1, 0].annotate('b', (0, 1.0), xycoords='axes fraction', ha='right', fontsize=18, weight='bold');
axs[2, 0].annotate('c', (0, 1.0), xycoords='axes fraction', ha='right', fontsize=18, weight='bold');

## Fig. S6a
The other Fig. S6 panels (b-e) were plotted using Matlab. See `figureS6bcde-cluster_separation.m`.

In [None]:
default_plot_settings(fontsize=12)

tp_x = np.linspace(-0.5, 3.25, 375)
go_x = np.linspace(-0.25, 0.25, 50)
pe_x = np.linspace(-0.5, 2.0, 250)

cluster_key = 'cluster_label'

cluster_order = cluster_names_to_idx.keys()
per_cluster_counts = []
for c in cluster_order:
    per_cluster_counts.append(len(np.where(cluster_df[cluster_key].values == c)[0]))

fig, ax = plt.subplots(len(cluster_order), 3, figsize=(7, 9), 
                       gridspec_kw={'height_ratios': per_cluster_counts, 'width_ratios': [1.5, 0.5, 1.5], 'hspace': 0.1})

for cur_c, c in enumerate(cluster_order):
    
    cur_df = cluster_df.loc[cluster_df[cluster_key] == c]
    cur_df = cur_df.sort_values(by=[f'{c}_weight'])
    
    idx = []
    for target_se in cur_df.subject_electrode.values:
        idx.append(np.where(clustered_se == target_se)[0][0])
    idx = np.array(idx)
    print(len(idx))
    
    vmax = 2
    vmin = -vmax
        
    cur_cmap = 'RdBu_r'
    
    ax[cur_c, 0].pcolormesh(tp_x, range(len(idx)), cluster_erp_trials['target_presentation_delay'].T[idx, :], 
                            vmin=vmin, vmax=vmax, cmap=cur_cmap)
    ax[cur_c, 1].pcolormesh(go_x, range(len(idx)), cluster_erp_trials['go_cue'].T[idx, :], 
                            vmin=vmin, vmax=vmax, cmap=cur_cmap)
    im = ax[cur_c, 2].pcolormesh(pe_x, range(len(idx)), cluster_erp_trials['pre_exec_speech'].T[idx, :], 
                            vmin=vmin, vmax=vmax, cmap=cur_cmap)
    
    for a, t in zip(ax[cur_c, :], [tp_x, go_x, pe_x]):
        
        a.tick_params(axis='x', direction='in', width=1, length=5)
        a.axes.set(yticks=[], xticks=np.arange(t[0], t[-1]+0.03, 0.25), xticklabels=[])
        a.axvline(x=0, color='k', linewidth=1, alpha=0.75)
        
        for spine in ['left', 'right', 'top', 'bottom']:
            a.spines[spine].set_visible(True)
            a.spines[spine].set_linewidth(1)
            
    ax[cur_c, 0].axes.set(ylabel=f'Cluster {cur_c+1}\nelectrodes')
    
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.25, 0.025, 0.5])
fig.colorbar(im, cax=cbar_ax, label='HGA (z-score)')
    
for a, alignment in zip(ax[-1, :], ['target_presentation', 'go_cue', 'pre_exec']):
    a.axes.set(xlabel=f'Time relative to\n{time_periods[alignment]} (s)')

for a in ax[:, 0]:
    a.axvline(x=2.5, color='k', linewidth=1, alpha=0.75)
        
ax[-1, 0].axes.set(xticklabels=[tick if i%2 == 0 else '' for i, tick in enumerate(np.arange(-0.5, 3.3, 0.25))])
ax[-1, 1].axes.set(xticklabels=[-0.25, 0, 0.25])
ax[-1, 2].axes.set(xticklabels=[tick if i%2 == 0 else '' for i, tick in enumerate(np.arange(-0.5, 2.1, 0.25))]);

## Fig. S7

In [None]:
def figure1_cluster_counts_supp_roi(axs=None, metric='percent', fs=10, return_fig=False):

    if axs is None:
        fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
        
    cur_df = cluster_count_df.groupby('fancy_location').sum().sort_values('count', ascending=False)
    location_order = cur_df.index.values
    location_counts = cur_df['count'].values
    
    # plot the total sampling
    ax = axs[0]
    ax = sns.barplot(data=total_elec_per_region, x='fancy_location', y='total_electrodes', 
                     order=location_order, color='none', width=0.7, fill=False, ax=ax)
    
    yticks = np.arange(0, 251, 50)
    ax.axes.set(ylim=(yticks[0], yticks[-1]), yticks=yticks, xlabel='');
    ax.set_ylabel('Number of\nrecording electrodes', fontsize=fs)
    ax.set_xticklabels(location_order, rotation=90, fontsize=fs, ha='center')
    ax.set_yticklabels(yticks, fontsize=fs)
    ax.bar_label(ax.containers[-1])
    
    
    # plot the number of significant electrodes
    ax = axs[1]
    ax = sns.barplot(data=total_elec_per_region, x='fancy_location', y='sig_above_baseline_anytime', 
                     order=location_order, color='none', width=0.7, fill=False, ax=ax)
    
    yticks = np.arange(0, 251, 50)
    ax.axes.set(ylim=(yticks[0], yticks[-1]), yticks=yticks, xlabel='');
    ax.set_ylabel('Number of electrodes\nwith sig. task-evoked response\nduring any phase', fontsize=fs)
    ax.set_xticklabels(location_order, rotation=90, fontsize=fs, ha='center')
    ax.set_yticklabels(yticks, fontsize=fs)
    ax.bar_label(ax.containers[-1])
    

    # plot the cluster counts
    ax = axs[2]
    bottoms = np.zeros(len(location_order))
    for hue in ['equal_tp_motor', 'higher_tp', 'higher_motor', 'equal']:
        
        cur_df = cluster_count_df.loc[cluster_count_df.cluster_label == hue]
        
        ax = sns.barplot(data=cur_df, x='fancy_location', y=metric, bottom=bottoms, 
                         order=location_order, color=cluster_name_colors[hue], width=0.7, ax=ax)
        
        bottoms += np.array([cur_df.loc[cur_df.fancy_location == loc][metric].values[0] for loc in location_order])
    
    yticks = np.arange(0, 76, 25)
    ax.axes.set(ylim=(yticks[0], yticks[-1]), yticks=yticks, xlabel='');
    ax.set_ylabel('Number of\nsustained electrodes', fontsize=fs)
    ax.set_xticklabels(location_order, rotation=90, fontsize=fs, ha='center')
    ax.set_yticklabels(yticks, fontsize=fs)
    ax.bar_label(ax.containers[-1])
    
    
    if return_fig:
        return fig, axs
    else:
        return axs

In [None]:
default_plot_settings(fontsize=12)
fig, _ = figure1_cluster_counts_supp_roi(metric='count', fs=15, return_fig=True)