# Fig. S14-15

In [None]:
import json
import os
import pickle
from itertools import combinations

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.ticker import MultipleLocator, FormatStrFormatter, AutoMinorLocator
from scipy import stats
from sklearn.metrics import confusion_matrix
from statannot import add_stat_annotation
from statsmodels.stats.multitest import fdrcorrection

from sylseq_paper.plotting import default_plot_settings, ucsf_sequential_color_palette as colors

default_plot_settings(font='Helvetica', fontsize=14)

import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 

%load_ext autoreload
%autoreload 2

In [None]:
# Set up paths
data_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'data'))
img_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'imaging'))
mni_img_path = os.path.join(img_dir, 'MNI_{}_{}_brain_2D.png')

sig_thresh = 0.05
plot_subjects = [217, 219, 223, 237, 240, 241, 253, 254, 260, 263, 267, 276]
plot_subjects = [f'EC{s}' for s in plot_subjects]

all_edf = pd.read_hdf(os.path.join(data_dir, 'all_anatomical_info.h5'))

## Fig. S14 - Unique sequence decoding

### ERPs for unique sequence

In [None]:
sequence_conditions = {
    '1': {'sequence_type': 'complex', 'syllable_type': 'complex', 'target_sequence': 'blaa-draa-gloo'},
    '2': {'sequence_type': 'complex', 'syllable_type': 'complex', 'target_sequence': 'blaa-gloo-draa'},
    '3': {'sequence_type': 'complex', 'syllable_type': 'complex', 'target_sequence': 'draa-gloo-blaa'},
}
seq_colors = np.copy(colors)
seq_subjects = ['EC260', 'EC253']
seq_elecs = [60, 7]

alignments = {
    'tp_and_delay': {
        'window': [-0.5, 2.5 + 1],
        'phase': 'target_presentation',
        'phase_letter': 'p'
    },
    'speech': {
        'window': [-0.5, 1.0],
        'phase': 'first_syllable',
        'phase_letter': 'e'
    }
}

with open(os.path.join(data_dir, 'suppfig_unique_sequence_erps.pkl'), 'rb') as f:
    seq_complexity_trials = pickle.load(f)
    
# Load accuracy df
bacc = pd.read_hdf(os.path.join(data_dir, 'suppfig_unique_sequence_accuracy.h5'))

decode_periods = ['mid_target_presentation', 'fixation_cue', 'pre_execution', 'during_execution']

In [None]:
def sequence_erps(axs=None, legend_bbox=(0.125, -0.08), annotate_elecs=True, return_fig=False, fs=11, legend_fs=10):
    
    if axs is None:
        fig, axs = plt.subplots(1, 4, figsize=(10, 2), gridspec_kw=dict(hspace=0.2, width_ratios=[3, 1.75, 3, 1.75]))
        
    erp_plot_kwargs = {
        'ylim': [-1, 6],
    }
    erp_axline_kwargs = {
        'linestyle': '--',
        'linewidth': 1,
        'color': 'k'
    }
        
    for r, (cur_subject, cur_elec) in enumerate(zip(seq_subjects, seq_elecs)):
        
        axs[r*2].set_ylabel('HGA (z-score)', fontsize=fs)
        
        se = f'{cur_subject}_{cur_elec}'
        
        for col, alignment_label in enumerate(alignments.keys()):
            
            cur_ax = 2 * r + col
        
            ax = axs[cur_ax]
            ax.axvline(x=0, alpha=0.7, zorder=1, **erp_axline_kwargs)
            
            if alignment_label == 'tp_and_delay':

                ax.axvline(x=2.5, alpha=0.7, zorder=1, **erp_axline_kwargs)
                
                if r == 0:
                    ax.annotate('Target\npres.', (0.0, erp_plot_kwargs['ylim'][1]), ha='center', va='bottom', fontsize=fs)
                    ax.annotate('Fixation\ncross', (2.5, erp_plot_kwargs['ylim'][1]), ha='center', va='bottom', fontsize=fs)
            
            elif alignment_label == 'speech' and r == 0:
                
                ax.annotate('Speech\nonset', (0.0, erp_plot_kwargs['ylim'][1]), ha='center', va='bottom', fontsize=fs)
                    
            if col == 0:
                despine_kw = dict(offset=dict(bottom=5, left=5), left=False)
            else:
                despine_kw = dict(offset=dict(bottom=5, left=5), left=True)
                
            sns.despine(ax=ax, **despine_kw)
                
            if col != 0:
                ax.get_yaxis().set_visible(False)

            ax.axhline(y=0, alpha=0.7, zorder=1, **erp_axline_kwargs)
            
            # Set major xticks at 1, minor at 0.5
            ax.xaxis.set_major_locator(MultipleLocator(1))
            ax.xaxis.set_minor_locator(MultipleLocator(0.5))
            ax.yaxis.set_major_locator(MultipleLocator(2))
            ax.yaxis.set_minor_locator(MultipleLocator(1))
        
            for cur_cond, (cond_label, ecog_trials) in enumerate(seq_complexity_trials[cur_subject][alignment_label].items()):

                y = np.nanmean(ecog_trials, axis=0)               
                err = stats.sem(ecog_trials, axis=0, nan_policy='omit')
                x = np.linspace(alignments[alignment_label]['window'][0], alignments[alignment_label]['window'][1], y.shape[0])
                ax.plot(x, y, color=seq_colors[cur_cond], clip_on=False, label=sequence_conditions[cond_label]['target_sequence'], zorder=2)
                ax.fill_between(x, y - err, y + err, color=seq_colors[cur_cond], alpha=0.2, clip_on=False, zorder=2)
                
            ax.axes.set(xlim=alignments[alignment_label]['window'], **erp_plot_kwargs)
            
            ax.set_xlabel('Time (s)', fontsize=fs)
            ax.xaxis.set_tick_params(labelsize=fs)
            ax.yaxis.set_tick_params(labelsize=fs)
            
    an_kw = dict(fontsize=fs, annotation_clip=False)
    
    if annotate_elecs:
        for i in range(len(seq_elecs)):
            axs[i*2].scatter([-0.3], [5.25], s=300, facecolor='white', edgecolor='#404040', marker='o', clip_on=False)
            axs[i*2].annotate(f'e{i+1}', (-0.3, 5.25), va='center', ha='center', color='#404040', **an_kw)
    
    axs[3].legend(loc='upper right', bbox_to_anchor=(1, 1.23), frameon=False,
                  fontsize=legend_fs, ncol=3, columnspacing=0.75, handlelength=1)

    if return_fig:
        return fig, axs
    else:
        return axs

