# Import

In [1]:
%matplotlib widget

In [2]:
import os

import pandas as pd
import numpy as np
import json
import joblib
import pickle
import torch

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import ttest_ind

In [3]:
sns.set_theme(style='whitegrid')

# Define

In [4]:
def get_val_summary(modifier, iteration, eval_dir, ):
    fname = os.path.join(eval_dir, f'r{iteration}', 'tables', f'configs.{modifier}.csv')
    summary_table = pd.read_csv(fname, index_col = 0)
    summary_table = summary_table[[str(n) for n in range(1, iteration+1)]]
    
    return summary_table


def get_itereval_summary(sub_keys, iteration, eval_dir, combined, ):
    rep = {
        '/': '-',
        ';': '--',
    }
    
    fname_key = '.'.join(sub_keys.values())
    for old_char, new_char in rep.items():
        fname_key = fname_key.replace(old_char, new_char)
    fname = os.path.join(eval_dir, f'r{iteration}', 'tables', combined, f'iterevals.{fname_key}.csv')
    summary_table = pd.read_csv(fname, index_col = 0)
    summary_table = summary_table[[str(n) for n in range(1, iteration+1)]]
    
    return summary_table
    

In [5]:
def get_mnli_tables(mnli_summary, subsetting='genre'):
    with open(mnli_summary, 'r') as f:
        summary = pd.DataFrame([json.loads(line) for line in f])
    
    mnli_tables = {}
    for comb in summary['comb'].unique():
        comb_sum = summary.loc[summary['comb'] == comb, :]

        for subset in summary[subsetting].unique():
            subset_sum = comb_sum.loc[comb_sum[subsetting] == subset, :]

            plot_tab = []
            for treat in subset_sum['treat'].unique():
                treat_sum = subset_sum.loc[subset_sum['treat'] == treat, :]
                s = treat_sum[['iter','acc']].set_index('iter').rename({'acc': treat}, axis=1).transpose()            
                plot_tab.append(s)
            
            mnli_tables[(model, comb, subset)] = pd.concat(plot_tab)
    
    return summary, mnli_tables

In [6]:
def split_run_name(run_name, split_by='_'):
    name_list = run_name.split(split_by)
    if len(name_list) == 2:
        input_type = 'full'
        comb = 'combined'
    elif len(name_list) == 3:
        if name_list[-1] == 'hyp':
            input_type = name_list[-1]
            comb = 'combined'
        else:
            input_type = 'full'
            comb = name_list[-1]
    else:
        input_type = name_list[-1]
        comb = name_list[-2]

    return (name_list[0], name_list[1], input_type, comb)

In [7]:
def load_sampled_results(sampled_base):
    collected = pd.read_csv(os.path.join(sampled_base, 'collected.csv'))
    itereval = pd.read_csv(os.path.join(sampled_base, 'itereval.csv'))
    mnli = pd.read_csv(os.path.join(sampled_base, 'mnli.csv'))
    anli = pd.read_csv(os.path.join(sampled_base, 'anli.csv'))
    
    
    # fill in keys
    collected['treat'] = collected['run'].apply(lambda x: split_run_name(x)[0])
    collected['iter'] = collected['run'].apply(lambda x: int(split_run_name(x)[1]))
    collected['mod'] = collected['run'].apply(lambda x: split_run_name(x)[2])
    collected['combined'] = collected['run'].apply(lambda x: split_run_name(x)[3])
    
    mnli['breakdown'] = mnli['genre'].fillna('combined')
    anli['breakdown'] = anli['tag'].fillna('combined')
    
    return collected, itereval, mnli, anli

In [8]:
def load_all_sampled(sampled_base, upto=5):
    loaded_keys = {'collected': 0, 'itereval':1 ,'mnli': 2, 'anli': 3}
    results = {key: [] for key in loaded_keys.keys()}
    
    for r in range(1, upto + 1):
        loaded = load_sampled_results(os.path.join(sampled_base, f'r{r}'))
        for result_key, loaded_key in loaded_keys.items():
            results[result_key].append(loaded[loaded_key])
    
    return {
        key: pd.concat(result_list, ignore_index=True)
        for key, result_list in results.items()
    }
    

