# Figure 1

In [None]:
import os
import pickle
from itertools import combinations

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.lines import Line2D
from scipy import stats
from sklearn.preprocessing import MinMaxScaler
from statannotations.Annotator import Annotator
from statsmodels.stats import multitest

from distractors_figures.plotting_utils.plot_fncs import default_plot_settings
from distractors_figures.plotting_utils.utils import correlation_permutation, plot_images_and_elecs

%load_ext autoreload
%autoreload 2

# Plot settings
default_plot_settings(font='Helvetica', fontsize=14, linewidth=1.5)

from distractors_figures.plotting_utils.color_pals import distractors as distractor_colors, det_model_type as det_model_type_colors

## Setup

In [None]:
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['xtick.major.pad'] = '1.25'
plt.rcParams['ytick.major.pad'] = '1.25'
plot_params = {'line_width': 1}

subjects = ['bravo1', 'bravo3']
subject_labels = {
    'bravo1': 'Bravo-1',
    'bravo3': 'Bravo-3'
}

model_type_labels = {
    'full_model_with_gate': 'Full model\n+verif.',
    'full_model_no_gate': 'Full model',
    'speech_only': 'Speech-only\nmodel',
}

data_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'data'))

boxplot_kwargs = {}
for key in ['cap', 'whisker', 'flier', 'median', 'mean']:
    boxplot_kwargs[f'{key}props'] = dict(clip_on=False)
boxplot_kwargs['boxprops'] = dict(alpha=0.85, clip_on=False)
boxplot_kwargs['flierprops'].update(dict(marker='o', markerfacecolor='None'))

pvalue_thresholds = [[0.0001, "****"], [0.001, "***"], [0.01, "**"], [0.05, "*"]]
sig_threshold = 0.05

colors = {
    'paradigms': distractor_colors[:3],
    'models': [det_model_type_colors[i] for i in [2, 3, 5]],
    'long_paradigms': ['#feebcc', '#c5c3ce', '#724438']
}

brain_lims = {
    'bravo1': {
        'xlim': [175, 525],
        'ylim': [125, 550]
    },
    'bravo3': {
        'xlim': [175, 525],
        'ylim': [100, 525]
    }
}

despine_params = dict(left=3, bottom=3)
plot_params = {'linewidth': 1}

## Load detection data

In [None]:
# Load detection saliences
with open(os.path.join(data_dir, 'detection_sals_fig1.pkl'), 'rb') as f:
    det_sals = pickle.load(f)
    
## for detection saliences, apply log and normalization
for key in det_sals.keys():
    new_sals = np.log(det_sals[key])
    scaler = MinMaxScaler()
    det_sals[key] = scaler.fit_transform(new_sals.reshape((-1, 1))).squeeze()

# Load FP and TP counts
bp = pd.read_hdf(os.path.join(data_dir, 'fig1_bp.h5'))

# Load false positive counts from the long form data
fpc = pd.read_hdf(os.path.join(data_dir, 'fig1_fpc.h5'))

In [None]:
# create model_type and block_type combinations for 18-way comparison
combos = []
for subject in bp.subject.unique():
    
    for model_type in bp.model_type.unique():

        for block_type1, block_type2 in combinations(bp.block_type.unique(), 2):
            combos.append([subject, model_type, model_type, block_type1, block_type2])
            
    for block_type in bp.block_type.unique():

        for model_type1, model_type2 in combinations(bp.model_type.unique(), 2):
            combos.append([subject, model_type1, model_type2, block_type, block_type])

In [None]:
# based off of pre-set combos
bp_stats = {key: [] for key in ['subject', 'model_type1', 'model_type2', 'block_type1', 'block_type2', 'comp_type', 'statistic', 'pval']}

for c in combos:
    
    subject, model_type1, model_type2, block_type1, block_type2 = c
            
    # true positive comparison
    bt1_tp = bp.loc[(bp.subject == subject) & (bp.model_type == model_type1) & (bp.block_type == block_type1)].tpc.values.sum()
    bt1_tp_total = bp.loc[(bp.subject == subject) & (bp.model_type == model_type1) & (bp.block_type == block_type1)].tp_total.values.sum()
    bt2_tp = bp.loc[(bp.subject == subject) & (bp.model_type == model_type2) & (bp.block_type == block_type2)].tpc.values.sum()
    bt2_tp_total = bp.loc[(bp.subject == subject) & (bp.model_type == model_type2) & (bp.block_type == block_type2)].tp_total.values.sum()

    c = [[bt1_tp, bt1_tp_total - bt1_tp],
         [bt2_tp, bt2_tp_total - bt2_tp]]
    tp_stat, tp_pval = stats.fisher_exact(c, alternative='two-sided')

    bp_stats['subject'].append(subject)
    bp_stats['model_type1'].append(model_type1)
    bp_stats['model_type2'].append(model_type2)
    bp_stats['block_type1'].append(block_type1)
    bp_stats['block_type2'].append(block_type2)
    bp_stats['comp_type'].append('tp')
    bp_stats['statistic'].append(tp_stat)
    bp_stats['pval'].append(tp_pval)

    # false positive comparison
    bt1_fp = bp.loc[(bp.subject == subject) & (bp.model_type == model_type1) & (bp.block_type == block_type1)].fpc.values.sum()
    bt1_fp_total = bp.loc[(bp.subject == subject) & (bp.model_type == model_type1) & (bp.block_type == block_type1)].fp_total.values.sum()
    bt2_fp = bp.loc[(bp.subject == subject) & (bp.model_type == model_type2) & (bp.block_type == block_type2)].fpc.values.sum()
    bt2_fp_total = bp.loc[(bp.subject == subject) & (bp.model_type == model_type2) & (bp.block_type == block_type2)].fp_total.values.sum()

    c = [[bt1_fp, bt1_fp_total - bt1_fp],
         [bt2_fp, bt2_fp_total - bt2_fp]]

    fp_stat, fp_pval = stats.fisher_exact(c, alternative='two-sided')

    bp_stats['subject'].append(subject)
    bp_stats['model_type1'].append(model_type1)
    bp_stats['model_type2'].append(model_type2)
    bp_stats['block_type1'].append(block_type1)
    bp_stats['block_type2'].append(block_type2)
    bp_stats['comp_type'].append('fp')
    bp_stats['statistic'].append(fp_stat)
    bp_stats['pval'].append(fp_pval)
            
