# Imports

In [67]:
%reload_ext autoreload
%autoreload 2


# import pickle
from plots import (
    MAX_TABLE_SIZE,
    make_table_avg,
    make_perf_table,
)
from utils import load_pickle
from generating_data.utils_for_notebooks import merge_methods
import pandas as pd
from IPython.display import display
import sys
import matplotlib.pyplot as plt
import os
import numpy as np
from scipy import stats


FIGURES_FOLDER = './tmp/NeurIPS_2025_Accelerated-Model-Evaluation-by-Using-Similarities-in-Prediction-Space/figures'
os.makedirs(FIGURES_FOLDER, exist_ok=True)

# Functions

In [68]:
def prepare_mae(mae_value):
    if mae_value is None or mae_value == float('nan'):
        return float('nan')
    else:
        return round(mae_value * 100, 2)

def prepare_rank(rank_value):
    if rank_value is None or rank_value == float('nan'):
        return float('nan')
    else:
        return round(rank_value, 3)

def make_table_1(data_for_table_1):

    assert len(data_for_table_1) == 4 # mae and rank for mmlu and hellaswag

    rows = []
    mmlu_maes, num_anchors_mmlu_maes = data_for_table_1[0]
    mmlu_ranks, num_anchors_mmlu_ranks = data_for_table_1[1]
    hellaswag_maes, num_anchors_hellaswag_maes = data_for_table_1[2]
    hellaswag_ranks, num_anchors_hellaswag_ranks = data_for_table_1[3]
    assert num_anchors_mmlu_maes == num_anchors_mmlu_ranks == num_anchors_hellaswag_maes == num_anchors_hellaswag_ranks
    num_anchors = num_anchors_mmlu_maes

    if hellaswag_maes is None:
        hellaswag_maes = mmlu_maes.copy()
        hellaswag_maes.loc[:,:] = float('nan')
    if hellaswag_ranks is None:
        hellaswag_ranks = mmlu_ranks.copy()
        hellaswag_ranks.loc[:,:] = float('nan')

    rows.append([ # headers
        "Approach",
        "Condensation", # type
        "Condensation", # num_anchors
        "Prediction", # type
        "MMLU", # mae
        "MMLU", # rank
        "hellaswag", # mae
        "hellaswag", # rank
    ])
    rows.append([
        "",
        "type", # type
        "num_anchors", # num_anchors
        "type", # type
        "mae", # mae
        "rank", # rank
        "mae", # mae
        "rank", # rank
    ])
    rows.append([ # RANDOM direct eval
        "Baseline",
        "Random",
        num_anchors,
        "Eval",
        prepare_mae(mmlu_maes.loc["random"]["naive"]),
        prepare_rank(mmlu_ranks.loc["random"]["naive"]),
        prepare_mae(hellaswag_maes.loc["random"]["naive"]),
        prepare_rank(hellaswag_ranks.loc["random"]["naive"]),
    ])
    # tinyBenchmarks
    rows.append([ # Random gp-IRT
        "tinyBenchmarks",
        "Random",
        num_anchors,
        "gp-IRT",
        prepare_mae(mmlu_maes.loc["random"]["gpirt"]),
        prepare_rank(mmlu_ranks.loc["random"]["gpirt"]),
        prepare_mae(hellaswag_maes.loc["random"]["gpirt"]),
        prepare_rank(hellaswag_ranks.loc["random"]["gpirt"]),
    ])
    rows.append([ # anchor-IRT gp-IRT
        "tinyBenchmarks",
        "anchor-IRT",
        num_anchors,
        "gp-IRT",
        prepare_mae(mmlu_maes.loc["anchor-irt"]["gpirt"]),
        prepare_rank(mmlu_ranks.loc["anchor-irt"]["gpirt"]),
        prepare_mae(hellaswag_maes.loc["anchor-irt"]["gpirt"]),
        prepare_rank(hellaswag_ranks.loc["anchor-irt"]["gpirt"]),
    ])
    rows.append([ # anchor-correctness gp-IRT
        "tinyBenchmarks",
        "anchor-correctness",
        num_anchors,
        "gp-IRT",
        prepare_mae(mmlu_maes.loc["anchor"]["gpirt"]),
        prepare_rank(mmlu_ranks.loc["anchor"]["gpirt"]),
        prepare_mae(hellaswag_maes.loc["anchor"]["gpirt"]),
        prepare_rank(hellaswag_ranks.loc["anchor"]["gpirt"]),
    ])
    rows.append([ # Random KNN
        "Baseline",
        "Random",
        num_anchors,
        "kNN",
        prepare_mae(mmlu_maes.loc["random"]["KNN"]),
        prepare_rank(mmlu_ranks.loc["random"]["KNN"]),
        prepare_mae(hellaswag_maes.loc["random"]["KNN"]),
        prepare_rank(hellaswag_ranks.loc["random"]["KNN"]),
    ])
    rows.append([ # Random fit
        "Baseline",
        "Random",
        num_anchors,
        "fit",
        prepare_mae(mmlu_maes.loc["random"]["fit"]),
        prepare_rank(mmlu_ranks.loc["random"]["fit"]),
        prepare_mae(hellaswag_maes.loc["random"]["fit"]),
        prepare_rank(hellaswag_ranks.loc["random"]["fit"]),
    ])
    rows.append([
        "DISCO (ours)",
        "High PDS",
        num_anchors,
        "kNN",
        prepare_mae(mmlu_maes.loc["highest"]["KNN"]),
        prepare_rank(mmlu_ranks.loc["highest"]["KNN"]),
        prepare_mae(hellaswag_maes.loc["highest"]["KNN"]),
        prepare_rank(hellaswag_ranks.loc["highest"]["KNN"]),
    ])
    rows.append([
        "DISCO (ours)",
        "High PDS",
        num_anchors,
        "fit",
        prepare_mae(mmlu_maes.loc["highest"]["fit"]),
        prepare_rank(mmlu_ranks.loc["highest"]["fit"]),
        prepare_mae(hellaswag_maes.loc["highest"]["fit"]),
        prepare_rank(hellaswag_ranks.loc["highest"]["fit"]),
    ])

    # res["baseline"] = {
    #     "mae": 0.0,
    #     "rank": 0.0
    # }
    # res["ours"] = {
    df = pd.DataFrame(rows)

    # display(df)

    latex_str = make_table_1_latex(df)

    return df, latex_str