In [9]:
def get_ttest_pvals(dist_df, verbose=True):
    pairs = [
        ('baseline', 'LotS'),
        ('baseline', 'LitL'),
        ('LotS', 'LitL'),
    ]
    
    ttest_dict = {}
    for pair in pairs:
        a = dist_df.loc[dist_df['treat'] == pair[0], 'acc']
        b = dist_df.loc[dist_df['treat'] == pair[1], 'acc']
        ttest_dict[pair] = ttest_ind(a, b)
    
    if verbose:
        for pair, ttest_results in ttest_dict.items():
            print('='*45)
            print(f"{pair}\nt: {ttest_results[0]:.5f} | p: {ttest_results[1]/2:.5f}")
    
    return ttest_dict
    

In [10]:
def err_line_plots(
    plot_df,
    ylim=[0,1],
    title=None,
    xlabel=None,
    ylabel=None,
    tabletitle=None,
    tableon=True,
    x='iter',
    y='acc',
    hue='treat',
    err_style='bars',
    ci=95,
    estimator=lambda x: np.median(x),
    markers=True,
    hue_order=['baseline', 'LotS', 'LitL'],
    iteration=5,
    bbox_to_anchor=(1.01, 1),
    palette=None,
    style_key='combined',
    ax=None,
):
    if not ax:
        fig, ax = plt.subplots()

    sns.lineplot(
        data=plot_df, x=x, y=y,
        hue=hue, err_style=err_style, ci=ci, markers=markers,
        style = style_key, style_order = ['combined', 'separate'],
        ax=ax
    )

    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_xticks(np.arange(1,6,1))
    ax.set_ylabel(ylabel)
    ax.set_ylim(*ylim)
    ax.legend(bbox_to_anchor=bbox_to_anchor)
    
    fig.tight_layout()
    
    if not ax:
        return fig

# Plot

In [11]:
model='roberta-large-mnli'
repo = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
eval_dir = os.path.join(repo, 'eval_summary', model)
sample_type = 'cross_eval'
iteration = 5

mnli_summary = os.path.join(eval_dir, 'mnli_evals', 'eval_summaries.jsonl')
anli_summary = os.path.join(eval_dir, 'anli_evals', 'eval_summaries.jsonl')

plots_dir = os.path.join(eval_dir, 'sample', sample_type, f'final', 'plots')
os.makedirs(plots_dir, exist_ok=True)

In [12]:
acc_name = 'Performance'
diff_name = 'Over Baseline'

In [13]:
distributions = load_all_sampled(
    os.path.join(eval_dir, 'sample', sample_type), upto=iteration
)

### ANLI Breakdown

In [14]:
repo_up = os.path.dirname(repo)
anli_annot_fname = os.path.join(repo_up, 'anli_annot_v0.2_combined_A1A2')
pred2annot_dicts = os.path.join(repo_up, 'data', 'ANLI_dicts.p')

In [15]:
anli_annot = joblib.load(anli_annot_fname)

with open(pred2annot_dicts, 'rb') as f:
    pred2annot = pickle.load(f)['idxo2idxa']

In [16]:
anli_annot['gold_label'].unique()

array(['c', 'e', 'n'], dtype=object)

In [17]:
breakdowns = [
    'Basic',
    'EventCoref',
    'Imperfection',
    'Numerical',
    'Reasoning',
    'Reference',
    'Tricky',
]

for breakdown in breakdowns:
    print(breakdown, anli_annot[breakdown].isna().sum())
    print(breakdown, anli_annot[breakdown].ne('none').sum())

Basic 0
Basic 1327
EventCoref 0
EventCoref 66
Imperfection 0
Imperfection 453
Numerical 0
Numerical 1036
Reasoning 0
Reasoning 1977
Reference 0
Reference 868
Tricky 0
Tricky 893


In [18]:
def get_anli_breakdowns(
    preds,
    anli_annotated,
    pred2annot, 
    int2pred={0:'c', 1:'e', 2:'n'},
    breakdowns = [
        'combined',
        'Basic',
        'EventCoref',
        'Imperfection',
        'Numerical',
        'Reasoning',
        'Reference',
        'Tricky',
    ],
    verbose=False,
):
    ans = {}
    for breakdown in breakdowns:
        ans[breakdown] = get_anli_acc(
            preds,
            anli_annotated,
            pred2annot,
            subset_col=breakdown,
            verbose=verbose,
            int2pred=int2pred,
        )
    return ans