bp_stats = pd.DataFrame(data=bp_stats)

hb_pvals = np.zeros_like(bp_stats.pval.values)
for subject in bp_stats.subject.unique():
    for comp_type in bp_stats.comp_type.unique():

        cur_df = bp_stats.loc[(bp_stats.subject == subject) & (bp_stats.comp_type == comp_type)]
        _, h, _, _ = multitest.multipletests(cur_df.pval.values, alpha=sig_threshold, method='holm', is_sorted=False, returnsorted=False)
        hb_pvals[cur_df.index.values] = h

bp_stats['hb_pval'] = hb_pvals

## Load classification data

In [None]:
# 10-word classifier accuracies
with open(os.path.join(data_dir, 'pseudo_df_classifier.pkl'), 'rb') as f:
    class_accs = pickle.load(f)
    
# Mod labels
ps = class_accs.paradigm.values
class_accs.paradigm = np.where(ps == 'None', 'Baseline', 'Listen')
class_accs.acc = 100 * class_accs.acc.values
    
# 10-word classifier saliences
with open(os.path.join(data_dir, 'classifier_sals_fig1.pkl'), 'rb') as f:
    class_sals = pickle.load(f)
    
## for 10-word classifier saliences, apply log and normalization
for key in class_sals.keys():
    new_sals = np.log(class_sals[key])
    scaler = MinMaxScaler()
    class_sals[key] = scaler.fit_transform(new_sals.reshape((-1, 1))).squeeze()

# Gate classifier curves
with open(os.path.join(data_dir, 'gate_fpr_fneg.pkl'), 'rb') as f:
    gate_fpr_fnr = pickle.load(f)

In [None]:
class_stats = {key: [] for key in ['subject', 'comparison1', 'comparison2', 'statistic', 'pval']}

for subject in class_accs.participant.unique():

    for block_type1, block_type2 in combinations(class_accs.paradigm.unique(), 2):
        
        bt1_correct = class_accs.loc[(class_accs.participant == subject) & (class_accs.paradigm == block_type1)]['corr'].values.sum()
        bt1_total = class_accs.loc[(class_accs.participant == subject) & (class_accs.paradigm == block_type1)]['total_samps'].values.sum()
        bt1_incorrect = bt1_total - bt1_correct
        
        bt2_correct = class_accs.loc[(class_accs.participant == subject) & (class_accs.paradigm == block_type2)]['corr'].values.sum()
        bt2_total = class_accs.loc[(class_accs.participant == subject) & (class_accs.paradigm == block_type2)]['total_samps'].values.sum()
        bt2_incorrect = bt2_total - bt2_correct
        
        c = [[bt1_correct, bt1_incorrect],
             [bt2_correct, bt2_incorrect]]
        stat, pval = stats.fisher_exact(c, alternative='two-sided')

        class_stats['subject'].append(subject)
        class_stats['comparison1'].append(block_type1)
        class_stats['comparison2'].append(block_type2)
        class_stats['statistic'].append(stat)
        class_stats['pval'].append(pval)

class_stats = pd.DataFrame(data=class_stats)

## Panels

### RT TPR

In [None]:
def rt_tpr(ax=None, ms=7, lp=10):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
        
    ax = sns.stripplot(data=bp.loc[bp.model_type == 'Full model\n+verif.'],
                       x='subject', y='tpr', hue='block_type',
                       palette=colors['paradigms'], ax=ax, dodge=True, jitter=True, clip_on=False, 
                       linewidth=1, size=ms)
    sns.despine(ax=ax, offset=despine_params)
    ax.spines['bottom'].set_visible(False)
    ax.xaxis.set_tick_params(length=0)

    ax.legend([], [], frameon=False)
    
    cur_bp = bp_stats.loc[(bp_stats.model_type1 == 'Full model\n+verif.') & (bp_stats.model_type2 == 'Full model\n+verif.') & (bp_stats.comp_type == 'tp')]
    box_pairs = []
    for i in range(cur_bp.shape[0]):
        this_bp = cur_bp.iloc[i]
        box_pairs.append(((this_bp.subject, this_bp.block_type1), (this_bp.subject, this_bp.block_type2)))
    box_pvals = cur_bp.hb_pval.values
    annotator = Annotator(ax, box_pairs, data=bp.loc[bp.model_type == 'Full model\n+verif.'], x='subject', y='tpr', hue='block_type')
    annotator.configure(
        test=None, alpha=sig_threshold, line_width=plot_params['linewidth'], test_short_name='test', pvalue_thresholds=pvalue_thresholds,hide_non_significant=True
    ).set_pvalues(pvalues=box_pvals).annotate()
    
    ax.axes.set(ylim=(0, 100), xlabel='');
    ax.set_ylabel('True positive rate (%)', labelpad=lp)
    
    return ax

### RT FPR

