In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import math
import random
from tqdm import tqdm

pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", None)
pd.set_option("display.max_colwidth", -1)
meanprops = dict(linestyle="-", linewidth=6, color="k", alpha=1, zorder=99)
whiskerprops = dict(linestyle="-", linewidth=0)

In [None]:
exp_names = [
    "svhn",
    "cifar10",
    "cifar100",
    "super_cifar100",
    "camelyon",
    "animals",
    "breeds",
    "svhnvit",
    "cifar10vit",
    "cifar100vit",
    "super_cifar100vit",
    "camelyonvit",
    "animalsvit",
    "breedsvit",
]

df_list = []
for exp in exp_names:
    # in_path = os.path.join("/Users/Paul/research/files/analysis/csvs/{}_paper_sweep.csv".format(exp))
    in_path = os.path.join(
        "/home/tillb/Projects/failure-detection-benchmark/results/{}.csv".format(exp)
    )
    df = pd.read_csv(in_path)
    df = df.dropna(subset=["name", "model"])
    df = df.drop_duplicates(subset=["name", "study", "model", "network", "confid"])
    df = df[
        (~df.study.str.contains("tinyimagenet_original"))
        & (~df.study.str.contains("tinyimagenet_proposed"))
    ]
    if exp == "cifar10" or exp == "cifar100" or exp == "super_cifar100":
        df = df[(df.name.str.contains("vgg13"))]
    if exp == "super_cifar100":
        df = df[df.study == "iid_study"]
        df["study"] = df.apply(
            lambda row: "cifar100_in_class_study_superclasses", axis=1
        )
    elif exp == "super_cifar100vit":
        df = df[(df.study == "iid_study") & (~(df.old_name.str.contains("modeldevries")))]  # TODO: DG/Devries/Confidnet for cifar100 not run yet
        df["study"] = df.apply(
            lambda row: "cifar100vit_in_class_study_superclasses", axis=1
        )
    else:
        df["study"] = df.apply(lambda row: exp + "_" + row["study"], axis=1)
    print(exp, len(df.groupby("name").count()))

    df_list.append(df)

df = pd.concat(df_list)
exp_names = [e for e in exp_names if not e.startswith("super_cifar100")]

In [None]:
df[
    (df.name.str.contains("dg_bbvgg13_do1"))
    & (df.study == "cifar100_iid_study")
    & (df.confid == "dg_mcd_mcp")
]

In [None]:
df.groupby("study").count()

In [None]:
df["backbone"] = df.apply(lambda row: row["name"].split("bb")[1].split("_")[0], axis=1)
df["dropout"] = df.apply(lambda row: row["name"].split("do")[1].split("_")[0], axis=1)
df["model"] = df.apply(lambda row: row["name"].split("_")[0], axis=1)
df["model"] = df.apply(lambda row: "vit_" + row["name"].split("model")[1].split("_")[0] if (row["model"] == "vit" and "model" in row["name"]) else row["model"], axis=1)
df["run"] = df.apply(lambda row: row["name"].split("run")[1].split("_")[0], axis=1)
df["rew"] = df.apply(lambda row: row["name"].split("_rew")[1].split("_")[0], axis=1)
df["confid"] = df.apply(
    lambda row: row["model"]
    + "_"
    + row["confid"]
    + "_"
    + row["dropout"]
    + "_"
    + row["rew"],
    axis=1,
)
df = df.drop("model", axis=1)
df = df.drop("dropout", axis=1)


df = df.drop("backbone", axis=1)
print(len(df))
# print(df[df.study.str.contains("cifar100vit")])

In [None]:

# MODEL SELECTION


def select_func(row, selection_df, selection_column):
    name_splitter = -1 if selection_column == "rew" else -2
    row_exp = row["study"].split("_")[0] + "_"
    row_confid = "_".join(row["confid"].split("_")[:name_splitter])
    selection_df = selection_df[
        (selection_df.study.str.contains(row_exp)) & (selection_df.confid == row_confid)
    ]
    try:
        if row[selection_column] == selection_df[selection_column].tolist()[0]:
            return 1
        else:
            return 0
    except IndexError as e:
        print(row_exp, row_confid, len(selection_df))
        raise e


ms_metric = "aurc"  # Careful, when changing consider changing idxmin -> idxmax

# REWARD
non_agg_columns = ["study", "confid", "rew"]
ms_filter_metrics_df = df[["study", "confid", "run", "rew", ms_metric]]
df_ms = ms_filter_metrics_df.groupby(by=non_agg_columns).mean().reset_index()
print(len(df_ms), len(ms_filter_metrics_df))
df_ms = df_ms[df_ms.study.str.contains("val_tuning")]
df_ms["confid"] = df_ms.apply(
    lambda row: "_".join(row["confid"].split("_")[:-1]), axis=1
)
df_ms = df_ms.loc[
    df_ms.groupby(["study", "confid"])[ms_metric].idxmin().reset_index()[ms_metric]
]
print(len(df), len(df_ms))
df["select_rew"] = df.apply(lambda row: select_func(row, df_ms, "rew"), axis=1)
selected_df = df[df.select_rew == 1]

# DROPOUT
non_agg_columns = ["study", "confid", "dropout"]
selected_df["dropout"] = selected_df.apply(
    lambda row: row["name"].split("do")[1].split("_")[0], axis=1
)
do_filter_metrics_df = selected_df[["study", "confid", "run", "dropout", ms_metric]]
df_do = do_filter_metrics_df.groupby(by=non_agg_columns).mean().reset_index()
print(len(df_do), len(do_filter_metrics_df))
df_do = df_do[df_do.study.str.contains("val_tuning")]
df_do["confid"] = df_do.apply(
    lambda row: "_".join(row["confid"].split("_")[:-2]), axis=1
)
df_do = df_do.loc[
    df_do.groupby(["study", "confid"])[ms_metric].idxmin().reset_index()[ms_metric]
]
print(len(df), len(selected_df), len(df_do))
selected_df["select_do"] = selected_df.apply(
    lambda row: select_func(row, df_do, "dropout"), axis=1
)
all_selected_df = selected_df[selected_df.select_do == 1]

