# Run Parameters

In [None]:
_model = 'roberta-large' # roberta-large or roberta-large-mnli
overwrite_plotting_data = False # Set to True if running for first time with new experiment data

# Import

In [None]:
%matplotlib widget

In [None]:
import os
import collections

import pandas as pd
import numpy as np
import json
import joblib
import pickle
import torch

import matplotlib.pyplot as plt
import matplotlib.path as pth
import seaborn as sns

from scipy.stats import ttest_ind, f_oneway

import statsmodels.api as sm
from statsmodels.formula.api import ols

In [None]:
sns.set_theme(style='whitegrid')

# Define

In [None]:
def get_val_summary(modifier, iteration, eval_dir, ):
    fname = os.path.join(eval_dir, f'r{iteration}', 'tables', f'configs.{modifier}.csv')
    summary_table = pd.read_csv(fname, index_col = 0)
    summary_table = summary_table[[str(n) for n in range(1, iteration+1)]]
    
    return summary_table


def get_itereval_summary(sub_keys, iteration, eval_dir, combined, ):
    rep = {
        '/': '-',
        ';': '--',
    }
    
    fname_key = '.'.join(sub_keys.values())
    for old_char, new_char in rep.items():
        fname_key = fname_key.replace(old_char, new_char)
    fname = os.path.join(eval_dir, f'r{iteration}', 'tables', combined, f'iterevals.{fname_key}.csv')
    summary_table = pd.read_csv(fname, index_col = 0)
    summary_table = summary_table[[str(n) for n in range(1, iteration+1)]]
    
    return summary_table
    

In [None]:
def get_mnli_tables(mnli_summary, subsetting='genre'):
    with open(mnli_summary, 'r') as f:
        summary = pd.DataFrame([json.loads(line) for line in f])
    
    mnli_tables = {}
    for comb in summary['comb'].unique():
        comb_sum = summary.loc[summary['comb'] == comb, :]

        for subset in summary[subsetting].unique():
            subset_sum = comb_sum.loc[comb_sum[subsetting] == subset, :]

            plot_tab = []
            for treat in subset_sum['treat'].unique():
                treat_sum = subset_sum.loc[subset_sum['treat'] == treat, :]
                s = treat_sum[['iter','acc']].set_index('iter').rename({'acc': treat}, axis=1).transpose()            
                plot_tab.append(s)
            
            mnli_tables[(model, comb, subset)] = pd.concat(plot_tab)
    
    return summary, mnli_tables

In [None]:
def split_run_name(run_name, split_by='_'):
    name_list = run_name.split(split_by)
    if len(name_list) == 2:
        input_type = 'full'
        comb = 'combined'
    elif len(name_list) == 3:
        if name_list[-1] == 'hyp':
            input_type = name_list[-1]
            comb = 'combined'
        else:
            input_type = 'full'
            comb = name_list[-1]
    else:
        input_type = name_list[-1]
        comb = name_list[-2]

    return (name_list[0], name_list[1], input_type, comb)

In [None]:
def unique_itereval(df, keys=['case', 'subcase', 'label', 'dataset', 'treat', 'iter', 'comb', 'sample_type', 'sample_partition']):
    return df.drop_duplicates(subset=keys, ignore_index=True)

def load_sampled_results(sampled_base):
    collected = pd.read_csv(os.path.join(sampled_base, 'collected.csv'))
    itereval = pd.read_csv(os.path.join(sampled_base, 'itereval.csv'))
    mnli = pd.read_csv(os.path.join(sampled_base, 'mnli.csv'))
    anli = pd.read_csv(os.path.join(sampled_base, 'anli.csv'))
    
    
    # fill in keys
    collected['treat'] = collected['run'].apply(lambda x: split_run_name(x)[0])
    collected['iter'] = collected['run'].apply(lambda x: int(split_run_name(x)[1]))
    collected['mod'] = collected['run'].apply(lambda x: split_run_name(x)[2])
    collected['combined'] = collected['run'].apply(lambda x: split_run_name(x)[3])
    
    
    mnli['breakdown'] = mnli['genre'].fillna('combined')
    anli['breakdown'] = anli['tag'].fillna('combined')
    
    itereval = unique_itereval(itereval)
    
    return collected, itereval, mnli, anli

In [None]:
def load_all_sampled(sampled_base, upto=5):
    loaded_keys = {'collected': 0, 'itereval':1 ,'mnli': 2, 'anli': 3}
    results = {key: [] for key in loaded_keys.keys()}
    
    for r in range(1, upto + 1):
        loaded = load_sampled_results(os.path.join(sampled_base, f'r{r}'))
        for result_key, loaded_key in loaded_keys.items():
            results[result_key].append(loaded[loaded_key])
    
    return {
        key: pd.concat(result_list, ignore_index=True)
        for key, result_list in results.items()
    }
    

In [None]:
def get_ttest_pvals(dist_df, verbose=True):
    pairs = [
        ('baseline', 'LotS'),
        ('baseline', 'LitL'),
        ('LotS', 'LitL'),
    ]
    
    ttest_dict = {}
    for pair in pairs:
        a = dist_df.loc[dist_df['treat'] == pair[0], 'acc']
        b = dist_df.loc[dist_df['treat'] == pair[1], 'acc']
        ttest_dict[pair] = ttest_ind(a, b)
    
    if verbose:
        for pair, ttest_results in ttest_dict.items():
            print('='*45)
            print(f"{pair}\nt: {ttest_results[0]:.5f} | p: {ttest_results[1]/2:.5f}")
    
    return ttest_dict


In [None]:
def two_way_anova(df, f1='iter', f2='treat', acc='acc', formula=None):
    keeps = [f1, f2, acc]
    
    if not formula:
        formula = f'{acc} ~ C({f1}) + C({f2}) + C({f1}):C({f2})'
        
    print(formula)
    model = ols(formula, data=df[keeps]).fit()
    
    return sm.stats.anova_lm(model, typ=2)

def one_way_anova(df, f='iter', acc='acc', formula=None):
    keeps = [f, acc]
    
    if not formula:
        formula = f'{acc} ~ C({f})'
        
    print(formula)
    model = ols(formula, data=df[keeps]).fit()
    
    return sm.stats.anova_lm(model, typ=1)

# Combine Plotting Data

In [None]:
repo = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
sample_type = 'cross_eval'
iteration = 5
plot_data = os.path.join(repo, 'eval_summary', 'plot_data')
os.makedirs(plot_data, exist_ok=True)

In [None]:
models = ['roberta-large', 'roberta-large-mnli']

In [None]:
distributions = {}

for model in models:
    eval_dir = os.path.join(repo, 'eval_summary', model)
    distributions[model] = load_all_sampled(
        os.path.join(eval_dir, 'sample', sample_type), upto=iteration
    )

## MNLI Only Training

In [None]:
pred_base = os.path.join(repo, 'predictions', 'roberta-large-mnli_only')
data_base = os.path.join(repo, 'tasks', 'data')
best_data = os.path.join(pred_base, 'best')

pred_dirs = [
    'baseline_5',
    'LotS_5',
    'LitL_5',
    'mnlieval_baseline_1',
    'anlieval_baseline_1',
    'eval_baseline_1',
]

