In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.datasets import CIFAR100
from torchvision.utils import make_grid
from sklearn.metrics import confusion_matrix

root_dir = "../"
data_dir = os.path.join(root_dir, "data")
results_dir = os.path.join(root_dir, "results")

dataset = CIFAR100(data_dir, train=False, transform=T.ToTensor())

results = pd.read_csv(os.path.join(results_dir, "cifar100_predictions.csv"))

sns.set_theme()
sns.set_context("paper")

In [None]:
figure_dir = os.path.join(root_dir, "figures", "cifar")
os.makedirs(figure_dir, exist_ok=True)

m = 6
idx = np.random.choice(len(dataset), m, replace=False)
images = torch.stack([dataset[i][0] for i in idx])

_, ax = plt.subplots(figsize=(16 / 4, 9 / 2))
grid = make_grid(images, nrow=2)
ax.imshow(grid.permute(1, 2, 0))
ax.axis("off")
ax.set_title("Original")
plt.savefig(os.path.join(figure_dir, "original.pdf"))
plt.show()

for shift in [0, 0.2, 0.4, 0.6, 0.8, 1.0]:
    shifted_images = T.functional.solarize(images, shift)

    _, ax = plt.subplots(figsize=(16 / 4, 9 / 2))
    grid = make_grid(shifted_images, nrow=2)
    ax.imshow(grid.permute(1, 2, 0))
    ax.axis("off")
    ax.set_title(f"Threshold = {shift:.1f}")
    plt.savefig(os.path.join(figure_dir, f"solarize_{shift:.1f}.pdf"))
    plt.show()

In [None]:
prediction, label = results["prediction"].values, results["label"].values
cm = confusion_matrix(label, prediction, normalize="true")

_, ax = plt.subplots(figsize=(5, 5))
sns.heatmap(cm, cmap="viridis", ax=ax)
ax.set_xlabel("Predicted label")
ax.set_ylabel("True label")
ax.axis("equal")
ax.set_title(f"Confusion matrix for CLIP ViT-L/14\naccuracy = {np.diag(cm).mean():.2%}")
plt.show()