In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from datasets import get_dataset

config_name = "cub"
config_dict = {
    "data.explanation_length": 6,
    "speaker.alpha": 0.2,
    "listener.type": "claim",
}
config = get_config(config_name, config_dict=config_dict)

train_dataset = get_dataset(config, train=True, return_attribute=True)
val_dataset = get_dataset(config, train=False, return_attribute=True)
classes = train_dataset.classes

results = config.get_results()

sns.set_theme()
sns.set_context("paper")

In [None]:
def get_class_attribute_similarity(image_attribute, label):
    def _class_similarity(class_attribute):
        similarity = []
        for class_label in [0, 1]:
            mask = class_attribute == class_label

            intersection = mask[:, None] * mask[None, :]
            intersection = np.sum(intersection, axis=-1)

            union = (mask[:, None] + mask[None, :]) > 0
            union = np.sum(union, axis=-1)

            _similarity = intersection / union
            _similarity = _similarity[np.triu_indices(len(_similarity), k=1)]
            similarity.append(np.mean(_similarity))
        return similarity

    class_attribute = [image_attribute[label == idx] for idx, _ in enumerate(classes)]
    return list(zip(*map(_class_similarity, class_attribute)))


train_image_attribute = train_dataset.image_attribute
train_label = np.array([label for _, label in train_dataset.samples])
train_neg_similarity, train_pos_similarity = get_class_attribute_similarity(
    train_image_attribute, train_label
)

results_image_attribute = np.zeros((len(results), train_image_attribute.shape[1]))
explanation = np.stack([row["explanation"] for _, row in results.iterrows()])
explanation = explanation[:, 1:, :]
np.put_along_axis(
    results_image_attribute, explanation[..., 0], explanation[..., 1], axis=1
)
val_label = np.array([label for _, label in val_dataset.samples])
results_neg_similarity, results_pos_similarity = get_class_attribute_similarity(
    results_image_attribute, val_label
)
results_similarity = np.stack([results_neg_similarity, results_pos_similarity], axis=-1)
results_similarity = np.mean(results_similarity, axis=-1)

_, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.3})
ax = axes[0]
sns.histplot(train_neg_similarity, kde=True, label="negative", ax=ax)
sns.histplot(train_pos_similarity, kde=True, label="positive", ax=ax)
ax.set_xlabel("Within-class mean attribute similarity")
ax.set_xlim(None, 1)
ax.set_title("Train data")

ax = axes[1]
sns.histplot(results_neg_similarity, kde=True, label="negative", ax=ax)
sns.histplot(results_pos_similarity, kde=True, label="positive", ax=ax)
ax.set_xlabel("Within-class mean attribute similarity")
ax.set_xlim(None, 1)
ax.set_title("Utterance data")

ax = axes[2]
sns.scatterplot(
    x=train_neg_similarity, y=results_neg_similarity, label="negative", ax=ax
)
sns.scatterplot(
    x=train_pos_similarity, y=results_pos_similarity, label="positive", ax=ax
)
ax.plot([0, 1], [0, 1], color="black", linestyle="--")
ax.set_xlabel("Train data")
ax.set_ylabel("Utterance data")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_title("Within-class mean attribute similarity")
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), title="attributes")
plt.show()

In [None]:
results["listener_prediction"] = results["action"].apply(lambda x: np.argmax(x))
results["listener_correct"] = results["listener_prediction"] == results["prediction"]

speaker_class_accuracy = results.groupby("label")["listener_correct"].mean()

_, ax = plt.subplots(figsize=(3, 3))
sns.regplot(
    x=results_similarity,
    y=speaker_class_accuracy,
    ax=ax,
    scatter_kws={"s": 8, "alpha": 0.5},
)
ax.set_xlabel("Within-class mean attribute similarity")
ax.set_ylabel("Listener accuracy")
ax.set_title("Utterance data")
plt.show()