In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from classifiers import get_classifier
from datasets import get_dataset
from notebooks.utils import viz_explanation

config_name = "cub"
topic_config_dict = {
    "data.explanation_length": 6,
    "listener.type": "topic",
    "listener.gamma": 0.4,
    "listener.prior": [0.0, 0.0, 1 / 3, 1 / 3, 1 / 3, 0.0],
    "listener.temperature_scale": 8.0,
    "speaker.alpha": [0.0, 0.2],
}
claim_config_dict = topic_config_dict.copy()
claim_config_dict["listener.type"] = "claim"

claim_configs = get_config(config_name, config_dict=claim_config_dict)
topic_configs = get_config(config_name, config_dict=topic_config_dict)

lit_claim_config, prag_claim_config = claim_configs.sweep(["speaker.alpha"])
_, prag_topic_config = topic_configs.sweep(["speaker.alpha"])

lit_claim_results = lit_claim_config.get_results()
prag_claim_results = prag_claim_config.get_results()
prag_topic_results = prag_topic_config.get_results()

classifier = get_classifier(topic_configs, device="cpu")
dataset = get_dataset(
    prag_topic_config,
    train=False,
    transform=classifier.preprocess,
    return_attribute=True,
)
classes = dataset.classes

attribute_dir = os.path.join(root_dir, "data", "CUB", "attributes")
with open(os.path.join(attribute_dir, "topics.txt"), "r") as f:
    lines = f.readlines()
    lines = [l.strip() for l in lines]
    topics = [l.split()[1].replace("_", " ") for l in lines]

sns.set_theme()
sns.set_context("paper")

In [None]:
# figure_dir = os.path.join(
#     root_dir, "figures", topic_config.data.dataset.lower(), "topic_kl"
# )
# os.makedirs(figure_dir, exist_ok=True)


# def topic_kl(results):
#     explanation_topics = np.array(results["explanation_topics"].values.tolist())

#     return np.sum(
#         explanation_topics * np.log((explanation_topics + 1e-08) / (prior + 1e-08)),
#         axis=-1,
#     )


# # lit_topic_kl = [topic_kl(r) for r in lit_topic_results]
# prag_topic_kl = [topic_kl(r) for r in prag_topic_results]

# lit_claim_kl = topic_kl(lit_claim_results)
# prag_claim_kl = topic_kl(prag_claim_results)

# # mu_lit_topic_kl = np.mean(lit_topic_kl, axis=-1).tolist()
# mu_prag_topic_kl = np.mean(prag_topic_kl, axis=-1).tolist()

# mu_lit_claim_kl = np.mean(lit_claim_kl)
# mu_prag_claim_kl = np.mean(prag_claim_kl)

# _, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
# # ax.plot(
# #     [0] + lit_topic_configs.listener.temperature_scale,
# #     [mu_lit_claim_kl] + mu_lit_topic_kl,
# #     label="Literal",
# #     marker="o",
# # )
# ax.plot(
#     [0] + prag_topic_configs.listener.temperature_scale,
#     [mu_prag_claim_kl] + mu_prag_topic_kl,
#     label="Pragmatic",
#     marker="o",
# )
# ax.set_xlabel("Temperature scale")
# ax.set_ylabel("Topic KL divergence")
# ax.set_xticks([0] + lit_topic_configs.listener.temperature_scale)
# ax.legend(title="speaker", loc="upper left", bbox_to_anchor=(1, 1))
# plt.savefig(
#     os.path.join(
#         figure_dir,
#         f"temperature_{','.join(map(lambda x: f'{x:.2f}', prag_topic_configs.listener.prior))}.pdf",
#     ),
#     bbox_inches="tight",
# )
# plt.savefig(
#     os.path.join(
#         figure_dir,
#         f"temperature_{','.join(map(lambda x: f'{x:.2f}', prag_topic_configs.listener.prior))}.png",
#     ),
#     bbox_inches="tight",
# )
# plt.show()

In [None]:
def topic_kl(prior, results):
    explanation_topics = np.array(results["explanation_topics"].values.tolist())

    return np.sum(
        explanation_topics * np.log((explanation_topics + 1e-08) / (prior + 1e-08)),
        axis=-1,
    )


