In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config

config_name = "cub"
config_dict = {
    "data.explanation_length": 12,
    "listener.type": "topic",
    "listener.gamma": 0.4,
    "listener.prior": 6 * [1 / 6],
    "listener.temperature_scale": [1.0, 2.0, 4.0, 8.0],
    "speaker.alpha": [0.0, 0.2],
}
config = get_config(config_name, config_dict=config_dict)

lit_configs, prag_configs = config.sweep(["speaker.alpha"])

lit_results = [
    c.get_results() for c in lit_configs.sweep(["listener.temperature_scale"])
]
prag_results = [
    c.get_results() for c in prag_configs.sweep(["listener.temperature_scale"])
]

sns.set_theme()
sns.set_context("paper")

In [None]:
def topic_kl(results):
    prior = np.array(config.listener.prior)
    explanation_topics = np.array(results["explanation_topics"].values.tolist())

    return np.sum(
        explanation_topics * np.log((explanation_topics + 1e-08) / (prior + 1e-08)),
        axis=-1,
    )


lit_kl = [topic_kl(r) for r in lit_results]
prag_kl = [topic_kl(r) for r in prag_results]

mu_lit_kl = np.mean(lit_kl, axis=-1)
mu_prag_kl = np.mean(prag_kl, axis=-1)

_, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
ax.plot(lit_configs.listener.temperature_scale, mu_lit_kl, label="Literal", marker="o")
ax.plot(
    prag_configs.listener.temperature_scale, mu_prag_kl, label="Pragmatic", marker="o"
)
ax.set_xlabel("Listener temperature scale")
ax.set_ylabel("Topic KL divergence")
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.show()

In [None]:
prag_claim_config_dict = config_dict.copy()
prag_claim_config_dict["listener.type"] = "claim"
prag_claim_config_dict["speaker.alpha"] = 0.20

prag_claim_config = get_config(config_name, config_dict=prag_claim_config_dict)
prag_claim_results = prag_claim_config.get_results()

best_temperature_idx = np.argmin(mu_prag_kl)
best_temperature = prag_configs.listener.temperature_scale[best_temperature_idx]
print(best_temperature_idx, best_temperature)

_lit_results = lit_results[best_temperature_idx]
_prag_results = prag_results[best_temperature_idx]
for r in [_lit_results, _prag_results]:
    r["listener_prediction"] = r["action"].apply(lambda x: np.argmax(x))
    r["correct"] = r["prediction"] == r["listener_prediction"]
    print(f"{r['correct'].mean():.2%}")

mask = _lit_results["correct"] & _prag_results["correct"]
_lit_kl = lit_kl[best_temperature_idx]
_prag_kl = prag_kl[best_temperature_idx]

In [None]:
explanation_regions = np.array(lit_results["explanation_topics"].values.tolist())
prag_explanation_regions = np.array(prag_results["explanation_topics"].values.tolist())

region_dist = np.mean(explanation_regions, axis=0)
prag_region_dist = np.mean(prag_explanation_regions, axis=0)

_, axes = plt.subplots(1, 2, figsize=(16 / 2, 9 / 4))
ax = axes[0]
ax.bar(range(len(region_dist)), region_dist)
ax.set_ylim([0, 0.55])

ax = axes[1]
ax.bar(range(len(prag_region_dist)), prag_region_dist)
ax.set_ylim([0, 0.55])
plt.show()

In [None]:
ax = axes[0]
ax.bar(range(len(region_dist)), region_dist)

ax = axes[1]
ax.bar(range(len(prag_region_dist)), prag_region_dist)
plt.show()