# Figure 3

In [None]:
import json
import os
import copy
import pickle
from IPython import display as ipd

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.legend_handler import HandlerTuple
from matplotlib.lines import Line2D
from matplotlib.patches import ConnectionPatch
from matplotlib import patches as mpatches
from matplotlib.ticker import MultipleLocator, FormatStrFormatter, AutoMinorLocator
from scipy import stats
from statannot import add_stat_annotation

from sylseq_paper.plotting import default_plot_settings, ucsf_sequential_color_palette as colors, smoothed_weighted_histogram
from sylseq_paper.statistics import correlation_permutation

default_plot_settings(font='Helvetica', fontsize=14)

import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 

%load_ext autoreload
%autoreload 2

In [None]:
# Set up paths
data_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'data'))
img_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'imaging'))

boxplot_kwargs = {}
for key in ['cap', 'whisker', 'flier', 'median', 'mean']:
    boxplot_kwargs[f'{key}props'] = dict(clip_on=False)
boxplot_kwargs['boxprops'] = dict(alpha=0.75, clip_on=False)

pvalue_thresholds = [[0.001, "***"], [0.01, "**"], [0.05, "*"], [10, "ns"]]

akt_r_cutoff = 0.1
annot_lw = 0.5

task_target_sequence_compare = {
    'simpleSeq_simpleSyl': ['baa-baa-baa', 'daa-daa-daa', 'gaa-gaa-gaa'],
    'complexSeq_simpleSyl': ['baa-daa-gaa', 'daa-baa-gaa', 'gaa-daa-baa', 'baa-gaa-daa', 'daa-gaa-baa', 'gaa-baa-daa']
}
task_target_syllable_compare = {
    'complexSeq_simpleSyl': ['baa-daa-gaa', 'daa-baa-gaa', 'gaa-daa-baa', 'baa-gaa-daa', 'daa-gaa-baa', 'gaa-baa-daa'],
    'complexSeq_complexSyl': ['blaa-draa-gloo', 'draa-blaa-gloo', 'gloo-draa-blaa', 'blaa-gloo-draa', 'draa-gloo-blaa', 'gloo-blaa-draa']
}

# Load bad channels
with open(os.path.join(data_dir, 'bad_channels.pkl'), 'rb') as f:
    bad_channels = pickle.load(f)
    
plot_subjects = ['EC217', 'EC219', 'EC223', 'EC237', 'EC240', 'EC241', 'EC253', 'EC254', 'EC260', 'EC263', 'EC276']
exclude_time_periods = {
    'EC276': 'pre_exec'
}
all_subjects = copy.copy(plot_subjects)

exclude_areas = ['superiorparietal', 'inferiortemporal', 'paracentral', 'medial']

## Figure panels

### Significant elecs, `all_sig_elecs` info

In [None]:
plot_views = ['lateral']

sig_thresh = 0.05

mni_img_path = os.path.join(img_dir, 'MNI_{}_{}_brain_2D.png')
mprcg_roi_path = os.path.join(img_dir, 'MNI_{}_{}_brain_2D_mprcg_roi.png')
warped_coord_paths = {
    'lateral': os.path.join(img_dir, '{}_lateral_elecmat_2D_warped.npy'),
    'medial': os.path.join(img_dir, '{}_medial_elecmat_2D_warped.npy')
}

plot_periods = ['target_pres', 'delay', 'pre_exec']
period_labels = {
    'target_pres': 'Encoding',
    'delay': 'Delay',
    'pre_exec': 'Pre-speech'
}

all_bc = {key: np.unique(np.concatenate(list(val.values()))) for key, val in bad_channels.items()}

# Load electrode and phase significant information
all_sig_elecs = pd.read_hdf(os.path.join(data_dir, 'fig3_electrode_information.h5'))

In [None]:
period_colors = {
    'target_pres': 'C0',
    'delay': 'orange',
    'pre_exec': 'green'
}
sig_elec_vmin = 25
sig_elec_vmax = 75
recon_cbar_ticks = np.arange(sig_elec_vmin, sig_elec_vmax + 1, 25)
period_cmaps = {}
for cur_period, cur_color in period_colors.items():
    period_cmaps[cur_period] = mpl.colors.LinearSegmentedColormap.from_list('custom', ['white', cur_color], N=100)

### Load `seq_vs_syl_df`

In [None]:
# Load sequence and articulatory (syl) complexity df
seq_vs_syl_df = pd.read_hdf(os.path.join(data_dir, 'fig3_seq_vs_syl_df.h5'))
seq_syl_points = pd.read_hdf(os.path.join(data_dir, 'fig3_seq_syl_points.h5'))

seq_vs_syl_df = seq_vs_syl_df.loc[seq_vs_syl_df.subject.isin(plot_subjects)]

# Exclude EC276 pre-speech bc of long reaction times
for subject, period in exclude_time_periods.items():
    seq_vs_syl_df.drop(seq_vs_syl_df.loc[(seq_vs_syl_df.subject == subject) & (seq_vs_syl_df.alignment == period)].index, inplace=True)

density_cmap = mpl.colors.LinearSegmentedColormap.from_list('custom', [colors[3], '#3dc2ff'], N=100)

seq_syl_encoding_colors = {
    'syllable': colors[3],
    'sequence': colors[0],
    'both': colors[2]
}
seq_syl_encoding_cmaps = {}
for area, color_hex in seq_syl_encoding_colors.items():
    seq_syl_encoding_cmaps[area] = mpl.colors.LinearSegmentedColormap.from_list(area, ["white", color_hex])

new_seq_syl_labels = {
    'sequence': 'Seq. complexity', 
    'syllable': 'Artic. complexity', 
    'both': 'Seq. & Artic. complexity',
    'neither': 'Neither'
}
full_name_feature_labels = {
    'sequence': 'Sequence',
    'syllable': 'Articulatory'
}

# Add complexity category for all significant elecs
complexity_category = []
max_mag = []

for cur_subject, cur_elec, cur_period in zip(all_sig_elecs.subject.values, all_sig_elecs.electrode.values, all_sig_elecs.period.values):
    
    cond = (seq_vs_syl_df.subject == cur_subject) & (seq_vs_syl_df.electrode == cur_elec) & (seq_vs_syl_df.alignment == cur_period)
    se_df = seq_vs_syl_df.loc[cond]

    if se_df.shape[0] == 0:
        complexity_category.append('none')
        max_mag.append(np.nan)
    else:
        syl_sig = se_df.loc[se_df.feature_type == 'syllable'].fdr_pval.values[0] < sig_thresh
        seq_sig = se_df.loc[se_df.feature_type == 'sequence'].fdr_pval.values[0] < sig_thresh
        
        if syl_sig and seq_sig:
            complexity_category.append('both')
        elif syl_sig:
            complexity_category.append('syllable')
        elif seq_sig:
            complexity_category.append('sequence')
        else:
            complexity_category.append('neither')
            
        max_mag.append(se_df.loc[se_df.feature_type == 'syllable'].max_complex_hga_magnitude.values[0])

all_sig_elecs['complexity_significance'] = complexity_category
all_sig_elecs['max_complex_hga_magnitude'] = max_mag

#### Load sustained cluster elecs

In [None]:
# Load sustained electrode list
with open(os.path.join(data_dir, 'fig1_ndf_sustained_se.pkl'), 'rb') as f:
    results = pickle.load(f)
    sustained_se = results['sustained_se']

### Load auditory response and AKT model dataframes