data_dirs = [
    os.path.join('baseline_5', 'val_round5_base_combined.jsonl'),
    os.path.join('LotS_5', 'val_round5_LotS_combined.jsonl'),
    os.path.join('LitL_5', 'val_round5_LitL_combined.jsonl'),
    os.path.join('mnli_mismatched', 'val_mismatched_mnli.jsonl'),
    os.path.join('anli_combined', 'val_anli.jsonl'),
    os.path.join('iterative_eval', 'val_itercombined.jsonl'),
]

lrs = ['0.00001', '0.00002', '0.00003']
batches = ['16', '32']

n_trials = 10

In [None]:
def read_jsonl(file):
    with open(file, 'r') as f:
        return [json.loads(line) for line in f.readlines()]
    
def get_acc(
    preds,
    data,
    int2pred={0:'contradiction', 1:'entailment', 2:'neutral'}
):
    df = pd.DataFrame(data)
    df['preds'] = pd.Series(preds).apply(lambda x: int2pred[x])
    df['correct'] = df['label'].eq(df['preds'])
    return df

In [None]:
if overwrite_plotting_data:
    preds_and_data = {}

    for pred, data in zip(pred_dirs, data_dirs):
        temp_pred = torch.load(os.path.join(pred_base, pred, 'val_preds.p'))
        temp_data = read_jsonl(os.path.join(data_base, data))
        preds_and_data[pred.split('_')[0]] = {'pred': temp_pred, 'data':temp_data}

    accs = {
        key: get_acc(val['pred']['mnli']['preds'], val['data']) for key, val in preds_and_data.items()
    }
    
    accs['glue'] = accs['eval'].loc[accs['eval']['dataset'] == 'glue']
    
    hans = accs['eval'].loc[accs['eval']['dataset'] == 'hans', :]
    tempdict = {'contradiction':'contradiction', 'neutral':'contradiction', 'entailment':'entailment'}
    hans['preds'] = hans['preds'].apply(lambda x: tempdict[x])
    hans['case'] = hans['case'].apply(lambda x: x[0])
    hans['correct'] = hans['label'].eq(hans['preds'])
    
    accs['hans'] = hans
    
    with open(os.path.join(plot_data, 'mnli-only-training_accs.p'), 'wb') as f:
        pickle.dump(accs, f)

In [None]:
if overwrite_plotting_data:
    for trial in range(1, n_trials+1):
        preds_and_data = {}

        for pred, data in zip(pred_dirs, data_dirs):
            temp_pred = torch.load(os.path.join(best_data, pred, f'{trial}', 'val_preds.p'))
            temp_data = read_jsonl(os.path.join(data_base, data))
            preds_and_data[pred.split('_')[0]] = {'pred': temp_pred, 'data':temp_data}

        accs = {
            key: get_acc(val['pred']['mnli']['preds'], val['data']) for key, val in preds_and_data.items()
        }

        accs['glue'] = accs['eval'].loc[accs['eval']['dataset'] == 'glue']

        hans = accs['eval'].loc[accs['eval']['dataset'] == 'hans', :]
        tempdict = {'contradiction':'contradiction', 'neutral':'contradiction', 'entailment':'entailment'}
        hans['preds'] = hans['preds'].apply(lambda x: tempdict[x])
        hans['case'] = hans['case'].apply(lambda x: x[0])
        hans['correct'] = hans['label'].eq(hans['preds'])

        accs['hans'] = hans

        mnli_out_dir = os.path.join(plot_data, 'mnli_restarts', 'best', f'{trial}')
        os.makedirs(mnli_out_dir, exist_ok = True)
        with open(os.path.join(mnli_out_dir, 'mnli-only-training_accs.p'), 'wb') as f:
            pickle.dump(accs, f)

## Collected

In [None]:
select2mod = {
    ('combined', 'full'): 'combined',
    ('combined', 'hyp'): 'hyp',
    ('separate', 'full'): 'separate',
    ('separate', 'hyp'): 'separate_hyp',
}

In [None]:
all_df = []

if overwrite_plotting_data:
    for model in models:
        collected = []
        eval_dir = os.path.join(repo, 'eval_summary', model)

        for combined, input_type in select2mod.keys():
            mod = select2mod[(combined, input_type)]
            temp = get_val_summary(mod, iteration, eval_dir, )
            for idx, row in temp.iterrows():
                df = pd.DataFrame({
                    'acc': row,
                    'iter': [int(x) for x in row.index.values],
                    'treat':row.name,
                    'mod':input_type,
                    'combined':combined,
                    'model':model,
                })
                collected.append(df)
        collected_t = pd.concat(collected, ignore_index = True)
        distributions[model]['collected']['model'] = model
        all_df.append(pd.concat([distributions[model]['collected'], collected_t], ignore_index = True))
    df = pd.concat(all_df, ignore_index = True)
    df.to_csv(os.path.join(plot_data, 'collected.csv'))

## GLUE

In [None]:
glue_keys = pd.read_csv('glue_case_keys.csv')
print(glue_keys)
glue_labels = ['combined', 'entailment', 'neutral', 'contradiction']

In [None]:
dataset = 'glue'

combineds = ['combined', 'separate']
all_df = []

if overwrite_plotting_data:
    for model in models:
        collected = []
        eval_dir = os.path.join(repo, 'eval_summary', model)

        for idx, caserow in glue_keys.iterrows():
            for label in glue_labels:
                for combined in combineds:
                    sub_keys = {
                        'dataset': dataset,     # either hans or glue
                        'case': caserow['case'],    # combined or specific to respective itereval set
                        'subcase': caserow['subcase'], # combined or specific to respective itereval set
                        'label': label,   # combined or [entailment, neutral, contradiction] for glue, [entailment, non-entailment] for hans
                    }

                    temp = get_itereval_summary(sub_keys, iteration, eval_dir, combined)
                    for idx, row in temp.iterrows():
                        df = pd.DataFrame({
                            'acc': row,
                            'iter': [int(x) for x in row.index.values],
                            'treat':row.name,
                            'case':sub_keys['case'],
                            'subcase':sub_keys['subcase'],
                            'label':sub_keys['label'],
                            'comb':combined,
                            'model':model,
                        })
                        collected.append(df)
                collected_t = pd.concat(collected, ignore_index = True)

        temp_sampled = distributions[model]['itereval']
        temp_sampled = temp_sampled.loc[temp_sampled['dataset'] == dataset, :]
        temp_sampled['model'] = model
        all_df.append(pd.concat([temp_sampled, collected_t], ignore_index=True))
    df = pd.concat(all_df, ignore_index = True)
    df.to_csv(os.path.join(plot_data, 'glue.csv'))

## HANS

In [None]:
hans_keys = pd.read_csv('hans_case_keys.csv')
print(hans_keys)
hans_labels = ['combined', 'entailment', 'non-entailment']

In [None]:
dataset = 'hans'

combineds = ['combined', 'separate']
all_df = []