In [None]:
def rt_fpr(ax=None, ms=7, lp=10):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
        
    ax = sns.stripplot(data=bp.loc[bp.model_type == 'Full model\n+verif.'], 
                       x='subject', y='fpr', hue='block_type', 
                       palette=colors['paradigms'], ax=ax, dodge=True, jitter=True, clip_on=False, linewidth=1, size=ms)
    sns.despine(ax=ax, offset=despine_params)
    ax.spines['bottom'].set_visible(False)
    ax.xaxis.set_tick_params(length=0)
    ax.axes.set(ylim=(0, 100), xlabel='');
    ax.set_ylabel('False positive rate (%)', labelpad=lp)
    
    h, l = ax.get_legend_handles_labels()
    for ha in h:
        ha.set_edgecolor('#404040')
        ha.set_sizes([7*ms])
    ax.legend(h, l, frameon=False, loc='upper right')
    
    cur_bp = bp_stats.loc[(bp_stats.model_type1 == 'Full model\n+verif.') & (bp_stats.model_type2 == 'Full model\n+verif.') & (bp_stats.comp_type == 'fp')]
    box_pairs = []
    for i in range(cur_bp.shape[0]):
        this_bp = cur_bp.iloc[i]
        box_pairs.append(((this_bp.subject, this_bp.block_type1), (this_bp.subject, this_bp.block_type2)))
    box_pvals = cur_bp.hb_pval.values
    annotator = Annotator(ax, box_pairs, data=bp.loc[bp.model_type == 'Full model\n+verif.'], x='subject', y='fpr', hue='block_type')
    annotator.configure(
        test=None, alpha=sig_threshold, line_width=plot_params['linewidth'], test_short_name='test', pvalue_thresholds=pvalue_thresholds, hide_non_significant=True
    ).set_pvalues(pvalues=box_pvals).annotate()
    
    return ax

### FPR ablation

In [None]:
def fpr_ablation(axs=None, plot_full_model=False, ms=7, lp=10):
    """
    `plot_full_model` is only for the supplement version of this panel
    """
    
    if plot_full_model:
        x_order = ['Speech-only\nmodel', 'Full model', 'Full model\n+verif.']
        ax_bp = bp.copy(deep=True)
    else:
        ax_bp = bp.loc[bp.model_type != 'Full model']
        x_order = ['Speech-only\nmodel', 'Full model\n+verif.']
    
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(8, 5))
        
    ax = axs[0]
    ax = sns.stripplot(data=ax_bp.loc[ax_bp.subject == 'Bravo-1'], 
                       x='model_type', y='fpr', hue='block_type', order=x_order,
                       palette=colors['paradigms'], ax=ax, dodge=True, jitter=True, clip_on=False, 
                       linewidth=1, size=ms)
    sns.despine(ax=ax, offset=despine_params)
    ax.spines['bottom'].set_visible(False)
    ax.xaxis.set_tick_params(length=0)
    ax.axes.set(ylim=(0, 100), xlabel='Bravo-1');
    ax.set_ylabel('False positive rate (%)', labelpad=lp)
    
    cur_bp = bp_stats.loc[(bp_stats.subject == 'Bravo-1') & (bp_stats.comp_type == 'fp')]
    if not plot_full_model:
        cur_bp = cur_bp.loc[(cur_bp.model_type1 != 'Full model') & (cur_bp.model_type2 != 'Full model')]
    
    box_pairs = []
    for i in range(cur_bp.shape[0]):
        this_bp = cur_bp.iloc[i]
        box_pairs.append(((this_bp.model_type1, this_bp.block_type1), (this_bp.model_type2, this_bp.block_type2)))
    box_pvals = cur_bp.hb_pval.values
    annotator = Annotator(ax, box_pairs, data=ax_bp.loc[ax_bp.subject == 'Bravo-1'], x='model_type', y='fpr', hue='block_type', order=x_order)
    annotator.configure(
        test=None, alpha=sig_threshold, line_width=plot_params['linewidth'], test_short_name='test', pvalue_thresholds=pvalue_thresholds, hide_non_significant=True
    ).set_pvalues(pvalues=box_pvals).annotate()
    
    h, l = ax.get_legend_handles_labels()
    for ha in h:
        ha.set_edgecolor('#404040')
        ha.set_sizes([7*ms])
    ax.legend(h[:3], l[:3], frameon=False, loc='upper left')
    
    ax = axs[1]
    ax = sns.stripplot(data=ax_bp.loc[ax_bp.subject == 'Bravo-3'], 
                       x='model_type', y='fpr', hue='block_type', order=x_order,
                       palette=colors['paradigms'], ax=ax, dodge=True, jitter=True, clip_on=False, 
                       linewidth=1, size=ms)
    sns.despine(ax=ax, offset=despine_params)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.xaxis.set_tick_params(length=0)
    ax.yaxis.set_tick_params(length=0)
    ax.axes.set(ylim=(0, 100), yticklabels=[], xlabel='Bravo-3', ylabel='');
    ax.get_legend().remove()
    
    cur_bp = bp_stats.loc[(bp_stats.subject == 'Bravo-3') & (bp_stats.comp_type == 'fp')]
    if not plot_full_model:
        cur_bp = cur_bp.loc[(cur_bp.model_type1 != 'Full model') & (cur_bp.model_type2 != 'Full model')]
    box_pairs = []
    for i in range(cur_bp.shape[0]):
        this_bp = cur_bp.iloc[i]
        box_pairs.append(((this_bp.model_type1, this_bp.block_type1), (this_bp.model_type2, this_bp.block_type2)))
    box_pvals = cur_bp.hb_pval.values
    annotator = Annotator(ax, box_pairs, data=ax_bp.loc[ax_bp.subject == 'Bravo-3'], x='model_type', y='fpr', hue='block_type', order=x_order)
    annotator.configure(
        test=None, alpha=sig_threshold, line_width=plot_params['linewidth'], test_short_name='test', pvalue_thresholds=pvalue_thresholds, hide_non_significant=True
    ).set_pvalues(pvalues=box_pvals).annotate()
    
    return axs

### Detection saliences

