In [None]:
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision import transforms as t

root_dir = "../"
sys.path.append(root_dir)
import configs
import datasets
from ibydmt.utils.config import get_config
from ibydmt.utils.data import get_dataset
from ibydmt.tester import get_test_classes, get_local_test_idx

rng = np.random.default_rng()

config_name, use_local_test_images = "cub", True
config = get_config(config_name)
test_classes = get_test_classes(config)

transform = t.Compose([t.Resize((224, 224)), t.ToTensor()])
train_dataset = get_dataset(config, train=True, transform=transform)
dataset = val_dataset = get_dataset(config, train=False, transform=transform)
print(len(train_dataset), len(val_dataset), len(train_dataset) + len(val_dataset))

m = 4
if use_local_test_images:
    class_idx = get_local_test_idx(config)
    if max([len(_class_idx) for _class_idx in class_idx.values()]) >= m:
        class_idx = {k: rng.choice(v, m, replace=False) for k, v in class_idx.items()}
else:
    class_idx = {class_name: [] for class_name in test_classes}
    for idx, (_, label) in enumerate(dataset.samples):
        class_name = dataset.classes[label]
        if class_name in test_classes:
            class_idx[class_name].append(idx)
    class_idx = {k: rng.choice(v, m, replace=False) for k, v in class_idx.items()}
print(class_idx)

In [None]:
figure_dir = os.path.join(root_dir, "figures", config.name.lower())
os.makedirs(figure_dir, exist_ok=True)

class_images = {
    k: torch.stack([dataset[idx][0] for idx in v]) for k, v in class_idx.items()
}
images = torch.stack([images for images in class_images.values()], dim=1)
images = images.flatten(0, 1)

_, ax = plt.subplots(figsize=(16, 9))
im = make_grid(images, nrow=len(test_classes))
ax.imshow(im.permute(1, 2, 0))
ax.axis("off")
ax.set_title(" ".join(test_classes))
plt.savefig(os.path.join(figure_dir, "images.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "images.png"), bbox_inches="tight")
plt.show()