In [None]:
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
from tqdm import tqdm

entity_name = "geometric-governance"
project_name = "learn_voting_rules_mono"
api = wandb.Api()

In [None]:
# Load monotonicity training runs
runs = api.runs(path=f"{entity_name}/{project_name}")
runs_dict = defaultdict(list)

for run in tqdm(runs):
    cfg = run.config
    history = run.history(keys=["train/monotonicity_loss", "val/accuracy"])

    loss_enable = cfg["monotonicity_loss_train"]
    history_dict = {
        "train/monotonicity_loss": history["train/monotonicity_loss"].to_numpy(),
        "val/accuracy": history["val/accuracy"].to_numpy(),
    }

    runs_dict["monotonic" if loss_enable else "standard"].append(history_dict)

In [None]:
sns.set_theme(style="whitegrid")

fig, axs = plt.subplots(1, 2, figsize=(12, 5))


# Helper function
def aggregate_histories(key, histories):
    stacked = np.stack([h[key] for h in histories])
    return stacked.mean(axis=0), stacked.std(axis=0)


ax = axs[0]
for method, histories in runs_dict.items():
    mean_acc, std_acc = aggregate_histories("val/accuracy", histories)
    x = np.arange(len(mean_acc))
    ax.plot(x, mean_acc, label=method)
    ax.fill_between(x, mean_acc - std_acc, mean_acc + std_acc, alpha=0.2)

ax.set_title("Validation Accuracy")
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy")
ax.legend()

ax = axs[1]
for method, histories in runs_dict.items():
    mean_loss, std_loss = aggregate_histories("train/monotonicity_loss", histories)
    x = np.arange(len(mean_loss))
    ax.plot(x, mean_loss, label=method)
    ax.fill_between(x, mean_loss - std_loss, mean_loss + std_loss, alpha=0.2)

ax.set_title("Train Monotonicity Loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.legend()
ax.set_yscale("log")

fig.suptitle("Learning STV with Monotonicity Loss", fontsize=16)

plt.tight_layout()
plt.savefig("monotonicity_loss_stv.png", dpi=300)
plt.show()