In [None]:
adf = pd.read_hdf(os.path.join(data_dir, 'auditory_responses.h5'))
akt_model = pd.read_hdf(os.path.join(data_dir, 'seqsyl_AKT_model_table.h5'))

### Sequence complexity ERP

In [None]:
seq_complexity_labels = [
    'Simple syl., simple seq.\n(baa-baa-baa)',
    'Simple syl., complex seq.\n(baa-daa-gaa)',
    'Complex syl., complex seq.\n(blaa-draa-gloo)'
]
complexity_colors = {
    'Repeated CV': colors[4],
    'CV sequence': colors[1],
    'CCV sequence': colors[5]
}
sns.palplot(list(complexity_colors.values()))

In [None]:
seq_complexity_conditions = {
    'Repeated CV': {'sequence_type': 'simple', 'syllable_type': 'simple'},
    'CV sequence': {'sequence_type': 'complex', 'syllable_type': 'simple'},
    'CCV sequence': {'sequence_type': 'complex', 'syllable_type': 'complex'}
}
seq_complexity_window = [-1, 1]
erp_sr = 200
sig_thresh = 0.05
seq_complexity_subjects = ['EC260', 'EC254', 'EC217']
seq_complexity_elecs = [60, 135, 7]

seq_complexity_se = [f'{s}_{e}' for s, e in zip(seq_complexity_subjects, seq_complexity_elecs)]

alignments = {
    'tp_and_delay': {
        'window': [-0.5, 2.5 + 1],
        'phase': 'target_presentation',
        'phase_letter': 'p'
    },
    'speech': {
        'window': [-0.5, 1.0],
        'phase': 'first_syllable',
        'phase_letter': 'e'
    }
}

sig_time_periods = {
    'target_pres': {
        'window': [0.5, 1.5],
        'phase' : 'target_presentation'
    },
    'delay'      : {
        'window': [0.0, 0.75],
        'phase' : 'fixation_cue'
    },
    'pre_exec'   : {
        'window': [-0.5, -0.1],
        'phase' : 'first_syllable'
    }
}
    
with open(os.path.join(data_dir, 'fig3_erp_examples.pkl'), 'rb') as f:
    seq_complexity_trials = pickle.load(f)

In [None]:
def complexity_erps(axs=None, legend_bbox=(0.125, -0.08), annotate_elecs=True, return_fig=False, fs=12, elec_fs=18, plot_first_elec_only=False, s=500):
    
    if axs is None:
        fig, axs = plt.subplots(3, 2, figsize=(10, 12), gridspec_kw=dict(hspace=0.25, width_ratios=[3, 1.75]))
        
    erp_plot_kwargs = {
        'ylim': [-1, 4],
    }
    erp_axline_kwargs = {
        'linestyle': '--',
        'color': '#404040'
    }
        
    for row, (cur_subject, cur_elec) in enumerate(zip(seq_complexity_subjects, seq_complexity_elecs)):
        
        if plot_first_elec_only and row != 0:
            continue
        
        axs[row, 0].axes.set(ylabel='HGA (z-score)')
        
        se = f'{cur_subject}_{cur_elec}'
        
        for col, alignment_label in enumerate(alignments.keys()):
        
            ax = axs[row, col]
            
            if alignment_label == 'tp_and_delay':
                
                # Find complexity effects
                tp_x = sig_time_periods['target_pres']['window']
                delay_x = list(np.array(sig_time_periods['delay']['window']) + 2.5)
                tp_sig = all_sig_elecs.loc[(all_sig_elecs.subject_electrode == se) & 
                                           (all_sig_elecs.period == 'target_pres')].complexity_significance.values[0]
                delay_sig = all_sig_elecs.loc[(all_sig_elecs.subject_electrode == se) & 
                                           (all_sig_elecs.period == 'delay')].complexity_significance.values[0]
                ax.axvspan(tp_x[0], tp_x[1], 0.95, 1, ec='None',
                           color=seq_syl_encoding_colors[tp_sig], alpha=0.5)
                ax.axvspan(delay_x[0], delay_x[1], 0.95, 1, ec='None',
                           color=seq_syl_encoding_colors[delay_sig], alpha=0.5)

                ax.axvline(x=0, alpha=0.7, zorder=1, **erp_axline_kwargs)
                ax.axvline(x=2.5, alpha=0.7, zorder=1, **erp_axline_kwargs)
                
                if row == 0:
                    ax.annotate('Target\npres.', (0.0, erp_plot_kwargs['ylim'][1]), ha='center', va='bottom')
                    ax.annotate('Fixation\ncross', (2.5, erp_plot_kwargs['ylim'][1]), ha='center', va='bottom')
                    ax.annotate('Encoding', (sig_time_periods['target_pres']['window'][0], erp_plot_kwargs['ylim'][0]-0.1), 
                                annotation_clip=False, ha='left', va='top')
                    ax.annotate('Delay', (sig_time_periods['delay']['window'][0] + 2.5, erp_plot_kwargs['ylim'][0]-0.1), 
                                annotation_clip=False, ha='left', va='top')
                
                ax.axvspan(sig_time_periods['target_pres']['window'][0],
                           sig_time_periods['target_pres']['window'][1], 
                           fc='k', alpha=0.1, zorder=0, ec='None')
                ax.axvspan(sig_time_periods['delay']['window'][0] + 2.5,
                           sig_time_periods['delay']['window'][1] + 2.5, 
                           fc='k', alpha=0.1, zorder=0, ec='None')
            
            elif alignment_label == 'speech':
                
                # Find complexity effects
                pe_x = sig_time_periods['pre_exec']['window']
                pe_sig = all_sig_elecs.loc[(all_sig_elecs.subject_electrode == se) & 
                                           (all_sig_elecs.period == 'pre_exec')].complexity_significance.values[0]
                ax.axvspan(pe_x[0], pe_x[1], 0.95, 1, ec='None',
                           color=seq_syl_encoding_colors[pe_sig], alpha=0.5)
                
                ax.axvline(x=0, alpha=0.7, zorder=1, **erp_axline_kwargs)
                
                if row == 0:
                    ax.annotate('Speech\nonset', (0.0, erp_plot_kwargs['ylim'][1]), ha='center', va='bottom')
                    ax.annotate('Pre-speech', (sig_time_periods['pre_exec']['window'][0], erp_plot_kwargs['ylim'][0]-0.1), 
                                annotation_clip=False, ha='left', va='top')
                    
                ax.axvspan(sig_time_periods['pre_exec']['window'][0],
                           sig_time_periods['pre_exec']['window'][1], 
                           fc='k', alpha=0.1, zorder=0, ec='None')
                
            if row == 0 and col == 0:
                despine_kw = dict(offset=dict(bottom=8, left=2.5))
            elif row == 0 and col != 0:
                despine_kw = dict(offset=dict(bottom=8, left=2.5), left=True)
            elif row != 0 and col == 0:
                despine_kw = dict(offset=dict(bottom=2.5, left=2.5))
            else:
                despine_kw = dict(offset=dict(bottom=2.5, left=2.5), left=True)
                
            sns.despine(ax=ax, **despine_kw)
                
            if col != 0:
                ax.get_yaxis().set_visible(False)

            ax.axhline(y=0, alpha=0.7, zorder=1, **erp_axline_kwargs)
            
            # Set major xticks at 1, minor at 0.5
            ax.xaxis.set_major_locator(MultipleLocator(1))
            ax.xaxis.set_minor_locator(MultipleLocator(0.5))
            ax.yaxis.set_major_locator(MultipleLocator(2))
            ax.yaxis.set_minor_locator(MultipleLocator(1))
        
            for cur_cond, (cond_label, ecog_trials) in enumerate(seq_complexity_trials[cur_subject][alignment_label].items()):

                y = np.nanmean(ecog_trials, axis=0)               
                err = stats.sem(ecog_trials, axis=0, nan_policy='omit')
                x = np.linspace(alignments[alignment_label]['window'][0], alignments[alignment_label]['window'][1], y.shape[0])
                ax.plot(x, y, color=complexity_colors[cond_label], clip_on=False, label=seq_complexity_labels[cur_cond], zorder=2)
                ax.fill_between(x, y - err, y + err, color=complexity_colors[cond_label], alpha=0.2, clip_on=False, zorder=2)
                
            ax.axes.set(xlim=alignments[alignment_label]['window'], **erp_plot_kwargs)
            
    an_kw = dict(fontsize=elec_fs, annotation_clip=False)
    
    if annotate_elecs:
        for i in range(len(seq_complexity_elecs)):
            axs[i, 0].scatter([-0.3], [3.25], s=s, facecolor='white', edgecolor='k', marker='o', clip_on=False)
            axs[i, 0].annotate(f'e{i+1}', (-0.3, 3.25), va='center', ha='center', **an_kw)
    
    axs[2, 0].axes.set(xlabel='Time (s)')
    axs[2, 1].axes.set(xlabel='Time (s)')
    
    if plot_first_elec_only:
        axs[0, 0].axes.set(xlabel='Time (s)')
        axs[0, 1].axes.set(xlabel='Time (s)')
        axs[1, 0].axis('off')
        axs[1, 1].axis('off')
    
    # Custom legend because it will span multiple panels
    handles, labels = [], []
    for cur_cond, (cond_label, _) in enumerate(seq_complexity_trials[cur_subject][alignment_label].items()):
        handles.append((mpatches.Patch(color=complexity_colors[cond_label]), Line2D([0], [0], color=complexity_colors[cond_label])))
        labels.append(seq_complexity_labels[cur_cond])
    axs[0, 0].legend(handles=handles, labels=labels, handler_map={tuple: HandlerTuple(ndivide=None)},
                     loc='lower left', bbox_to_anchor=legend_bbox, frameon=False, 
                     fontsize=fs, ncol=3)
    
    if return_fig:
        return fig, axs
    else:
        return axs

