In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from drcomp.utils.notebooks import get_dataset, get_model_for_dataset, get_preprocessor
from drcomp import DimensionalityReducer
import matplotlib.pyplot as plt
import numpy as np
import scienceplots
from sklearn.utils import resample

plt.style.use(["science", "notebook"])

In [None]:
X, _ = get_dataset("MNIST", root_dir="..")
preprocessor = get_preprocessor("MNIST", root_dir="..", from_pretrained=True)
model = get_model_for_dataset("MNIST", "PCA", root_dir="..", from_pretrained=True)
models = {"PCA": model, "PCA2": model}

In [None]:
n_images = 10

In [None]:
fig, axs = plt.subplots(len(models) + 1, n_images, figsize=(10, 2))
images = resample(X, n_samples=n_images)
for i, img in enumerate(images):
    ground_truth = preprocessor.inverse_transform(img.reshape(1, -1)).reshape(28, 28)
    axs[0, i].imshow(ground_truth, cmap="gray")
    axs[0, i].axis("off")
    for j, (name, model) in enumerate(models.items(), start=1):
        reconstructed = model.reconstruct(img.reshape(1, -1))
        reconstructed = preprocessor.inverse_transform(reconstructed).reshape(28, 28)
        axs[j, i].imshow(reconstructed, cmap="gray")
        axs[j, i].axis("off")
        # plt.text(-0.5, 0.5 * j, name, ha="center", va="center", transform=axs[j, 0].transAxes)
# plt.text(-0.5, 0.5, "Ground truth", ha="center", va="center", transform=axs[0, 0].transAxes)
plt.tight_layout()

In [None]:
def plot_reconstructions(
    models: dict[str, DimensionalityReducer],
    images,
    preprocessor,
    width,
    height,
    channels,
    cmap="gray",
    flatten=True,
):
    """Plot the reconstructions of the samples by the given models compared to the original image."""
    n_images = len(images)
    fig, axs = plt.subplots(len(models) + 1, n_images, figsize=(10, 2))
    flattened_size = width * height * channels
    assert np.shape(images) == (n_images, flattened_size)
    ground_truth = images.reshape(
        -1, width, height, channels
    )  # matplotlib expects channels last
    processed_images = preprocessor.transform(images)
    if not flatten:
        X = processed_images.reshape(-1, channels, width, height)
    else:
        X = processed_images.reshape(-1, flattened_size)
    reconstructions = [model.reconstruct(X) for model in models.values()]  # reconstruct
    reconstructions = [
        preprocessor.inverse_transform(X_hat) for X_hat in reconstructions
    ]  # apply inverse transform
    reconstructions = [
        X_hat.reshape(-1, width, height, channels) for X_hat in reconstructions
    ]  # reshape to image size
    for i in range(n_images):
        axs[0, i].imshow(ground_truth[i], cmap=cmap)
        axs[0, i].axis("off")
        for j in range(1, len(models)):
            axs[j, i].imshow(reconstructions[j - 1][i], cmap=cmap)
            axs[j, i].axis("off")
    plt.tight_layout()
    return fig, axs

In [None]:
images = resample(X, n_samples=10)
fig, axs = plot_reconstructions(
    models, images, preprocessor, width=28, height=28, channels=1, flatten=True
)