In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.special import softmax

root_dir = "../"
sys.path.append(root_dir)
from datasets import CUB
import configs

data_dir = os.path.join(root_dir, "data")
results_dir = os.path.join(root_dir, "results")

dataset = CUB(data_dir, train=False, return_attribute=True)
classes = dataset.classes
claims = dataset.claims

sns.set_theme()
sns.set_context("paper")

In [None]:
def sort_action(action):
    probs = softmax(action)
    sorted_listener_probs_idx = np.argsort(probs)[::-1]
    sorted_listener_probs = probs[sorted_listener_probs_idx]
    sorted_listener_classes = [classes[i] for i in sorted_listener_probs_idx]
    return sorted_listener_probs, sorted_listener_classes


speaker_config = configs.CUBClaimSpeaker()
pragmatic_speaker_config = configs.CUBPragmaticClaimSpeaker()

results = speaker_config.get_results(workdir=root_dir)
pragmatic_results = pragmatic_speaker_config.get_results(workdir=root_dir)

image_idx = np.array(results["idx"].values.tolist()).astype(int)
pragmatic_image_idx = np.array(pragmatic_results["idx"].values.tolist()).astype(int)
assert np.all(image_idx == pragmatic_image_idx)

label = np.array(results["label"].values.tolist()).astype(int)
pragmatic_label = np.array(pragmatic_results["label"].values.tolist())
assert np.all(label == pragmatic_label)

prediction = np.array(results["prediction"].values.tolist()).astype(int)
pragmatic_prediction = np.array(pragmatic_results["prediction"].values.tolist())
assert np.all(prediction == pragmatic_prediction)

explanation = np.array(results["explanation"].values.tolist()).astype(int)
pragmatic_explanation = np.array(
    pragmatic_results["explanation"].values.tolist()
).astype(int)

listener_action = np.array(results["listener_action"].values.tolist())
pragmatic_listener_action = np.array(
    pragmatic_results["listener_action"].values.tolist()
)

listener_prediction = np.argmax(listener_action, axis=-1)
pragmatic_listener_prediction = np.argmax(pragmatic_listener_action, axis=-1)

correct_classifier = prediction = label
wrong_listener = listener_prediction != prediction
correct_pragmatic_listener = pragmatic_listener_prediction == prediction
idx = np.where(correct_classifier & wrong_listener & correct_pragmatic_listener)[0]

m = 10
idx = np.random.choice(idx, m, replace=False)


for _idx in idx:
    _image, _label, _attribute = dataset[_idx]
    _prediction = prediction[_idx]
    _listener_action = listener_action[_idx]
    _pragmatic_listener_action = pragmatic_listener_action[_idx]
    _explanation = explanation[_idx]
    _pragmatic_explanation = pragmatic_explanation[_idx]

    sorted_listener_probs, sorted_listener_classes = sort_action(_listener_action)
    sorted_pragmatic_listener_probs, sorted_pragmatic_listener_classes = sort_action(
        _pragmatic_listener_action
    )

    explanation_claims = [claims[i] for i in _explanation if i < len(claims)]
    pragmatic_explanation_claims = [
        claims[i] for i in _pragmatic_explanation if i < len(claims)
    ]

    intersection = np.intersect1d(_explanation, _pragmatic_explanation)
    pragmatic_diff = np.setdiff1d(_pragmatic_explanation, _explanation)
    shared_claims = [claims[i] for i in intersection if i < len(claims)]
    unique_pragmatic_claims = [claims[i] for i in pragmatic_diff if i < len(claims)]

    _, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.3})
    ax = axes[0]
    ax.imshow(_image)
    ax.axis("off")
    ax.set_title(f"Label: {classes[_label]}\nPrediction: {classes[_prediction]}")

    k = 5
    ax = axes[1]
    sns.barplot(x=np.arange(k), y=sorted_listener_probs[:k], ax=ax)
    ax.set_ylabel("Class probability")
    ax.set_xticks(np.arange(k))
    ax.set_xticklabels(sorted_listener_classes[:k], rotation=45, ha="right")
    ax.set_title("Listener")

    ax = axes[2]
    sns.barplot(x=np.arange(k), y=sorted_pragmatic_listener_probs[:k], ax=ax)
    ax.set_ylabel("Class probability")
    ax.set_xticks(np.arange(k))
    ax.set_xticklabels(sorted_pragmatic_listener_classes[:k], rotation=45, ha="right")
    ax.set_title("Pragmatic Listener")
    plt.show()

    print(f"Speaker: {explanation_claims}")
    print(f"Pragmatic Speaker: {pragmatic_explanation_claims}")
    # print(f"Shared claims: {shared_claims}")
    print(f"Uniquely pragmatic claims: {unique_pragmatic_claims}")