def make_table_1_latex(df):
        # Add column headers
    df.columns = ["Approach", "Type", "# Samples", "Type", "MAE", "Rank", "MAE", "Rank"]

    # Create LaTeX table content
    latex_str = "\\begin{table}[H]\n"
    latex_str += "\\centering\n\\small\n"
    latex_str += "\\begin{tabular}{c|cc|c|cc|cc}\n"
    latex_str += "\\toprule\n"
    latex_str += "\\multicolumn{1}{c}{\\textbf{Approach}}&\\multicolumn{2}{c}{\\textbf{Condensation}} & \\multicolumn{1}{c}{\\textbf{Prediction}} & \\multicolumn{2}{c}{\\textbf{MMLU}}& \\multicolumn{2}{c}{\\textbf{hellaswag}} \\\\\n"
    latex_str += "&Type & \\# \\negthinspace Samples & Type & {MAE}  &Rank& {MAE}  &Rank \\\\\n"
    latex_str += "\\toprule\n"

    # Process each row
    current_approach = ""
    for _, row in df.iterrows():
        if row["Approach"] == "Approach" or row["Approach"] == "":
            continue
        if row["Approach"] == current_approach:
            approach_str = ""
        else:
            approach_str = row["Approach"]
            current_approach = row["Approach"]

            # Add midrule before new approach except for first one
            if approach_str != "Baseline":
                latex_str += "\\midrule\n"

        # Format numbers
        mae_mmlu = "-" if pd.isna(row["MAE"].values[0]) else f"{float(row['MAE'].values[0]):.2f}"
        rank_mmlu = "-" if pd.isna(row["Rank"].values[0]) else f"{float(row['Rank'].values[0]):.3f}"
        mae_hellaswag = "-" if pd.isna(row["MAE"].values[1]) else f"{float(row['MAE'].values[1]):.2f}"
        rank_hellaswag = "-" if pd.isna(row["Rank"].values[1]) else f"{float(row['Rank'].values[1]):.3f}"

        # Bold best results
        if approach_str == "DISCO (ours)" and row["Type"].values[1] == "linear":
            mae_mmlu = f"\\textbf{{{mae_mmlu}}}"
            rank_mmlu = f"\\textbf{{{rank_mmlu}}}"

        latex_str += f"{approach_str}&{row['Type'].values[0]} & {row['# Samples']} & {row['Type'].values[1]} & {mae_mmlu} &{rank_mmlu} & {mae_hellaswag} &{rank_hellaswag} \\\\\n"

    latex_str += "\\bottomrule\n"
    latex_str += "\\end{tabular}\n"
    latex_str += "\\vspace{1em}\n"
    latex_str += "\\caption{Mean Absolute Error (MAE) for different sampling and prediction strategies. For question answering task on MMLU dataset [FIX]. \\joon{Add computational complexity info\nAdd hellaswag results, add ranking metric, add method from hellaswag.\n}}\n"
    latex_str += "\\label{tab:language-main}\n"
    latex_str += "\\end{table}"

    # Store LaTeX code in DataFrame metadata
    df.attrs['latex_table'] = latex_str
    return latex_str