In [None]:
pd.set_option("display.max_rows", 200)
print(len(df), len(selected_df), len(all_selected_df), type)
all_selected_df[
    (all_selected_df.study.str.contains("val_tuning")) & (all_selected_df.run == "1")
][["study", "confid", "rew", "dropout", "aurc"]]

In [None]:
def rename_confids(in_confid):
    confid = in_confid.replace("confidnet_", "")
    confid = confid.replace("_dg", "_res")
    # confid = confid.replace("dg_", "deepgamblers_")
    confid = confid.replace("_det", "")
    confid = confid.replace("det_", "")
#     confid = confid.replace("_devries", "")
    confid = confid.replace("tcp", "confidnet")
    confid = confid.upper()
    confid = confid.replace("DEVRIES_DEVRIES", "DEVRIES")
    confid = confid.replace("VIT_VIT", "VIT")
    confid = confid.replace("DEVRIES", "Devries et al.")
    confid = confid.replace("CONFIDNET", "ConfidNet")
    confid = confid.replace("RES", "Res")
    confid = confid.replace("_", "-")
    confid = confid.replace("MCP", "MSR")
    confid = confid.replace("VIT-Res", "VIT-DG-Res")
    return confid


# FINAL CLEANING AND ASSIGNING OF DF
clean_df = all_selected_df.drop("dropout", axis=1)

# clean_df = clean_df.drop("rew", axis=1)
clean_df = clean_df[~clean_df.confid.str.contains("waic")]
clean_df["confid"] = clean_df.apply(
    lambda row: "_".join(row["confid"].split("_")[:-2]), axis=1
)
clean_df = clean_df[~clean_df.confid.str.contains("devries_mcd")]
clean_df = clean_df[~clean_df.confid.str.contains("devries_det")]
clean_df = clean_df[~clean_df.confid.str.contains("_sv")]
clean_df = clean_df[~clean_df.confid.str.contains("_mi")]
clean_df["confid"][clean_df["network"] == "vit"] = clean_df["confid"][
    clean_df["network"] == "vit"
].apply(lambda row: "vit_" + row)
print(clean_df.confid.unique())
clean_df["confid"] = clean_df.apply(lambda row: rename_confids(row["confid"]), axis=1)
clean_df["study"] = clean_df.study.str.replace(
    "tinyimagenet_384", "tinyimagenet_resize"
)
clean_df["study"] = clean_df.study.str.replace("vit", "").str.replace("_384", "")
df = clean_df
print(df.confid.unique())
print(df.study.unique())

In [None]:
# Agregate over runs. Number TABLES. TODO GET RID OF REWARD FOR PROPER RANKING ACROSS STUDIES!

metric = "aurc"
non_agg_columns = ["study", "confid"]  # might need rew if no model selection
filter_metrics_df = df[non_agg_columns + ["run", metric]]
df_mean = filter_metrics_df.groupby(by=non_agg_columns).mean().reset_index().round(2)
df_std = filter_metrics_df.groupby(by=non_agg_columns).std().reset_index().round(2)

studies = df_mean.study.unique().tolist()
dff = pd.DataFrame({"confid": df.confid.unique()})
print(dff)
print("CHECK LEN DFF", len(dff), len(df_mean))
combine_and_str = False
if combine_and_str:
    agg_mean_std = (
        lambda s1, s2: s1
        if (s1.name == "confid" or s1.name == "study" or s1.name == "rew")
        else s1.astype(str) + " ± " + s2.astype(str)
    )
    df_mean = df_mean.combine(df_std, agg_mean_std)
    for s in studies:
        sdf = df_mean[df_mean.study == s]
        dff[s] = dff["confid"].map(sdf.set_index("confid")[metric])


else:
    for s in studies:
        sdf = df_mean[df_mean.study == s]
        dff[s] = dff["confid"].map(sdf.set_index("confid")[metric])
        # print("DFF", dff.columns.tolist())

In [None]:
# Tripple results

non_agg_columns = ["study", "confid"]  # might need rew if no model selection
filter_metrics_df = df[non_agg_columns + ["run", metric]]
df_aurc = (
    df[non_agg_columns + ["run", "aurc"]]
    .groupby(by=non_agg_columns)
    .mean()
    .reset_index()
    .round(2)
)
df_auc = (
    df[non_agg_columns + ["run", "failauc"]]
    .groupby(by=non_agg_columns)
    .mean()
    .reset_index()
)
df_acc = (
    df[non_agg_columns + ["run", "accuracy"]]
    .groupby(by=non_agg_columns)
    .mean()
    .reset_index()
)
df_acc["aurc"] = df_acc["accuracy"] * 100
df_acc = df_acc.round(2)
df_auc["aurc"] = df_auc["failauc"] * 100
df_auc = df_auc.round(2)
studies = df_aurc.study.unique().tolist()
tripple_dff = df_aurc[df_aurc.study == "cifar100_iid_study"][["confid"]]
print("CHECK LEN DFF", len(dff), len(df_aurc))


agg_mean_std = (
    lambda s1, s2: s1
    if (s1.name == "confid" or s1.name == "study" or s1.name == "rew")
    else s1.astype(str) + " / " + s2.astype(str)
)
df_aurc = df_aurc.combine(df_acc, agg_mean_std)
df_aurc = df_aurc.combine(df_auc, agg_mean_std)
for s in studies:
    sdf = df_aurc[df_aurc.study == s]
    tripple_dff[s] = tripple_dff["confid"].map(sdf.set_index("confid")["aurc"])

In [None]:
# PLOT METRICS SELECTION
# df_acc
plot_dff = dff[["confid"] + [c for c in dff.columns if c.startswith("animals_")]]
columns = (
    ["confid"]
    + [c for c in plot_dff.columns if "iid" in c]
    + [c for c in plot_dff.columns if "ood" in c]
    + [c for c in plot_dff.columns if "proposed" in c]
)
print(columns, plot_dff.columns)
# columns = ["confid"]+ [c for c in plot_dff.columns if "noise" in c]
plot_dff[columns].set_index("confid").to_latex(
    "/home/t974t/Projects/failure-detection-benchmark/results/animals"
)
# print(len(df_aurc), len(df_auc), len(df_acc))
# df_acc
# dff[["confid", "dropout"] + [c for c in dff.columns if "original" in c]]

In [None]:

# RANKING DF