In [None]:
f, _ = complexity_erps(annotate_elecs=False, return_fig=True, fs=10, plot_first_elec_only=False, legend_bbox=(0.125, 1.15));

### Single electrode encoding plot

In [None]:
def single_elec_stat_plot(axs=None, fs=12, source_stats={}, return_source_stats=False):
    
    if axs is None:
        fig, axs = plt.subplots(3, 2, figsize=(5, 10))
        
    source_stats = {key: [] for key in ['Subject', 'Comparison', 'Sample sizes', 'Corrected P-value', 'Statistic']}

    for row, (subj, elec) in enumerate(zip(seq_complexity_subjects, seq_complexity_elecs)):
        
        if row == 2:
            ylim = (-1, 5)
        else:
            ylim = (-1, 4)
        
        for col, feat_type in enumerate(['sequence', 'syllable']):

            ax = axs[row, col]
            se = f'{subj}_{elec}'
            
            print(feat_type, se, flush=True)
            
            # choose colors
            if feat_type == 'sequence':
                palette = [complexity_colors[key] for key in ['Repeated CV', 'CV sequence']]
            elif feat_type == 'syllable':
                palette = [complexity_colors[key] for key in ['CV sequence', 'CCV sequence']]

            cur_df = seq_syl_points.loc[(seq_syl_points.subject_electrode == se) & (seq_syl_points.feature_type == feat_type)]

            ax = sns.boxplot(data=cur_df, x='alignment', y='point_value', hue='point_type', 
                             order=plot_periods, hue_order=['simple', 'complex'],
                             palette=palette, width=0.5,
                            showfliers=False, ax=ax, **boxplot_kwargs)
            
            sns.despine(ax=ax, offset=dict(left=5, bottom=7))
            ax.spines['bottom'].set_visible(False)
            ax.xaxis.set_tick_params(length=0)

            box_pairs, box_pvals = [], []
            trial_labels = []
            for pp in plot_periods:
                
                cur_seq_df = seq_vs_syl_df.loc[(seq_vs_syl_df.alignment == pp) & 
                                         (seq_vs_syl_df.subject_electrode == se) & 
                                         (seq_vs_syl_df.feature_type == feat_type)]
                
                trial_labels.append([cur_df.loc[(cur_df.alignment == pp) & (cur_df.point_type == 'simple')].shape[0],
                                     cur_df.loc[(cur_df.alignment == pp) & (cur_df.point_type == 'complex')].shape[0]])
                
                pval = cur_seq_df.fdr_pval.values[0]
                z = cur_seq_df.statistic.values[0]
                
                source_stats['Subject'].append(subj)
                source_stats['Sample sizes'].append(trial_labels[-1])
                source_stats['Corrected P-value'].append(pval)
                source_stats['Statistic'].append(z)
                
                if feat_type == 'sequence':
                    source_stats['Comparison'].append('Simple syl., simple seq. vs. Simple syl., complex seq.')
                elif feat_type == 'syllable':
                    source_stats['Comparison'].append('Simple syl., complex seq. vs. Complex syl., complex seq.')
                
                if pval < sig_thresh:
                    box_pairs.append(((pp, 'simple'), (pp, 'complex')))
                    box_pvals.append(pval)
                
            for cur_per, label in enumerate(trial_labels):
                ax.annotate(f'{label[0]}', (cur_per-0.025, -1.18), fontsize=fs, color=palette[0],
                            annotation_clip=False, va='top', ha='right', xycoords='data')
                ax.annotate(',', (cur_per, -1.18), fontsize=fs, color='#404040',
                            annotation_clip=False, va='top', ha='center', xycoords='data')
                ax.annotate(f'{label[1]}', (cur_per+0.025, -1.18), fontsize=fs, color=palette[1],
                            annotation_clip=False, va='top', ha='left', xycoords='data')
            
            if len(box_pairs) != 0:
                add_stat_annotation(data=cur_df, x='alignment', y='point_value', hue='point_type', 
                                    order=plot_periods, hue_order=['simple', 'complex'],
                                    text_format='star', ax=ax, verbose=1, linewidth=annot_lw,
                                    box_pairs=box_pairs, perform_stat_test=False, pvalues=box_pvals,
                                    loc='outside', pvalue_format_string='{:.3f}', fontsize=fs, pvalue_thresholds=pvalue_thresholds)

            ax.axes.set(ylim=ylim, xlabel='', ylabel='',
                        xticklabels=[])

            # Set major xticks at 1, minor at 0.5
            ax.yaxis.set_major_locator(MultipleLocator(2))
            ax.yaxis.set_minor_locator(MultipleLocator(1))
                
            if row == 2:
                ax.axes.set_xticklabels([period_labels[p] for p in plot_periods], fontsize=fs)
                
                if col == 0:
                    ax.axes.set(ylabel='Time-averaged HGA')

                # Fix legend
                hand, labl = ax.get_legend_handles_labels()
                handout, lablout = [], []
                for h, l in zip(hand, labl):
                    if l not in lablout and type(h) == mpl.collections.PathCollection:
                        lablout.append(l)
                        handout.append(h)
                lablout = [f'{l.capitalize()}\n{feat_type}' for l in lablout]
                ax.legend(handout, lablout, frameon=False, bbox_to_anchor=(0.5, 1), loc='center', ncol=2);
            else:
                ax.get_legend().remove()
        
    source_stats = pd.DataFrame(data=source_stats)
    
    if return_source_stats:
        return axs, source_stats
    else:
        return axs

