In [None]:
from elk_generalization.results.viz import get_result_dfs
from elk_generalization.utils import get_quirky_model_name

models = [
    "mistralai/Mistral-7B-v0.1",
]
model_scales = {
    "pythia-410m": 0.41,
    "pythia-1b": 1,
    "pythia-1.4b": 1.4,
    "pythia-2.8b": 2.8,
    "pythia-6.9b": 6.9,
    "pythia-12b": 12,
    "Llama-2-7b-hf": 7,
    "Mistral-7B-v0.1": 7,
}
method_titles = {
    "lr": "LogR",
    "mean-diff": "Diff-in-means",
    "mean-diff-on-pair": "Diff-in-means on contrast pair",
    "lda": "LDA",
    "lr-on-pair": "LogR on contrast pair",
    "ccs": "CCS",
    "crc": "CRC",
}

ds_names = [
    "capitals",
    "hemisphere",
    "population",
    "sciq",
    "sentiment",
    "nli",
    "authors",
    "addition",
    "subtraction",
    "multiplication",
    "modularaddition",
    "squaring",
]

ds_abbrevs = {
    "capitals": "cap",
    "hemisphere": "hem",
    "population": "pop",
    "sciq": "sciq",
    "sentiment": "snt",
    "nli": "nli",
    "authors": "aut",
    "addition": "add",
    "subtraction": "sub",
    "multiplication": "mul",
    "modularaddition": "mod",
    "squaring": "sqr",
}
root = "../../experiments/"


# Qualitative differences

In [None]:
plot_ds_names = ds_names.copy()
plot_models = models
fr, to = "AE", "AE"
filter_by = "all"
if "H" in fr or "E" in fr:
    plot_ds_names.remove("population")  # difficulty is label
if filter_by == "disagreements":
    plot_ds_names.remove("authors")  # authors is only False for disagreements
weak_only = False
metric = "auroc"
methods = ["vincs",]
# var, inv, cov, supervised weights
vincs_hparams = (0.0, 1.0, 0.0, 1.0)
use_leace = False
templatization_method = "random"
standardize_templates = False
full_finetuning = False
ensemble = "full"
label_col="alice_label"
rs = dict()
for reporter in methods:
    rs[reporter] = get_result_dfs(plot_models, fr, to, plot_ds_names, label_col=label_col, ensemble=ensemble, filter_by=filter_by, metric=metric, reporter=reporter, root_dir=root, weak_only=weak_only, vincs_hparams=vincs_hparams, use_leace=use_leace, templatization_method=templatization_method, standardize_templates=standardize_templates, full_finetuning=full_finetuning)

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
sns.set_context("paper")

fig, ax = plt.subplots(1, 1, sharey=True, sharex=True, figsize=(6, 3), dpi=200)

for i, method in enumerate(methods):
    avg_reporter_results, per_ds_results_dfs, all_result_dfs, avg_lm_result, per_ds_lm_result_dfs, lm_results = rs[method]
    colors = sns.color_palette("tab20", len(per_ds_results_dfs))
    for j, (key, result_df, lm_result) in enumerate(zip(per_ds_results_dfs.keys(), per_ds_results_dfs.values(), per_ds_lm_result_dfs.values())):
        ax.plot(result_df["layer_frac"], result_df[metric], alpha=0.9, color=colors[j], linewidth=0.8, label=ds_abbrevs[key])
        ax.hlines(lm_result, 0, 1, color=colors[j], linewidth=1, linestyle=":")

    # turn legend on
    if i == 0:
        ax.legend(loc=[1.01, 0.01])

    if i % 3 == 0:
        lab = {
            "disagree": f"{metric.upper()}" + " on $\\bf{disagreements}$",
            "agree": f"{metric.upper()}" + " on $\\bf{agreements}$",
            "all": f"{metric.upper()}" + " on $\\bf{all\\ examples}$",
        }[filter_by]
        ax.set_ylabel(lab, fontsize=12)
    
    if i == 0:
        ax.set_xlabel("Layer (fraction of max)", fontsize=12)
    
    ax.set_xlim(0, 1)
    ax.set_ylim(-0.01, 1.01)
    if i == 3:
        ax.legend(loc="lower left")

plt.title(f"Layerwise {metric.upper()} for {fr}$\\to${to}" + (" weak only" if weak_only else ""), fontsize=14)
plt.tight_layout()
os.makedirs("../../figures", exist_ok=True)
plt.savefig(f"../../figures/layerwise_auroc_qualitative_{fr}_{to}.pdf")
plt.show()