prior = np.array(topic_configs.listener.prior)
prag_claim_kl = topic_kl(prior, prag_claim_results)
prag_topic_kl = topic_kl(prior, prag_topic_results)

for r in [lit_claim_results, prag_claim_results, prag_topic_results]:
    r["listener_prediction"] = r["action"].apply(lambda x: np.argmax(x))
    r["correct"] = r["prediction"] == r["listener_prediction"]
    print(f"{r['correct'].mean():.2%}")

mask = (
    lit_claim_results["correct"]
    & prag_claim_results["correct"]
    & prag_topic_results["correct"]
    # & prag_topic_results["label"].isin([179, 187, 138, 46, 158])
).values

kl_diff = prag_claim_kl - prag_topic_kl
sorted_idx = np.argsort(kl_diff)[::-1]
sorted_mask = mask[sorted_idx]
sorted_idx = sorted_idx[sorted_mask]

In [None]:
figure_dir = os.path.join(
    root_dir, "figures", prag_topic_config.data.dataset.lower(), "topics"
)
os.makedirs(figure_dir, exist_ok=True)

m = 30
for rank, idx in enumerate(sorted_idx[:m]):
    image, label, image_attribute = dataset[idx]
    label_name = classes[label]
    print(idx, label, kl_diff[idx])

    idx_lit_claim_results = lit_claim_results.iloc[idx]
    idx_prag_claim_results = prag_claim_results.iloc[idx]
    idx_prag_topic_results = prag_topic_results.iloc[idx]
    assert (
        idx_lit_claim_results["prediction"]
        == idx_prag_claim_results["prediction"]
        == idx_prag_topic_results["prediction"]
    )
    cls_prediction = idx_lit_claim_results["prediction"]
    prediction_name = classes[cls_prediction]

    image = (image * 1 / 2) + 1 / 2

    _, axes = plt.subplots(
        1,
        4,
        figsize=(16, 9 / 4),
        gridspec_kw={"width_ratios": [3, 1, 1, 1], "wspace": 1.0},
    )
    ax = axes[0]
    ax.imshow(image.permute(1, 2, 0))
    ax.axis("off")
    ax.set_title(f"Label: {label_name}\nPrediction: {prediction_name}")

    ax = axes[1]
    viz_explanation(dataset, lit_claim_results, idx, ax)

    ax = axes[2]
    viz_explanation(dataset, prag_claim_results, idx, ax)

    ax = axes[3]
    viz_explanation(dataset, prag_topic_results, idx, ax)
    plt.savefig(
        os.path.join(figure_dir, f"{rank}_{idx}_{label_name}_explanations.pdf"),
        bbox_inches="tight",
    )
    plt.savefig(
        os.path.join(figure_dir, f"{rank}_{idx}_{label_name}_explanations.png"),
        bbox_inches="tight",
    )
    plt.show()

    topic_data = {"topic": [], "probability": [], "speaker": []}
    for speaker, speaker_results in [
        ("Literal/none", idx_lit_claim_results),
        ("Pragmatic/Literal", idx_prag_claim_results),
        ("Pragmatic/Topic", idx_prag_topic_results),
    ]:
        explanation_topics = speaker_results["explanation_topics"]
        for i, topic in enumerate(topics):
            topic_data["topic"].append(topic)
            topic_data["probability"].append(explanation_topics[i])
            topic_data["speaker"].append(speaker)
    for i, topic in enumerate(topics):
        topic_data["topic"].append(topic)
        topic_data["probability"].append(prior[i])
        topic_data["speaker"].append("Topic prior")

    _, ax = plt.subplots(figsize=(9 / 8, 16 / 8))
    topic_df = pd.DataFrame(topic_data)
    sns.barplot(data=topic_data, y="topic", x="probability", hue="speaker", ax=ax)
    ax.set_xlabel("Frequency")
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
    ax.legend(title="Speaker/Listener", loc="upper left", bbox_to_anchor=(1, 1))
    plt.savefig(
        os.path.join(figure_dir, f"{rank}_{idx}_{label_name}_topics.pdf"),
        bbox_inches="tight",
    )
    plt.savefig(
        os.path.join(figure_dir, f"{rank}_{idx}_{label_name}_topics.png"),
        bbox_inches="tight",
    )
    plt.show()