In [None]:
import os
import sys
import shutil
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision.transforms as T
from torchvision.utils import make_grid

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from datasets import get_dataset

config_name = "cub"
config_dict = {
    "data.explanation_length": 6,
    "speaker.alpha": 0.2,
    "listener.type": ["claim", "topic"],
    "listener.gamma": 0.4,
    "listener.temperature_scale": 4.0,
}
config = get_config(config_name, config_dict=config_dict)

transform = T.Compose([T.Resize(224), T.CenterCrop((224, 224)), T.ToTensor()])
dataset = get_dataset(config, train=False, transform=transform, return_attribute=False)
classes = dataset.classes

claim_config, topic_config = config.sweep(keys=["listener.type"])
claim_results, topic_results = claim_config.get_results(), topic_config.get_results()
results = {"claim": claim_results, "topic": topic_results}

for r in results.values():
    r["listener_prediction"] = r["action"].apply(lambda x: np.argmax(x))
    r["classifier_correct"] = r["prediction"] == r["label"]
    r["listener_correct"] = r["listener_prediction"] == r["prediction"]

sns.set_style("white")
sns.set_context("paper")

In [None]:
_, ax = plt.subplots(figsize=(3, 3))

speaker_mean = []

m = 45
for idx, (speaker, r) in enumerate(results.items()):
    gr = r.groupby(["label"])
    acc = gr[["classifier_correct", "listener_correct"]].mean()
    acc["mean"] = acc.mean(axis=1)
    speaker_mean.append(acc["mean"])

speaker_mean = np.array(speaker_mean)
mu = np.mean(speaker_mean, axis=0)
sorted_idx = np.argsort(mu)[::-1]
for idx in sorted_idx[:m]:
    print(f"{classes[idx]:<30} {idx}: {speaker_mean[:, idx]}")

# m = 25
# for idx, (speaker, r) in enumerate(results.items()):
#     gr = r.groupby(["label"])
#     acc = gr[["classifier_correct", "listener_correct"]].mean()
#     acc["mean"] = acc.mean(axis=1)
#     acc["count"] = gr.size()
#     acc.sort_values("mean", ascending=False, inplace=True)

#     prev = ""

#     print(f"{speaker} speaker summary:")
#     print("label, classifier accuracy, listener accuracy, count")

#     selected_classes = []
#     for label, row in acc.iterrows():
#         class_name = classes[label]

#         # if any([chunk in prev.split() for chunk in class_name.split()]):
#         #     continue
#         # prev += f" {class_name}"
#         selected_classes.append(label)

#         if (idx + 1) > m:
#             break
#         print(
#             f"\t{class_name:<30}: {row['classifier_correct']:<5.2%}"
#             f"\t{row['listener_correct']:<5.2%}"
#             f"\t({row['count']:.0f})"
#         )

#         if len(selected_classes) == m:
#             break

#     ax.scatter(acc["classifier_correct"], acc["listener_correct"], label=speaker)
#     print(selected_classes)

# ax.set_xlabel("Classifier accuracy")
# ax.set_ylabel("Listener accuracy")
# ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
# plt.show()

In [None]:
figure_dir = os.path.join(root_dir, "figures", "cub", "human_evaluation", "examples")
os.makedirs(figure_dir, exist_ok=True)
shutil.rmtree(figure_dir)
os.makedirs(figure_dir, exist_ok=True)

label = np.array([label for _, label in dataset.samples])
selected_classes = [179, 187, 138, 46, 158]
# selected_classes = [100, 139, 105, 98, 179, 52, 80, 76, 74, 46]

# m = 5
for class_idx in selected_classes:
    class_name = classes[class_idx]
    print(f"Class: {class_name}")

    image_idx = np.where(label == class_idx)[0]
    # image_idx = np.random.permutation(image_idx)[:m]

    for i, idx in enumerate(image_idx):
        _, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
        image = dataset[idx][0]
        # grid = make_grid(image, nrow=m, padding=2)
        ax.imshow(image.permute(1, 2, 0).numpy())
        ax.axis("off")
        plt.savefig(
            os.path.join(figure_dir, f"{class_name.lower().replace(' ', '_')}_{i}.png"),
            bbox_inches="tight",
            dpi=300,
        )
    plt.show()