In [None]:
def plot_cortex(vals, plot_subject=None, ax=None, closeup=False, es=25):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    
    if plot_subject == 'bravo1':
        elec_scale = es
    else:
        elec_scale = 2.6 * es
    
    cbar_params = {
        'plot_colorbar'        : False,
        'colorbar_title'       : 'Color title',
    }
    elec_size_color_params = {
        'color_spec'   : 'Greys',
        'color_params' : {'min': 0.0,  'max': 1.0,   'relative': False, 'scale': 1.0},
        'size_params'  : {'min': 5.0, 'max': 100.0, 'relative': True, 'shift': 0.0, 'scale': elec_scale},
        'alpha_params' : {'min': 0.5, 'max': 1.0,  'relative': True, 'exponent': 1}
    }
    other_params = {
        'show_fig': False,
        'elec_size_color_params': elec_size_color_params
    }
    all_plot_params = {
        'elec_loc_file_path': os.path.join(data_dir, f'{plot_subject}_lateral_elecmat_2D.npy'),
        'all_image_params': [
            {
                'file_name': os.path.join(data_dir, f'{plot_subject}_lateral_brain_2D.png'),
                'invert_y': True,
                'alpha': 0.35
            },
            dict(file_name=os.path.join(data_dir, f'{plot_subject}_precentral_gyrus_mask.npy'), 
                 invert_y=True, alpha=0.2, only_apply_alpha_to_nonzero=True,
                 mask_pixel_values=('#00000000', '#1E547D')),
        ],
        'y_scale': -1.0,
        'add_height': True,
        'elec_plot_params': {'linewidths': 0.0, 'zorder': 100000.0},
    }
    all_plot_params.update(cbar_params)
    all_plot_params.update(other_params)
    all_plot_params['elec_weights'] = vals
    plot_images_and_elecs(ax=ax, **all_plot_params)
    
    mprcg_roi = plt.imread(os.path.join(data_dir, f'{plot_subject}_mprcg_roi.png'))
    ax.imshow(np.flipud(mprcg_roi), alpha=0.75)
    
    if closeup:
        ax.axes.set(**brain_lims[plot_subject])
        
    return ax

def adjust_order(arr, with_string=False, fill_value=0):
    new_arr = np.concatenate([arr[:107], arr[108:112], arr[113:117], arr[118:]])
    if(with_string):
        arr[:] = 'dummy'
    else:
        arr[:] = fill_value
    arr[:250] = new_arr
    return(arr)

### Classification accuracy

In [None]:
def class_acc_bp(ax=None, ms=7, lp=10):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
        
    cur_colors = np.array(colors['paradigms'])[np.array([0, 2])]
        
    ax = sns.stripplot(data=class_accs, x='participant', y='acc', hue='paradigm', 
                       palette=cur_colors, ax=ax, dodge=True, jitter=True, clip_on=False, 
                       linewidth=1, size=ms)
    
    sns.despine(ax=ax, offset=despine_params)
    ax.spines['bottom'].set_visible(False)
    ax.xaxis.set_tick_params(length=0)
    
    vals = []
    for subject in ['Bravo-1', 'Bravo-3']:
        for block_type in ['Baseline', 'Listen']:
            cur_df = class_accs.loc[(class_accs.participant == subject) & (class_accs.paradigm == block_type)]
            vals.append(100 * (cur_df['corr'].values.sum() / cur_df['total_samps'].values.sum()))
            
    ax.scatter([-0.35, 0.15, 0.65, 1.15], vals, marker='>', s=ms*4, color=np.array(colors['paradigms'])[np.array([0, 2, 0, 2])])
    
    box_pairs = []
    for i in range(class_stats.shape[0]):
        this_bp = class_stats.iloc[i]
        box_pairs.append(((this_bp.subject, this_bp.comparison1), (this_bp.subject, this_bp.comparison2)))
    box_pvals = class_stats.pval.values
    annotator = Annotator(ax, box_pairs, data=class_accs, x='participant', y='acc', hue='paradigm', dodge=True)
    annotator.configure(test=None, alpha=sig_threshold, line_width=plot_params['linewidth'], test_short_name='test', pvalue_thresholds=pvalue_thresholds, hide_non_significant=True).set_pvalues(pvalues=box_pvals).annotate()
    
    ax.axes.set(ylim=(0, 100), xlabel='');
    ax.set_ylabel('Classifier accuracy (%)', labelpad=lp)
    ax.axhline(y=100/10, color='k', linestyle=':')
    ax.annotate('Chance', (-0.35, 13), annotation_clip=False)
    
    triangle = mpl.lines.Line2D([], [], markerfacecolor='white', markeredgecolor='k', marker='>', linestyle='None',
                                markersize=ms, alpha=0.7)

    h, l = ax.get_legend_handles_labels()
    for ha in h:
        ha.set_edgecolor('#404040')
        ha.set_sizes([7*ms])
        
    h.append(triangle)
    l.append('Overall accuracy')
    ax.legend(h, l, frameon=False, bbox_to_anchor=(-0.1, 0.57), loc='upper left')
    
    return ax

### Classifier saliences

In [None]:
def class_sal_scatter(ax=None, plot_subject=None, plot_xlabel=True, ms=30):
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
    
    if plot_subject == 'bravo1':
        baseline_sal = class_sals['b1_baseline']
        listen_sal = class_sals['b1_listening']
    elif plot_subject == 'bravo3':
        baseline_sal = class_sals['b3_baseline']
        listen_sal = class_sals['b3_listening']
        
    # stat test
    corr, pval = correlation_permutation(baseline_sal, listen_sal)
    pthresholds = np.array([0.0001, 0.001, 0.01])
    plabel = pthresholds[np.where(pthresholds > pval)[0][0]]
    
    ax.scatter(baseline_sal, listen_sal, color='k', alpha=0.5, s=ms, clip_on=False)
    ax.annotate(f'r={corr:0.2f}\nP < {plabel}', (1, 0.05), va='bottom', ha='right')
    
    ax.axes.set(xlim=(0, 1), ylim=(0, 1), xticks=[0, 1], yticks=[0, 1])
    ax.set_aspect('equal')
    sns.despine(ax=ax, offset=despine_params)
    ax.plot([0, 1], [0, 1], color='k', linestyle='--', linewidth=1, alpha=0.7)
    
    if plot_xlabel:
        ax.axes.set(xlabel='Electrode contribution\nduring baseline')
    
    return ax

