In [6]:
from pathlib import Path

import wandb
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import re
from matplotlib import pyplot as plt

sns.set_style("whitegrid")

out = Path("plots")
data_category = "all"
api = wandb.Api()
runs = api.runs(f"grains-polito/NeSy24PascalPart_{data_category.upper()}")
reports = api.reports(f"grains-polito/NeSy24PascalPart_{data_category.upper()}")

exp_keys = {
    "test/reasonable_false_pos": "# Mereological Violations (Reasonable)",
    "test/type_acc": "Type Accuracy",
    "test/partof_pr_auc": "PartOf PR AUC",
    "test/false_neg": "False Negatives",
    "test/unreasonable_false_pos": "# Mereological Violations",
    "test/type_balanced_acc": "Type Balanced Accuracy",
    "test/partof_roc_auc": "PartOf ROC AUC",
}

sort_types = {
    "test/reasonable_false_pos": "min",
    "test/type_acc": "max",
    "test/partof_pr_auc": "max",
    "test/false_neg": "min",
    "test/unreasonable_false_pos": "min",
    "test/type_balanced_acc": "max",
    "test/partof_roc_auc": "max",
}

legend_map = {
    'log_ltn': "logLTN",
    'prod_rl': "LTN-Prod",
    'stable_rl_2': "LTN-Stable (p: 2)",
    'stable_rl_6': "LTN-Stable (p: 6)",
    'focal_ltn_6': "Focal LTN (gamma: 6)",
    'focal_ltn': "Focal LTN (gamma: 2)",
    'focal_ltn_1': "Focal LTN (gamma: 1)",
    'focal_ltn_0': "Focal LTN (gamma: 0)",
    'focal_log_ltn_6': "Focal logLTN (gamma: 6)",
    'focal_log_ltn': "Focal logLTN (gamma: 2)",
    'focal_log_ltn_1': "Focal logLTN (gamma: 1)",
    'focal_log_ltn_0': "Focal logLTN (gamma: 0)",
}

In [7]:
runs_df = pd.DataFrame()
discarded_keys = set()
for run in tqdm(runs):
    df = run.history(
        samples=10, keys=list(exp_keys.keys()),
        x_axis="_step", pandas=(True), stream="default"
    ).dropna(axis=1, how="any")
    steps = df["_step"].unique().tolist()
    steps.sort()
    steps = {v: i for i, v in enumerate(steps)}
    df["_step"] = df["_step"].map(steps)
    for k, v in run.config.items():
        if not k.startswith("_"):
            try:
                if k == "group_name":
                    v = legend_map[v]
                df[k] = v
            except ValueError:
                discarded_keys.add(k)
    if "_0" in run.config.get("group_name"):
        continue
    df["name"] = run.name
    runs_df = pd.concat([runs_df, df])
print(discarded_keys)

100%|██████████| 47/47 [00:33<00:00,  1.39it/s]

{'partof_hidden_layer_sizes', 'types_hidden_layer_sizes'}





In [8]:
tmp = runs_df[["name", "_step", "random_seed", "group_name", "ltn_config", "gamma"] + list(exp_keys.keys())]

In [None]:
out  = pd.DataFrame()
for k, title in exp_keys.items():
    t = tmp[(tmp["gamma"].isna())].groupby(["group_name","_step"]).agg(mean= (k,"mean"), std= (k,"std"))
    t = t.sort_values('mean', ascending=sort_types[k] == "min").reset_index().groupby(["group_name"]).nth(0)
    t["metric"] = exp_keys[k]
    t = t.reset_index()
    out = pd.concat([out, t[["group_name","metric","mean","std"]]], axis=0)
    # display(t)

In [None]:
out.set_index(["group_name", "metric"])
# order_legend = list(legend_map.values())
font_size = 12
out.mkdir(exist_ok=True, parents=True)
for k, title in exp_keys.items():
    ax = sns.lineplot(y=k, x="_step", data=tmp[(tmp["gamma"].isna())], hue="group_name")
    ax.set_title(title, weight='bold').set_fontsize(font_size)
    ax.set(xlabel="Epoch", ylabel="Value")
    # handles, labels = ax.get_legend_handles_labels()
    # labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: order_legend.index(t[0])))
    # ax.legend(handles, labels, title="Method")
    ax.legend(*zip(*sorted(zip(*ax.get_legend_handles_labels()), key = lambda s: [int(t) if t.isdigit() else t.lower() for t in re.split('(\d+)', s[1])])))
    ax.ticklabel_format(useMathText=True, useOffset=False)
    sns.despine()
    plt.tight_layout()
    # display(ax)
    name = k.replace("\\", "_").replace("/", "_")
    ax.get_figure().savefig(out / f"{name}.pdf")
    plt.clf()

In [14]:
for ltn_config in ["focal_ltn", "focal_log_ltn"]:
    for k, title in exp_keys.items():
        ax = sns.lineplot(y=k, x="_step", data=tmp[(tmp["ltn_config"] == ltn_config)], hue="group_name")
        ax.set_title(title, weight='bold').set_fontsize(font_size)
        ax.set(xlabel="Epoch", ylabel="Value")
        # handles, labels = ax.get_legend_handles_labels()
        # labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: order_legend.index(t[0])))
        # ax.legend(handles, labels, title="Method")
        ax.legend(*zip(*sorted(zip(*ax.get_legend_handles_labels()), key = lambda s: [int(t) if t.isdigit() else t.lower() for t in re.split('(\d+)', s[1])])))
        ax.ticklabel_format(useMathText=True, useOffset=False)
        sns.despine()
        plt.tight_layout()
        # display(ax)
        name = k.replace("\\", "_").replace("/", "_")
        ax.get_figure().savefig(out / f"gamma_{ltn_config}_{name}.pdf")
        plt.clf()

<Figure size 640x480 with 0 Axes>