In [None]:
import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

root_dir = "../"
sys.path.append(root_dir)
import configs
import datasets
from ibydmt.utils.config import get_config
from ibydmt.utils.compare import rank_agreement

config_name, concept_type = "awa2", "class"
global_kw = {"testing.kernel_scale": 0.9, "testing.tau_max": 400}
global_cond_kw = {
    "ckde.scale": 2000,
    "testing.kernel_scale": 0.9,
    "testing.tau_max": 800,
}

config = get_config(config_name)
_, global_rank_agreement = rank_agreement(
    config, "global", concept_type, results_kw=global_kw
)
_, global_cond_rank_agreement = rank_agreement(
    config, "global_cond", concept_type, results_kw=global_cond_kw
)
_, pcbm_rank_agreement = rank_agreement(config, "global_cond", concept_type, pcbm=True)

sns.set_theme()
sns.set_context("paper")

In [None]:
figure_dir = os.path.join(root_dir, "figures", "awa2")
os.makedirs(figure_dir, exist_ok=True)

compare_dict = {
    "SKIT": global_rank_agreement,
    "c-SKIT": global_cond_rank_agreement,
    "PCBM": pcbm_rank_agreement,
}

df_method = []
df_agreement = []
for method_name, agreement in compare_dict.items():
    n_backbone = agreement.shape[1]
    for i in range(n_backbone):
        for j in range(n_backbone):
            if i == j:
                continue
            class_agreement = agreement[:, i, j].flatten()
            df_method.extend([method_name] * len(class_agreement))
            df_agreement.extend(class_agreement)

df = pd.DataFrame({"method": df_method, "agreement": df_agreement})

_, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
sns.boxplot(data=df, x="method", y="agreement", hue="method", ax=ax)
ax.set_xlabel("")
ax.set_ylabel("Rank agreement")
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(
    [
        f"{method_name}\n({agreement.mean():.2f})"
        for method_name, agreement in compare_dict.items()
    ]
)
plt.savefig(os.path.join(figure_dir, "rank_agreement_dist.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "rank_agreement_dist.png"), bbox_inches="tight")
plt.show()