In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torchvision.utils import make_grid
from sklearn.metrics import pairwise_distances

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from classifiers import get_classifier
from datasets import get_dataset
from notebooks.utils import viz_explanation

config_name = "cub"
config_dict = {
    "data.explanation_length": 12,
    "listener.type": "claim",
    "listener.gamma": 0.4,
    "speaker.alpha": [0.0, 0.2],  # do not change this
}
config = get_config(config_name, config_dict=config_dict)

classifier = get_classifier(config, device="cpu")
preprocess = classifier.preprocess

dataset = get_dataset(config, train=False, transform=preprocess, return_attribute=True)
classes = dataset.classes

lit_results, prag_results = [c.get_results() for c in config.sweep(["speaker.alpha"])]
for r in [lit_results, prag_results]:
    r["listener_prediction"] = r["action"].apply(lambda x: np.argmax(x))
    r["correct"] = r["listener_prediction"] == r["prediction"]

idx = (
    (lit_results["label"] == lit_results["prediction"])
    * (lit_results["correct"] == False)
    * prag_results["correct"]
)
idx = np.nonzero(idx.values)[0]

sns.set_theme()
sns.set_context("paper")

In [None]:
figure_dir = os.path.join(
    root_dir, "figures", config.data.dataset.lower(), "explanations"
)
os.makedirs(figure_dir, exist_ok=True)

m = 10
example_idx = np.random.choice(idx, m, replace=False)
for _idx in example_idx:
    image, label, image_attribute = dataset[_idx]

    idx_lit_results = lit_results.iloc[_idx]
    idx_prag_results = prag_results.iloc[_idx]
    assert idx_lit_results["prediction"] == idx_prag_results["prediction"]
    cls_prediction = idx_prag_results["prediction"]

    image = (image * 1 / 2) + 1 / 2

    _, axes = plt.subplots(
        1,
        3,
        figsize=(16 / 2, 9 / 4),
        gridspec_kw={"width_ratios": [3, 1, 1], "wspace": 1.0},
    )
    ax = axes[0]
    ax.imshow(image.permute(1, 2, 0))
    ax.axis("off")
    ax.set_title(f"Label: {classes[label]}\nPrediction: {classes[cls_prediction]}")

    ax = axes[1]
    viz_explanation(dataset, lit_results, _idx, ax)
    ax.set_title(f"Literal speaker\n{ax.get_title()}")

    ax = axes[2]
    viz_explanation(dataset, prag_results, _idx, ax)
    ax.set_title(f"Pragmatic speaker\n{ax.get_title()}")

    plt.savefig(os.path.join(figure_dir, f"{_idx}.pdf"), bbox_inches="tight")
    plt.savefig(os.path.join(figure_dir, f"{_idx}.png"), bbox_inches="tight")
    plt.show()
    plt.close()

In [None]:
def _viz_explanation(dataset, results, idx, ax):
    image, label, image_attribute = dataset[idx]

    claism = dataset.claims
    vocab_size = len(claims) + 3
    special_tokens = {
        vocab_size - 3: "[BOS]",
        vocab_size - 2: "[EOS]",
        vocab_size - 1: "[PAD]",
    }

    _results = results.iloc[idx]
    prediction = _results["prediction"]
    listener_prediction = _results["listener_prediction"]
    explanation = _results["explanation"]
    cls_attn_weights = results_idx["cls_attention"]

    explanation_claims = [
        (
            f"{attribute_to_human_readable(claims[claim])} ({cls}/{image_attribute[claim]:.0f})"
            if claim < len(claims)
            else special_tokens[claim]
        )
        for claim, cls in explanation
    ]
    explanation_claims = explanation_claims + ["[CLS]"]


