In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../")

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

DATA_PATH = "../data"

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

from src.log_mock import PrintLog
log = PrintLog()

import wandb

wandb.init(mode="disabled")

In [None]:
wapi = wandb.Api()
runs = wapi.runs("bayes/poverty")

In [None]:
for run in runs:
    print(run.name, run.summary.keys())

In [None]:
import plotly.express as px
import pandas as pd
import dateutil
import datetime

def create_plot_data_for_run(run):
    model_name = run.name.split("-")[0]
    return {
        "model": model_name + "-" + run.name.split("-")[1],
        "ood mse": run.summary["ood"]["mse"],
        "ood pearson": run.summary["ood"]["pearson"],
        "ood ll": run.summary["ood"]["avg_ll"],
        "ood lml": run.summary["ood"]["avg_lml"],
        "ood qce": run.summary["ood"]["qce"],
        "ood sqce": run.summary["ood"]["sqce"],
        "id mse": run.summary["id"]["mse"],
        "id pearson": run.summary["id"]["pearson"],
        "id ll": run.summary["id"]["avg_ll"],
        "id lml": run.summary["id"]["avg_lml"],
        "id qce": run.summary["id"]["qce"],
        "id sqce": run.summary["id"]["sqce"],
    }

def plot(data, value):
    plot = px.box(data, x="model", y=value, color="model")
    return plot

def pareto_plot(data, x, y):
    plot = px.scatter(data, x=x, error_x=f"{x}_std", y=y, error_y=f"{y}_std", color="model")
    return plot

def build_data(runs):
    rows = []
    for run in runs:
        if run.state != "finished":
            continue
        if "old" in run.tags:
            print("Skipping old run " + run.name)
            continue
        if "ood" not in run.summary:
            print("Skipping failed run " + run.name)
            continue
        rows.append(create_plot_data_for_run(run))
    return pd.DataFrame.from_dict(rows)

def aggregate_data(data):
    aggregated_data = data.groupby(["model"]).agg({
        "model": "first",
        "ood mse": ["mean", "sem"], 
        "ood pearson": ["mean", "sem"], 
        "ood ll": ["mean", "sem"], 
        "ood lml": ["mean", "sem"],
        "ood qce": ["mean", "sem"],
        "ood sqce": ["mean", "sem"],
        "id mse": ["mean", "sem"], 
        "id pearson": ["mean", "sem"], 
        "id ll": ["mean", "sem"], 
        "id lml": ["mean", "sem"],
        "id qce": ["mean", "sem"],
        "id sqce": ["mean", "sem"],
    })
    aggregated_data.columns = [a[0] + "_std" if a[1] == "sem" else a[0] for a in aggregated_data.columns.to_flat_index()]
    aggregated_data["ood mse_std"] *= 2.0
    aggregated_data["ood pearson_std"] *= 2.0
    aggregated_data["ood ll_std"] *= 2.0
    aggregated_data["ood lml_std"] *= 2.0
    aggregated_data["ood qce_std"] *= 2.0
    aggregated_data["ood sqce_std"] *= 2.0
    aggregated_data["id mse_std"] *= 2.0
    aggregated_data["id pearson_std"] *= 2.0
    aggregated_data["id ll_std"] *= 2.0
    aggregated_data["id lml_std"] *= 2.0
    aggregated_data["id qce_std"] *= 2.0
    aggregated_data["id sqce_std"] *= 2.0
    return aggregated_data

In [None]:
data = aggregate_data(build_data(runs))

In [None]:
data

In [None]:
pareto_plot(data, "ood pearson", "ood sqce")

In [None]:
pareto_plot(data, "id pearson", "id sqce")

In [None]:
pareto_plot(data, "ood lml", "ood ll")

In [None]:
data.to_csv(sep=",", header=True)