# dff = dff[["confid", "rew"] + [c for c in dff.columns if  c.startswith("cifar100")]]
select_df = dff
rank_df = select_df.rank(na_option="keep", numeric_only=True, ascending=False)
# actually aurc should be ranked ascedingly, but we want the lowest rank to show on top on the y axis
# so careful when using this df for other things than this plot!

rank_df["confid"] = dff.confid
rank_df

In [None]:
# RANKING PLOTS

scale = 10
sns.set_style("whitegrid")
plt_exps = exp_names
sns.set_context("paper", font_scale=scale * 0.20)
f, axs = plt.subplots(
    nrows=len(plt_exps), ncols=1, figsize=(3 * scale, len(exp_names) * scale * 2)
)
# todo ! supercifar has to be a part of cifar100 exp. check also weird observation regarding val_tuning
for ax_ix, exp in enumerate(plt_exps):
    cols = [c for c in rank_df.columns if c.startswith(exp + "_")]
    cols = (
        ["{}_val_tuning".format(exp), "{}_iid_study".format(exp)]
        + [c for c in cols if "noise" in c]
        + [c for c in cols if "in_class" in c]
        + [c for c in cols if "proposed" in c]
    )
    numeric_exp_df = rank_df[cols]
    # todo DROPNAN?
    confids_list = rank_df.confid.tolist()
    x = range(len(numeric_exp_df.columns))
    for ix in range(len(numeric_exp_df)):
        y = numeric_exp_df.iloc[ix].values
        axs[ax_ix].plot(x, y)
    axs[ax_ix].set_yticks(range(1, len(numeric_exp_df) + 1))
    axs[ax_ix].set_yticklabels(
        rank_df[["confid"] + [c for c in rank_df.columns if c.startswith(exp)]]
        .sort_values(by=numeric_exp_df.columns[0])
        .confid.tolist()
    )
    axs[ax_ix].set_xticks(x)
    axs[ax_ix].set_xticklabels([c for c in numeric_exp_df.columns], rotation=90)
    axs[ax_ix].set_xlim(0, len(numeric_exp_df.columns) - 1)

plt.tight_layout()
plt.show()

In [None]:
colors = [
    "tab:blue",
    "green",
    "tab:purple",
    "orange",
    "red",
    "black",
    "pink",
    "olive",
    "grey",
    "brown",
    "tab:cyan",
    "blue",
    "limegreen",
    "darkmagenta",
    "salmon",
    "tab:blue",
    "green",
    "tab:purple",
    "orange",
]
print(len(rank_df.confid.str.replace("VIT-", "").unique().tolist()))
print(len(colors))

color_dict = {conf: colors[ix] for ix, conf in enumerate(rank_df.confid.str.replace("VIT-", "").unique().tolist())}
color_dict.update({conf: color_dict[conf.replace("VIT-", "")] for ix, conf in enumerate(rank_df.confid[rank_df.confid.str.contains("VIT")].tolist())})
print(color_dict)

In [None]:
# SUM RANKING PLOTS

select_columns = [c for c in rank_df.columns]
iid_columns = [c for c in select_columns if "iid" in c]
print("IID", iid_columns)
in_class_columns = [c for c in select_columns if "in_class" in c]
print("SUB CLASS", in_class_columns)
new_class_columns = [
    c for c in select_columns if ("new_class" in c and "proposed" in c)
]
sem_new_class_columns = [
    c for c in new_class_columns if ("cifar10_" in c and "cifar100_" in c)
]
print("SEMANTIC NEW CLASS", sem_new_class_columns)
nonsem_new_class_columns = [
    c for c in new_class_columns if c not in sem_new_class_columns
]
print("NON-SEMANTIC NEW CLASS", nonsem_new_class_columns)
noise_columns = [c for c in select_columns if "noise" in c]
print("NOISE", noise_columns)
sum_rank_df = rank_df[["confid"]]
sum_rank_df["iid"] = rank_df[iid_columns].sum(axis=1, numeric_only=True, skipna=False)
sum_rank_df["corruption-shift"] = rank_df[noise_columns].sum(
    axis=1, numeric_only=True, skipna=False
)
if len(in_class_columns) > 0:
    sum_rank_df["sub-class-shift"] = rank_df[in_class_columns].sum(
        axis=1, numeric_only=True, skipna=False
    )
sum_rank_df["sem.-new-class-shift"] = rank_df[sem_new_class_columns].sum(
    axis=1, numeric_only=True, skipna=False
)
sum_rank_df["non-sem.-new-class-shift"] = rank_df[nonsem_new_class_columns].sum(
    axis=1, numeric_only=True, skipna=False
)
sum_rank_df = sum_rank_df.rank(na_option="keep", numeric_only=True, ascending=True)
sum_rank_df["confid"] = rank_df.confid
sum_rank_df["aggregated"] = sum_rank_df.sum(
    axis=1, numeric_only=True, skipna=False
).rank(na_option="keep", ascending=True)

# sum_rank_df["iid"] = sum_rank_df.apply(lambda row: row["iid"] + 0.5 if row["confid"] == "confidnet_mcd" else row["iid"], axis=1)
# sum_rank_df["iid"] = sum_rank_df.apply(lambda row: row["iid"] - 0.5 if row["confid"] == "deepgamblers_mcd_mi" else row["iid"], axis=1)

scale = 10
sns.set_style("whitegrid")
sns.set_context("paper", font_scale=scale * 0.50)
f, axs = plt.subplots(nrows=1, ncols=1, figsize=(3 * scale, 1.5 * scale * 1.2))
# todo ! supercifar has to be a part of cifar100 exp. check also weird observation regarding val_tuning

show_columns = [
    "iid",
    "corruption-shift",
    "sub-class-shift",
    "sem.-new-class-shift",
    "non-sem.-new-class-shift",
    "aggregated",
]
cols = show_columns  # [c for c in sum_rank_df.columns if c.startswith("sum")]
numeric_exp_df = sum_rank_df[cols]
print(numeric_exp_df)
# todo DROPNAN?
confids_list = sum_rank_df.confid.tolist()
x = range(len(numeric_exp_df.columns))
ranked_confs = sum_rank_df.sort_values(by=numeric_exp_df.columns[0]).confid.tolist()
# from itertools import zip
import numpy as np
seen = [[] for _ in x]
for ix in range(len(numeric_exp_df)):
    y = numeric_exp_df.iloc[ix].values
    axs.plot(
        x,
        y,
        linewidth=3.1,
#         marker=".",
#         ms=18,
        color=color_dict[sum_rank_df.confid.tolist()[ix]],
        alpha=0.3,
    )
    for x_, y_ in zip(x, y):
        if np.isnan(y_):
            continue
            
        if y_ in seen[x_]:
            y_ -= 0.45
        else: 
            seen[x_].append(y_)
            
        axs.text(
            x_,
            y_,
            ranked_confs[ix],
            fontsize=16,
            horizontalalignment='center',
#             verticalalignment='center'
        )
        
