In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from datasets import get_dataset

config_name = "cub_claim"
alpha = 0.20
config = get_config(config_name, alpha=alpha)

train_dataset = get_dataset(config, train=True, return_attribute=True)
val_dataset = get_dataset(config, train=False, return_attribute=True)
classes = train_dataset.classes

results = config.get_results()

sns.set_theme()
sns.set_context("paper")

In [None]:
def get_class_attribute_similarity(image_attribute, label):
    def _class_similarity(class_attribute):
        class_attribute[class_attribute == -1] = 0

        intersection = class_attribute[:, None, :] * class_attribute[None, :, :]
        intersection = np.sum(intersection, axis=-1)

        union = (class_attribute[:, None, :] + class_attribute[None, :, :]) > 0
        union = np.sum(union, axis=-1)

        similarity = intersection / union
        similarity = similarity[np.triu_indices(len(similarity), k=1)]
        return np.mean(similarity)

    class_attribute = [image_attribute[label == idx] for idx, _ in enumerate(classes)]
    return list(map(_class_similarity, class_attribute))


train_image_attribute = train_dataset.image_attribute
train_label = np.array([label for _, label in train_dataset.samples])
train_similarity = get_class_attribute_similarity(train_image_attribute, train_label)

results_image_attribute = np.zeros((len(results), train_image_attribute.shape[1]))
explanation = np.stack([row["explanation"] for _, row in results.iterrows()])
explanation = explanation[:, 1:, :]
np.put_along_axis(
    results_image_attribute, explanation[..., 0], explanation[..., 1], axis=1
)
val_label = np.array([label for _, label in val_dataset.samples])
results_similarity = get_class_attribute_similarity(results_image_attribute, val_label)


_, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.3})
ax = axes[0]
sns.histplot(train_similarity, kde=True, ax=ax)
ax.set_xlabel("Within-class mean attribute similarity")
ax.set_xlim(None, 1)
ax.set_title(
    f"Train split ({np.mean(train_similarity):.2f} +- {np.std(train_similarity):.2f})"
)

ax = axes[1]
sns.histplot(results_similarity, kde=True, ax=ax)
ax.set_xlabel("Within-class mean attribute similarity")
ax.set_xlim(None, 1)
ax.set_title(
    f"Results ({np.mean(results_similarity):.2f} +- {np.std(results_similarity):.2f})"
)

ax = axes[2]
sns.scatterplot(x=train_similarity, y=results_similarity, ax=ax)
ax.plot([0, 1], [0, 1], color="black", linestyle="--")
ax.set_xlabel("Train split")
ax.set_ylabel("Results")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_title("Within-class mean attribute similarity")
plt.show()