In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../")

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

DATA_PATH = "../data"

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

from src.log_mock import PrintLog
log = PrintLog()

import wandb

wandb.init(mode="disabled")

In [None]:
wapi = wandb.Api()
runs = wapi.runs("bayes/fmow")

In [None]:
for run in runs:
    print(run.name, run.summary.keys())

In [None]:
import plotly.express as px
import pandas as pd
import dateutil
import datetime

def create_plot_data_for_run(run):
    parts = run.name.split("-")
    if len(parts) > 2:
        model_name = parts[0] + "-" + parts[1]
    else:
        model_name = parts[0]

    worst_acc = 1
    worst_acc_group = "None"
    for name, results in run.summary["test_results"].items():
        if "region" in name and name != "worst_region_acc":
            if results["accuracy"] < worst_acc:
                worst_acc = results["accuracy"]
                worst_acc_group = name

    return {
        "model": model_name,
        "worst_region_acc": run.summary["test_results"]["worst_region_acc"],
        "all accuracy": run.summary["test_results"]["all"]["accuracy"],
        "all log likelihood": run.summary["test_results"]["all"]["log_likelihood"],
        "all ece": run.summary["test_results"]["all"]["ece"],
        "all sece": run.summary["test_results"]["all"]["sece"],
        "worst_acc accuracy": run.summary["test_results"][worst_acc_group]["accuracy"],
        "worst_acc sece": run.summary["test_results"][worst_acc_group]["sece"],
        "worst_acc ece": run.summary["test_results"][worst_acc_group]["ece"],
        "worst_acc log_likelihood": run.summary["test_results"][worst_acc_group]["log_likelihood"]
    }

def plot(data, value):
    plot = px.box(data, x="model", y=value, color="model")
    return plot

def pareto_plot(data, x, y):
    plot = px.scatter(data, x=x, error_x=f"{x}_std", y=y, error_y=f"{y}_std", color="model")
    return plot

def build_data(runs):
    rows = []
    for run in runs:
        if dateutil.parser.parse(run.created_at) < datetime.datetime(2023, 3, 10, 10, 0):
            continue
        if run.state != "finished":
            continue
        if "old" in run.tags:
            print("Skipping old run " + run.name)
            continue
        if "test_results" not in run.summary:
            print("Skipping crashed run " + run.name)
            continue
        rows.append(create_plot_data_for_run(run))
    return pd.DataFrame.from_dict(rows)

def aggregate_data(data):
    aggregated_data = data.groupby(["model"]).agg({
        "model": "first",
        "worst_region_acc": ["mean", "sem"],
        "all accuracy": ["mean", "sem"],
        "all log likelihood": ["mean", "sem"], 
        "all sece": ["mean", "sem"],
        "all ece": ["mean", "sem"],
        "worst_acc accuracy": ["mean", "sem"],
        "worst_acc sece": ["mean", "sem"],
        "worst_acc ece": ["mean", "sem"],
        "worst_acc log_likelihood": ["mean", "sem"],
    })
    aggregated_data.columns = [a[0] + "_std" if a[1] == "sem" else a[0] for a in aggregated_data.columns.to_flat_index()]
    aggregated_data["worst_region_acc_std"] *= 2.0
    aggregated_data["all accuracy_std"] *= 2.0
    aggregated_data["all log likelihood_std"] *= 2.0
    aggregated_data["all sece_std"] *= 2.0
    aggregated_data["all ece_std"] *= 2.0
    aggregated_data["worst_acc accuracy_std"] *= 2.0
    aggregated_data["worst_acc sece_std"] *= 2.0
    aggregated_data["worst_acc ece_std"] *= 2.0
    aggregated_data["worst_acc log_likelihood_std"] *= 2.0
    return aggregated_data

In [None]:
data = aggregate_data(build_data(runs))

In [None]:
data

In [None]:
pareto_plot(data, "worst_acc accuracy", "worst_acc sece")

In [None]:
pareto_plot(data, "all accuracy", "all sece")

In [None]:
data.to_csv(sep=",", header=True)

