In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
sys.path.append(root_dir)
from _configs import CUBClaimConfig, CUBTopicConfig, CUBDistributionConfig

data_dir = os.path.join(root_dir, "data")
cub_dir = os.path.join(data_dir, "CUB")
attribute_dir = os.path.join(cub_dir, "attributes")

with open(os.path.join(attribute_dir, "topics.txt"), "r") as f:
    lines = f.read().splitlines()
    lines = [line.strip().split() for line in lines]
    topics = [t.replace("_", " ") for _, t in lines]

beta = 0.4

sns.set_theme()
sns.set_context("paper")

In [None]:
def get_explanation_length(results):
    explanation_topic = np.array(results["explanation_topics"].values.tolist())
    explanation_length = np.sum(explanation_topic > -1, axis=-1)
    return np.mean(explanation_length)


def get_explanation_consistency(results):
    explanation_consistency = np.array(
        results["explanation_consistency"].values.tolist()
    )
    return np.mean(explanation_consistency)


def get_listener_accuracy(results):
    prediction = np.array(results["prediction"].values.tolist())
    listener_action = np.array(results["listener_action"].values.tolist())
    listener_prediction = np.argmax(listener_action, axis=-1)
    return np.mean(prediction == listener_prediction)


def get_topic_distribution(results):
    explanation_topic = np.array(results["explanation_topics"].values.tolist())

    n_topics = len(topics)
    topic_mask = np.arange(n_topics) + 1
    explanation_topic_mask = explanation_topic[..., None] == topic_mask
    explanation_topic_distribution = np.sum(
        explanation_topic_mask, axis=-2, dtype=float
    )
    explanation_topic_distribution /= np.sum(
        explanation_topic_distribution, axis=-1, keepdims=True
    )
    return np.mean(explanation_topic_distribution, axis=0)

# Base Claim Listener

In [None]:
config = CUBClaimConfig(beta=beta, gamma=0.0, alpha=0.0)
results = config.get_results(workdir=root_dir)

gamma = [0.0, 0.01, 0.02]
pragmatic_configs = [CUBClaimConfig(beta=beta, gamma=g, alpha=0.1) for g in gamma]
pragmatic_results = [
    config.get_results(workdir=root_dir) for config in pragmatic_configs
]

explanation_length = get_explanation_length(results)
explanation_consistency = get_explanation_consistency(results)
listener_accuracy = get_listener_accuracy(results)
topic_distribution = get_topic_distribution(results)

pragmatic_explanation_length = [get_explanation_length(r) for r in pragmatic_results]
pragmatic_explanation_consistency = [
    get_explanation_consistency(r) for r in pragmatic_results
]
pragmatic_listener_accuracy = [get_listener_accuracy(r) for r in pragmatic_results]
pragmatic_topic_distribution = [get_topic_distribution(r) for r in pragmatic_results]

_, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.3})
ax = axes[0]
ax.axhline(explanation_length, color="black", linestyle="--", label=r"$\alpha = 0.0$")
ax.plot(gamma, pragmatic_explanation_length, marker="o", label=r"$\alpha = 0.1$")
ax.set_xlabel(r"$\gamma$")
ax.set_ylabel("Explanation length")
ax.set_xticks(gamma)
ax.legend()

ax = axes[1]
ax.axhline(
    explanation_consistency, color="black", linestyle="--", label=r"$\alpha = 0.0$"
)
ax.plot(gamma, pragmatic_explanation_consistency, marker="o", label=r"$\alpha = 0.1$")
ax.set_xlabel(r"$\gamma$")
ax.set_ylabel("Explanation consistency")
ax.set_xticks(gamma)
ax.legend()

ax = axes[2]
ax.axhline(listener_accuracy, color="black", linestyle="--", label=r"$\alpha = 0.0$")
ax.plot(gamma, pragmatic_listener_accuracy, marker="o", label=r"$\alpha = 0.1$")
ax.set_xlabel(r"$\gamma$")
ax.set_ylabel("Listener accuracy")
ax.set_xticks(gamma)
ax.legend()
plt.show()

