In [None]:
from functools import partial
from typing import List

from dotenv import load_dotenv

from src.utils.wandb import get_runs

load_dotenv("../.env")

In [None]:
import pandas as pd
import wandb


def lmap(*x):
    return list(map(*x))


api = wandb.Api()

In [None]:
# Project is specified by <entity/project-name>
runs = get_runs(project="robust-cifar100-resnet-moe")

In [None]:
runs[1].summary

In [None]:
summary_list, config_list, name_list = [], [], []
for run in runs:
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append({k: v for k, v in run.config.items() if not k.startswith("_")})

    # .name is the human-readable name of the run.
    name_list.append(run.name)

runs_df = pd.DataFrame({"summary": summary_list, "config": config_list, "name": name_list})
# runs_df.to_csv("project.csv")

In [None]:
summary = pd.DataFrame(runs_df["summary"])
summary

In [None]:
run_idx = name_list.index("cifar100-resnet18")
summary = pd.DataFrame(runs_df["summary"][run_idx])

In [None]:
runs_df["summary"][run_idx]["train/acc"]

In [None]:
summary["train/acc"]["values"]

In [None]:
run_names = [
    "cifar100-resnet18",
    "cifar100-resnet18-block-moe4-GALRN-1",
    "cifar100-resnet18-block-moe4-CGARN-1",
    "cifar100-resnet18-conv-moe4-GALRN-1",
    "cifar100-resnet18-conv-moe4-CGARN-1",
]
short_run_names = [
    rn.replace("cifar100-resnet18-", "*-").replace("cifar100-resnet18", "resnet18")
    for rn in run_names
]

free_adv_run_names = [
    "cifar100-resnet18-free-adv",
    "cifar100-resnet18-free-adv-train-block-moe4-GALRN-1",
    "cifar100-resnet18-free-adv-train-block-moe4-CGARN-1",
    "cifar100-resnet18-free-adv-train-conv-moe4-GALRN-1",
    "cifar100-resnet18-free-adv-train-conv-moe4-CGARN-1",
]


def _summary_for_name(run_name):
    return runs_df["summary"][name_list.index(run_name)]


def lmap(*x):
    return list(map(*x))


summaries = lmap(_summary_for_name, run_names)
free_adv_summaries = lmap(_summary_for_name, free_adv_run_names)

test_accs = [summary["test/acc"] for summary in summaries]
adversarial_accs = [summary["attack/acc"] for summary in summaries]

fa_test_accs = [summary["test/acc"] for summary in free_adv_summaries]
fa_adversarial_accs = [summary["attack/acc"] for summary in free_adv_summaries]

In [None]:
df = pd.DataFrame(
    data=zip(test_accs, fa_test_accs, adversarial_accs, fa_adversarial_accs),
    index=short_run_names,
    columns=["Natural", "Free-Adv-Train", "PGD(20,8,2)", "Free-Adv-Train, PGD(20,8,2)"],
)
fmts_max_4f = {
    column: partial(bold_formatter, value=df[column].max(), num_decimals=4)
    for column in df.columns
}
print(df.to_latex(formatters=fmts_max_4f, escape=False))

In [None]:
cfgs = [
    "cifar100-resnet18-block-moe4-GALRN",
    "cifar100-resnet18-block-moe4-CGARN",
    "cifar100-resnet18-conv-moe4-GALRN",
    "cifar100-resnet18-conv-moe4-CGARN",
]

ks = lmap(str, range(1, 5))


def names_for_cfg(cfg):
    # cfg = cfg.replace("18","18-free-adv-train")
    return [f"{cfg}-{k}" for k in ks]


def summaries_for_cfg(cfg):
    names = names_for_cfg(cfg)
    return lmap(_summary_for_name, names)


def accs_for_cfg(cfg, key="attack/acc"):
    summaries = summaries_for_cfg(cfg)
    accs = [summary[key] for summary in summaries]
    return accs


accs = lmap(accs_for_cfg, cfgs)
accs_test = lmap(partial(accs_for_cfg, key="test/acc"), cfgs)

In [None]:
baseline = _summary_for_name("cifar100-resnet18")["attack/acc"]
columns = ["baseline", *ks]
data = [[baseline, None, None, None, None], *([None, *d] for d in accs)]
index = ["cifar100-resnet18", *cfgs]
index = [
    rn.replace("cifar100-resnet18-", "*-").replace("cifar100-resnet18", "resnet18") for rn in index
]

