In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import weightedtau
from tqdm import tqdm

root_dir = "../"
sys.path.append(root_dir)
import configs
import datasets
from ibydmt.utils.config import get_config
from ibydmt.utils.data import get_dataset
from ibydmt.utils.result import TestingResults
from ibydmt.tester import sweep, get_local_test_idx

config_name = "cub"
results_kw = {"testing.kernel_scale": 0.5, "testing.tau_max": 200}

config = get_config(config_name)
backbone_configs = sweep(config, sweep_keys=["data.backbone"])
dataset = get_dataset(config)
classes = dataset.classes

test_idx = get_local_test_idx(config)
n_test_images = sum([len(v) for v in test_idx.values()])

figure_dir = os.path.join(
    root_dir, "figures", config.name.lower(), "local_cond", "compare"
)
os.makedirs(figure_dir, exist_ok=True)

sns.set_style()
sns.set_context("paper")

In [None]:
cardinality = 2

results = {}
for i, backbone_config in enumerate(tqdm(backbone_configs)):
    backbone = backbone_config.data.backbone

    backbone_results = TestingResults.load(
        backbone_config, "local_cond", "image", results_kw=results_kw
    )

    results[backbone] = {}
    for class_name, class_test_idx in test_idx.items():
        for idx in class_test_idx:
            (
                _,
                sorted_concepts,
                sorted_rejected,
                sorted_tau,
                sorted_importance,
            ) = backbone_results.sort(
                class_name,
                idx=idx,
                cardinality=cardinality,
                fdr_control=True,
                with_importance=True,
            )
            results[backbone][idx] = {
                "concepts": sorted_concepts,
                "rejected": sorted_rejected,
                "tau": sorted_tau,
                "importance": sorted_importance,
            }

In [None]:
importance_agreement = np.zeros(
    (n_test_images, len(backbone_configs), len(backbone_configs))
)

for i, (backbone1, results1) in enumerate(results.items()):
    for j, (backbone2, results2) in enumerate(results.items()):
        test_images1 = results1.keys()
        test_images2 = results2.keys()
        assert set(test_images1) == set(test_images2)
        test_images = test_images1

        for k, idx in enumerate(test_images):
            concepts1 = results1[idx]["concepts"]
            importance1 = results1[idx]["importance"]

            concepts2 = results2[idx]["concepts"]
            importance2 = results2[idx]["importance"]

            importance_agreement[k, i, j] = np.mean(
                [
                    i1 == importance2[concepts2.index(c1)]
                    for c1, i1 in zip(concepts1, importance1)
                ]
            )

importance_agreement_mu = np.mean(importance_agreement, axis=0)
importance_agreement_std = np.std(importance_agreement, axis=0)
annot = np.array(
    [
        f"{mu:.2f}\n(±{std:.2f})"
        for mu, std in zip(
            importance_agreement_mu.flatten(), importance_agreement_std.flatten()
        )
    ]
).reshape(importance_agreement_mu.shape)

_, ax = plt.subplots(figsize=(5, 5))
sns.heatmap(
    importance_agreement_mu,
    vmin=0.0,
    vmax=1.0,
    ax=ax,
    annot=annot,
    cmap="mako",
    fmt="",
    linecolor="black",
    linewidths=0.5,
    cbar_kws={"label": "Importance agreement"},
    annot_kws={"fontsize": 7},
)
ax.axis("on")
ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
ax.set_xlabel(r"$s = %d$" % cardinality)
ax.set_xticklabels(list(results.keys()), rotation=45, ha="left")
ax.set_yticklabels(list(results.keys()), rotation=0)
plt.savefig(
    os.path.join(
        figure_dir, f"{results_kw['testing.tau_max']}_{cardinality}_importance.pdf"
    ),
    bbox_inches="tight",
)
plt.savefig(
    os.path.join(
        figure_dir, f"{results_kw['testing.tau_max']}_{cardinality}_importance.jpg"
    ),
    bbox_inches="tight",
)
plt.show()
print(importance_agreement_mu.mean())

In [None]:
rank_agreement = np.zeros((n_test_images, len(backbone_configs), len(backbone_configs)))

for i, (backbone1, results1) in enumerate(results.items()):
    for j, (backbone2, results2) in enumerate(results.items()):
        test_images1 = results1.keys()
        test_images2 = results2.keys()
        assert set(test_images1) == set(test_images2)
        test_images = test_images1

        for k, idx in enumerate(test_images):
            concepts1 = results1[idx]["concepts"]
            concepts2 = results2[idx]["concepts"]
            assert set(concepts1) == set(concepts2)

            idx1 = np.arange(len(concepts1))
            idx2 = np.array([concepts1.index(concept) for concept in concepts2])
            rank_agreement[k, i, j] = weightedtau(idx1, idx2, rank=False).statistic

rank_agreement_mu = np.mean(rank_agreement, axis=0)
rank_agreement_std = np.std(rank_agreement, axis=0)
annot = np.array(
    [
        f"{mu:.2f}\n(±{std:.2f})"
        for mu, std in zip(rank_agreement_mu.flatten(), rank_agreement_std.flatten())
    ]
).reshape(rank_agreement_mu.shape)

_, ax = plt.subplots(figsize=(5, 5))
sns.heatmap(
    rank_agreement_mu,
    vmin=-1.0,
    vmax=1.0,
    ax=ax,
    annot=annot,
    cmap="mako",
    fmt="",
    linecolor="black",
    linewidths=0.5,
    cbar_kws={"label": "Weighted Kendall's tau"},
    annot_kws={"fontsize": 7},
)
ax.axis("on")
ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
ax.set_xlabel(r"$s = %d$" % cardinality)
ax.set_xticklabels(list(results.keys()), rotation=45, ha="left")
ax.set_yticklabels(list(results.keys()), rotation=0)
plt.savefig(
    os.path.join(figure_dir, f"{results_kw['testing.tau_max']}_{cardinality}_rank.pdf"),
    bbox_inches="tight",
)
plt.savefig(
    os.path.join(figure_dir, f"{results_kw['testing.tau_max']}_{cardinality}_rank.jpg"),
    bbox_inches="tight",
)
plt.show()
print(rank_agreement_mu.mean())