In [None]:
import os
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
results_dir = os.path.join(root_dir, "results")
data_dir = os.path.join(root_dir, "data")
attribute_dir = os.path.join(data_dir, "CUB", "attributes")

with open(os.path.join(attribute_dir, "topics.txt"), "r") as f:
    lines = f.readlines()
    lines = [line.strip().split() for line in lines]
    topics = [topic.replace("_", " ") for _, topic in lines]

with open(os.path.join(attribute_dir, "attribute_topic.txt"), "r") as f:
    lines = f.readlines()
    lines = [line.strip().split() for line in lines]
    attribute_topic = {attribute: int(topic) for attribute, topic in lines}

sns.set_theme()
sns.set_context("paper")

In [None]:
prior = np.array(len(topics) * [1 / len(topics)])

run_name = "cub_claim_10_0.0_0.4_8"
df = pd.read_parquet(os.path.join(results_dir, f"{run_name}.parquet"))

explanation_topic = np.array(df["explanation_topics"].values.tolist()).astype(int)
topic_mask = np.arange(len(topics)) + 1
explanation_topic_mask = explanation_topic[..., None] == topic_mask
explanation_topic_distribution = np.sum(explanation_topic_mask, axis=-2).astype(float)
explanation_topic_distribution /= np.sum(
    explanation_topic_distribution, axis=-1, keepdims=True
)
topic_kl = np.sum(
    prior * np.log(prior / (explanation_topic_distribution + 1e-08)), axis=-1
)
# temperature = torch.tensor(1 / ( topic_kl + 1))
temperature = torch.tensor(1 / np.log(topic_kl + np.e))
print(min(topic_kl), max(topic_kl))

listener_action = torch.from_numpy(np.array(df["listener_action"].values.tolist()))

pre_temperature_logit = torch.softmax(listener_action, dim=-1)
pre_temperature_logit = torch.amax(pre_temperature_logit, dim=-1)

post_temperature_logit = torch.softmax(temperature[:, None] * listener_action, dim=-1)
post_temperature_logit = torch.amax(post_temperature_logit, dim=-1)

_, axes = plt.subplots(1, 2, figsize=(16 / 2, 9 / 4), gridspec_kw={"wspace": 0.25})
ax = axes[0]
sns.histplot(topic_kl, ax=ax)
ax.set_xlabel("Topic KL divergence")

ax = axes[1]
ax.plot([1e-01, 1], [1e-01, 1], color="black", linestyle="--")
sns.scatterplot(x=pre_temperature_logit, y=post_temperature_logit, hue=topic_kl, ax=ax)
ax.set_xlabel("Pre-temperature logit")
ax.set_ylabel("Post-temperature logit")
# ax.set_xscale("log")
# ax.set_yscale("log")
plt.show()