df = pd.DataFrame(
    data=data,
    index=index,
    columns=columns,
)
fmts_max_4f = {
    column: partial(bold_formatter, value=baseline + 1e-4, num_decimals=4) for column in df.columns
}
print(df.to_latex(formatters=fmts_max_4f, escape=False, na_rep=""))

In [None]:
baseline = _summary_for_name("cifar100-resnet18")["test/acc"]
columns = ["baseline", *ks]
data = [[baseline, None, None, None, None], *([None, *d] for d in accs_test)]
index = ["cifar100-resnet18", *cfgs]
index = [
    rn.replace("cifar100-resnet18-", "*-").replace("cifar100-resnet18", "resnet18") for rn in index
]

df = pd.DataFrame(
    data=data,
    index=index,
    columns=columns,
)
fmts_max_4f = {
    column: partial(bold_formatter, value=baseline + 1e-4, num_decimals=4) for column in df.columns
}
print(df.to_latex(formatters=fmts_max_4f, escape=False, na_rep=""))

In [None]:
import matplotlib.pyplot as plt

df.T.plot()
plt.savefig("ks_plot.png")

In [None]:
def run_to_pred_table(run):
    return wandb.use_artifact(f"run-{run.id}-prediction_table:v0")

In [None]:
run_to_pred_table(runs[1])

In [None]:
# Fixed experts

In [None]:
run_idx = name_list.index("evaluate-cifar100-resnet18-pgd-adv-train-block-moe16-CGARN-1")
summary = pd.DataFrame(runs_df["summary"][run_idx])

In [None]:
summary["performance_plot_PGD-20-8-2_table"].plot()

In [None]:
pp_table = runs[run_idx].summary["performance_plot_PGD-20-8-2_table"]

In [None]:
dict(pp_table.items())

In [None]:
df = pd.read_csv("./cifar100_fixed_experts_robust/wandb_fixed_expert_pgd_perf_16.csv")


def load_fixed_expert_file(fp):
    df = pd.read_csv(fp)
    accs = df["Metric"].values
    all_experts_acc = accs[-1]
    fe_accs = accs[:-1]
    return all_experts_acc, fe_accs


def load_expert_files(
    num_experts: list, prefix="cifar100_fixed_experts_robust/wandb_fixed_expert_pgd_perf"
):
    accs = []
    fe_accs_list = []
    for ne in num_experts:
        all_experts_acc, fe_accs = load_fixed_expert_file(f"./{prefix}_{ne}.csv")
        accs.append(all_experts_acc)
        fe_accs_list.append(fe_accs)
    return accs, fe_accs_list


def match_fixed_expert_accuracies(num_experts_list, fe_accs_list):
    fe_data = [
        (num_experts, fe_acc)
        for num_experts, fe_accs in zip(num_experts_list, fe_accs_list)
        for fe_acc in fe_accs
    ]
    return np.array(fe_data, dtype=float)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt


def plot_fixed_experts(
    num_experts_list,
    accs,
    fe_data_np,
    fe_data_brown_np=None,
    show_accs=True,
    show_fe_scatter=True,
    show_sota=False,
    ylabel="Adversarial Accuracy",
    baseline=0.178,
):
    np.random.seed(1)

    plt.xticks(num_experts_list)
    # plt.xlim(1.5,max(num_experts_list)+1)
    plt.ylabel(ylabel)
    plt.xlabel("Number of Experts")

    legend = ["ResNet18 Baseline"]
    plt.axhline(baseline, color="green", linestyle="--")
    (accs_line,) = plt.plot(num_experts_list, accs, marker="x")

    if show_accs:
        legend.append("ResNet18-BlockMoE; k=1")
    else:
        accs_line.remove()

    if show_fe_scatter and len(fe_data_np) > 0:
        fe_data_np = np.copy(fe_data_np)
        fe_data_np[:, 0] += fe_data_np[:, 0] * np.random.uniform(
            -0.1, 0.1, size=fe_data_np.shape[0]
        )
        sns.scatterplot(x=fe_data_np[:, 0], y=fe_data_np[:, 1], color="brown")
        legend.append("Fixed Expert (robust)")

    if show_fe_scatter and fe_data_brown_np is not None and len(fe_data_brown_np) > 0:
        fe_data_brown_np = np.copy(fe_data_brown_np)
        fe_data_brown_np[:, 0] += fe_data_brown_np[:, 0] * np.random.uniform(
            -0.1, 0.1, size=fe_data_brown_np.shape[0]
        )
        sns.scatterplot(x=fe_data_brown_np[:, 0], y=fe_data_brown_np[:, 1])
        legend.append("Fixed Expert")

    if show_sota:
        plt.axhspan(0.25, 0.27, color="red", linestyle="--", alpha=0.3)
        legend.append("ResNet18 SOTA")

    plt.legend(legend)
    plt.semilogx(base=2)
    # plt.ylim(0.05, 0.3)

