In [None]:
root = '/Users/admin/Documents/PhD/Code/perceptual-tuning-results/'

mp_folder = root + 'ABX/mp_scores'

analysis_folder = root + 'ABX/analyses/avg_error/resampling'

fig_path = root + 'ABX/figures/avg.pdf'

fig_robustness_path = root + 'ABX/figures/avg_robustness.pdf'

fig_baselines_ path = root + 'ABX/figures/avg_baselines.pdf'

In [None]:
# Uncomment for development/debugging
#%matplotlib inline


# Uncomment to plot finalized figures

import matplotlib as mpl
mpl.use("pgf")
pgf_with_custom_preamble = {
    "font.family": "serif", # use serif/main font for text elements
    "text.usetex": True,    # use inline math for ticks
    "pgf.rcfonts": False,   # don't setup fonts from rc parameters
    "pgf.preamble": [
         "\\usepackage{unicode-math}",  # unicode math setup
         "\\setmainfont{Doulos SIL}" # serif font via preamble
         ]
}
mpl.rcParams.update(pgf_with_custom_preamble)


from scone_phobia import apply_analysis
from scone_phobia.utils.mp_scores import estimate_std
from scone_phobia.analyses.avg_error import avg_error
import scone_phobia.metadata.add_metadata as add_metadata
from scone_phobia.plots.catplot import custom_catplot
import matplotlib.patches as patches
import numpy as np
import seaborn as sns

We specialise the generic avg_error analysis a little (averaging over 4 conditions and over C and V scores). To avoid any issue of dependency that could arise when aggregating estimates of variability of our error estimates, we apply the resampling directly over the specialised analysis. This is also necessary to perform permutation tests.

We cache the resampling results at the level of the output of the generic avg_error analysis though, as the specialised analysis is not costly to apply to resamples and would require writing a dedicated caching scheme.

In [None]:
# Loading (or computing if it's the first time) avg_error analysis results with full resamples


# select relevant models
dpgmm = 'dpgmm_novtln_vad'
percTun_filt = lambda mp_fname: 'AMtri1_sat_small_LMtri1satsmall' in mp_fname\
                                  or 'mfcc_novtln' in mp_fname\
                                  or dpgmm in mp_fname

# launch analyses with cached resampling
analysis = avg_error
df_avg, boot_df_avg = apply_analysis(analysis, mp_folder,
                                     filt=percTun_filt,
                                     add_metadata=add_metadata.language_register,
                                     resampling=True,
                                     resample_caching_scheme='mp_file',
                                     analysis_folder=analysis_folder,
                                     pickle_encoding=None,
                                     resampled_pickle_encoding="latin1",
                                     verbose=0)

# we're going to do further aggregation on errors that might have some dependencies
# so we have no use for the resampled standard deviation estimates for the current errors
del df_avg['std']

In [None]:
# Average results into 4 conditions 
# based on match between language and register across training and test sets

def get_traintest_match(df_row=None):
    if df_row is None:
        # return col names for condition
        return ['train/test match']
    else:
        # return list of values computed from df_row
        cols = ['training set', 'test set',
                'training language', 'test language',
                'training register', 'test register']
        x, y, lx, ly, rx, ry = [df_row[e] for e in cols]
        if x == 'None' or y == 'None':
            rel = 'NA'  # not applicable
        elif x == y:
            rel = 'same C'  # same corpus
        elif lx == ly and rx != ry:
            rel = 'same L diff R'
        elif rx == ry and lx != ly:
            rel = 'diff L same R'
        elif lx == ly and rx == ry:
            rel = 'same L same R'
        else:
            rel = 'diff L diff R'
        return [rel]


# We also look at full 16 conditions for each training set/test set combination (for extended data)
def get_full_condition(df_row=None):
    cols = ['training language', 'test language',
            'training register', 'test register']
    if df_row is None:
        # return col names for condition
        return cols
    else:
        # return list of values computed from df_row
        return [df_row[col] for col in cols]


def agg_conds(df, get_condition, df_is_resampled=False):
    df = df.copy(deep=True)  # we're adding columns, so make a copy
    cond_cols = get_condition()
    cond_data = zip(*[get_condition(row) for _, row in df.iterrows()])
    for col, col_data in zip(cond_cols, cond_data):
        df[col] = col_data
    groupby_cols = ['model type', 'contrast type'] + cond_cols
    if df_is_resampled:
        # ensure separate analysis for each resample
        groupby_cols = groupby_cols + ['batch ID', 'batch size', 'boot ID']
    agg_df = df.groupby(groupby_cols, as_index=False).mean()
    return cond_cols, agg_df

