In [1]:
import os
from collections import defaultdict
from pprint import pprint

import numpy as np
import yaml
from easydict import EasyDict

import wandb

In [146]:
# igor-VGG19-CIFAR10-baseline-with-augmentation-seed-43

models = ["VGG19"]
datasets = ["CIFAR10"]
initializations = ["baseline", "from-adversarial"]
regularizations = ["no-regularisation", "dropout", "augmentation"]

In [147]:
api = wandb.Api()

with open("../configs/base_config.yaml", "r") as file:
    config = EasyDict(yaml.safe_load(file))



seeds = [43, 91, 17]
numbers = {init: {reg: defaultdict(list) for reg in regularizations} for init in initializations}

runs = api.runs(f"{config.wandb.entity}/{config.wandb.project}")

for init in initializations:
    for reg in regularizations:
            
        for seed in seeds:

            run_name = f"igor-VGG19-CIFAR10-{init}-with-{reg}-seed-{seed}"
            print(run_name)
            # run_name = run_name_template + str(seed)

            for run in runs:
                if run.name == run_name:
                    break
            else:
                continue

            summary_metrics = run.summary_metrics

            logged_metrics = run.history(keys=["train/loss", "train/acc", "train/ECE", "eval/loss", "eval/acc", "eval/ECE"], pandas=False)

            for key, value in logged_metrics[-1].items():
                if key.startswith("eval/") or key.startswith("train/"):
                    numbers[init][reg][key].append(value)

            for key, value in summary_metrics.items():
                if key.startswith("measures/"):
                    numbers[init][reg][key].append(value)

pprint(numbers)

igor-VGG19-CIFAR10-baseline-with-no-regularisation-seed-43
igor-VGG19-CIFAR10-baseline-with-no-regularisation-seed-91
igor-VGG19-CIFAR10-baseline-with-no-regularisation-seed-17
igor-VGG19-CIFAR10-baseline-with-dropout-seed-43
igor-VGG19-CIFAR10-baseline-with-dropout-seed-91
igor-VGG19-CIFAR10-baseline-with-dropout-seed-17
igor-VGG19-CIFAR10-baseline-with-augmentation-seed-43
igor-VGG19-CIFAR10-baseline-with-augmentation-seed-91
igor-VGG19-CIFAR10-baseline-with-augmentation-seed-17
igor-VGG19-CIFAR10-from-adversarial-with-no-regularisation-seed-43
igor-VGG19-CIFAR10-from-adversarial-with-no-regularisation-seed-91
igor-VGG19-CIFAR10-from-adversarial-with-no-regularisation-seed-17
igor-VGG19-CIFAR10-from-adversarial-with-dropout-seed-43
igor-VGG19-CIFAR10-from-adversarial-with-dropout-seed-91
igor-VGG19-CIFAR10-from-adversarial-with-dropout-seed-17
igor-VGG19-CIFAR10-from-adversarial-with-augmentation-seed-43
igor-VGG19-CIFAR10-from-adversarial-with-augmentation-seed-91
igor-VGG19-CIFAR10

In [148]:
output = {init: {reg: {} for reg in regularizations} for init in initializations}

for init, reg_results in numbers.items():
    for reg, results in reg_results.items():
        print("-----")
        print(reg)
        for key, values in dict(results).items():
            mean = np.mean(values)
            std = np.std(values)
            output[init][reg][key] = (mean, std)
            print(f"{key}: {round(mean, 2)} +- {round(std, 2)}")

-----
no-regularisation
train/loss: 0.0 +- 0.0
train/acc: 1.0 +- 0.0
train/ECE: 0.0 +- 0.0
eval/loss: 1.05 +- 0.01
eval/acc: 0.75 +- 0.05
eval/ECE: 0.15 +- 0.0
measures/pacbayes_flatness: 0.25 +- 0.0
measures/fisher_rao_norm: 13.54 +- 0.6
measures/relative_flatness: 17.86 +- 1.23
measures/hessian_trace: 19.37 +- 1.83
measures/max_hessian_eigenvalue: 2.64 +- 0.62
measures/squared_euclidean_norm: 7.11 +- 0.15
-----
dropout
train/loss: 0.0 +- 0.0
train/acc: 1.0 +- 0.0
train/ECE: 0.0 +- 0.0
eval/loss: 1.06 +- 0.04
eval/acc: 0.73 +- 0.12
eval/ECE: 0.16 +- 0.01
measures/squared_euclidean_norm: 7.58 +- 0.09
measures/max_hessian_eigenvalue: 2.45 +- 0.43
measures/fisher_rao_norm: 14.41 +- 0.35
measures/pacbayes_flatness: 0.25 +- 0.0
measures/relative_flatness: 13.47 +- 0.16
measures/hessian_trace: 15.88 +- 0.59
-----
augmentation
train/loss: 0.27 +- 0.0
train/acc: 0.89 +- 0.04
train/ECE: 0.0 +- 0.0
eval/loss: 0.82 +- 0.04
eval/acc: 0.81 +- 0.05
eval/ECE: 0.1 +- 0.01
measures/pacbayes_flatness: 

In [149]:
prefix = r"""\begin{table}
\centering
\begin{tabular}{l>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}}
\toprule
"""
suffix = r"""\bottomrule
\end{tabular}
\caption{Various measures for VGG-19 on CIFAR-10}
\label{tab:measures}
\end{table}
"""

In [155]:
table = prefix
measures = []

for measure in output["baseline"]["no-regularisation"]:
    if measure.startswith("measures/"):
        measures.append(measure)
        measure = measure.replace("_", " ").split("/")[1].capitalize()
        table += f"& {measure} "

