In [1]:
import os
from collections import defaultdict
from pprint import pprint

import numpy as np
import yaml
from easydict import EasyDict
from tqdm import tqdm

import wandb

In [21]:
models = ["VGG19"]
datasets = ["CIFAR10"]
initializations = ["baseline", "from-adversarial"]
regularizations = ["no-regularisation", "dropout", "augmentation", "weight_decay"]

In [56]:
api = wandb.Api()

with open("../configs/base_config.yaml", "r") as file:
    config = EasyDict(yaml.safe_load(file))

early_stopping = True

seeds = [43, 91, 17]

numbers = {init: {reg: defaultdict(list) for reg in regularizations} for init in initializations}

In [57]:
def get_statistics(run):
    summary_metrics = run.summary_metrics
    logged_metrics = run.history(keys=["train/loss", "train/acc", "train/ECE", "eval/loss", "eval/acc", "eval/ECE"], pandas=False)

    output = {}

    # if early_stopping:
    #     eval_accuracies = [res["eval/acc"] for res in logged_metrics]
    #     best_epoch = np.argmax(eval_accuracies)
    #     for key, value in logged_metrics[best_epoch].items():
    #         if key.startswith("eval/") or key.startswith("train/"):
    #             output[key] = value
    # else:
    #     for key, value in logged_metrics[-1].items():
    #         if key.startswith("eval/") or key.startswith("train/"):
    #             output[key] = value


    for key, value in summary_metrics.items():
        if early_stopping and key.startswith("measures_early_stopping/"):
            output[key] = value
        elif not early_stopping and key.startswith("measures/"):
            output[key] = value
    return output
            

In [59]:
for init in initializations:
    numbers[init]["dropout"] = {}

    for p in [0.1, 0.5]:
        numbers[init]["dropout"][p] = defaultdict(list)

    numbers[init]["weight_decay"] = {}
    for wd in [0.05, 0.1, 0.15]:
        numbers[init]["weight_decay"][wd] = defaultdict(list)

runs = api.runs(f"{config.wandb.entity}/{config.wandb.project}")

for init in initializations:
    for reg in tqdm(regularizations):
            
        for seed in seeds:

            run_name = f"igor-VGG19-CIFAR10-{init}-with-{reg}-seed-{seed}"

            if reg == "weight_decay":

                for wd in numbers[init][reg]:
                    print(wd)
                    run_name_1 = f"igor-VGG19-CIFAR10-{init}-with-no-regularisation-seed-{seed}-v2"
                    run_name_2 = f"igor-VGG19-CIFAR10-{init}-with-{reg}{str(wd)}-seed-{seed}-v2"
                    for run in runs:                
                        if (run.name == run_name_1 or run.name == run_name_2) and run.config["weight_decay"] == wd:
                            stats = get_statistics(run)

                            for key, value in stats.items():
                                numbers[init][reg][wd][key].append(value)
                            break
                    else:
                        raise Exception(f"{init} {reg} {wd} {seed}")
            elif reg == "dropout":
                for p in numbers[init][reg]:
                    run_name_1 = f"igor-VGG19-CIFAR10-{init}-with-{reg}-seed-{seed}-v2"
                    run_name_2 = f"igor-VGG19-CIFAR10-{init}-with-{reg}_{str(p)}-seed-{seed}-v2"
                    for run in runs:
                        if (run.name == run_name_1 or run.name == run_name_2) and run.config["dropout"] == p:
                            stats = get_statistics(run)

                            for key, value in stats.items():
                                numbers[init][reg][p][key].append(value)
                            break
                    else:
                        raise Exception(f"{init} {reg} {p} {seed}")
            else:
                run_name = f"igor-VGG19-CIFAR10-{init}-with-{reg}-seed-{seed}-v2"
                for run in runs:
                    if run.name == run_name and (run.config["weight_decay"] == 0 or "weight_decay" not in run.config) and (run.config["dropout"] == 0.0 or "dropout" not in run.config):
                        stats = get_statistics(run)
                        if reg == "no-regularisation":
                            print(run.name)
                            print(stats['measures_early_stopping/relative_flatness'])
                        for key, value in stats.items():
                            numbers[init][reg][key].append(value)
                        break
                else:
                    raise Exception(f"{init} {reg} {seed}")

  0%|          | 0/4 [00:00<?, ?it/s]