In [None]:
import numpy as np

prefix = "cifar100_fixed_experts_robust/wandb_fixed_expert_pgd_perf"
prefix_natural = "cifar100_fixed_experts_robust/wandb_fixed_expert_natural_perf"
figure_prefix = "cifar_robust"

num_experts_list = [2, 4, 8, 16, 32]
accs, fe_accs_list = load_expert_files(num_experts_list, prefix=prefix)
accs_natural, fe_accs_natural_list = load_expert_files(num_experts_list, prefix=prefix_natural)

fe_data_np = match_fixed_expert_accuracies(num_experts_list, fe_accs_list)

fe_data_natural_np = match_fixed_expert_accuracies(num_experts_list, fe_accs_natural_list)

In [None]:
accs_np = np.array([accs[num_experts_list.index(x)] for x in fe_data_np[:, 0]])
accs_up = fe_data_np[:, 1] >= accs_np
fe_data_up = fe_data_np[accs_up]
fe_data_down = fe_data_np[~accs_up]

fe_data_natural_up = fe_data_natural_np[accs_up]
fe_data_natural_down = fe_data_natural_np[~accs_up]

In [None]:
plot_fixed_experts(
    num_experts_list,
    accs,
    fe_data_up,
    fe_data_brown_np=fe_data_down,
    show_fe_scatter=True,
    show_sota=False,
)
plt.savefig(f"{figure_prefix}_adv_fixed_expert_plot.png")
plt.show()

plot_fixed_experts(
    num_experts_list,
    accs_natural,
    fe_data_natural_up,
    fe_data_brown_np=fe_data_natural_down,
    show_fe_scatter=True,
    show_sota=False,
    ylabel="Accuracy",
    baseline=0.5232,
)
plt.savefig(f"{figure_prefix}_natural_fixed_expert_plot.png")

In [None]:
pp_table = runs[1].summary["performance_plot_PGD-20-8-2_table"]

In [None]:
import wandb

run = wandb.init(project="robust-cifar100-resnet-moe")

In [None]:
my_table = wandb.use_artifact("run-2meisvnp-loss_plot_PGD2082_table:v0").get(
    "loss_plot_PGD-20-8-2_table.table.json"
)

In [None]:
my_table

In [None]:
import pandas as pd


def load_fixed_expert_table(table, column="Metric"):
    accs = table.get_column(column)
    all_experts_acc = accs[-1]
    fe_accs = accs[:-1]
    return all_experts_acc, fe_accs


def get_table(artifact_run, table_names):
    if isinstance(table_names, str):
        table_names = [table_names]
    id = artifact_run.id
    for table_name in table_names:
        try:
            short_table_name = table_name.replace("-", "")
            my_table = wandb.use_artifact(f"run-{id}-{short_table_name}:v0").get(
                f"{table_name}.table.json"
            )
            return my_table
        except Exception as e:
            print(f"Ignoring error: {e}")
    raise ValueError("None of the given tables could be found!")


def load_expert_accs(runs: list, table_names="loss_plot_PGD-20-8-2_table", column="Metric"):
    accs = []
    fe_accs_list = []
    for run in runs:
        table = get_table(run, table_names)
        all_experts_acc, fe_accs = load_fixed_expert_table(table, column=column)
        accs.append(all_experts_acc)
        fe_accs_list.append(fe_accs)
    return accs, fe_accs_list


