In [24]:
from pathlib import Path

import wandb
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import re
from matplotlib import pyplot as plt

sns.set_style("whitegrid")

out = Path("plots")
data_category = "all"
api = wandb.Api()
runs = api.runs(f"grains-polito/NeSy24PascalPart_{data_category.upper()}")
reports = api.reports(f"grains-polito/NeSy24PascalPart_{data_category.upper()}")

exp_keys = {
    "test/type_acc": "Type Accuracy",
    "test/type_balanced_acc": "Type Balanced Accuracy",
    "test/partof_pr_auc": "PartOf PR AUC",
    "test/partof_roc_auc": "PartOf ROC AUC",
    "test/unreasonable_false_pos": "# Mereological Violations",
    "test/reasonable_false_pos": "# Mereological Violations (Reasonable)",
    "test/false_neg": "False Negatives",
}

sort_types = {
    "test/type_acc": "max",
    "test/type_balanced_acc": "max",
    "test/partof_pr_auc": "max",
    "test/partof_roc_auc": "max",
    "test/unreasonable_false_pos": "min",
    "test/reasonable_false_pos": "min",
    "test/false_neg": "min",
}

legend_map = {
    'log_ltn': "logLTN",
    'prod_rl': "LTN-Prod",
    'stable_rl_2': "LTN-Stable (p: 2)",
    'stable_rl_6': "LTN-Stable (p: 6)",
    'focal_ltn_6': "Focal LTN (gamma: 6)",
    'focal_ltn': "Focal LTN (gamma: 2)",
    'focal_ltn_sum': "Focal LTN (Sum)",
    'focal_ltn_1': "Focal LTN (gamma: 1)",
    'focal_ltn_0': "Focal LTN (gamma: 0)",
    'focal_log_ltn_6': "Focal logLTN (gamma: 6)",
    'focal_log_ltn': "Focal logLTN (gamma: 2)",
    'focal_log_ltn_sum': "Focal logLTN (Sum)",
    'focal_log_ltn_1': "Focal logLTN (gamma: 1)",
    'focal_log_ltn_0': "Focal logLTN (gamma: 0)",
}

sort_order = ["LTN-Stable (p: 2)","LTN-Stable (p: 6)", "LTN-Prod", "logLTN", "Focal LTN (gamma: 2)", "Focal LTN (Sum)", "Focal logLTN (gamma: 2)", "Focal logLTN (Sum)"]


In [33]:
runs_df = pd.DataFrame()
discarded_keys = set()
for run in tqdm(runs):
    df = run.history(
        samples=10, keys=list(exp_keys.keys()),
        x_axis="_step", pandas=(True), stream="default"
    ).dropna(axis=1, how="any")
    steps = df["_step"].unique().tolist()
    steps.sort()
    steps = {v: i for i, v in enumerate(steps)}
    df["_step"] = df["_step"].map(steps)
    for k, v in run.config.items():
        if not k.startswith("_"):
            try:
                if k == "group_name":
                    tmp = v.split("_")
                    if "." in tmp[-1]:
                        v = "_".join(tmp[:-1])
                        v = f"{legend_map[v]} ({tmp[-1]})"
                    elif "sum" in v:
                        v = v.replace("_1", "")
                        v = legend_map[v]
                    else:
                        v = legend_map[v]
                df[k] = v
            except ValueError:
                discarded_keys.add(k)
    if run.config.get("group_name").endswith("_0"):
        continue
    df["name"] = run.name
    if "data_ratio" not in df.columns:
        df["data_ratio"] = 1
    if "gamma" not in df.columns:
        df["gamma"] = 1
    runs_df = pd.concat([runs_df, df])
print(discarded_keys)

100it [01:08,  1.46it/s]                       

{'types_hidden_layer_sizes', 'partof_hidden_layer_sizes'}





In [45]:
tmp = runs_df[
    ["name", "_step", "random_seed", "group_name", "ltn_config", "gamma", "data_ratio"] + list(exp_keys.keys())]