igor-VGG19-CIFAR10-baseline-with-no-regularisation-seed-43-v2
61.091068246692885
igor-VGG19-CIFAR10-baseline-with-no-regularisation-seed-91-v2
195.50387235102244


 25%|██▌       | 1/4 [00:00<00:01,  1.60it/s]

igor-VGG19-CIFAR10-baseline-with-no-regularisation-seed-17-v2
120.02682561606343


 75%|███████▌  | 3/4 [00:02<00:00,  1.31it/s]

0.05
0.1
0.15
0.05
0.1
0.15
0.05
0.1


100%|██████████| 4/4 [00:04<00:00,  1.03s/it]


0.15


  0%|          | 0/4 [00:00<?, ?it/s]

igor-VGG19-CIFAR10-from-adversarial-with-no-regularisation-seed-43-v2
0.7096637757824737


 25%|██▌       | 1/4 [00:00<00:01,  1.77it/s]

igor-VGG19-CIFAR10-from-adversarial-with-no-regularisation-seed-91-v2
3.65132278756937
igor-VGG19-CIFAR10-from-adversarial-with-no-regularisation-seed-17-v2
26.620046064374037


 75%|███████▌  | 3/4 [00:02<00:00,  1.18it/s]

0.05
0.1
0.15
0.05
0.1
0.15
0.05
0.1
0.15


100%|██████████| 4/4 [00:04<00:00,  1.14s/it]


In [42]:
for init in initializations:
    for reg in regularizations:
        if reg == "weight_decay":
            if len(numbers[init][reg]) == 0:
                del numbers[init][reg]
            for wd in numbers[init][reg]:
                numbers[init][reg+"_"+str(wd)] = numbers[init][reg][wd]
            del numbers[init][reg]
        elif reg == "dropout":
            if len(numbers[init][reg]) == 0:
                del numbers[init][reg]
            for p in numbers[init][reg]:
                numbers[init][reg+"_"+str(p)] = numbers[init][reg][p]
            del numbers[init][reg]

In [43]:
output = {init: {reg: {} for reg in numbers["baseline"]} for init in numbers}

for init, reg_results in numbers.items():
    for reg, results in reg_results.items():

        for key, values in dict(results).items():
            mean = np.mean(values)
            std = np.std(values)
            output[init][reg][key] = (mean, std)

In [44]:
prefix = r"""\begin{table}
\centering
\begin{tabular}{l>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}}
\toprule
"""
suffix = r"""\bottomrule
\end{tabular}
\caption{Various measures for VGG-19 on CIFAR-10}
\label{tab:measures}
\end{table}
"""

In [45]:
table = prefix
measures = ['measures/squared_euclidean_norm', 'measures/fisher_rao_norm', 'measures/relative_flatness', 'measures/hessian_trace', 'measures/max_hessian_eigenvalue']
if early_stopping:
    measures = [m.replace("measures/", "measures_early_stopping/") for m in measures]
    
for measure in measures:
    measure = measure.replace("_", " ").split("/")[1].capitalize()
    table += f"& {measure} "

In [46]:
measures

['measures/squared_euclidean_norm',
 'measures/fisher_rao_norm',
 'measures/relative_flatness',
 'measures/hessian_trace',
 'measures/max_hessian_eigenvalue']

