# Figure S1-3 - Participant behavior

In [None]:
import json
import os
from itertools import combinations

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from scipy import stats
from statannot import add_stat_annotation
from statsmodels.stats.multitest import multipletests

from sylseq_paper.plotting import default_plot_settings, ucsf_sequential_color_palette as colors

default_plot_settings()

In [None]:
data_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'data'))

subjects = ['EC217', 'EC219', 'EC223', 'EC237', 'EC240', 'EC241', 'EC253',
            'EC254', 'EC260', 'EC263', 'EC267', 'EC276', 'EC282', 'EC289']

sig_thresh = 0.05
pvalue_thresholds = [[0.001, "***"], [0.01, "**"], [0.05, "*"], [10, "ns"]]

# Load mistakes dataframe
adf = pd.read_hdf(os.path.join(data_dir, 'supp_participant_behavior.h5'))

In [None]:
conditions = np.array(['noneSeq_simpleSyl', 'noneSeq_complexSyl', 
                       'simpleSeq_simpleSyl', 'simpleSeq_complexSyl', 
                       'complexSeq_simpleSyl', 'complexSeq_complexSyl'])
condition_labels = np.array(['Isolated\nsimple syl.', 'Isolated\ncomplex syl.', 
                             'Simple syl.,\nsimple seq.', 'Complex syl.,\nsimple seq.', 
                             'Simple syl.,\ncomplex seq.', 'Complex syl.,\ncomplex seq.'])
condition_dict = dict(zip(conditions, condition_labels))

condition_colors = {
    'noneSeq_simpleSyl': colors[2],
    'noneSeq_complexSyl': colors[7],
    'simpleSeq_simpleSyl': colors[4],
    'simpleSeq_complexSyl': 'sienna',
    'complexSeq_simpleSyl': colors[1],
    'complexSeq_complexSyl': colors[5]
}

# Mistake counts per participant

In [None]:
mistake_counts = {key: [] for key in ['subject', 'condition', 'trial_type', 'count', 'percent']}

for key in conditions:
    cur_adf = adf.loc[(adf.condition == key)]
    
    for subject in adf.subject.unique():
        cur_subj_adf = cur_adf.loc[cur_adf.subject == subject]
        total_trials = cur_subj_adf.shape[0]
        
        if total_trials == 0:
            continue
        
        mistake_counts['subject'].append(subject)
        mistake_counts['condition'].append(key)
        mistake_counts['trial_type'].append('correct')
        mistake_counts['count'].append(cur_subj_adf.loc[~cur_subj_adf.decoder_data_mistake].shape[0])
        mistake_counts['percent'].append(100*cur_subj_adf.loc[~cur_subj_adf.decoder_data_mistake].shape[0] / total_trials)
        
        mistake_counts['subject'].append(subject)
        mistake_counts['condition'].append(key)
        mistake_counts['trial_type'].append('mistake')
        mistake_counts['count'].append(cur_subj_adf.loc[cur_subj_adf.decoder_data_mistake].shape[0])
        mistake_counts['percent'].append(100*cur_subj_adf.loc[cur_subj_adf.decoder_data_mistake].shape[0] / total_trials)
        
mistake_counts = pd.DataFrame(data=mistake_counts)

## Fig. S3: percent correct vs. mistake trials

In [None]:
fs = 12
default_plot_settings(font='Helvetica', fontsize=fs)

fig, axs = plt.subplots(4, 4, figsize=(14, 13), gridspec_kw=dict(wspace=0.5, hspace=1))

for cur_ax, ax in enumerate(axs.ravel()):
    
    if cur_ax >= len(subjects):
        ax.axis('off')
        continue
        
    subject = subjects[cur_ax]
    subj_df = mistake_counts.loc[mistake_counts.subject == subject]
    idx = np.where(np.isin(conditions, subj_df.condition.unique()))[0]
    subj_conditions = conditions[idx]
    subj_condition_labels = condition_labels[idx]

    g = sns.barplot(data=subj_df, x='condition', y='percent', hue='trial_type',
                    order=subj_conditions, ax=ax, alpha=0.75)
    ax.axes.set(ylabel='Percent of trials', ylim=(0, 100), xlabel='')
    ax.set_title(subject, pad=17, fontsize=fs)
    ax.set_xticklabels(subj_condition_labels, fontsize=fs-3, rotation=90)
    
    for i, cond in enumerate(subj_conditions):
        n_trials = subj_df.loc[subj_df.condition == cond]['count'].values.sum()
        if i == 0:
            trial_label = f'n={n_trials}'
        else:
            trial_label = str(n_trials)
        ax.annotate(trial_label, (i, 100), ha='center', va='bottom', fontsize=fs-3)
    
    if subject != subjects[-1]:
        g.legend_.remove()
    else:
        # Fix legend
        handles, labels = g.get_legend_handles_labels()
        labels = [l.capitalize().replace('_', ' ') for l in labels]
        ax.legend(handles, labels, frameon=False, bbox_to_anchor=(1, 1), loc='upper left', fontsize=fs-3)

