In [None]:
import os
import sys
import torch
import numpy as np
import ml_collections
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.transforms import Compose, CenterCrop, ToTensor
from torchvision.utils import make_grid

root_dir = "../"
sys.path.append(root_dir)
import configs
import datasets
from ibydmt.utils.config import get_config
from ibydmt.utils.concept_data import get_dataset_with_concepts
from ibydmt.classifiers import ZeroShotClassifier
from ibydmt.samplers import cKDE
from models.clip_classifier import CLIPClassifier
from notebooks.viz_ckde_utils import viz_cond_pdf, viz_local_dist

config_name = "imagenette"
config = get_config(config_name)

ckde = ml_collections.ConfigDict()
ckde.metric = config.ckde.metric
ckde.scale_method = "neff"
ckde.scale = 2000

config.ckde = ckde

transform = Compose([CenterCrop(224), ToTensor()])
dataset = get_dataset_with_concepts(config, train=False)
classes, concepts = dataset.classes, dataset.concepts

model = cKDE(config)
classifier = CLIPClassifier.load_or_train(config, root_dir)

figure_dir = os.path.join(root_dir, "figures", config.name.lower(), "ckde")
os.makedirs(figure_dir, exist_ok=True)

sns.set_style("white")
sns.set_context("paper")

In [None]:
class_idx = {
    "tench": [330, 117],
    "English springer": [715, 571],
    "cassette player": [1041, 981],
    "chainsaw": [1207, 1289],
    "church": [1831, 1599],
    "French horn": [2306, 2180],
    "garbage truck": [2425, 2376],
    "gas pump": [3016, 2889],
    "golf ball": [3439, 3320],
    "parachute": [3696, 3656],
}

Z = dataset.Z
for class_name, idx in class_idx.items():
    k = classes.index(class_name)

    def _classifier(h):
        return classifier(h)[:, k]

    for _idx in idx:
        idx_figure_dir = os.path.join(figure_dir, f"{class_name}_{_idx}")
        os.makedirs(idx_figure_dir, exist_ok=True)

        image, _, z, y = dataset[_idx]

        _, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
        ax.imshow(image.permute(1, 2, 0))
        ax.set_title(f"Class: {class_name}")
        ax.axis("off")
        plt.savefig(
            os.path.join(idx_figure_dir, f"image_{model.scale}.pdf"),
            bbox_inches="tight",
        )
        plt.savefig(
            os.path.join(idx_figure_dir, f"image_{model.scale}.png"),
            bbox_inches="tight",
        )
        plt.show()

        # viz_cond_pdf(model, z, class_name, concepts, idx_figure_dir)
        viz_local_dist(model, _classifier, z, class_name, concepts, idx_figure_dir)

In [None]:
entropy = np.zeros((len(classes), len(concepts)))

for i, class_name in enumerate(classes):
    for j, concept in enumerate(concepts):
        _results = results[(class_name, concept)]

        _entropy = []
        for idx, nn_idx in _results.items():
            unique, counts = np.unique(nn_idx, return_counts=True)
            p = counts / np.sum(counts)
            _entropy.append(-np.sum(p * np.log(p)))

        entropy[i, j] = np.mean(_entropy)

_, ax = plt.subplots(figsize=(16, 9 / 16 * len(classes)))
sns.heatmap(
    entropy, annot=True, fmt=".2f", ax=ax, xticklabels=concepts, yticklabels=classes
)
plt.show()

In [None]:
sorted_concepts_idx = np.argsort(entropy, axis=-1)

for i, class_name in enumerate(classes):
    class_idx = []
    for concept in concepts:
        class_idx.extend(list(results[(class_name, concept)].keys()))
    class_idx = list(set(class_idx))

    n, k, m = 5, 4, 4

    idx = np.random.choice(class_idx, n, replace=False)
    image = torch.stack([dataset[_idx][0] for _idx in idx])

    top_k_concepts_idx = sorted_concepts_idx[i, -k:]
    top_k_concepts = [concepts[j] for j in top_k_concepts_idx[::-1]]

    _, axes = plt.subplots(
        1,
        k + 1,
        figsize=(16, 9),
        width_ratios=[1] + k * [m],
        gridspec_kw={"wspace": 0.05},
    )
    ax = axes[0]
    im = make_grid(image, nrow=1)
    ax.imshow(im.permute(1, 2, 0))
    ax.axis("off")
    ax.set_title(f"{class_name}")

    for j, concept in enumerate(top_k_concepts):
        nn = torch.zeros(n * m, 3, 128, 128)
        for i, _idx in enumerate(idx):
            nn_idx = results[(class_name, concept)][_idx]
            nn_idx = np.unique(nn_idx)
            nn_idx = np.random.choice(nn_idx, m, replace=True)

            for p, nn_idx in enumerate(nn_idx):
                nn[i * m + p] = model.dataset[nn_idx][0]

        ax = axes[j + 1]
        im = make_grid(nn, nrow=m)
        ax.imshow(im.permute(1, 2, 0))
        ax.axis("off")
        ax.set_title(concept)

    figure_path = os.path.join(figure_dir, f"{class_name}_{concept_name}")
    plt.savefig(f"{figure_path}.pdf", bbox_inches="tight")
    plt.savefig(f"{figure_path}.png", bbox_inches="tight")
    plt.show()
    plt.close()