# Import

In [1]:
%matplotlib widget

In [2]:
import os

import pandas as pd
import numpy as np
import json
import joblib
import pickle
import torch

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import ttest_ind

In [3]:
sns.set_theme(style='whitegrid')

# Define

In [4]:
def get_val_summary(modifier, iteration, eval_dir, ):
    fname = os.path.join(eval_dir, f'r{iteration}', 'tables', f'configs.{modifier}.csv')
    summary_table = pd.read_csv(fname, index_col = 0)
    summary_table = summary_table[[str(n) for n in range(1, iteration+1)]]
    
    return summary_table


def get_itereval_summary(sub_keys, iteration, eval_dir, combined, ):
    rep = {
        '/': '-',
        ';': '--',
    }
    
    fname_key = '.'.join(sub_keys.values())
    for old_char, new_char in rep.items():
        fname_key = fname_key.replace(old_char, new_char)
    fname = os.path.join(eval_dir, f'r{iteration}', 'tables', combined, f'iterevals.{fname_key}.csv')
    summary_table = pd.read_csv(fname, index_col = 0)
    summary_table = summary_table[[str(n) for n in range(1, iteration+1)]]
    
    return summary_table
    

In [5]:
def get_mnli_tables(mnli_summary, subsetting='genre'):
    with open(mnli_summary, 'r') as f:
        summary = pd.DataFrame([json.loads(line) for line in f])
    
    mnli_tables = {}
    for comb in summary['comb'].unique():
        comb_sum = summary.loc[summary['comb'] == comb, :]

        for subset in summary[subsetting].unique():
            subset_sum = comb_sum.loc[comb_sum[subsetting] == subset, :]

            plot_tab = []
            for treat in subset_sum['treat'].unique():
                treat_sum = subset_sum.loc[subset_sum['treat'] == treat, :]
                s = treat_sum[['iter','acc']].set_index('iter').rename({'acc': treat}, axis=1).transpose()            
                plot_tab.append(s)
            
            mnli_tables[(model, comb, subset)] = pd.concat(plot_tab)
    
    return summary, mnli_tables

In [6]:
def split_run_name(run_name, split_by='_'):
    name_list = run_name.split(split_by)
    if len(name_list) == 2:
        input_type = 'full'
        comb = 'combined'
    elif len(name_list) == 3:
        if name_list[-1] == 'hyp':
            input_type = name_list[-1]
            comb = 'combined'
        else:
            input_type = 'full'
            comb = name_list[-1]
    else:
        input_type = name_list[-1]
        comb = name_list[-2]

    return (name_list[0], name_list[1], input_type, comb)

In [7]:
def load_sampled_results(sampled_base):
    collected = pd.read_csv(os.path.join(sampled_base, 'collected.csv'))
    itereval = pd.read_csv(os.path.join(sampled_base, 'itereval.csv'))
    mnli = pd.read_csv(os.path.join(sampled_base, 'mnli.csv'))
    anli = pd.read_csv(os.path.join(sampled_base, 'anli.csv'))
    
    
    # fill in keys
    collected['treat'] = collected['run'].apply(lambda x: split_run_name(x)[0])
    collected['iter'] = collected['run'].apply(lambda x: int(split_run_name(x)[1]))
    collected['mod'] = collected['run'].apply(lambda x: split_run_name(x)[2])
    collected['combined'] = collected['run'].apply(lambda x: split_run_name(x)[3])
    
    mnli['breakdown'] = mnli['genre'].fillna('combined')
    anli['breakdown'] = anli['tag'].fillna('combined')
    
    return collected, itereval, mnli, anli

In [8]:
def load_all_sampled(sampled_base, upto=5):
    loaded_keys = {'collected': 0, 'itereval':1 ,'mnli': 2, 'anli': 3}
    results = {key: [] for key in loaded_keys.keys()}
    
    for r in range(1, upto + 1):
        loaded = load_sampled_results(os.path.join(sampled_base, f'r{r}'))
        for result_key, loaded_key in loaded_keys.items():
            results[result_key].append(loaded[loaded_key])
    
    return {
        key: pd.concat(result_list, ignore_index=True)
        for key, result_list in results.items()
    }
    

In [9]:
def get_ttest_pvals(dist_df, verbose=True):
    pairs = [
        ('baseline', 'LotS'),
        ('baseline', 'LitL'),
        ('LotS', 'LitL'),
    ]
    
    ttest_dict = {}
    for pair in pairs:
        a = dist_df.loc[dist_df['treat'] == pair[0], 'acc']
        b = dist_df.loc[dist_df['treat'] == pair[1], 'acc']
        ttest_dict[pair] = ttest_ind(a, b)
    
    if verbose:
        for pair, ttest_results in ttest_dict.items():
            print('='*45)
            print(f"{pair}\nt: {ttest_results[0]:.5f} | p: {ttest_results[1]/2:.5f}")
    
    return ttest_dict
    

