In [None]:
from functools import partial

import numpy as np
import pandas as pd

from src.utils.latex import bold_formatter
from src.utils.wandb import get_runs


def printnames(runs):
    print(f"{len(runs)} runs:")
    for run in runs:
        print(run.name)

In [None]:
def table_for_runs(runs, acc_key="test/acc", baseline_acc=0.7301):
    loss_types = ["switch" if "switch" in run.tags else "entropy" for run in runs]
    architectures = ["BlockMoE" if "block" in run.name else "ConvMoE" for run in runs]
    gates = ["CGARN" if "CGARN" in run.name else "GALRN" for run in runs]
    ks = np.array([int(run.name[-1]) for run in runs])

    keys = list(tuple(x) for x in zip(loss_types, architectures, gates, ks))
    accs = {key: run.summary[acc_key] for key, run in zip(keys, runs)}

    def filter_by_k(it, k):
        return [item for idx, item in enumerate(it) if ks[idx] == k]

    loss_types = filter_by_k(loss_types, k=1)
    architectures = filter_by_k(architectures, k=1)
    gates = filter_by_k(gates, k=1)

    accsks = [filter_by_k(accs, k=k) for k in (1, 2, 3, 4)]

    data = zip(loss_types, architectures, gates, *accsks)

    def sort_order(data, *order):
        for i in order:
            data = sorted(data, key=lambda o: o[i])
        return data

    k1keys = list(sorted(k[:-1] for k in keys if k[-1] == 1))

    def fill_row(loss, arch, gate):
        return [accs.get((loss, arch, gate, k), None) for k in range(1, 5)]

    data = [(*k1key, *fill_row(*k1key)) for k1key in k1keys]

    # data = sort_order(data, 2, 1, 0)
    columns = ["Loss Type", "Architecture", "Gate Type", "Accuracy k=1", "k=2", "k=3", "k=4"]
    df = pd.DataFrame(
        data=data,
        # index=index,
        columns=columns,
    )

    fmts_max_4f = {
        column: partial(bold_formatter, value=(baseline_acc or 0) + 1e-4, num_decimals=4)
        for column in columns[3:]
    }
    print(
        df.to_latex(
            formatters=fmts_max_4f if baseline_acc else None, escape=False, na_rep="", index=False
        )
    )

In [None]:
runs = get_runs(project="robust-cifar100-resnet-moe")
runs = [run for run in runs if run.state == "finished"]
runs = [run for run in runs if run.config.get("model.model.num_experts") == 4]
runs = [run for run in runs if run.config.get("model.model.expert_capacity") == "None"]

In [None]:
std_runs = [run for run in runs if "Cross" in run.config.get("model._target_")]
robust_runs = [run for run in runs if "pgd" in run.config.get("model._target_")]

In [None]:
table_for_runs(std_runs, "test/acc", baseline_acc=0.7301)
# table_for_runs(std_runs, "test/main_loss", baseline_acc=None)
table_for_runs(std_runs, "attack/acc", baseline_acc=1e-4)
# table_for_runs(std_runs, "attack/main_loss", baseline_acc=None)

In [None]:
table_for_runs(robust_runs, baseline_acc=0.5232)
table_for_runs(robust_runs, "attack/acc", baseline_acc=0.178)

In [None]:
# printnames(robust_runs)
print([(run.name, run.summary["attack/acc"]) for run in robust_runs])

In [None]:
runs[0].config["model.model.expert_capacity"]