axs.set_yticks(range(1, len(numeric_exp_df) + 1))
# axs.set_yticklabels(ranked_confs)
axs.set_yticklabels([])
# axs.set_yticklabels(reversed(range(1, len(numeric_exp_df) + 1)), fontsize=12)
axs.set_xticks(x)
# axs.set_xticklabels([c[:5] for c in numeric_exp_df.columns], rotation=90)
axs.set_xticklabels([c for c in numeric_exp_df.columns], fontsize=18, fontweight='bold')
axs.set_xlim(0, len(numeric_exp_df.columns) - 1)
axs.xaxis.tick_top()
print(axs.get_facecolor())
axs.annotate("", xy=(1.05, 0), xytext=(1.05, 1),
            arrowprops=dict(width=3, headwidth=8, headlength=8, color='grey'), xycoords='axes fraction')
axs.annotate("best\nrank", xy=(1.054, 1), xytext=(1.054, 1), xycoords='axes fraction', fontsize=14, horizontalalignment='left', verticalalignment='top')
axs.annotate("worst\nrank", xy=(1.054, 0), xytext=(1.054, 0), xycoords='axes fraction', fontsize=14, horizontalalignment='left', verticalalignment='bottom')
plt.tight_layout()
# plt.savefig("/Users/Paul/research/files/analysis/paper_plots/ranking.png")
plt.savefig("/home/tillb/Projects/failure-detection-benchmark/results/ranking.png")
plt.show()

In [None]:
dff.columns

In [None]:
# BAR PLOTS


scale = 10
sns.set_style("whitegrid")
sns.set_context("paper", font_scale=scale * 0.35)
for exp in ["cifar100"]:
    cols = [c for c in rank_df.columns if c.startswith(exp + "_")]
    cols = (
        ["{}_iid_study".format(exp)]
        + [c for c in cols if "noise" in c]
        + [c for c in cols if "in_class" in c]
        + [c for c in cols if "proposed" in c]
    )
    numeric_exp_df = dff[cols]
    # todo DROPNAN?
    confids_list = dff.confid.tolist()
    x = range(len(numeric_exp_df.columns))
    f, axs = plt.subplots(
        nrows=1, ncols=len(cols), figsize=(scale * len(cols), scale * 2)
    )

    for ix, c in enumerate(cols):
        sns.stripplot(ax=axs[ix], x="confid", y=c, data=dff)
        axs[ix].set_ylim(dff[c].min() - 4, dff[c].max() + 4)
        axs[ix].set_xticklabels(axs[ix].get_xticklabels(), rotation=90)
        title = axs[ix].get_ylabel()
        title = title.replace(exp + "_", "")
        title = title.replace("_proposed_mode", "")
        title = title.replace("_", "-")
        title = title.replace("-study-", "-shift-")
        title = title.replace("in-class", "sub-class")
        title = title.replace("-resize", "")
        axs[ix].set_title(title)
        axs[ix].set_ylabel("")
        axs[ix].set_xlabel("")
    # plt.bar(x="x", height="{}_iid_study".format(exp), data=dff)
    # for ix in range(len(numeric_exp_df)):
    #     y = numeric_exp_df.iloc[ix].values
    #     axs[ax_ix].plot(x, y)
    # axs[ax_ix].set_yticks(range(1, len(numeric_exp_df) + 1))
    # axs[ax_ix].set_yticklabels(rank_df[["confid"] + [c for c in rank_df.columns if c.startswith(exp)]].sort_values(by=numeric_exp_df.columns[0]).confid.tolist())
    # axs[ax_ix].set_xticks(x)
    # axs[ax_ix].set_xticklabels([c for c in numeric_exp_df.columns], rotation = 90)
    # axs[ax_ix].set_xlim(0, len(numeric_exp_df.columns) - 1)

    plt.tight_layout()
    plt.show()

In [None]:
# OVERVOEW PLOTS

metrics = ["aurc", "accuracy", "failauc"]
plot_exps = ["animals"]  # exp_names
cross_mode = False
scale = 8
sns.set_style("whitegrid")
sns.set_context("paper", font_scale=scale * 0.35)
dims = ["confid"]

# plot_df =

