In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from drcomp.utils.notebooks import get_dataset
import matplotlib.pyplot as plt
import numpy as np
from sklearn.utils import resample

In [None]:
# Get the datasets
dataset_names = [
    "MNIST",
    "FER2013",
    "FashionMNIST",
    "OlivettiFaces",
]
Xys_dict = {name: get_dataset(name, root_dir="..") for name in dataset_names}

In [None]:
img_sizes = [(28, 28, 1), (48, 48, 1), (28, 28, 1), (64, 64, 1)]

In [None]:
def plot_dataset_example(
    samples, img_size: tuple[int, int, int], axs=None, figsize=(4, 4)
):
    assert len(samples) == 4
    if axs is None:
        _, axs = plt.subplots(2, 2, figsize=figsize)
    for i, ax in enumerate(axs.flatten()):
        ax.imshow(samples[i].reshape(*img_size), cmap="gray")
        ax.axis("off")

In [None]:
plot_names = {
    "MNIST": "(a)",
    "FER2013": "(b)",
    "FashionMNIST": "(c)",
    "OlivettiFaces": "(d)",
}

In [None]:
samples = {
    name: resample(Xys[0], stratify=Xys[1], n_samples=4)
    for name, Xys in Xys_dict.items()
}

In [None]:
fig, axs = plt.subplots(nrows=4, ncols=4, figsize=(5.91, 4.8))
for i, (name, sample) in enumerate(samples.items()):
    plot_dataset_example(sample, img_sizes[i], axs.T[i])
    plt.text(
        0.5,
        -0.3,
        plot_names[name],
        transform=axs.T[i][-1].transAxes,
        ha="center",
        va="center",
        fontsize=11,
    )
plt.tight_layout()
plt.subplots_adjust(wspace=0.1, hspace=0.05, bottom=0.2)
plt.savefig("../figures/dataset_samples.pdf", bbox_inches="tight")
plt.show()