In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
print(f"Old working dir {os.getcwd()}")
os.chdir('../../')
print(f"New working dir {os.getcwd()}")

In [None]:
from pathlib import Path
plots_dir = Path('./conformal_plots/')
os.makedirs(plots_dir, exist_ok=True)

In [None]:
from typing import List
from pathlib import Path

import numpy as np
import pandas as pd
import torch

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

In [None]:
results_dir = Path('./conformal_results_u/')
#results_dir = Path('./conformal_results/')

In [None]:
from conformal.real_datasets.process_raw import datasets

In [None]:
n_targets = {"rf1": 8, "rf2": 8, "scm1d": 16, "scm20d": 16, "sgemm": 4, "bio": 2, "blog": 2}
df_n_targets = pd.DataFrame({"dataset_name": n_targets.keys(), "n_targets": n_targets.values()})
df_n_targets

In [None]:
t20c = matplotlib.colormaps["tab20c"]
t20c

In [None]:
palette = {
    "OT-CP-Global": t20c(0),
    "OT-CP-Local": t20c(1),
    "Ell-Local": t20c(2),
    "PB": t20c(4),
    "RPB": t20c(5),
    "HPD": t20c(6),
    "Quantile": t20c(7),
    "PB (CPFlow)": t20c(8),
    "RPB (CPFlow)": t20c(9),
    "HPD (CPFlow)": t20c(10),
    "Quantile (CPFlow)": t20c(11),
    "PB (Y)": t20c(12),
    "RPB (Y)": t20c(13),
    "HPD (Y)": t20c(14),
    "Quantile (Y)": t20c(15),
}

In [None]:
colormap = matplotlib.colormaps["tab20"]
colormap

In [None]:
selected_palette = {
    "OT-CP-Global": colormap(0),
    "OT-CP-Local": colormap(1),
    "PB": colormap(2),
    "RPB": colormap(3),
    "PB (RF)": colormap(4),
    "RPB (RF)": colormap(5),
    
    "PB (Y, RF)": colormap(6),
    "RPB (Y, RF)": colormap(7),

    "PB (Y)": colormap(8),
    "RPB (Y)": colormap(9),
    
    "Ell-Local": colormap(12),
}

In [None]:
# Where to load each method/metric from?


In [None]:
from typing import Literal

def load_methods_from(method_names: List[str], path: str | Path, seeds=range(10), extention: Literal["feather", "csv"] = "feather") -> pd.DataFrame:
    dataframes = []
    #seeds = range(10)
    for seed in seeds:
        for dataset_name in datasets:
            fn = Path(path) / dataset_name / str(seed) / f"metrics_all.{extention}"
            if fn.is_file():
                if extention == "feather":
                    dataframes.append(pd.read_feather(fn))
                else:
                    dataframes.append(pd.read_csv(fn))
            else:
                print(f"Error: dataset {dataset_name}, seed {seed} not found.")
    df = pd.concat(dataframes).merge(df_n_targets, on="dataset_name")
    if "volume" in df.columns:
        df["log_vol_d"] = np.log(df["volume"]) / df["n_targets"]
    if len(method_names) > 0:
        df = df[df["method_name"].isin(method_names)]
    return df


In [None]:
df0 = load_methods_from(method_names=[], path="./conformal_results_u/")

In [None]:
#df.head(20)
df1 = load_methods_from(method_names=[], path="./conformal_results_u/")
df2 = load_methods_from(method_names=[], path="./conformal_results_250923/", extention="csv")

In [None]:
df1["method_name"].unique(), df2.method_name.unique()

In [None]:
df1

In [None]:
df2

In [None]:
df = pd.concat([df1, df2], ignore_index=True)

In [None]:
#palette = blended_palette(df["base_model_name"], df["conformalizer"], paletteA="Set1", paletteB="Set2")

In [None]:
#pd.DataFrame.from_dict(palette, orient="index")

In [None]:
df.shape, df[['dataset_name', 'alpha', 'method_name', 'seed']].drop_duplicates().shape