In [None]:
sequence_erps();

### ERP recon

In [None]:
def erp_recon(ax=None):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
        
    img = plt.imread(mni_img_path.format('lh', 'lateral'))
    ax.imshow(img, alpha=0.5)
    
    for i, (subj, elec) in enumerate(zip(seq_subjects, seq_elecs)):
        
        edf = all_edf.loc[all_edf.subject == subj]
        
        x = edf.loc[edf.electrode == elec].warp_x.values[0]
        y = edf.loc[edf.electrode == elec].warp_y.values[0]
        
        ax.scatter(x, y, color='k', s=20)
        ax.annotate(f'e{i+1}', (x, y-10))
        
    ax.axis('off');
    
    return ax

In [None]:
erp_recon();

### Logistic regression decoding of unique sequence

In [None]:
def log_reg_accuracy(ax=None, fs=10):

    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 4))

    cur_df = bacc.loc[bacc.area == 'all_sig']
    
    ax = sns.pointplot(data=cur_df, x='period', y='accuracy', hue='subject', order=decode_periods, scale=0.75,
                       errorbar=None, palette=['#404040' for _ in cur_df.subject.unique()], ax=ax)
    ax.legend_.remove()
    plt.setp(ax.collections, clip_on=False)
    plt.setp(ax.lines, clip_on=False)
    ax.axes.set(ylim=(0.0, 1.0), xlabel='', yticks=[0.0, 0.5, 1.0])
    
    ax.axhline(y=1/3, color='k', linestyle='--', linewidth=1)
    ax.annotate('Chance', (1, 1/3), va='top', ha='right', xycoords=('figure fraction', 'data'), 
                annotation_clip=False, color='#404040', fontsize=fs)
    
    ax.annotate('Chance', (1, 0.31), va='top', ha='right', color='#404040', fontsize=fs, xycoords=('axes fraction', 'data'))
    
    sns.despine(ax=ax, offset=dict(bottom=5, left=5), left=False)
    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='x', length=0)
    
    ax.set_xticklabels(['Encoding', 'Delay', 'Pre-speech', 'Speech'], fontsize=fs);
    ax.set_yticklabels(ax.get_yticks(), fontsize=fs);
    ax.set_ylabel('Average decoding accuracy', fontsize=fs);
    
    return ax

In [None]:
log_reg_accuracy();

### Overall figure - unique sequence decoding

In [None]:
##### Figure setup #####
default_plot_settings(font='Helvetica', linewidth=1)

all_axs = {}
fig = plt.figure(figsize=(10, 6))
fs = 10
gs = mpl.gridspec.GridSpec(100, 100, figure=fig)

