In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
base = pd.read_csv("results/eval_mmlu-13b_chat-base.csv")
base = base[base.split == "test"][["dataset", "unc_acc", "unc_ece"]]
base["model"] = "base"

ct = pd.read_csv("results/eval_mmlu-13b_chat-ct.csv")
ct = ct[ct.split == "test"][["dataset", "unc_acc", "unc_ece"]]
ct["model"] = "ct"

results = pd.concat([base, ct], ignore_index=True)
results

In [None]:
dname_map = {d: idx + 1 for idx, d in enumerate(sorted(results.dataset.unique()))}

results["task_idx"] = results.dataset.apply(lambda d: dname_map[d])
results["unc_ece_100"] = results["unc_ece"] * 100.
results["unc_acc_100"] = results["unc_acc"] * 100.
results = results.sort_values("task_idx")

In [None]:
sns.set_theme(font_scale=2., style="whitegrid")

ref = results[results.model == "base"]["unc_ece_100"].values
new = results[results.model == "ct"]["unc_ece_100"].values

g = sns.displot(pd.DataFrame({ "rel_imp": -((new - ref) / ref) * 100. }),
                # kind="kde", fill=True, bw_adjust=.5,
                kde=True, stat="count", kde_kws={"bw_adjust": .5, "cut": 2}, binwidth=10,
                height=4, aspect=5/3, legend=False,
                palette=sns.color_palette("tab20")[8:9])

g.ax.get_lines()[0].set(linewidth=5)

g.set(xlabel="ECE Improvement (%)", title="CT v/s Unc. Query")

g.fig.tight_layout()
g.fig.show()
# g.fig.savefig("mmlu_rel_imp_ct_uncq.pdf", bbox_inches="tight")

In [None]:
sns.set_theme(font_scale=6., style="whitegrid")

fig, ax = plt.subplots(figsize=(200, 40))

ax = sns.barplot(ax=ax, data=results,
                 x="task_idx", y="unc_ece_100", hue="model",
                 width=0.68,
                 hue_order=["ct", "base"],
                 palette=sns.color_palette("tab20")[8:10])

ax.set_ylabel(r'$\mathrm{ECE} (\%)$', fontsize=250, labelpad=100)
# ax.set_yticks(np.arange(0, 60 + 1e-3, 10))
ax.set_yticklabels(ax.get_yticklabels(), fontsize=150)

ax.set_xlabel('MMLU Task ID', fontsize=200, labelpad=120)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=150)

handles, labels = ax.get_legend_handles_labels()
label_map = { "base": "Base", "ct": "CT" }
labels = [label_map[l] for l in labels]

ax.legend(handles=handles, labels=labels, loc='best',
          title='', title_fontsize=200,
          prop=dict(size=180))#, bbox_to_anchor=(.91, .7, .1, .1))

fig.tight_layout()
fig.show()
# fig.savefig("mmlu_13b_chat_oe_ct_ece.pdf", bbox_inches="tight")

In [None]:
sns.set_theme(font_scale=6., style="whitegrid")

fig, ax = plt.subplots(figsize=(200, 40))

ax = sns.barplot(ax=ax, data=results,
                 x="task_idx", y="unc_acc_100", hue="model",
                 width=0.68,
                 hue_order=["ct", "base"],
                 palette=sns.color_palette("tab20")[8:10])

ax.set_ylabel(r'Query Acc. $(\%)$', fontsize=250, labelpad=100)
# ax.set_yticks(np.arange(0, 60 + 1e-3, 10))
ax.set_yticklabels(ax.get_yticklabels(), fontsize=150)

ax.set_xlabel('MMLU Task ID', fontsize=200, labelpad=120)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=150)

handles, labels = ax.get_legend_handles_labels()
label_map = { "base": "Base", "ct": "CT" }
labels = [label_map[l] for l in labels]

ax.legend(handles=handles, labels=labels, loc='best',
          title='', title_fontsize=200,
          prop=dict(size=180))#, bbox_to_anchor=(.91, .7, .1, .1))

fig.tight_layout()
fig.show()
# fig.savefig("mmlu_13b_chat_oe_ct_qacc.pdf", bbox_inches="tight")