In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import weightedtau
from tqdm import tqdm

root_dir = "../"
sys.path.append(root_dir)
import configs
import datasets
from ibydmt.utils.config import get_config
from ibydmt.utils.data import get_dataset
from ibydmt.utils.result import TestingResults
from ibydmt.utils.pcbm import PCBM, LaBo
from ibydmt.tester import sweep, get_test_classes

config_name, concept_type = "awa2", "class"
results_kw = {"ckde.scale": 4000, "testing.kernel_scale": 0.9, "testing.tau_max": 800}

config = get_config(config_name)
backbone_configs = sweep(config, sweep_keys=["data.backbone"])
dataset = get_dataset(config)
classes = dataset.classes

test_classes = get_test_classes(config)
test_classes_idx = [classes.index(c) for c in test_classes]

sns.set_style()
sns.set_context("paper")

In [None]:
results = {}
for i, backbone_config in enumerate(tqdm(backbone_configs)):
    backbone = backbone_config.data.backbone
    backbone_results = TestingResults.load(
        backbone_config, "global_cond", concept_type, results_kw=results_kw
    )

    results[backbone] = {}
    for class_name in test_classes:
        (
            _,
            sorted_concepts,
            sorted_rejected,
            sorted_tau,
            sorted_importance,
        ) = backbone_results.sort(
            class_name,
            fdr_control=True,
            with_importance=True,
        )
        results[backbone][class_name] = {
            "concepts": sorted_concepts,
            "rejected": sorted_rejected,
            "tau": sorted_tau,
            "importance": sorted_importance,
        }

In [None]:
figure_dir = os.path.join(
    root_dir, "figures", config.name.lower(), "global_cond", "compare"
)
os.makedirs(figure_dir, exist_ok=True)

rank_agreement = np.zeros(
    (len(test_classes), len(backbone_configs), len(backbone_configs))
)

for i, (backbone1, results1) in enumerate(results.items()):
    for j, (backbone2, results2) in enumerate(results.items()):
        for k, class_name in enumerate(test_classes):
            concepts1 = results1[class_name]["concepts"]
            concepts2 = results2[class_name]["concepts"]
            assert set(concepts1) == set(concepts2)

            idx1 = np.arange(len(concepts1))
            idx2 = np.array([concepts1.index(concept) for concept in concepts2])
            rank_agreement[k, i, j] = weightedtau(idx1, idx2, rank=False).statistic

rank_agreement_mu = np.mean(rank_agreement, axis=0)
rank_agreement_std = np.std(rank_agreement, axis=0)
annot = np.array(
    [
        f"{mu:.2f}\n(±{std:.2f})"
        for mu, std in zip(rank_agreement_mu.flatten(), rank_agreement_std.flatten())
    ]
).reshape(rank_agreement_mu.shape)

_, ax = plt.subplots(figsize=(5, 5))
sns.heatmap(
    rank_agreement_mu,
    vmin=-1.0,
    vmax=1.0,
    ax=ax,
    annot=annot,
    cmap="mako",
    fmt="",
    linecolor="black",
    linewidths=0.5,
    cbar_kws={"label": "Weighted Kendall's tau"},
    annot_kws={"fontsize": 7},
)
ax.axis("on")
ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
ax.set_xticklabels(list(results.keys()), rotation=45, ha="left")
ax.set_yticklabels(list(results.keys()), rotation=0)
plt.savefig(
    os.path.join(
        figure_dir,
        f"{results_kw['testing.tau_max']}_{results_kw['ckde.scale']}_rank.pdf",
    ),
    bbox_inches="tight",
)
plt.savefig(
    os.path.join(
        figure_dir,
        f"{results_kw['testing.tau_max']}_{results_kw['ckde.scale']}_rank.jpg",
    ),
    bbox_inches="tight",
)
plt.show()
print(rank_agreement_mu.mean())

In [None]:
pcbm_results = {}
for backbone_config in backbone_configs:
    backbone = backbone_config.data.backbone

    pcbm_results[backbone] = {}
    for class_idx, class_name in zip(test_classes_idx, test_classes):
        concept_class_name = None
        if concept_type == "class":
            concept_class_name = class_name

        pcbm = PCBM.load_or_train(
            backbone_config, concept_class_name=concept_class_name
        )
        concepts = pcbm.concepts

        weights = pcbm.weights()
        class_weights = weights[class_idx]
        # class_weights = np.abs(class_weights)
        class_weights /= np.sum(class_weights)

        sorted_idx = np.argsort(class_weights)[::-1]
        sorted_concepts = [concepts[i] for i in sorted_idx]
        sorted_weights = class_weights[sorted_idx]

        pcbm_results[backbone][class_name] = {
            "concepts": sorted_concepts,
            "weights": sorted_weights,
        }