In [156]:
for init in initializations:
    table += f"\\\\ \\midrule \n \multicolumn{{7}}{{c}}{{\\textbf{{{init.capitalize()}}}}} \\\\ \\midrule \n"
    for reg in regularizations:
        
        table += f"\\textit{{{reg.replace('-', ' ').capitalize()}}} "
        for measure in measures:
            if measure not in output[init][reg]:
                table += "& - "
            mean, std = output[init][reg][measure]
            table += f"& {round(mean, 2)}" + r"$_{\pm " + f"{round(std, 2)}" + r"}$ "
        table += "\\\\ \n"

KeyError: 'measures/pacbayes_flatness'

In [None]:
table += suffix

In [None]:
print(table)

\begin{table}
\centering
\begin{tabular}{l>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}}
\toprule
& Pacbayes flatness & Max hessian eigenvalue & Fisher rao norm & Hessian trace & Relative flatness & Squared euclidean norm \\ \midrule 
 \multicolumn{7}{c}{\textbf{Baseline}} \\ \midrule 
\textit{No regularisation} & 0.25$_{\pm 0.0}$ & 2.64$_{\pm 0.62}$ & 13.54$_{\pm 0.6}$ & 19.37$_{\pm 1.83}$ & 17.86$_{\pm 1.23}$ & 7.11$_{\pm 0.15}$ \\ 
\textit{Dropout} & 0.25$_{\pm 0.0}$ & 2.45$_{\pm 0.43}$ & 14.41$_{\pm 0.35}$ & 15.88$_{\pm 0.59}$ & 13.47$_{\pm 0.16}$ & 7.58$_{\pm 0.09}$ \\ 
\textit{Augmentation} & 0.25$_{\pm 0.0}$ & 4.51$_{\pm 0.3}$ & 9.7$_{\pm 0.64}$ & 35.76$_{\pm 0.87}$ & 26.82$_{\pm 2.35}$ & 6.45$_{\pm 0.03}$ \\ 
\bottomrule
\end{tabular}
\caption{Various measures for VGG-19 on CIFAR-10}
\label{tab:measures}
\end{table}

In [127]:
table = prefix
measures = []

for measure in output["baseline"]["no-regularisation"]:
    if measure.startswith("eval/") or measure.startswith("train/"):
        measures.append(measure)
        table += f"& {measure} "

In [128]:
for init in initializations:
    table += f"\\\\ \\midrule \n \multicolumn{{7}}{{c}}{{\\textbf{{{init.capitalize()}}}}} \\\\ \\midrule \n"
    for reg in regularizations:
        table += f"\\\\ \n \\textit{{{reg.replace('-', ' ').capitalize()}}} "
        for measure in measures:
            mean, std = output[init][reg][measure]
            table += f"& {round(mean, 2)}" + r"$_{\pm " + f"{round(std, 2)}" + r"}$ "
        table += "\n"

In [129]:
table += suffix

In [130]:
print(table)

\begin{table}
\centering
\begin{tabular}{l>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}}
\toprule
& train/loss & train/acc & train/ECE & eval/loss & eval/acc & eval/ECE \\ \midrule 
 \multicolumn{7}{c}{\textbf{Baseline}} \\ \midrule 
\\ 
 \textit{No regularisation} & 0.0$_{\pm 0.0}$ & 1.0$_{\pm 0.0}$ & 0.0$_{\pm 0.0}$ & 1.05$_{\pm 0.01}$ & 0.75$_{\pm 0.05}$ & 0.15$_{\pm 0.0}$ 
\\ 
 \textit{Dropout} & 0.0$_{\pm 0.0}$ & 1.0$_{\pm 0.0}$ & 0.0$_{\pm 0.0}$ & 1.06$_{\pm 0.04}$ & 0.73$_{\pm 0.12}$ & 0.16$_{\pm 0.01}$ 
\\ 
 \textit{Augmentation} & 0.27$_{\pm 0.0}$ & 0.89$_{\pm 0.04}$ & 0.0$_{\pm 0.0}$ & 0.82$_{\pm 0.04}$ & 0.81$_{\pm 0.05}$ & 0.1$_{\pm 0.01}$ 
\bottomrule
\end{tabular}
\caption{Various measures for VGG-19 on CIFAR-10}
\label{tab:measures}
\end{table}



In [121]:
output

{'baseline': {'no-regularisation': {'train/loss': (0.0003856863826711275,
    2.973792937898298e-05),
   'train/acc': (1.0, 0.0),
   'train/ECE': (0.00022782506130170077, 4.021206277513025e-05),
   'eval/loss': (1.0518516411383947, 0.010936258214849937),
   'eval/acc': (0.75, 0.05103103630798288),
   'eval/ECE': (0.15408625702063242, 0.0026177185062572048),
   'measures/pacbayes_flatness': (0.25001525948758285, 0.0),
   'measures/max_hessian_eigenvalue': (2.6366789738337197, 0.6248315679666351),
   'measures/fisher_rao_norm': (13.542055804056735, 0.6032366984390466),
   'measures/hessian_trace': (19.365692138671875, 1.8322403969558578),
   'measures/relative_flatness': (17.860666518747166, 1.227669159773781),
   'measures/squared_euclidean_norm': (7.108161631403131,
    0.14733285545739494)},
  'dropout': {'train/loss': (0.0004698470386131635, 1.4019319707461249e-05),
   'train/acc': (1.0, 0.0),
   'train/ECE': (0.00030279486478927237, 1.6366618901949782e-05),
   'eval/loss': (1.062583