In [None]:
df["method_name"].unique()

In [None]:
(df0[df0['method_name'].str.contains("OT-CP")]["volume"] == 0).sum()

In [None]:
g_cov = sns.catplot(
    data=df[df["dataset_name"].isin(["sgemm", "scm20d"])], x="dataset_name", y="marginal_coverage", col="alpha", hue="method_name", sharey=False,
)
g_cov.set_axis_labels("Dataset", "Marginal coverage")
for alpha, ax in g_cov.axes_dict.items():
    ax.axhline(1 - alpha, ls="--", c="k", alpha=0.5)
#for ax in g_cov.axes.flatten():
#    ax.tick_params(labelbottom=True)

In [None]:
g_cov = sns.catplot(
    data=df0, x="dataset_name", y="volume", col="alpha", hue="method_name", sharey=False,
)

In [None]:
#g_cov.axes_dict

In [None]:
#df

In [None]:
metrics_columns = ["marginal_coverage", "worst_slab_coverage", "volume", "log_vol_d"]
id_vars = list(df.columns.difference(metrics_columns))
df_melted = pd.melt(df, id_vars=id_vars, value_vars=metrics_columns, var_name="metric", value_name="value")

In [None]:
len(df_melted)

In [None]:
g_all = sns.catplot(
    data=df_melted,#.query("dataset_name == 'bio' or dataset_name == 'blog'"), 
    kind="box", 
    x="dataset_name", y="value", col="alpha", row="metric", hue="method_name", #_mathtext",
    #palette=palette,
    sharey="row", showfliers=False,
)
g_all.set_axis_labels("", "")
for (metric_name, alpha), ax in g_all.axes_dict.items():
    if "coverage" in metric_name:
        ax.axhline(1 - alpha, ls="--", c="k", alpha=0.5)
    if "volume" in metric_name:
        ax.set_yscale("log")
for ax in g_all.axes.flatten():
    ax.tick_params(labelbottom=True)

In [None]:
df_main = df[
    df["dataset_name"].isin(["sgemm", "scm20d", "bio", "blog",])
    & ~df["method_name"].str.contains("CPFlow")
    & df['worst_slab_coverage']!=0
].copy()
def get_hatch(name):
    if "Quantile" in name:
        return "/"
    elif "PB" in name:
        return "\\"
    elif "RPB" in name:
        return "x"
    elif "HPD" in name:
        return "-"
    else:
        return None
df_main["hatch"] = df_main["method_name"].apply(get_hatch)

In [None]:
boxplot_hatches = {
    -3: "",
    -2: "",
    -1: "",
    0: "//",
    1: "xx",
    2: "--",
    3: "o",
}

In [None]:
df_main["hatch"].head()

In [None]:
t20c

In [None]:
palette_appendix = {
    "OT-CP-Global": t20c(0),
    "OT-CP-Local": t20c(1),
    "Ell-Local": t20c(2),
    "PB": t20c(4),
    "RPB": t20c(5),
    "HPD": t20c(6),
    "Quantile": t20c(7),
    "PB (Y)": t20c(8),
    "RPB (Y)": t20c(9),
    "HPD (Y)": t20c(10),
    "Quantile (Y)": t20c(11),
    "PB (RF)": t20c(12),
    "RPB (RF)": t20c(13),
    "HPD (RF)": t20c(14),
    "Quantile (RF)": t20c(15),
    "PB (Y, RF)": t20c(16),
    "RPB (Y, RF)": t20c(17),
    "HPD (Y, RF)": t20c(18),
    "Quantile (Y, RF)": t20c(19),
}