if overwrite_plotting_data:
    for model in models:
        collected = []
        eval_dir = os.path.join(repo, 'eval_summary', model)

        for idx, caserow in hans_keys.iterrows():
            for label in hans_labels:
                for combined in combineds:
                    sub_keys = {
                        'dataset': dataset,     # either glue or hans
                        'case': caserow['case'],    # combined or specific to respective itereval set
                        'subcase': caserow['subcase'], # combined or specific to respective itereval set
                        'label': label,   # combined or [entailment, neutral, contradiction] for glue, [entailment, non-entailment] for hans
                    }

                    temp = get_itereval_summary(sub_keys, iteration, eval_dir, combined)
                    for idx, row in temp.iterrows():
                        df = pd.DataFrame({
                            'acc': row,
                            'iter': [int(x) for x in row.index.values],
                            'treat':row.name,
                            'case':sub_keys['case'],
                            'subcase':sub_keys['subcase'],
                            'label':sub_keys['label'],
                            'comb':combined,
                            'model':model,
                        })
                        collected.append(df)
                collected_t = pd.concat(collected, ignore_index = True)

        temp_sampled = distributions[model]['itereval']
        temp_sampled = temp_sampled.loc[temp_sampled['dataset'] == dataset, :]
        temp_sampled['model'] = model
        all_df.append(pd.concat([temp_sampled, collected_t], ignore_index=True))
    df = pd.concat(all_df, ignore_index = True)
    df.to_csv(os.path.join(plot_data, 'hans.csv'))

## MNLI

In [None]:
all_df = []

if overwrite_plotting_data:
    for model in models:
        eval_dir = os.path.join(repo, 'eval_summary', model)
        mnli_summary = os.path.join(eval_dir, 'mnli_evals', 'eval_summaries.jsonl')

        with open(mnli_summary, 'r') as f:
            summary = pd.DataFrame([json.loads(line) for line in f])
        summary['breakdown'] = summary['genre'].fillna('combined')
        summary['iter'] = summary['iter'].apply(lambda x: int(x))

        temp = pd.concat([distributions[model]['mnli'], summary], ignore_index=True)
        temp['model'] = model

        all_df.append(temp)
    df = pd.concat(all_df, ignore_index = True)
    df.to_csv(os.path.join(plot_data, 'mnli.csv'))

## ANLI

In [None]:
all_df = []

if overwrite_plotting_data:
    for model in models:
        eval_dir = os.path.join(repo, 'eval_summary', model)
        df = pd.read_csv(os.path.join(eval_dir, 'sample', sample_type, 'final', 'anli_by_annotation.csv'))
        df['model'] = model

        all_df.append(df)
    df = pd.concat(all_df, ignore_index = True)
    df.to_csv(os.path.join(plot_data, 'anli.csv'))

# Plot

## Plot Params

In [None]:
combined = 'combined'

In [None]:
repo = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
plot_dir = os.path.join(repo, 'eval_summary', 'plot_data')
plot_out = os.path.join(repo, 'eval_summary', 'plots', _model)
os.makedirs(plot_out, exist_ok=True)

In [None]:
acc_name = 'Performance'
diff_name = 'Over Baseline'

In [None]:
save_figs = True
figtype='pdf'

err_style='bars' # band or bars
err_kws={'elinewidth': 2, 'capsize': 3}

title_fontsize=18
label_fontsize=16
legend_fontsize=14

err_line_offset = 4 if _model == 'roberta-large' else 5
err_cap_offset = 0.05


In [None]:
cols2plot = {'treat':'Protocol', 'model':'Model'}
treat2plot = {'baseline':'Baseline', 'LotS':'LitL', 'LitL':'LitL Chat'}
model2plot = {'roberta-large': r'RoBERTa$_{\rm{Lg}}$', 'roberta-large-mnli': r'RoBERTa$_{\rm{Lg+MNLI}}$'}

hue='Protocol'
hue_order=['Baseline', 'LitL', 'LitL Chat']
style_key='Model'
style_order=[
    model2plot[_model]
]

palette={
        'baseline':'tab:blue',
        'LitL':'tab:orange',
        'LitL Chat':'tab:green',
    }

xlim = [0.8, 5.2]

In [None]:
with open(os.path.join(plot_dir, 'mnli-only-training_accs.p'), 'rb') as f:
    mnli_accs = pickle.load(f)

## Def plot

In [None]:
def err_line_plots(
    plot_df,
    ylim=[0,1],
    ystep=0.1,
    xlim=[1,5],
    title=None,
    xlabel=None,
    ylabel=None,
    tabletitle=None,
    tableon=True,
    x='iter',
    y='acc',
    err_style='bars',
    ci=95,
    estimator='mean',
    markers=True,
    hue='treat',
    hue_order=['baseline', 'LotS', 'LitL'],
    iteration=5,
    bbox_to_anchor=(1.01, 1),
    palette=None,
    style_key='combined',
    style_order=['combined', 'separate'],
    yaxis_visible = True,
    xaxis_visible = True,
    ylabel_visible = True,
    xlabel_visible = True,
    legend_visible = True,
    ax=None,
    figsize=(6.4, 4.8),
    err_kws={'elinewidth': 1, 'capsize': 2},
    err_alpha=0.6,
    linewidth=2,
    markersize=7,
    loc='best',
    ncol=1,
    error_offsets = [
        {
            'line':-err_line_offset,
            'cap':-err_cap_offset,
        },
        {
            'line':0,
            'cap':0,
        },
        {
            'line':+err_line_offset,
            'cap':+err_cap_offset,
        }
    ]
):
    kwargs = {}
    no_ax = not ax
    if not ax:
        fig, ax = plt.subplots(figsize=figsize)
    
    if len(style_order) == 1:
        plot_df = plot_df.loc[plot_df[style_key] == style_order[0], :]
        style_key, style_order = hue, hue_order
    
    g = sns.lineplot(
        data=plot_df, x=x, y=y,
        hue=hue, hue_order=hue_order,
        style = style_key, style_order = style_order,
        err_style=err_style, err_kws=err_kws,
        ci=ci, markers=markers,
        ax=ax, **kwargs,
        linewidth=linewidth, markersize=markersize,
    )
    
    if error_offsets:
        assert len(g.containers) == len(error_offsets), f'{len(g.collections)}, error_offsets {len(error_offsets)}'
        
        for container, offsets in zip(g.containers, error_offsets):
            # offset line
            plt.setp(container[2][0], offsets = [offsets['line'], 0.])
            
            # offset caps
            for cap in container[1]:
                temp = cap._xy
                temp[:, 0] = temp[:, 0] + offsets['cap']
                cap._path = pth.Path(temp)
        
    plt.setp(g.containers, alpha=err_alpha)

    ax.set_title(title, fontsize=title_fontsize)
    
    ax.set_xlabel(xlabel if xaxis_visible and xlabel_visible else '', fontsize=label_fontsize)
    ax.set_xticks(np.arange(1, 6, 1))
    ax.set_xlim(*xlim)
    if not xaxis_visible:
        ax.xaxis.set_ticklabels([])
    
    ax.set_ylabel(ylabel if yaxis_visible and ylabel_visible else '', fontsize=label_fontsize)
    ax.set_yticks(np.arange(ylim[0], ylim[1]+ystep, ystep))
    ax.set_ylim(*ylim)
    if not yaxis_visible:
        ax.yaxis.set_ticklabels([])
    
    ax.legend(
        bbox_to_anchor=bbox_to_anchor,
        loc=loc,
        ncol=ncol,
        fontsize=legend_fontsize,
    ).set_visible(legend_visible)
    
    if no_ax:
        fig.tight_layout()
        return fig

## Combined In-Domain