for metric in metrics:
    if not cross_mode:
        for exp in plot_exps:
            plot_data = df[df.study.str.startswith(exp + "_")][
                ["study", "confid", "run", metric]
            ]  # & (data["ne"].str.contains("250")) & (data["ap"]==False)]
            studies = plot_data.study.unique().tolist()
            print(studies, plot_data.columns)
            f, axs = plt.subplots(
                nrows=len(dims),
                ncols=len(studies),
                figsize=(len(studies) * scale * 1.2, len(dims) * scale * 1.2),
            )
            for xix, dim in enumerate(dims):
                for yix, study in enumerate(studies):
                    y = metric
                    sns.stripplot(
                        ax=axs[yix],
                        x=dim,
                        y=metric,
                        data=plot_data[plot_data.study == study],
                        s=scale * 0.8,
                    )
                    sns.boxplot(
                        ax=axs[yix],
                        x=dim,
                        y=metric,
                        data=plot_data[plot_data.study == study],
                        saturation=0,
                        showbox=False,
                        showcaps=False,
                        showfliers=False,
                        whiskerprops=whiskerprops,
                        showmeans=True,
                        meanprops=meanprops,
                        meanline=True,
                    )
                    axs[yix].set_xticklabels(axs[yix].get_xticklabels(), rotation=90)

                    # if "iid" in study and metric == "aurc":
                    #     axs[xix, yix].set_ylim(4, 8)
                    # if "iid" in study and metric == "failauc":
                    #     axs[xix, yix].set_ylim(0.90, 0.96)
            plt.tight_layout()
            plt.savefig(
                "/home/t974t/Projects/failure-detection-benchmark/results/paper_{}_{}.png".format(
                    exp, metric
                )
            )
    else:
        # plot_data = df[df.study.str.startswith(exp)][["study", "confid", "run", "rew", metric]] # & (data["ne"].str.contains("250")) & (data["ap"]==False)]
        plot_data = df[df.study.str.contains("iid_study")][
            ["study", "confid", "run", "rew", metric]
        ]
        print(studies, plot_data.columns)
        f, axs = plt.subplots(
            nrows=len(dims),
            ncols=len(exp_names),
            figsize=(len(exp_names) * scale, len(dims) * scale * 1.2),
        )
        for xix, dim in enumerate(dims):
            for yix, exp in enumerate(exp_names):
                y = metric
                sns.stripplot(
                    ax=axs[yix],
                    x=dim,
                    y=metric,
                    data=plot_data[plot_data.study == "{}_iid_study".format(exp)],
                    s=scale * 0.8,
                )
                sns.boxplot(
                    ax=axs[yix],
                    x=dim,
                    y=metric,
                    data=plot_data[plot_data.study == "{}_iid_study".format(exp)],
                    saturation=0,
                    showbox=False,
                    showcaps=False,
                    showfliers=False,
                    whiskerprops=whiskerprops,
                    showmeans=True,
                    meanprops=meanprops,
                    meanline=True,
                )
                axs[yix].set_xticklabels(axs[yix].get_xticklabels(), rotation=90)
                axs[yix].set_xlabel(exp)
                # if "iid" in study and metric == "aurc":
                #     axs[xix, yix].set_ylim(4, 8)
                # if "iid" in study and metric == "failauc":
                #     axs[xix, yix].set_ylim(0.90, 0.96)
        plt.tight_layout()
        plt.savefig(
            "/home/t974t/Projects/failure-detection-benchmark/results/paper_iid_{}.png".format(
                metric
            )
        )

In [None]:
# FINAL STRIP PLOTS
def final_strip_plots():
    metrics = [
        "accuracy", "aurc", "failauc", "ece",  "fail-NLL"
    ]
    plot_exps = [
        "cifar10",
        "cifar100",
        "svhn",
        "breeds",
        "animals",
        "camelyon",
    ]  # exp_names
    cross_mode = False
    scale = 15
    sns.set_style("whitegrid")
    # sns.color_palette("tab20")
    # palette = sns.color_palette()
    # c = []
    # for ix in range(15):
    #     print(ix)
    #     c.append(palette[ix])
    # print(c)
    # random.shuffle(c)
    # print(c)
    sns.set_context("paper", font_scale=scale * 0.35)
    dims = ["confid"]

    studies = [
        'iid-study',
        'sub-class-shift',
        'corruption-shift-1',
        'corruption-shift-2',
        'corruption-shift-3',
        'corruption-shift-4',
        'corruption-shift-5',
        'new-class-shift-cifar10',
        'new-class-shift-cifar10-original-mode',
        'new-class-shift-cifar100',
        'new-class-shift-cifar100-original-mode',
        'new-class-shift-svhn',
        'new-class-shift-svhn-original-mode',
        'new-class-shift-tinyimagenet',
        'new-class-shift-tinyimagenet-original-mode'
    ]

    # print(df)

    for exp in plot_exps:
        print(f"Creating plots for {exp}...")
        pdata = df[df.study.str.startswith(exp + "_")][
            ["study", "confid", "run"] + metrics
        ]

        def fix_studies(n):
            n = n.replace(exp + "_", "")
            n = n.replace("_proposed_mode", "")
            n = n.replace("_", "-")
            n = n.replace("-study-", "-shift-")
            n = n.replace("in-class", "sub-class")
            n = n.replace("noise", "corruption")
            n = n.replace("-resize", "")
            n = n.replace("-wilds-ood-test", "")
            n = n.replace("-ood-test", "")
            n = n.replace("-superclasses", "")
            return n


        plot_studies = pdata.study.unique().tolist()
        plot_studies = [
            c for c in plot_studies if not "val_tuning" in c
        ]  # & (data["ne"].str.contains("250")) & (data["ap"]==False)]
        plot_studies = list(sorted(plot_studies, key=lambda x: fix_studies(x)))
        # print(studies)
        cols = [c for c in plot_studies if exp + "_" in c]
        # plot_studies = studies
        ncols = len(plot_studies)

        nrows = len(metrics)
        f, axs = plt.subplots(
            nrows=nrows,
            ncols=ncols,
            figsize=(ncols * scale * 1.2, nrows * scale * 1.2),
            squeeze=False
        )
        # axs = axs.flatten()

        for mix, metric in enumerate(metrics):
            plot_data = df[df.study.str.startswith(exp + "_")][
                ["study", "confid", "run", metric]
            ]
            # print(plot_studies, plot_data.columns)
            saxs = axs[mix]
            for xix, dim in enumerate(dims):
                skipped = 0
                for yix, study in enumerate(studies):
                    if study not in [fix_studies(s) for s in plot_studies]:
                        skipped += 1
                        continue

                    yix = yix - skipped
                    y = metric
                    # print(plot_data.study.apply(fix_studies), study)
                    data = plot_data[plot_data.study.apply(fix_studies) == study].sort_values(by="confid")
                    plot_colors = [
                        color_dict[conf] for conf in data.confid.unique().tolist()
                    ]
                    # print(plot_colors)
                    palette = sns.color_palette(plot_colors)
                    # print(plot_colors)
                    # print(data.confid.unique().tolist())
                    sns.set_palette(palette)

                    # print(data[~data[dim].str.startswith("VIT")])

                    # order = data[dim].str.replace("VIT-", "").sort_values().unique()
                    # data[dim] = data[dim].str.replace("VIT-", "").sort_values().unique()

                    # if not "noise" in study or "noise_study_3" in study:
                    # print(study)
                    sns.stripplot(
                        ax=saxs[yix],
                        x=dim,
                        y=metric,
                        data=data,
                        s=scale * 1.6,
                        label=dim,
                        alpha=0.5
                    )
                    sns.boxplot(
                        ax=saxs[yix],
                        x=dim,
                        y=metric,
                        data=data,
                        medianprops=dict(alpha=0),
                        saturation=0,
                        showbox=False,
                        showcaps=False,
                        showfliers=False,
                        whiskerprops=whiskerprops,
                        showmeans=True,
                        meanprops=meanprops,
                        meanline=True,
                    )
                    # axs[yix].set_xticklabels("")
                    saxs[yix].set_xticklabels(saxs[yix].get_xticklabels(), rotation=90)

                    saxs[yix].set_title(fix_studies(study), pad=35)
                    saxs[yix].set_ylabel("")
                    saxs[yix].set_xlabel("")
                    # lim = data[metric].mean() + data[metric].std()
    #                 saxs[yix].set_ylim(data[metric].min(), data[metric].max())
                    if yix == 0:
                        saxs[yix].set_ylabel(metric)

                    # if yix == 5:
                    #     axs[yix].axis("off")
                    #     axs[yix-1].legend()

                    # if "iid" in study and metric == "aurc":
                    #     axs[xix, yix].set_ylim(4, 8)
                    # if "iid" in study and metric == "failauc":
                    #     axs[xix, yix].set_ylim(0.90, 0.96)
        plt.tight_layout()
        plt.savefig(
            "/home/tillb/Projects/failure-detection-benchmark/results/final_paper_{}.png".format(
                exp
            )
        )
        plt.close(f)
        
