In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
sys.path.append(root_dir)
from configs.utils import get_config
from test_utils import TestingResults
from datasets import get_dataset
from concept_lib import get_concepts

config_name, concept_type = "imagenette", "dataset"
config = get_config(config_name)

sns.set_style("white")
sns.set_context("paper")

In [None]:
kernel_scale = config.testing.kernel_scale

dataset = get_dataset(config)
classes = dataset.classes

for scale in kernel_scale:
    tau_max = config.testing.tau_max

    results = {
        _tau_max: TestingResults.load(
            config,
            root_dir,
            "global",
            concept_type=concept_type,
            kernel_scale=scale,
            tau_max=_tau_max,
        )
        for _tau_max in tau_max
    }

    _, axes = plt.subplots(
        2, 5, figsize=(16, 9), gridspec_kw={"wspace": 0.9, "hspace": 0.7}
    )

    for class_idx, class_name in enumerate(classes):
        _, concepts = get_concepts(config)

        class_results = {
            _tau_max: results.get(class_name, concepts=concepts)
            for _tau_max, results in results.items()
        }

        class_rank = np.empty((len(tau_max), len(concepts)))
        for i, (_, tau) in enumerate(class_results.values()):
            sorted_idx = np.argsort(tau)
            for j, idx in enumerate(sorted_idx):
                class_rank[i, idx] = j + 1

        ax = axes[class_idx // 5, class_idx % 5]
        palette = sns.color_palette(
            "ch:s=.25,rot=-.25", n_colors=len(concepts)
        ).as_hex()[::-1]
        for rank in range(len(concepts)):
            concept_idx = class_rank[-1].tolist().index(rank + 1)
            concept_ranks = class_rank[:, concept_idx]
            _rank = len(concepts) - concept_ranks + 1
            ax.plot(range(len(tau_max)), _rank, color=palette[rank])
            ax.annotate(
                concepts[concept_idx],
                (1.03, _rank[-1]),
                xycoords=("axes fraction", "data"),
                va="center",
            )
        ax.set_xlabel(r"$\tau^{\text{max}}$")
        ax.set_ylabel("Rank")
        ax.set_xlim(0, len(tau_max) - 1)
        ax.set_xticks(range(len(tau_max)))
        ax.set_xticklabels(tau_max)
        ax.set_yticks(range(1, len(concepts) + 1))
        ax.set_yticklabels(ax.get_yticks()[::-1])
        ax.set_title(class_name)
    plt.show()
    plt.close()