In [None]:
figure_dir = os.path.join(
    root_dir, "figures", config.name.lower(), "global_cond", "compare"
)
os.makedirs(figure_dir, exist_ok=True)

pcbm_rank_agreement = np.zeros(
    (len(test_classes), len(backbone_configs), len(backbone_configs))
)

for i, (backbone1, results1) in enumerate(pcbm_results.items()):
    for j, (backbone2, results2) in enumerate(pcbm_results.items()):
        for k, class_name in enumerate(test_classes):
            concepts1 = results1[class_name]["concepts"]
            concepts2 = results2[class_name]["concepts"]
            assert set(concepts1) == set(concepts2)

            idx1 = np.arange(len(concepts1))
            idx2 = np.array([concepts1.index(concept) for concept in concepts2])
            pcbm_rank_agreement[k, i, j] = weightedtau(idx1, idx2, rank=False).statistic

pcbm_rank_agreement_mu = np.mean(pcbm_rank_agreement, axis=0)
pcbm_rank_agreement_std = np.std(pcbm_rank_agreement, axis=0)
annot = np.array(
    [
        f"{mu:.2f}\n(±{std:.2f})"
        for mu, std in zip(
            pcbm_rank_agreement_mu.flatten(), pcbm_rank_agreement_std.flatten()
        )
    ]
).reshape(pcbm_rank_agreement_mu.shape)

_, ax = plt.subplots(figsize=(5, 5))
sns.heatmap(
    pcbm_rank_agreement_mu,
    vmin=-1.0,
    vmax=1.0,
    ax=ax,
    annot=annot,
    cmap="mako",
    fmt="",
    linecolor="black",
    linewidths=0.5,
    cbar_kws={"label": "Weighted Kendall's tau"},
    annot_kws={"fontsize": 7},
)
ax.axis("on")
ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
ax.set_xticklabels(list(pcbm_results.keys()), rotation=45, ha="left")
ax.set_yticklabels(list(pcbm_results.keys()), rotation=0)
plt.savefig(
    os.path.join(figure_dir, f"pcbm_{concept_type}_rank.pdf"), bbox_inches="tight"
)
plt.savefig(
    os.path.join(figure_dir, f"pcbm_{concept_type}_rank.jpg"), bbox_inches="tight"
)
plt.show()
print(pcbm_rank_agreement_mu.mean())

In [None]:
importance_agreement = np.zeros(
    (len(classes), len(backbone_configs), len(backbone_configs))
)

for i, (backbone1, results1) in enumerate(results.items()):
    for j, (backbone2, results2) in enumerate(results.items()):
        for class_idx, class_name in enumerate(classes):
            concepts1 = results1[class_name]["concepts"]
            concepts2 = results2[class_name]["concepts"]
            assert set(concepts1) == set(concepts2)

            importance1 = results1[class_name]["importance"]
            importance2 = results2[class_name]["importance"]

            importance_agreement[class_idx, i, j] = np.mean(
                [
                    i1 == importance2[concepts2.index(c1)]
                    for c1, i1 in zip(concepts1, importance1)
                ]
            )

importance_agreement_mu = np.mean(importance_agreement, axis=0)
importance_agreement_std = np.std(importance_agreement, axis=0)
annot = np.array(
    [
        f"{mu:.2f}\n(±{std:.2f})"
        for mu, std in zip(
            importance_agreement_mu.flatten(), importance_agreement_std.flatten()
        )
    ]
).reshape(importance_agreement_mu.shape)

_, ax = plt.subplots(figsize=(5, 5))
sns.heatmap(
    importance_agreement_mu,
    vmin=0.0,
    vmax=1.0,
    ax=ax,
    annot=annot,
    cmap="mako",
    fmt="",
    linecolor="black",
    linewidths=0.5,
    cbar_kws={"label": "Importance agreement"},
    annot_kws={"fontsize": 7},
)
ax.axis("on")
ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
ax.set_xticklabels(list(results.keys()), rotation=45, ha="left")
ax.set_yticklabels(list(results.keys()), rotation=0)
plt.savefig(
    os.path.join(
        figure_dir,
        f"{results_kw['testing.tau_max']}_{results_kw['ckde.scale']}_importance.pdf",
    ),
    bbox_inches="tight",
)
plt.savefig(
    os.path.join(
        figure_dir,
        f"{results_kw['testing.tau_max']}_{results_kw['ckde.scale']}_importance.jpg",
    ),
    bbox_inches="tight",
)
plt.show()