final_strip_plots()

In [None]:
# FINAL STRIP PLOTS

def final_strip_plots_sc():
    metrics = [
        "accuracy", "aurc", "failauc", "ece",  "fail-NLL"
    ]
    plot_exps = [
        "cifar10",
        "cifar100",
        "svhn",
        "breeds",
        "animals",
        "camelyon",
    ]  # exp_names
    cross_mode = False
    scale = 15
    sns.set_style("whitegrid")
    # sns.color_palette("tab20")
    # palette = sns.color_palette()
    # c = []
    # for ix in range(15):
    #     print(ix)
    #     c.append(palette[ix])
    # print(c)
    # random.shuffle(c)
    # print(c)
    sns.set_context("paper", font_scale=scale * 0.35)
    dims = ["confid"]

    studies = [
        'iid-study',
        'sub-class-shift',
        'corruption-shift-1',
        'corruption-shift-2',
        'corruption-shift-3',
        'corruption-shift-4',
        'corruption-shift-5',
        'new-class-shift-cifar10',
        'new-class-shift-cifar10-original-mode',
        'new-class-shift-cifar100',
        'new-class-shift-cifar100-original-mode',
        'new-class-shift-svhn',
        'new-class-shift-svhn-original-mode',
        'new-class-shift-tinyimagenet',
        'new-class-shift-tinyimagenet-original-mode'
    ]

    # print(df)

    for exp in plot_exps:
        print(f"Creating plots for {exp}...")
        pdata = df[df.study.str.startswith(exp + "_")][
            ["study", "confid", "run"] + metrics
        ]

        def fix_studies(n):
            n = n.replace(exp + "_", "")
            n = n.replace("_proposed_mode", "")
            n = n.replace("_", "-")
            n = n.replace("-study-", "-shift-")
            n = n.replace("in-class", "sub-class")
            n = n.replace("noise", "corruption")
            n = n.replace("-resize", "")
            n = n.replace("-wilds-ood-test", "")
            n = n.replace("-ood-test", "")
            n = n.replace("-superclasses", "")
            return n


        plot_studies = pdata.study.unique().tolist()
        plot_studies = [
            c for c in plot_studies if not "val_tuning" in c
        ]  # & (data["ne"].str.contains("250")) & (data["ap"]==False)]
        plot_studies = list(sorted(plot_studies, key=lambda x: fix_studies(x)))
        # print(studies)
        cols = [c for c in plot_studies if exp + "_" in c]
        # plot_studies = studies
        ncols = len(plot_studies)

        nrows = len(metrics)
        f, axs = plt.subplots(
            nrows=nrows,
            ncols=ncols,
            figsize=(ncols * scale * 1.2, nrows * scale * 1.2),
            squeeze=False
        )
        # axs = axs.flatten()

        for mix, metric in enumerate(metrics):
            plot_data = df[df.study.str.startswith(exp + "_")][
                ["study", "confid", "run", metric]
            ]
            # print(plot_studies, plot_data.columns)
            saxs = axs[mix]
            for xix, dim in enumerate(dims):
                skipped = 0
                for yix, study in enumerate(studies):
                    if study not in [fix_studies(s) for s in plot_studies]:
                        skipped += 1
                        continue

                    yix = yix - skipped
                    y = metric
                    # print(plot_data.study.apply(fix_studies), study)
                    data = plot_data[plot_data.study.apply(fix_studies) == study].sort_values(by="confid")
                    plot_colors = [
                        color_dict[conf] for conf in data.confid.unique().tolist()
                    ]
                    # print(plot_colors)
                    palette = sns.color_palette(plot_colors)
                    # print(plot_colors)
                    # print(data.confid.unique().tolist())
                    sns.set_palette(palette)

                    # print(data[~data[dim].str.startswith("VIT")])

                    order = data[dim].str.replace("VIT-", "").sort_values().unique()

                    # if not "noise" in study or "noise_study_3" in study:
                    # print(study)
                    sns.stripplot(
                        ax=saxs[yix],
                        x=data[~data[dim].str.startswith("VIT")][dim],
                        y=metric,
                        data=data[~data[dim].str.startswith("VIT")],
                        s=scale * 1.6,
                        label=dim,
                        alpha=0.5,
                        order=order,
                    )
                    sns.stripplot(
                        ax=saxs[yix],
                        x=data[data[dim].str.startswith("VIT")][dim].str.replace("VIT-", ""),
                        y=metric,
                        data=data[data[dim].str.startswith("VIT")],
                        s=scale * 1.6,
                        label=dim,
                        marker='X',
                        alpha=0.5,
                        order=order,
                    )
                    sns.boxplot(
                        ax=saxs[yix],
                        x=data[~data[dim].str.startswith("VIT")][dim],
                        y=metric,
                        data=data[~data[dim].str.startswith("VIT")],
                        medianprops=dict(alpha=0),
                        saturation=0,
                        showbox=False,
                        showcaps=False,
                        showfliers=False,
                        whiskerprops=whiskerprops,
                        showmeans=True,
                        meanprops=meanprops,
                        meanline=True,
                        order=order,
                    )
                    sns.boxplot(
                        ax=saxs[yix],
                        x=data[data[dim].str.startswith("VIT-")][dim].str.replace("VIT-", ""),
                        y=metric,
                        data=data[data[dim].str.startswith("VIT-")],
                        medianprops=dict(alpha=0),
                        saturation=0,
                        showbox=False,
                        showcaps=False,
                        showfliers=False,
                        whiskerprops=whiskerprops,
                        showmeans=True,
                        meanprops=dict(linewidth=6, alpha=1, zorder=99, dashes=(1,1,1,1,1,1,1)),
                        meanline=True,
                        order=order,
                    )
                    # axs[yix].set_xticklabels("")
                    saxs[yix].set_xticklabels(saxs[yix].get_xticklabels(), rotation=90)

                    saxs[yix].set_title(fix_studies(study), pad=35)
                    saxs[yix].set_ylabel("")
                    saxs[yix].set_xlabel("")
                    # lim = data[metric].mean() + data[metric].std()
    #                 saxs[yix].set_ylim(data[metric].min(), data[metric].max())
                    if yix == 0:
                        saxs[yix].set_ylabel(metric)

                    # if yix == 5:
                    #     axs[yix].axis("off")
                    #     axs[yix-1].legend()

                    # if "iid" in study and metric == "aurc":
                    #     axs[xix, yix].set_ylim(4, 8)
                    # if "iid" in study and metric == "failauc":
                    #     axs[xix, yix].set_ylim(0.90, 0.96)
        plt.tight_layout()
        plt.savefig(
            "/home/tillb/Projects/failure-detection-benchmark/results/final_paper_{}_single_column.png".format(
                exp
            )
        )
        plt.close(f)
        
