In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from datasets import get_dataset

config_name = "cub_claim"
alpha = [0.0, 0.20]
configs = get_config(config_name, alpha=alpha)

dataset = get_dataset(configs[0], train=False, return_attribute=True)
classes, claims = dataset.classes, dataset.claims

results = [config.get_results() for config in configs]

listener_accuracy = [
    (r["prediction"] == r["action"].apply(lambda x: np.argmax(x))).mean()
    for r in results
]
print(
    "\n".join(
        [
            f"Speaker {i+1} (alpha = {alpha[i]:.2f}): listener accuracy ="
            f" {listener_accuracy[i]:.2%}"
            for i in range(len(alpha))
        ]
    )
)

sns.set_theme()
sns.set_context("paper")

In [None]:
def class_accuracy(results):
    label = results["label"].values.tolist()
    prediction = results["prediction"].values.tolist()
    listener_prediction = (
        results["action"].apply(lambda x: np.argmax(x)).values.tolist()
    )
    cls_cm = confusion_matrix(
        label, prediction, normalize="true", labels=range(len(classes))
    )
    listener_cm = confusion_matrix(
        label, listener_prediction, normalize="true", labels=range(len(classes))
    )
    cls_class_accuracy = np.diag(cls_cm)
    listener_class_accuracy = np.diag(listener_cm)
    return cls_class_accuracy, listener_class_accuracy


cls_class_accuracy, listener_class_accuracy = zip(*[class_accuracy(r) for r in results])

_, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
sns.scatterplot(x=listener_class_accuracy[0], y=listener_class_accuracy[1], ax=ax)
ax.plot([0, 1], [0, 1], color="black", linestyle="--")
ax.set_xlabel(f"Listener accuracy\n(alpha={alpha[0]:.2f})")
ax.set_ylabel(f"Listener accuracy\n(alpha={alpha[1]:.2f})")
plt.show()

In [None]:
m = 10
example_idx = np.random.choice(len(dataset), m, replace=False)
for idx in example_idx:
    image, label, image_attribute = dataset[idx]

    prediction = [results.iloc[idx]["prediction"].squeeze() for results in results]
    assert len(np.unique(prediction)) == 1
    prediction = prediction[0]

    explanations = [results.iloc[idx]["explanation"].squeeze() for results in results]
    listener_prediction = [
        np.argmax(results.iloc[idx]["action"]) for results in results
    ]

    _, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
    ax.imshow(image)
    ax.axis("off")
    ax.set_title(f"Label: {classes[label]}\nPrediction: {classes[prediction]}")
    plt.show()

    print(
        "\t\t\t\t||\t".join(
            [
                f"Speaker {i+1} (alpha = {config.speaker.alpha:.2f}):"
                for i, config in enumerate(configs)
            ]
        )
    )
    for claim_idx in range(configs[0].data.explanation_length):
        print(
            "\t||\t".join(
                [
                    f"\t{claims[explanation[claim_idx+1, 0]]:<40}"
                    f" ({explanation[claim_idx+1, 1]:d}/{image_attribute[explanation[claim_idx+1, 0]]:.0f})"
                    for explanation in explanations
                ]
            )
        )
    print(
        "\t||\t".join(
            [
                f"Listener prediction: {classes[prediction]:<30}"
                for prediction in listener_prediction
            ]
        )
    )