### Gate classifier threshold

In [None]:
def gate_curve(axs=None, ms=80, lp=10):
    
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(8, 3))
    
    for subj, ax in zip(['b1', 'b3'], axs.ravel()):
        
        if subj == 'b1':
            subj_label = 'Bravo-1'
        elif subj == 'b3':
            subj_label = 'Bravo-3'

        ax.plot(gate_fpr_fnr[subj]['thresholds'], 100*np.array(gate_fpr_fnr[subj]['fpr']), color='k', linestyle='-', linewidth=1, label='False pos.', clip_on=False)
        ax.plot(gate_fpr_fnr[subj]['thresholds'], 100*np.array(gate_fpr_fnr[subj]['fneg']), color='k', linestyle=':', linewidth=1, label='False neg.', clip_on=False)
        
        t = 0.65
        ax.scatter([t], [10], marker='v', s=ms, color='g', edgecolor='g')
        ax.plot([t, t], [0, 10], color='g', linestyle='--', linewidth=1, alpha=0.6)
        
        max_acc = np.max(gate_fpr_fnr[subj]['accs'])
        t = max(gate_fpr_fnr[subj]['thresholds'][np.where(gate_fpr_fnr[subj]['accs'] == max_acc)[0]])
        ax.scatter([t], [10], marker='v', s=ms, edgecolor='g', facecolor='white', zorder=1)
        ax.plot([t, t], [0, 10], color='g', linestyle='--', linewidth=1, alpha=0.6, zorder=0)
        
        ax.axes.set(ylim=[0, 15], yticks=[0, 5, 10, 15], xticks=[0, 1], xlim=[0, 1])
        ax.set_ylabel('Error rate (%)', labelpad=lp)
        ax.set_xlabel('Speech probability\nthreshold', labelpad=lp-2)
        sns.despine(ax=ax, offset=despine_params)
        
        if subj == 'b1':
            ax.annotate(subj_label, (0.5, 15), ha='center', va='top')
            
        elif subj == 'b3':
            ax.legend(frameon=False, bbox_to_anchor=(-0.85, 0.93), loc='upper left', handlelength=1, ncol=2)
            ax.annotate(subj_label, (0.5, 15), ha='center', va='top')
            ax.spines['left'].set_visible(False)
            ax.yaxis.set_tick_params(length=0)
            ax.axes.set(yticks=[], ylabel='')
    
    return axs

## Overall figure

In [None]:
##### Figure setup #####
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['xtick.major.pad'] = '1.25'
plt.rcParams['ytick.major.pad'] = '1.25'
plot_params = {'linewidth': 1}
lp = 0
default_plot_settings(font='Helvetica', fontsize=6, linewidth=1, ticklength=4)

all_axs = {}
fig = plt.figure(figsize=(7.25, 7.75))
gs = mpl.gridspec.GridSpec(100, 100, figure=fig, hspace=15, wspace=3)

all_axs['schematic'] = fig.add_subplot(gs[:30, :])
all_axs['schematic'].axis('off');

all_axs['fpr_ablation'] = np.array([fig.add_subplot(gs[30:51, :18]), fig.add_subplot(gs[30:51, 18:35])])
all_axs['rt_tpr'] = fig.add_subplot(gs[30:51, 42:57])

all_axs['det_sal_b1'] = fig.add_subplot(gs[30:45, 58:77])
all_axs['det_sal_b3'] = fig.add_subplot(gs[34:53, 72:])

all_axs['fp_count'] = fig.add_subplot(gs[55:74, :34])

all_axs['class_acc'] = fig.add_subplot(gs[59:78, 42:59])

all_axs['class_sal_b1'] = fig.add_subplot(gs[57:70, 59:78])
all_axs['class_sal_b3'] = fig.add_subplot(gs[70:83, 59:78])

all_axs['class_sal_b1_scatter'] = fig.add_subplot(gs[58:69, 78:])
all_axs['class_sal_b3_scatter'] = fig.add_subplot(gs[72:83, 78:])
all_axs['class_sal_scatter_label'] = fig.add_subplot(gs[58:84, 78:])


##########

all_axs['fpr_ablation'] = fpr_ablation(axs=all_axs['fpr_ablation'], ms=4, lp=lp)
all_axs['rt_tpr'] = rt_tpr(ax=all_axs['rt_tpr'], ms=4, lp=lp)

all_axs['det_sal_b1'] = plot_cortex(det_sals['bravo1'], plot_subject='bravo1', ax=all_axs['det_sal_b1'], es=1)
all_axs['det_sal_b3'] = plot_cortex(adjust_order(np.copy(det_sals['bravo3']), fill_value=min(det_sals['bravo3'])), 
                                    plot_subject='bravo3', ax=all_axs['det_sal_b3'], es=1)

all_axs['fp_count'].axis('off')

all_axs['class_acc'] = class_acc_bp(ax=all_axs['class_acc'], ms=4, lp=lp)

all_axs['class_sal_b1'] = plot_cortex(class_sals['b1_overall'], plot_subject='bravo1', ax=all_axs['class_sal_b1'], closeup=True, es=1)
all_axs['class_sal_b3'] = plot_cortex(class_sals['b3_overall'], plot_subject='bravo3', ax=all_axs['class_sal_b3'], closeup=True, es=1)
all_axs['class_sal_b1'].annotate('Bravo-1', (-0.1, 0.5), rotation=90, ha='center', va='center', annotation_clip=False, xycoords='axes fraction')
all_axs['class_sal_b3'].annotate('Bravo-3', (-0.1, 0.5), rotation=90, ha='center', va='center', annotation_clip=False, xycoords='axes fraction')