def get_anli_acc(
    preds,
    anli_annotated,
    pred2annot, 
    int2pred={0:'contradiction', 1:'entailment', 2:'neutral'}, 
    subset_col=None,
    verbose=False,
):
    temp = anli_annotated
    
    skipped = 0
    # get predictions
    for idx, pred in enumerate(preds):
        try:
            temp.loc[pred2annot[idx], 'pred'] = int2pred[pred]
        except KeyError as ke:
            if verbose:
                print(f'Warning: {ke}')
            skipped += 1
    if verbose:
        print(skipped)
    
    assert temp['pred'].isna().sum() == 0
    
    temp['correct'] = temp['gold_label'] == temp['pred']
    
    if not subset_col is None and subset_col != 'combined':
        temp = temp.loc[temp[subset_col].ne('none'), :]
    
    return temp['correct'].sum()/temp.shape[0]

In [19]:
sample_partitions = np.linspace(0.1, 1, 10)
treats = {
    'baseline': '1_Baseline_protocol',
    'LotS': '2_Ling_on_side_protocol',
    'LitL': '3_Ling_in_loop_protocol',
}
rounds = range(1, 6)
combineds = ['combined', 'separate']
sampling = 'cross_eval'

In [20]:
save = True
verbose = False

pred_base = os.path.join(repo, 'predictions', model, 'anli_evals')

anli_accs = []
for treat, treat_dir in treats.items():
    for r in rounds:
        for combined in combineds:
            print(treat, f'{r:.1f}', combined)
            # breakdown for collected
            ext_base = os.path.join(pred_base, treat_dir, f'r{r}', combined)
            breakdowns = get_anli_breakdowns(
                torch.load(os.path.join(ext_base, 'val_preds.p'))['mnli']['preds'],
                anli_annot,
                pred2annot,
                verbose=verbose,
            )
            for breakdown, acc in breakdowns.items():
                anli_accs.append(
                    {
                        'treat': treat,
                        'iter': int(r),
                        'comb': combined,
                        'breakdown': breakdown,
                        'acc': acc,
                        'sample_partition': None,
                    }
                )
            
            # breakdown for sampled collected
            for sample_partition in sample_partitions:
                extext_base = os.path.join(ext_base, sampling, f'{sample_partition:.1f}')
                breakdowns = get_anli_breakdowns(
                    torch.load(os.path.join(extext_base, 'val_preds.p'))['mnli']['preds'],
                    anli_annot,
                    pred2annot,
                    verbose=verbose,
                )
                for breakdown, acc in breakdowns.items():
                    anli_accs.append(
                        {
                            'treat': treat,
                            'iter': int(r),
                            'comb': combined,
                            'breakdown': breakdown,
                            'acc': acc,
                            'sample_partition': sample_partition,
                        }
                    )
anli_accs_df = pd.DataFrame(anli_accs)
if save:
    os.makedirs(os.path.join(eval_dir, 'sample', sampling, 'final'), exist_ok=True)
    anli_accs_df.to_csv(os.path.join(eval_dir, 'sample', sampling, 'final', 'anli_by_annotation.csv'))

baseline 1.0 combined
baseline 1.0 separate
baseline 2.0 combined
baseline 2.0 separate
baseline 3.0 combined
baseline 3.0 separate
baseline 4.0 combined
baseline 4.0 separate
baseline 5.0 combined
baseline 5.0 separate
LotS 1.0 combined
LotS 1.0 separate
LotS 2.0 combined
LotS 2.0 separate
LotS 3.0 combined
LotS 3.0 separate
LotS 4.0 combined
LotS 4.0 separate
LotS 5.0 combined
LotS 5.0 separate
LitL 1.0 combined
LitL 1.0 separate
LitL 2.0 combined
LitL 2.0 separate
LitL 3.0 combined
LitL 3.0 separate
LitL 4.0 combined
LitL 4.0 separate
LitL 5.0 combined
LitL 5.0 separate
