# Figure 5

In [None]:
import os
import json
import pickle
from datetime import datetime

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from statannot import add_stat_annotation
from statsmodels.stats.multitest import fdrcorrection

from sylseq_paper.plotting import default_plot_settings, ucsf_sequential_color_palette as colors, smoothed_weighted_histogram
from sylseq_paper.stimulation_info import stim_pairs

default_plot_settings(font='Helvetica', fontsize=11, linewidth=1)

%load_ext autoreload
%autoreload 2

# Path setup and load data

In [None]:
data_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'data'))
img_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'imaging'))

mni_img_path = os.path.join(img_dir, 'MNI_{}_{}_brain_2D.png')
subj_img_path = os.path.join(img_dir, '{}_lateral_brain_2D.png')

sig_thresh = 0.05

pvalue_thresholds = [[0.001, "***"], [0.01, "**"], [0.05, "*"], [10, "ns"]]

In [None]:
subject_markers = {
    'EC260': {'marker': '^', 'proportion': 0.5},
    'EC267': {'marker': 'o', 'proportion': 0.5},
    'EC276': {'marker': 'D', 'proportion': 0.4},
    'EC282': {'marker': '*', 'proportion': 1.0},
    'EC289': {'marker': 'v', 'proportion': 0.5}
}

annot_lw = 0.5

## Load data

In [None]:
task_order = {
    'Vocalization': dict(task='vocalization', sequence_complexity='nan', real_word='nan'),
    'Orofacial\nmotor': dict(task='facial motor', sequence_complexity='nan', real_word='nan'),
    'Simple\nsyl.\nseq.': dict(sequence_complexity='simple', real_word='False'),
    'Isolated\nsyl. of\ncomplex\nseq.': dict(sequence_complexity='isolated complex'),
    'Complex\nsyl.\nseq.': dict(sequence_complexity='complex', real_word='False'),
    '4-syl.\nreal\nwords': dict(sequence_complexity='complex', real_word='True')
}

exclude_labels = ['patientbreath', 'patient_breath', 'breath', 'breathing', 'yeah this']

stim_palette = [colors[0], colors[1]]

In [None]:
plot_subjects = ['EC260', 'EC267', 'EC276', 'EC282', 'EC289']

# Load error count df
error_counts = pd.read_hdf(os.path.join(data_dir, 'fig5_stimulation_error_counts.h5'))

# Load syllable duration df
total_patient_annot = pd.read_hdf(os.path.join(data_dir, 'fig5_syllable_duration_annot.h5'))

# Load syllable segmentation (inter-syllable duration) df
total_intersyl_annot = pd.read_hdf(os.path.join(data_dir, 'fig5_syllable_segmentation_annot.h5'))

# Load overlap density dict
with open(os.path.join(data_dir, 'rt_seq_overlap_density_dict.pkl'), 'rb') as f:
    overlap_dict = pickle.load(f)
    
# Load electrode info
total_edf = pd.read_hdf(os.path.join(data_dir, 'all_anatomical_info.h5'))
total_edf = total_edf.loc[total_edf.subject.isin(plot_subjects)]
assert np.isin(plot_subjects, total_edf.subject.unique()).all()

# Figure panels

## Recon with stim sites
Plot sites with sequencing effects and with zero effect.

