In [6]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from itertools import chain
import json
from json import JSONDecodeError
import seaborn as sns
from scipy.stats import linregress
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel


def load_result(path):
    path = Path(path)
    try:
        with open(path / "results.json") as f:
            results = json.load(f)
        with open(path / "config.json") as f:
            config = json.load(f)
    except (FileNotFoundError, JSONDecodeError) as e:
        if isinstance(e, JSONDecodeError):
            print(f"Error decoding {path}: {e}")
        return
    
    if config["reporter"]["method"] == "ModularSftReporter":
        stages_cfg = config["reporter"]["stages"]
        weak_compute = sum(stage["num_weak_nonunique"] * stage["train_args"]["num_train_epochs"] for stage in stages_cfg)
        oracle_compute = sum(stage["num_oracle_nonunique"] * stage["train_args"]["num_train_epochs"] for stage in stages_cfg)
        total_compute = weak_compute + oracle_compute
    else:
        weak_compute, oracle_compute, total_compute = np.nan, np.nan, np.nan

    logodds = np.array(results["calibrated_logodds"])
    labels = np.array(results["gt_soft_labels"]) > 0.5
    thresh = np.quantile(logodds, labels.mean())
    calibrated_acc = ((logodds > thresh) == labels).mean()

    num_weak = results["reporter"]["num_weak"]
    num_oracle = results["reporter"]["num_oracle"]
    num_weak_nonunique = results["reporter"]["num_weak_nonunique"]
    num_oracle_nonunique = results["reporter"]["num_oracle_nonunique"]
    
    seed = int(path.name.split("_")[-1].split("s")[-1])
    sweep_name = "_".join(path.name.split("_")[3:-1])
    return {
        "auroc": results["auroc"],
        "acc": results["acc"],
        "calibrated_acc": calibrated_acc,
        "model_name": config["model"]["name"],
        "num_oracle": num_oracle,
        "num_weak": num_weak,
        "num_oracle_nonunique": num_oracle_nonunique,
        "num_weak_nonunique": num_weak_nonunique,
        "weak_compute": weak_compute,
        "oracle_compute": oracle_compute,
        "total_compute": total_compute,
        "seed": seed,
        "ds_name": path.parent.name,
        "sweep_name": sweep_name,
        "path": str(path),
    }


def get_results_df(ds_names=None, sweep_names=None):
    if sweep_names is not None:
        patterns = [f"nw=*_{sweep_name}_s*" for sweep_name in sweep_names]
    else:
        patterns = [f"nw=*_*_s*"]
    results = []
    if ds_names is None:
        ds_names = [d.name for d in Path("results").iterdir() if d.is_dir()]
    for ds_name in ds_names:
        for subdir in chain(*[Path(f"results/{ds_name}").glob(pattern) for pattern in patterns]):
            try:
                if result := load_result(subdir):
                    results.append(result)
            except Exception as e:
                print(e, subdir)
                continue
    results_df = pd.DataFrame(results)
    results_df.set_index(["ds_name", "model_name", "sweep_name"], inplace=True, drop=False)
    return results_df


def get_cur_df(results_df, run_name):
    if run_name[2] == "2weak_prompt_*_sft_estop":
        curr_df = results_df.xs((run_name[0], run_name[1]), level=['ds_name', 'model_name'])    
        curr_df = curr_df[curr_df.sweep_name.isin([
            "2weak_prompt_weak_sft_estop",
            "2weak_prompt_oracle_sft_estop",
        ])]
    else:
        curr_df = results_df.xs(run_name, level=['ds_name', 'model_name', 'sweep_name'])
    return curr_df


def find_result_by_n(results_df, n_weak, n_oracle, run_name, atol=0., rtol=0., verbose=True, metric="acc"):
    curr_df = get_cur_df(results_df, run_name)
    curr_df = curr_df[np.isclose(curr_df["num_oracle"], n_oracle, atol=atol, rtol=rtol)]
    curr_df = curr_df[np.isclose(curr_df["num_weak"], n_weak, atol=atol, rtol=rtol)]
    if len(curr_df) == 0:
        if verbose:
            print(f"WARNING: {run_name} has no results for ({n_weak}, {n_oracle})")
        return None, None, 0
    # print(curr_df["seed"].values.tolist())
    return curr_df[metric].mean(), curr_df[metric].std(ddof=1), len(curr_df)


