In [None]:
import os
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.utils import make_grid
from torchvision.transforms import Compose, CenterCrop, ToTensor
from functools import reduce
from scipy.special import softmax

root_dir = "../"
sys.path.append(root_dir)
from configs.utils import get_config
from concept_datasets import get_concept_dataset
from models.clip_classifier import CLIPClassifier

config_name = "imagenette"
config = get_config(config_name)

results_dir = os.path.join(root_dir, "results", config.name.lower())

transform = Compose([CenterCrop(224), ToTensor()])
train_dataset = dataset = get_concept_dataset(
    config, train=True, transform=transform, return_image=True
)
val_dataset = get_concept_dataset(config, train=False, return_image=True)
print(f"Train images: {len(train_dataset)}")
print(f"Validation images: {len(val_dataset)}")
print(f"Total images: {len(train_dataset) + len(val_dataset)}")
classes, concepts = dataset.classes, dataset.concepts

classifier = CLIPClassifier.load_or_train(config, root_dir)
classifier.predict(root_dir)

figure_dir = os.path.join(root_dir, "figures", config.name.lower(), "images")
os.makedirs(figure_dir, exist_ok=True)

sns.set_style("white")
sns.set_context("paper")

In [None]:
class_idx = {
    class_name: [idx for idx, y in enumerate(dataset.Y) if y == i]
    for i, class_name in enumerate(classes)
}
image_idx = {
    class_name: np.random.choice(idx, 5).tolist()
    for class_name, idx in class_idx.items()
}

_, axes = plt.subplots(1, len(classes), figsize=(16, 9), gridspec_kw={"wspace": 0.1})
for i, (class_name, idx) in enumerate(image_idx.items()):
    images = torch.stack([dataset[_idx][0] for _idx in idx])

    ax = axes[i]
    img = make_grid(images, nrow=1)
    ax.imshow(img.permute(1, 2, 0))
    ax.axis("off")
    ax.set_title(class_name)

plt.savefig(os.path.join(figure_dir, "sample_images.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "sample_images.png"), bbox_inches="tight")
plt.show()

In [None]:
_image_idx = reduce(lambda x, y: x + y, list(image_idx.values()))

for i, idx in enumerate(_image_idx):
    image, h, z, label = dataset[idx]
    class_name = dataset.classes[label]

    output = classifier(h)
    probs = softmax(output / 0.1)

    sorted_idx = np.argsort(z)[::-1]
    sorted_z = z[sorted_idx]
    sorted_concepts = [concepts[i] for i in sorted_idx]

    _, axes = plt.subplots(1, 2, figsize=(9 / 2, 16 / 4), gridspec_kw={"wspace": 0.7})
    ax = axes[0]
    ax.imshow(image.permute(1, 2, 0))
    ax.axis("off")
    ax.set_title(class_name)

    ax = axes[1]
    sns.barplot(x=sorted_z, y=sorted_concepts, ax=ax)
    ax.set_xlabel(r"$z$")
    ax.set_xlim(0.8 * sorted_z.min())

    # ax = axes[2]
    # sns.barplot(x=probs, y=dataset.classes, ax=ax)

    plt.savefig(os.path.join(figure_dir, f"{class_name}_{i}.pdf"), bbox_inches="tight")
    plt.savefig(os.path.join(figure_dir, f"{class_name}_{i}.png"), bbox_inches="tight")
    plt.show()
    plt.close()

In [None]:
predictions = get_predictions(config, root_dir)
predictions["correct"] = predictions["class"] == predictions["prediction"]

class_correct = predictions.groupby("class")["correct"]
class_accuracy = class_correct.mean().reset_index()
class_std = class_correct.std().reset_index()

sorted_idx = np.argsort(class_accuracy["correct"].values)[::-1]
sorted_classes = class_accuracy["class"].values[sorted_idx]
sorted_accuracy = class_accuracy["correct"].values[sorted_idx]
sorted_std = class_std["correct"].values[sorted_idx]

_, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
sns.barplot(x=sorted_accuracy, y=sorted_classes, ax=ax)
ax.set_xlabel("Accuracy")
ax.set_xlim(0.8 * sorted_accuracy.min(), 1)
plt.show()