In [None]:
import os
from torch.distributions import Categorical
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
import os

base_dir = f"{os.environ.get('BLOBDIR')}/evals/eval-mmlu_oe-base-j7bezfxq"
base_map = {dd.split('-')[1]: f"{base_dir}/{dd}/test" for dd in os.listdir(base_dir)}

ct_dir = f"{os.environ.get('BLOBDIR')}/evals/eval-mmlu_oe-ct-f99ghkak"
ct_map = {dd.split('-')[1]: f"{ct_dir}/{dd}/test" for dd in os.listdir(ct_dir)}

ds_keys = sorted(list(filter(
    lambda x: x not in ["college_physics", "professional_law"], 
    set(base_map.keys()).intersection(set(ct_map.keys())))))
len(ds_keys)

In [None]:
import pandas as pd
import torch
from llm.models import get_model
from llm.datasets import get_token_vec

tokenizer = get_model("llama2_13b_chat_tokenizer")
token_vec = get_token_vec(tokenizer, format="roman_choice")

def load_metrics(path):
    row_paths = [r for r in os.listdir(path) if r.endswith('.csv')]
    rows = pd.concat([pd.read_csv(f"{path}/{r}") for r in row_paths], ignore_index=True)

    q_paths = [r for r in os.listdir(path) if r.endswith('.pt')]
    pt = [torch.load(f"{path}/{q}", map_location="cpu") for q in q_paths]
    
    q_labels = torch.cat([p["fuzzy_gpt-3.5-turbo-1106"]["q_labels"] for p in pt], dim=0)
    q_logits = torch.cat([p["fuzzy_gpt-3.5-turbo-1106"]["q_logits"] for p in pt], dim=0)
    
    assert len(rows) == len(q_labels)
    assert len(rows) == len(q_logits)

    q_p = q_logits[:, token_vec].softmax(dim=-1)

    return rows, q_labels, q_p

def prep_dataset(dataset):
    base_df, base_labels, base_p = load_metrics(base_map[dataset])
    base_df = pd.DataFrame({
        "target_lengths": base_df.target.apply(lambda t: len(t)).values,
        "entropy": Categorical(probs=base_p).entropy().numpy(),
        "p_yes": base_p[:, 1].numpy(),
        "correct": (base_labels == base_p.argmax(dim=-1)).long().numpy(),
        "model": "base",
        "dataset": dataset,
    })

    ct_df, ct_labels, ct_p = load_metrics(ct_map[dataset])
    ct_df = pd.DataFrame({
        "target_lengths": ct_df.target.apply(lambda t: len(t)).values,
        "entropy": Categorical(probs=ct_p).entropy().numpy(),
        "p_yes": ct_p[:, 1].numpy(),
        "correct": (ct_labels == ct_p.argmax(dim=-1)).long().numpy(),
        "model": "ct",
        "dataset": dataset,
    })

    return pd.concat([base_df, ct_df], ignore_index=True)

In [None]:
# sns.set(font_scale=1.5, style="whitegrid")

# for ds in ds_keys[:1]:
#     df = prep_dataset(ds)

#     fig, axes = plt.subplots(1, 2, figsize=(7, 4), sharex=True, sharey=True)

#     for i, m in enumerate(["base", "ct"]):
#         sns.kdeplot(
#             data=df[(df.labels.isin([0,1])) & (df.model == m)],
#             x="target_lengths", y="p_yes",
#             # fill=True,
#             thresh=0,
#             levels=15,
#             ax=axes[i],
#             cmap=sns.color_palette("coolwarm", as_cmap=True),
#         )

#     fig.suptitle(ds)

#     fig.tight_layout()
#     fig.show()

In [None]:

sns.set(font_scale=1.5, style="whitegrid")

for ds in tqdm(ds_keys):
    df = prep_dataset(ds)
    df = df[(df.model == "ct") & (df.correct.isin([0,1]))]

    g = sns.jointplot(
        data=df,
        x="target_lengths", y="p_yes",
        kind="reg",
        color=sns.color_palette("tab20")[8],
        height=4.5,
    )
    g.ax_joint.spines['left'].set_position('zero')
    g.ax_joint.set(xlabel="Target Length", ylabel=r"Confidence")
    g.fig.suptitle(ds)
    g.fig.tight_layout()
    # g.fig.savefig(f"length_figs/{ds}.pdf", bbox_inches="tight")