def find_results_by_budget(results_df, budget, wmc, run_name):
    curr_df = get_cur_df(results_df, run_name).copy()
    curr_df["cost"] = curr_df["num_oracle"] + curr_df["num_weak"] * wmc
    curr_df = curr_df[curr_df["cost"] <= budget]
    return curr_df

In [50]:
salient_ds_name = "amazon_polarity_title_only_weak_amplified"
salient_results_df = get_results_df(ds_names=[salient_ds_name], sweep_names=None)

'reporter' results/amazon_polarity_title_only_weak_amplified/nw=1000_no=10_m=Meta-Llama-3-8B_seq_sft_both_estop_s5
'reporter' results/amazon_polarity_title_only_weak_amplified/nw=499_no=0_m=Qwen1.5-0.5B_seq_sft
'reporter' results/amazon_polarity_title_only_weak_amplified/nw=99_no=0_m=Meta-Llama-3-8B_seq_sft
'reporter' results/amazon_polarity_title_only_weak_amplified/nw=499_no=0_m=Meta-Llama-3-8B_seq_sft


In [71]:
salient_results_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,auroc,acc,calibrated_acc,model_name,num_oracle,num_weak,num_oracle_nonunique,num_weak_nonunique,weak_compute,oracle_compute,total_compute,seed,ds_name,sweep_name,path
ds_name,model_name,sweep_name,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
amazon_polarity_title_only_weak_amplified,meta-llama/Meta-Llama-3-8B,seq_sft_both_estop_clean_disjoint_16shot_weak,0.940515,0.856,0.856,meta-llama/Meta-Llama-3-8B,0,40963,0,40960,40960.0,0.0,40960.0,0,amazon_polarity_title_only_weak_amplified,seq_sft_both_estop_clean_disjoint_16shot_weak,results/amazon_polarity_title_only_weak_amplif...
amazon_polarity_title_only_weak_amplified,meta-llama/Meta-Llama-3-8B,seq_sft_both_estop_clean_disjoint_16shot_weak,0.96055,0.894,0.894,meta-llama/Meta-Llama-3-8B,0,176,0,160,30000.0,0.0,30000.0,0,amazon_polarity_title_only_weak_amplified,seq_sft_both_estop_clean_disjoint_16shot_weak,results/amazon_polarity_title_only_weak_amplif...
amazon_polarity_title_only_weak_amplified,meta-llama/Meta-Llama-3-8B,seq_sft_both_estop_clean_disjoint_16shot_weak,0.993921,0.94,0.956,meta-llama/Meta-Llama-3-8B,8192,16,8192,0,0.0,30000.0,30000.0,0,amazon_polarity_title_only_weak_amplified,seq_sft_both_estop_clean_disjoint_16shot_weak,results/amazon_polarity_title_only_weak_amplif...
amazon_polarity_title_only_weak_amplified,meta-llama/Meta-Llama-3-8B,seq_sft_both_estop_clean_disjoint_16shot_weak,0.945293,0.85,0.864,meta-llama/Meta-Llama-3-8B,0,656,0,640,30000.0,0.0,30000.0,0,amazon_polarity_title_only_weak_amplified,seq_sft_both_estop_clean_disjoint_16shot_weak,results/amazon_polarity_title_only_weak_amplif...
amazon_polarity_title_only_weak_amplified,meta-llama/Meta-Llama-3-8B,seq_sft_both_estop_clean_disjoint_16shot_weak,0.937011,0.851,0.85,meta-llama/Meta-Llama-3-8B,0,10252,0,10240,30000.0,0.0,30000.0,0,amazon_polarity_title_only_weak_amplified,seq_sft_both_estop_clean_disjoint_16shot_weak,results/amazon_polarity_title_only_weak_amplif...
amazon_polarity_title_only_weak_amplified,meta-llama/Meta-Llama-3-8B,seq_sft_both_estop_clean_disjoint_16shot_weak,0.94587,0.844,0.86,meta-llama/Meta-Llama-3-8B,0,2574,0,2560,30000.0,0.0,30000.0,0,amazon_polarity_title_only_weak_amplified,seq_sft_both_estop_clean_disjoint_16shot_weak,results/amazon_polarity_title_only_weak_amplif...


