In [None]:
import os
import sys
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

root_dir = "../"
results_dir = os.path.join(root_dir, "results", "cub")

sns.set_style("white")
sns.set_context("paper")

In [None]:
descriptions = {}
demographics = {"age": [], "gender": [], "experience": []}
completion_data = {"pid": [], "time": []}
results_data = {
    "pid": [],
    "condition": [],
    "target": [],
    "answer": [],
}

responses = pd.read_csv(os.path.join(results_dir, "user_study.tsv"), sep="\t")
trials = responses["Trials"]
for trial in trials:
    data = json.loads(trial)

    completion_time = 0
    for step in data[::-1]:
        elapsed_time = step["time_elapsed"]
        if elapsed_time > completion_time:
            completion_time = elapsed_time

        phase = step["phase"]
        pid = step["prolific_pid"]

        if phase == "survey":
            response = step["response"]
            experience = response["experience"]

            demographics["age"].append(response["age"])
            demographics["gender"].append(response["gender"])
            demographics["experience"].append(experience)

        if phase == "trial":
            condition = step["condition"]
            if condition == "random":
                condition = "random speaker"
            if condition == "speaker:literal":
                condition = "literal speaker"
            if condition == "speaker:pragmatic":
                condition = "pragmatic speaker"
            if condition == "speaker:topic":
                condition = "topic speaker"
            if condition == "vip":
                condition = "V-IP"

            target = step["target"]
            answer = step["answer"]

            results_data["pid"].append(pid)
            results_data["condition"].append(condition)
            results_data["target"].append(target)
            results_data["answer"].append(answer)

        if phase == "spcies_examples":
            response = step["response"]

            species = step["species"]
            if species not in descriptions:
                descriptions[species] = []
            descriptions[species].append(response["Q0"])

    completion_data["pid"].append(pid)
    completion_data["time"].append(completion_time)

In [None]:
figure_dir = os.path.join(root_dir, "figures", "user_study")
os.makedirs(figure_dir, exist_ok=True)

palette = sns.color_palette("tab10")

_, axes = plt.subplots(
    1, 4, figsize=(16 / 1.5, 9 / 4), gridspec_kw={"hspace": 0.5, "wspace": 0.5}
)

ax = axes[0]
age = pd.Categorical(
    demographics["age"],
    categories=["18-24", "25-34", "35-44", "45-54", "55+"],
    ordered=True,
)
gender = pd.Categorical(
    demographics["gender"],
    categories=["Female", "Male", "Non-binary", "Prefer not to say"],
    ordered=False,
)
experience = pd.Categorical(
    demographics["experience"],
    categories=[
        "No experience",
        "Some experience (I birdwatch occasionally)",
        "Moderate experience (I birdwatch regularly)",
        "High experience (I am a professional birdwatcher)",
    ],
    ordered=True,
)

demographics_df = pd.DataFrame({"age": age, "gender": gender, "experience": experience})

completion_df = pd.DataFrame(completion_data)
completion_df["time_min"] = completion_df["time"] / 60000

sns.countplot(data=demographics_df, x="age", color=palette[0], ax=ax)
ax.set_xlabel("Age")
ax.set_ylabel("Count")
ax.set_yticks([0, 5, 10, 15])
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.tick_params(axis="x", which="both", direction="out", length=4, color="black")
ax.xaxis.set_ticks_position("bottom")

ax = axes[1]
sns.countplot(data=demographics_df, x="gender", color=palette[0], ax=ax)
ax.set_xlabel("Gender")
ax.set_ylabel("Count")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.tick_params(axis="x", which="both", direction="out", length=4, color="black")
ax.xaxis.set_ticks_position("bottom")

ax = axes[2]
sns.countplot(data=demographics_df, x="experience", color=palette[0], ax=ax)
ax.set_xlabel("Experience level")
ax.set_ylabel("Count")
ax.set_yticks([0, 5, 10, 15, 20])
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.tick_params(axis="x", which="both", direction="out", length=4, color="black")
ax.xaxis.set_ticks_position("bottom")