In [None]:
_, ss = single_elec_stat_plot(return_source_stats=True);

### Sequence vs syllable complexity plots

In [None]:
def seq_syl_z_scatter(axs=None, horizontal=False, return_fig=False, fs=14, elec_fs=18, ms=20, label_s=500):
    
    if axs is None:
        if horizontal:
            fig, axs = plt.subplots(1, 3, figsize=(12, 3))
        else:
            fig, axs = plt.subplots(3, 1, figsize=(5, 10))
        
    alim = [0, 10]
    ticks = [0, 2, 4, 6, 8, 10]
    
    elec_labels = {
        'e1': [[3, 9], [1, 9.5], [1, 9.5]],
        'e2': [[1, 9.5], [3, 9], [3, 9]],
        'e3': [[9, 7], [8, 2], [8, 2]]
    }

    for cur_period, (ax, alignment) in enumerate(zip(axs, plot_periods)):
        
        sns.despine(ax=ax, offset=dict(left=5, bottom=5))

        ax.plot(alim, alim, linestyle='--', color='#404040', zorder=0, alpha=0.6)

        cur_df = seq_vs_syl_df.loc[(seq_vs_syl_df.alignment == alignment) & 
                                   (seq_vs_syl_df.subject_electrode.isin(sustained_se))]

        first_seq, first_syl, first_both = True, True, True
        
        plot_ses = list(cur_df.subject_electrode.unique())
        plot_ses.extend(seq_complexity_se)
        
        plot_ses_seq = []
        plot_ses_syl = []

        for se in plot_ses:

            seq_pval = cur_df.loc[(cur_df.subject_electrode == se) & (cur_df.feature_type == 'sequence')].fdr_pval.values[0]
            syl_pval = cur_df.loc[(cur_df.subject_electrode == se) & (cur_df.feature_type == 'syllable')].fdr_pval.values[0]

            if seq_pval < sig_thresh and syl_pval < sig_thresh:
                kwargs = {'fc': seq_syl_encoding_colors['both']}
                if first_both:
                    kwargs['label'] = new_seq_syl_labels['both']
                    first_both = False

            elif seq_pval < sig_thresh:
                kwargs = {'fc': seq_syl_encoding_colors['sequence']}
                if first_seq:
                    kwargs['label'] = new_seq_syl_labels['sequence']
                    first_seq = False

            elif syl_pval < sig_thresh:
                kwargs = {'fc': seq_syl_encoding_colors['syllable']}
                if first_syl:
                    kwargs['label'] = new_seq_syl_labels['syllable']
                    first_syl = False

            else:
                continue
                
            if se in seq_complexity_se:
                kwargs['ec'] = 'k'
                kwargs['s'] = 2*ms
                kwargs['alpha'] = 1.0
                kwargs['zorder'] = 2
            else:
                kwargs['s'] = ms
                kwargs['alpha'] = 0.75
                kwargs['zorder'] = 0
            
            a = np.abs(cur_df.loc[(cur_df.subject_electrode == se) & (cur_df.feature_type == 'syllable')].statistic.values[0])
            b = np.abs(cur_df.loc[(cur_df.subject_electrode == se) & (cur_df.feature_type == 'sequence')].statistic.values[0])
            ax.scatter(a, b, clip_on=False, **kwargs)
            
            if ~np.isnan(a) and ~np.isnan(b):
                plot_ses_syl.append(a)
                plot_ses_seq.append(b)
            else:
                print('skipping', se)
            
            if se in seq_complexity_se:
                # Label example elecs
                cur_se = np.where(np.array(seq_complexity_se) == se)[0][0]
                x, y = elec_labels[f'e{cur_se+1}'][cur_period]
                ax.plot([x, a], [y, b], color='k', linewidth=0.25, zorder=1)

                an_kw = dict(fontsize=elec_fs, annotation_clip=False)
                ax.scatter([x], [y], s=label_s, facecolor='white', edgecolor='k', marker='o', clip_on=False, zorder=2)
                ax.annotate(f'e{cur_se+1}', (x, y), va='center', ha='center', zorder=3, **an_kw)
            
            ax.axes.set(box_aspect=1, xlim=alim, ylim=alim, xticks=ticks, yticks=ticks)
            ax.set_xlabel('{}\nabs(z-value)'.format(new_seq_syl_labels['syllable']), fontsize=fs)
            ax.set_ylabel('{}\nabs(z-value)'.format(new_seq_syl_labels['sequence']), fontsize=fs)
        
        if cur_period == 0:
            handles, labels = ax.get_legend_handles_labels()
            l = ax.legend(handles, ['Seq. only', 'Seq. & artic.', 'Artic. only'], frameon=False, 
                      bbox_to_anchor=(0.5, 0.2), loc='lower left', handletextpad=-0.25)
            l.set_zorder(20)
            
    if axs is None:
        fig.tight_layout();
            
    if return_fig:
        return fig, axs
    else:
        return axs

In [None]:
f, _ = seq_syl_z_scatter(horizontal=False, return_fig=True, fs=14);

### Auditory overlap