In [51]:
salient_results_df.iloc[1].path

'results/amazon_polarity_title_only_weak_amplified/nw=160_no=0_m=Meta-Llama-3-8B_seq_sft_both_estop_clean_disjoint_16shot_weak_s0'

In [62]:
import json
with open('results/amazon_polarity_title_only_weak_amplified/nw=160_no=0_m=Meta-Llama-3-8B_seq_sft_both_estop_clean_disjoint_16shot_weak_s0/results.json') as f:
    salient_results = json.load(f)
salient_results.keys()
salient_lo = np.array(salient_results["calibrated_logodds"])
salient_preds = salient_lo > 0
weak_labels = np.array(salient_results["weak_soft_labels"]) > 0.5
salient_df = pd.DataFrame({
    "ids": salient_results["ids"],
    "salient_lo": salient_lo,
    "salient_preds": salient_preds,
    "weak_labels": weak_labels,
})


In [63]:
ds_name = "amazon_polarity_title_only_neither_amplified"
results_df = get_results_df(ds_names=[ds_name], sweep_names=None)

'reporter' results/amazon_polarity_title_only_neither_amplified/nw=100_no=800_m=Meta-Llama-3-8B_seq_sft_both_estop_s5


In [64]:
results_df.iloc[2].path

'results/amazon_polarity_title_only_neither_amplified/nw=160_no=0_m=Meta-Llama-3-8B_seq_sft_both_estop_clean_disjoint_16shot_oracle_s0'

In [65]:
import json
with open('results/amazon_polarity_title_only_neither_amplified/nw=160_no=0_m=Meta-Llama-3-8B_seq_sft_both_estop_clean_disjoint_16shot_oracle_s0/results.json') as f:
    results = json.load(f)
results.keys()

dict_keys(['auroc', 'auroc_lo', 'auroc_hi', 'acc', 'acc_lo', 'acc_hi', 'auroc_against_weak', 'auroc_against_weak_lo', 'auroc_against_weak_hi', 'acc_against_weak', 'acc_against_weak_lo', 'acc_against_weak_hi', 'weak_soft_labels', 'oracle_ids', 'ids', 'calibrated_logodds', 'gt_soft_labels', 'reporter'])