In [47]:
for init in output:
    table += f"\\\\ \\midrule \n \multicolumn{{6}}{{c}}{{\\textbf{{{init.capitalize()}}}}} \\\\ \\midrule \n"
    for reg in output["baseline"]:
        table += f"\\textit{{{reg.replace('-', ' ').replace('_', ' ').capitalize()}}} "
        for measure in measures:
            if measure not in output[init][reg]:
                table += "& - "
            mean, std = output[init][reg][measure]


            if not "hessian" in measure.lower():
                table += f"& {mean:.2f}" + r"$_{\pm " + f"{std:.2f}" + r"}$ "
            else:
                mean = f"{mean:.2e}".replace("+", "").replace("e0", "e")
                std = f"{std:.2e}".replace("+", "").replace("e0", "e").replace("e-0", "e-").replace("e", r"\text{e}")
                table += f"& {mean}" + r"$_{\pm " + f"{std}" + r"}$ "
        table += "\\\\ \n"

In [48]:
table += suffix

In [49]:
print(table)

\begin{table}
\centering
\begin{tabular}{l>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}}
\toprule
& Squared euclidean norm & Fisher rao norm & Relative flatness & Hessian trace & Max hessian eigenvalue \\ \midrule 
 \multicolumn{6}{c}{\textbf{Baseline}} \\ \midrule 
\textit{No regularisation} & 7.11$_{\pm 0.15}$ & 0.02$_{\pm 0.00}$ & 0.53$_{\pm 0.01}$ & 2.67e2$_{\pm 1.46\text{e}1}$ & 8.97e0$_{\pm 6.28\text{e}-1}$ \\ 
\textit{Augmentation} & 6.45$_{\pm 0.03}$ & 3.47$_{\pm 0.20}$ & 67.69$_{\pm 1.54}$ & 4.44e4$_{\pm 1.94\text{e}3}$ & 1.78e3$_{\pm 3.32\text{e}2}$ \\ 
\textit{Dropout 0.1} & 7.58$_{\pm 0.09}$ & 0.02$_{\pm 0.00}$ & 0.49$_{\pm 0.01}$ & 2.04e2$_{\pm 1.05\text{e}1}$ & 7.65e0$_{\pm 2.64\text{e}-1}$ \\ 
\textit{Dropout 0.5} & 9.67$_{\pm 0.09}$ & 0.03$_{\pm 0.00}$ & 0.58$_{\pm 0.02}$ & 2.48e2$_{\pm 8.12\text{e}1}$ & 9.0

In [235]:
pprint(output)

{'baseline': {'augmentation': {},
              'dropout_0.1': {},
              'dropout_0.5': {'eval/ECE': (0.15715893109639487,
                                           0.0022362117710209646),
                              'eval/acc': (0.7083333333333334,
                                           0.02946278254943948),
                              'eval/loss': (1.0410934478044511,
                                            0.005464427414741201),
                              'train/ECE': (0.001004324876703322,
                                            0.0002477828758228618),
                              'train/acc': (1.0, 0.0),
                              'train/loss': (0.0012533368187802876,
                                             0.00034717178729339103)},
              'no-regularisation': {},
              'weight_decay_0.05': {'eval/ECE': (0.07135884712139766,
                                                 0.0028329768334999015),
                                 

In [17]:
table = prefix
measures = ["train/loss", "train/acc", "train/ECE", "eval/loss", "eval/acc", "eval/ECE"]

In [18]:
for init in output:
    table += f"\\\\ \\midrule \n \multicolumn{{7}}{{c}}{{\\textbf{{{init.capitalize()}}}}} \\\\ \\midrule \n"
    for reg in output["baseline"]:
        if len(output[init][reg]) == 0:
            continue
        table += f"\\\\ \n \\textit{{{reg.replace('-', ' ').capitalize()}}} "
        for measure in measures:
            mean, std = output[init][reg][measure]
            table += f"& {mean:.2f}" + r"$_{\pm " + f"{std:.2f}" + r"}$ "
        table += "\n"

KeyError: 'train/loss'

In [19]:
table += suffix

In [20]:
print(table)

\begin{table}
\centering
\begin{tabular}{l>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}>{\raggedleft\arraybackslash}p{2cm}}
\toprule
\\ \midrule 
 \multicolumn{7}{c}{\textbf{Baseline}} \\ \midrule 
\\ 
 \textit{No regularisation} \bottomrule
\end{tabular}
\caption{Various measures for VGG-19 on CIFAR-10}
\label{tab:measures}
\end{table}