# Combine Plotting Data

In [10]:
overwrite_plotting_data = False

In [11]:
repo = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
sample_type = 'cross_eval'
iteration = 5
plot_data = os.path.join(repo, 'eval_summary', 'plot_data')
os.makedirs(plot_data, exist_ok=True)

In [12]:
models = ['roberta-large', 'roberta-large-mnli']

In [13]:
distributions = {}

for model in models:
    eval_dir = os.path.join(repo, 'eval_summary', model)
    distributions[model] = load_all_sampled(
        os.path.join(eval_dir, 'sample', sample_type), upto=iteration
    )

## Collected

In [14]:
select2mod = {
    ('combined', 'full'): 'combined',
    ('combined', 'hyp'): 'hyp',
    ('separate', 'full'): 'separate',
    ('separate', 'hyp'): 'separate_hyp',
}

In [15]:
all_df = []

if overwrite_plotting_data:
    for model in models:
        collected = []
        eval_dir = os.path.join(repo, 'eval_summary', model)

        for combined, input_type in select2mod.keys():
            mod = select2mod[(combined, input_type)]
            temp = get_val_summary(mod, iteration, eval_dir, )
            for idx, row in temp.iterrows():
                df = pd.DataFrame({
                    'acc': row,
                    'iter': [int(x) for x in row.index.values],
                    'treat':row.name,
                    'mod':input_type,
                    'combined':combined,
                    'model':model,
                })
                collected.append(df)
        collected_t = pd.concat(collected, ignore_index = True)
        distributions[model]['collected']['model'] = model
        all_df.append(pd.concat([distributions[model]['collected'], collected_t], ignore_index = True))
    df = pd.concat(all_df, ignore_index = True)
    df.to_csv(os.path.join(plot_data, 'collected.csv'))

## GLUE

In [16]:
glue_keys = pd.read_csv('glue_case_keys.csv')
print(glue_keys)
glue_labels = ['combined', 'entailment', 'neutral', 'contradiction']

                            case                                subcase
0                       combined                               combined
1                      Knowledge                               combined
2                      Knowledge                           Common sense
3                      Knowledge                        World knowledge
4              Lexical Semantics                               combined
..                           ...                                    ...
64  Predicate-Argument Structure  Relative clauses;Anaphora/Coreference
65  Predicate-Argument Structure         Relative clauses;Restrictivity
66  Predicate-Argument Structure                          Restrictivity
67  Predicate-Argument Structure     Restrictivity;Anaphora/Coreference
68  Predicate-Argument Structure         Restrictivity;Relative clauses

[69 rows x 2 columns]


In [17]:
dataset = 'glue'

combineds = ['combined', 'separate']
all_df = []

if overwrite_plotting_data:
    for model in models:
        collected = []
        eval_dir = os.path.join(repo, 'eval_summary', model)

        for idx, caserow in glue_keys.iterrows():
            for label in glue_labels:
                for combined in combineds:
                    sub_keys = {
                        'dataset': dataset,     # either hans or glue
                        'case': caserow['case'],    # combined or specific to respective itereval set
                        'subcase': caserow['subcase'], # combined or specific to respective itereval set
                        'label': label,   # combined or [entailment, neutral, contradiction] for glue, [entailment, non-entailment] for hans
                    }

                    temp = get_itereval_summary(sub_keys, iteration, eval_dir, combined)
                    for idx, row in temp.iterrows():
                        df = pd.DataFrame({
                            'acc': row,
                            'iter': [int(x) for x in row.index.values],
                            'treat':row.name,
                            'case':sub_keys['case'],
                            'subcase':sub_keys['subcase'],
                            'label':sub_keys['label'],
                            'comb':combined,
                            'model':model,
                        })
                        collected.append(df)
                collected_t = pd.concat(collected, ignore_index = True)

        temp_sampled = distributions[model]['itereval']
        temp_sampled = temp_sampled.loc[temp_sampled['dataset'] == dataset, :]
        temp_sampled['model'] = model
        all_df.append(pd.concat([temp_sampled, collected_t], ignore_index=True))
    df = pd.concat(all_df, ignore_index = True)
    df.to_csv(os.path.join(plot_data, 'glue.csv'))

## HANS

In [18]:
hans_keys = pd.read_csv('hans_case_keys.csv')
print(hans_keys)
hans_labels = ['combined', 'entailment', 'non-entailment']

               case                         subcase