In [None]:
def auditory_control(ax=None, s=15, plot_legend=False, legend_anchor=-0.3):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
        
    aud_control_se = all_sig_elecs.loc[(all_sig_elecs.complexity_significance.isin(['sequence', 'both'])) &
                                       (all_sig_elecs.subject_electrode.isin(sustained_se))].subject_electrode.unique()
    pre_exec_seq_se = all_sig_elecs.loc[(all_sig_elecs.period == 'pre_exec') & 
                                        (all_sig_elecs.subject_electrode.isin(sustained_se)) &
                                        (all_sig_elecs.complexity_significance.isin(['sequence', 'both']))].subject_electrode.unique()

    zs, auds, sigs, locs = [], [], [], []
    feat_type = 'sequence'

    for se in aud_control_se:

        if se not in adf.subject_electrode.values:
            continue

        cur_adf = adf.loc[adf.subject_electrode == se]
        auds.append(abs(cur_adf.z_val.values[0]))

        potential_zs = seq_vs_syl_df.loc[(seq_vs_syl_df.subject_electrode == se) & 
                                         (seq_vs_syl_df.feature_type == 'sequence') & 
                                         (seq_vs_syl_df.alignment.isin(plot_periods))].statistic.abs().dropna().values

        zs.append(max(potential_zs))
        sigs.append(cur_adf.significant.values[0])
        
        loc = all_sig_elecs.loc[all_sig_elecs.subject_electrode == se].location.values[0]
        
        if loc == 'middleprecentral' and se in pre_exec_seq_se:
            locs.append('mprcg')
        elif loc == 'posteriorsuperiortemporal' and se in pre_exec_seq_se:
            locs.append('pstg')
        else:
            locs.append('other')

    auds = np.array(auds)
    zs = np.array(zs)
    sigs = np.array(sigs)
    locs = np.array(locs)

    ylim = [0, 10]
    xlim = [0, 25]

    sns.despine(ax=ax, offset=dict(left=5, bottom=5))
    
    idx = np.intersect1d(np.where(locs=='other')[0], np.where(sigs==True)[0])
    if len(idx) != 0:
        ax.scatter(auds[idx], zs[idx], s=s, ec='k', fc='k', clip_on=False, alpha=0.25, label='Other areas', zorder=1)
    
    idx = np.intersect1d(np.where(locs=='other')[0], np.where(sigs==False)[0])
    if len(idx) != 0:
        ax.scatter(auds[idx], zs[idx], s=s, ec='k', fc='None', clip_on=False, alpha=0.15, zorder=1)
    
    idx = np.intersect1d(np.where(locs=='pstg')[0], np.where(sigs==True)[0])
    if len(idx) != 0:
        ax.scatter(auds[idx], zs[idx], s=3*s, fc=seq_syl_encoding_colors[feat_type], ec='k', clip_on=False, label='pSTG', zorder=2)
    
    idx = np.intersect1d(np.where(locs=='pstg')[0], np.where(sigs==False)[0])
    if len(idx) != 0:
        ax.scatter(auds[idx], zs[idx], s=3*s, ec=seq_syl_encoding_colors[feat_type], fc='None', clip_on=False, zorder=2)
    
    idx = np.intersect1d(np.where(locs=='mprcg')[0], np.where(sigs==True)[0])
    if len(idx) != 0:
        ax.scatter(auds[idx], zs[idx], s=3*s, marker='^', fc=seq_syl_encoding_colors[feat_type], ec='k', clip_on=False, label='mPrCG', zorder=3)
    
    idx = np.intersect1d(np.where(locs=='mprcg')[0], np.where(sigs==False)[0])
    if len(idx) != 0:
        ax.scatter(auds[idx], zs[idx], s=3*s, marker='^', ec=seq_syl_encoding_colors[feat_type], fc='None', clip_on=False, zorder=3)
    
    if plot_legend:
        handles, labels = ax.get_legend_handles_labels()
        handles.append(mpl.lines.Line2D([0], [0], color='w', markerfacecolor='None', markeredgecolor='k', marker='s', markersize=s//2))
        labels.append('No significant\nauditory response')
        ax.legend(handles, labels, loc='upper right', bbox_to_anchor=(legend_anchor, 1.01), frameon=False)
    
    # Compute the correlation for electrodes with significant auditory responses
    sig_idx = np.where(sigs)[0]
    corr, p = correlation_permutation(auds[sig_idx], zs[sig_idx], n_permute=1000, random_seed=46)
    
    ax.set_title('$\it{r}$=' + f'{corr:0.2f}' + ' (P=' + f'{p:0.3f})')
    ax.axes.set(xlabel='Auditory response\nabs(z-value)', ylabel=f'{feat_type.capitalize()} complexity\nabs(z-value)')
    ax.axes.set(ylim=ylim, xlim=xlim, xticks=range(xlim[0], xlim[1]+1, 5));
    
    return ax

In [None]:
auditory_control(s=20, plot_legend=True);

### AKT elec overlap

In [None]:
def akt_control(ax=None, weight_type=None, plot_legend=False, s=20):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
        
    akt_control_se = all_sig_elecs.loc[(all_sig_elecs.complexity_significance.isin(['sequence', 'both'])) &
                                       (all_sig_elecs.subject_electrode.isin(sustained_se))].subject_electrode.unique()
    pre_exec_seq_se = all_sig_elecs.loc[(all_sig_elecs.period == 'pre_exec') & 
                                        (all_sig_elecs.subject_electrode.isin(sustained_se)) &
                                        (all_sig_elecs.complexity_significance.isin(['sequence', 'both']))].subject_electrode.unique()

    zs, rs, locs = [], [], []
    feat_type = 'sequence'

    for se in akt_control_se:

        if se not in akt_model.subject_electrode.values:
            continue

        cur_akt = akt_model.loc[akt_model.subject_electrode == se]
        
        if weight_type == 'akt':
            rs.append(cur_akt.AKT_r.values[0])
            
        elif weight_type == 'f0':
            # Exclude low AKT r elecs for f0 weight
            if cur_akt.AKT_r.values[0] < akt_r_cutoff:
                continue
                
            if np.isnan(cur_akt.max_f0_weight.values[0]):
                continue
            
            rs.append(abs(cur_akt.max_f0_weight.values[0]))

        potential_zs = seq_vs_syl_df.loc[(seq_vs_syl_df.subject_electrode == se) & 
                                         (seq_vs_syl_df.feature_type == 'sequence') & 
                                         (seq_vs_syl_df.alignment.isin(plot_periods))].statistic.abs().dropna().values
        zs.append(max(potential_zs))
        
        loc = all_sig_elecs.loc[all_sig_elecs.subject_electrode == se].location.values[0]
        
        if loc == 'middleprecentral' and se in pre_exec_seq_se:
            locs.append('mprcg')
        elif loc == 'posteriorsuperiortemporal' and se in pre_exec_seq_se:
            locs.append('pstg')
        else:
            locs.append('other')

    rs = np.array(rs)
    zs = np.array(zs)
    locs = np.array(locs)

    if weight_type == 'akt':
        ylim = [0, 10]
        xlim = [-0.1, 0.4]
        yticks = np.arange(ylim[0], ylim[1]+1, 5)
        xticks = np.arange(xlim[0], xlim[1]+0.07, 0.1)
    elif weight_type == 'f0':
        ylim = [0, 10]
        xlim = [0, 0.004]
        yticks = np.arange(ylim[0], ylim[1]+1, 5)
        xticks = np.arange(xlim[0], xlim[1]+0.001, 0.002)

    sns.despine(ax=ax, offset=dict(left=5, bottom=5))
    
    idx = np.where(locs=='other')[0]
    ax.scatter(rs[idx], zs[idx], s=s, ec='k', fc='k', clip_on=False, alpha=0.25, label='Other areas', zorder=1)
    
    idx = np.where(locs=='pstg')[0]
    ax.scatter(rs[idx], zs[idx], s=3*s, fc=seq_syl_encoding_colors[feat_type], ec='k', clip_on=False, label='pSTG', zorder=2)
    
    idx = np.where(locs=='mprcg')[0]
    ax.scatter(rs[idx], zs[idx], s=3*s, marker='^', fc=seq_syl_encoding_colors[feat_type], ec='k', clip_on=False, label='mPrCG', zorder=3)
    
    if plot_legend:
        ax.legend(loc='upper left', bbox_to_anchor=(1, 1.01), frameon=False)

    corr, p = correlation_permutation(rs, zs, n_permute=1000, random_seed=46)
    ax.set_title('$\it{r}$=' + f'{corr:0.2f}' + ' (P=' + f'{p:0.3f})')
    ax.axes.set(ylim=ylim, xlim=xlim, yticks=yticks, xticks=xticks);
    
    if weight_type == 'akt':
        ax.axes.set(xlabel='AKT model $\it{r}$', ylabel=f'{feat_type.capitalize()} complexity\nabs(z-value)')
    elif weight_type == 'f0':
        ax.axes.set(xlabel='AKT model - f0 weight', ylabel=f'{feat_type.capitalize()} complexity\nabs(z-value)')
    
    return ax

In [None]:
akt_control(weight_type='akt', plot_legend=True);

In [None]:
akt_control(weight_type='f0', plot_legend=True);

### Seq/syl brain recons

In [None]:
def seq_syl_recon(axs=None, plot_view='lateral', plot_hemis=['lh'], density=False, highlight_elecs=True, plot_nonsig_elecs=False,
                  s=30, density_feature_type='sequence', return_fig=False, orientation='vertical', plot_titles=True,
                 label_s=500, elec_fs=18):
    
    elec_labels = {
        'e1': [150, 0],
        'e2': [750, 150],
        'e3': [675, 25]
    }
    
    if axs is None:
        
        if orientation == 'vertical':
            figshape = (len(plot_periods), len(plot_hemis))
            fig, axs = plt.subplots(*figshape, figsize=(7, 12))
            axs = axs.reshape(figshape)
        
        else:
            figshape = (len(plot_hemis), len(plot_periods))
            fig, axs = plt.subplots(*figshape, figsize=(12, 10))
            axs = axs.reshape(figshape)

    for cur_period, period in enumerate(plot_periods):
        for cur_hemi, hemi in enumerate(plot_hemis):
            
            if orientation == 'vertical':
                ax = axs[cur_period, cur_hemi]
            else:
                ax = axs[cur_hemi, cur_period]

            brain_img = plt.imread(mni_img_path.format(hemi, plot_view))
            ax.imshow(brain_img, zorder=1, alpha=0.5)

            cur_df = all_sig_elecs.loc[(all_sig_elecs.hemisphere == hemi) & 
                                       (all_sig_elecs.view == plot_view) & 
                                       (all_sig_elecs.period == period) & 
                                       (all_sig_elecs.subject_electrode.isin(sustained_se)) &
                                       (all_sig_elecs.complexity_significance != 'none')]

            # plot non-sig elecs
            if plot_nonsig_elecs:
                nonsig_df = cur_df.loc[(cur_df.complexity_significance == 'neither')]
                ax.scatter(nonsig_df.x, nonsig_df.y, color='k', alpha=0.25, s=2, zorder=1)
                
            if density:
                cur_sig_df = cur_df.loc[cur_df.complexity_significance.isin([density_feature_type, 'both'])]
                
                sns.kdeplot(x=cur_sig_df.x.values, y=cur_sig_df.y.values, cmap=seq_syl_encoding_cmaps[density_feature_type], 
                            fill=True, thresh=0.25, bw_adjust=0.25, alpha=0.75, ax=ax)
                
            else:
                for sig_type in ['syllable', 'both', 'sequence']:
                    cur_sig_df = cur_df.loc[cur_df.complexity_significance == sig_type]

                    ax.scatter(cur_sig_df.x, cur_sig_df.y, 
                               c=cur_sig_df.max_complex_hga_magnitude,
                               cmap=seq_syl_encoding_cmaps[sig_type], vmin=0, vmax=2,
                               s=s, alpha=0.5)
                    
                if plot_view == 'lateral' and hemi == 'lh' and highlight_elecs:
                    for cur_se, se in enumerate(seq_complexity_se):

                        cur_sig_df = cur_df.loc[cur_df.subject_electrode == se]

                        if cur_period == len(plot_periods)-1 and cur_hemi == 0:
                            # Label example elecs
                            x, y = elec_labels[f'e{cur_se+1}']
                            ax.plot([x, cur_sig_df.x], [y, cur_sig_df.y], linewidth=0.25, color='k', zorder=2)

                            an_kw = dict(fontsize=elec_fs, annotation_clip=False)
                            ax.scatter([x], [y], s=label_s, facecolor='white', edgecolor='k', marker='o', clip_on=False, zorder=3)
                            ax.annotate(f'e{cur_se+1}', (x, y), va='center', ha='center', zorder=4, **an_kw)

                        ax.scatter(cur_sig_df.x, cur_sig_df.y, 
                                   c=cur_sig_df.max_complex_hga_magnitude, ec='k',
                                   cmap=seq_syl_encoding_cmaps[cur_sig_df.complexity_significance.values[0]], vmin=0, vmax=2,
                                   s=s*(5/3), alpha=1, zorder=4)
            
            if cur_hemi == 0 and plot_titles:
                ax.axes.set(title=period_labels[period])
                
            ax.axis('off')
        
    if return_fig:
        return fig, axs
    else:
        return axs

In [None]:
f, _ = seq_syl_recon(plot_hemis=['lh'], density=False, return_fig=True, s=25, orientation='horizontal');

#### Combined seq/syl density

In [None]:
def seq_syl_combined_density(ax=None, fig=None, return_fig=False, highlight_first_complexity_elec=False, period=None,
                             plot_mprcg_roi=False, plot_colorbar=True, figsize=(10, 10), plot_title=True, precentral_zoom=False, draw_rectangle=False,
                            label_fontsize=11):
    
    plot_hemi = 'lh'
    plot_view = 'lateral'
    
    box_lims = [[200, 375], [425, 150]]
    
    elec_labels = {
        'e1': [150, 0],
        'e2': [750, 150],
        'e3': [675, 25]
    }
    
    brain_img = plt.imread(mni_img_path.format(plot_hemi, plot_view))
    
    # Get sequence density
    cur_df = all_sig_elecs.loc[(all_sig_elecs.hemisphere == plot_hemi) & 
                               (all_sig_elecs.view == plot_view) & 
                               (all_sig_elecs.period == period) & 
                               (all_sig_elecs.subject_electrode.isin(sustained_se) & 
                                (all_sig_elecs.complexity_significance != 'none'))]
    w = np.where(cur_df.complexity_significance.isin(['both', 'sequence']).values, 1, 0)
    
    # Get smoothed Gaussian
    seq_density, xedges, yedges = smoothed_weighted_histogram(x=cur_df.x.values, 
                                                                  y=cur_df.y.values,
                                                                  weights=w, 
                                                                  xlim=[0, brain_img.shape[1]], 
                                                                  ylim=[0, brain_img.shape[0]],
                                                                  bins=100,
                                                                  smooth=2,
                                                                  baseline_norm=True)
    
    # Get syllable density
    condition = (all_sig_elecs.hemisphere == plot_hemi) & (all_sig_elecs.view == plot_view) & (all_sig_elecs.period == period) & (all_sig_elecs.subject_electrode.isin(sustained_se))
    cur_df = all_sig_elecs.loc[condition]
    w = np.where(cur_df.complexity_significance.isin(['both', 'syllable']).values, 1, 0)
    
    # Get smoothed Gaussian
    syl_density, xedges, yedges = smoothed_weighted_histogram(x=cur_df.x.values, 
                                                                  y=cur_df.y.values,
                                                                  weights=w, 
                                                                  xlim=[0, brain_img.shape[1]], 
                                                                  ylim=[0, brain_img.shape[0]],
                                                                  bins=100,
                                                                  smooth=2,
                                                                  baseline_norm=True)
    
    # Get differential density
    seq_minus_syl_density = seq_density - syl_density
    seq_minus_syl_preference = np.maximum(seq_density, syl_density)
    
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
        
    ax.imshow(brain_img, zorder=0, alpha=0.5)
    
    if plot_mprcg_roi:
        mprcg_roi = plt.imread(mprcg_roi_path.format(plot_hemi, plot_view))
        ax.imshow(mprcg_roi, zorder=0, alpha=0.7)
        
    alphas = np.concatenate([np.zeros(15), np.linspace(0, 0.9, 56), 0.9*np.ones(30)])
    dx = 0.05

    for i in np.arange(0, 1, dx):
        idx = np.where(~np.logical_and(seq_minus_syl_preference > i, seq_minus_syl_preference <= i + dx))
        temp = np.copy(seq_minus_syl_density)
        temp[idx] = np.nan
        im = ax.pcolormesh(xedges, yedges, temp.T, alpha=alphas[int(abs(100*i))], vmin=-1, vmax=1, cmap=density_cmap, zorder=0, rasterized=True)
        
    if plot_colorbar:
        cbar = plt.colorbar(im, ax=ax, location='bottom', ticks=[-1, 1], shrink=0.45, pad=0.025, label='Normalized density')
        cbar = cbar.set_ticklabels([new_seq_syl_labels['syllable'], new_seq_syl_labels['sequence']], fontsize=label_fontsize)
        
    if highlight_first_complexity_elec:
        
        cur_se = 0
        cur_sig_df = cur_df.loc[cur_df.subject_electrode == seq_complexity_se[0]]

        # Label example elecs
        x, y = elec_labels[f'e{cur_se+1}']
        
        an_kw = dict(fontsize=18, annotation_clip=False)
        ax.scatter([x], [y], s=500, facecolor='none', edgecolor='k', marker='o', clip_on=False, zorder=2)
        ax.annotate(f'e{cur_se+1}', (x, y), va='center', ha='center', zorder=3, **an_kw)

        ax.scatter(cur_sig_df.x, cur_sig_df.y, fc='none', ec='k', s=50, alpha=1, zorder=3)
        
        cp = ConnectionPatch((x, y), (cur_sig_df.x, cur_sig_df.y), 
                         coordsA='data', coordsB='data', axesA=ax, axesB=ax,
                         shrinkA=12, shrinkB=4,
                         linewidth=1)
        ax.add_patch(cp)
        
    ax.axis('off')
    
    if precentral_zoom:
        ax.axes.set(xlim=box_lims[0], ylim=box_lims[1])
        
    if precentral_zoom or draw_rectangle:
        rect = mpatches.Rectangle((box_lims[0][0], box_lims[1][1]), 
                                  abs(box_lims[0][0]-box_lims[0][1]),
                                  abs(box_lims[1][0]-box_lims[1][1]),
                                  edgecolor='#404040', linewidth=0.5, facecolor='none', clip_on=False)
        ax.add_patch(rect)
    
    if plot_title:
        ax.axes.set(title=period_labels[period])
        
    if return_fig:
        return fig, ax
    else:
        return ax

In [None]:
f, _ = seq_syl_combined_density(return_fig=True, figsize=(5, 5), period='pre_exec', plot_title=True, plot_colorbar=True, precentral_zoom=False, draw_rectangle=True);
f, _ = seq_syl_combined_density(return_fig=True, figsize=(5, 5), period='pre_exec', plot_title=False, plot_colorbar=False, precentral_zoom=True);

# Overall figure

In [None]:
##### Figure setup #####
default_plot_settings(font='Helvetica', fontsize=5, linewidth=0.5, ticklength=2)
label_fontsize = 5

plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['axes.labelpad'] = 1
plt.rcParams['patch.linewidth'] = 0.25
plt.rcParams['hatch.linewidth'] = 0.25
plt.rcParams['lines.markeredgewidth'] = 0.25
plt.rcParams['axes.titlepad'] = 2
plt.rcParams['xtick.major.pad'] = 2
plt.rcParams['ytick.major.pad'] = 2
plt.rcParams['legend.handletextpad'] = 1

all_axs = {}
mm = 1/25.4
fig = plt.figure(figsize=(180*mm, 240*mm))
gs = mpl.gridspec.GridSpec(100, 100, figure=fig, hspace=3, wspace=2)


erp_gs = mpl.gridspec.GridSpecFromSubplotSpec(3, 2, 
                                              subplot_spec=gs[:35, :57], 
                                              width_ratios=[3, 1.75],
                                              hspace=0.5,
                                              wspace=0.05)
stat_gs = mpl.gridspec.GridSpecFromSubplotSpec(3, 2, 
                                              subplot_spec=gs[:35, 65:], 
                                              hspace=0.45,
                                              wspace=0.25)
scatter_gs = mpl.gridspec.GridSpecFromSubplotSpec(3, 1, 
                                                  subplot_spec=gs[40:, :25],
                                                  hspace=0.35)
recon_gs = mpl.gridspec.GridSpecFromSubplotSpec(3, 1, 
                                                subplot_spec=gs[40:, 25:55],
                                                hspace=0.35)
controls_gs = mpl.gridspec.GridSpecFromSubplotSpec(2, 2, 
                                                  subplot_spec=gs[69:97, 65:],
                                                  hspace=0.75,
                                                  wspace=0.75)

##########

all_axs['erps'] = []
for r in range(erp_gs._nrows):
    all_axs['erps'].append([])
    for c in range(erp_gs._ncols):
        all_axs['erps'][-1].append(fig.add_subplot(erp_gs[r, c]))
all_axs['erps'] = np.array(all_axs['erps'])
all_axs['erps'] = complexity_erps(axs=all_axs['erps'], fs=label_fontsize, elec_fs=6, s=75, legend_bbox=(0.65, 1.17))


all_axs['erp_stats'] = []
for r in range(stat_gs._nrows):
    all_axs['erp_stats'].append([])
    for c in range(stat_gs._ncols):
        all_axs['erp_stats'][-1].append(fig.add_subplot(stat_gs[r, c]))
all_axs['erp_stats'] = np.array(all_axs['erp_stats'])
all_axs['erp_stats'] = single_elec_stat_plot(axs=all_axs['erp_stats'], fs=label_fontsize)


all_axs['scatters'] = np.array([fig.add_subplot(scatter_gs[i]) for i in range(3)])
all_axs['scatters'] = seq_syl_z_scatter(axs=all_axs['scatters'], fs=label_fontsize, elec_fs=6, ms=6, label_s=75)


all_axs['recons'] = []
for r in range(recon_gs._nrows):
    all_axs['recons'].append([])
    for c in range(recon_gs._ncols):
        all_axs['recons'][-1].append(fig.add_subplot(recon_gs[r, c]))
all_axs['recons'] = np.array(all_axs['recons'])
all_axs['recons'] = seq_syl_recon(axs=all_axs['recons'], plot_hemis=['lh'], label_s=75, elec_fs=6, s=7)


all_axs['combined_density'] = fig.add_subplot(gs[40:65, 50:95])
all_axs['combined_density'] = seq_syl_combined_density(ax=all_axs['combined_density'], period='pre_exec', plot_title=True, 
                                                       label_fontsize=5, plot_colorbar=True, draw_rectangle=True)
all_axs['density_inset'] = fig.add_subplot(gs[40:58, 83:])
all_axs['density_inset'] = seq_syl_combined_density(ax=all_axs['density_inset'], period='pre_exec', plot_title=False, plot_colorbar=False, precentral_zoom=True)


all_axs['auditory_control'] = fig.add_subplot(controls_gs[0, 1])
all_axs['auditory_control'] = auditory_control(ax=all_axs['auditory_control'], plot_legend=True, legend_anchor=-0.75, s=7)


all_axs['akt_control'] = fig.add_subplot(controls_gs[1, 0])
all_axs['akt_control'] = akt_control(ax=all_axs['akt_control'], weight_type='akt', s=7)


all_axs['f0_control'] = fig.add_subplot(controls_gs[1, 1])
all_axs['f0_control'] = akt_control(ax=all_axs['f0_control'], weight_type='f0', s=7)


# #### Panel labels
panel_fs = 7
all_axs['erps'][0, 0].annotate('a', (-0.1, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_fs, weight='bold');
all_axs['erp_stats'][0, 0].annotate('b', (-0.29, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_fs, weight='bold');
all_axs['erp_stats'][0, 1].annotate('c', (-0.2, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_fs, weight='bold');
all_axs['scatters'][0].annotate('d', (-0.2, 1.04), xycoords='axes fraction', ha='right', fontsize=panel_fs, weight='bold');
all_axs['recons'][0, 0].annotate('e', (0.0, 1.04), xycoords='axes fraction', ha='right', fontsize=panel_fs, weight='bold');
all_axs['combined_density'].annotate('f', (0.0, 1.03), xycoords='axes fraction', ha='right', fontsize=panel_fs, weight='bold');
all_axs['auditory_control'].annotate('g', (-0.3, 1.15), xycoords='axes fraction', ha='right', fontsize=panel_fs, weight='bold');
all_axs['akt_control'].annotate('h', (-0.3, 1.15), xycoords='axes fraction', ha='right', fontsize=panel_fs, weight='bold');
all_axs['f0_control'].annotate('i', (-0.4, 1.15), xycoords='axes fraction', ha='right', fontsize=panel_fs, weight='bold');

# Supplementary/presentation figures

## Fig. S10 - RH and medial views of Fig. 3E

In [None]:
fig, axs = plt.subplots(4, 3, figsize=(13, 13), gridspec_kw={'hspace': 0.05, 'wspace': 0.05})

seq_syl_recon(axs=np.array(axs[0, :]).reshape((1, -1)), plot_hemis=['lh'], plot_view='medial', density=False, s=25, 
              orientation='horizontal', plot_titles=True, highlight_elecs=False, plot_nonsig_elecs=True);
seq_syl_recon(axs=np.array(axs[1, :]).reshape((1, -1)), plot_hemis=['rh'], plot_view='medial', density=False, s=25, 
              orientation='horizontal', plot_titles=False, highlight_elecs=False, plot_nonsig_elecs=True);
seq_syl_recon(axs=np.array(axs[2, :]).reshape((1, -1)), plot_hemis=['lh'], plot_view='lateral', density=False, s=25, 
              orientation='horizontal', plot_titles=False, highlight_elecs=False, plot_nonsig_elecs=True);
seq_syl_recon(axs=np.array(axs[3, :]).reshape((1, -1)), plot_hemis=['rh'], plot_view='lateral', density=False, s=25, 
              orientation='horizontal', plot_titles=False, highlight_elecs=False, plot_nonsig_elecs=True);

## Fig. S11 - Auditory control, per task phase

In [None]:
default_plot_settings(font='Helvetica', fontsize=11, linewidth=0.5, ticklength=2)

fig, axs = plt.subplots(2, 3, figsize=(10, 8))

for i, feat_type in enumerate(['sequence', 'syllable']):
    for j, period in enumerate(plot_periods):
        
        ax = axs[i, j]
        sns.despine(ax=ax, offset=dict(left=10, bottom=10))

        sig_df = seq_vs_syl_df.loc[(seq_vs_syl_df.alignment == period) & 
                                   (seq_vs_syl_df.fdr_pval < sig_thresh) & 
                                   (seq_vs_syl_df.feature_type == feat_type) &
                                   (seq_vs_syl_df.subject_electrode.isin(sustained_se))]
        zs, rs, facecolors, edgecolors, s = [], [], [], [], []
        
        idx = []
        cur_z = 0

        for se, z in zip(sig_df.subject_electrode.values, sig_df.statistic.values):
            
            if se not in adf.subject_electrode.values:
                continue
                
            rs.append(abs(adf.loc[adf.subject_electrode == se].z_val.values[0]))
            zs.append(abs(z))
            
            
            if adf.loc[adf.subject_electrode == se].significant.values[0]:
                facecolors.append('C0')
                edgecolors.append('C0')
                s.append(20)
                idx.append(cur_z)
            else:
                facecolors.append('grey')
                edgecolors.append('grey')
                s.append(10)
                
            cur_z += 1

        ax.scatter(rs, zs, facecolors=facecolors, edgecolors=edgecolors, s=s, clip_on=False, alpha=0.7)
        
        rs = np.array(rs)
        zs = np.array(zs)
        idx = np.array(idx)
        
        corr, p = correlation_permutation(rs[idx], zs[idx], n_permute=1000, random_seed=46)
        
        ax.axes.set(title=period_labels[period] + '\n$\it{r}$=' + f'{corr:0.2f}' + ' ($\it{P}$=' + f'{p:0.3f})')
        ax.axes.set(xlabel='Auditory response\nabs(z-value)', 
                    ylabel='{} complexity\nabs(z-value)'.format(full_name_feature_labels[feat_type]))
        ax.axes.set(ylim=(0, 10), xlim=(0, 20))
        
fig.tight_layout();

## Fig. S12, S13 - Articulatory control, per task phase
Set `weight_type` to `'akt'` for Fig. S12, or to `'f0'` for Fig. S13.

In [None]:
weight_type = 'akt'
# weight_type = 'f0'

default_plot_settings(font='Helvetica', fontsize=11, linewidth=0.5, ticklength=2)

fig, axs = plt.subplots(2, 3, figsize=(10, 8))

if weight_type == 'akt':
    ylim = [0, 10]
    xlim = [0, 0.4]
elif weight_type == 'f0':
    ylim = [0, 10]
    xlim = [0, 0.004]

for i, feat_type in enumerate(['sequence', 'syllable']):
    for j, period in enumerate(plot_periods):
        
        ax = axs[i, j]
        sns.despine(ax=ax, offset=dict(left=10, bottom=10))

        sig_df = seq_vs_syl_df.loc[(seq_vs_syl_df.alignment == period) & 
                                   (seq_vs_syl_df.fdr_pval < sig_thresh) & 
                                   (seq_vs_syl_df.feature_type == feat_type) &
                                   (seq_vs_syl_df.subject_electrode.isin(sustained_se))]
        zs = []
        rs = []

        for se, z in zip(sig_df.subject_electrode.values, sig_df.statistic.values):
            
            if se not in akt_model.subject_electrode.values:
                continue

            cur_akt = akt_model.loc[akt_model.subject_electrode == se]

            if weight_type == 'akt':
                rs.append(abs(cur_akt.AKT_r.values[0]))

            elif weight_type == 'f0':
                
                # Exclude low AKT r elecs for f0 weights
                if cur_akt.AKT_r.values[0] < akt_r_cutoff:
                    continue

                if np.isnan(cur_akt.max_f0_weight.values[0]):
                    continue

                rs.append(abs(cur_akt.max_f0_weight.values[0]))
                
            zs.append(abs(z))

        rs = np.array(rs)
        zs = np.array(zs)
        
        ax.scatter(rs, zs, s=20, alpha=0.7, clip_on=False, zorder=2)
        corr, p = correlation_permutation(rs, zs, n_permute=1000, random_seed=46)
        ax.axes.set(title=period_labels[period] + '\n$\it{r}$=' + f'{corr:0.2f}' + ' ($\it{P}$=' + f'{p:0.3f})')
        ax.axes.set(ylim=ylim, xlim=xlim);
    
        if weight_type == 'akt':
            ax.axvline(x=akt_r_cutoff, linewidth=1, linestyle='--', color='k', zorder=1)
            ax.axes.set(xlabel='AKT model $\it{r}$', ylabel='{} complexity\nabs(z-value)'.format(full_name_feature_labels[feat_type]))
        elif weight_type == 'f0':
            ax.axes.set(xlabel='AKT model - f0 weight', ylabel='{} complexity\nabs(z-value)'.format(full_name_feature_labels[feat_type]))
        
fig.tight_layout();