cond_cols_agg, df_agg = agg_conds(df_avg, get_traintest_match)
_, boot_df_agg = agg_conds(boot_df_avg, get_traintest_match, df_is_resampled=True)


cond_cols_control, df_control = agg_conds(df_avg, get_full_condition)
_, boot_df_control = agg_conds(boot_df_avg, get_full_condition, df_is_resampled=True)

In [None]:
# (C + V) / 2
def avg_C_V(df, cond_cols, df_is_resampled=False):
    res_df = df[[e in['C', 'V'] for e in df['contrast type']]]
    groupby_cols = ['model type'] + cond_cols
    if df_is_resampled:
        # ensure separate analysis for each resample
        groupby_cols = groupby_cols + ['batch ID', 'batch size', 'boot ID']
    res_df = res_df.groupby(groupby_cols, as_index=False).mean()
    return res_df

df = avg_C_V(df_agg, cond_cols_agg)
boot_df = avg_C_V(boot_df_agg, cond_cols_agg, df_is_resampled=True)

df_control = avg_C_V(df_control, cond_cols_control)
boot_df_control = avg_C_V(boot_df_control, cond_cols_control, df_is_resampled=True)

In [None]:
# Get estimate of standard deviations
df = estimate_std(df, boot_df)
df_control = estimate_std(df_control, boot_df_control)

### Main figure

In [None]:
col_order = [dpgmm]
col_labels = ['']
x_order = ['same C', 'same L diff R', 'diff L same R', 'diff L diff R']
xticklabels = ['Same language\n Same register', 'Same language\n Different register',
               'Different language\n Same register', 'Different language\n Different register']
err_args = {'ecolor': 'k',
            'capsize': 2,
            'elinewidth': 2,
            'markeredgewidth': 2}
# colors courtesy of Marianne
# my_red, my_blue = (0.85, .325, .098, 1), (0, .447, .741, 1)
my_red, my_blue = 'r', 'b' 

# main part
palette={e : (my_red if 'diff L' in e else my_blue) for e in df['train/test match']}
g, x_dict = custom_catplot(x='train/test match', y="error", yerr="std",
                           col="model type",
                           data=df,
                           kind="bar",
                           err_args=err_args,
                           order=x_order,
                           col_order=col_order,
                           legend=False,
                           palette=palette)


# labels, fontsize etc. 
g.set_xticklabels(xticklabels, rotation=45, ha='right', fontsize=15)
for tick in g.axes[0,0].yaxis.get_major_ticks():
    tick.label.set_fontsize(20)
g.set_ylabels('ABX error rate (in \%)', fontsize=20)
g.set_xlabels('Train/Test relationship', fontsize=20)
for ax, t in zip(g.axes.flatten(), col_labels):
    ax.tick_params(axis='both', which='both', width=0, length=0)
    ax.set_axisbelow(True)
    ax.grid(axis='y')
    ax.set_title(t, fontsize=25)
g.despine(left=True)
# y range set to half that for fig2?
g.axes[0,0].set_ylim([0, 11])
g.axes[0,0].set_xlim([-.48, 3.48])
    

g.savefig(fig_path)
# legend: blue is matched language, red mismatched language

### Figure showing robustness of the results across training and test sets