def extract_data_for_table_1(source_df, num_anchors, lower_better, key="PDS type"):
    # Group by PDS type and calculate mean for each group
    df = source_df[num_anchors]
    # For debugging:
    # display(df)
    # if num_anchors == 100 and not lower_better:
    #     display(df)
    # print("DEBUG")

    # Keep rows with NaN PDS type and group the rest
    nan_rows = df[df[key].isna()]
    non_nan_rows = df[df[key].notna()]
    if lower_better:
        grouped_non_nan = non_nan_rows.groupby(key, as_index=(key=="PDS type")).min()
    else:
        grouped_non_nan = non_nan_rows.groupby(key, as_index=(key=="PDS type")).max()
    grouped_df = pd.concat([grouped_non_nan, nan_rows])

    # to_drop = ["GradientBoostingRegressor_200"]
    # to_drop = ['MLP3_e700_lr0.001', 'Ridge_10', 'Lasso_e-4', 'RandomForestRegressor_100']
    to_drop = ['MLP3_e700_lr0.001', 'Ridge_10', 'Lasso_e-4', "GradientBoostingRegressor_200"] # keep only Random Forest
    # print("DEBUG: Dropping", to_drop)
    grouped_df = grouped_df.drop(columns=to_drop)

    # Get the columns to find minimum across
    min_cols = ['MLP3_e700_lr0.001', 'Ridge_10', 'Lasso_e-4', 'RandomForestRegressor_100', 'GradientBoostingRegressor_200']
    min_cols = [col for col in min_cols if col not in to_drop]

    # Find minimum value across specified columns and store in new 'linear' column
    if lower_better:
        grouped_df['fit'] = grouped_df[min_cols].min(axis=1)
    else:
        grouped_df['fit'] = grouped_df[min_cols].max(axis=1)

    # Drop the original columns
    grouped_df = grouped_df.drop(columns=min_cols)

    # Drop the stratified and #guiding_models columns since they're no longer meaningful after grouping
    for cols_to_drop in ['stratified', '#guiding_models', 'cirt', 'pirt']:
        if cols_to_drop == key:
            continue
        if cols_to_drop in grouped_df.columns:
            grouped_df = grouped_df.drop(cols_to_drop, axis=1)

    return grouped_df, num_anchors


def make_df_with_results(table_avg, table_std, bench, split):
    cur_methods_for_table = table_avg[bench][split].keys()

    df = make_perf_table(
        table_avg[bench][split],
        table_std[bench][split],
        methods=cur_methods_for_table,
    )

    pd.set_option('display.max_rows', MAX_TABLE_SIZE)
    pd.set_option('display.max_columns', MAX_TABLE_SIZE)
    pd.set_option(
        "display.max_colwidth", MAX_TABLE_SIZE
    )
    for num_samples in df.keys():
        # print("#anchor_points:", num_samples)
        # Reorder columns to put guiding models, PDS type, and stratified first
        cols = df[num_samples].columns.tolist()
        first_cols = ['#guiding_models', 'PDS type', 'stratified']
        other_cols = [col for col in cols if col not in first_cols]
        df[num_samples] = df[num_samples][first_cols + other_cols]

        # Replace all values in #guiding_models column with 382
        df[num_samples].loc[df[num_samples]['#guiding_models'] == 'all', '#guiding_models'] = 382

        # Sort rows by #guiding_models
        df[num_samples] = df[num_samples].sort_values(['PDS type', 'stratified', '#guiding_models'])

        # print(df[num_samples])

    # df[max(list(df.keys()))].to_csv(results_table_path)
    return df