In [57]:
out = pd.DataFrame()
gamma = 0.5
df = tmp[(tmp["gamma"] == 1) & (tmp["data_ratio"] == gamma)].groupby(["group_name", "_step"])
for k, title in exp_keys.items():
    t = df.agg(mean=(k, "mean"), std=(k, "std"))
    t = t.sort_values('mean', ascending=sort_types[k] == "min").reset_index()
    if k == "test/type_acc":
        t = t.groupby(["group_name"]).nth(0)
        best_step = t[["group_name", "_step"]]
    else:
        t["filter"] = t[["group_name", "_step"]].apply(
            lambda x: any((best_step["group_name"] == x["group_name"]) & (best_step["_step"] == x["_step"])), axis=1)
        t = t[t["filter"]]
    t["metric"] = exp_keys[k]
    t = t.reset_index()
    t["res"] = t["mean"].map("{:.3f}".format) + " ± " + t["std"].map("{:.3f}".format)
    out = pd.concat([out, t[["group_name", "metric", "res", "_step"]]], axis=0)

tmp2 = out.pivot(index="group_name", columns="metric", values="res")[
    ["Type Accuracy", "Type Balanced Accuracy", "PartOf PR AUC", "PartOf ROC AUC", "# Mereological Violations",
     "# Mereological Violations (Reasonable)", "False Negatives"]]
if gamma != 1:
    print("here")
    tmp2 = tmp2.rename(index= lambda x: x.replace(f" ({gamma})", ""))
    tmp2 = tmp2.dropna()
tmp2 = tmp2.reindex(sort_order)
tmp2.to_clipboard()
tmp2

here


metric,Type Accuracy,Type Balanced Accuracy,PartOf PR AUC,PartOf ROC AUC,# Mereological Violations,# Mereological Violations (Reasonable),False Negatives
group_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
LTN-Stable (p: 2),0.084 ± 0.040,0.027 ± 0.016,0.028 ± 0.001,0.449 ± 0.011,28581.400 ± 61665.956,762.200 ± 1640.254,17815.000 ± 10354.137
LTN-Stable (p: 6),0.512 ± 0.016,0.344 ± 0.011,0.033 ± 0.001,0.495 ± 0.012,10165.400 ± 12879.185,478.000 ± 648.668,18022.600 ± 10479.708
LTN-Prod,0.598 ± 0.008,0.500 ± 0.030,0.032 ± 0.002,0.479 ± 0.017,5095.600 ± 8040.543,228.000 ± 377.017,14238.800 ± 10618.381
logLTN,0.584 ± 0.032,0.511 ± 0.024,0.030 ± 0.003,0.466 ± 0.024,2360.600 ± 1725.000,50.400 ± 30.672,14318.600 ± 10698.127
Focal LTN (gamma: 2),0.589 ± 0.013,0.521 ± 0.013,0.033 ± 0.002,0.488 ± 0.021,3040.800 ± 4848.993,176.200 ± 266.254,26126.600 ± 155.330
Focal LTN (Sum),0.584 ± 0.022,0.496 ± 0.014,0.033 ± 0.000,0.492 ± 0.006,2526.500 ± 3395.043,80.750 ± 135.202,26169.000 ± 101.348
Focal logLTN (gamma: 2),0.616 ± 0.013,0.538 ± 0.012,0.031 ± 0.002,0.479 ± 0.015,3962.000 ± 3286.371,203.600 ± 186.499,14272.600 ± 10740.735
Focal logLTN (Sum),0.633 ± 0.006,0.550 ± 0.011,0.032 ± 0.002,0.479 ± 0.016,8676.333 ± 6900.485,478.333 ± 352.614,25882.000 ± 288.111


In [40]:
out.pivot(index="group_name", columns="metric", values="res")

