In [None]:
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config

config_name = "cub"
config = get_config(config_name)

claim_cls_dir = os.path.join(root_dir, "results", "claim_cls")
sweeps_dir = os.path.join(root_dir, "results", config.data.dataset.lower(), "sweeps")

figure_dir = os.path.join(root_dir, "figures", "sweeps")
os.makedirs(figure_dir, exist_ok=True)

sns.set_style("white")
sns.set_context("paper")

In [None]:
sweep_literal_path = os.path.join(sweeps_dir, "sweep_all_literal.csv")
sweep_literal_df = pd.read_csv(sweep_literal_path)

claim_cls_path = os.path.join(claim_cls_dir, f"{config.data.dataset.lower()}.txt")
with open(claim_cls_path, "r") as f:
    claim_cls_accuracy = float(f.read().strip())

_, ax = plt.subplots(figsize=(16 / 7, 9 / 4))
sns.lineplot(
    data=sweep_literal_df,
    x="listener.gamma",
    y="val/explanation sentiment",
    hue="data.explanation_length",
    palette="tab10",
    ax=ax,
)
ax.set_xlabel(r"True negative weight $\gamma$")
ax.set_ylabel("Fraction of positive claims")
ax.legend(title="Utterance length")
ax.set_ylim(0, 1.05)
ax.set_xticks(sweep_literal_df["listener.gamma"].unique())
plt.savefig(os.path.join(figure_dir, "sweep_gamma.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "sweep_gamma.png"), bbox_inches="tight")
plt.show()

_, ax = plt.subplots(figsize=(16 / 7, 9 / 4))

sns.lineplot(
    data=sweep_literal_df,
    x="val/explanation sentiment",
    y="val/explanation accuracy",
    hue="data.explanation_length",
    palette="tab10",
    ax=ax,
)
ax.axhline(claim_cls_accuracy, linestyle="-.", color="black")
ax.text(
    x=1.02,
    y=claim_cls_accuracy + 0.01,
    s="Strong supervision",
    color="black",
    ha="right",
    va="bottom",
)
ax.set_xlabel("Fraction of positive claims")
ax.set_ylabel("Utterance accuracy")
ax.set_ylim(None, 0.98)
ax.legend(title="Utterance length")
plt.savefig(os.path.join(figure_dir, "sweep_accuracy.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "sweep_accuracy.png"), bbox_inches="tight")
plt.show()

In [None]:
sweep_beta_k_path = os.path.join(sweeps_dir, "sweep_beta_k.csv")
sweep_beta_k_results = pd.read_csv(sweep_beta_k_path)
b = sweep_beta_k_results["speaker.k"].astype(int)
sweep_beta_k_results["n_pref"] = b * (b - 1) // 2

_, ax = plt.subplots(figsize=(16 / 7, 9 / 4))
sns.lineplot(
    data=sweep_beta_k_results,
    x="speaker.beta",
    y="val/listener accuracy",
    hue="n_pref",
    palette="tab10",
    ax=ax,
)
ax.set_xlabel(r"DPO regularizer $\beta$")
ax.set_ylabel("Listener accuracy")
ax.set_xticks(sweep_beta_k_results["speaker.beta"].unique())
ax.set_title(f"Utterance length: {config.data.explanation_length}")
ax.legend(title=r"$n_{\text{pref}}$")
plt.savefig(os.path.join(figure_dir, "sweep_beta.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "sweep_beta.png"), bbox_inches="tight")
plt.show()

In [None]:
sweep_listener_k_path = os.path.join(sweeps_dir, "sweep_listener_k.csv")
sweep_listener_k_results = pd.read_csv(sweep_listener_k_path)

_, ax = plt.subplots(figsize=(16 / 7, 9 / 4))
sns.lineplot(
    data=sweep_listener_k_results,
    x="listener.k",
    y="val/listener accuracy",
    ax=ax,
)
ax.set_xlabel(r"$n_{\text{expl}}$")
ax.set_ylabel("Listener accuracy")
ax.set_xticks(sweep_listener_k_results["listener.k"].unique())
ax.set_title(f"Utterance length: 6")
plt.savefig(os.path.join(figure_dir, "sweep_nexpl.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "sweep_nexpl.png"), bbox_inches="tight")
plt.show()

In [None]:
sweep_alpha_path = os.path.join(sweeps_dir, "sweep_alpha.csv")
sweep_alpha_results = pd.read_csv(sweep_alpha_path)

_, ax = plt.subplots(figsize=(16 / 7, 9 / 4))
sns.lineplot(
    data=sweep_alpha_results,
    x="speaker.alpha",
    y="val/listener accuracy",
    hue="gamma",
    palette="tab10",
    ax=ax,
)
ax.set_xlabel(r"Pragmatic strength $\alpha$")
ax.set_ylabel("Listener accuracy")
ax.set_xticks(sweep_alpha_results["speaker.alpha"].unique())
ax.legend(title=r"TN weight $\gamma$")
ax.set_title("Utterance length: 6")
plt.savefig(os.path.join(figure_dir, "sweep_alpha.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "sweep_alpha.png"), bbox_inches="tight")
plt.show()