final_strip_plots_sc()

In [None]:
# FINAL STRIP PLOTS

def final_strip_plots_box():
    metrics = [
        "accuracy", "aurc", "failauc", "ece",  "fail-NLL"
    ]
    plot_exps = [
        "cifar10",
        "cifar100",
        "svhn",
        "breeds",
        "animals",
        "camelyon",
    ]  # exp_names
    cross_mode = False
    scale = 15
    sns.set_style("whitegrid")
    # sns.color_palette("tab20")
    # palette = sns.color_palette()
    # c = []
    # for ix in range(15):
    #     print(ix)
    #     c.append(palette[ix])
    # print(c)
    # random.shuffle(c)
    # print(c)
    sns.set_context("paper", font_scale=scale * 0.35)
    dims = ["confid"]

    studies = [
        'iid-study',
        'sub-class-shift',
        'corruption-shift-1',
        'corruption-shift-2',
        'corruption-shift-3',
        'corruption-shift-4',
        'corruption-shift-5',
        'new-class-shift-cifar10',
        'new-class-shift-cifar10-original-mode',
        'new-class-shift-cifar100',
        'new-class-shift-cifar100-original-mode',
        'new-class-shift-svhn',
        'new-class-shift-svhn-original-mode',
        'new-class-shift-tinyimagenet',
        'new-class-shift-tinyimagenet-original-mode'
    ]

    # print(df)

    for exp in plot_exps:
        print(f"Creating plots for {exp}...")
        pdata = df[df.study.str.startswith(exp + "_")][
            ["study", "confid", "run"] + metrics
        ]

        def fix_studies(n):
            n = n.replace(exp + "_", "")
            n = n.replace("_proposed_mode", "")
            n = n.replace("_", "-")
            n = n.replace("-study-", "-shift-")
            n = n.replace("in-class", "sub-class")
            n = n.replace("noise", "corruption")
            n = n.replace("-resize", "")
            n = n.replace("-wilds-ood-test", "")
            n = n.replace("-ood-test", "")
            n = n.replace("-superclasses", "")
            return n


        plot_studies = pdata.study.unique().tolist()
        plot_studies = [
            c for c in plot_studies if not "val_tuning" in c
        ]  # & (data["ne"].str.contains("250")) & (data["ap"]==False)]
        plot_studies = list(sorted(plot_studies, key=lambda x: fix_studies(x)))
        # print(studies)
        cols = [c for c in plot_studies if exp + "_" in c]
        # plot_studies = studies
        ncols = len(plot_studies)

        nrows = len(metrics)
        f, axs = plt.subplots(
            nrows=nrows,
            ncols=ncols,
            figsize=(ncols * scale * 1.2, nrows * scale * 1.2),
            squeeze=False
        )
        # axs = axs.flatten()

        for mix, metric in enumerate(metrics):
            plot_data = df[df.study.str.startswith(exp + "_")][
                ["study", "confid", "run", metric]
            ]
            # print(plot_studies, plot_data.columns)
            saxs = axs[mix]
            for xix, dim in enumerate(dims):
                skipped = 0
                for yix, study in enumerate(studies):
                    if study not in [fix_studies(s) for s in plot_studies]:
                        skipped += 1
                        continue

                    yix = yix - skipped
                    y = metric
                    # print(plot_data.study.apply(fix_studies), study)
                    data = plot_data[plot_data.study.apply(fix_studies) == study].sort_values(by="confid")
                    plot_colors = [
                        color_dict[conf] for conf in data.confid.unique().tolist()
                    ]
                    # print(plot_colors)
                    palette = sns.color_palette(plot_colors)
                    # print(plot_colors)
                    # print(data.confid.unique().tolist())
                    sns.set_palette(palette)

                    # print(data[~data[dim].str.startswith("VIT")])

                    order = data[dim].str.replace("VIT-", "").sort_values().unique()

                    # if not "noise" in study or "noise_study_3" in study:
                    # print(study)
                    # sns.stripplot(
                    #     ax=saxs[yix],
                    #     x=data[~data[dim].str.startswith("VIT")][dim],
                    #     y=metric,
                    #     data=data[~data[dim].str.startswith("VIT")],
                    #     s=scale * 1.6,
                    #     label=dim,
                    #     order=order,
                    # )
                    # sns.stripplot(
                    #     ax=saxs[yix],
                    #     x=data[data[dim].str.startswith("VIT")][dim].str.replace("VIT-", ""),
                    #     y=metric,
                    #     data=data[data[dim].str.startswith("VIT")],
                    #     s=scale * 1.6,
                    #     label=dim,
                    #     marker='X',
                    #     order=order,
                    # )
                    sns.boxplot(
                        ax=saxs[yix],
                        x=data[~data[dim].str.startswith("VIT")][dim],
                        y=metric,
                        data=data[~data[dim].str.startswith("VIT")],
                        medianprops=dict(alpha=0),
                        saturation=0,
                        showbox=True,
                        showcaps=False,
                        showfliers=False,
                        whiskerprops=whiskerprops,
                        showmeans=True,
                        meanprops=meanprops,
                        meanline=True,
                        order=order,
                    )
                    sns.boxplot(
                        ax=saxs[yix],
                        x=data[data[dim].str.startswith("VIT-")][dim].str.replace("VIT-", ""),
                        y=metric,
                        data=data[data[dim].str.startswith("VIT-")],
                        medianprops=dict(alpha=0),
                        saturation=0,
                        showbox=True,
                        showcaps=False,
                        showfliers=False,
                        whiskerprops=whiskerprops,
                        showmeans=True,
                        meanprops=meanprops,
                        meanline=True,
                        order=order,
                    )
                    # axs[yix].set_xticklabels("")
                    saxs[yix].set_xticklabels(saxs[yix].get_xticklabels(), rotation=90)

                    saxs[yix].set_title(fix_studies(study), pad=35)
                    saxs[yix].set_ylabel("")
                    saxs[yix].set_xlabel("")
                    # lim = data[metric].mean() + data[metric].std()
    #                 saxs[yix].set_ylim(data[metric].min(), data[metric].max())
                    if yix == 0:
                        saxs[yix].set_ylabel(metric)

                    # if yix == 5:
                    #     axs[yix].axis("off")
                    #     axs[yix-1].legend()

                    # if "iid" in study and metric == "aurc":
                    #     axs[xix, yix].set_ylim(4, 8)
                    # if "iid" in study and metric == "failauc":
                    #     axs[xix, yix].set_ylim(0.90, 0.96)
        plt.tight_layout()
        plt.savefig(
            "/home/tillb/Projects/failure-detection-benchmark/results/final_paper_{}_single_column_box.png".format(
                exp
            )
        )
        plt.close(f)
        