def fixed_expert_performance_plots(runs, figure_prefix, baseline_natural, baseline_attacked):
    accs_list, fe_accs_list = load_expert_accs(
        runs, table_names=("performance_plot_natural_table", "loss_plot_natural_table")
    )
    accs_robust_list, fe_accs_robust_list = load_expert_accs(
        runs, table_names=("performance_plot_PGD-20-8-2_table", "loss_plot_PGD-20-8-2_table")
    )
    #%%
    fe_data_np = match_fixed_expert_accuracies(num_experts_list, fe_accs_list)
    fe_data_robust_np = match_fixed_expert_accuracies(num_experts_list, fe_accs_robust_list)
    #%%
    accs_robust_np = np.array(
        [accs_robust_list[num_experts_list.index(x)] for x in fe_data_robust_np[:, 0]]
    )
    accs_robust_up = fe_data_robust_np[:, 1] >= accs_robust_np
    fe_data_robust_up = fe_data_robust_np[accs_robust_up]
    fe_data_robust_down = fe_data_robust_np[~accs_robust_up]

    fe_data_up = fe_data_np[accs_robust_up]
    fe_data_down = fe_data_np[~accs_robust_up]

    #%%
    plot_fixed_experts(
        num_experts_list,
        accs_robust_list,
        fe_data_robust_up,
        fe_data_brown_np=fe_data_robust_down,
        show_fe_scatter=True,
        show_sota=False,
        baseline=baseline_attacked,
    )
    plt.savefig(f"{figure_prefix}_adv_fixed_expert_plot.png")
    plt.show()

    plot_fixed_experts(
        num_experts_list,
        accs_list,
        fe_data_up,
        fe_data_brown_np=fe_data_down,
        show_fe_scatter=True,
        show_sota=False,
        ylabel="Accuracy",
        baseline=baseline_natural,
    )
    plt.savefig(f"{figure_prefix}_natural_fixed_expert_plot.png")
    plt.show()

In [None]:
api = wandb.Api()

In [None]:
all_runs: List[wandb.wandb_sdk.wandb_run.Run] = api.runs("ditschuk/robust-cifar100-resnet-moe")

In [None]:
def select(names, tag, run):
    return run.name in names and tag in run.tags


def filter_sorted(names, tag):
    runs = {run.name: run for run in filter(partial(select, names, tag), all_runs)}
    return [runs[name] for name in names]

In [None]:
num_experts_list = [2, 4, 8, 16, 32]
run_names = [f"evaluate-cifar100-resnet18-block-moe{ne}-CGARN-1" for ne in num_experts_list]
for tag in ("entropy", "switch"):
    natural_runs = filter_sorted(run_names, tag)
    figure_prefix = f"fixed_expert_plots/cifar_{tag}"
    fixed_expert_performance_plots(
        natural_runs, figure_prefix, baseline_natural=0.7301, baseline_attacked=1e-4
    )

In [None]:
robust_run_names = [
    f"evaluate-cifar100-resnet18-pgd-adv-train-block-moe{ne}-CGARN-1" for ne in num_experts_list
]
for tag in ("entropy", "switch"):
    robust_runs = filter_sorted(robust_run_names, tag)
    figure_prefix = f"fixed_expert_plots/cifar_robust_{tag}"
    fixed_expert_performance_plots(
        robust_runs, figure_prefix, baseline_natural=0.5232, baseline_attacked=0.178
    )

In [None]:
runs = get_runs(project="robust-cifar100-resnet-moe", tags={"adv-train-ablation"})

In [None]:
def parse_name(name):
    if "pgd" in name:
        return "PGD-7"
    if "fast" in name:
        return "Fast"
    if "free" in name:
        return "Free"


names = [run.name for run in runs]
adversarial_methods = [parse_name(name) for name in names]
architectures = [
    "ResNetBlockMoe" if "model.model.num_experts" in run.config else "ResNet-18" for run in runs
]
accs = [run.summary["test/acc"] for run in runs]
adv_accs = [run.summary["attack/acc"] for run in runs]

In [None]:
data = zip(adversarial_methods, architectures, accs, adv_accs)
data = sorted(data, key=lambda o: o[0])
data = sorted(data, key=lambda o: o[1])
columns = ["Adversarial Training Method", "Architecture", "Accuracy", "Accuracy against PGD"]
df = pd.DataFrame(
    data=data,
    # index=index,
    columns=columns,
)
print(df.to_latex(escape=False, na_rep="", index=False))

# print(df.to_latex(escape=False, na_rep="", index=False))

In [None]:
runs = get_runs(project="robust-cifar100-resnet-moe")