In [1]:
import os
os.chdir('/disk/...')  
print(os.getcwd())

/disk/..

In [2]:
from functools import partial
from typing import List

from dotenv import load_dotenv
import numpy as np

from src.utils.wandb import get_runs

load_dotenv("../.env")

False

In [3]:
import pandas as pd
import wandb


def lmap(*x):
    return list(map(*x))


api = wandb.Api()

In [4]:
def match_fixed_expert_accuracies(num_experts_list, fe_accs_list):
    fe_data = [
        (num_experts, fe_acc)
        for num_experts, fe_accs in zip(num_experts_list, fe_accs_list)
        for fe_acc in fe_accs
    ]
    return np.array(fe_data, dtype=float)

In [5]:
import seaborn as sns
import matplotlib.pyplot as plt


def plot_fixed_experts(
    num_experts_list,
    accs,
    fe_data_np,
    fe_data_brown_np=None,
    show_accs=True,
    show_fe_scatter=True,
    show_sota=False,
    ylabel="Adversarial Accuracy",
    baseline=0.178,
):
    np.random.seed(1)

    plt.xticks(num_experts_list)
    # plt.xlim(1.5,max(num_experts_list)+1)
    plt.ylabel(ylabel)
    plt.xlabel("Number of Experts")

    legend = ["ResNet18 Baseline"]
    plt.axhline(baseline, color="green", linestyle="--")
    (accs_line,) = plt.plot(num_experts_list, accs, marker="x")

    if show_accs:
        legend.append("ResNet18-BlockMoE; k=1")
    else:
        accs_line.remove()

    if show_fe_scatter and len(fe_data_np) > 0:
        fe_data_np = np.copy(fe_data_np)
        fe_data_np[:, 0] += fe_data_np[:, 0] * np.random.uniform(
            -0.1, 0.1, size=fe_data_np.shape[0]
        )
        sns.scatterplot(x=fe_data_np[:, 0], y=fe_data_np[:, 1], color="brown")
        legend.append("Fixed Expert (robust)")

    if show_fe_scatter and fe_data_brown_np is not None and len(fe_data_brown_np) > 0:
        fe_data_brown_np = np.copy(fe_data_brown_np)
        fe_data_brown_np[:, 0] += fe_data_brown_np[:, 0] * np.random.uniform(
            -0.1, 0.1, size=fe_data_brown_np.shape[0]
        )
        sns.scatterplot(x=fe_data_brown_np[:, 0], y=fe_data_brown_np[:, 1])
        legend.append("Fixed Expert")

    if show_sota:
        plt.axhspan(0.25, 0.27, color="red", linestyle="--", alpha=0.3)
        legend.append("ResNet18 SOTA")

    plt.legend(legend)
    plt.semilogx(base=2)
    # plt.ylim(0.05, 0.3)

In [7]:
import wandb