outer_erp_gs = mpl.gridspec.GridSpecFromSubplotSpec(1, 2, 
                                                  subplot_spec=gs[:33, :90],
                                                  wspace=0.2)
erp1_gs = mpl.gridspec.GridSpecFromSubplotSpec(1, 2, 
                                              subplot_spec=outer_erp_gs[0, 0], 
                                              width_ratios=[2.5, 1.5],
                                              wspace=0.1)
erp2_gs = mpl.gridspec.GridSpecFromSubplotSpec(1, 2, 
                                              subplot_spec=outer_erp_gs[0, 1], 
                                              width_ratios=[2.5, 1.5],
                                              wspace=0.1)

########################

# Plot ERPs
all_axs['erps'] = [fig.add_subplot(erp1_gs[c]) for c in range(erp1_gs._ncols)]
all_axs['erps'].extend([fig.add_subplot(erp2_gs[c]) for c in range(erp2_gs._ncols)])
all_axs['erps'] = sequence_erps(axs=all_axs['erps'])

# Plot recon
all_axs['recon'] = fig.add_subplot(gs[50:, :35])
all_axs['recon'] = erp_recon(ax=all_axs['recon'])

# Plot decoding results
all_axs['decoding'] = fig.add_subplot(gs[50:, 46:95])
all_axs['decoding'] = log_reg_accuracy(ax=all_axs['decoding'])

########################

panel_kw = dict(ha='right', fontsize=14, weight='bold', color='#404040')
panel_axs = {}
panel_axs['a'] = fig.add_subplot(gs[2:7, 2:10])
panel_axs['b'] = fig.add_subplot(gs[56:61, 4:14])
panel_axs['c'] = fig.add_subplot(gs[56:61, 45:55])

for panel, ax in panel_axs.items():
    ax.axis('off')
    ax.annotate(panel, (-35, 35), xycoords='axes points', **panel_kw);

## Fig. S15 - Temporal cross correlations for target presentation, delay, and execution

In [None]:
# Load cross correlation results
corr_dist_compare = pd.read_hdf(os.path.join(data_dir, 'suppfig_cross_corr_compare.h5'))
    
with open(os.path.join(data_dir, 'suppfig_cross_corr_mats.pkl'), 'rb') as f:
    avg_plot_mats = pickle.load(f)
    
corr_dist_compare_periods = ['delay_condition', 'warp_execution_akt', 'warp_execution_akt_75jitter']
corr_periods = ['target_presentation_condition', 'delay_condition', 'warp_execution_condition', 'warp_execution_akt', 'warp_execution_akt_75jitter']
corr_windows = {
    'target_presentation_condition': [-0.25, 2.5],
    'delay_condition': [-0.25, 0.75],
    'warp_execution_condition': [-0.5, 1.5],
    'warp_execution_akt': [-0.5, 1.5],
    'warp_execution_akt_75jitter': [-0.5, 1.5]
}
corr_lim = 0.4

### Correlation distribution plot

In [None]:
def correlation_distribution(ax=None, fs=10):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))

    ax = sns.pointplot(data=corr_dist_compare, x='period', y='mean_90_corr', hue='subject', 
                       palette=['#404040' for _ in corr_dist_compare.subject.unique()],
                       order=corr_dist_compare_periods, errorbar=None, ax=ax, scale=0.75)
    ax.legend_.remove()
    plt.setp(ax.collections, clip_on=False)
    plt.setp(ax.lines, clip_on=False)
    ax.axes.set(ylim=(0.2, 0.70), xlabel='', yticks=[0.20, 0.45, 0.70])

    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='x', length=0)
    
    # Signed-rank stats test
    box_pairs, box_pair_pvals = [], []
    for c1, c2 in combinations(corr_dist_compare.period.unique(), 2):
        a = corr_dist_compare.loc[corr_dist_compare.period == c1].sort_values('subject').mean_90_corr.values
        b = corr_dist_compare.loc[corr_dist_compare.period == c2].sort_values('subject').mean_90_corr.values
        stat, pval = stats.wilcoxon(a, b)
        print((c1, c2), stat)
        box_pairs.append((c1, c2))
        box_pair_pvals.append(pval)
        
    # FDR correction
    _, box_pair_pvals = fdrcorrection(box_pair_pvals, alpha=sig_thresh, method='poscorr', is_sorted=False)
    
    add_stat_annotation(ax, data=corr_dist_compare, x='period', y='mean_90_corr', order=corr_dist_compare_periods,
                        perform_stat_test=False, pvalues=box_pair_pvals, box_pairs=box_pairs, color='#404040', linewidth=1,
                        text_format='star', loc='outside', fontsize='smaller', verbose=True, line_offset=0.01);
        
    ax.set_xticklabels(['Delay', 'Speech\n(AKT)', 'Jittered\nspeech (AKT)'], fontsize=fs);
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=fs);
    ax.set_ylabel('Mean of $\mathregular{90^{th}}$ percentile\ncorrelation distribution', fontsize=fs);
    
    ax.set_xlim(-0.25, None)

    return ax