final_strip_plots_box()

In [None]:
# ************************ RISK PLOTS *******************************
import random

metrics = ["ece"]
plot_exps = ["cifar100"]  # exp_names
cross_mode = False
scale = 15
sns.set_style("whitegrid")
# sns.color_palette("tab20")
# palette = sns.color_palette()
# c = []
# for ix in range(15):
#     print(ix)
#     c.append(palette[ix])
# print(c)
# random.shuffle(c)
# print(c)
# sns.set_palette(c)
sns.set_context("paper", font_scale=scale * 0.5)
dims = ["confid"]


for metric in metrics:
    if not cross_mode:
        for exp in plot_exps:
            plot_data = df[df.study.str.startswith(exp + "_")][
                ["study", "confid", "run", metric]
            ]
            studies = plot_data.study.unique().tolist()
            studies = [
                c for c in studies if not "val_tuning" in c
            ]  # & (data["ne"].str.contains("250")) & (data["ap"]==False)]
            plot_studies = studies  # [c for c in studies if not ("noise" in c or "noise_study_1" in c)]
            cols = [c for c in plot_studies if exp + "_" in c]
            plot_studies = (
                ["{}_iid_study".format(exp)]
                + [c for c in cols if "noise" in c]
                + [c for c in cols if "in_class" in c]
                + [c for c in cols if "proposed" in c]
            )
            # plot_studies = [c for c in cols if "noise" in c]
            print(studies, plot_data.columns)
            ncols = len(plot_studies)
            print("CHECK COLS", ncols, plot_studies)
            f, axs = plt.subplots(
                nrows=len(dims),
                ncols=ncols,
                figsize=(6 * scale * 1.2, len(dims) * scale * 1.2),
            )
            for xix, dim in enumerate(dims):
                for yix, study in enumerate(plot_studies):
                    y = metric
                    data = plot_data[plot_data.study == study].sort_values(by="confid")
                    # if not "noise" in study or "noise_study_3" in study:
                    print(study)
                    sns.stripplot(
                        ax=axs[yix],
                        x=dim,
                        y=metric,
                        data=data,
                        s=scale * 1.6,
                        label=dim,
                    )
                    sns.boxplot(
                        ax=axs[yix],
                        x=dim,
                        y=metric,
                        data=data,
                        medianprops=dict(alpha=0),
                        saturation=0,
                        showbox=False,
                        showcaps=False,
                        showfliers=False,
                        whiskerprops=whiskerprops,
                        showmeans=True,
                        meanprops=meanprops,
                        meanline=True,
                    )
                    # axs[yix].set_xticklabels("")
                    axs[yix].set_xticklabels(axs[yix].get_xticklabels(), rotation=90)
                    title = study
                    title = title.replace(exp + "_", "")
                    title = title.replace("_proposed_mode", "")
                    title = title.replace("_", "-")
                    title = title.replace("-study-", "-shift-")
                    title = title.replace("in-class", "sub-class")
                    title = title.replace("noise", "corruption")
                    title = title.replace("-resize", "")
                    title = title.replace("-wilds-ood-test", "")
                    title = title.replace("-ood-test", "")
                    title = title.replace("-superclasses", "")
                    axs[yix].set_title(title, pad=35)
                    axs[yix].set_ylabel("")
                    axs[yix].set_xlabel("")
                    lim = data[metric].mean() + data[metric].std()
                    axs[yix].set_ylim(axs[yix].get_ylim()[0], lim)
                    if yix == 0:
                        axs[yix].set_ylabel(exp)

                    # if yix == 5:
                    #     axs[yix].axis("off")
                    #     axs[yix-1].legend()

                    # if "iid" in study and metric == "aurc":
                    #     axs[xix, yix].set_ylim(4, 8)
                    # if "iid" in study and metric == "failauc":
                    #     axs[xix, yix].set_ylim(0.90, 0.96)
            plt.tight_layout()
            plt.savefig(
                "/home/t974t/Projects/failure-detection-benchmark/results/RISK_final_paper_{}_{}.png".format(
                    exp, metric
                )
            )