In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config

# CUB
config_name = "cub"
priors = {
    "uniform": 6 * [1 / 6],
    "topic": [0.0, 0.0, 1 / 3, 1 / 3, 1 / 3, 0.0],
}
base_config_dict = {
    "data.explanation_length": 6,
    "listener.gamma": 0.4,
    "listener.temperature_scale": [1.0, 2.0, 4.0, 8.0],
    "speaker.alpha": 0.2,
}
# CheXpert

claim_config_dict = base_config_dict.copy()
claim_config_dict["listener.type"] = "claim"


def get_topic_config_dict(prior):
    config_dict = base_config_dict.copy()
    config_dict["listener.type"] = "topic"
    config_dict["listener.prior"] = prior
    return config_dict


claim_config = get_config(config_name, config_dict=claim_config_dict)

sns.set_style("white")
sns.set_context("paper")

In [None]:
figure_dir = os.path.join(root_dir, "figures", "temperature")
if not os.path.exists(figure_dir):
    os.makedirs(figure_dir)


def get_kl_data(prior_name, prior):
    topic_config_dict = get_topic_config_dict(prior)

    topic_configs = get_config(config_name, config_dict=topic_config_dict)
    topic_configs = topic_configs.sweep(keys=["listener.temperature_scale"])

    data = {"prior": [], "temperature": [], "kl": [], "accuracy": []}
    t0_norm = None
    for config in [claim_config] + topic_configs:
        results = config.get_results()
        temperature = (
            config.listener.temperature_scale
            if config.listener.type == "topic"
            else 0.0
        )

        results = config.get_results()
        results["listener_prediction"] = results["action"].apply(lambda x: np.argmax(x))
        results["correct"] = results["prediction"] == results["listener_prediction"]
        explanation_topics = np.array(results["explanation_topics"].values.tolist())

        kl = np.sum(
            explanation_topics * np.log((explanation_topics + 1e-08) / (prior + 1e-08)),
            axis=-1,
        )
        kl = np.mean(kl).item()
        accuracy = np.mean(results["correct"]).item()

        if temperature == 0.0:
            t0_norm = kl

        data["prior"].append(prior_name)
        data["temperature"].append(temperature)
        data["kl"].append(kl / t0_norm)
        data["accuracy"].append(accuracy)
    return data


kl_data = [
    pd.DataFrame(get_kl_data(prior_name, np.array(prior)))
    for prior_name, prior in priors.items()
]
kl_data = pd.concat(kl_data, ignore_index=True)

_, ax = plt.subplots(figsize=(16 / 6, 9 / 4))
sns.lineplot(
    data=kl_data,
    x="temperature",
    y="accuracy",
    hue="prior",
    marker="o",
    linestyle="--",
    alpha=0.3,
    ax=ax,
)
sns.lineplot(
    data=kl_data, x="temperature", y="kl", hue="prior", marker="o", linewidth=1.5, ax=ax
)
ax.set_xlabel(r"Temperature scale $\tau$")
# ax.set_ylabel("Normalized KL divergence")
ax.set_ylabel("")
ax.set_xticks(kl_data["temperature"].unique())
plt.savefig(
    os.path.join(figure_dir, f"{claim_config.data.dataset.lower()}.png"),
    bbox_inches="tight",
)
plt.savefig(
    os.path.join(figure_dir, f"{claim_config.data.dataset.lower()}.pdf"),
    bbox_inches="tight",
)
plt.show()