In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np

In [None]:
# ## Raw preprocessing
# results = pd.read_csv("results/eval_mmlu-13b_chat-ct-oe.csv")
# results = results.rename(columns={'fuzzy_gpt-3.5-turbo-1106_acc': 'acc',
#  'fuzzy_gpt-3.5-turbo-1106_unc_acc': 'unc_acc',
#  'fuzzy_gpt-3.5-turbo-1106_unc_ece': 'ece'})
# results.loc[results.query_peft_dir.isna(), "method"] = "base"
# results.loc[~results.query_peft_dir.isna(), "method"] = "ct"
# results.to_csv("results/eval_mmlu-13b_chat-ct-oe.csv", index=False)

In [None]:
results = pd.read_csv("results/eval_mmlu-13b_chat-ct-oe.csv")

In [None]:
dname_map = {d: idx + 1 for idx, d in enumerate(results.dataset.unique())}

plt_results = results[results.split == "test"]
plt_results["task_idx"] = results.dataset.apply(lambda d: dname_map[d])
plt_results["ece_100"] = results["ece"] * 100.
plt_results["acc_100"] = results["acc"] * 100.
plt_results["unc_acc_100"] = results["unc_acc"] * 100.
plt_results = plt_results.sort_values("task_idx")
plt_results

In [None]:
sns.set_theme(font_scale=1.5, style="whitegrid")

ref = plt_results[plt_results.method == "base"]["ece_100"].values
new = plt_results[plt_results.method == "ct"]["ece_100"].values

g = sns.displot(pd.DataFrame({ "rel_imp": -((new - ref) / ref) * 100. }),
                # kind="kde", fill=True, bw_adjust=.5,
                kde=True, stat="count", kde_kws={"bw_adjust": .5, "cut": 2}, binwidth=10,
                height=5, aspect=4/3, legend=False,
                palette=sns.color_palette("tab20")[8:9])

g.ax.get_lines()[0].set(linewidth=5)

g.set(xlabel="ECE Improvement over Untrained Query (%)")

g.fig.tight_layout()
g.fig.show()
# g.fig.savefig("mmlu_rel_imp_qa_oe.pdf", bbox_inches="tight")

In [None]:
sns.set_theme(font_scale=6., style="whitegrid")

fig, ax = plt.subplots(figsize=(100, 40))

ax = sns.barplot(ax=ax, data=plt_results,
                 x="task_idx", y="ece_100", hue="method",
                 width=0.68,
                 hue_order=["ct", "base"],
                 palette=sns.color_palette("tab20")[8:10])

ax.set_ylabel(r'$\mathrm{ECE} (\%)$', fontsize=250, labelpad=100)
# ax.set_yticks(np.arange(0, 60 + 1e-3, 10))
ax.set_yticklabels(ax.get_yticklabels(), fontsize=150)

ax.set_xlabel('MMLU Task ID', fontsize=250, labelpad=100)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=150)

handles, labels = ax.get_legend_handles_labels()
label_map = { "base": "Base", "ct": "CT" }
labels = [label_map[l] for l in labels]

ax.legend(handles=handles, labels=labels, loc='best',
          title='', title_fontsize=200,
          prop=dict(size=180))#, bbox_to_anchor=(.91, .7, .1, .1))

fig.tight_layout()
fig.show()
# fig.savefig("mmlu_oe_ece_comparison.pdf", bbox_inches="tight")

In [None]:
sns.set_theme(font_scale=6., style="whitegrid")

fig, ax = plt.subplots(figsize=(100, 40))

ax = sns.barplot(ax=ax, data=plt_results,
                 x="task_idx", y="acc_100", hue="method",
                 width=0.68,
                 hue_order=["ct", "base"],
                 palette=sns.color_palette("tab20")[8:10])

ax.set_ylabel('Query Acc. (%)', fontsize=250, labelpad=100)
ax.set_yticks(np.arange(0, 100 + 1e-3, 20))
ax.set_yticklabels(ax.get_yticklabels(), fontsize=150)

ax.set_xlabel('MMLU Task ID', fontsize=250, labelpad=100)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=150)

handles, labels = ax.get_legend_handles_labels()
label_map = { "base": "Base", "ct": "CT" }
labels = [label_map[l] for l in labels]

ax.legend(handles=handles, labels=labels, loc='best',
          title='', title_fontsize=200,
          prop=dict(size=180))#, bbox_to_anchor=(.91, .7, .1, .1))

fig.tight_layout()
fig.show()
# fig.savefig("mmlu_oe_qacc_comparison.pdf", bbox_inches="tight")