In [None]:
import os
import sys
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
sys.path.append(root_dir)
import configs

results_dir = os.path.join(root_dir, "results")
data_dir = os.path.join(root_dir, "data")
attribute_dir = os.path.join(data_dir, "CUB", "attributes")

with open(os.path.join(attribute_dir, "topics.txt"), "r") as f:
    lines = f.readlines()
    lines = [line.strip().split() for line in lines]
    topics = [topic.replace("_", " ") for _, topic in lines]

with open(os.path.join(attribute_dir, "attribute_topic.txt"), "r") as f:
    lines = f.readlines()
    lines = [line.strip().split() for line in lines]
    attribute_topic = [int(topic) for _, topic in lines]

sns.set_theme()
sns.set_context("paper")

In [None]:
def viz_topic_distribution(config, ax):
    results = config.get_results(workdir=root_dir)

    explanation_topic = np.array(
        results["explanation_topics"].values.tolist(), dtype=int
    )
    topic_distribution = np.zeros((explanation_topic.shape[0], len(topics)))

    for idx, _explanation_topic in enumerate(explanation_topic):
        _topic_distribution = np.unique(_explanation_topic, return_counts=True)
        for topic, count in zip(*_topic_distribution):
            if topic > 0:
                topic_distribution[idx, topic - 1] = count

    topic_distribution = topic_distribution / topic_distribution.sum(
        axis=-1, keepdims=True
    )
    topic_distribution = np.mean(topic_distribution, axis=0)

    sns.barplot(topic_distribution, ax=ax)
    ax.set_ylabel("Topic frequency")
    ax.set_xticks(range(len(topics)))
    ax.set_xticklabels(topics, rotation=45, ha="right")
    ax.set_ylim(0, 0.6)


gamma = 0.01
temperature_scale = 4.0
claim_speaker = configs.CUBPragmaticClaimConfig()
topic_speaker = configs.CUBPragmaticTopicConfig(gamma=0.0)
topic_speaker_with_penalty = configs.CUBPragmaticTopicConfig(gamma=gamma)
distribution_speaker = configs.CUBPragmaticDistributionConfig(
    temperature_scale=temperature_scale
)

_, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.25})
ax = axes[0]
viz_topic_distribution(claim_speaker, ax)
ax.set_title("Pragmatic claim speaker")

ax = axes[1]
viz_topic_distribution(topic_speaker, ax)
ax.set_title("Pragmatic topic speaker\n" + r"(no coloration, $\gamma = 0$)")

ax = axes[2]
viz_topic_distribution(topic_speaker_with_penalty, ax)
ax.set_title("Pragmatic topic speaker\n" + r"(no coloration, $\gamma = %.2f$)" % gamma)
plt.show()

ax, axes = plt.subplots(1, 3, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"wspace": 0.25})
ax = axes[0]
viz_topic_distribution(claim_speaker, ax)
ax.set_title("Pragmatic claim speaker")

ax = axes[1]
viz_topic_distribution(topic_speaker_with_penalty, ax)
ax.set_title("Pragmatic topic speaker\n" + r"(no coloration, $\gamma = %.2f$)" % gamma)

ax = axes[2]
viz_topic_distribution(distribution_speaker, ax)
ax.set_title(
    "Pragmatic distribution speaker\n"
    + r"(uniform prior, $\tau = %d$)" % temperature_scale
)
plt.show()