0          combined                        combined
1       constituent                       ce_adverb
2       constituent           ce_after_since_clause
3       constituent                  ce_conjunction
4       constituent         ce_embedded_under_since
5       constituent          ce_embedded_under_verb
6       constituent                       cn_adverb
7       constituent              cn_after_if_clause
8       constituent                  cn_disjunction
9       constituent            cn_embedded_under_if
10      constituent          cn_embedded_under_verb
11      constituent                        combined
12  lexical_overlap                        combined
13  lexical_overlap  le_around_prepositional_phrase
14  lexical_overlap       le_around_relative_clause
15  lexical_overlap                  le_conjunction
16  lexical_overlap                      le_passive
17  lexical_overlap              le_relative_clause
18  lexical_

In [19]:
dataset = 'hans'

combineds = ['combined', 'separate']
all_df = []

if overwrite_plotting_data:
    for model in models:
        collected = []
        eval_dir = os.path.join(repo, 'eval_summary', model)

        for idx, caserow in hans_keys.iterrows():
            for label in hans_labels:
                for combined in combineds:
                    sub_keys = {
                        'dataset': dataset,     # either glue or hans
                        'case': caserow['case'],    # combined or specific to respective itereval set
                        'subcase': caserow['subcase'], # combined or specific to respective itereval set
                        'label': label,   # combined or [entailment, neutral, contradiction] for glue, [entailment, non-entailment] for hans
                    }

                    temp = get_itereval_summary(sub_keys, iteration, eval_dir, combined)
                    for idx, row in temp.iterrows():
                        df = pd.DataFrame({
                            'acc': row,
                            'iter': [int(x) for x in row.index.values],
                            'treat':row.name,
                            'case':sub_keys['case'],
                            'subcase':sub_keys['subcase'],
                            'label':sub_keys['label'],
                            'comb':combined,
                            'model':model,
                        })
                        collected.append(df)
                collected_t = pd.concat(collected, ignore_index = True)

        temp_sampled = distributions[model]['itereval']
        temp_sampled = temp_sampled.loc[temp_sampled['dataset'] == dataset, :]
        temp_sampled['model'] = model
        all_df.append(pd.concat([temp_sampled, collected_t], ignore_index=True))
    df = pd.concat(all_df, ignore_index = True)
    df.to_csv(os.path.join(plot_data, 'hans.csv'))

## MNLI

In [20]:
all_df = []

if overwrite_plotting_data:
    for model in models:
        eval_dir = os.path.join(repo, 'eval_summary', model)
        mnli_summary = os.path.join(eval_dir, 'mnli_evals', 'eval_summaries.jsonl')

        with open(mnli_summary, 'r') as f:
            summary = pd.DataFrame([json.loads(line) for line in f])
        summary['breakdown'] = summary['genre'].fillna('combined')
        summary['iter'] = summary['iter'].apply(lambda x: int(x))

        temp = pd.concat([distributions[model]['mnli'], summary], ignore_index=True)
        temp['model'] = model

        all_df.append(temp)
    df = pd.concat(all_df, ignore_index = True)
    df.to_csv(os.path.join(plot_data, 'mnli.csv'))

## ANLI

In [21]:
all_df = []

if overwrite_plotting_data:
    for model in models:
        eval_dir = os.path.join(repo, 'eval_summary', model)
        df = pd.read_csv(os.path.join(eval_dir, 'sample', sample_type, 'final', 'anli_by_annotation.csv'))
        df['model'] = model

        all_df.append(df)
    df = pd.concat(all_df, ignore_index = True)
    df.to_csv(os.path.join(plot_data, 'anli.csv'))

# Plot

## Def plot

In [22]:
def err_line_plots(
    plot_df,
    ylim=[0,1],
    title=None,
    xlabel=None,
    ylabel=None,
    tabletitle=None,
    tableon=True,
    x='iter',
    y='acc',
    hue='treat',
    err_style='bars',
    ci=95,
    estimator=lambda x: np.median(x),
    markers=True,
    hue_order=['baseline', 'LotS', 'LitL'],
    iteration=5,
    bbox_to_anchor=(1.01, 1),
    palette=None,
    style_key='combined',
    style_order=['combined', 'separate'],
    yaxis_visible = True,
    xaxis_visible = True,
    legend_visible = True,
    ax=None,
):
    no_ax = not ax
    if not ax:
        fig, ax = plt.subplots()

    sns.lineplot(
        data=plot_df, x=x, y=y,
        hue=hue, err_style=err_style, ci=ci, markers=markers,
        style = style_key, style_order = style_order,
        ax=ax
    )

    ax.set_title(title)
    
    ax.set_xlabel(xlabel if xaxis_visible else '')
    ax.set_xticks(np.arange(1, 6, 1))
    if not xaxis_visible:
        ax.xaxis.set_ticklabels([])
    
    ax.set_ylabel(ylabel if yaxis_visible else '')
    ax.set_ylim(*ylim)
    if not yaxis_visible:
        ax.yaxis.set_ticklabels([])
    
    ax.legend(bbox_to_anchor=bbox_to_anchor).set_visible(legend_visible)
    
    if no_ax:
        fig.tight_layout()
        return fig