# Read data

In [None]:
def process_table_data(
    bench,
    split,
    filename_suffix,
    data,
    scenarios_to_skip,
    ordered,
    agg_type,
    table_avg_base,
    table_std_base,
    model_perf_base
):
    current_table_avg, current_table_std, current_model_perf = make_table_avg(
        bench,
        split,
        filename_suffix,
        data,
        scenarios_to_skip=scenarios_to_skip,
        ordered=ordered,
        return_perf_table=True,
        agg_type=agg_type
    )
    table_avg_base = merge_methods(table_avg_base, current_table_avg)
    table_std_base = merge_methods(table_std_base, current_table_std)
    model_perf_base = merge_methods(model_perf_base, current_model_perf)
    return table_avg_base, table_std_base, model_perf_base


# load needed results
results_suffixes = {
    "mmlu_fields": {

        "iid": {
            "ours": "_disagreement_best_47",
            "irt": "_disagreement_compare_with_irt43"
        },
        # "noniid": {
        #     "ours": "_disagreement_best_48",
        #     "irt": "_disagreement_compare_with_irt44"
        # } # Old
        "noniid": {
            "ours": "_disagreement_best_68",
            "irt": "_disagreement_compare_with_irt70"
        } # more n anchors (till 2000)
    },
    "hellaswag": {

        "iid": {
            "ours": "_disagreement_best_63",
            "irt": "_disagreement_best_63"
        },
        "noniid": {
            "ours": "_disagreement_best_64",
            "irt": "_disagreement_best_64"
        }
    },
    "num_models": {
        3: "_disagreement_best_57",
        10: "_disagreement_best_56",
        30: "_disagreement_best_55",
        100: "_disagreement_best_54",
        300: "_disagreement_best_53",
        382: "_disagreement_best_68"
    },
    # "umap_pca": {
    #     "no_pca": "_disagreement_before_pca_15",
    #     "pca64": "_disagreement_umap_pca21",
    #     "pca128": "_disagreement_umap_pca23",
    #     "pca256": "_disagreement_umap_pca25",
    #     "umap64": "_disagreement_umap_pca27",
    #     "umap128": "_disagreement_umap_pca29",
    #     "umap256": "_disagreement_umap_pca31"
    # },
    "prediction_strategy": {
        # "linear": "_disagreement_best_68",
        # "mlps": "_disagreement_best_68",
        "other": "_disagreement_best_68",
    }
}
print("DEBUG: uncomment all except for num_models")
scenarios_to_skip = []
table_1_data = []
table_1_data_iid = []
figure_n_anchors_data = {}

table_avg_dict = {}
table_std_dict = {}
model_perf_dict = {}