In [None]:
indomain_l_offset = 0 if _model == 'roberta-large' else -1
indomain_c_offset = 0

error_offsets = [
        {
            'line':-(err_line_offset+indomain_l_offset),
            'cap':-(err_cap_offset+indomain_c_offset),
        },
        {
            'line':0,
            'cap':0,
        },
        {
            'line':+(err_line_offset+indomain_l_offset),
            'cap':+(err_cap_offset+indomain_c_offset),
        }
    ]

In [None]:
figsize=(7, 5)
fig, ax = plt.subplots(2, 1, figsize=figsize)

### In-domain

In [None]:
mnli_collected = []
n_trials = 10

for trial in range(1, n_trials+1):
    with open(os.path.join(plot_dir, 'mnli_restarts', 'best', f'{trial}', 'mnli-only-training_accs.p'), 'rb') as f:
        temp_mnli_accs = pickle.load(f)

    for treat in ['baseline', 'LotS', 'LitL']:
        temp_df = temp_mnli_accs[treat]
        temp_df['iter'] = temp_df['round'].apply(lambda x: int(x[-1]))

        for iteration in temp_df['iter'].unique():
            temp_acc = temp_df.loc[temp_df['iter'] <= iteration, :]
            mnli_collected.append(
                {
                    'treat': treat,
                    'iter': iteration,
                    'model': 'mnli-only',
                    'mod': 'full',
                    'combined': 'combined',
                    'trial': trial,
                    'acc': temp_acc['correct'].sum()/temp_acc.shape[0]
                }
            )

            temp_acc = temp_df.loc[temp_df['iter'] == iteration, :]
            mnli_collected.append(
                {
                    'treat': treat,
                    'iter': iteration,
                    'model': 'mnli-only',
                    'mod': 'full',
                    'combined': 'separate',
                    'trial': trial,
                    'acc': temp_acc['correct'].sum()/temp_acc.shape[0]
                }
            )

mnli_df = pd.DataFrame(mnli_collected)

In [None]:
mnli_df = mnli_df.loc[mnli_df['combined'] == 'combined', :]

for treat in mnli_df['treat'].unique():
    for iteration in mnli_df['iter'].unique():
        temp_df = mnli_df.loc[mnli_df['treat'] == treat, :]
        temp_df = temp_df.loc[temp_df['iter'] == iteration, :]
        
        print(treat, iteration, f"acc mean: {temp_df['acc'].mean():.3f}", f"acc std: {temp_df['acc'].std():.3f}")

In [None]:
# 2-way ANOVA

iterations = [1, 5]

iter_df = mnli_df.loc[mnli_df['iter'].isin(iterations), :]
anova_table = two_way_anova(iter_df)

print(model)
print(anova_table)
print('-'*90)

### In Domain

In [None]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'collected.csv'))

In [None]:
input_type = 'full'

plot_df = plot_df.loc[plot_df['combined'] == combined, :]
plot_df = plot_df.loc[plot_df['mod'] == input_type, :]

In [None]:
# 2-way ANOVA

iterations = [1, 5]
models = ['roberta-large', 'roberta-large-mnli']

for model in models:
    iter_df = plot_df.loc[plot_df['iter'].isin(iterations), :]
    iter_df = iter_df.loc[iter_df['model'] == model, :]
    anova_table = two_way_anova(iter_df)
    
    print(model)
    print(anova_table)
    print('-'*90)

In [None]:
plot_df['model'] = plot_df['model'].apply(lambda x: model2plot[x])
plot_df['treat'] = plot_df['treat'].apply(lambda x: treat2plot[x])
plot_df = plot_df.rename(columns=cols2plot)

In [None]:
ylims={
    'roberta-large': [0.6, 0.9],
    'roberta-large-mnli': [0.8, 1.0],
}
title=f'In-domain Validation'
xlabel='Iteration'
ylabel='Accuracy'
tabletitle='Median'
tableon=False

err_line_plots(
    plot_df,
    err_style=err_style,
    ylim=ylims[_model],
    title=title,
    xlabel=xlabel,
    ylabel=ylabel,
    tabletitle=tabletitle,
    palette=palette,
    tableon=tableon,
    style_key=style_key,
    style_order=style_order,
    figsize=figsize,
    hue=hue,
    hue_order=hue_order,
    xlim=xlim,
    err_kws=err_kws,
    ax=ax[0],
    ystep=0.05 if _model == 'roberta-large-mnli' else 0.1,
    legend_visible=True,
    xlabel_visible=False,
    error_offsets=error_offsets,
)

### Hyp

In [None]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'collected.csv'))

In [None]:
input_type = 'hyp'

plot_df = plot_df.loc[plot_df['combined'] == combined, :]
plot_df = plot_df.loc[plot_df['mod'] == input_type, :]

In [None]:
# 2-way ANOVA

iterations = [1, 5]
models = ['roberta-large', 'roberta-large-mnli']

for model in models:
    iter_df = plot_df.loc[plot_df['iter'].isin(iterations), :]
    iter_df = iter_df.loc[iter_df['model'] == model, :]
    anova_table = two_way_anova(iter_df)
    
    print(model)
    print(anova_table)
    print('-'*90)

In [None]:
plot_df['model'] = plot_df['model'].apply(lambda x: model2plot[x])
plot_df['treat'] = plot_df['treat'].apply(lambda x: treat2plot[x])
plot_df = plot_df.rename(columns=cols2plot)

In [None]:
ylims={
    'roberta-large': [0.3,0.7],
    'roberta-large-mnli': [0.3, 0.7],
}
title=f'Hypothesis-only Input'
xlabel='Iteration'
ylabel='Accuracy'
tabletitle='Median'
tableon=False

bbox_to_anchor = (0.5, -1.2)

err_line_plots(
    plot_df,
    err_style=err_style,
    ylim=ylims[_model],
    title=title,
    xlabel=xlabel,
    ylabel=ylabel,
    tabletitle=tabletitle,
    palette=palette,
    tableon=tableon,
    style_key=style_key,
    style_order=style_order,
    figsize=figsize,
    err_kws=err_kws,
    hue=hue,
    hue_order=hue_order,
    xlim=xlim,
    xlabel_visible=True,
    legend_visible=False,
    ax=ax[1],
    error_offsets=error_offsets,
)

In [None]:
# get majority class baseline per protocol and round
append = '_combined' if combined == 'combined' else ''
nli_data = os.path.join(repo, 'NLI_data')
protocol2dir = {
    'base':'1_Baseline_protocol',
    'LotS':'2_Ling_on_side_protocol',
    'LitL':'3_Ling_in_loop_protocol',
}
rounds = range(1,6)

majority_class = []

file2plot = {'base':'Baseline', 'LotS':'LitL', 'LitL':'LitL Chat'}

for protocol, protocol_dir in protocol2dir.items():
    for r in rounds:
        val_name = f'val_round{r}_{protocol}{append}.jsonl'
        val_path = os.path.join(nli_data, protocol_dir, val_name)
        
        labels2count = collections.defaultdict(int)
        with open(val_path, 'r') as f:
            for example in f.readlines():
                label = json.loads(example)['label']
                labels2count[label] += 1
        majority_class.append({
            'Protocol':file2plot[protocol],
            'Iteration':r,
            'Majority Class':max(labels2count.values())/sum(labels2count.values())
        })