In [None]:
def plot_res_robustness(data, model_type, fig_path=None, ymax=16):
    # Select only dpgmm data
    df = data[data["model type"] == model_type]
    facet_labels = ["Read American English test", "Read Japanese test",
                    "Spont. American English test", "Spont. Japanese test"]  # row by row
    col_order = ["American English", "Japanese"]
    row_order = ["Read", "Spontaneous"]
    x_order = ['Read', 'Spontaneous']
    hue_order = ["American English", "Japanese"]
    xticklabels = ["Read training", "Spont. training"]
    err_args = {'ecolor': 'k',
                'capsize': 2,
                'elinewidth': 2,
                'markeredgewidth': 2}
    # colors courtesy of Marianne
    # my_red, my_blue = (0.85, .325, .098, 1), (0, .447, .741, 1)
    my_red, my_blue = 'r', 'b' 

    # main part
    palette={'American English': 'b', 'Japanese': 'g'}
    #palette = {e : (my_red if 'diff L' in e else my_blue) for e in df['condition']}
    g, x_dict = custom_catplot(x="training register", y="error", yerr="std",
                               order=x_order,
                               col="test language",
                               col_order=col_order,
                               row="test register",
                               row_order=row_order,
                               hue="training language",
                               hue_order=hue_order,
                               data=df,
                               kind="bar",
                               err_args=err_args,
                               legend=False,
                               sharex=False,
                               palette=palette)

    # labels, fontsize etc. 
    g.set_xticklabels(xticklabels, fontsize=15)
    for axes in g.axes:
        for tick in axes[0].yaxis.get_major_ticks():
            tick.label.set_fontsize(20)
    g.set_ylabels('ABX error rate (in \%)', fontsize=18)
    g.set_xlabels('Training register', fontsize=18)
    for ax, t in zip(g.axes.flatten(), facet_labels):
        ax.tick_params(axis='both', which='both', width=0, length=0)
        ax.set_axisbelow(True)
        ax.grid(axis='y')
        ax.set_title(t, fontsize=20)
        ax.set_xlim([-.48, 1.48])
        ax.set_ylim([0, ymax])
    g.despine(left=True)
    g.fig.tight_layout()

    if not(fig_path is None):
        g.savefig(fig_path)
    # legend: blue is AE, red Jap. for training language

In [None]:
plot_res_robustness(df_control, dpgmm, fig_path=fig_robustness_path)

### Extended Figure: with baseline and topline

In [None]:
# Break down into MFCC baseline and other scores for plot
mfcc_df = df[df["model type"] == 'mfcc_novtln']
mfcc_err, mfcc_std = float(mfcc_df['error']), float(mfcc_df['std'])
main_df = df[df["model type"] != 'mfcc_novtln']

In [None]:
# Figure
col_order = [dpgmm, 'AMtri1_sat_small_LMtri1satsmall']
col_labels = ['GMM (unsupervised)', 'HMM (supervised)']
x_order = ['same C', 'same L diff R', 'diff L same R', 'diff L diff R']
xticklabels = ['Same language\n Same register', 'Same language\n Different register',
               'Different language\n Same register', 'Different language\n Different register']
err_args = {'ecolor': 'k',
            'capsize': 2,
            'elinewidth': 2,
            'markeredgewidth': 2}
# colors courtesy of Marianne
# my_red, my_blue = (0.85, .325, .098, 1), (0, .447, .741, 1)
my_red, my_blue = 'r', 'b' 

# main part
palette={e : (my_red if 'diff L' in e else my_blue) for e in df['train/test match']}
g, x_dict = custom_catplot(x="train/test match", y="error", yerr="std",
                           col="model type",
                           data=main_df,
                           kind="bar",
                           err_args=err_args,
                           order=x_order,
                           col_order=col_order,
                           legend=False,
                           palette=palette)
# baseline
for ax in g.axes.flatten():
    mi, ma = ax.get_xlim()
    # plot dotted line
    line = ax.plot([mi, ma], [mfcc_err, mfcc_err], 'k--')
    # put in background
    line[0].set_zorder(0)
    # add error-bands
    rect = patches.Rectangle((mi,mfcc_err-mfcc_std), ma-mi, 2*mfcc_std,
                              edgecolor=(.8, .8, .8, 1),
                              facecolor=(.8, .8, .8, 1))
    # put error-bands in background
    rect.set_zorder(0)
    ax.add_patch(rect)

# labels, fontsize etc. 
g.set_xticklabels(xticklabels, rotation=45, ha='right', fontsize=15)
for tick in g.axes[0,0].yaxis.get_major_ticks():
    tick.label.set_fontsize(20)
g.set_ylabels('ABX error rate (in \%)', fontsize=20)
g.set_xlabels('Train/Test relationship', fontsize=20)
for ax, t in zip(g.axes.flatten(), col_labels):
    ax.tick_params(axis='both', which='both', width=0, length=0)
    ax.set_axisbelow(True)
    ax.grid(axis='y')
    ax.set_title(t, fontsize=25)
g.despine(left=True)
# y range set to half that for fig2?
g.axes[0,0].set_ylim([0, 15])
g.axes[0,0].set_xlim([-.48, 3.48])
g.axes[0,1].set_xlim([-.48, 3.48])


g.savefig(fig_baselines_path)
# legend: blue is matched language, red mismatched language + black dotted line and grey