In [None]:
algo_names = [
    ("map-1", "MAP"),
    ("map-5", "Deep Ensemble"),
    ("mcd-1", "MCD"),
    ("mcd-5", "MultiMCD"),
    ("swag-1", "SWAG"),
    ("swag-5", "MultiSWAG"),
    ("laplace-1", "LL Laplace"),
    ("laplace-5", "LL MultiLaplace"),
    ("bbb-1", "BBB"),
    ("bbb-5", "MultiBBB"),
    ("rank1-1", "Rank-1 VI"),
    ("ivon_p500-1", "iVON"),
    ("svgd-1", "SVGD"),
    ("sngp", "SNGP"),
]

def num(value, std, best=None, ty=None):
    value = float(value)
    std = float(std)

    if not math.isnan(std):
        num_string = f"{value:.3f} \\pm {std:.3f}"
    else:
        num_string = f"{value:.3f} \\pm -"

    if best is None or ty is None or math.isnan(std):
        return f"${num_string}$"

    if ty == "max":
        if value >= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "min":
        if value <= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "zero":
        if abs(value) <= best:
            num_string = f"\\bm{{{num_string}}}"
    return f"${num_string}$"

def col_name(name, align):
    return f"\\multicolumn{{1}}{{{align}}}{{{name}}}"

def create_table(data, prefix):
    print("\\begin{tabular}{l|rrrrrr}")
    print(f"    {col_name('Model', 'l')} & {col_name('Worst U/R Pearson', 'c')} & {col_name('psLML', 'c')} & {col_name('LML', 'c')} & {col_name('MSE', 'c')} & {col_name('QCE', 'c')} & {col_name('sQCE', 'c')} \\\\")
    print("    \\hline")

    best_pearson, best_pearson_std = 0, 0
    best_ll, best_ll_std = -1000, 0
    best_lml, best_lml_std = -1000, 0
    best_mse, best_mse_std = 1000, 0
    best_qce, best_qce_std = 1000, 0
    best_sqce, best_sqce_std = 1000, 0

    for algo, name in algo_names:
        row = data[data["model"] == algo]
        
        if float(row[prefix + "pearson"]) > best_pearson:
            best_pearson = float(row[prefix + "pearson"])
            best_pearson_std = float(row[prefix + "pearson_std"])

        if float(row[prefix + "ll"]) > best_ll:
            best_ll = float(row[prefix + "ll"])
            best_ll_std = float(row[prefix + "ll_std"])

        if float(row[prefix + "lml"]) > best_lml:
            best_lml = float(row[prefix + "lml"])
            best_lml_std = float(row[prefix + "lml_std"])

        if float(row[prefix + "mse"]) > best_mse:
            best_mse = float(row[prefix + "mse"])
            best_mse_std = float(row[prefix + "mse_std"])
        
        if float(row[prefix + "qce"]) < best_qce:
            best_qce = float(row[prefix + "qce"])
            best_qce_std = float(row[prefix + "qce_std"])
        
        if abs(float(row[prefix + "sqce"])) < best_sqce:
            best_sqce = abs(float(row[prefix + "sqce"]))
            best_sqce_std = float(row[prefix + "sqce_std"])

    best_pearson -= best_pearson_std
    best_ll -= best_ll_std
    best_lml -= best_lml_std
    best_mse += best_mse_std
    best_qce += best_qce_std
    best_sqce = abs(best_sqce) + best_sqce_std

    for algo, name in algo_names:
        row = data[data["model"] == algo]
        print(f"    {name} & {num(row[prefix + 'pearson'], row[prefix + 'pearson_std'], best_pearson, 'max')} & {num(row[prefix + 'll'], row[prefix + 'll_std'], best_ll, 'max')} & {num(row[prefix + 'lml'], row[prefix + 'lml_std'], best_lml, 'max')} & {num(row[prefix + 'mse'], row[prefix + 'mse_std'], best_mse, 'min')} & {num(row[prefix + 'qce'], row[prefix + 'qce_std'], best_qce, 'min')} & {num(row[prefix + 'sqce'], row[prefix + 'sqce_std'], best_sqce, 'zero')} \\\\")
    print("\\end{tabular}")
create_table(data, "ood ")