In [None]:
correlation_distribution();

### Correlation matrices

In [None]:
def correlation_matrices(axs=None, plot_linewidth=1, fig=None, fs=10):
    
    if axs is None:
        fig, axs = plt.subplots(2, 3, figsize=(10, 10))
        
    for i, (ax, cur_period) in enumerate(zip(axs.ravel(), corr_periods)):
        
        if cur_period.startswith('target_presentation'):
            time_label = 'target pres.'
        elif cur_period.startswith('delay'):
            time_label = 'fixation cue'
        elif 'jitter' in cur_period:
            time_label = 'jittered speech onset'
        elif 'execution' in cur_period:
            time_label = 'speech onset'
        
        plot_mat = avg_plot_mats[cur_period]

        x = np.linspace(*corr_windows[cur_period], plot_mat.shape[0])
        im = ax.pcolormesh(x, x, plot_mat, vmin=0, vmax=corr_lim, cmap=sns.color_palette('flare_r', as_cmap=True))
        ax.axvline(x=0, color='k', linewidth=plot_linewidth)
        ax.axhline(y=0, color='k', linewidth=plot_linewidth)
        
        ax.set_aspect('equal', adjustable='box')
        ax.spines['top'].set_visible(True)
        ax.spines['right'].set_visible(True)
        
        # Set major xticks at 0.5, minor at 0.25
        ax.xaxis.set_major_locator(MultipleLocator(0.5))
        ax.xaxis.set_minor_locator(MultipleLocator(0.25))
        ax.yaxis.set_major_locator(MultipleLocator(0.5))
        ax.yaxis.set_minor_locator(MultipleLocator(0.25))
        
        # Set tick label fontsize
        ax.tick_params(axis='both', labelsize=fs, pad=1)
        
        ax.axes.set(xlim=(x[0], x[-1]), ylim=(x[0], x[-1]))
        
        # Set x and y labels
        ax.set_xlabel(f'Time from {time_label} (s)', fontsize=fs, labelpad=5)
        
    return axs

### Overall figure - cross-trial correlations

In [None]:
##### Figure setup #####
default_plot_settings(font='Helvetica', linewidth=1)

all_axs = {}
fig = plt.figure(figsize=(10, 6))
fs = 10
gs = mpl.gridspec.GridSpec(100, 100, figure=fig)
corr_gs = mpl.gridspec.GridSpecFromSubplotSpec(2, 3, 
                                              subplot_spec=gs[:, :89],
                                              wspace=0.2,
                                              hspace=0.4)

########################
    
# Cross-correlations
all_axs['cc'] = []
for r in range(corr_gs._nrows):
    all_axs['cc'].append([])
    for c in range(corr_gs._ncols):
        all_axs['cc'][-1].append(fig.add_subplot(corr_gs[r, c]))
all_axs['cc'] = np.array(all_axs['cc'])
all_axs['cc'] = correlation_matrices(axs=all_axs['cc']);
all_axs['cc'][-1, -1].axis('off')

# Cross-correlation colorbar
all_axs['cc_cbar'] = fig.add_subplot(gs[:42, 89:91])
cmap = sns.color_palette('flare_r', as_cmap=True)
norm = mpl.colors.Normalize(vmin=0, vmax=corr_lim)
cb1 = mpl.colorbar.ColorbarBase(all_axs['cc_cbar'], cmap=cmap, norm=norm, orientation='vertical')
cb1.set_label('Correlation coefficient', fontsize=10)
cb1.ax.tick_params(labelsize=10)

# Cross-correlation distribution
all_axs['cc_dist'] = fig.add_subplot(gs[67:, 70:95])
all_axs['cc_dist'] = correlation_distribution(ax=all_axs['cc_dist']);

########################

panel_kw = dict(ha='right', fontsize=14, weight='bold', color='#404040')

panel_axs = {}
panel_axs['f'] = fig.add_subplot(gs[71:76, 68:78])

for panel, ax in panel_axs.items():
    ax.axis('off')
    ax.annotate(panel, (-35, 65), xycoords='axes points', **panel_kw);
    
for panel, ax in zip(['a', 'b', 'c', 'd', 'e'], all_axs['cc'].ravel()):
    ax.annotate(panel, (-0.15, 1.05), xycoords='axes fraction', **panel_kw);