In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import reverse_cuthill_mckee

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

root_dir = "../"
sys.path.append(root_dir)
from datasets import CUB
from classifiers import CLIPClassifier
from speaker_model import ClaimSpeaker

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_dir = os.path.join(root_dir, "data")

backbone = "ViT-L/14"
classifier = CLIPClassifier(backbone, device=device)

dataset = CUB(data_dir, train=False)
classes = dataset.classes
claims = dataset.claims

speaker = ClaimSpeaker(classifier, len(classes), claims, device=device)
state_path = os.path.join(root_dir, "weights", f"cub_10_0.0_0.4_8.pt")
state = torch.load(state_path, map_location=device)
speaker.load_state_dict(state["speaker"])

pragmatic_speaker = ClaimSpeaker(classifier, len(classes), claims, device=device)
state_path = os.path.join(root_dir, "weights", f"cub_10_0.1_0.4_8.pt")
state = torch.load(state_path, map_location=device)
pragmatic_speaker.load_state_dict(state["speaker"])

sns.set_theme()
sns.set_context("paper")

In [None]:
with torch.no_grad():
    class_prompts = [f"This is a picture of a {c}" for c in classes]
    clip_class_embeddings = classifier.encode_text(class_prompts)
    clip_class_corr = torch.matmul(clip_class_embeddings, clip_class_embeddings.T)
    clip_class_corr = clip_class_corr.cpu().numpy()

    speaker_class_embeddings = speaker.class_embedding.weight
    speaker_class_embeddings /= torch.norm(
        speaker_class_embeddings, dim=-1, keepdim=True
    )
    speaker_class_corr = torch.matmul(
        speaker_class_embeddings, speaker_class_embeddings.T
    )
    speaker_class_corr = speaker_class_corr.cpu().numpy()

    pragmatic_speaker_class_embeddings = pragmatic_speaker.class_embedding.weight
    pragmatic_speaker_class_embeddings /= torch.norm(
        pragmatic_speaker_class_embeddings, dim=-1, keepdim=True
    )
    pragmatic_speaker_class_corr = torch.matmul(
        pragmatic_speaker_class_embeddings, pragmatic_speaker_class_embeddings.T
    )
    pragmatic_speaker_class_corr = pragmatic_speaker_class_corr.cpu().numpy()

# clip_class_graph = csr_matrix(clip_class_corr - np.eye(len(classes)))
# class_order = reverse_cuthill_mckee(clip_class_graph)

# ordered_clip_class_corr = clip_class_corr[class_order][:, class_order]
# ordered_speaker_class_corr = speaker_class_corr[class_order][:, class_order]
# ordered_pragmatic_speaker_class_corr = pragmatic_speaker_class_corr[class_order][
#     :, class_order
# ]

_, axes = plt.subplots(1, 3, figsize=(16 / 2, 9 / 4))
ax = axes[0]
sns.heatmap(clip_class_corr, ax=ax, cbar=False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("CLIP class similarities")

ax = axes[1]
sns.heatmap(speaker_class_corr, ax=ax, cbar=False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("Speaker class similarities")

ax = axes[2]
sns.heatmap(pragmatic_speaker_class_corr, ax=ax, cbar=False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("Pragmatic speaker class similarities")
plt.show()