# Using adaptive retrieval on the synthetic dataset

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.stats.proportion import proportion_confint
from scipy.stats import pearsonr
import os

seed = 633
np.random.seed(seed)
import random
random.seed(seed)

In [16]:
# take the results of parametric and nonparametric-augmented systems 
# and compute how well adaptive retrieval would perform
do_plot = False
n_boot = 100
parametric_path = ""  # ADD PATH TO VANILLA RESULTS
nonparametric_path = ""  # ADD PATH TO RETRIEVAL-AUGMENTED RESULTS
def clean(df):
    return df[~df["s_pop"].isna() & (df["s_pop"] >= 0)]
sample = clean(pd.read_csv(parametric_path))
sample_ret = clean(pd.read_csv(nonparametric_path))
sample = sample.sort_values("question").reset_index(drop=True)
sample_ret = sample_ret.sort_values("question").reset_index(drop=True)

In [17]:
# suppress warnings
import warnings
warnings.filterwarnings("ignore")
props = sample.prop.unique()

test_ret_accs_all = []
test_param_accs_all = []
test_hybrid_accs_all = []
test_count_rets = []
test_count_params = []
train_hybrid_accs_all = []

for boot in range(1 if do_plot else n_boot):
    split_proportion = 0.75
    split_mask = ["train"] * round(len(sample) * split_proportion) + ["test"] * round(len(sample) * (1 - split_proportion))
    np.random.shuffle(split_mask)
    sample["split"] = split_mask
    sample_ret["split"] = split_mask

    test_ret_accs = []
    test_param_accs = []
    test_hybrid_accs = []
    train_hybrid_accs = []
    test_count_rets.append(0)
    test_count_params.append(0)
    test_sizes = []
    train_sizes = []
    train_threshs = dict()
    plot_title = ""
    if do_plot:
        plt.figure(dpi=200, figsize=(30, 30))
    for i, prop in enumerate(props):

        if do_plot:
            plt.subplot(4, 4, i+1)
        cluster_sample = sample[sample.prop == prop].copy()
        cluster_sample_ret_with_pop = sample_ret[sample_ret.prop == prop].copy()

        log_pop = np.log(cluster_sample["s_pop"].values)
        cluster_sample["log_pop"] = log_pop
        cluster_sample_ret_with_pop["log_pop"] = log_pop
        ser = log_pop
        _, bin_edges = np.histogram(ser)

        c = cluster_sample.is_correct.values
        counts_c, _ = np.histogram(ser[c], bins=bin_edges)
        counts_inc, _ = np.histogram(ser[~c], bins=bin_edges)
        total = counts_c + counts_inc

        c_ret = cluster_sample_ret_with_pop.is_correct.values
        counts_c_ret, _ = np.histogram(ser[c_ret], bins=bin_edges)
        counts_inc_ret, _ = np.histogram(ser[~c_ret], bins=bin_edges)
        total_ret = counts_c_ret + counts_inc_ret

        width = 0.4*(bin_edges[1] - bin_edges[0])
        thresh_idx = np.argmax(list((sum(counts_c_ret[:i]) + sum(counts_c[i:])) / sum(total_ret) for i in range(len(total_ret) + 1)))

        # plt.bar(bin_edges[:-1] - 0.5 * width, counts_c / total, width=width, alpha=0.9, label="parametric knowledge (vanilla)", align='edge')
        # plt.bar(bin_edges[:-1] + 0.5 * width, counts_c_ret / total_ret, width=width, alpha=0.9, label="retrieval augmented (BM25)", align='edge')
        # lo, hi = proportion_confint(counts_c, total, alpha=0.05, method='wilson')
        # plt.errorbar(bin_edges[:-1], counts_c / total, yerr=[counts_c / total - lo, hi - counts_c / total], fmt='none', ecolor='black', elinewidth=1, capsize=2)
        # lo, hi = proportion_confint(counts_c_ret, total_ret, alpha=0.05, method='wilson')
        # plt.errorbar(bin_edges[:-1] + width, counts_c_ret / total_ret, yerr=[counts_c_ret / total_ret - lo, hi - counts_c_ret / total_ret], fmt='none', ecolor='black', elinewidth=1, capsize=2)
        # # plt.errorbar(bin_edges[:-1], counts_c / total, , fmt='none', ecolor='black', capsize=2)
        # # plt.errorbar(bin_edges[:-1], counts_c_ret / total_ret, yerr=wilson(counts_c_ret / total_ret, total_ret), fmt='none', ecolor='black', capsize=2)
        # plt.axvline(x=bin_edges[thresh_idx] + 0.7 * (bin_edges[1] - bin_edges[0]), color='red', linestyle='--', label="threshold")

        param_acc = sum(counts_c) / sum(total)
        ret_acc = sum(counts_c_ret) / sum(total_ret)
        hybrid_acc = (sum(counts_c_ret[:thresh_idx]) + sum(counts_c[thresh_idx:])) / (sum(total_ret[:thresh_idx]) + sum(total[thresh_idx:]))
        ret_acc_gain = hybrid_acc - ret_acc
        param_acc_gain = hybrid_acc - param_acc
        
        # let the optimal threshold be the one that maximizes the hybrid accuracy
        train_idxs = cluster_sample.split.values == "train"
        test_idxs = cluster_sample.split.values == "test"
        train_ser = ser[train_idxs]
        test_ser = ser[test_idxs]
        train_c = c[train_idxs]
        test_c = c[test_idxs]
        train_c_ret = c_ret[train_idxs]
        test_c_ret = c_ret[test_idxs]
        train_counts_c, _ = np.histogram(train_ser[train_c], bins=bin_edges)
        train_counts_inc, _ = np.histogram(train_ser[~train_c], bins=bin_edges)
        train_total = train_counts_c + train_counts_inc
        train_counts_c_ret, _ = np.histogram(train_ser[train_c_ret], bins=bin_edges)
        train_counts_inc_ret, _ = np.histogram(train_ser[~train_c_ret], bins=bin_edges)
        train_total_ret = train_counts_c_ret + train_counts_inc_ret
        test_counts_c, _ = np.histogram(test_ser[test_c], bins=bin_edges)
        test_counts_inc, _ = np.histogram(test_ser[~test_c], bins=bin_edges)
        test_total = test_counts_c + test_counts_inc
        test_counts_c_ret, _ = np.histogram(test_ser[test_c_ret], bins=bin_edges)
        test_counts_inc_ret, _ = np.histogram(test_ser[~test_c_ret], bins=bin_edges)
        test_total_ret = test_counts_c_ret + test_counts_inc_ret
        
        # find the optimal threshold
        train_thresh_idx = np.argmax(list((sum(train_counts_c_ret[:i]) + sum(train_counts_c[i:])) / (sum(train_total_ret[:i]) + sum(train_total[i:])) for i in range(len(train_total) + 1)))
        train_thresh = bin_edges[train_thresh_idx] - 0.5 * (bin_edges[1] - bin_edges[0])
        train_threshs[prop] = train_thresh

        # calculate the accuracy on the test set
        test_param_acc = sum(test_counts_c) / sum(test_total)
        test_ret_acc = sum(test_counts_c_ret) / sum(test_total_ret)
        test_hybrid_acc = (sum(test_counts_c_ret[:train_thresh_idx]) + sum(test_counts_c[train_thresh_idx:])) / (sum(test_total_ret[:train_thresh_idx]) + sum(test_total[train_thresh_idx:]))
        test_sizes.append(sum(test_total))
        test_count_rets[-1] += sum(test_total_ret[:train_thresh_idx])
        test_count_params[-1] += sum(test_total[train_thresh_idx:])
        test_ret_accs.append(test_ret_acc)
        test_param_accs.append(test_param_acc)
        test_hybrid_accs.append(test_hybrid_acc)
        train_hybrid_accs.append((sum(train_counts_c_ret[:train_thresh_idx]) + sum(train_counts_c[train_thresh_idx:])) / (sum(train_total_ret[:train_thresh_idx]) + sum(train_total[train_thresh_idx:])))
        train_sizes.append(sum(train_total))

        if do_plot:
            plt.bar(bin_edges[:-1] - 0.5 * width, test_counts_c / test_total, width=width, alpha=0.9, label="parametric knowledge (vanilla)", align='edge', hatch='//')
            plt.bar(bin_edges[:-1] + 0.5 * width, test_counts_c_ret / test_total_ret, width=width, alpha=0.9, label="retrieval augmented (BM25)", align='edge', hatch='//')
            lo, hi = proportion_confint(test_counts_c, test_total, alpha=0.05, method='wilson')
            plt.errorbar(bin_edges[:-1], test_counts_c / test_total, yerr=[test_counts_c / test_total - lo, hi - test_counts_c / test_total], fmt='none', ecolor='black', elinewidth=1, capsize=2)
            lo, hi = proportion_confint(test_counts_c_ret, test_total_ret, alpha=0.05, method='wilson')
            plt.errorbar(bin_edges[:-1] + width, test_counts_c_ret / test_total_ret, yerr=[test_counts_c_ret / test_total_ret - lo, hi - test_counts_c_ret / test_total_ret], fmt='none', ecolor='black', elinewidth=1, capsize=2)
            plt.axvline(x=bin_edges[train_thresh_idx] - 0.8 * width, color='red', linestyle='--', label="threshold")

            print(f"Threshold for {prop}:", train_thresh)
            print("Parametric knowledge only:", param_acc)
            print("Retrieval augmented accuracy:", ret_acc)
            print(f"New accuracy with thresh={thresh_idx}:", hybrid_acc)
            print()
            plt.title(f"{prop}")
            plt.ylim([0,1.01])
    if do_plot:
        plt.xlabel("log(s_pop)")
        plt.ylabel("proportion correct")
        plt.legend()
        plt.tight_layout()
        plt.show()

    # take the weighted mean by test_size
    test_ret_accs_all.append(np.average(test_ret_accs, weights=test_sizes))
    test_param_accs_all.append(np.average(test_param_accs, weights=test_sizes))
    test_hybrid_accs_all.append(np.average(test_hybrid_accs, weights=test_sizes))
    train_hybrid_accs_all.append(np.average(train_hybrid_accs, weights=train_sizes))