In [None]:
avg = True
majority_class_df = pd.DataFrame(majority_class)

if avg:
    avg_majority = majority_class_df.groupby(by='Iteration').mean()
    plt.plot(
        avg_majority.index.values, avg_majority['Majority Class'],
        c='k', ls='--'
    )
else:
    sns.lineplot(
        data=majority_class_df, 
        x='Iteration', y='Majority Class', 
        hue='Protocol', style='Protocol', markers=True, ax=ax[1]
    )
ax[1].legend().set_visible(False)

### Save

In [None]:
fig.tight_layout()
if save_figs:
    fig.savefig(os.path.join(plot_out, f'{combined}_indomain.{figtype}'))

## Diagnostic Sets

In [None]:
iterations = [1, 5]
models = style_order
iter_l_offset = -0.75 if _model == 'roberta-large' else -1.75
iter_c_offset = +0.02

In [None]:
error_offsets = [
        {
            'line':-(err_line_offset+iter_l_offset),
            'cap':-(err_cap_offset+iter_c_offset),
        },
        {
            'line':0,
            'cap':0,
        },
        {
            'line':+(err_line_offset+iter_l_offset),
            'cap':+(err_cap_offset+iter_c_offset),
        }
    ]

In [None]:
figsize = (15, 5)
fig, ax = plt.subplots(2, 4, figsize=figsize) # top is GLUE bottom is HANS
ax[1, 3].set_axis_off()

### GLUE

In [None]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'glue.csv'))

In [None]:
mnli_collected = []
temp_df = mnli_accs['glue']
temp_df['case_text'] = temp_df['case'].apply(lambda x: x[0] if len(x) > 0 else '')

for case in temp_df['case_text'].unique():
    temp_acc = temp_df.loc[temp_df['case_text'] == case, :]
    mnli_collected.append({
        'acc': temp_acc['correct'].sum()/temp_acc.shape[0],
        'subcase': 'combined',
        'label': 'combined',
        'model': 'mnli-only',
        'case': case,
    })
    
    for label in temp_acc['label'].unique():
        temptemp_acc = temp_acc.loc[temp_acc['label'] == label, :]
        mnli_collected.append({
            'acc': temptemp_acc['correct'].sum()/temptemp_acc.shape[0],
            'subcase': 'combined',
            'label': label,
            'model': 'mnli-only',
            'case': case,
        })
mnli_df = pd.DataFrame(mnli_collected)

In [None]:
label = 'combined'

plot_df = plot_df.loc[plot_df['comb'] == combined, :]
plot_df = plot_df.loc[plot_df['label'] == label, :]
plot_df = plot_df.loc[plot_df['subcase'] == 'combined', :]

In [None]:
mnli_df = mnli_df.loc[mnli_df['label'] == label, :]
mnli_df = mnli_df.loc[mnli_df['subcase'] == 'combined', :]

In [None]:
plot_df['model'] = plot_df['model'].apply(lambda x: model2plot[x])
plot_df['treat'] = plot_df['treat'].apply(lambda x: treat2plot[x])
plot_df = plot_df.rename(columns=cols2plot)

In [None]:
ylims={
    'roberta-large':{
        'Knowledge':[0.4,0.7],
        'Lexical Semantics':[0.5, 0.8],
        'Logic': [0.4,0.7],
        'Predicate-Argument Structure':[0.5,0.8],
    },
    'roberta-large-mnli':{
        'Knowledge':[0.4,0.7],
        'Lexical Semantics':[0.5, 0.8],
        'Logic': [0.4,0.7],
        'Predicate-Argument Structure':[0.5,0.8],
    },
}
title=f""
xlabel='Iteration'
ylabel='GLUE'
tabletitle='median'
tableon=False

glue_keys = pd.read_csv('glue_case_keys.csv')

i = 0
for case in glue_keys['case'].unique():
    if case == 'combined':
        continue
    
    temp_df = plot_df.loc[plot_df['case'] == case, :]
    err_line_plots(
        temp_df,
        err_style=err_style,
        ylim=ylims[_model][case],
        title=f"{case}",
        xlabel=xlabel if i == len(glue_keys['case'].unique()) - 2 else None,
        ylabel=ylabel,
        tabletitle=tabletitle,
        palette=palette,
        tableon=tableon,
        style_key=style_key,
        style_order=style_order,
        ax=ax[0, i],
        ylabel_visible = i == 0,
        legend_visible = False,
        err_kws=err_kws,
        hue=hue,
        hue_order=hue_order,
        xlim=xlim,
        error_offsets=error_offsets,
    )
    
    if _model == 'roberta-large-mnli':
        temp_mnli = mnli_df.loc[mnli_df['case'] == case, :]
        ax[0, i].hlines(temp_mnli['acc'], xlim[0], xlim[1], label='mnli-only', zorder=10)
    
    i += 1
    
    for model in models:
        iter_df = temp_df.loc[temp_df['iter'].isin(iterations), :]
        iter_df = iter_df.loc[iter_df['Model'] == model, :]
        anova_table = two_way_anova(iter_df, f2='Protocol')

        print(model, case)
        print(anova_table)
        print('-'*90)

### HANS non-entailment

In [None]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'hans.csv'))

In [None]:
mnli_collected = []
temp_df = mnli_accs['hans']
temp_df['case_text'] = temp_df['case']

for case in temp_df['case_text'].unique():
    temp_acc = temp_df.loc[temp_df['case_text'] == case, :]
    mnli_collected.append({
        'acc': temp_acc['correct'].sum()/temp_acc.shape[0],
        'subcase': 'combined',
        'label': 'combined',
        'model': 'mnli-only',
        'case': case,
    })
    
    for label in temp_acc['label'].unique():
        temptemp_acc = temp_acc.loc[temp_acc['label'] == label, :]
        mnli_collected.append({
            'acc': temptemp_acc['correct'].sum()/temptemp_acc.shape[0],
            'subcase': 'combined',
            'label': 'non-entailment' if label == 'contradiction' else label,
            'model': 'mnli-only',
            'case': case,
        })
mnli_df = pd.DataFrame(mnli_collected)

In [None]:
label = 'non-entailment'

plot_df = plot_df.loc[plot_df['comb'] == combined, :]
plot_df = plot_df.loc[plot_df['label'] == label, :]
plot_df = plot_df.loc[plot_df['subcase'] == 'combined', :]

In [None]:
mnli_df = mnli_df.loc[mnli_df['label'] == label, :]
mnli_df = mnli_df.loc[mnli_df['subcase'] == 'combined', :]

In [None]:
plot_df['model'] = plot_df['model'].apply(lambda x: model2plot[x])
plot_df['treat'] = plot_df['treat'].apply(lambda x: treat2plot[x])
plot_df = plot_df.rename(columns=cols2plot)

In [None]:
ylims={
    'roberta-large':{
        'constituent': [0.0, 0.6],
        'lexical_overlap': [0.0, 0.8],
        'subsequence': [0.0, 0.6],
    },
    'roberta-large-mnli':{
        'constituent': [0.1, 0.5],
        'lexical_overlap': [0.6, 1.0],
        'subsequence': [0.1, 0.5],
    },
}
title=f""
xlabel='Iteration'
ylabel='HANS Non-Entailment'
tabletitle='median'
tableon=False