for bench, per_bench in results_suffixes.items():
    if bench not in table_avg_dict:
        table_avg_dict[bench] = {}
        table_std_dict[bench] = {}
        model_perf_dict[bench] = {}
    # ordered = bench in ["mmlu_fields", "hellaswag", "num_models", "umap_pca"]
    ordered = True
    for agg_type in ["mae", "rank"]:

        if agg_type not in table_avg_dict[bench]:
            table_avg_dict[bench][agg_type] = {}
            table_std_dict[bench][agg_type] = {}
            model_perf_dict[bench][agg_type] = {}

        if bench in ["num_models", "umap_pca", "prediction_strategy"]:
            if agg_type == "mae":
                continue

            split = "noniid"
            real_bench = "mmlu_fields"

            factor_list = []
            for factor, filename_suffix in per_bench.items():

                table_avg_base = None
                table_std_base = None
                model_perf_base = None

                if factor not in table_avg_dict[bench]:
                    table_avg_dict[bench][factor] = {}
                    table_std_dict[bench][factor] = {}
                    model_perf_dict[bench][factor] = {}

                results_path = f'results/accs_{real_bench}_split-{split}_iterations-5{filename_suffix}.pickle'

                data = load_pickle(results_path)

                table_avg_base, table_std_base, model_perf_base = process_table_data(
                    real_bench,
                    split,
                    filename_suffix,
                    data,
                    scenarios_to_skip,
                    ordered,
                    agg_type,
                    table_avg_base,
                    table_std_base,
                    model_perf_base
                )

                table_avg_dict[bench][agg_type][factor] = table_avg_base
                table_std_dict[bench][agg_type][factor] = table_std_base
                model_perf_dict[bench][agg_type][factor] = model_perf_base

                df = make_df_with_results(table_avg_base, table_std_base, real_bench, split)

                if bench == "prediction_strategy":

                    filtered_df = df[100]
                    display(filtered_df)
                    print("DEBUG:")
                    filtered_df = filtered_df[filtered_df['PDS type'] == 'highest']
                    filtered_df = pd.DataFrame([filtered_df.max()], index=['highest'])
                    filtered_df.drop(columns=['PDS type', 'stratified', '#guiding_models'], inplace=True)
                    factor_list.append(filtered_df)
                else:
                    grouped_df, _ = extract_data_for_table_1(df, num_anchors=100, lower_better=(agg_type == "mae"))
                    grouped_df = grouped_df[["fit"]].rename(columns={"fit": f"fit-{factor}"})
                    factor_list.append(grouped_df)

            if len(factor_list) > 0:
                if agg_type == "rank":
                    factor_df = pd.concat(factor_list, axis=1)

                if bench == "num_models":
                    num_models_df = factor_df

                if bench == "prediction_strategy":
                    prediction_strategy_df = factor_df
        else:
            # ordered = bench in ["mmlu_fields", "hellaswag"]
            for split, per_split in per_bench.items():
                if split not in table_avg_dict[bench][agg_type]:
                    table_avg_dict[bench][agg_type][split] = {}
                    table_std_dict[bench][agg_type][split] = {}
                    model_perf_dict[bench][agg_type][split] = {}
                # for agg_type in ["mae", "rank"]:
                table_avg_base = None
                table_std_base = None
                model_perf_base = None
                for method in [
                    "ours",
                    "irt"
                ]:
                    # our_results_path = f'results/accs_{bench}_split-{split}_iterations-5{per_split["ours"]}.pickle'

                    # data_ours = load_pickle(our_results_path)
                    # irt_results_path = f'results/accs_{bench}_split-{split}_iterations-5{per_split["irt"]}.pickle'

                    # data_irt = load_pickle(irt_results_path)
                    filename_suffix = per_split[method]
                    results_path = f'results/accs_{bench}_split-{split}_iterations-5{filename_suffix}.pickle'
                    data = load_pickle(results_path)

                    # current_table_avg, current_table_std, current_model_perf = make_table_avg(
                    #     bench,
                    #     split,
                    #     filename_suffix,
                    #     data,
                    #     scenarios_to_skip=scenarios_to_skip,
                    #     ordered=ordered,
                    #     return_perf_table=True,
                    #     agg_type=agg_type
                    # )
                    # table_avg_base = merge_methods(table_avg_base, current_table_avg)
                    # table_std_base = merge_methods(table_std_base, current_table_std)
                    # model_perf_base = merge_methods(model_perf_base, current_model_perf)
                    table_avg_base, table_std_base, model_perf_base = process_table_data(
                        bench,
                        split,
                        filename_suffix,
                        data,
                        scenarios_to_skip,
                        ordered,
                        agg_type,
                        table_avg_base,
                        table_std_base,
                        model_perf_base
                    )

                # if agg_type not in table_avg_dict[bench]:
                #     table_avg_dict[bench][agg_type] = {}
                #     table_std_dict[bench][agg_type] = {}
                #     model_perf_dict[bench][agg_type] = {}
                # if split not in table_avg_dict[bench][agg_type]:
                #     table_avg_dict[bench][agg_type][split] = {}
                #     table_std_dict[bench][agg_type][split] = {}
                #     model_perf_dict[bench][agg_type][split] = {}
                table_avg_dict[bench][agg_type][split] = table_avg_base
                table_std_dict[bench][agg_type][split] = table_std_base
                model_perf_dict[bench][agg_type][split] = model_perf_base
                # print(model_perf_dict)

                # sys.exit(0)
                if split == "noniid":
                # if split == "iid":
                    df = make_df_with_results(table_avg_base, table_std_base, bench, split)
                    table_1_data.append(extract_data_for_table_1(df, num_anchors=100, lower_better=(agg_type == "mae")))
                    for num_anchors in df.keys():
                        if num_anchors not in figure_n_anchors_data:
                            figure_n_anchors_data[num_anchors] = []
                        figure_n_anchors_data[num_anchors].append(extract_data_for_table_1(df, num_anchors=num_anchors, lower_better=(agg_type == "mae")))
                    if agg_type == "rank":
                        ablation_strat, _ = extract_data_for_table_1(df, num_anchors=100, lower_better=(agg_type == "mae"), key="stratified")
                if split == "iid":
                    df_iid = make_df_with_results(table_avg_base, table_std_base, bench, split)
                    table_1_data_iid.append(extract_data_for_table_1(df_iid, num_anchors=100, lower_better=(agg_type == "mae")))
                    # print("DEBUG", df[100])