_, axes = plt.subplots(1, 4, figsize=(16, 9 / 4))
ax = axes[0]
sns.barplot(topic_distribution, ax=ax)
ax.set_ylabel("Topic frequency")
ax.set_xticks(range(len(topics)))
ax.set_xticklabels(topics, rotation=45, ha="right")
ax.set_ylim(0, 0.6)
ax.set_title(r"$\alpha = %.2f, \gamma = %.2f$" % (0.0, 0.0))

for i, (g, topic_distribution) in enumerate(zip(gamma, pragmatic_topic_distribution)):
    ax = axes[i + 1]
    sns.barplot(topic_distribution, ax=ax)
    ax.set_ylabel("Topic frequency")
    ax.set_xticks(range(len(topics)))
    ax.set_xticklabels(topics, rotation=45, ha="right")
    ax.set_ylim(0, 0.6)
    ax.set_title(r"$\alpha = %.2f, \gamma = %.2f$" % (0.1, g))
plt.show()

# Topic Listener (no coloration)

In [None]:
config = CUBTopicConfig(beta=beta, gamma=0.0, alpha=0.0)
results = config.get_results(workdir=root_dir)

gamma = [0.0, 0.01, 0.02]
pragmatic_configs = [CUBTopicConfig(beta=beta, gamma=g, alpha=0.1) for g in gamma]
pragmatic_results = [
    config.get_results(workdir=root_dir) for config in pragmatic_configs
]

explanation_length = get_explanation_length(results)
explanation_consistency = get_explanation_consistency(results)
listener_accuracy = get_listener_accuracy(results)
topic_distribution = get_topic_distribution(results)

pragmatic_explanation_length = [get_explanation_length(r) for r in pragmatic_results]
pragmatic_explanation_consistency = [
    get_explanation_consistency(r) for r in pragmatic_results
]
pragmatic_listener_accuracy = [get_listener_accuracy(r) for r in pragmatic_results]
pragmatic_topic_distribution = [get_topic_distribution(r) for r in pragmatic_results]

_, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.3})
ax = axes[0]
ax.axhline(explanation_length, color="black", linestyle="--", label=r"$\alpha = 0.0$")
ax.plot(gamma, pragmatic_explanation_length, marker="o", label=r"$\alpha = 0.1$")
ax.set_xlabel(r"$\gamma$")
ax.set_ylabel("Explanation length")
ax.set_xticks(gamma)
ax.legend()

ax = axes[1]
ax.axhline(
    explanation_consistency, color="black", linestyle="--", label=r"$\alpha = 0.0$"
)
ax.plot(gamma, pragmatic_explanation_consistency, marker="o", label=r"$\alpha = 0.1$")
ax.set_xlabel(r"$\gamma$")
ax.set_ylabel("Explanation consistency")
ax.set_xticks(gamma)
ax.legend()

ax = axes[2]
ax.axhline(listener_accuracy, color="black", linestyle="--", label=r"$\alpha = 0.0$")
ax.plot(gamma, pragmatic_listener_accuracy, marker="o", label=r"$\alpha = 0.1$")
ax.set_xlabel(r"$\gamma$")
ax.set_ylabel("Listener accuracy")
ax.set_xticks(gamma)
ax.legend()
plt.show()

_, axes = plt.subplots(1, 4, figsize=(16, 9 / 4))
ax = axes[0]
sns.barplot(topic_distribution, ax=ax)
ax.set_ylabel("Topic frequency")
ax.set_xticks(range(len(topics)))
ax.set_xticklabels(topics, rotation=45, ha="right")
ax.set_ylim(0, 0.6)
ax.set_title(r"$\alpha = %.2f, \gamma = %.2f$" % (0.0, 0.0))