all_axs['class_sal_b1_scatter'] = class_sal_scatter(plot_subject='bravo1', ax=all_axs['class_sal_b1_scatter'], plot_xlabel=False, ms=1);
all_axs['class_sal_b3_scatter'] = class_sal_scatter(plot_subject='bravo3', ax=all_axs['class_sal_b3_scatter'], ms=1);

all_axs['class_sal_scatter_label'].axis('off');
all_axs['class_sal_scatter_label'].annotate('Electrode contribution\nwith listening distractor', (-0.01, 0.5), rotation=90, ha='center', va='center', annotation_clip=False, xycoords='axes fraction')


#### Panel labels
panel_axs = {}
panel_axs['D'] = fig.add_subplot(gs[32:37, :5])
panel_axs['E'] = fig.add_subplot(gs[32:37, 41:46])
panel_axs['F'] = fig.add_subplot(gs[32:37, 63:68])
panel_axs['G'] = fig.add_subplot(gs[61:65, :5])
panel_axs['H'] = fig.add_subplot(gs[61:65, 40:45])
panel_axs['I'] = fig.add_subplot(gs[61:65, 65:70])
panel_axs['J'] = fig.add_subplot(gs[61:65, 82:87])

for panel, ax in panel_axs.items():
    ax.axis('off')
    ax.annotate(panel, (-18, 30), xycoords='axes points', ha='right', fontsize=9, weight='bold');

## Supplementary

### Fig. S6: Detection vs gate saliences

In [None]:
# Verification model saliences
with open(os.path.join(data_dir, 'gate_sals.pkl'), 'rb') as f:
    gate_sals = pickle.load(f)
    
## for 10-word classifier saliences, apply log and normalization
for key in gate_sals.keys():
    new_sals = np.log(gate_sals[key])
    scaler = MinMaxScaler()
    gate_sals[key] = scaler.fit_transform(new_sals.reshape((-1, 1))).squeeze()

In [None]:
def det_gate_scatter(ax=None, plot_subject=None, plot_ylabel=False, s=30):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
    
    if plot_subject == 'bravo1':
        det_sal = det_sals[plot_subject]
        gate_sal = gate_sals['b1']
    elif plot_subject == 'bravo3':
        det_sal = det_sals[plot_subject]
        gate_sal = gate_sals['b3']
        
    # stat test
    corr, pval = correlation_permutation(det_sal, gate_sal)
    pthresholds = np.array([0.0001, 0.001, 0.01])
    plabel = pthresholds[np.where(pthresholds > pval)[0][0]]
    
    ax.scatter(det_sal, gate_sal, color='k', alpha=0.5, s=s, clip_on=False)
    ax.annotate(f'r={corr:0.2f}\nP < {plabel}', (0.75, 0.1), va='top')
    
    ax.axes.set(xlim=(0, 1), ylim=(0, 1), xticks=[0, 1], yticks=[0, 1], title=subject_labels[plot_subject],
                xlabel='Detection model\nelectrode contribution')
    ax.set_aspect('equal')
    sns.despine(ax=ax, offset=despine_params)
    ax.plot([0, 1], [0, 1], color='k', linestyle='--', linewidth=1, alpha=0.7)
    
    if plot_ylabel:
        ax.axes.set(ylabel='Verification model\nelectrode contribution')
    
    return ax

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8, 4))

det_gate_scatter(plot_subject='bravo1', ax=axs[0], plot_ylabel=True);
det_gate_scatter(plot_subject='bravo3', ax=axs[1]);

### Fig. S3: System ablations

In [None]:
def tpr_ablation(axs=None, plot_full_model=False):
    """
    `plot_full_model` is only for the supplement version of this panel
    """
    
    if plot_full_model:
        x_order = ['Speech-only\nmodel', 'Full model',  'Full model\n+verif.']
        ax_bp = bp.copy(deep=True)
    else:
        ax_bp = bp.loc[bp.model_type != 'Full model']
        x_order = ['Speech-only\nmodel', 'Full model\n+verif.']
    
    if axs is None:
        fig, axs = plt.subplots(1, 2, figsize=(8, 5))
        
    ax = axs[0]
    ax = sns.stripplot(data=ax_bp.loc[ax_bp.subject == 'Bravo-1'], 
                       x='model_type', y='tpr', hue='block_type', order=x_order,
                       palette=colors['paradigms'], ax=ax, dodge=True, jitter=True, clip_on=False, 
                       linewidth=1, size=7)
    sns.despine(ax=ax, offset=despine_params)
    ax.spines['bottom'].set_visible(False)
    ax.xaxis.set_tick_params(length=0)
    
    cur_bp = bp_stats.loc[(bp_stats.subject == 'Bravo-1') & (bp_stats.comp_type == 'tp')]
    if not plot_full_model:
        cur_bp = cur_bp.loc[(cur_bp.model_type1 != 'Full model') & (cur_bp.model_type2 != 'Full model')]
    
    box_pairs = []
    for i in range(cur_bp.shape[0]):
        this_bp = cur_bp.iloc[i]
        box_pairs.append(((this_bp.model_type1, this_bp.block_type1), (this_bp.model_type2, this_bp.block_type2)))
    box_pvals = cur_bp.hb_pval.values
    annotator = Annotator(ax, box_pairs, data=ax_bp.loc[ax_bp.subject == 'Bravo-1'], x='model_type', y='tpr', hue='block_type', order=x_order)
    annotator.configure(test=None, alpha=sig_threshold, test_short_name='test', pvalue_thresholds=pvalue_thresholds, hide_non_significant=True).set_pvalues(pvalues=box_pvals).annotate()
    
    ax.axes.set(ylim=(0, 100), ylabel='True positive rate (%)', xlabel='Bravo-1');
    
    ax.legend([], [], frameon=False)
    
    ax = axs[1]
    ax = sns.stripplot(data=ax_bp.loc[ax_bp.subject == 'Bravo-3'], 
                       x='model_type', y='tpr', hue='block_type', order=x_order,
                       palette=colors['paradigms'], ax=ax, dodge=True, jitter=True, clip_on=False, 
                       linewidth=1, size=7)
    sns.despine(ax=ax, offset=despine_params)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.xaxis.set_tick_params(length=0)
    ax.yaxis.set_tick_params(length=0)
    
    ax.get_legend().remove()
    
    cur_bp = bp_stats.loc[(bp_stats.subject == 'Bravo-3') & (bp_stats.comp_type == 'tp')]
    if not plot_full_model:
        cur_bp = cur_bp.loc[(cur_bp.model_type1 != 'Full model') & (cur_bp.model_type2 != 'Full model')]
    box_pairs = []
    for i in range(cur_bp.shape[0]):
        this_bp = cur_bp.iloc[i]
        box_pairs.append(((this_bp.model_type1, this_bp.block_type1), (this_bp.model_type2, this_bp.block_type2)))
    box_pvals = cur_bp.hb_pval.values
    annotator = Annotator(ax, box_pairs, data=ax_bp.loc[ax_bp.subject == 'Bravo-3'], x='model_type', y='tpr', hue='block_type', order=x_order)
    annotator.configure(test=None, alpha=sig_threshold, test_short_name='test', pvalue_thresholds=pvalue_thresholds, hide_non_significant=True).set_pvalues(pvalues=box_pvals).annotate()
    
    ax.axes.set(ylim=(0, 100), yticklabels=[], yticks=[0, 20], xlabel='Bravo-3', ylabel='');
    
    return axs