In [None]:
#n_facets_to_plot = new_ugly_filter_wsc_df["dataset_name"].nunique()
#print(n_facets_to_plot)
iclr_width = 5.50107
plot_aspect_wide = 16 / 9
plot_height = iclr_width / plot_aspect_wide
sns.set_style({'axes.grid' : True})
g_wsc = sns.catplot(
    data=df_main,
    kind="box",
    y="worst_slab_coverage",
    col="alpha",
    row="dataset_name",
    #col="dataset_name",
    hue="method_name", #_mathtext",
    palette=palette_appendix,
    sharey="row",
    showfliers=False,
    #height=plot_height,
)
g_wsc.set_axis_labels("", "Worst slab coverage")
g_wsc.set_xticklabels([])
g_wsc.despine(bottom=False, top=False, right=False)
for (dataset_name, alpha), ax in g_wsc.axes_dict.items():
    ax.set_title(rf"$\mathtt{{{dataset_name}}}$, $\alpha={alpha:.1f}$")
    ax.axhline(1 - alpha, xmax=1, ls="--", c="k", alpha=0.9)
for ax in g_wsc.axes.flatten():
    ax.tick_params(left=False, bottom=False)
    for i, patch in enumerate(ax.patches):
        # Blue bars first, then green bars
        patch.set_hatch(boxplot_hatches[(i - 3) % 4])
for j, legend_patch in enumerate(g_wsc.legend.get_patches()):
    legend_patch.set_hatch(boxplot_hatches[j % 4])

sns.move_legend(g_wsc, "lower center", bbox_to_anchor=(0.45, 1), ncol=len(palette) // 2, title=None,
                )
g_wsc.savefig(plots_dir / "results_worst_slab_coverage_250925_hatch.pdf", bbox_inches="tight")
g_wsc.savefig(plots_dir / "results_worst_slab_coverage_250925_hatch.png", bbox_inches="tight")

# Selected results for main part

In [None]:
df['dataset_name'].unique()

In [None]:
#df

In [None]:
#g_all.axes_dict
#df[df['dataset_name'] == 'sgemm']

In [None]:
new_ugly_filter_wsc_df = df[
    (df["alpha"] == 0.1) &
    (~df["method_name"].str.contains("Quantile")) &
    ~df["method_name"].str.contains("HPD") &
    ~df["method_name"].str.contains("CPFlow") &
    ~df["dataset_name"].str.contains("rf") &
    ~df["dataset_name"].str.contains("scm1d") &
    ~df["method_name"].str.contains("Y")
].copy()
new_ugly_filter_wsc_df.columns

In [None]:
new_ugly_filter_wsc_df['worst_slab_coverage_error'] = np.log((new_ugly_filter_wsc_df['worst_slab_coverage'] - (1 - new_ugly_filter_wsc_df['alpha'])).abs())

In [None]:
labels_main_part_old = [r"$\mathtt{OT}$-$\mathtt{CP}$", r"$\mathtt{OT}$-$\mathtt{CP}$+", 
                        r"$\mathrm{ELL}$",
                        r"$\mathrm{PB}_{U}$", r"$\mathrm{RPB}_{U}$",
                        r"$\mathrm{PB}_{Y}$", r"$\mathrm{RPB}_{Y}$",
                        r"$\mathrm{PBS}_{U}$", r"$\mathrm{RPBS}_{U}$",
                        r"$\mathrm{PBS}_{Y}$", r"$\mathrm{RPBS}_{Y}$",]

In [None]:
labels_main_part = [r"$\mathtt{OT}$-$\mathtt{CP}$", r"$\mathtt{OT}$-$\mathtt{CP}$+", 
                        r"$\mathrm{ELL}$",
                        r"$\mathrm{PB}$", r"$\mathrm{RPB}$",
                        r"$\mathrm{PBS}$", r"$\mathrm{RPBS}$",]