In [None]:
def stim_recon(ax=None, markersize=200, label_fontsize=9, color_other_errors=False, return_fig=False, plot_overlap_density=False):
    
    img = plt.imread(mni_img_path.format('lh', 'lateral'))
    
    if plot_overlap_density:
        
        weight_list = [overlap_dict['in_seq'], overlap_dict['in_rt']]

        dens = []
        for w in weight_list:
            overlap_density, xedges, yedges = smoothed_weighted_histogram(x=overlap_dict['cur_df'].x.values, 
                                                                          y=overlap_dict['cur_df'].y.values,
                                                                          weights=w, 
                                                                          xlim=[0, img.shape[1]], 
                                                                          ylim=[0, img.shape[0]],
                                                                          bins=100,
                                                                          smooth=2,
                                                                          baseline_norm=True)
            overlap_density /= overlap_density.sum()
            dens.append(overlap_density)
        dens = np.stack(dens, axis=-1)
        dens = dens.sum(-1)
        dens /= dens.max()
        overlap_density = np.copy(dens)
        
        cmap = mpl.colors.LinearSegmentedColormap.from_list('cmap', ['white', colors[2]])
        
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))

    ax.imshow(img, alpha=0.5, zorder=0)
    ax.axis('off');
    
    if plot_overlap_density:
        alphas = np.concatenate([np.zeros(15), np.linspace(0, 0.7, 55), 0.7*np.ones(30)])
        dx = 0.05
        for i in np.arange(0, 1, dx):
            idx = np.where(~np.logical_and(overlap_density > i, overlap_density <= i + dx))
            temp = np.copy(overlap_density)
            temp[idx] = np.nan
            ax.pcolormesh(xedges, yedges, temp.T, alpha=alphas[int(100*i)], vmin=0, vmax=1, cmap=cmap, zorder=0, rasterized=True)
            
    for subject in plot_subjects:

        edf = total_edf.loc[total_edf.subject == subject]
        
        i = 0
        marker_alpha = 0.5
        seq_alpha = 0.8

        for clinical_elec_key in stim_pairs[subject].keys():

            cur_err_dict = stim_pairs[subject][clinical_elec_key]
            marker = subject_markers[subject]['marker']
            p = subject_markers[subject]['proportion']

            if cur_err_dict['seq_deficit'] and not cur_err_dict['inconsistent_seq_deficit'] and not cur_err_dict['motor_deficit']:
                # Sequencing elecs
                kwargs = dict(facecolor=colors[1], edgecolor='k', s=int(p*markersize), marker=marker, zorder=3, alpha=seq_alpha)
                i += 1
            elif not cur_err_dict['seq_deficit'] and not cur_err_dict['inconsistent_seq_deficit']:
                # No sequence effect
                
                # Elecs with no sequence effect but motor at current or higher amplitudes
                if cur_err_dict['motor_deficit'] or ('higher_amp_motor_deficit' in cur_err_dict.keys() and cur_err_dict['higher_amp_motor_deficit'] == True):
                    if color_other_errors:
                        kwargs = dict(facecolor=colors[4], edgecolor='k', s=markersize//2, marker=marker, alpha=0.65, zorder=2)
                    else:
                        kwargs = dict(facecolor='grey', edgecolor='k', s=int(p*markersize//2), marker=marker, alpha=marker_alpha, zorder=2)
                
                elif not cur_err_dict['motor_deficit'] or ('higher_amp_motor_deficit' in cur_err_dict.keys() and cur_err_dict['higher_amp_motor_deficit'] in [False, 'not_tested']):
                    
                    if 'other_deficit' in cur_err_dict.keys():
                        # Elecs with no motor deficits but other deficits (perceptual or sensory)
                        if color_other_errors:
                            kwargs = dict(facecolor=colors[3], edgecolor='k', s=markersize//2, marker=marker, alpha=0.5, zorder=2)
                        else:
                            kwargs = dict(facecolor='grey', edgecolor='k', s=int(p*markersize//2), marker=marker, alpha=marker_alpha, zorder=2)
                        
                    else:
                        # Elecs with no motor deficits and no other deficits
                        kwargs = dict(facecolor='None', edgecolor='k', s=int(p*markersize//2), marker=marker, zorder=2, alpha=marker_alpha)
                
                else:
                    print(subject, clinical_elec_key)
                    continue
            else:
                print(subject, clinical_elec_key)
                continue
            
            if cur_err_dict['center_research_elec'] == 'interpolate':
                x1 = edf.loc[edf.electrode == cur_err_dict['research_elecs'][0]-1].iloc[0].warp_x
                y1 = edf.loc[edf.electrode == cur_err_dict['research_elecs'][0]-1].iloc[0].warp_y
                x2 = edf.loc[edf.electrode == cur_err_dict['research_elecs'][1]-1].iloc[0].warp_x
                y2 = edf.loc[edf.electrode == cur_err_dict['research_elecs'][1]-1].iloc[0].warp_y
                x = np.mean([x1, x2])
                y = np.mean([y1, y2])

            else:
                elec = cur_err_dict['center_research_elec'] - 1
                x = edf.loc[edf.electrode == elec].iloc[0].warp_x
                y = edf.loc[edf.electrode == elec].iloc[0].warp_y

            if i == 1 and kwargs['facecolor'] == colors[1]:
                label = subject
            else:
                label = None
                
            ax.scatter(x, y, label=label, **kwargs)
            
    if color_other_errors:
        add_patches = [mpl.patches.Patch(color=colors[1]), mpl.patches.Patch(color=colors[3]), mpl.patches.Patch(color=colors[4]), mpl.patches.Patch(facecolor='None', edgecolor='k')]
        add_labels = ['Sequencing errors', 'Sensory effects', 'Direct motor effects', 'No deficits']
    else:
        add_patches = [mpl.patches.Patch(color=colors[1]), mpl.patches.Patch(color='grey'), mpl.patches.Patch(facecolor='None', edgecolor='k')]
        add_labels = ['Sequencing errors', 'Other sensorimotor effects', 'No deficits', ]
    
    handles, labels = ax.get_legend_handles_labels()
    
    subj_legend = ax.legend(handles, labels, frameon=False, ncol=3, loc='lower left', fontsize=label_fontsize, 
                            bbox_to_anchor=(-0.05, -0.25), columnspacing=2.5);
    ax.add_artist(subj_legend)

    patch_legend = ax.legend(add_patches, add_labels, frameon=False, ncol=2, loc='lower left', fontsize=label_fontsize,
                             bbox_to_anchor=(-0.05, -0.45), columnspacing=0.75);
         
    if return_fig:
        return fig, ax
    else:
        return ax

In [None]:
f, _ = stim_recon(plot_overlap_density=True, return_fig=True);

## Error chart

In [None]:
def error_chart(ax=None, error_type=None, metric='count', plot_legend=True, plot_xticklabels=True, plot_title=False, 
                plot_ylabel=True, label_fontsize=None, include_only_complex=False, legend_inside=True, yticks=None, markersize=5):
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 3))
        
    cur_error_counts = error_counts.loc[error_counts.error_type == error_type]
    
    if include_only_complex:
        cur_error_counts = cur_error_counts.loc[cur_error_counts.task_label.isin(['Complex\nsyl.\nseq.', '4-syl.\nreal\nwords'])]
        torder = ['Complex\nsyl.\nseq.', '4-syl.\nreal\nwords']
    else:
        torder = task_order.keys()
    
    if metric == 'count':
        ylabel = 'Number of{}errors'
        ylim = (0, None)
        title = None
        
    elif metric == 'percent':
        ylim = (0, 100)
        
        if plot_title:
            
            if error_type.startswith('syllable'):
                title = error_type.replace('_', '\n')
            else:
                title = error_type.replace('_', ' ')
            
            ylabel = 'Percent utterances with errors'
            
        else:
            title = None
            ylabel = 'Percent utterances with{}errors'
        
    if error_type == 'all':
        ylabel = ylabel.format(' ')
    else:
        ylabel = ylabel.format('\n' + error_type.replace('_', ' ') + ' ')
        
    if not plot_ylabel:
        ylabel = None
    
    for subject in plot_subjects:
        cur_cur_error_counts = cur_error_counts.loc[cur_error_counts.subject == subject]
        
        if cur_cur_error_counts.shape[0] == 0:
            continue
        ax = sns.stripplot(data=cur_cur_error_counts, 
                           x='task_label', y=metric, hue='stimulation_label', marker=subject_markers[subject]['marker'], 
                           hue_order=['No stim.', 'Stim.'], palette=stim_palette, order=torder, ax=ax, clip_on=False, size=markersize)
    ax.axes.set(xlabel='', ylabel=ylabel, ylim=ylim);
    
    if label_fontsize is not None:
        ax.set_title(title, fontsize=label_fontsize)
    
    sns.despine(ax=ax, offset=dict(left=1, bottom=2.5))
    ax.spines['bottom'].set_visible(False)
    ax.xaxis.set_tick_params(length=0)
    
    if yticks is not None:
        ax.set_yticks(yticks)
    
    if plot_legend:
        handles, labels = ax.get_legend_handles_labels()
        for handle in handles:
            handle._sizes = [markersize]
        if legend_inside:
            ax.legend(handles[:2], labels[:2], frameon=False, loc='upper left', fontsize=label_fontsize)
        else:
            ax.legend(handles[:2], labels[:2], frameon=False, loc='upper left', fontsize=label_fontsize, bbox_to_anchor=(1, 1.05))
    else:
        ax.legend([], [], frameon=False)
        
    if plot_xticklabels:
        ax.set_xticklabels(torder, fontsize=label_fontsize)
    else:
        ax.axes.set(xticklabels=[])
    
    return ax

In [None]:
error_chart(error_type='all', metric='percent', include_only_complex=False);

In [None]:
fig, axs = plt.subplots(4, 1, figsize=(3, 8))

plot_legend = False
plot_ylabel = False


for ax, error_type in zip(axs.ravel(), ['distortion', 'syllable_segmentation', 'syllable_duration', 'stuttering']):
    
    if error_type == 'stuttering':
        plot_xticklabels = True
    else:
        plot_xticklabels = False
        
    error_chart(error_type=error_type, 
                 metric='percent', 
                 ax=ax,
                 plot_legend=plot_legend,
                 legend_inside=not plot_legend,
                 plot_title=True,
                 plot_ylabel=plot_ylabel,
                 plot_xticklabels=plot_xticklabels,
                 label_fontsize=11,
                 include_only_complex='True',
               yticks=[0, 50, 100]);
    
fig.tight_layout();

## Syllable duration

In [None]:
def syllable_duration(axs=None, patient_annot=None, label_fontsize=9, verbose=1, markersize=2, source_stats={}, return_source_stats=False):
    
    if axs is None:
        fig, axs = plt.subplots(1, 3, figsize=(9, 4))
        
    source_stats = {key: [] for key in ['Subject', 'Comparison', 'Sample sizes', 'Corrected P-value', 'Statistic']}

    order_labels = {
        'simple': 'Simple sequences', 
        'isolated complex': 'Isolated syllables of\ncomplex sequences', 
        'complex': 'Complex sequences'
    }
    ylims = [
        [0.0, 1],
        [0.0, 1],
        [0.0, 1]
    ]

    for cur_ax, (seq_label, ax) in enumerate(zip(order_labels.keys(), axs)):

        sns.despine(ax=ax, offset=dict(left=2.5, bottom=9))
        ax.spines['bottom'].set_visible(False)

        kwargs = dict(x='subject', y='duration', hue='stimulation_label', 
                      order=plot_subjects, hue_order=['No stim.', 'Stim.'])

        boxplot_kwargs = {}
        for key in ['cap', 'whisker', 'flier', 'median', 'mean']:
            boxplot_kwargs[f'{key}props'] = dict(clip_on=False)
        boxplot_kwargs['boxprops'] = dict(alpha=0.3, clip_on=False)

        cur_df = patient_annot.loc[patient_annot.sequence_complexity == seq_label].copy(deep=True)

        if cur_df.shape[0] == 0:
            continue
            
        if seq_label == 'complex':
            temp = cur_df.loc[cur_df.duration > 1]
            print(f'{temp.shape[0]} trials for {temp.subject.unique()} greater than 1 s, displayed at 1s on plot')
            cur_df['duration'] = np.clip(cur_df['duration'], a_min=0, a_max=1)

        ax = sns.boxplot(data=cur_df, ax=ax, showfliers=False, palette=stim_palette, **kwargs, **boxplot_kwargs)
        sns.stripplot(data=cur_df, ax=ax, dodge=True, clip_on=False, palette=stim_palette, s=markersize, **kwargs)
        
        ax.tick_params(axis='x', length=0)
        ax.set_xticklabels(ax.get_xticklabels(), fontsize=label_fontsize, rotation=90, ha='center')
        ax.axes.set(xlabel=order_labels[seq_label], 
                    yticks=np.arange(ylims[cur_ax][0], ylims[cur_ax][1]+0.1, 0.5), ylim=ylims[cur_ax], 
                    ylabel='Syllable duration (s)')
        
        if cur_ax == 0:
            ax.annotate('# samples', (-0.7, -0.177), fontsize=label_fontsize, annotation_clip=False, va='bottom', ha='right', xycoords='data')
        
        box_pairs, box_pvals, box_zvals = [], [], []
        trial_labels = []
        for subject in plot_subjects:
            
            n_no_stim = cur_df.loc[(cur_df.subject == subject) & (cur_df.stimulation == False)].shape[0]
            n_stim = cur_df.loc[(cur_df.subject == subject) & (cur_df.stimulation == True)].shape[0]
            
            if n_no_stim < 5 or n_stim < 5:
                trial_labels.append(['' if n_no_stim == 0 else n_no_stim, '' if n_no_stim == 0 else n_stim])
                continue
            else:
                trial_labels.append([n_no_stim, n_stim])
                
            a = cur_df.loc[(cur_df.subject == subject) & (cur_df.stimulation == False)].duration.values
            b = cur_df.loc[(cur_df.subject == subject) & (cur_df.stimulation == True)].duration.values
            z, p = stats.ranksums(b, a, alternative='greater')
            
            box_pairs.append(((subject, 'Stim.'), (subject, 'No stim.')))
            box_pvals.append(p)
            box_zvals.append(z)
            
            source_stats['Subject'].append(subject)
            source_stats['Comparison'].append('{} -- no stim. vs. stim.'.format(order_labels[seq_label].replace('\n', ' ')))
            source_stats['Sample sizes'].append([n_no_stim, n_stim])
            source_stats['Statistic'].append(z)
            
        _, box_pvals = fdrcorrection(box_pvals, alpha=sig_thresh, method='poscorr', is_sorted=False)
        source_stats['Corrected P-value'].extend(box_pvals)
        
        for cur_subj, label in enumerate(trial_labels):
            ax.annotate(label[0], (cur_subj, -0.115), fontsize=label_fontsize, color=stim_palette[0],
                        annotation_clip=False, va='bottom', ha='center', xycoords='data')
            ax.annotate(label[1], (cur_subj, -0.177), fontsize=label_fontsize, color=stim_palette[1],
                        annotation_clip=False, va='bottom', ha='center', xycoords='data')
            
        add_stat_annotation(ax, data=cur_df, box_pairs=box_pairs, perform_stat_test=False,
                           pvalues=box_pvals, text_format='star', linewidth=annot_lw,
                           loc='outside', verbose=verbose, pvalue_thresholds=pvalue_thresholds, **kwargs);

        if cur_ax == 0:
            # Fix legend
            hand, labl = ax.get_legend_handles_labels()
            handout, lablout = [], []
            for h, l in zip(hand, labl):
                h._sizes = [markersize]
                if l not in lablout and type(h) == mpl.collections.PathCollection:
                    lablout.append(l)
                    handout.append(h)
            lablout = [l.capitalize() for l in lablout]
            ax.legend(handout, lablout, frameon=False, loc='upper left', bbox_to_anchor=(-0.2, 1.05), fontsize=label_fontsize)
            
        else:
            ax.legend([],[], frameon=False)

        if cur_ax != 0:
            ax.axes.set(ylabel='')

    source_stats = pd.DataFrame(data=source_stats)
    
    if return_source_stats:
        return axs, source_stats
    else:
        return axs

In [None]:
_, ss = syllable_duration(patient_annot=total_patient_annot, verbose=1, return_source_stats=True);

## Inter-syllable duration (syllable segmentation)

In [None]:
def syllable_segmentation(axs=None, intersyl_annot=None, label_fontsize=9, verbose=1, markersize=3, source_stats={}, return_source_stats=False):
    
    if axs is None:
        fig, axs = plt.subplots(1, 3, figsize=(9, 4))
        
    source_stats = {key: [] for key in ['Subject', 'Comparison', 'Sample sizes', 'Corrected P-value', 'Statistic']}

    order_labels = {
        'simple': 'Simple sequences', 
        'isolated complex': 'Isolated syllables of\ncomplex sequences', 
        'complex': 'Complex sequences'
    }
    ylims = [
        [0.0, 1.5],
        [0.0, 1.5],
        [0.0, 1.5]
    ]

    for cur_ax, (seq_label, ax) in enumerate(zip(order_labels.keys(), axs)):

        sns.despine(ax=ax, offset=dict(left=2.5, bottom=9))
        ax.spines['bottom'].set_visible(False)

        kwargs = dict(x='subject', y='duration', hue='stimulation_label', 
                      order=plot_subjects, hue_order=['No stim.', 'Stim.'])

        boxplot_kwargs = {}
        for key in ['cap', 'whisker', 'flier', 'median', 'mean']:
            boxplot_kwargs[f'{key}props'] = dict(clip_on=False)
        boxplot_kwargs['boxprops'] = dict(alpha=0.3, clip_on=False)

        cur_df = intersyl_annot.loc[(intersyl_annot.interword == False) & (intersyl_annot.sequence_complexity == seq_label)]
        ax = sns.boxplot(data=cur_df, ax=ax, showfliers=False, palette=stim_palette, **kwargs, **boxplot_kwargs)
        sns.stripplot(data=cur_df, ax=ax, dodge=True, clip_on=False, palette=stim_palette, s=markersize, **kwargs)
        
        ax.tick_params(axis='x', length=0)
        ax.set_xticklabels(ax.get_xticklabels(), fontsize=label_fontsize, rotation=90, ha='center')
        ax.axes.set(xlabel=order_labels[seq_label], 
                    yticks=np.arange(ylims[cur_ax][0], ylims[cur_ax][1]+0.1, 0.5), ylim=ylims[cur_ax], 
                    ylabel='Inter-syllable\nduration (s)')
        
        if cur_ax == 0:
            ax.annotate('# samples', (-0.7, -0.28), fontsize=label_fontsize, annotation_clip=False, va='bottom', ha='right', xycoords='data')
        
        box_pairs, box_pvals, box_zvals = [], [], []
        trial_labels = []
        for subject in plot_subjects:
            
            n_no_stim = cur_df.loc[(cur_df.subject == subject) & (cur_df.stimulation == False)].shape[0]
            n_stim = cur_df.loc[(cur_df.subject == subject) & (cur_df.stimulation == True)].shape[0]
            
            if n_no_stim < 5 or n_stim < 5:
                trial_labels.append(['' if n_no_stim == 0 else n_no_stim, '' if n_no_stim == 0 else n_stim])
                continue
            else:
                trial_labels.append([n_no_stim, n_stim])
                
            a = cur_df.loc[(cur_df.subject == subject) & (cur_df.stimulation == False)].duration.values
            b = cur_df.loc[(cur_df.subject == subject) & (cur_df.stimulation == True)].duration.values
            z, p = stats.ranksums(b, a, alternative='greater')
            
            box_pairs.append(((subject, 'Stim.'), (subject, 'No stim.')))
            box_pvals.append(p)
            box_zvals.append(z)
            
            source_stats['Subject'].append(subject)
            source_stats['Comparison'].append('{} -- no stim. vs. stim.'.format(order_labels[seq_label].replace('\n', ' ')))
            source_stats['Sample sizes'].append([n_no_stim, n_stim])
            source_stats['Statistic'].append(z)
            
        _, box_pvals = fdrcorrection(box_pvals, alpha=sig_thresh, method='poscorr', is_sorted=False)
        source_stats['Corrected P-value'].extend(box_pvals)
        
        for cur_subj, label in enumerate(trial_labels):
            ax.annotate(label[0], (cur_subj, -0.18), fontsize=label_fontsize, color=stim_palette[0],
                        annotation_clip=False, va='bottom', ha='center', xycoords='data')
            ax.annotate(label[1], (cur_subj, -0.28), fontsize=label_fontsize, color=stim_palette[1],
                        annotation_clip=False, va='bottom', ha='center', xycoords='data')
            
        add_stat_annotation(ax, data=cur_df, box_pairs=box_pairs, perform_stat_test=False,
                           pvalues=box_pvals, text_format='star', pvalue_thresholds=pvalue_thresholds,
                           loc='outside', linewidth=annot_lw, verbose=verbose, **kwargs);

        if cur_ax == 0:
            # Fix legend
            hand, labl = ax.get_legend_handles_labels()
            handout, lablout = [], []
            for h, l in zip(hand, labl):
                h._sizes = [markersize]
                if l not in lablout and type(h) == mpl.collections.PathCollection:
                    lablout.append(l)
                    handout.append(h)
            lablout = [l.capitalize() for l in lablout]
            ax.legend(handout, lablout, frameon=False, loc='upper left', bbox_to_anchor=(-0.2, 1.05), fontsize=label_fontsize)
            
        else:
            ax.legend([],[], frameon=False)

        if cur_ax != 0:
            ax.axes.set(ylabel='')
    
    source_stats = pd.DataFrame(data=source_stats)
    
    if return_source_stats:
        return axs, source_stats
    else:
        return axs

In [None]:
_, ss = syllable_segmentation(intersyl_annot=total_intersyl_annot, verbose=1, return_source_stats=True);

## Microphone examples

In [None]:
word_syl_dict = {
    'catastrophe': ['ca', 'ta', 'stro', 'phe'],
    'blaadraagloo': ['blaa', 'draa', 'gloo'],
    'papapa': ['pa', 'pa', 'pa']
}

spx_examples = {
    'catastrophe': {
        'segmentation_indices': [[325, 326], [323, 324], [330, 331], [332, 333]],
        'duration_indices': [323, 326, 331, 332]
    },
    'vocalization': {},
    'blaadraagloo': {
        'segmentation_indices': [[210, 211], [211, 212], [213, 214], [217, 218], [218, 219]],
        'duration_indices': [210, 213, 217, 218]
    },
    'papapa': {}
}

# Load panel data
with open(os.path.join(data_dir, 'fig5_spectrogram_examples.pkl'), 'rb') as f:
    spx_plot_data = pickle.load(f)

In [None]:
def stim_spec_panel(trial_key=None, fig=None, gs=None, label_fontsize=10):
    
    if gs is None:
        fig, axs = plt.subplots(2, 1, figsize=(10, 4), gridspec_kw={'height_ratios': [1, 5]}, sharex=True)
    else:
        axs = [fig.add_subplot(gs[i, 0]) for i in range(2)]
        
    stim_sr = spx_plot_data[trial_key]['stim_sr']
    trial_stim = spx_plot_data[trial_key]['trial_stim']
    S_db = spx_plot_data[trial_key]['spectrogram']
    fig_annot = spx_plot_data[trial_key]['fig_annot']
    start_time = spx_plot_data[trial_key]['start_time']
    stop_time = spx_plot_data[trial_key]['stop_time']

    stim_x = np.linspace(0, stop_time - start_time, len(trial_stim))
    xlim = (stim_x[0], stim_x[-1])
    
    axs[0].plot(stim_x, trial_stim, color='k')
    axs[0].fill_between(stim_x, 0, trial_stim, color='lightgrey', rasterized=True)
    stim_zero_line = np.where(trial_stim > 0, 0, np.nan)
    axs[0].plot(stim_x, stim_zero_line, color='k')
    axs[0].annotate('stimulation', (0, 0.2), fontstyle='italic', fontsize=label_fontsize)
    axs[0].axes.set(xlim=xlim)
    axs[0].axis('off')

    spec_t = np.linspace(stim_x[0], stim_x[-1], S_db.shape[1])
    spec_f = np.linspace(0, stim_sr/2000, S_db.shape[0])
    axs[1].pcolormesh(spec_t, spec_f, S_db, cmap='Blues', vmin=-70, rasterized=True)
    axs[1].axes.set(ylim=(0, 8), ylabel='Frequency (kHz)', yticks=[0, 4, 8], xlabel='Time (s)', xlim=xlim)

    for cur_label, (onset, label) in enumerate(zip(fig_annot.onset.values, fig_annot.label.values)):
        
        if trial_key != 'vocalization' and label not in word_syl_dict[trial_key]:
            color = colors[3]
            weight = 'bold'
            alpha = 1.0
        else:
            color = '#404040'
            weight = 'regular'
            alpha = 0.7
            
        if stop_time - onset < 0.5:
            continue
            
        if cur_label % 2 == 0:
            height = 8.1
        else:
            height = 8.6
        
        axs[1].annotate(label, (onset - start_time, height), fontstyle='italic', weight=weight,
                        fontsize=label_fontsize, annotation_clip=False, color=color, alpha=alpha)
        
    axs[1].annotate(' ', (onset - start_time, 8.6), fontstyle='italic', weight=weight,
                        fontsize=label_fontsize, annotation_clip=False, color=color)

    if 'segmentation_indices' in spx_examples[trial_key]:
        
        height = 3.5
        
        for (start, stop) in spx_examples[trial_key]['segmentation_indices']:
            
            onset = fig_annot.loc[start].offset - start_time
            offset = fig_annot.loc[stop].onset - start_time
            axs[1].plot((onset, offset), (height, height), color=colors[1])#, linewidth=1)
            axs[1].plot((onset, onset), (height - 0.15, height + 0.15), color=colors[1])#, linewidth=1)
            axs[1].plot((offset, offset), (height - 0.15, height + 0.15), color=colors[1])#, linewidth=1)
            
        axs[1].annotate('Syllable\nsegmentation', (stop_time - start_time, height), va='center', ha='left', color=colors[1], fontsize=label_fontsize, linespacing=1)
            
    if 'duration_indices' in spx_examples[trial_key]:
        
        height = 5.5
        
        for idx in spx_examples[trial_key]['duration_indices']:
            
            onset = fig_annot.loc[idx].onset - start_time
            offset = fig_annot.loc[idx].offset - start_time
            axs[1].plot((onset, offset), (height, height), color=colors[2])#, linewidth=1)
            axs[1].plot((onset, onset), (height - 0.15, height + 0.15), color=colors[2])#, linewidth=1)
            axs[1].plot((offset, offset), (height - 0.15, height + 0.15), color=colors[2])#, linewidth=1)
            
        axs[1].annotate('Syllable\nduration', (stop_time - start_time, height), va='center', ha='left', color=colors[2], fontsize=label_fontsize, linespacing=1)
    
    return axs

In [None]:
subject = 'EC260'
trial_key = 'catastrophe'
stim_spec_panel(trial_key=trial_key);

# Overall figure

In [None]:
##### Figure setup #####
default_plot_settings(font='Helvetica', fontsize=5, linewidth=0.5, ticklength=2)

plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['axes.labelpad'] = 1
plt.rcParams['patch.linewidth'] = 0.25
plt.rcParams['axes.titlepad'] = 1
plt.rcParams['xtick.major.pad'] = 2
plt.rcParams['ytick.major.pad'] = 2
plt.rcParams['legend.handletextpad'] = 0.5

all_axs = {}
mm = 1/25.4
fig = plt.figure(figsize=(180*mm, 160*mm))
gs = mpl.gridspec.GridSpec(100, 100, figure=fig, wspace=75, hspace=100)

#########################

all_axs['recon'] = fig.add_subplot(gs[:32, :27])
all_axs['recon'] = stim_recon(ax=all_axs['recon'], markersize=25, label_fontsize=5, plot_overlap_density=True);

all_axs['all_errors'] = fig.add_subplot(gs[4:24, 35:71])
all_axs['all_errors'] = error_chart(error_type='all', metric='percent', ax=all_axs['all_errors'], label_fontsize=5, markersize=2);

error_gs = mpl.gridspec.GridSpecFromSubplotSpec(3, 2, subplot_spec=gs[2:29, 75:], wspace=0.2, hspace=0.45)
error_types = ['syllable_segmentation', 'syllable_duration', 'distortion', 'stuttering', 'pause']
i = 0
for r in range(3):
    for c in range(2):
        error_type = error_types[i]
        all_axs[f'error_{error_type}'] = fig.add_subplot(error_gs[r, c])
        
        if r == 2 and c == 0:
            plot_legend = True
        else:
            plot_legend = False
            
        if r == 2:
            plot_xticklabels = True
        else:
            plot_xticklabels = False
            
        if r == 1 and c == 0:
            plot_ylabel = True
        else:
            plot_ylabel = False
        
        
        all_axs[f'error_{error_type}'] = error_chart(error_type=error_type, 
                                                     metric='percent', 
                                                     ax=all_axs[f'error_{error_type}'],
                                                     plot_legend=plot_legend,
                                                     legend_inside=not plot_legend,
                                                     plot_title=True,
                                                     plot_ylabel=plot_ylabel,
                                                     plot_xticklabels=plot_xticklabels,
                                                     label_fontsize=5,
                                                     include_only_complex='True',
                                                     markersize=2);
        i += 1
        if i == 5:
            break
    

seg_gs = mpl.gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=gs[43:61, :40], wspace=0.4)
all_axs['segmentation'] = [fig.add_subplot(seg_gs[i]) for i in range(3)]
all_axs['segmentation'] = syllable_segmentation(intersyl_annot=total_intersyl_annot, axs=all_axs['segmentation'], label_fontsize=5, markersize=1.5)

dur_gs = mpl.gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=gs[75:93, :40], wspace=0.4)
all_axs['duration'] = [fig.add_subplot(dur_gs[i]) for i in range(3)]
all_axs['duration'] = syllable_duration(patient_annot=total_patient_annot, axs=all_axs['duration'], label_fontsize=5, markersize=1.5)


error_gs = mpl.gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[35:53, 45:93], height_ratios=[1, 6], hspace=0.3)
trial_key = 'catastrophe'
all_axs[trial_key] = stim_spec_panel(trial_key=trial_key, gs=error_gs, fig=fig, label_fontsize=5)

error_gs = mpl.gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[57:75, 45:93], height_ratios=[1, 6], hspace=0.3)
trial_key = 'blaadraagloo'
all_axs[trial_key] = stim_spec_panel(trial_key=trial_key, gs=error_gs, fig=fig, label_fontsize=5)

error_gs = mpl.gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[79:97, 45:73], height_ratios=[1, 6], hspace=0.3)
trial_key = 'papapa'
all_axs[trial_key] = stim_spec_panel(trial_key=trial_key, gs=error_gs, fig=fig, label_fontsize=5)

error_gs = mpl.gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[79:97, 77:95], height_ratios=[1, 6], hspace=0.3)
trial_key = 'vocalization'
all_axs[trial_key] = stim_spec_panel(trial_key=trial_key, gs=error_gs, fig=fig, label_fontsize=5)


