In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from drcomp.utils.notebooks import get_dataset
import matplotlib.pyplot as plt
import numpy as np
from sklearn.utils import resample

In [None]:
# Get the datasets
dataset_names = ["MNIST", "FER2013", "LfwPeople", "OlivettiFaces"]
samples_dict = {
    name: resample(get_dataset(name, root_dir="..")[0], n_samples=4)
    for name in dataset_names
}

In [None]:
img_sizes = [(28, 28, 1), (48, 48, 1), (62, 47, 1), (64, 64, 1)]

In [None]:
def plot_dataset_example(
    samples, img_size: tuple[int, int, int], axs=None, figsize=(4, 4)
):
    assert len(samples) == 4
    if axs is None:
        fig, axs = plt.subplots(2, 2, figsize=figsize)
    else:
        fig = plt.gcf()
    for i, ax in enumerate(axs.flatten()):
        ax.imshow(samples[i].reshape(*img_size), cmap="gray")
        ax.axis("off")

In [None]:
def plot_dataset_examples(
    samples_dict: dict[str, np.ndarray],
    img_sizes: list[tuple[int, int, int]],
    figsize=(15, 15),
):
    fig, axs = plt.subplots(4, 4, figsize=figsize)
    for i, (name, samples) in enumerate(samples_dict.items()):
        for j in range(4):
            axs[j, i].imshow(samples[j].reshape(*img_sizes[i]), cmap="gray")
            axs[j, i].axis("off")
        axs[0, i].set_title(name)
    plt.tight_layout()
    plt.show()

In [None]:
plot_dataset_examples(samples_dict, img_sizes, figsize=(5.9, 10))

In [None]:
one_sample_each = {
    name: resample(samples_dict[name], n_samples=1) for name in dataset_names
}
one_sample_each = {
    name: sample.reshape(*img_sizes[i])
    for i, (name, sample) in enumerate(one_sample_each.items())
}

In [None]:
plot_names = {
    "MNIST": "MNIST Zahlen",
    "FER2013": "Facial Emotion Recognition",
    "LfwPeople": "Labeled Faces in the Wild",
    "OlivettiFaces": "Olivetti Faces",
}

In [None]:
import scienceplots
from drcomp.plotting import save_fig

plt.style.use("science")
fig, axs = plt.subplots(2, 2, figsize=(6, 6))
for ax, (name, sample) in zip(axs.flatten(), one_sample_each.items()):
    ax.imshow(sample, cmap="gray")
    ax.axis("off")
    ax.set_title(plot_names[name])
plt.tight_layout()
# plt.show()
save_fig("../figures", fig, "dataset_samples", latex=True, height=6)