In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from datasets import get_dataset

config_name = "cub_distribution"
config = get_config(config_name)

train_dataset = get_dataset(config, train=True, return_attribute=True)
test_dataset = get_dataset(config, train=False, return_attribute=True)
assert train_dataset.classes == test_dataset.classes
classes = train_dataset.classes

sns.set_theme()
sns.set_context("paper")

In [None]:
def get_class_attribute_similarity(image_attribute, label):
    def _class_similarity(class_attribute):
        class_attribute = np.array(class_attribute)
        class_attribute[class_attribute == -1] = 0
        class_attribute = class_attribute.astype(bool)

        intersection = class_attribute[:, None, :] * class_attribute[None, :, :]
        intersection = np.sum(intersection, axis=-1)

        union = class_attribute[:, None, :] + class_attribute[None, :, :]
        union = np.sum(union, axis=-1)
        return intersection / union

    class_attribute = [
        image_attribute[label == class_idx] for class_idx, _ in enumerate(classes)
    ]
    return list(map(_class_similarity, class_attribute))


train_image_attribute = train_dataset.image_attribute
train_label = np.array([label for _, label in train_dataset.samples])
train_similarity = get_class_attribute_similarity(train_image_attribute, train_label)
train_mean_class_similarity = [
    np.mean(class_similarity) for class_similarity in train_similarity
]

test_image_attribute = test_dataset.image_attribute
test_label = np.array([label for _, label in test_dataset.samples])
test_similarity = get_class_attribute_similarity(test_image_attribute, test_label)
test_mean_class_similarity = [
    np.mean(class_similarity) for class_similarity in test_similarity
]

results = config.get_results()
explanation = np.array(results["explanation"].values.tolist()).astype(int)
results_image_attribute = np.zeros((explanation.shape[0], np.amax(explanation) + 1))
np.put_along_axis(results_image_attribute, explanation, 1, axis=-1)
results_image_attribute = results_image_attribute[:, :-3]
results_similarity = get_class_attribute_similarity(results_image_attribute, test_label)
results_mean_class_similarity = [
    np.mean(class_similarity) for class_similarity in results_similarity
]

_, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.25})
ax = axes[0]
sns.histplot(train_mean_class_similarity, kde=True, ax=ax)
ax.set_xlabel("Within-class mean attribute similarity")
ax.set_xlim(None, 1)
ax.set_title(
    f"Train split\n({np.mean(train_mean_class_similarity):.2f} +-"
    f" {np.std(train_mean_class_similarity):.2f})"
)

ax = axes[1]
sns.histplot(test_mean_class_similarity, kde=True, ax=ax)
ax.set_xlabel("Within-class mean attribute similarity")
ax.set_xlim(None, 1)
ax.set_title(
    f"Test split\n({np.mean(test_mean_class_similarity):.2f} +-"
    f" {np.std(test_mean_class_similarity):.2f})"
)

ax = axes[2]
sns.histplot(results_mean_class_similarity, kde=True, ax=ax)
ax.set_xlabel("Within-class mean attribute similarity")
ax.set_xlim(None, 1)
ax.set_title(
    f"Results\n({np.mean(results_mean_class_similarity):.2f} +-"
    f" {np.std(results_mean_class_similarity):.2f})"
)
plt.show()

_, axes = plt.subplots(1, 2, figsize=(16 / 2, 9 / 4))
ax = axes[0]
sns.scatterplot(x=train_mean_class_similarity, y=test_mean_class_similarity, ax=ax)
ax.plot([0, 1], [0, 1], color="black", linestyle="--")
ax.set_xlabel("Train split")
ax.set_ylabel("Test split")
ax.set_title("Within-class mean attribute similarity")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

ax = axes[1]
sns.scatterplot(x=test_mean_class_similarity, y=results_mean_class_similarity, ax=ax)
ax.plot([0, 1], [0, 1], color="black", linestyle="--")
ax.set_xlabel("Test split")
ax.set_ylabel("Results")
ax.set_title("Within-class mean attribute similarity")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.show()

In [None]:
train_sorted_idx = np.argsort(train_mean_class_similarity)
test_sorted_idx = np.argsort(test_mean_class_similarity)
results_sorted_idx = np.argsort(results_mean_class_similarity)

print("Most diverse classes:")
m = 10
for train_idx, test_idx, results_idx in zip(
    train_sorted_idx[:m], test_sorted_idx[:m], results_sorted_idx[:m]
):
    train_similarity = train_mean_class_similarity[train_idx]
    test_similarity = test_mean_class_similarity[test_idx]
    results_similarity = results_mean_class_similarity[results_idx]
    print(
        f"\t{classes[train_idx]:<30} {classes[test_idx]:<30} {classes[results_idx]:<30}"
    )