test_size = split_mask.count("test")
prop_adaptive = (np.mean(test_count_rets) + np.mean(test_count_params)) / test_size
param_acc = np.mean(test_param_accs_all)
ret_acc = np.mean(test_ret_accs_all)
hybrid_acc = np.mean(test_hybrid_accs_all)
train_hybrid_acc = np.mean(train_hybrid_accs_all)
overall_hybrid_acc = np.mean(test_hybrid_accs_all) * prop_adaptive + max(np.mean(test_param_accs_all), np.mean(test_ret_accs_all)) * (1 - prop_adaptive)
sem_test_hybrid_acc = 2 * np.std(test_hybrid_accs_all) / np.sqrt(test_size)
print("prop adaptive:", prop_adaptive)
print("prop retrieval:", np.mean(test_count_rets) / test_size)
print("parametric knowledge only:", param_acc)
print("retrieval augmented accuracy:", ret_acc)
print("hybrid accuracy:", hybrid_acc, "pm", sem_test_hybrid_acc)
print("hybrid accuracy on train:", train_hybrid_acc)
print("overall accuracy gain:", overall_hybrid_acc - max(param_acc, ret_acc))
print("overall accuracy:", overall_hybrid_acc)

prop adaptive: 1.0
prop retrieval: 0.7981889543033361
parametric knowledge only: 0.17163442668909443
retrieval augmented accuracy: 0.2537706756377908
hybrid accuracy: 0.2706980656013457 pm 0.000232407510793975
hybrid accuracy on train: 0.2719317757009346
overall accuracy gain: 0.016927389963554862
overall accuracy: 0.2706980656013457
