In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from drcomp.utils.notebooks import get_dataset
import matplotlib.pyplot as plt
import numpy as np
from sklearn.utils import resample

In [None]:
# Get the datasets
dataset_names = ["MNIST", "FER2013", "LfwPeople", "OlivettiFaces"]
Xys_dict = {name: get_dataset(name, root_dir="..") for name in dataset_names}

In [None]:
img_sizes = [(28, 28, 1), (48, 48, 1), (62, 47, 1), (64, 64, 1)]

In [None]:
def plot_dataset_example(
    samples, img_size: tuple[int, int, int], axs=None, figsize=(4, 4)
):
    assert len(samples) == 4
    if axs is None:
        _, axs = plt.subplots(2, 2, figsize=figsize)
    for i, ax in enumerate(axs.flatten()):
        ax.imshow(samples[i].reshape(*img_size), cmap="gray")
        ax.axis("off")

In [None]:
plot_names = {
    "MNIST": "MNIST Zahlen",
    "FER2013": "Facial Emotion Recognition",
    "LfwPeople": "Labeled Faces in the Wild",
    "OlivettiFaces": "Olivetti Faces",
}

In [None]:
import scienceplots
from drcomp.plotting import save_fig

plt.style.use("science")
fig, axs = plt.subplots(2, 2, figsize=(6, 6))
for ax, (name, sample) in zip(axs.flatten(), one_sample_each.items()):
    ax.imshow(sample, cmap="gray")
    ax.axis("off")
    ax.set_title(plot_names[name])
plt.tight_layout()
# plt.show()
save_fig("../figures", fig, "dataset_samples", latex=True, height=6)

In [None]:
# plot the weights of three selected models
plt.style.use("science")
fig = plt.figure(figsize=(5.9, 7))
sfigs = fig.subfigures(2, 2)

layout = (2, 2)
axsTL = sfigs[0, 0].subplots(*layout)
axsTR = sfigs[0, 1].subplots(*layout)
axsBL = sfigs[1, 0].subplots(*layout)
axsBR = sfigs[1, 1].subplots(*layout)
axs = [axsTL, axsTR, axsBL, axsBR]

for (ax, dataset, img_size) in zip(axs, samples.keys(), img_sizes):
    plot_dataset_example(samples[dataset], img_size, ax)
    if dataset == "LfwPeople":
        hspace = 0.1
    else:
        hspace = 0.05
    plt.gcf().subplots_adjust(wspace=0.05, hspace=hspace)

sfigs[0, 0].text(0.52, 0.05, "(a)", ha="center")
sfigs[0, 1].text(0.52, 0.05, "(b)", ha="center")
sfigs[1, 0].text(0.52, 0.02, "(c)", ha="center")
sfigs[1, 1].text(0.52, 0.02, "(d)", ha="center")
plt.tight_layout()
plt.subplots_adjust(bottom=0.1, wspace=0.1)
fig.savefig("../figures/dataset_samples.pgf", backend="pgf")
plt.show()