## Plot Params

In [23]:
combined = 'separate'

In [24]:
repo = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
plot_dir = os.path.join(repo, 'eval_summary', 'plot_data')
plot_out = os.path.join(repo, 'eval_summary', 'plots')
os.makedirs(plot_out, exist_ok=True)

In [25]:
acc_name = 'Performance'
diff_name = 'Over Baseline'

In [26]:
save_figs = True
figtype='jpg'

## In-domain

In [27]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'collected.csv'))

In [28]:
input_type = 'full'

plot_df = plot_df.loc[plot_df['combined'] == combined, :]
plot_df = plot_df.loc[plot_df['mod'] == input_type, :]

In [29]:
ylim=[0,1]
title=f'In-domain Validation'
xlabel='Iteration'
ylabel='Accuracy'
tabletitle='Median'
err_style='bars'
tableon=False

style_key='model'
style_order=['roberta-large', 'roberta-large-mnli']

fig = err_line_plots(
    plot_df,
    err_style=err_style,
    ylim=ylim,
    title=title,
    xlabel=xlabel,
    ylabel=ylabel,
    tabletitle=tabletitle,
    palette={
        'baseline':'tab:blue',
        'LotS':'tab:orange',
        'LitL':'tab:green',
    },
    tableon=tableon,
    style_key=style_key,
    style_order=style_order,
)
if save_figs:
    fig.savefig(os.path.join(plot_out, f'{combined}_indomain_val.{figtype}'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Hyp

In [30]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'collected.csv'))

In [31]:
input_type = 'hyp'

plot_df = plot_df.loc[plot_df['combined'] == combined, :]
plot_df = plot_df.loc[plot_df['mod'] == input_type, :]

In [32]:
ylim=[0,1]
title=f'Hypothesis-only Input'
xlabel='Iteration'
ylabel='Accuracy'
tabletitle='Median'
err_style='bars'
tableon=False

style_key='model'
style_order=['roberta-large', 'roberta-large-mnli']

fig = err_line_plots(
    plot_df,
    err_style=err_style,
    ylim=ylim,
    title=title,
    xlabel=xlabel,
    ylabel=ylabel,
    tabletitle=tabletitle,
    palette={
        'baseline':'tab:blue',
        'LotS':'tab:orange',
        'LitL':'tab:green',
    },
    tableon=tableon,
    style_key=style_key,
    style_order=style_order,
)
if save_figs:
    fig.savefig(os.path.join(plot_out, f'{combined}_hypothesis_only.{figtype}'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## GLUE

In [33]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'glue.csv'))

  interactivity=interactivity, compiler=compiler, result=result)


In [34]:
label = 'combined'

plot_df = plot_df.loc[plot_df['comb'] == combined, :]
plot_df = plot_df.loc[plot_df['label'] == label, :]
plot_df = plot_df.loc[plot_df['subcase'] == 'combined', :]

In [35]:
ylim=[0,1]
title=f""
xlabel='Iteration'
ylabel='Accuracy'
tabletitle='median'
err_style='bars'
tableon=False

bbox_to_anchor = (1.05, 1)
figsize=(15, 5)

style_key='model'
style_order=['roberta-large', 'roberta-large-mnli']

glue_keys = pd.read_csv('glue_case_keys.csv')
fig, ax = plt.subplots(1, len(glue_keys['case'].unique()) - 1, figsize=figsize)

i = 0
for case in glue_keys['case'].unique():
    if case == 'combined':
        continue
    
    temp_df = plot_df.loc[plot_df['case'] == case, :]
    err_line_plots(
        temp_df,
        err_style=err_style,
        ylim=ylim,
        title=f"{case}",
        xlabel=xlabel,
        ylabel=ylabel,
        tabletitle=tabletitle,
        palette={
            'baseline':'tab:blue',
            'LotS':'tab:orange',
            'LitL':'tab:green',
        },
        tableon=tableon,
        style_key=style_key,
        style_order=style_order,
        ax=ax[i],
        yaxis_visible = i == 0,
        legend_visible = i == len(glue_keys['case'].unique()) - 2,
        bbox_to_anchor=bbox_to_anchor,
    )
    i += 1

