In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import statsmodels.api as sm

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from classifiers import get_classifier
from datasets import get_dataset

config_name = "cub"
config_dict = {
    "data.explanation_length": 6,
    "listener.type": "claim",
    "listener.gamma": 0.4,
    "speaker.alpha": [0.0, 0.2],  # do not change this
}
config = get_config(config_name, config_dict=config_dict)

classifier = get_classifier(config, device="cpu")
preprocess = classifier.preprocess

dataset = get_dataset(config, train=False, transform=preprocess, return_attribute=False)
classes = dataset.classes

lit_results, prag_results = [c.get_results() for c in config.sweep(["speaker.alpha"])]
for r in [lit_results, prag_results]:
    r["listener_prediction"] = r["action"].apply(lambda x: np.argmax(x))
    r["cls_correct"] = r["label"] == r["prediction"]
    r["listener_correct"] = r["listener_prediction"] == r["prediction"]

cls_accuracy = lit_results["cls_correct"].mean()
lit_accuracy = lit_results["listener_correct"].mean()
prag_accuracy = lit_results["listener_correct"].mean()
print(f"Classifier accuracy: {cls_accuracy:.2%}")
print(f"Literal accuracy: {lit_accuracy:.2%}")
print(f"Pragmatic accuracy: {prag_accuracy:.2%}")

sns.set_theme()
sns.set_context("paper")

In [None]:
figure_dir = os.path.join(root_dir, "figures", config.data.dataset.lower())
os.makedirs(figure_dir, exist_ok=True)

class_accuracy = lit_results.groupby("label")["cls_correct"].mean()
lit_class_accuracy = lit_results.groupby("label")["listener_correct"].mean()
prag_class_accuracy = prag_results.groupby("label")["listener_correct"].mean()

df = pd.DataFrame(
    {
        "class_accuracy": class_accuracy,
        "literal": lit_class_accuracy,
        "pragmatic": prag_class_accuracy,
    }
)

X_lit = sm.add_constant(df["class_accuracy"])  # Adds intercept term
y_lit = df["literal"]
model_lit = sm.OLS(y_lit, X_lit).fit()
print("Literal model summary:")
print(model_lit.summary())

X_prag = sm.add_constant(df["class_accuracy"])
y_prag = df["pragmatic"]
model_prag = sm.OLS(y_prag, X_prag).fit()
print("\nPragmatic model summary:")
print(model_prag.summary())

_, ax = plt.subplots(figsize=(3, 3))
sns.regplot(
    x=class_accuracy,
    y=lit_class_accuracy,
    ax=ax,
    label="Literal",
    scatter_kws={"s": 5, "alpha": 0.3},
)
sns.regplot(
    x=class_accuracy,
    y=prag_class_accuracy,
    ax=ax,
    label="Pragmatic",
    scatter_kws={"s": 5, "alpha": 0.3},
)
ax.set_xlabel("Classifier accuracy")
ax.set_ylabel("Listener accuracy")
ax.legend(title="speaker", loc="upper left", bbox_to_anchor=(1, 1))
ax.set_title("Class-wise accuracy")
plt.savefig(os.path.join(figure_dir, "class_accuracy.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "class_accuracy.png"), bbox_inches="tight")
plt.show()

In [None]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=-1, keepdims=True)


results_dir = os.path.join(root_dir, "results", config.data.dataset.lower())

classifier_safe = config.data.classifier.lower().replace(":", "_").replace("/", "_")
classifier_results_path = os.path.join(results_dir, f"{classifier_safe}.pkl")
classifier_results = pd.read_pickle(classifier_results_path)
cls_logit = classifier_results["logits"].apply(lambda x: np.max(softmax(x)))

for r in [lit_results, prag_results]:
    r["listener_logit"] = r["action"].apply(lambda x: np.max(softmax(x)))

lit_logit = lit_results["listener_logit"]
prag_logit = prag_results["listener_logit"]

df = pd.DataFrame(
    {
        "cls_logit": cls_logit,
        "literal": lit_logit,
        "pragmatic": prag_logit,
    }
)

X_lit = sm.add_constant(df["cls_logit"])  # Adds intercept term
y_lit = df["literal"]
model_lit = sm.OLS(y_lit, X_lit).fit()
print("Literal model summary:")
print(model_lit.summary())

X_prag = sm.add_constant(df["cls_logit"])
y_prag = df["pragmatic"]
model_prag = sm.OLS(y_prag, X_prag).fit()
print("\nPragmatic model summary:")
print(model_prag.summary())

_, ax = plt.subplots(figsize=(3, 3))
sns.regplot(
    x=cls_logit,
    y=lit_logit,
    ax=ax,
    label="Literal",
    scatter_kws={"s": 5, "alpha": 0.3},
)
sns.regplot(
    x=cls_logit,
    y=prag_logit,
    ax=ax,
    label="Pragmatic",
    scatter_kws={"s": 5, "alpha": 0.3},
)
ax.set_xlabel("Classifier logit")
ax.set_ylabel("Listener logit")
ax.legend(title="speaker", loc="upper left", bbox_to_anchor=(1, 1))
plt.savefig(os.path.join(figure_dir, "logit.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "logit.png"), bbox_inches="tight")
plt.show()