# Reaction/production times per participant

In [None]:
# compute stats for reaction times
rt_boxpairs = []
rt_pvals = []

# compute stats for production durations
pd_boxpairs = []
pd_pvals = []

for subject in adf.subject.unique():
    
    subj_rt_boxpairs = []
    subj_rt_pvals = []
    subj_pd_boxpairs = []
    subj_pd_pvals = []
    
    possible_conditions = adf.loc[(adf.subject == subject)].condition.unique()
    cur_adf = adf.loc[(~adf.decoder_data_mistake) & (adf.reaction_time > 0) & (adf.subject == subject)]
    
    for c1, c2 in combinations(possible_conditions, 2):
        
        a = cur_adf.loc[(cur_adf.condition == c1)].reaction_time.values
        b = cur_adf.loc[(cur_adf.condition == c2)].reaction_time.values
        stat, pval = stats.ranksums(a, b)
        
        subj_rt_boxpairs.append(((subject, c1), (subject, c2)))
        subj_rt_pvals.append(pval)
        
        a = cur_adf.loc[(cur_adf.condition == c1)].correct_production_time.values
        b = cur_adf.loc[(cur_adf.condition == c2)].correct_production_time.values
        stat, pval = stats.ranksums(a, b)
        
        subj_pd_boxpairs.append(((subject, c1), (subject, c2)))
        subj_pd_pvals.append(pval)
        
    _, subj_rt_pvals, _, _ = multipletests(subj_rt_pvals, alpha=sig_thresh, method='holm', is_sorted=False, returnsorted=False)
    rt_boxpairs.extend(subj_rt_boxpairs)
    rt_pvals.extend(subj_rt_pvals)
    
    _, subj_pd_pvals, _, _ = multipletests(subj_pd_pvals, alpha=sig_thresh, method='holm', is_sorted=False, returnsorted=False)
    pd_boxpairs.extend(subj_pd_boxpairs)
    pd_pvals.extend(subj_pd_pvals)
    
    print(subject, f'{len(subj_rt_pvals)}-way Holm-Bonferroni correction')
    
print()


total_pd = 0
sig_pd = 0
for p, pv in zip(pd_boxpairs, pd_pvals):
    combo = [p[0][1], p[1][1]]
    combo = sorted(combo)
    
    if (combo[0] == 'complexSeq_complexSyl' and combo[1] == 'complexSeq_simpleSyl') and ~np.isnan(pv):
        total_pd += 1
        if pv < sig_thresh:
            sig_pd += 1
    
print(f'{sig_pd} / {total_pd} artic. complexity with significant production duration difference')

total_pd = 0
sig_pd = 0
for p, pv in zip(pd_boxpairs, pd_pvals):
    combo = [p[0][1], p[1][1]]
    combo = sorted(combo)
    
    if (combo[0] == 'complexSeq_simpleSyl' and combo[1] == 'simpleSeq_simpleSyl') and ~np.isnan(pv):
        total_pd += 1
        if pv < sig_thresh:
            sig_pd += 1
            
            
            cur_adf = adf.loc[(~adf.decoder_data_mistake) & (adf.reaction_time > 0) & (adf.subject == p[0][0])]
            a = cur_adf.loc[(cur_adf.condition == 'simpleSeq_simpleSyl')].correct_production_time.values
            b = cur_adf.loc[(cur_adf.condition == 'complexSeq_simpleSyl')].correct_production_time.values

            if np.median(a) > np.median(b):
                print(p[0][0], np.median(a), np.median(b))
    
print(f'{sig_pd} / {total_pd} seq. complexity with significant production duration difference')

# show significant RT
rt_pvals = np.array(rt_pvals)
keep_idx = np.where(rt_pvals < sig_thresh)[0]
rt_boxpairs = [rtb for i, rtb in enumerate(rt_boxpairs) if i in keep_idx]
rt_pvals = rt_pvals[keep_idx]

# show only the non-sig production
pd_pvals = np.array(pd_pvals)
keep_idx = np.where(pd_pvals > sig_thresh)[0]
pd_boxpairs = [pdb for i, pdb in enumerate(pd_boxpairs) if i in keep_idx]
pd_pvals = pd_pvals[keep_idx]

## Fig. S1: reaction times

In [None]:
fs = 12
default_plot_settings(font='Helvetica', fontsize=fs)

fig, ax = plt.subplots(figsize=(15, 5))

boxplot_kwargs = {}
for key in ['cap', 'whisker', 'flier', 'median', 'mean']:
    boxplot_kwargs[f'{key}props'] = dict(clip_on=False)
