In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

root_dir = "../"
sys.path.append(root_dir)
import configs

sns.set_theme()
sns.set_context("paper")

In [None]:
def get_listener_confusion_matrix(results):
    label = results["label"]
    prediction = np.array(results["prediction"].values.tolist())
    listener_action = np.array(results["listener_action"].values.tolist())
    listener_prediction = np.argmax(listener_action, axis=-1)
    return confusion_matrix(
        prediction, listener_prediction, labels=np.unique(label), normalize="true"
    )


def accuracy_vs_consistency(results):
    label = results["label"]
    prediction = results["prediction"]
    cls_confusion_matrix = confusion_matrix(label, prediction, normalize="true")
    cls_accuracy = np.diag(cls_confusion_matrix)

    explanation_consistency = results["explanation_consistency"]
    class_consistency = np.zeros(len(cls_accuracy))
    for i in range(len(cls_accuracy)):
        class_consistency[i] = explanation_consistency[label == i].mean()
    return cls_accuracy, class_consistency


speaker_config = configs.CUBClaimSpeaker()
pragmatic_speaker_config = configs.CUBPragmaticClaimSpeaker()

results = speaker_config.get_results(workdir=root_dir)
pragmatic_results = pragmatic_speaker_config.get_results(workdir=root_dir)

listener_confusion = get_listener_confusion_matrix(results)
pragmatic_listener_confusion = get_listener_confusion_matrix(pragmatic_results)

listener_class_accuracy = np.diag(listener_confusion)
pragmatic_listener_class_accuracy = np.diag(pragmatic_listener_confusion)
difference = pragmatic_listener_class_accuracy - listener_class_accuracy

_, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.25})
ax = axes[0]
sns.heatmap(listener_confusion, ax=ax, cmap="viridis")
ax.set_title(f"Listener\n(accuracy = {np.diag(listener_confusion).mean():.2f})")

ax = axes[1]
sns.heatmap(pragmatic_listener_confusion, ax=ax, cmap="viridis")
ax.set_title(
    "Pragmatic listener\n(accuracy ="
    f" {np.diag(pragmatic_listener_confusion).mean():.2f})"
)

ax = axes[2]
ax.plot([0, 1], [0, 1], color="black", linestyle="--")
ax.scatter(listener_class_accuracy, pragmatic_listener_class_accuracy, s=5)
ax.set_xlabel("Listener accuracy")
ax.set_ylabel("Pragmatic listener accuracy")
plt.show()

_, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.25})
ax = axes[0]
cls_accuracy, class_consistency = accuracy_vs_consistency(results)
ax.scatter(cls_accuracy, class_consistency, s=5)
ax.set_xlabel("Classifier accuracy")
ax.set_ylabel("Explanation consistency")
ax.set_title("Speaker")

ax = axes[1]
cls_accuracy, pragmatic_class_consistency = accuracy_vs_consistency(pragmatic_results)
ax.scatter(cls_accuracy, class_consistency, s=5)
ax.set_xlabel("Classifier accuracy")
ax.set_ylabel("Explanation consistency")
ax.set_title("Pragmatic Speaker")

ax = axes[2]
ax.plot([0.2, 1], [0.2, 1], color="black", linestyle="--")
ax.scatter(class_consistency, pragmatic_class_consistency, s=5)
ax.set_xlabel("Speaker")
ax.set_ylabel("Pragmatic speaker")
ax.set_title("Class-wise consistency")
plt.show()