run = wandb.init(project="robust-cifar100-resnet-moe")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfanhaixi[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
import pandas as pd
import torch
import json





# def get_table(artifact_run, table_names):
#     if isinstance(table_names, str):
#         table_names = [table_names]
#     id = artifact_run.id
#     for table_name in table_names:
#         try:
#             short_table_name = table_name.replace("-", "")
#             my_table = wandb.use_artifact(f"run-{id}-{short_table_name}:v0").get(
#                 f"{table_name}.table.json"
#             )
#             return my_table
#         except Exception as e:
#             print(f"Ignoring error: {e}")
#     raise ValueError("None of the given tables could be found!")
class SimpleTable:
    def __init__(self, columns, data):
        self.columns = columns
        self.data = data
        
    def get_column(self, col_name):
        idx = self.columns.index(col_name)
        return [row[idx] for row in self.data]

def get_table(artifact_run, table_names):
    if isinstance(table_names, str):
        table_names = [table_names]
    run_id = artifact_run.id

    for table_name in table_names:
        short_name = table_name.replace("-", "")
        artifact_ref = f"run-{run_id}-{table_name}:latest"
        try:
            artifact = wandb.use_artifact(artifact_ref)
            local_dir = artifact.download()
            json_path = os.path.join(local_dir, f"{table_name}.table.json")
            with open(json_path, "r") as f:
                raw = json.load(f)
            return SimpleTable(columns=raw["columns"], data=raw["data"])
        except Exception as e:
            print(f"Ignoring error: {e}")

    raise ValueError("None of the given tables could be found!")


def load_expert_accs(runs: list, table_names="loss_plot_PGD-20-8-2_table", column="Metric"):
    accs = []
    fe_accs_list = []
    for run in runs:
        table = get_table(run, table_names)
        all_experts_acc, fe_accs = load_fixed_expert_table(table, column=column)
        accs.append(all_experts_acc)
        fe_accs_list.append(fe_accs)
        print(f"fe_accs_list:", fe_accs_list)
    return accs, fe_accs_list


# def fixed_expert_performance_plots(runs, figure_prefix, baseline_natural, baseline_attacked):
#     accs_list, fe_accs_list = load_expert_accs(
#         runs, table_names=("performance_plot_natural_table")
#     )
#     accs_robust_list, fe_accs_robust_list = load_expert_accs(
#         runs, table_names=("performance_plot_PGD-20-8-2_table")
#     )
#     #%%
#     fe_data_np = match_fixed_expert_accuracies(num_experts_list, fe_accs_list)
#     fe_data_robust_np = match_fixed_expert_accuracies(num_experts_list, fe_accs_robust_list)
#     #%%
#     accs_robust_np = np.array(
#         [accs_robust_list[num_experts_list.index(x)] for x in fe_data_robust_np[:, 0]]
#     )
#     accs_robust_up = fe_data_robust_np[:, 1] >= accs_robust_np
#     fe_data_robust_up = fe_data_robust_np[accs_robust_up]
#     fe_data_robust_down = fe_data_robust_np[~accs_robust_up]

#     fe_data_up = fe_data_np[accs_robust_up]
#     fe_data_down = fe_data_np[~accs_robust_up]

#     #%%
#     plot_fixed_experts(
#         num_experts_list,
#         accs_robust_list,
#         fe_data_robust_up,
#         fe_data_brown_np=fe_data_robust_down,
#         show_fe_scatter=True,
#         show_sota=False,
#         baseline=baseline_attacked,
#     )
#     plt.savefig(f"{figure_prefix}_adv_fixed_expert_plot.png")
#     plt.show()

#     plot_fixed_experts(
#         num_experts_list,
#         accs_list,
#         fe_data_up,
#         fe_data_brown_np=fe_data_down,
#         show_fe_scatter=True,
#         show_sota=False,
#         ylabel="Accuracy",
#         baseline=baseline_natural,
#     )
#     plt.savefig(f"{figure_prefix}_natural_fixed_expert_plot.png")
#     plt.show()
def fixed_expert_performance_plots(runs, figure_prefix, baseline_natural, baseline_attacked):
    accs_list, fe_accs_list = load_expert_accs(
        runs, table_names=("performance_plot_natural_table")
    )
    accs_robust_list, fe_accs_robust_list = load_expert_accs(
        runs, table_names=("performance_plot_PGD-20-8-2_table")
    )
    #%%
    fe_data_np = match_fixed_expert_accuracies(num_experts_list, fe_accs_list)
    fe_data_robust_np = match_fixed_expert_accuracies(num_experts_list, fe_accs_robust_list)
    #%%
    accs_robust_np = np.array(
        [accs_robust_list[num_experts_list.index(x)] for x in fe_data_robust_np[:, 0]]
    )
    accs_robust_up = fe_data_robust_np[:, 1] >= accs_robust_np
    fe_data_robust_up = fe_data_robust_np[accs_robust_up]
    fe_data_robust_down = fe_data_robust_np[~accs_robust_up]

    fe_data_up = fe_data_np[accs_robust_up]
    fe_data_down = fe_data_np[~accs_robust_up]

    #%%
    os.makedirs(os.path.dirname(f"{figure_prefix}_adv_fixed_expert_plot.png"), exist_ok=True)

    plot_fixed_experts(
        num_experts_list,
        accs_robust_list,
        fe_data_robust_up,
        fe_data_brown_np=fe_data_robust_down,
        show_fe_scatter=True,
        show_sota=False,
        baseline=baseline_attacked,
    )
    plt.savefig(f"{figure_prefix}_adv_fixed_expert_plot.png")
    plt.show()

    os.makedirs(os.path.dirname(f"{figure_prefix}_natural_fixed_expert_plot.png"), exist_ok=True)

    plot_fixed_experts(
        num_experts_list,
        accs_list,
        fe_data_up,
        fe_data_brown_np=fe_data_down,
        show_fe_scatter=True,
        show_sota=False,
        ylabel="Accuracy",
        baseline=baseline_natural,
    )
    plt.savefig(f"{figure_prefix}_natural_fixed_expert_plot.png")
    plt.show()

In [9]:
api = wandb.Api()

In [10]:
all_runs: List[wandb.wandb_sdk.wandb_run.Run] = api.runs("robust-cifar100-resnet-moe")

In [11]:
def select(names, tag, run):
    return run.name in names and tag in run.tags


def filter_sorted(names, tag):
    runs = {run.name: run for run in filter(partial(select, names, tag), all_runs)}
    return [runs[name] for name in names]



In [12]:
# num_experts_list = [2,8,16,32]
# names = [f"evaluate-cifar100-resnet18-block-moe{ne}-GALRN-1-switch" for ne in num_experts_list]
# for run in all_runs:
#     if run.name in names:
#         print(run.name, run.tags)


In [128]:
# wandb.run=="evaluate-cifar100-resnet18-block-moe2-GALRN-1-switch"
# print(wandb.run.id)
# table_names=("performance_plot_natural_table", "loss_plot_natural_table")
# table = get_table(run, table_names)
# print(table)

In [6]:
num_experts_list = [2,4]
run_names = [f"evaluate-cifar100-resnet18-block-moe{ne}-GALRN-1-switch" for ne in num_experts_list]
# for tag in ("entropy", "switch"):
#     natural_runs = filter_sorted(run_names, tag)
#     figure_prefix = f"fixed_expert_plots/cifar_{tag}"
#     fixed_expert_performance_plots(
#         natural_runs, figure_prefix, baseline_natural=0.7301, baseline_attacked=1e-4
#     )
natural_runs = [run for run in all_runs if run.name in run_names]
print([run.id for run in natural_runs])
figure_prefix = "fixed_expert_plots/cifar-switch"
fixed_expert_performance_plots(
    natural_runs,
    figure_prefix,
    baseline_natural=0.7301,
    baseline_attacked=1e-4
)

NameError: name 'all_runs' is not defined

In [None]:
robust_run_names = [
    f"evaluate-cifar100-resnet18-pgd-adv-train-block-moe{ne}-GALRN-1-switch" for ne in num_experts_list
]
for tag in ("entropy", "switch"):
    robust_runs = filter_sorted(robust_run_names, tag)
    figure_prefix = f"fixed_expert_plots/cifar_robust_{tag}"
    fixed_expert_performance_plots(
        robust_runs, figure_prefix, baseline_natural=0.5232, baseline_attacked=0.178
    )