metric,# Mereological Violations,# Mereological Violations (Reasonable),False Negatives,PartOf PR AUC,PartOf ROC AUC,Type Accuracy,Type Balanced Accuracy
group_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
Focal LTN (Sum) (0.5),2526.500 ± 3395.043,80.750 ± 135.202,26169.000 ± 101.348,0.033 ± 0.000,0.492 ± 0.006,0.584 ± 0.022,0.496 ± 0.014
Focal LTN (gamma: 2) (0.5),3040.800 ± 4848.993,176.200 ± 266.254,26126.600 ± 155.330,0.033 ± 0.002,0.488 ± 0.021,0.589 ± 0.013,0.521 ± 0.013
Focal logLTN (Sum) (0.5),8676.333 ± 6900.485,478.333 ± 352.614,25882.000 ± 288.111,0.032 ± 0.002,0.479 ± 0.016,0.633 ± 0.006,0.550 ± 0.011
Focal logLTN (gamma: 2) (0.5),3962.000 ± 3286.371,203.600 ± 186.499,14272.600 ± 10740.735,0.031 ± 0.002,0.479 ± 0.015,0.616 ± 0.013,0.538 ± 0.012
LTN-Prod (0.5),5095.600 ± 8040.543,228.000 ± 377.017,14238.800 ± 10618.381,0.032 ± 0.002,0.479 ± 0.017,0.598 ± 0.008,0.500 ± 0.030
LTN-Stable (p: 2) (0.5),28581.400 ± 61665.956,762.200 ± 1640.254,17815.000 ± 10354.137,0.028 ± 0.001,0.449 ± 0.011,0.084 ± 0.040,0.027 ± 0.016
LTN-Stable (p: 6) (0.5),10165.400 ± 12879.185,478.000 ± 648.668,18022.600 ± 10479.708,0.033 ± 0.001,0.495 ± 0.012,0.512 ± 0.016,0.344 ± 0.011
logLTN (0.5),2360.600 ± 1725.000,50.400 ± 30.672,14318.600 ± 10698.127,0.030 ± 0.003,0.466 ± 0.024,0.584 ± 0.032,0.511 ± 0.024


In [31]:
tmp[tmp["data_ratio"] == 0.5]

Unnamed: 0,name,_step,random_seed,group_name,ltn_config,gamma,data_ratio,test/type_acc,test/type_balanced_acc,test/partof_pr_auc,test/partof_roc_auc,test/unreasonable_false_pos,test/reasonable_false_pos,test/false_neg
0,focal_log_ltn_sum_0.5_1303,0,1303,Focal logLTN (Sum) (0.5),focal_log_ltn_sum,,0.5,0.467990,0.343411,0.029381,0.463037,157762,5270,22323
1,focal_log_ltn_sum_0.5_1303,1,1303,Focal logLTN (Sum) (0.5),focal_log_ltn_sum,,0.5,0.573113,0.457380,0.032140,0.477709,377,8,26201
2,focal_log_ltn_sum_0.5_1303,2,1303,Focal logLTN (Sum) (0.5),focal_log_ltn_sum,,0.5,0.592131,0.482952,0.031291,0.467379,6116,251,25976
3,focal_log_ltn_sum_0.5_1303,3,1303,Focal logLTN (Sum) (0.5),focal_log_ltn_sum,,0.5,0.603477,0.513427,0.032592,0.475195,18307,914,25509
0,focal_ltn_sum_0.5_1301,0,1301,Focal LTN (Sum) (0.5),focal_ltn_sum,,0.5,0.447246,0.361140,0.031416,0.480855,83155,2365,24107
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5,log_ltn_0.5_1300,5,1300,logLTN (0.5),log_ltn,,0.5,0.507447,0.484805,0.031069,0.475942,949,56,6528
6,log_ltn_0.5_1300,6,1300,logLTN (0.5),log_ltn,,0.5,0.548041,0.471347,0.029947,0.463349,432,26,6543
7,log_ltn_0.5_1300,7,1300,logLTN (0.5),log_ltn,,0.5,0.545484,0.493303,0.030313,0.466277,5686,230,6357
8,log_ltn_0.5_1300,8,1300,logLTN (0.5),log_ltn,,0.5,0.542032,0.481716,0.029865,0.467164,1726,38,6519


