In [61]:
import os
from collections import defaultdict
from pprint import pprint

import numpy as np
import yaml
from easydict import EasyDict

import wandb

In [62]:
models = ["VGG19"]
datasets = ["CIFAR10"]
initializations = ["baseline", "from-adversarial"]
regularizations = ["no-regularisation", "dropout", "augmentation"]

In [63]:
api = wandb.Api()

with open("../configs/base_config.yaml", "r") as file:
    config = EasyDict(yaml.safe_load(file))

early_stopping = False

seeds = [43, 91, 17]
numbers = {init: {reg: defaultdict(list) for reg in regularizations} for init in initializations}

runs = api.runs(f"{config.wandb.entity}/{config.wandb.project}")

for init in initializations:
    for reg in regularizations:
            
        for seed in seeds:

            run_name = f"igor-VGG19-CIFAR10-{init}-with-{reg}-seed-{seed}"

            for run in runs:
                if run.name == run_name:
                    break
            else:
                raise Exception

            summary_metrics = run.summary_metrics
            logged_metrics = run.history(keys=["train/loss", "train/acc", "train/ECE", "eval/loss", "eval/acc", "eval/ECE"], pandas=False)

            if early_stopping:
                eval_accuracies = [res["eval/acc"] for res in logged_metrics]
                best_epoch = np.argmax(eval_accuracies)
                for key, value in logged_metrics[best_epoch].items():
                    if key.startswith("eval/") or key.startswith("train/"):
                        numbers[init][reg][key].append(value)
            else:
                for key, value in logged_metrics[-1].items():
                    if key.startswith("eval/") or key.startswith("train/"):
                        numbers[init][reg][key].append(value)

            for key, value in summary_metrics.items():
                if key.startswith("measures/"):
                    numbers[init][reg][key].append(value)
            

In [64]:
output = {init: {reg: {} for reg in regularizations} for init in initializations}

for init, reg_results in numbers.items():
    for reg, results in reg_results.items():

        for key, values in dict(results).items():
            mean = np.mean(values)
            std = np.std(values)
            output[init][reg][key] = (mean, std)

In [65]:
prefix = r"""\begin{table}
\centering
\begin{tabular}{l>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}}
\toprule
"""
suffix = r"""\bottomrule
\end{tabular}
\caption{Various measures for VGG-19 on CIFAR-10}
\label{tab:measures}
\end{table}
"""

In [66]:
table = prefix
measures = ['measures/pacbayes_flatness', 'measures/fisher_rao_norm', 'measures/relative_flatness', 'measures/hessian_trace', 'measures/max_hessian_eigenvalue', 'measures/squared_euclidean_norm']
for measure in measures:
    measure = measure.replace("_", " ").split("/")[1].capitalize()
    table += f"& {measure} "

In [67]:
for init in initializations:
    table += f"\\\\ \\midrule \n \multicolumn{{7}}{{c}}{{\\textbf{{{init.capitalize()}}}}} \\\\ \\midrule \n"
    for reg in regularizations:
        
        table += f"\\textit{{{reg.replace('-', ' ').capitalize()}}} "
        for measure in measures:
            if measure not in output[init][reg]:
                table += "& - "
            mean, std = output[init][reg][measure]
            table += f"& {round(mean, 2)}" + r"$_{\pm " + f"{round(std, 2)}" + r"}$ "
        table += "\\\\ \n"

In [68]:
table += suffix

In [69]:
print(table)

\begin{table}
\centering
\begin{tabular}{l>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}}
\toprule
& Pacbayes flatness & Fisher rao norm & Relative flatness & Hessian trace & Max hessian eigenvalue & Squared euclidean norm \\ \midrule 
 \multicolumn{7}{c}{\textbf{Baseline}} \\ \midrule 
\textit{No regularisation} & 0.25$_{\pm 0.0}$ & 13.54$_{\pm 0.6}$ & 17.86$_{\pm 1.23}$ & 19.37$_{\pm 1.83}$ & 2.64$_{\pm 0.62}$ & 7.11$_{\pm 0.15}$ \\ 
\textit{Dropout} & 0.25$_{\pm 0.0}$ & 14.41$_{\pm 0.35}$ & 13.47$_{\pm 0.16}$ & 15.88$_{\pm 0.59}$ & 2.45$_{\pm 0.43}$ & 7.58$_{\pm 0.09}$ \\ 
\textit{Augmentation} & 0.25$_{\pm 0.0}$ & 9.7$_{\pm 0.64}$ & 26.82$_{\pm 2.35}$ & 35.76$_{\pm 0.87}$ & 4.51$_{\pm 0.3}$ & 6.45$_{\pm 0.03}$ \\ 
\\ \midrule 
 \multicolumn{7}{c}{\textbf{From-adversarial}} \\ \midrule 
\textit{No regularisation} & 0.25$_

In [70]:
table = prefix
measures = []
measures = ["Pacbayes flatness", "Fisher rao norm", "Relative flatness", "Hessian trace", "Max hessian eigenvalue", "Squared euclidean norm"]

for measure in output["baseline"]["no-regularisation"]:
    if measure.startswith("eval/") or measure.startswith("train/"):
        measures.append(measure)
        table += f"& {measure} "

In [71]:
for init in initializations:
    table += f"\\\\ \\midrule \n \multicolumn{{7}}{{c}}{{\\textbf{{{init.capitalize()}}}}} \\\\ \\midrule \n"
    for reg in regularizations:
        table += f"\\\\ \n \\textit{{{reg.replace('-', ' ').capitalize()}}} "
        for measure in measures:
            mean, std = output[init][reg][measure]
            table += f"& {round(mean, 2)}" + r"$_{\pm " + f"{round(std, 2)}" + r"}$ "
        table += "\n"

KeyError: 'Pacbayes flatness'

In [None]:
table += suffix

In [None]:
print(table)

\begin{table}
\centering
\begin{tabular}{l>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}}
\toprule
& train/loss & train/acc & train/ECE & eval/loss & eval/acc & eval/ECE \\ \midrule 
 \multicolumn{7}{c}{\textbf{Baseline}} \\ \midrule 
\\ 
 \textit{No regularisation} & 0.51$_{\pm 0.35}$ & 0.8$_{\pm 0.1}$ & 0.01$_{\pm 0.0}$ & 0.98$_{\pm 0.09}$ & 0.79$_{\pm 0.03}$ & 0.09$_{\pm 0.06}$ 
\\ 
 \textit{Dropout} & 0.12$_{\pm 0.03}$ & 0.97$_{\pm 0.02}$ & 0.01$_{\pm 0.0}$ & 1.08$_{\pm 0.05}$ & 0.85$_{\pm 0.08}$ & 0.17$_{\pm 0.01}$ 
\\ 
 \textit{Augmentation} & 0.41$_{\pm 0.15}$ & 0.87$_{\pm 0.04}$ & 0.01$_{\pm 0.0}$ & 0.85$_{\pm 0.04}$ & 0.85$_{\pm 0.03}$ & 0.09$_{\pm 0.01}$ 
\\ \midrule 
 \multicolumn{7}{c}{\textbf{From-adversarial}} \\ \midrule 
\\ 
 \textit{No regularisation} & 0.02$_{\pm 0.02}$ & 1.0$_{\pm 0.01}$ & 0.0$_{\pm 0.0}$ 