m = 10
example_idx = np.random.choice(len(dataset), m, replace=False)
for idx in example_idx:
    image, label, image_attribute = dataset[idx]

    image = (image * 1 / 2) + 1 / 2

    results_idx = lit_results.iloc[idx]
    prediction = results_idx["prediction"]
    listener_prediction = results_idx["listener_prediction"]
    explanation = results_idx["explanation"]

    # gen_attn_weights = results_idx["gen_attention"]
    # cross_attn_weights = results_idx["cross_attention"]

    # unimodal_embeddings = results_idx["unimodal_embeddings"]
    # unimodal_similarity = 1 - pairwise_distances(unimodal_embeddings, metric="cosine")

    # multimodal_embeddings = results_idx["multimodal_embeddings"]
    # multimodal_similarity = 1 - pairwise_distances(
    #     multimodal_embeddings, metric="cosine"
    # )

    cls_attn_weights = results_idx["cls_attention"]

    # gen_attn_weights = gen_attn_weights.reshape(gen_attn_weights.shape[0], 16, 16)
    # gen_attn_weights = np.repeat(gen_attn_weights, 14, axis=1)
    # gen_attn_weights = np.repeat(gen_attn_weights, 14, axis=2)
    # gen_attn_weights = gen_attn_weights[:, None, ...]

    # cross_attn_weights = cross_attn_weights.reshape(cross_attn_weights.shape[0], 16, 16)
    # cross_attn_weights = np.repeat(cross_attn_weights, 14, axis=1)
    # cross_attn_weights = np.repeat(cross_attn_weights, 14, axis=2)
    # cross_attn_weights = cross_attn_weights[:, None, ...]

    explanation_claims = [
        (
            f"{attribute_to_human_readable(claims[idx])} ({y}/{image_attribute[idx]:.0f})"
            if idx < len(claims)
            else special_tokens[idx]
        )
        for idx, y in explanation
    ]
    explanation_claims = explanation_claims + ["[CLS]"]

    _, axes = plt.subplots(
        1,
        2,
        figsize=(16 / 2, 9 / 4),
        gridspec_kw={"width_ratios": [3, 1], "wspace": 0.7},
    )
    ax = axes[0]
    ax.imshow(image.permute(1, 2, 0))
    ax.axis("off")
    ax.set_title(f"Label: {classes[label]}\nPrediction: {classes[prediction]}")

    ax = axes[1]
    sns.barplot(x=cls_attn_weights, y=explanation_claims, ax=ax)
    ax.set_xlabel("Attention weights")
    ax.set_title(f"Action: {classes[listener_prediction]}")
    plt.show()

    # _, ax = plt.subplots(figsize=(16, 9))
    # image_grid = make_grid(
    #     image.unsqueeze(0).expand(config.speaker.n_queries, -1, -1, -1),
    #     nrow=8,
    #     padding=2,
    #     pad_value=0,
    # )
    # grid = make_grid(
    #     torch.tensor(gen_attn_weights), nrow=8, padding=2, pad_value=0, normalize=True
    # )
    # ax.imshow(image_grid.permute(1, 2, 0))
    # ax.imshow(grid[0], cmap="Reds", alpha=0.5)
    # ax.axis("off")
    # ax.set_title("Generative attention pooling weights")
    # plt.show()

    # _, ax = plt.subplots(figsize=(16, 9))
    # image_grid = make_grid(
    #     image.unsqueeze(0).expand(config.data.explanation_length + 1, -1, -1, -1),
    #     nrow=7,
    #     padding=2,
    #     pad_value=0,
    # )
    # grid = make_grid(
    #     torch.tensor(cross_attn_weights), nrow=7, padding=2, pad_value=0, normalize=True
    # )
    # ax.imshow(image_grid.permute(1, 2, 0))
    # ax.imshow(grid[0], cmap="Reds", alpha=0.5)
    # ax.axis("off")
    # ax.set_title("Cross attention weights")

    # _, axes = plt.subplots(1, 2, figsize=(16 / 2, 9 / 4))
    # ax = axes[0]
    # sns.heatmap(unimodal_similarity, ax=ax, cmap="viridis", square=True, cbar=True)
    # ax.set_xticklabels(explanation_claims[:-1], rotation=45, ha="right")
    # ax.set_yticklabels(
    #     explanation_claims[:-1],
    # )
    # ax.set_title("Unimodal embedding similarity")

    # ax = axes[1]
    # sns.heatmap(multimodal_similarity, ax=ax, cmap="viridis", square=True, cbar=True)
    # ax.set_xticklabels(explanation_claims[:-1], rotation=45, ha="right")
    # ax.set_yticklabels([])
    # ax.set_title("Multimodal embedding similarity")
    # plt.show()