##### PANEL LABELS #####
panel_label_kwargs = dict(xycoords='axes fraction', ha='right', fontsize=7, weight='bold')
all_axs['recon'].annotate('a', (-0.1, 1.07), **panel_label_kwargs);
all_axs['all_errors'].annotate('b', (-0.15, 1.13), **panel_label_kwargs);
all_axs['error_syllable_segmentation'].annotate('c', (-0.45, 1.05), **panel_label_kwargs);
all_axs['segmentation'][0].annotate('d', (-0.45, 1.1), **panel_label_kwargs);
all_axs['duration'][0].annotate('e', (-0.45, 1.1), **panel_label_kwargs);
all_axs['catastrophe'][0].annotate('f', (-0.06, 0.95), **panel_label_kwargs);
all_axs['papapa'][0].annotate('g', (-0.1, 0.95), **panel_label_kwargs);

# Supplementary figures

## Fig. S16 - Recon with sensorimotor error types broken down

In [None]:
fig, _ = stim_recon(color_other_errors=True, plot_overlap_density=True, return_fig=True)

## Fig. S17 - All error types for all tasks

In [None]:
fig, axs = plt.subplots(3, 2, figsize=(11, 9))
fs = 9

counter = 0
for error_type in error_counts.error_type.unique():
    
    ax = axs.ravel()[counter]
    
    if error_type == 'all':
        continue
        
    ax = error_chart(ax=ax, error_type=error_type, metric='percent', 
                     include_only_complex=False, plot_legend=False, 
                     plot_ylabel=False, label_fontsize=fs);
    ax.set_title(error_type.replace ('_', ' ').capitalize(), fontsize=fs+3)
    ax.set_ylabel('Percent utterances with errors', fontsize=fs)
    counter += 1
    
fig.tight_layout();