In [None]:
n_facets_to_plot = new_ugly_filter_wsc_df["dataset_name"].nunique()
print(n_facets_to_plot)
iclr_width = 5.50107
plot_aspect_wide = 16 / 9
plot_height = iclr_width / plot_aspect_wide
sns.set_style({'axes.grid' : True})
g_wsc = sns.catplot(
    data=new_ugly_filter_wsc_df,
    kind="box",
    y="worst_slab_coverage",
    #col="alpha",
    #row="dataset_name",
    col="dataset_name",
    hue="method_name", #_mathtext",
    palette=selected_palette,
    sharey=True,
    showfliers=False,
    height=plot_height,
)
sns.move_legend(g_wsc, "lower center", bbox_to_anchor=(0.45, 1), ncol=len(palette), title=None,
                labels=labels_main_part)
g_wsc.set_axis_labels("", "Worst slab coverage")
g_wsc.set_xticklabels([])
g_wsc.despine(bottom=False, top=False, right=False)
#for (dataset_name, alpha,), ax in g_wsc.axes_dict.items():
for dataset_name, ax in g_wsc.axes_dict.items():
    #ax.set_title(rf"$\mathtt{{{dataset_name}}}$, $\alpha={alpha:.1f}$")
    ax.set_title(rf"$\mathtt{{{dataset_name}}}$")
    ax.axhline(1 - 0.1, xmax=1, ls="--", c="k", alpha=0.9)
for ax in g_wsc.axes.flatten():
    ax.tick_params(left=False, bottom=False)
    ax.set_ylim(0.65, 0.95)

#g_wsc.savefig(plots_dir / "selected_results_worst_slab_coverage_250924.pdf", bbox_inches="tight")
#g_wsc.savefig(plots_dir / "selected_results_worst_slab_coverage_250924.png", bbox_inches="tight")

In [None]:
plot_height

In [None]:
12 / 5 / 4

In [None]:
new_ugly_filter_volume_df = df[
    (df["alpha"] == 0.1) &
    (~df["method_name"].str.contains("Quantile")) &
    ~df["method_name"].str.contains("HPD") &
    ~df["method_name"].str.contains("CPFlow") &
    #~df["dataset_name"].str.contains("scm1d") &
    #~df["dataset_name"].str.contains("rf") &
    df["dataset_name"].isin(["scm20d", "sgemm", "bio", "blog"]) &
    ~df["method_name"].str.contains("Y")    
].copy()
new_ugly_filter_volume_df.columns

In [None]:
#new_ugly_filter_volume_df.query("dataset_name == 'sgemm' and method_name == 'OT-CP-Local'")["volume"]

In [None]:
#new_ugly_filter_volume_df.query("dataset_name == 'sgemm' and method_name == 'OT-CP-Global'")["volume"]

In [None]:
#new_ugly_filter_volume_df.query("dataset_name == 'sgemm' and method_name == 'PB'")["volume"]

In [None]:
df_four_volumes = pd.read_csv("four_volumes.csv").set_index(["dataset_name", "seed"])
df_four_volumes

In [None]:
df_four_volumes_rf = pd.read_csv("four_volumes_rf.csv").set_index(["dataset_name", "seed"])
df_four_volumes_rf

In [None]:
import itertools


new_ugly_filter_volume_corrected_df = \
    new_ugly_filter_volume_df[(new_ugly_filter_volume_df["seed"].isin([0, 1, 3])) | (~new_ugly_filter_volume_df["method_name"].isin(["PB", "PB (RF)"]))].copy()
for dataset_name, seed in itertools.product(["scm20d", "sgemm", "bio", "blog"], [0, 1, 3]):
    #new_ugly_filter_volume_corrected_df["log_vol_d"] 
    new_ugly_filter_volume_corrected_df.loc[
        ((new_ugly_filter_volume_corrected_df["method_name"] == "PB")) & (new_ugly_filter_volume_corrected_df["dataset_name"] == dataset_name) & (new_ugly_filter_volume_corrected_df["seed"] == seed),
        "log_vol_d"
    ] = df_four_volumes.loc[(dataset_name, seed), "mean"]
    new_ugly_filter_volume_corrected_df.loc[
        ((new_ugly_filter_volume_corrected_df["method_name"] == "PB (RF)")) & (new_ugly_filter_volume_corrected_df["dataset_name"] == dataset_name) & (new_ugly_filter_volume_corrected_df["seed"] == seed),
        "log_vol_d"
    ] = df_four_volumes_rf.loc[(dataset_name, seed), "mean"]