# generate table_avg, perf_avg and etc
# extract max across sampling methods
# for table in table_1_data:
#     print("DEBUG")
#     display(table)

# Table 1

In [None]:
# table_1, latex_str = make_table_1(table_1_data + [(None, 100), (None, 100)])
table_1, latex_str = make_table_1(table_1_data)
display(table_1)
print(latex_str)

In [None]:
# table_1, latex_str = make_table_1(table_1_data + [(None, 100), (None, 100)])
table_1_iid, latex_str_iid = make_table_1(table_1_data_iid)
display(table_1_iid)
print(latex_str_iid)

# Sensitivity: Num source models

In [None]:
display(num_models_df.applymap(prepare_rank))

# Ablation: prediction strategy

In [None]:
display(prediction_strategy_df)

# Ablation: Stratification

In [None]:
display(ablation_strat.loc[~ablation_strat['stratified'].isna(), ('fit', 'stratified')])

# Figure correlation with gt performance

In [23]:
def plot_correlation_with_gt_performance(model_perf_dict):
    # if results == 'acc':
    #     if agg == 'leaderboard':
    # split = 'noniid'
    split = 'noniid'
    iteration = 1
    number_item = 100
    color_mappings = {}

    alphas = {'random_naive':.4,'anchor_naive':.4,'anchor-irt_naive':.4,'anchor-irt_gpirt':.8}
    markersize = {'random_naive':7,'anchor_naive':5,'anchor-irt_naive':5,'anchor-irt_gpirt':5}
    names = {
        'random_naive':'random',
        'anchor-irt_naive':'IRT ',
        'anchor_naive':'correctness',
        'anchor-irt_gpirt':'IRT++',
        'high-disagreement@100+nonstratified_GradientBoostingRegressor_200': 'High PDS/Linear',
        'high-disagreement@100+nonstratified_RandomForestRegressor_100': 'High PDS/Random Forest'
    }
    plt.figure(figsize=(1.2*3.5,1.2*3))

    # for i,bench in enumerate(['mmlu']): #benchs[:4]
    for i,bench in enumerate(['mmlu_fields']): #benchs[:4]
        # axis = {'lb':'avg. score', 'mmlu_fields':'accuracy', 'mmlu':'accuracy', 'helm':'mean win rate', 'alpaca':'win rate'}
        # for method in ['anchor-irt_gpirt']: #

        for method in [
            # 'random_KNN'
            # 'high-disagreement@100+nonstratified_GradientBoostingRegressor_200'
            # 'high-disagreement+nonstratified_RandomForestRegressor_100'
            'high-disagreement@100+nonstratified_RandomForestRegressor_100'
        ]: #
            # print(model_perf[bench][split][method])
            if model_perf_dict is None:
                print("Using hardcoded data")
                x = [
                    0.74505765, 0.57627382, 0.63675495, 0.65292241, 0.38084539, 0.64716621,
                    0.60905773, 0.64266456, 0.77400765, 0.65963511, 0.45349487, 0.53712371,
                    0.54949706, 0.65399154, 0.66097766, 0.70370242, 0.60520878, 0.56062885,
                    0.62193137, 0.63557209, 0.39394311, 0.457429, 0.33628442, 0.60314194,
                    0.63755535, 0.62489059, 0.62572778, 0.76651098, 0.61419194, 0.64899013,
                    0.65274401, 0.50838601, 0.77287346, 0.70602432, 0.54623595, 0.74655344,
                    0.71878393, 0.63617809, 0.76029052, 0.75110232
                ]
                y = [
                    0.75731513, 0.60630879, 0.63863409, 0.66208239, 0.38429752, 0.64862158,
                    0.6141397, 0.63641627, 0.72950503, 0.65494508, 0.39609827, 0.54465079,
                    0.55145108, 0.64995122, 0.6629164, 0.7023732, 0.62522846, 0.54741339,
                    0.63129268, 0.6264134, 0.41973213, 0.51171946, 0.35283553, 0.60904086,
                    0.63357007, 0.62456882, 0.62493765, 0.76117544, 0.62270945, 0.63739011,
                    0.64783929, 0.49335064, 0.75879076, 0.72244666, 0.56023766, 0.74885069,
                    0.7073013, 0.6313948, 0.70250782, 0.75130385
                ]
            else:
                model_perf = model_perf_dict[bench]['mae'][split]
                x,y = model_perf[bench][split]['truth'], model_perf[bench][split][method][number_item][:,iteration]
                print(x)
                print(y)
            method_name = names[method] if method in names else method
            # label = "{:} (error={:.3f}, $r_S$={:.2f})".format(method_name, np.abs(x-y).mean(), stats.spearmanr(x,y).statistic)
            label = "{:} ($r_S$={:.2f})".format(method_name, stats.spearmanr(x,y).statistic)
            markersize = markersize[method] if method in markersize else 5
            alpha = alphas[method] if method in alphas else 0.5
            color = color_mappings[method] if method in color_mappings else 'black'
            plt.plot(x, y, 'o', label = label, markersize=markersize,alpha=alpha, color=color)

        plt.legend(fontsize=10, framealpha=.9)
        #plt.title(titles[bench])
        plt.plot([0,1],[0,1],'--r',lw=.5)
        plt.grid(alpha=.2)
        plt.xlabel('Ground truth performance', size=12)
        plt.ylabel('Estimated performance', size=12)
        plt.xlim(0,1)
        plt.ylim(0,1)
        tick_label_size = 11  # Example size, adjust as needed
        plt.tick_params(axis='x', labelsize=tick_label_size)
        plt.tick_params(axis='y', labelsize=tick_label_size)

    plt.tight_layout()
    plt.savefig(os.path.join(FIGURES_FOLDER, 'performance_correlation_mmlu.pdf'), bbox_inches='tight', dpi=400, transparent=True)
    plt.show()