bbox_to_anchor = (1.15, 1)

hans_keys = pd.read_csv('hans_case_keys.csv')

case2title = {
    'constituent': 'Constituent',
    'lexical_overlap': 'Lexical Overlap',
    'subsequence': 'Subsequence',
}

i = 0
for case in hans_keys['case'].unique():
    if case == 'combined':
        continue
    
    temp_df = plot_df.loc[plot_df['case'] == case, :]
    err_line_plots(
        temp_df,
        err_style=err_style,
        ylim=ylims[_model][case],
        title=f"{case2title[case]}",
        xlabel=xlabel,
        ylabel=ylabel,
        tabletitle=tabletitle,
        palette=palette,
        tableon=tableon,
        style_key=style_key,
        style_order=style_order,
        ax=ax[1, i],
        ylabel_visible = i == 0,
        legend_visible = i == len(hans_keys['case'].unique()) - 2,
        bbox_to_anchor=bbox_to_anchor,
        err_kws=err_kws,
        hue=hue,
        hue_order=hue_order,
        xlim=xlim,
        ystep=0.2,
        error_offsets=error_offsets,
    )
    
    if _model == 'roberta-large-mnli':
        temp_mnli = mnli_df.loc[mnli_df['case'] == case, :]
        ax[1, i].hlines(temp_mnli['acc'], xlim[0], xlim[1], label='mnli-only', zorder=10)
    
    i += 1
    
    for model in models:
        iter_df = temp_df.loc[temp_df['iter'].isin(iterations), :]
        iter_df = iter_df.loc[iter_df['Model'] == model, :]
        anova_table = two_way_anova(iter_df, f2='Protocol')

        print(model, case)
        print(anova_table)
        print('-'*90)

### Save

In [None]:
fig.tight_layout()
fig.subplots_adjust(wspace=2.5e-1)
if save_figs:
    fig.savefig(os.path.join(plot_out, f'{combined}_itereval.{figtype}'))

## Combined Held-out

In [None]:
figsize=(7, 5)
fig, ax = plt.subplots(2, 1, figsize=figsize)

In [None]:
iterations = [1, 5]
models = style_order

In [None]:
mnli_df = mnli_accs['mnlieval']

In [None]:
ylims={
    'roberta-large':{
        'mnli':[0.6,0.9],
        'anli':[0.2,0.5]
    },
    'roberta-large-mnli':{
        'mnli':[0.8, 1.0],
        'anli':[0.3, 0.4]
    },
}
title=f""
xlabel='Iteration'
ylabel='Accuracy'
tabletitle='median'
tableon=False

bbox_to_anchor = (1.01, 1)

case2title = {
    'mnli': 'MNLI-mismatched',
    'anli': 'ANLI',
}

i = 0
for case, title_name in case2title.items():
    plot_df = pd.read_csv(os.path.join(plot_dir, f'{case}.csv'))
    plot_df = plot_df.loc[plot_df['comb'] == combined, :]
    plot_df = plot_df.loc[plot_df['breakdown'] == 'combined', :]
    mnli_df = mnli_accs[f'{case}eval']
    
    plot_df['model'] = plot_df['model'].apply(lambda x: model2plot[x])
    plot_df['treat'] = plot_df['treat'].apply(lambda x: treat2plot[x])
    plot_df = plot_df.rename(columns=cols2plot)
    
    err_line_plots(
        plot_df,
        err_style=err_style,
        ylim=ylims[_model][case],
        title=f"{title_name}",
        xlabel=xlabel,
        ylabel=ylabel,
        tabletitle=tabletitle,
        palette=palette,
        tableon=tableon,
        style_key=style_key,
        style_order=style_order,
        ax=ax[i],
        xlabel_visible = i == 1,
        legend_visible = i == 0,
        bbox_to_anchor=bbox_to_anchor,
        err_kws=err_kws,
        hue=hue,
        hue_order=hue_order,
        xlim=xlim,
        ystep=0.05 if _model == 'roberta-large-mnli' else 0.1,
    )
    
    if _model == 'roberta-large-mnli':
        ax[i].hlines(mnli_df['correct'].sum()/mnli_df.shape[0], xlim[0], xlim[1], label='mnli-only', zorder=10)
    
    i += 1
    
    for model in models:
        iter_df = plot_df.loc[plot_df['iter'].isin(iterations), :]
        iter_df = iter_df.loc[iter_df['Model'] == model, :]
        anova_table = two_way_anova(iter_df, f2='Protocol')

        print(model, case)
        print(anova_table)
        print('-'*90)

### Save

In [None]:
fig.tight_layout()
if save_figs:
    fig.savefig(os.path.join(plot_out, f'{combined}_val.{figtype}'))

## ANLI Breakdown

In [None]:
anli_l_offset = -0
anli_c_offset = +0.075

In [None]:
error_offsets = [
    {
        'line':-(err_line_offset+anli_l_offset),
        'cap':-(err_cap_offset+anli_c_offset),
    },
    {
        'line':0,
        'cap':0,
    },
    {
        'line':+(err_line_offset+anli_l_offset),
        'cap':+(err_cap_offset+anli_c_offset),
    },
    ]

In [None]:
anli_plot_out = os.path.join(repo, 'eval_summary', 'plots')
anli_style_order=[
    r'RoBERTa$_{\rm{Lg}}$', 
    r'RoBERTa$_{\rm{Lg+MNLI}}$'
]

In [None]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'anli.csv'))

In [None]:
breakdowns = [
        'combined',
        'Basic',
#         'EventCoref',
        'Imperfection',
        'Numerical',
        'Reasoning',
        'Reference',
        'Tricky',
    ]

In [None]:
mnli_df = mnli_accs['anlieval']

In [None]:
repo_up = os.path.dirname(repo)
anli_annot_fname = os.path.join(repo_up, 'anli_annot_v0.2_combined_A1A2')
anli_annot = joblib.load(anli_annot_fname)

In [None]:
def get_anli_breakdown_acc(pred_df, anli_annot, breakdown):
    temp1 = anli_annot[['uid', breakdown]]
    temp2 = pred_df[['uid', 'correct']]  
    temp = temp1.merge(temp2, on='uid')
    temp = temp.loc[temp[breakdown].ne('none'), :]
    return temp['correct'].sum()/temp.shape[0]

In [None]:
plot_df = plot_df.loc[plot_df['comb'] == combined, :]

In [None]:
plot_df['model'] = plot_df['model'].apply(lambda x: model2plot[x])
plot_df['treat'] = plot_df['treat'].apply(lambda x: treat2plot[x])
plot_df = plot_df.rename(columns=cols2plot)

