In [None]:
import os
import sys
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"

study_name = "cub/human_evaluation/pilot"
results_dir = os.path.join(root_dir, "results", study_name)

In [None]:
descriptions = {}
demographics = {"age": [], "gender": [], "experience": []}
results_df = {
    "pid": [],
    "experience_level": [],
    "condition": [],
    "target": [],
    "answer": [],
}

responses = os.listdir(results_dir)
for response in responses:
    data = json.load(open(os.path.join(results_dir, response)))

    for step in data[::-1]:
        phase = step["phase"]
        pid = step["prolific_pid"]

        if phase == "survey":
            response = step["response"]
            experience = response["experience"]
            if experience == "No experience":
                experience_level = 0
            if "Some experience" in experience:
                experience_level = 1
            if "Moderate experience" in experience:
                experience_level = 2

            demographics["age"].append(response["age"])
            demographics["gender"].append(response["gender"])
            demographics["experience"].append(experience_level)

        if phase == "trial":
            condition = step["condition"]
            if condition == "random":
                condition = "random speaker"
            if condition == "speaker:literal":
                condition = "literal speaker"
            if condition == "speaker:pragmatic":
                condition = "pragmatic speaker"
            if condition == "speaker:topic":
                condition = "topic speaker"
            if condition == "vip":
                condition = "V-IP"

            target = step["target"]
            answer = step["answer"]

            results_df["pid"].append(pid)
            results_df["experience_level"].append(experience_level)
            results_df["condition"].append(condition)
            results_df["target"].append(target)
            results_df["answer"].append(answer)

        if phase == "spcies_examples":
            response = step["response"]

            species = step["species"]
            if species not in descriptions:
                descriptions[species] = []
            descriptions[species].append(response["Q0"])

In [None]:
_, axes = plt.subplots(
    1, 3, figsize=(16 / 2, 9 / 4), gridspec_kw={"hspace": 0.5, "wspace": 0.5}
)

ax = axes[0]
age = demographics["age"]
df = pd.DataFrame({"age": age})
age_order = ["18-24", "25-34", "35-44", "45-54", "55+"]
df["age"] = pd.Categorical(df["age"], categories=age_order, ordered=True)

sns.histplot(data=df, x="age", ax=ax)
ax.set_xlabel("Age")

ax = axes[1]
sns.histplot(demographics["gender"], ax=ax)
ax.set_xlabel("Gender")

ax = axes[2]
sns.histplot(demographics["experience"], discrete=True, ax=ax)
ax.set_xlabel("Experience level")
ax.set_xticks([0, 1, 2, 3])
plt.show()

In [None]:
results_df = pd.DataFrame(results_df)
results_df["condition"] = pd.Categorical(
    results_df["condition"],
    categories=[
        "random speaker",
        "V-IP",
        "literal speaker",
        "pragmatic speaker",
        "topic speaker",
    ],
    ordered=True,
)

condition_df = results_df.groupby(["condition", "target"])["pid"].count()
conditions = condition_df.index.get_level_values(0).unique()
targets = condition_df.index.get_level_values(1).unique()

_, ax = plt.subplots(figsize=(3, 3))
sns.heatmap(
    condition_df.values.reshape(len(conditions), len(targets)),
    annot=True,
    fmt="d",
    ax=ax,
    cbar=False,
)
ax.set_xticklabels(targets, rotation=45, ha="right")
ax.set_yticklabels(conditions, rotation=0)
ax.set_xlabel("Species")
ax.set_ylabel("Condition")
ax.set_title("Number of responses per condition and species")
plt.show()

In [None]:
results_df["correct"] = results_df["target"] == results_df["answer"]

_, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
sns.barplot(data=results_df, x="condition", y="correct", ax=ax)
ax.set_xlabel("Explainer")
ax.set_ylabel("Accuracy")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
plt.show()

_, ax = plt.subplots(figsize=(16 / 3, 9 / 4))
sns.barplot(data=results_df, x="target", y="correct", hue="condition", ax=ax)
ax.set_xlabel("Bird species")
ax.set_ylabel("Accuracy")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.show()

In [None]:
print("Example descriptions:")
print("===================================")
for species, species_descriptions in descriptions.items():
    print(f"species: {species:<30}")
    for description in species_descriptions[:5]:
        print(f"\t `{description}`")