In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config, Config
from datasets import get_dataset

config_name = "imagenet"
# CUB
# config_dict = {
#     "data.explanation_length": [6, 12, 18, 24, 30],
#     "speaker.alpha": [0.0, 0.2],
#     "listener.listener_type": "claim",
#     "listener.gamma": [0.0, 0.4, 0.8],
# }
# explanation_length = 6
# vip_max_queries = 311
# CheXpert
# config_dict = {
#     "data.explanation_length": [4, 6, 8],
#     "speaker.alpha": [0.0, 0.2],
#     "listener.listener_type": "claim",
#     "listener.gamma": [0.0, 0.2, 0.4],
# }
# explanation_length = 4
# ImageNet
config_dict = {
    "data.explanation_length": [12],
    "speaker.alpha": [0.0, 0.2],
    "listener.listener_type": "claim",
    "listener.gamma": [0.4],
}
explanation_length = 12
vip_max_queries = 399
config = get_config(config_name, config_dict=config_dict)

dataset = get_dataset(config, train=False, workdir=root_dir)
classes = dataset.classes
print(f"Number of classes: {len(classes)}")

results_dir = os.path.join(root_dir, "results", config.data.dataset.lower())

sns.set_style("white")
sns.set_context("paper")

In [None]:
classifier_safe = config.data.classifier.lower().replace(":", "_").replace("/", "_")
results_path = os.path.join(results_dir, f"{classifier_safe}.pkl")

results = pd.read_pickle(results_path)
label = results["label"]
prediction = results["prediction"]

accuracy = (label == prediction).mean()
print(f"{classifier_safe} accuracy: {accuracy:.2%}")

confusion = confusion_matrix(label, prediction, normalize="true")
accuracy_per_class = np.diag(confusion)

sorted_class_idx = np.argsort(accuracy_per_class)[::-1]
sorted_classes = [classes[idx] for idx in sorted_class_idx]
sorted_accuracy = accuracy_per_class[sorted_class_idx]

print("Class-wise accuracy:")
for _, (class_name, acc) in enumerate(zip(sorted_classes, sorted_accuracy)):
    print(f"\t{class_name}: {acc:.2%}")

In [None]:
figure_dir = os.path.join(root_dir, "figures", "accuracy")
os.makedirs(figure_dir, exist_ok=True)


def get_accuracy(config: Config):
    results = config.get_results()
    results["listener_prediction"] = results["action"].apply(lambda x: np.argmax(x))
    return (
        results["explanation_accuracy"].mean(),
        (results["listener_prediction"] == results["prediction"]).mean(),
    )


lit_config, prag_config = config.sweep(keys=["speaker.alpha"])

lit_len_configs = lit_config.sweep(keys=["data.explanation_length", "listener.gamma"])
prag_len_configs = prag_config.sweep(keys=["data.explanation_length", "listener.gamma"])

results_data = {
    "speaker": [],
    "gamma": [],
    "explanation_length": [],
    "explanation_accuracy": [],
    "listener_accuracy": [],
}
for _config in lit_len_configs + prag_len_configs:
    if _config.speaker.alpha == 0.0:
        speaker = "literal"
    else:
        speaker = "pragmatic"

    explanation_accuracy, listener_accuracy = get_accuracy(_config)
    results_data["speaker"].append(speaker)
    results_data["gamma"].append(_config.listener.gamma)
    results_data["explanation_length"].append(_config.data.explanation_length)
    results_data["explanation_accuracy"].append(explanation_accuracy)
    results_data["listener_accuracy"].append(listener_accuracy)

results_df = pd.DataFrame(results_data)

lit_accuracy = results_df[
    (results_df["speaker"] == "literal")
    & (results_df["explanation_length"] == explanation_length)
]["listener_accuracy"].values[0]
prag_accuracy = results_df[
    (results_df["speaker"] == "pragmatic")
    & (results_df["explanation_length"] == explanation_length)
]["listener_accuracy"].values[0]

print(f"Literal listener accuracy ({explanation_length} claims): {lit_accuracy:.2%}")
print(f"Pragmatic listener accuracy ({explanation_length} claims): {prag_accuracy:.2%}")

_, axes = plt.subplots(1, 2, figsize=(16 / 2, 9 / 4))
ax = axes[0]
sns.lineplot(
    data=results_df,
    x="explanation_length",
    y="explanation_accuracy",
    hue="speaker",
    style="gamma",
    marker="o",
    ax=ax,
)
ax.set_xlabel("Utterance length")
ax.set_ylabel("Explanation accuracy")

ax = axes[1]
sns.lineplot(
    data=results_df,
    x="explanation_length",
    y="listener_accuracy",
    hue="speaker",
    style="gamma",
    marker="o",
    ax=ax,
)
ax.set_xlabel("Utterance length")
ax.set_ylabel("Listener accuracy")
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.savefig(
    os.path.join(figure_dir, f"{config.data.dataset.lower()}.pdf"), bbox_inches="tight"
)
plt.savefig(
    os.path.join(figure_dir, f"{config.data.dataset.lower()}.png"), bbox_inches="tight"
)
plt.show()

In [None]:
vip_results_path = os.path.join(
    results_dir, f"{config.data.dataset.lower()}_vip_query{vip_max_queries}_sbiased.pkl"
)
vip_results = pd.read_pickle(vip_results_path)

prediction = np.array(vip_results["prediction"].values.tolist())
vip_logits = np.array(vip_results["logits"].values.tolist())

vip_prediction = np.argmax(vip_logits, axis=-1)
vip_accuracy = np.mean(vip_prediction == prediction[:, None], axis=0)
print("V-IP accuracy:")
for n_queries, accuracy in enumerate(vip_accuracy):
    print(f"\t{n_queries + 1} queries: {accuracy:.2%}")

explanation_length = results_df["explanation_length"].unique()
explanation_length = np.sort(explanation_length)

_, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
ax.plot(
    explanation_length,
    vip_accuracy[explanation_length - 1],
    label="V-IP",
    color="black",
    linestyle="--",
)
sns.lineplot(
    data=results_df,
    x="explanation_length",
    y="listener_accuracy",
    hue="speaker",
    style="gamma",
    marker="o",
    ax=ax,
)
ax.set_xlabel("Utterance length")
ax.set_ylabel("Listener accuracy")
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.savefig(
    os.path.join(figure_dir, f"{config.data.dataset.lower()}_vip.pdf"),
    bbox_inches="tight",
)
plt.savefig(
    os.path.join(figure_dir, f"{config.data.dataset.lower()}_vip.png"),
    bbox_inches="tight",
)
plt.show()