In [None]:
if _model == 'roberta-large':
    ylim=[0.25,0.4]
    title=f""
    xlabel='Iteration'
    ylabel='Accuracy'
    tabletitle='median'
    tableon=False

    bbox_to_anchor = (1.01, 1)
    figsize=(15, 6)

    fig, ax = plt.subplots(2, len(breakdowns) - 1, figsize=figsize)

    i = 0
    for anli_model in anli_style_order:
        temp_df = plot_df.loc[plot_df['Model'] == anli_model, :]
        for case in breakdowns:
            if case == 'combined':
                continue
            
            print(i // (len(breakdowns) - 1), i % (len(breakdowns) - 1))
            temp_ax = ax[i // (len(breakdowns) - 1), i%(len(breakdowns) - 1)]
            temptemp_df = temp_df.loc[temp_df['breakdown'] == case, :]
            err_line_plots(
                temptemp_df,
                err_style=err_style,
                ylim=ylim,
                title=f"{case}",
                xlabel=xlabel,
                ylabel=anli_model,
                tabletitle=tabletitle,
                palette=palette,
                tableon=tableon,
                style_key=style_key,
                style_order=[anli_model],
                ax=temp_ax,
                yaxis_visible = i % (len(breakdowns) - 1) == 0,
                legend_visible = i == len(breakdowns) - 2,
                xlabel_visible = i // (len(breakdowns) - 1) == 1,
                bbox_to_anchor=bbox_to_anchor,
                err_kws=err_kws,
                hue=hue,
                hue_order=hue_order,
                xlim=xlim,
                error_offsets=error_offsets,
                ystep=0.05,
            )

            temp_ax.hlines(get_anli_breakdown_acc(mnli_df, anli_annot, case), xlim[0], xlim[1], label='mnli-only', zorder=10)    

            i += 1

    fig.tight_layout()
    if save_figs:
        print(os.path.dirname(plot_out))
        fig.savefig(os.path.join(os.path.dirname(plot_out), f'{combined}_anli_breakdown.{figtype}'))

## HANS Entailment

In [None]:
iterations = [1, 5]
models = style_order
iter_l_offset = -0.75 if _model == 'roberta-large' else -1.75
iter_c_offset = +0.02

In [None]:
error_offsets = [
        {
            'line':-(err_line_offset+iter_l_offset),
            'cap':-(err_cap_offset+iter_c_offset),
        },
        {
            'line':0,
            'cap':0,
        },
        {
            'line':+(err_line_offset+iter_l_offset),
            'cap':+(err_cap_offset+iter_c_offset),
        }
    ]

In [None]:
title_fontsize=16
label_fontsize=14
legend_fontsize=12

In [None]:
figsize = (15, 5)
fig, ax = plt.subplots(2, 3, figsize=figsize)

### RoBERTa HANS Entailment

#### Read Data

In [None]:
combined = 'combined'
_model = 'roberta-large'

In [None]:
repo = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
plot_dir = os.path.join(repo, 'eval_summary', 'plot_data')
plot_out = os.path.join(repo, 'eval_summary', 'plots', _model)
os.makedirs(plot_out, exist_ok=True)

In [None]:
acc_name = 'Performance'
diff_name = 'Over Baseline'

In [None]:
save_figs = True
figtype='pdf'

err_style='bars' # band or bars
err_kws={'elinewidth': 2, 'capsize': 3}

err_line_offset = 4 if _model == 'roberta-large' else 5
err_cap_offset = 0.05


In [None]:
cols2plot = {'treat':'Protocol', 'model':'Model'}
treat2plot = {'baseline':'Baseline', 'LotS':'LitL', 'LitL':'LitL Chat'}
model2plot = {'roberta-large': r'RoBERTa$_{\rm{Lg}}$', 'roberta-large-mnli': r'RoBERTa$_{\rm{Lg+MNLI}}$'}

hue='Protocol'
hue_order=['Baseline', 'LitL', 'LitL Chat']
style_key='Model'
style_order=[
    model2plot[_model]
]

palette={
        'baseline':'tab:blue',
        'LitL':'tab:orange',
        'LitL Chat':'tab:green',
    }

xlim = [0.8, 5.2]

In [None]:
with open(os.path.join(plot_dir, 'mnli-only-training_accs.p'), 'rb') as f:
    mnli_accs = pickle.load(f)

#### Plot

In [None]:
iterations = [1, 5]
models = style_order
iter_l_offset = -0.75 if _model == 'roberta-large' else -1.75
iter_c_offset = +0.02

In [None]:
error_offsets = [
        {
            'line':-(err_line_offset+iter_l_offset),
            'cap':-(err_cap_offset+iter_c_offset),
        },
        {
            'line':0,
            'cap':0,
        },
        {
            'line':+(err_line_offset+iter_l_offset),
            'cap':+(err_cap_offset+iter_c_offset),
        }
    ]

In [None]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'hans.csv'))

In [None]:
mnli_collected = []
temp_df = mnli_accs['hans']
temp_df['case_text'] = temp_df['case']

for case in temp_df['case_text'].unique():
    temp_acc = temp_df.loc[temp_df['case_text'] == case, :]
    mnli_collected.append({
        'acc': temp_acc['correct'].sum()/temp_acc.shape[0],
        'subcase': 'combined',
        'label': 'combined',
        'model': 'mnli-only',
        'case': case,
    })
    
    for label in temp_acc['label'].unique():
        temptemp_acc = temp_acc.loc[temp_acc['label'] == label, :]
        mnli_collected.append({
            'acc': temptemp_acc['correct'].sum()/temptemp_acc.shape[0],
            'subcase': 'combined',
            'label': 'non-entailment' if label == 'contradiction' else label,
            'model': 'mnli-only',
            'case': case,
        })
mnli_df = pd.DataFrame(mnli_collected)

In [None]:
label = 'entailment'

plot_dff = plot_df.loc[plot_df['comb'] == combined, :]
plot_dff = plot_dff.loc[plot_dff['label'] == label, :]
plot_dff = plot_dff.loc[plot_dff['subcase'] == 'combined', :]

In [None]:
mnli_df = mnli_df.loc[mnli_df['label'] == label, :]
mnli_df = mnli_df.loc[mnli_df['subcase'] == 'combined', :]

In [None]:
plot_dff['model'] = plot_dff['model'].apply(lambda x: model2plot[x])
plot_dff['treat'] = plot_dff['treat'].apply(lambda x: treat2plot[x])
plot_dff = plot_dff.rename(columns=cols2plot)

In [None]:
ylims={
    'roberta-large':{
        'constituent': [0.6, 1.05],
        'lexical_overlap': [0.6, 1.05],
        'subsequence': [0.6, 1.05],
    },
    'roberta-large-mnli':{
        'constituent': [0.6, 1.05],
        'lexical_overlap': [0.6, 1.05],
        'subsequence': [0.6, 1.05],
    },
}
title=f""
xlabel='Iteration'
ylabel= r'RoBERTa$_{\rm{Lg}}$'
tabletitle='median'
tableon=False

bbox_to_anchor = (1.055, 1)

hans_keys = pd.read_csv('hans_case_keys.csv')

case2title = {
    'constituent': 'Constituent',
    'lexical_overlap': 'Lexical Overlap',
    'subsequence': 'Subsequence',
}