In [None]:
def fpr_lineplot(axs=None):
    
    if axs is None:
        fig, axs = plt.subplots(2, 1, figsize=(5, 6), gridspec_kw={'height_ratios': [1, 4]})

    ax = axs[0]
    ax = sns.pointplot(data=fpc.loc[(fpc.model_type != 'Speech-only\nmodel') & (fpc.block_type == 'Baseline')], x='model_type', ax=ax,
                 y='false_positive_count_gauss', hue='subject_block_type_task_type', palette=2*[colors['paradigms'][0]], linestyles=[':', '-'])
    ax.axes.set(yticks=[0, 4, 8], ylim=[0, 8], ylabel='', xlabel='', xticks=[])
    plt.setp(ax.collections, clip_on=False)
    plt.setp(ax.lines, clip_on=False);
    ax.get_legend().set_visible(False)
    sns.despine(ax=ax, offset=despine_params)
    ax.spines['bottom'].set_visible(False)
    ax.xaxis.set_tick_params(length=0)


    ax = axs[1]
    cur_fpc = fpc.loc[(fpc.model_type != 'Speech-only\nmodel') & (fpc.block_type != 'Baseline')]
    sns.pointplot(data=cur_fpc.loc[cur_fpc.subject == 'Bravo-1'], x='model_type', ax=ax,
                 y='false_positive_count_gauss', hue='block_type_task_type', hue_order=['Listen_regular', 'Read_regular', 'Listen_long', 'Read_long', 'Memory_long'],
                  palette=colors['paradigms'][1:]+colors['long_paradigms'])
    sns.pointplot(data=cur_fpc.loc[cur_fpc.subject == 'Bravo-3'], x='model_type', ax=ax,
                 y='false_positive_count_gauss', hue='block_type_task_type', hue_order=['Listen_regular', 'Read_regular', 'Listen_long', 'Read_long', 'Memory_long'],
                  palette=colors['paradigms'][1:]+colors['long_paradigms'], linestyles=':')
    ax.axes.set(yticks=[0, 2, 4, 6, 8], ylim=[0, 8], ylabel='', xlabel='', xlim=[-0.5, 1.5])
    plt.setp(ax.collections, clip_on=False)
    plt.setp(ax.lines, clip_on=False);
    ax.get_legend().set_visible(False)
    sns.despine(ax=ax, offset=despine_params)
    ax.spines['bottom'].set_visible(False)
    ax.xaxis.set_tick_params(length=0)

    box_pairs = [('Full model', 'Full model\n+verif.')]
    a = fpc.loc[(fpc.model_type == 'Full model') & (fpc.block_type != 'Baseline')].false_positive_count.values
    b = fpc.loc[(fpc.model_type == 'Full model\n+verif.') & (fpc.block_type != 'Baseline')].false_positive_count.values
    _, pval = stats.wilcoxon(a, b, alternative='greater')
    print(pval)

    ax.plot([0, 1], [7.8, 7.8], linewidth=1.5, color='k', clip_on=False)
    ax.plot([0, 0], [7.7, 7.8], linewidth=1.5, color='k', clip_on=False)
    ax.plot([1, 1], [7.7, 7.8], linewidth=1.5, color='k', clip_on=False)
    
    
    if pval < 0.0001:
        m = '****'
    elif pval < 0.001:
        m = '***'
    elif pval < 0.01:
        m = '**'
    elif pval < 0.05:
        m = '*'
    
    ax.annotate(m, (0.5, 7.9), fontsize=14, annotation_clip=False)

    ax.set_ylabel('Number of false positives', y=1)


    # Custom legend
    handles = []
    handles.append(Line2D([0], [0], label='Baseline', marker='o', markersize=9, 
                   markeredgecolor=colors['paradigms'][0], markerfacecolor=colors['paradigms'][0], linestyle=''))
    handles.append(Line2D([0], [0], label='Listen', marker='o', markersize=9, 
                   markeredgecolor=colors['paradigms'][1], markerfacecolor=colors['paradigms'][1], linestyle=''))
    handles.append(Line2D([0], [0], label='Listen, long', marker='o', markersize=9, 
                   markeredgecolor=colors['long_paradigms'][0], markerfacecolor=colors['long_paradigms'][0], linestyle=''))
    handles.append(Line2D([0], [0], label='Read', marker='o', markersize=9, 
                   markeredgecolor=colors['paradigms'][2], markerfacecolor=colors['paradigms'][2], linestyle=''))
    handles.append(Line2D([0], [0], label='Read, long', marker='o', markersize=9, 
                   markeredgecolor=colors['long_paradigms'][1], markerfacecolor=colors['long_paradigms'][1], linestyle=''))
    handles.append(Line2D([0], [0], label='Memory, long', marker='o', markersize=9, 
                   markeredgecolor=colors['long_paradigms'][2], markerfacecolor=colors['long_paradigms'][2], linestyle=''))
    handles.append(Line2D([0], [0], label='Bravo-1', color='k', linestyle='-', linewidth=1.5))
    handles.append(Line2D([0], [0], label='Bravo-3', color='k', linestyle=':', linewidth=1.5))
    ax.legend(handles=handles, frameon=False, bbox_to_anchor=(0.85, 1.25), loc='upper left');
    
    return axs