In [None]:
out = pd.DataFrame()
df = tmp[(tmp["ltn_config"].isin(["focal_ltn", "focal_log_ltn"])) & (tmp["data_ratio"] == 1)].groupby(["group_name", "_step"])
for k, title in exp_keys.items():
    t = df.agg(mean=(k, "mean"), std=(k, "std"))
    t = t.sort_values('mean', ascending=sort_types[k] == "min").reset_index()
    if k == "test/type_acc":
        t = t.groupby(["group_name"]).nth(0)
        best_step = t[["group_name", "_step"]]
    else:
        t["filter"] = t[["group_name", "_step"]].apply(
            lambda x: any((best_step["group_name"] == x["group_name"]) & (best_step["_step"] == x["_step"])), axis=1)
        t = t[t["filter"]]
    t["metric"] = exp_keys[k]
    t = t.reset_index()
    t["res"] = t["mean"].map("{:.3f}".format) + " ± " + t["std"].map("{:.3f}".format)
    out = pd.concat([out, t[["group_name", "metric", "res", "_step"]]], axis=0)
# order_keys = list(legend_map.values())
tmp2 = out.pivot(index="group_name", columns="metric", values="res")[
    ["Type Accuracy", "Type Balanced Accuracy", "PartOf PR AUC", "PartOf ROC AUC", "# Mereological Violations",
     "# Mereological Violations (Reasonable)", "False Negatives"]]
tmp2.to_clipboard()
tmp2

In [None]:
out.set_index(["group_name", "metric"])
# order_legend = list(legend_map.values())
font_size = 12
out.mkdir(exist_ok=True, parents=True)
for k, title in exp_keys.items():
    ax = sns.lineplot(y=k, x="_step", data=tmp[(tmp["gamma"].isna())], hue="group_name")
    ax.set_title(title, weight='bold').set_fontsize(font_size)
    ax.set(xlabel="Epoch", ylabel="Value")
    # handles, labels = ax.get_legend_handles_labels()
    # labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: order_legend.index(t[0])))
    # ax.legend(handles, labels, title="Method")
    ax.legend(*zip(*sorted(zip(*ax.get_legend_handles_labels()),
                           key=lambda s: [int(t) if t.isdigit() else t.lower() for t in re.split('(\d+)', s[1])])))
    ax.ticklabel_format(useMathText=True, useOffset=False)
    sns.despine()
    plt.tight_layout()
    # display(ax)
    name = k.replace("\\", "_").replace("/", "_")
    ax.get_figure().savefig(out / f"{name}.pdf")
    plt.clf()

In [14]:
for ltn_config in ["focal_ltn", "focal_log_ltn"]:
    for k, title in exp_keys.items():
        ax = sns.lineplot(y=k, x="_step", data=tmp[(tmp["ltn_config"] == ltn_config)], hue="group_name")
        ax.set_title(title, weight='bold').set_fontsize(font_size)
        ax.set(xlabel="Epoch", ylabel="Value")
        # handles, labels = ax.get_legend_handles_labels()
        # labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: order_legend.index(t[0])))
        # ax.legend(handles, labels, title="Method")
        ax.legend(*zip(*sorted(zip(*ax.get_legend_handles_labels()),
                               key=lambda s: [int(t) if t.isdigit() else t.lower() for t in re.split('(\d+)', s[1])])))
        ax.ticklabel_format(useMathText=True, useOffset=False)
        sns.despine()
        plt.tight_layout()
        # display(ax)
        name = k.replace("\\", "_").replace("/", "_")
        ax.get_figure().savefig(out / f"gamma_{ltn_config}_{name}.pdf")
        plt.clf()

<Figure size 640x480 with 0 Axes>