In [70]:
results_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,auroc,acc,calibrated_acc,model_name,num_oracle,num_weak,num_oracle_nonunique,num_weak_nonunique,weak_compute,oracle_compute,total_compute,seed,ds_name,sweep_name,path
ds_name,model_name,sweep_name,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
amazon_polarity_title_only_neither_amplified,meta-llama/Meta-Llama-3-8B,seq_sft_both_estop_clean_disjoint_16shot_oracle,0.994738,0.967,0.954,meta-llama/Meta-Llama-3-8B,8205,0,8192,0,0.0,30000.0,30000.0,0,amazon_polarity_title_only_neither_amplified,seq_sft_both_estop_clean_disjoint_16shot_oracle,results/amazon_polarity_title_only_neither_amp...
amazon_polarity_title_only_neither_amplified,meta-llama/Meta-Llama-3-8B,seq_sft_both_estop_clean_disjoint_16shot_oracle,0.947147,0.868,0.866,meta-llama/Meta-Llama-3-8B,16,40959,0,40960,40960.0,0.0,40960.0,0,amazon_polarity_title_only_neither_amplified,seq_sft_both_estop_clean_disjoint_16shot_oracle,results/amazon_polarity_title_only_neither_amp...
amazon_polarity_title_only_neither_amplified,meta-llama/Meta-Llama-3-8B,seq_sft_both_estop_clean_disjoint_16shot_oracle,0.979861,0.929,0.93,meta-llama/Meta-Llama-3-8B,16,160,0,160,30000.0,0.0,30000.0,0,amazon_polarity_title_only_neither_amplified,seq_sft_both_estop_clean_disjoint_16shot_oracle,results/amazon_polarity_title_only_neither_amp...
amazon_polarity_title_only_neither_amplified,meta-llama/Meta-Llama-3-8B,seq_sft_both_estop_clean_disjoint_16shot_oracle,0.96502,0.903,0.896,meta-llama/Meta-Llama-3-8B,16,640,0,640,30000.0,0.0,30000.0,0,amazon_polarity_title_only_neither_amplified,seq_sft_both_estop_clean_disjoint_16shot_oracle,results/amazon_polarity_title_only_neither_amp...
amazon_polarity_title_only_neither_amplified,meta-llama/Meta-Llama-3-8B,seq_sft_both_estop_clean_disjoint_16shot_oracle,0.945853,0.858,0.868,meta-llama/Meta-Llama-3-8B,16,10240,0,10240,30000.0,0.0,30000.0,0,amazon_polarity_title_only_neither_amplified,seq_sft_both_estop_clean_disjoint_16shot_oracle,results/amazon_polarity_title_only_neither_amp...
amazon_polarity_title_only_neither_amplified,meta-llama/Meta-Llama-3-8B,seq_sft_both_estop_clean_disjoint_16shot_oracle,0.949586,0.86,0.87,meta-llama/Meta-Llama-3-8B,16,2560,0,2560,30000.0,0.0,30000.0,0,amazon_polarity_title_only_neither_amplified,seq_sft_both_estop_clean_disjoint_16shot_oracle,results/amazon_polarity_title_only_neither_amp...


In [68]:
# assert salient_results["ids"] == results["ids"]
lo = np.array(results["calibrated_logodds"])
preds = lo > 0
labels = np.array(results["gt_soft_labels"]) > 0.5
df = pd.DataFrame({
    "ids": salient_results["ids"],
    "lo": lo,
    "preds": preds,
    "labels": labels,
})
both_df = pd.merge(salient_df, df, on="ids")
both_df

Unnamed: 0,ids,salient_lo,salient_preds,weak_labels,lo,preds,labels
0,31103764,2.929688,True,True,5.390625,True,True
1,92dbe902,1.976562,True,True,-0.062988,False,True
2,94f2476d,-1.257812,False,False,-2.093750,False,False
3,b8888357,-1.562500,False,False,-0.314453,False,False
4,648731c9,2.859375,True,True,1.453125,True,True
...,...,...,...,...,...,...,...
995,b562521f,-1.921875,False,False,-0.619141,False,False
996,daccc210,2.898438,True,True,4.718750,True,True
997,22e77115,-1.761719,False,False,-1.023438,False,False
998,a509abdd,-2.011719,False,True,-0.982422,False,False


In [69]:
from sklearn.metrics import roc_auc_score
disagree_mask = weak_labels != labels

auc_on_disagree = roc_auc_score(labels[disagree_mask], lo[disagree_mask])
salient_auc_on_disagree = roc_auc_score(labels[disagree_mask], salient_lo[disagree_mask])
acc_on_disagree = (labels[disagree_mask] == preds[disagree_mask]).mean()
salient_acc_on_disagree = (labels[disagree_mask] == salient_preds[disagree_mask]).mean()
auc = roc_auc_score(labels, lo)
salient_auc = roc_auc_score(labels, salient_lo)
acc = (labels == preds).mean()
salient_acc = (labels == salient_preds).mean()
print(f"{auc=:.2f} {salient_auc=:.2f} {len(labels)=}")
print(f"{acc=:.2f} {salient_acc=:.2f}")
print(f"{auc_on_disagree=:.2f} {salient_auc_on_disagree=:.2f} {sum(disagree_mask)=}")
print(f"{acc_on_disagree=:.2f} {salient_acc_on_disagree=:.2f}")


auc=0.98 salient_auc=0.96 len(labels)=1000
acc=0.93 salient_acc=0.89
auc_on_disagree=0.84 salient_auc_on_disagree=0.44 sum(disagree_mask)=157
acc_on_disagree=0.76 salient_acc_on_disagree=0.48
