In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

os.environ["CUDA_VISIBLE_DEVICES"] = "3"

root_dir = "../"
sys.path.append(root_dir)
from datasets import CUB
from classifiers import CLIPClassifier
from speaker_model import ClaimSpeaker

rng = np.random.default_rng()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_dir = os.path.join(root_dir, "data")

backbone = "ViT-L/14"
classifier = CLIPClassifier(backbone, device=device)

dataset = CUB(
    data_dir, train=False, transform=classifier.preprocess, return_attribute=True
)
classes = dataset.classes
claims = dataset.claims

class_prompts = [f"A photo of a {class_name}" for class_name in classes]

beta = 0.4
speaker = ClaimSpeaker(classifier, len(classes), claims, device=device)
speaker_state_path = os.path.join(root_dir, "weights", f"speaker_10_0.0_{beta}_8.pt")
speaker.load_state_dict(torch.load(speaker_state_path))
speaker.eval()

pragmatic_speaker = ClaimSpeaker(classifier, len(classes), claims, device=device)
pragmatic_speaker_state_path = os.path.join(
    root_dir, "weights", f"speaker_10_0.05_{beta}_8.pt"
)
pragmatic_speaker.load_state_dict(torch.load(pragmatic_speaker_state_path))
pragmatic_speaker.eval()

sns.set_theme()
sns.set_context("paper")

In [None]:
m = 8
idx = rng.choice(len(dataset), m, replace=False)

for _idx in idx:
    image, attribute = dataset[_idx]
    image_path, _ = dataset.samples[_idx]

    image = image.to(device).unsqueeze(0)

    with torch.no_grad():
        cls_output = classifier(image, class_prompts)
        image_features = cls_output["image_features"]
        logits = cls_output["logits"]
        prediction = torch.argmax(logits).item()

        prediction_name = classes[prediction]

    b = 10
    image_features = image_features.unsqueeze(1).expand(-1, 10, -1).float()
    prediction = torch.tensor([prediction]).long().to(device)

    k = 4
    explanation, explanation_logp = speaker.explain(image_features, prediction, k)
    explanation = explanation.squeeze()

    pragmatic_explanation, pragmatic_explanation_logp = pragmatic_speaker.explain(
        image_features, prediction, k
    )
    pragmatic_explanation = pragmatic_explanation.squeeze()

    _, ax = plt.subplots(figsize=(3, 3))
    image_raw = Image.open(image_path).convert("RGB")
    ax.imshow(image_raw)
    ax.axis("off")
    ax.set_title(f"Prediction: {prediction_name}")
    plt.show()

    print("Not pragmatic speaker")
    for _explanation in explanation:
        explanation_claims = [claims[i] for i in _explanation if i < len(claims)]
        explanation_gt = [attribute[i] for i in _explanation if i < len(claims)]
        print(list(zip(explanation_claims, explanation_gt)))
    print("==========")

    print("Pragmatic speaker")
    for _explanation in pragmatic_explanation:
        explanation_claims = [claims[i] for i in _explanation if i < len(claims)]
        explanation_gt = [attribute[i] for i in _explanation if i < len(claims)]
        print(list(zip(explanation_claims, explanation_gt)))
    print("==========")