new_ugly_filter_volume_corrected_df.query("method_name == 'PB (RF)'")[["dataset_name", "seed", "log_vol_d"]]

In [None]:
g_logvold = sns.catplot(
    data=new_ugly_filter_volume_corrected_df,#.query("dataset_name == 'bio' or dataset_name == 'blog'"),
    kind="bar",
    y="log_vol_d",
    #col="alpha",
    #row="dataset_name",
    col="dataset_name",
    hue="method_name", #_mathtext",
    estimator="median",
    palette=selected_palette,
    sharey=False,
    #showfliers=False,
    facet_kws={
        "despine": False,
    },
    height=plot_height,
    linewidth=0.9,
    edgecolor="k",
    dodge=2.6,
    gap=0.1,
)
sns.move_legend(g_logvold, "lower center", bbox_to_anchor=(0.45, 1), ncol=len(palette), title=None, 
                labels=labels_main_part)
g_logvold.set_axis_labels("", r"$(\log V) / d$")#Worst slab coverage")
#g_logvold.set_axis_labels("", "Volume")
g_logvold.set_xticklabels([])
#g_logvold.despine(bottom=True)
for dataset_name, ax in g_logvold.axes_dict.items():
    ax.set_title(rf"$\mathtt{{{dataset_name}}}$")
for ax in g_logvold.axes.flatten():
    ax.tick_params(bottom=False)
    ax.grid(visible=True, which="both", axis="y")
    ax.set_axisbelow(True)
    #ax.set_ylim(None, 2.5)

g_logvold.savefig(plots_dir / "selected_results_volume_250925.pdf", bbox_inches="tight")
g_logvold.savefig(plots_dir / "selected_results_volume_250925.png", bbox_inches="tight")

In [None]:
np.log(6427081) / 16

In [None]:
new_ugly_filter_volume_df.query("method_name == 'Ell-Local'").volume.min()

In [None]:
os.getcwd()

In [None]:
colormap = matplotlib.colormaps["tab20"]
colormap

In [None]:
tuned_params = {}
dfs = []
for dataset_name in ("rf1", "rf2", "scm1d", "scm20d"):
    df_tuning = pd.read_feather(f"./conformal_results_slurm/{dataset_name}/53/tuning.feather")
    print(df_tuning.loc[df_tuning['error'].idxmin()])
    tuned_params[dataset_name] = df_tuning.loc[df_tuning['error'].idxmin()].to_dict()
    df_tuning["dataset_name"] = dataset_name
    dfs.append(df_tuning)
df_tuning = pd.concat(dfs)

In [None]:
print(tuned_params)

In [None]:
sns.pointplot(df_tuning, x="n_epochs", y="error", hue="dataset_name")

In [None]:
df3 = load_methods_from(method_names=[], path="./conformal_results_sgemm_no_areas/", seeds=range(10, 15), extention="feather")
df4 = load_methods_from(method_names=[], path="./conformal_results_sgemm_areas/", seeds=range(10, 15),extention="feather")

In [None]:
df_sgemm = pd.merge(df3, df4,)

In [None]:
df_sgemm

In [None]:
g_sgemm_vold = sns.catplot(
    data=df_sgemm,
    kind="bar",
    y="volume",
    #col="alpha",
    #row="dataset_name",
    col="dataset_name",
    hue="method_name", #_mathtext",
    estimator="median",
    #palette=selected_palette,
    sharey=True,
    #showfliers=False,
    facet_kws={
        "despine": False,
    },
    height=plot_height,
    linewidth=0.9,
    edgecolor="k",
    dodge=2.6,
    gap=0.1,
)
plt.ylim(0, 2.5)

In [None]:
df_sgemm.plot("volume", kind="hist")

In [None]:
(df_sgemm["volume"] > 0).sum()