fig.tight_layout()
if save_figs:
    fig.savefig(os.path.join(plot_out, f'{combined}_glue.{figtype}'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## HANS non-entailment

In [36]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'hans.csv'))

In [37]:
label = 'non-entailment'

plot_df = plot_df.loc[plot_df['comb'] == combined, :]
plot_df = plot_df.loc[plot_df['label'] == label, :]
plot_df = plot_df.loc[plot_df['subcase'] == 'combined', :]

In [38]:
ylim=[0,1]
title=f""
xlabel='Iteration'
ylabel='Accuracy'
tabletitle='median'
err_style='bars'
tableon=False

bbox_to_anchor = (1.9, 1)
figsize=(10, 5)

style_key='model'
style_order=['roberta-large', 'roberta-large-mnli']

hans_keys = pd.read_csv('hans_case_keys.csv')
fig, ax = plt.subplots(1, len(hans_keys['case'].unique()) - 1, figsize=figsize)

i = 0
for case in hans_keys['case'].unique():
    if case == 'combined':
        continue
    
    temp_df = plot_df.loc[plot_df['case'] == case, :]
    err_line_plots(
        temp_df,
        err_style=err_style,
        ylim=ylim,
        title=f"{case}",
        xlabel=xlabel,
        ylabel=ylabel,
        tabletitle=tabletitle,
        palette={
            'baseline':'tab:blue',
            'LotS':'tab:orange',
            'LitL':'tab:green',
        },
        tableon=tableon,
        style_key=style_key,
        style_order=style_order,
        ax=ax[i],
        yaxis_visible = i == 0,
        legend_visible = i == len(hans_keys['case'].unique()) - 2,
        bbox_to_anchor=bbox_to_anchor
    )
    i += 1

fig.tight_layout()
if save_figs:
    fig.savefig(os.path.join(plot_out, f'{combined}_hans.{figtype}'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## MNLI

In [39]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'mnli.csv'))

In [40]:
plot_df = plot_df.loc[plot_df['comb'] == combined, :]

In [41]:
ylim=[0,1]
title=f'MNLI-mismatched'
xlabel='Iteration'
ylabel='Accuracy'
tabletitle='Median'
err_style='bars'
tableon=False

style_key='model'
style_order=['roberta-large', 'roberta-large-mnli']

fig = err_line_plots(
    plot_df,
    err_style=err_style,
    ylim=ylim,
    title=title,
    xlabel=xlabel,
    ylabel=ylabel,
    tabletitle=tabletitle,
    palette={
        'baseline':'tab:blue',
        'LotS':'tab:orange',
        'LitL':'tab:green',
    },
    tableon=tableon,
    style_key=style_key,
    style_order=style_order,
)

if save_figs:
    fig.savefig(os.path.join(plot_out, f'{combined}_mnli.{figtype}'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## ANLI

In [42]:
plot_df = pd.read_csv(os.path.join(plot_dir, 'anli.csv'))

In [43]:
breakdowns = [
        'combined',
        'Basic',
        'EventCoref',
        'Imperfection',
        'Numerical',
        'Reasoning',
        'Reference',
        'Tricky',
    ]

In [44]:
plot_df = plot_df.loc[plot_df['comb'] == combined, :]

In [45]:
ylim=[0,1]
title=f""
xlabel='Iteration'
ylabel='Accuracy'
tabletitle='median'
err_style='bars'
tableon=False

bbox_to_anchor = (6.8, 1)
figsize=(20, 5)

style_key='model'
style_order=['roberta-large', 'roberta-large-mnli']

fig, ax = plt.subplots(1, len(breakdowns) - 1, figsize=figsize)

i = 0
for case in breakdowns:
    if case == 'combined':
        continue
    
    temp_df = plot_df.loc[plot_df['breakdown'] == case, :]
    err_line_plots(
        temp_df,
        err_style=err_style,
        ylim=ylim,
        title=f"{case}",
        xlabel=xlabel,
        ylabel=ylabel,
        tabletitle=tabletitle,
        palette={
            'baseline':'tab:blue',
            'LotS':'tab:orange',
            'LitL':'tab:green',
        },
        tableon=tableon,
        style_key=style_key,
        style_order=style_order,
        ax=ax[i],
        yaxis_visible = i == 0,
        legend_visible = i == len(hans_keys['case'].unique()) - 2,
        bbox_to_anchor=bbox_to_anchor
    )
    i += 1

fig.tight_layout()
if save_figs:
    fig.savefig(os.path.join(plot_out, f'{combined}_anli.{figtype}'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

