In [None]:
from operator import itemgetter

import numpy as np
from matplotlib import pyplot as plt

from src.utils.wandb import get_runs


def printnames(runs):
    print(f"{len(runs)} runs:")
    for run in runs:
        print(run.name)

In [None]:
runs = get_runs(project="robust-cifar100-resnet-moe")
runs = [run for run in runs if run.state == "finished"]
runs = [run for run in runs if run.config.get("model.model.k") == 1]

In [None]:
def plots_for_runs(runs, figname, acc_key="test/acc", baseline_acc=0.7301, ylabel="Accuracy"):
    loss_types = ["switch" if "switch" in run.tags else "entropy" for run in runs]
    num_experts = [run.config["model.model.num_experts"] for run in runs]
    accs = [run.summary[acc_key] for run in runs]

    entropy_items = [idx for idx, loss in enumerate(loss_types) if loss == "entropy"]
    switch_items = [idx for idx in range(len(loss_types)) if idx not in entropy_items]

    data = list(zip(num_experts, accs))
    entropy_data = sorted(itemgetter(*entropy_items)(data), key=lambda t: t[0])
    switch_data = sorted(itemgetter(*switch_items)(data), key=lambda t: t[0])
    np.random.seed(1)

    # plt.xticks(num_experts_list)
    # plt.xlim(1.5,max(num_experts_list)+1)
    plt.ylabel(ylabel)
    plt.xlabel("Number of Experts")
    print(loss_types)

    legend = ["ResNet18 Baseline"]
    plt.axhline(baseline_acc, color="green", linestyle="--")
    plt.plot(*zip(*entropy_data), marker="x")
    legend.append("Entropy")
    plt.plot(*zip(*switch_data), marker="x")
    legend.append("Switch")

    plt.legend(legend)
    plt.semilogx(base=2)
    # plt.ylim(0.05, 0.3)
    plt.savefig(figname)
    plt.show()

In [None]:
def apply_filters(fs):
    return list(filter(lambda run: all(f(run) for f in fs), runs))


filters = [
    lambda run: "block" in run.name,
    lambda run: "CGARN" in run.name,
    lambda run: "CGARN" in run.name,
    lambda run: run.config.get("model.optimizer.lr") == 0.1,
    lambda run: run.config.get("trainer.max_epochs") == 200,
    lambda run: run.id != "1jsf0owy",
]
std_runs = apply_filters(filters + [lambda run: "Cross" in run.config.get("model._target_")])
pgd_runs = apply_filters(filters + [lambda run: "pgd" in run.config.get("model._target_")])
printnames(std_runs)
printnames(pgd_runs)

In [None]:
plots_for_runs(
    std_runs, figname="num_expert_plots/std_acc.png", acc_key="test/acc", baseline_acc=0.7301
)
plots_for_runs(
    std_runs,
    figname="num_expert_plots/std_adv_acc.png",
    acc_key="attack/acc",
    baseline_acc=1e-4,
    ylabel="Adversarial Accuracy",
)

In [None]:
plots_for_runs(
    pgd_runs, figname="num_expert_plots/robust_acc.png", acc_key="test/acc", baseline_acc=0.53
)
plots_for_runs(
    pgd_runs,
    figname="num_expert_plots/robust_adv_acc.png",
    acc_key="attack/acc",
    baseline_acc=0.178,
    ylabel="Adversarial Accuracy",
)

In [None]:
loss_types = ["switch" if "switch" in run.tags else "entropy" for run in std_runs]
num_experts = [run.config["model.model.num_experts"] for run in std_runs]
accs = [run.summary["test/acc"] for run in std_runs]

In [None]:
pgd_runs[0].id

In [None]:
std_runs = [run for run in runs if "Cross" in run.config.get("model._target_")]
robust_runs = [run for run in runs if "pgd" in run.config.get("model._target_")]

In [None]:
table_for_runs(std_runs, "test/acc", baseline_acc=0.7301)
table_for_runs(std_runs, "test/main_loss", baseline_acc=None)
table_for_runs(std_runs, "attack/acc", baseline_acc=1e-4)
table_for_runs(std_runs, "attack/main_loss", baseline_acc=None)

In [None]:
table_for_runs(robust_runs)
table_for_runs(robust_runs, "attack/acc", baseline_acc=None)

In [None]:
printnames(robust_runs)
print(["switch" if "switch" in run.tags else "entropy" for run in robust_runs])

In [None]:
robust_runs[0].config["model.model.num_experts"]

In [None]:
import pandas as pd
from src.utils.latex import bold_formatter
from functools import partial


def table_for_runs(runs, acc_key="test/acc", baseline_acc=0.7301):
    loss_types = ["switch" if "switch" in run.tags else "entropy" for run in runs]
    training = ["PGD" if "pgd" in run.config["model._target_"] else "SGD" for run in runs]
    gates = ["CGARN" if "CGARN" in run.name else "GALRN" for run in runs]
    ns = np.array([int(run.config["model.model.num_experts"]) for run in runs])

    keys = list(tuple(x) for x in zip(loss_types, training, gates, ns))
    accs = {key: run.summary[acc_key] for key, run in zip(keys, runs)}

    def filter_by_k(it, k):
        return [item for idx, item in enumerate(it) if ns[idx] == k]

    loss_types = filter_by_k(loss_types, k=2)
    training = filter_by_k(training, k=2)
    gates = filter_by_k(gates, k=2)

    accsks = [filter_by_k(accs, k=k) for k in (2, 4, 8, 16, 32)]

    k1keys = list(sorted(k[:-1] for k in keys if k[-1] == 2))

    def fill_row(loss, arch, gate):
        return [accs.get((loss, arch, gate, k), None) for k in (2, 4, 8, 16, 32)]

    data = [(*k1key, *fill_row(*k1key)) for k1key in k1keys]

    # data = sort_order(data, 2, 1, 0)
    columns = [
        "Loss Type",
        "Architecture",
        "Gate Type",
        "Experts n=2",
        "n=4",
        "n=8",
        "n=16",
        "n=32",
    ]
    df = pd.DataFrame(
        data=data,
        # index=index,
        columns=columns,
    )

    fmts_max_4f = {
        column: partial(bold_formatter, value=(baseline_acc or 0) + 1e-4, num_decimals=4)
        for column in columns[3:]
    }
    print(
        df.to_latex(
            formatters=fmts_max_4f if baseline_acc else None, escape=False, na_rep="", index=False
        )
    )

In [None]:
num_expert_runs = get_runs(project="robust-cifar100-resnet-moe", tags=("num_expert_ablation",))

In [None]:
table_for_runs(num_expert_runs)

In [None]:
table_for_runs(num_expert_runs, acc_key="attack/acc")