ax = axes[3]
median_time = completion_df["time_min"].median()
sns.histplot(data=completion_df, x="time_min", color=palette[0], alpha=1.0, ax=ax)
ax.axvline(median_time, color="red", linestyle="--")
ax.text(x=median_time + 3, y=12, s=f"Median\n({median_time:.0f} min)", color="red")
ax.set_xlabel("Completion time (min)")
ax.set_yticks([0, 5, 10, 15])
plt.savefig(os.path.join(figure_dir, "cohort.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "cohort.png"), bbox_inches="tight")
plt.show()

In [None]:
results_df = pd.DataFrame(results_data)
results_df["condition"] = pd.Categorical(
    results_df["condition"],
    categories=[
        "random speaker",
        "V-IP",
        "literal speaker",
        "pragmatic speaker",
        "topic speaker",
    ],
    ordered=True,
)

condition_df = results_df.groupby(["condition", "target"])["pid"].count()
conditions = condition_df.index.get_level_values(0).unique()
targets = condition_df.index.get_level_values(1).unique()

_, ax = plt.subplots(figsize=(3, 3))
sns.heatmap(
    condition_df.values.reshape(len(conditions), len(targets)),
    annot=True,
    fmt="d",
    ax=ax,
    cbar=False,
    xticklabels=True,
    yticklabels=True,
)
ax.set_xticklabels(targets, rotation=45, ha="right")
ax.set_yticklabels(conditions, rotation=0)
ax.set_xlabel("Species")
ax.set_ylabel("Explainer")
ax.set_title("Number of responses")
ax.tick_params(
    axis="both",
    which="both",
    direction="out",
    length=4,
    width=1,
    color="black",
    bottom=True,
    left=True,
)
plt.savefig(os.path.join(figure_dir, "responses.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "responses.png"), bbox_inches="tight")
plt.show()

In [None]:
results_df["correct"] = results_df["target"] == results_df["answer"]

_, axes = plt.subplots(1, 2, figsize=(16 / 2, 9 / 4))
ax = axes[0]
sns.barplot(
    data=results_df, x="condition", y="correct", hue="condition", palette="tab10", ax=ax
)
ax.set_xlabel("Explainer")
ax.set_ylabel("Accuracy")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

ax = axes[1]
sns.barplot(data=results_df, x="target", y="correct", hue="condition", ax=ax)
ax.set_xlabel("Bird species")
ax.set_ylabel("Accuracy")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.legend(title="explainer", loc="upper left", bbox_to_anchor=(1, 1))
plt.savefig(os.path.join(figure_dir, "accuracy.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "accuracy.png"), bbox_inches="tight")
plt.show()

In [None]:
print("Example descriptions:")
print("===================================")
for species, species_descriptions in descriptions.items():
    print(f"species: {species:<30}")
    for description in species_descriptions[:5]:
        print(f"\t `{description}`")

In [None]:
os.makedirs(os.path.join(figure_dir, "confusion"), exist_ok=True)

condition_df = results_df.groupby(["condition"])
for condition, group in condition_df:
    condition = condition[0]
    condition_safe = condition.lower().replace("-", "_").replace(" ", "_")

    target = group["target"]
    answer = group["answer"]
    c = confusion_matrix(target, answer, normalize="true")

    _, ax = plt.subplots(figsize=(3, 3))
    sns.heatmap(
        c,
        annot=True,
        fmt=".2f",
        ax=ax,
        cbar=False,
        xticklabels=targets,
        yticklabels=targets,
    )
    ax.tick_params(
        axis="both",
        which="both",
        direction="out",
        length=4,
        width=1,
        color="black",
        bottom=True,
        left=True,
    )
    ax.set_title(condition)
    plt.savefig(
        os.path.join(figure_dir, "confusion", f"{condition_safe}.pdf"),
        bbox_inches="tight",
    )
    plt.savefig(
        os.path.join(figure_dir, "confusion", f"{condition_safe}.png"),
        bbox_inches="tight",
    )
    plt.show()