In [None]:
plot_correlation_with_gt_performance(model_perf_dict)

In [None]:
plot_correlation_with_gt_performance(None)

# Figure: n_anchors

In [75]:
def make_figure_n_anchors(figure_n_anchors_data):
    maes = {}
    ranks = {}
    num_anchors_axis = []
    for method in [
        "DISCO + Random Forest (ours)",
        "DISCO + kNN (ours)",
        "Random + Eval",
        "Random + Random Forest",
        "Anchor-correctness + gp-IRT"
    ]:
        maes[method] = []
        ranks[method] = []
    if figure_n_anchors_data is None:
        print("Using hardcoded data")
        ranks = {
            'DISCO + Random Forest (ours)': [0.96, 0.977, 0.981, 0.987, 0.977, 0.985, 0.984, 0.979],
            'DISCO + kNN (ours)': [0.973, 0.975, 0.973, 0.972, 0.974, 0.957, 0.952, 0.95],
            'Random + Eval': [0.574, 0.721, 0.865, 0.916, 0.944, 0.972, 0.988, 0.992],
            'Random + Random Forest': [0.824, 0.866, 0.901, 0.933, 0.928, 0.932, 0.952, 0.959],
            'Anchor-correctness + gp-IRT': [0.845, 0.882, 0.925, 0.927, 0.956, 0.974, 0.991, 0.996]}
        maes = {
            'DISCO + Random Forest (ours)': [2.05, 1.2, 1.16, 1.07, 1.08, 1.02, 0.93, 1.15],
            'DISCO + kNN (ours)': [1.17, 1.29, 1.2, 1.31, 1.59, 1.53, 1.57, 1.56],
            'Random + Eval': [14.28, 9.05, 4.62, 3.45, 2.45, 1.4, 0.92, 0.73],
            'Random + Random Forest': [3.23, 2.57, 2.05, 1.81, 1.66, 1.59, 1.51, 1.55],
            'Anchor-correctness + gp-IRT': [3.4, 2.62, 2.24, 2.08, 1.67, 1.46, 1.14, 0.82]}
        num_anchors_axis = [10, 30, 60, 100, 200, 500, 1000, 2000]
    else:
        for num_anchors, data_for_table_1 in figure_n_anchors_data.items():
            assert len(data_for_table_1) >= 2 # mae and rank for mmlu and possibly helm

            mmlu_maes, num_anchors_mmlu_maes = data_for_table_1[0]
            mmlu_ranks, num_anchors_mmlu_ranks = data_for_table_1[1]
            # helm_maes, num_anchors_helm_maes = data_for_table_1[2]
            # helm_ranks, num_anchors_helm_ranks = data_for_table_1[3]
            assert num_anchors_mmlu_maes == num_anchors_mmlu_ranks == num_anchors

            maes["DISCO + Random Forest (ours)"].append(prepare_mae(mmlu_maes.loc["highest"]["fit"]))
            ranks["DISCO + Random Forest (ours)"].append(prepare_rank(mmlu_ranks.loc["highest"]["fit"]))

            maes["DISCO + kNN (ours)"].append(prepare_mae(mmlu_maes.loc["highest"]["KNN"]))
            ranks["DISCO + kNN (ours)"].append(prepare_rank(mmlu_ranks.loc["highest"]["KNN"]))

            maes["Random + Eval"].append(prepare_mae(mmlu_maes.loc["random"]["naive"]))
            ranks["Random + Eval"].append(prepare_rank(mmlu_ranks.loc["random"]["naive"]))

            maes["Random + Random Forest"].append(prepare_mae(mmlu_maes.loc["random"]["fit"]))
            ranks["Random + Random Forest"].append(prepare_rank(mmlu_ranks.loc["random"]["fit"]))

            maes["Anchor-correctness + gp-IRT"].append(prepare_mae(mmlu_maes.loc["anchor"]["gpirt"]))
            ranks["Anchor-correctness + gp-IRT"].append(prepare_rank(mmlu_ranks.loc["anchor"]["gpirt"]))
            num_anchors_axis.append(num_anchors)
    # Create line plot
    plt.figure(figsize=(8, 5))

    print(ranks)
    print(maes)

    for method, mae_values in maes.items():
        plt.plot(num_anchors_axis, mae_values, marker='o', label=method)

    plt.xlabel('Number of Anchors')
    plt.ylabel('Mean Absolute Error (MAE)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.title('MAE vs Number of Anchors by Method')
    plt.yscale('log')
    plt.xticks(num_anchors_axis)
    plt.yticks(
        [1.0, 1.1, 1.3, 1.5, 2.0, 3.0, 10.0],
        ['1.0', '1.1', '1.3', '1.5', '2.0', '3.0', '10.0']
    )
    plt.xscale('log')

    plt.figure(figsize=(8, 5))

    for method, rank_values in ranks.items():
        plt.plot(num_anchors_axis, rank_values, marker='o', label=method)

    plt.xlabel('Number of Anchors')
    plt.ylabel('Rank Correlation')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.title('Rank Correlation vs Number of Anchors by Method')
    plt.xticks(num_anchors_axis)
    plt.yticks(
        [0.60, 0.70, 0.80, 0.90, 0.92, 0.95, 0.98, 1.00],
        ['0.60', '0.70', '0.80', '0.90', '0.92', '0.95', '0.98', '1.00']
    )
    plt.xscale('log')

    # Save MAE plot
    plt.figure(1)  # Switch to first figure (MAE plot)
    plt.savefig(os.path.join(FIGURES_FOLDER, 'num_samples', 'mae_vs_n_anchors.pdf'), bbox_inches='tight')

    # Save rank correlation plot
    plt.figure(2)  # Switch to second figure (rank plot)
    plt.savefig(os.path.join(FIGURES_FOLDER, 'num_samples', 'rank_vs_n_anchors.pdf'), bbox_inches='tight')

    return plt.gcf()


In [None]:
fig_n_anchors = make_figure_n_anchors(figure_n_anchors_data)

In [None]:
fig_n_anchors = make_figure_n_anchors(None)