In [None]:
def rt_fpc(ax=None, ytarget=None, return_fig=False, plot_legend=True):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(7, 5))
        
    if ytarget == 'false_positive_sum':
        cur_fpc = fpc.loc[fpc.model_type != 'Full model\n+verif.']
        model_colors = colors['models'][:2]
    else:
        cur_fpc = fpc.copy(deep=True)
        model_colors = colors['models']
        
    cur_fpc = cur_fpc.loc[cur_fpc.task_type == 'long']
        
    ax = sns.barplot(data=cur_fpc, x='subject_block_type', y=ytarget, hue='model_type', 
                     palette=colors['models'], ax=ax, alpha=0.9)
    ax = sns.stripplot(data=cur_fpc, x='subject_block_type', y=ytarget, hue='model_type', 
                       palette=colors['models'], ax=ax, dodge=True, jitter=False, clip_on=False, 
                       linewidth=1, size=7)
    sns.despine(ax=ax, offset=despine_params)
    ax.spines['bottom'].set_visible(False)
    ax.xaxis.set_tick_params(length=0)
    
    xlabs = [i.get_text().split('_')[1] for i in ax.get_xticklabels()]
    ax.axes.set(xticklabels=xlabs)
    
    if ytarget == 'false_positive_count':
        if plot_legend:
            h, l = ax.get_legend_handles_labels()
            ax.legend(h[3:], l[3:], frameon=False, loc='upper left', handlelength=1.5)
        else:
            ax.legend([], [], frameon=False)
        ax.axes.set(ylim=(0, 8), ylabel='Number of false positives', xlabel='');
    else:
        if plot_legend:
            h, l = ax.get_legend_handles_labels()
            ax.legend(h[2:], l[2:], frameon=False, loc='upper left', handlelength=1.5)
        else:
            ax.legend([], [], frameon=False)
        ax.axes.set(ylim=(0, 2500), ylabel='Number of timepoints\nwhere P(speech)>0.5', xlabel='');
        
    ax.annotate('Bravo-1', (1, -0.175), xycoords=('data', 'axes fraction'), ha='center', annotation_clip=False)
    ax.annotate('Bravo-3', (3.5, -0.175), xycoords=('data', 'axes fraction'), ha='center', annotation_clip=False)
    
    if return_fig:
        return fig, ax
    else:
        return ax

In [None]:
##### Figure setup #####
default_plot_settings(font='Helvetica', fontsize=11, linewidth=1.5, ticklength=4)

all_axs = {}
fig = plt.figure(figsize=(15, 9))
gs = mpl.gridspec.GridSpec(100, 100, figure=fig, hspace=15, wspace=2)

all_axs['fpr_ablation'] = np.array([fig.add_subplot(gs[:37, :22]), fig.add_subplot(gs[:37, 24:45])])
all_axs['tpr_ablation'] = np.array([fig.add_subplot(gs[:37, 54:76]), fig.add_subplot(gs[:37, 78:])])

all_axs['fp_count'] = fig.add_subplot(gs[50:, :30])
all_axs['fp_sum'] = fig.add_subplot(gs[50:, 37:66])
all_axs['fp_trend'] = np.array([fig.add_subplot(gs[50:62, 71:89]), fig.add_subplot(gs[66:, 71:89])])

##########

all_axs['tpr_ablation'] = tpr_ablation(axs=all_axs['tpr_ablation'], plot_full_model=True)
all_axs['fpr_ablation'] = fpr_ablation(axs=all_axs['fpr_ablation'], plot_full_model=True)

all_axs['fp_count'] = rt_fpc(ax=all_axs['fp_count'], ytarget='false_positive_count')
all_axs['fp_sum'] = rt_fpc(ax=all_axs['fp_sum'], ytarget='false_positive_sum')

all_axs['fp_trend'] = fpr_lineplot(axs=all_axs['fp_trend']);

#### Panel labels
panel_axs = {}
panel_axs['A'] = fig.add_subplot(gs[8:13, :5])
panel_axs['B'] = fig.add_subplot(gs[8:13, 53:58])
panel_axs['C'] = fig.add_subplot(gs[57:62, :5])
panel_axs['D'] = fig.add_subplot(gs[57:62, 35:40])
panel_axs['E'] = fig.add_subplot(gs[57:62, 70:75])

for panel, ax in panel_axs.items():
    ax.axis('off')
    ax.annotate(panel, (-35, 65), xycoords='axes points', ha='right', fontsize=18, weight='bold');