boxplot_kwargs['boxprops'] = dict(alpha=0.3, clip_on=False)

ax = sns.boxplot(data=adf.loc[(~adf.decoder_data_mistake) & (adf.reaction_time > 0)], 
                 x='subject', y='reaction_time', hue='condition', palette=condition_colors.values(), 
                 showfliers=False, hue_order=condition_colors.keys(), ax=ax, zorder=2, **boxplot_kwargs)
ax = sns.stripplot(data=adf.loc[(~adf.decoder_data_mistake) & (adf.reaction_time > 0)], 
                   x='subject', y='reaction_time', hue='condition', palette=condition_colors.values(),
                   clip_on=False, hue_order=condition_colors.keys(), ax=ax, zorder=3, dodge=True, s=3)

add_stat_annotation(data=adf.loc[(~adf.decoder_data_mistake) & (adf.reaction_time > 0)], x='subject', y='reaction_time', hue='condition', 
                    hue_order=condition_colors.keys(),
                    text_format='star', ax=ax, verbose=0, pvalue_thresholds=pvalue_thresholds,
                    box_pairs=rt_boxpairs, perform_stat_test=False, pvalues=list(rt_pvals),
                    loc='inside', line_height=0.005, line_offset=0.01, text_offset=0.01)

sns.despine(ax=ax, offset=dict(left=5, bottom=5))
ax.spines['bottom'].set_visible(False)
ax.tick_params(axis='x', length=0)

# Fix legend
hand, labl = ax.get_legend_handles_labels()
handout, lablout = [], []
for h, l in zip(hand, labl):
    if l not in lablout and type(h) == mpl.collections.PathCollection:
        lablout.append(condition_dict[l])
        handout.append(h)
ax.legend(handout, lablout, frameon=False, bbox_to_anchor=(1, 1), loc='upper left', fontsize=fs-3)

ax.axhline(y=1.1, color='k', linewidth=1, linestyle=':', zorder=1)
ax.axes.set(ylim=(0, 5), ylabel='Reaction time (s)', xlabel='');

In [None]:
# average reaction time
cur_adf = adf.loc[(~adf.decoder_data_mistake) & (adf.reaction_time > 0)]
rts = cur_adf.reaction_time.values

print(f'Over {cur_adf.subject.unique().shape[0]} participants')
print(f'{np.mean(rts):0.2f} s on average')
print(f'{np.std(rts):0.2f} s st dev')
print(f'{np.percentile(rts, [2.5, 97.5])} 95% CI')
print(f'{len(rts)} total trials')

## Fig. S2: production durations

In [None]:
fs = 12
default_plot_settings(font='Helvetica', fontsize=fs)

fig, ax = plt.subplots(figsize=(15, 5))

boxplot_kwargs = {}
for key in ['cap', 'whisker', 'flier', 'median', 'mean']:
    boxplot_kwargs[f'{key}props'] = dict(clip_on=False)
boxplot_kwargs['boxprops'] = dict(alpha=0.3, clip_on=False)


ax = sns.boxplot(data=adf.loc[(~adf.decoder_data_mistake) & (adf.reaction_time > 0)], 
                 x='subject', y='correct_production_time', hue='condition', palette=condition_colors.values(), 
                 showfliers=False, hue_order=condition_colors.keys(), ax=ax, zorder=2, **boxplot_kwargs)
ax = sns.stripplot(data=adf.loc[(~adf.decoder_data_mistake) & (adf.reaction_time > 0)], 
                   x='subject', y='correct_production_time', hue='condition', palette=condition_colors.values(),
                   clip_on=False, hue_order=condition_colors.keys(), ax=ax, zorder=3, dodge=True, s=3)

add_stat_annotation(data=adf.loc[(~adf.decoder_data_mistake) & (adf.reaction_time > 0)], x='subject', y='correct_production_time', hue='condition', 
                    hue_order=condition_colors.keys(),
                    text_format='star', ax=ax, verbose=1, pvalue_thresholds=pvalue_thresholds,
                    box_pairs=pd_boxpairs, perform_stat_test=False, pvalues=list(pd_pvals),
                    loc='inside', line_height=0.005, line_offset=0.01, text_offset=0.01)

sns.despine(ax=ax, offset=dict(left=5, bottom=5))
ax.spines['bottom'].set_visible(False)
ax.tick_params(axis='x', length=0)

# Fix legend
hand, labl = ax.get_legend_handles_labels()
handout, lablout = [], []
for h, l in zip(hand, labl):
    if l not in lablout and type(h) == mpl.collections.PathCollection:
        lablout.append(condition_dict[l])
        handout.append(h)
ax.legend(handout, lablout, frameon=False, bbox_to_anchor=(1, 1), loc='upper left', fontsize=fs-3)

ax.axes.set(ylim=(0, 5), ylabel='Production duration (s)', xlabel='');