i = 0
for case in hans_keys['case'].unique():
    if case == 'combined':
        continue
    
    temp_df = plot_dff.loc[plot_dff['case'] == case, :]
    err_line_plots(
        temp_df,
        err_style=err_style,
        ylim=ylims[_model][case],
        title=f"{case2title[case]}",
        xlabel=xlabel,
        ylabel=ylabel,
        tabletitle=tabletitle,
        palette=palette,
        tableon=tableon,
        style_key=style_key,
        style_order=style_order,
        ax=ax[0, i],
        ylabel_visible = i == 0,
        legend_visible = False,
        bbox_to_anchor=bbox_to_anchor,
        err_kws=err_kws,
        hue=hue,
        hue_order=hue_order,
        xlim=xlim,
        ystep=0.2,
        error_offsets=error_offsets,
    )
    
    if _model == 'roberta-large-mnli':
        temp_mnli = mnli_df.loc[mnli_df['case'] == case, :]
        ax[1, i].hlines(temp_mnli['acc'], xlim[0], xlim[1], label='mnli-only', zorder=10)
    
    i += 1
    
    for model in models:
        iter_df = temp_df.loc[temp_df['iter'].isin(iterations), :]
        iter_df = iter_df.loc[iter_df['Model'] == model, :]
        anova_table = two_way_anova(iter_df, f2='Protocol')

        print(model, case)
        print(anova_table)
        print('-'*90)

### RoBERTa MNLI HANS Entailment

#### Read Data

In [None]:
combined = 'combined'
_model = 'roberta-large-mnli'

In [None]:
repo = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
plot_dir = os.path.join(repo, 'eval_summary', 'plot_data')
plot_out = os.path.join(repo, 'eval_summary', 'plots', _model)
os.makedirs(plot_out, exist_ok=True)

In [None]:
acc_name = 'Performance'
diff_name = 'Over Baseline'

In [None]:
save_figs = True
figtype='pdf'

err_style='bars' # band or bars
err_kws={'elinewidth': 2, 'capsize': 3}

err_line_offset = 4 if _model == 'roberta-large' else 5
err_cap_offset = 0.05


In [None]:
cols2plot = {'treat':'Protocol', 'model':'Model'}
treat2plot = {'baseline':'Baseline', 'LotS':'LitL', 'LitL':'LitL Chat'}
model2plot = {'roberta-large': r'RoBERTa$_{\rm{Lg}}$', 'roberta-large-mnli': r'RoBERTa$_{\rm{Lg+MNLI}}$'}

hue='Protocol'
hue_order=['Baseline', 'LitL', 'LitL Chat']
style_key='Model'
style_order=[
    model2plot[_model]
]

palette={
        'baseline':'tab:blue',
        'LitL':'tab:orange',
        'LitL Chat':'tab:green',
    }

xlim = [0.8, 5.2]

In [None]:
with open(os.path.join(plot_dir, 'mnli-only-training_accs.p'), 'rb') as f:
    mnli_accs = pickle.load(f)

#### Plot

In [None]:
iterations = [1, 5]
models = style_order
iter_l_offset = -0.75 if _model == 'roberta-large' else -1.75
iter_c_offset = +0.02

In [None]:
error_offsets = [
        {
            'line':-(err_line_offset+iter_l_offset),
            'cap':-(err_cap_offset+iter_c_offset),
        },
        {
            'line':0,
            'cap':0,
        },
        {
            'line':+(err_line_offset+iter_l_offset),
            'cap':+(err_cap_offset+iter_c_offset),
        }
    ]

In [None]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'hans.csv'))

In [None]:
mnli_collected = []
temp_df = mnli_accs['hans']
temp_df['case_text'] = temp_df['case']

for case in temp_df['case_text'].unique():
    temp_acc = temp_df.loc[temp_df['case_text'] == case, :]
    mnli_collected.append({
        'acc': temp_acc['correct'].sum()/temp_acc.shape[0],
        'subcase': 'combined',
        'label': 'combined',
        'model': 'mnli-only',
        'case': case,
    })
    
    for label in temp_acc['label'].unique():
        temptemp_acc = temp_acc.loc[temp_acc['label'] == label, :]
        mnli_collected.append({
            'acc': temptemp_acc['correct'].sum()/temptemp_acc.shape[0],
            'subcase': 'combined',
            'label': 'non-entailment' if label == 'contradiction' else label,
            'model': 'mnli-only',
            'case': case,
        })
mnli_df = pd.DataFrame(mnli_collected)

In [None]:
label = 'entailment'

plot_dff = plot_df.loc[plot_df['comb'] == combined, :]
plot_dff = plot_dff.loc[plot_dff['label'] == label, :]
plot_dff = plot_dff.loc[plot_dff['subcase'] == 'combined', :]

In [None]:
mnli_df = mnli_df.loc[mnli_df['label'] == label, :]
mnli_df = mnli_df.loc[mnli_df['subcase'] == 'combined', :]

In [None]:
plot_dff['model'] = plot_dff['model'].apply(lambda x: model2plot[x])
plot_dff['treat'] = plot_dff['treat'].apply(lambda x: treat2plot[x])
plot_dff = plot_dff.rename(columns=cols2plot)

In [None]:
ylims={
    'roberta-large':{
        'constituent': [0.6, 1.05],
        'lexical_overlap': [0.6, 1.05],
        'subsequence': [0.6, 1.05],
    },
    'roberta-large-mnli':{
        'constituent': [0.6, 1.05],
        'lexical_overlap': [0.6, 1.05],
        'subsequence': [0.6, 1.05],
    },
}
title=f""
xlabel='Iteration'
ylabel=r'RoBERTa$_{\rm{Lg+MNLI}}$'
tabletitle='median'
tableon=False

bbox_to_anchor = (-.75, -1)

hans_keys = pd.read_csv('hans_case_keys.csv')

case2title = {
    'constituent': 'Constituent',
    'lexical_overlap': 'Lexical Overlap',
    'subsequence': 'Subsequence',
}

i = 0
for case in hans_keys['case'].unique():
    if case == 'combined':
        continue
    
    temp_df = plot_dff.loc[plot_dff['case'] == case, :]
    err_line_plots(
        temp_df,
        err_style=err_style,
        ylim=ylims[_model][case],
        title=f"{case2title[case]}",
        xlabel=xlabel,
        ylabel=ylabel,
        tabletitle=tabletitle,
        palette=palette,
        tableon=tableon,
        style_key=style_key,
        style_order=style_order,
        ax=ax[1, i],
        ylabel_visible = i == 0,
        legend_visible = i == len(hans_keys['case'].unique())-2,
        loc='lower center',
        ncol=len(hans_keys['case'].unique()),
        bbox_to_anchor=bbox_to_anchor,
        err_kws=err_kws,
        hue=hue,
        hue_order=hue_order,
        xlim=xlim,
        ystep=0.2,
        error_offsets=error_offsets,
    )
    
    if _model == 'roberta-large-mnli':
        temp_mnli = mnli_df.loc[mnli_df['case'] == case, :]
        ax[1, i].hlines(temp_mnli['acc'], xlim[0], xlim[1], label='mnli-only', zorder=10)
    
    i += 1
    
    for model in models:
        iter_df = temp_df.loc[temp_df['iter'].isin(iterations), :]
        iter_df = iter_df.loc[iter_df['Model'] == model, :]
        anova_table = two_way_anova(iter_df, f2='Protocol')

        print(model, case)
        print(anova_table)
        print('-'*90)

### Save

In [None]:
plot_out = os.path.join(repo, 'eval_summary', 'plots', 'HANS_entailment')
os.makedirs(plot_out, exist_ok=True)

In [None]:
fig.tight_layout()
fig.subplots_adjust(wspace=2e-1)
if save_figs:
    fig.savefig(os.path.join(plot_out, f'{combined}_HANS_entailment.pdf'))