In [None]:
algo_names = [
    ("map-1", "MAP"),
    ("map-5", "Deep Ensemble"),
    ("mcd_p0.1-1", "MCD"),
    ("mcd-5", "MultiMCD"),
    ("swag-1", "SWAG"),
    ("swag-5", "MultiSWAG"),
    ("swag_ll-1", "LL SWAG"),
    ("laplace-1", "LL Laplace"),
    ("laplace-5", "LL MultiLaplace"),
    ("bbb-1", "LL BBB"),
    ("bbb-5", "LL MultiBBB"),
    ("rank1-1", "Rank-1 VI"),
    ("ll_ivon-1", "LL iVON"),
    ("ll_ivon-5", "LL MultiiVON"),
    ("svgd-1", "SVGD"),
    ("sngp", "SNGP"),
]

def num(value, std, best=None, ty=None):
    value = float(value)
    std = float(std)
    num_string = f"{value:.3f} \\pm {std:.3f}"

    if best is None or ty is None:
        return f"${num_string}$"

    if ty == "max":
        if value >= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "min":
        if value <= best:
            num_string = f"\\bm{{{num_string}}}"
    elif ty == "zero":
        if abs(value) <= best:
            num_string = f"\\bm{{{num_string}}}"
    return f"${num_string}$"

def col_name(name, align):
    return f"\\multicolumn{{1}}{{{align}}}{{{name}}}"

def create_table(data, prefix):
    print("\\begin{tabular}{l|rrrrrr}")
    print(f"    {col_name('Model', 'l')} & {col_name('WR Accuracy', 'c')} & {col_name('WR ECE', 'c')} & {col_name('WR sECE', 'c')} & {col_name('Avg Accuracy', 'c')} & {col_name('Avg ECE', 'c')} & {col_name('Avg sECE', 'c')} \\\\")
    print("    \\hline")

    best_acc, best_acc_std = 0, 0
    best_ece, best_ece_std = 1000, 0
    best_sece, best_sece_std = 1000, 0
    best_avg_acc, best_avg_acc_std = 0, 0
    best_avg_ece, best_avg_ece_std = 1000, 0
    best_avg_sece, best_avg_sece_std = 1000, 0

    for algo, name in algo_names:
        row = data[data["model"] == algo]

        if float(row[prefix + "worst_acc accuracy"]) > best_acc:
            best_acc = float(row[prefix + "worst_acc accuracy"])
            best_acc_std = float(row[prefix + "worst_acc accuracy_std"])
        
        if float(row[prefix + "worst_acc ece"]) < best_ece:
            best_ece = float(row[prefix + "worst_acc ece"])
            best_ece_std = float(row[prefix + "worst_acc ece_std"])
        
        if abs(float(row[prefix + "worst_acc sece"])) < best_sece:
            best_sece = abs(float(row[prefix + "worst_acc sece"]))
            best_sece_std = float(row[prefix + "worst_acc sece_std"])
        
        if float(row[prefix + "all accuracy"]) > best_avg_acc:
            best_avg_acc = float(row[prefix + "all accuracy"])
            best_avg_acc_std = float(row[prefix + "all accuracy_std"])
        
        if float(row[prefix + "all ece"]) < best_avg_ece:
            best_avg_ece = float(row[prefix + "all ece"])
            best_avg_ece_std = float(row[prefix + "all ece_std"])
        
        if abs(float(row[prefix + "all sece"])) < best_avg_sece:
            best_avg_sece = abs(float(row[prefix + "all sece"]))
            best_avg_sece_std = float(row[prefix + "all sece_std"])

    best_acc -= best_acc_std
    best_ece += best_ece_std
    best_sece = abs(best_sece) + best_sece_std

    best_avg_acc -= best_avg_acc_std
    best_avg_ece += best_avg_ece_std
    best_avg_sece = abs(best_avg_sece) + best_avg_sece_std

    for algo, name in algo_names:
        row = data[data["model"] == algo]
        print(f"    {name} & {num(row[prefix + 'worst_acc accuracy'], row[prefix + 'worst_acc accuracy_std'], best_acc, 'max')} & {num(row[prefix + 'worst_acc ece'], row[prefix + 'worst_acc ece_std'], best_ece, 'min')} & {num(row[prefix + 'worst_acc sece'], row[prefix + 'worst_acc sece_std'], best_sece, 'zero')} & {num(row[prefix + 'all accuracy'], row[prefix + 'all accuracy_std'], best_avg_acc, 'max')} & {num(row[prefix + 'all ece'], row[prefix + 'all ece_std'], best_avg_ece, 'min')} & {num(row[prefix + 'all sece'], row[prefix + 'all sece_std'], best_avg_sece, 'zero')} \\\\")
    print("\\end{tabular}")
create_table(data, "")