for i, (g, topic_distribution) in enumerate(zip(gamma, pragmatic_topic_distribution)):
    ax = axes[i + 1]
    sns.barplot(topic_distribution, ax=ax)
    ax.set_ylabel("Topic frequency")
    ax.set_xticks(range(len(topics)))
    ax.set_xticklabels(topics, rotation=45, ha="right")
    ax.set_ylim(0, 0.6)
    ax.set_title(r"$\alpha = %.2f, \gamma = %.2f$" % (0.1, g))
plt.show()

# Distribution Listener (uniform prior)

In [None]:
temperature_scale = [1.0, 2.0, 4.0, 8.0]

configs = [
    CUBDistributionConfig(beta=beta, gamma=0.0, alpha=0.0, temperature_scale=t)
    for t in temperature_scale
]
results = [config.get_results(workdir=root_dir) for config in configs]

pragmatic_configs = [
    CUBDistributionConfig(beta=beta, gamma=0.0, alpha=0.1, temperature_scale=t)
    for t in temperature_scale
]
pragmatic_results = [
    config.get_results(workdir=root_dir) for config in pragmatic_configs
]

explanation_length = [get_explanation_length(r) for r in results]
explanation_consistency = [get_explanation_consistency(r) for r in results]
listener_accuracy = [get_listener_accuracy(r) for r in results]
topic_distribution = [get_topic_distribution(r) for r in results]

pragmatic_explanation_length = [get_explanation_length(r) for r in pragmatic_results]
pragmatic_explanation_consistency = [
    get_explanation_consistency(r) for r in pragmatic_results
]
pragmatic_listener_accuracy = [get_listener_accuracy(r) for r in pragmatic_results]
pragmatic_topic_distribution = [get_topic_distribution(r) for r in pragmatic_results]

_, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.3})
ax = axes[0]
ax.plot(temperature_scale, explanation_length, marker="o", label=r"$\alpha = 0.0$")
ax.plot(
    temperature_scale, pragmatic_explanation_length, marker="o", label=r"$\alpha = 0.1$"
)
ax.set_xlabel("Temperature scale")
ax.set_ylabel("Explanation length")
ax.set_xticks(temperature_scale)
ax.legend()

ax = axes[1]
ax.plot(temperature_scale, explanation_consistency, marker="o", label=r"$\alpha = 0.0$")
ax.plot(
    temperature_scale,
    pragmatic_explanation_consistency,
    marker="o",
    label=r"$\alpha = 0.1$",
)
ax.set_xlabel("Temperature scale")
ax.set_ylabel("Explanation consistency")
ax.set_xticks(temperature_scale)
ax.legend()

ax = axes[2]
ax.plot(temperature_scale, listener_accuracy, marker="o", label=r"$\alpha = 0.0$")
ax.plot(
    temperature_scale, pragmatic_listener_accuracy, marker="o", label=r"$\alpha = 0.1$"
)
ax.set_xlabel("Temperature scale")
ax.set_ylabel("Listener accuracy")
ax.set_xticks(temperature_scale)
ax.legend()
plt.show()

for distribution, alpha in [
    (topic_distribution, 0.0),
    (pragmatic_topic_distribution, 0.1),
]:
    _, axes = plt.subplots(1, len(temperature_scale), figsize=(16, 9 / 4))
    for i, t in enumerate(temperature_scale):
        _distribution = distribution[i]
        prior = np.ones(len(topics)) / len(topics)
        kl = np.sum(prior * np.log(prior / (_distribution + 1e-08)))

        ax = axes[i]
        sns.barplot(_distribution, ax=ax)
        ax.set_ylabel("Topic frequency")
        ax.set_xticks(range(len(topics)))
        ax.set_xticklabels(topics, rotation=45, ha="right")
        ax.set_ylim(0, 0.6)
        ax.set_title(
            r"$\alpha = %.2f, \tau = %.1f$" % (alpha, t)
            + "\n"
            + r"$\text{KL} = %.2f$" % kl
        )
    plt.show()