In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import weightedtau
from tqdm import tqdm

root_dir = "../"
sys.path.append(root_dir)
import configs
import datasets
from ibydmt.utils.config import get_config
from ibydmt.utils.data import get_dataset
from ibydmt.utils.result import TestingResults
from ibydmt.tester import sweep

config_name = "imagenette"
results_kw = {"testing.kernel_scale": 0.9, "testing.tau_max": 1600}

config = get_config(config_name)
backbone_configs = sweep(config, sweep_keys=["data.backbone"])
dataset = get_dataset(config)
classes = dataset.classes

sns.set_style()
sns.set_context("paper")

In [None]:
results = {}
for i, backbone_config in enumerate(tqdm(backbone_configs)):
    backbone = backbone_config.data.backbone
    backbone_results = TestingResults.load(
        backbone_config, "global", "dataset", results_kw=results_kw
    )

    results[backbone] = {}
    for class_name in classes:
        _, sorted_concepts, sorted_rejected, sorted_tau = backbone_results.sort(
            class_name, fdr_control=True
        )
        results[backbone][class_name] = {
            "concepts": sorted_concepts,
            "rejected": sorted_rejected,
            "tau": sorted_tau,
        }

In [None]:
figure_dir = os.path.join(root_dir, "figures", config.name.lower(), "global", "compare")
os.makedirs(figure_dir, exist_ok=True)

rank_agreement = np.zeros((len(classes), len(backbone_configs), len(backbone_configs)))

for i, (backbone1, results1) in enumerate(results.items()):
    for j, (backbone2, results2) in enumerate(results.items()):
        for class_idx, class_name in enumerate(classes):
            concepts1 = results1[class_name]["concepts"]
            concepts2 = results2[class_name]["concepts"]
            assert set(concepts1) == set(concepts2)

            idx1 = np.arange(len(concepts1))
            idx2 = np.array([concepts1.index(concept) for concept in concepts2])
            rank_agreement[class_idx, i, j] = weightedtau(
                idx1, idx2, rank=False
            ).statistic

rank_agreement_mu = np.mean(rank_agreement, axis=0)
rank_agreement_std = np.std(rank_agreement, axis=0)
annot = np.array(
    [
        f"{mu:.2f}\n(±{std:.2f})"
        for mu, std in zip(rank_agreement_mu.flatten(), rank_agreement_std.flatten())
    ]
).reshape(rank_agreement_mu.shape)

_, ax = plt.subplots(figsize=(5, 5))
sns.heatmap(
    rank_agreement_mu,
    vmin=-1.0,
    vmax=1.0,
    ax=ax,
    annot=annot,
    cmap="mako",
    fmt="",
    linecolor="black",
    linewidths=0.5,
    cbar_kws={"label": "Weighted Kendall's tau"},
    annot_kws={"fontsize": 7},
)
ax.axis("on")
ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
ax.set_xlabel(r"$\tau^{\max} = %d$" % results_kw["testing.tau_max"])
ax.set_xticklabels(list(results.keys()), rotation=45, ha="left")
ax.set_yticklabels(list(results.keys()), rotation=0)
plt.savefig(
    os.path.join(figure_dir, f"{results_kw['testing.tau_max']}_kendall.pdf"),
    bbox_inches="tight",
)
plt.savefig(
    os.path.join(figure_dir, f"{results_kw['testing.tau_max']}_kendall.jpg"),
    bbox_inches="tight",
)
plt.show()