In [51]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines


poe_path = f"../process_of_elimination.csv"

poe_raw_df = pd.read_csv(poe_path)

lowest_raw_df = poe_raw_df[poe_raw_df["mask_strategy"] == "lowest"]
below_average_raw_df = poe_raw_df[poe_raw_df["mask_strategy"] == "below_average"]
# find mask strategies that start with min_k
min_k_raw_df = poe_raw_df[poe_raw_df["mask_strategy"].str.startswith("min_k")]

In [45]:
def process(df, drop_columns=None): 
    # get rid of identical rows 
    df = df.drop_duplicates()
    if isinstance(drop_columns, list):
        drop_columns += ["model_family", "seed", "batch_size", "loading_precision", "sample"]
    else:
        drop_columns = ["model_family", "seed", "batch_size", "loading_precision", "sample"]
    df = df.drop(columns=drop_columns)
    # shorten checkpoint names
    df['checkpoint'] = df["checkpoint"].apply(lambda x: x.split("/")[-1])
    return df

def process_v2(df):
    datasets="anli cqa siqa logical_deduction_five_objects disambiguation_qa conceptual_combinations strange_stories symbol_interpretation".split()
    df = df[df["dataset"].isin(datasets)]
    df = df.groupby(["dataset", "checkpoint", "mask_strategy"]).mean().reset_index()
    df = df.drop(columns=["checkpoint"])
    # accuracy: 3 decimal places
    df["accuracy"] = df["accuracy"].apply(lambda x: round(x, 3))
    return df

In [68]:
lowest_df = process(lowest_raw_df, drop_columns=["n_shot", "prompting_method", "scoring_method", "mask_accuracy", "method"])
min_k_df = process(min_k_raw_df, drop_columns=["n_shot", "prompting_method", "scoring_method", "mask_accuracy", "method"])
below_average_df = process(below_average_raw_df, drop_columns=["n_shot", "prompting_method", "scoring_method", "mask_accuracy", "method"])

lowest_df = process_v2(lowest_df)
min_k_df = process_v2(min_k_df)
below_average_df = process_v2(below_average_df)

# concatenate to df, and sort by dataset
df = pd.concat([lowest_df, min_k_df, below_average_df])
df = df.sort_values(by=["dataset", "mask_strategy"])

# reset the index
df = df.reset_index(drop=True)
# Find the index of the row with the highest accuracy for each dataset
max_accuracy_indices = df.groupby('dataset')['accuracy'].idxmax()
# Get the rows with the highest accuracy for each dataset
rows_with_highest_accuracy = df.loc[max_accuracy_indices]

# save to csv
df.to_csv("min_k_mask.csv", index=False)
rows_with_highest_accuracy.to_csv("min_k_mask_max_accuracy.csv", index=False)

In [69]:
df

Unnamed: 0,dataset,mask_strategy,accuracy
0,anli,below_average,0.55
1,anli,lowest,0.556
2,anli,min_k_2,0.578
3,conceptual_combinations,below_average,0.722
4,conceptual_combinations,lowest,0.76
5,conceptual_combinations,min_k_2,0.742
6,conceptual_combinations,min_k_3,0.604
7,cqa,below_average,0.892
8,cqa,lowest,0.895
9,cqa,min_k_2,0.884


In [70]:
rows_with_highest_accuracy

Unnamed: 0,dataset,mask_strategy,accuracy
2,anli,min_k_2,0.578
4,conceptual_combinations,lowest,0.76
8,cqa,lowest,0.895
13,disambiguation_qa,lowest,0.678
16,logical_deduction_five_objects,lowest,0.56
20,siqa,below_average,0.82
23,strange_stories,below_average,0.766
30,symbol_interpretation,min_k_3,0.274


In [72]:
# remove rows with mask_strategy to below_average
df = df[df["mask_strategy"] != "below_average"]

# reset the index
df = df.reset_index(drop=True)
# Find the index of the row with the highest accuracy for each dataset
max_accuracy_indices = df.groupby('dataset')['accuracy'].idxmax()
# Get the rows with the highest accuracy for each dataset
rows_with_highest_accuracy = df.loc[max_accuracy_indices]
rows_with_highest_accuracy

Unnamed: 0,dataset,mask_strategy,accuracy
1,anli,min_k_2,0.578
2,conceptual_combinations,lowest,0.76
5,cqa,lowest,0.895
9,disambiguation_qa,lowest,0.678
11,logical_deduction_five_objects,lowest,0.56
15,siqa,lowest,0.817
18,strange_stories,min_k_2,0.766
22,symbol_interpretation,min_k_3,0.274
