In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from drcomp.reducers import PCA, KernelPCA, AutoEncoder
from drcomp.autoencoder import FullyConnectedAE
from drcomp import estimate_intrinsic_dimension, DimensionalityReducer
from sklearn.datasets import fetch_lfw_people
import matplotlib.pyplot as plt
import random

In [None]:
lfw_people = fetch_lfw_people(
    min_faces_per_person=0,
    color=True,
    data_home="../data/raw",
    download_if_missing=True,
)
X = lfw_people.data
X.shape

In [None]:
intrinsic_dim = estimate_intrinsic_dimension(X)
intrinsic_dim

In [None]:
pca = PCA(intrinsic_dim)
kpca = KernelPCA(intrinsic_dim, kernel="rbf", fit_inverse_transform=True)
base = FullyConnectedAE(
    input_size=X.shape[1], intrinsic_dim=intrinsic_dim, hidden_layer_dims=[1024, 256]
)
AE = AutoEncoder(base, batch_size=16, max_epochs=10, learning_rate=1e-3)

In [None]:
reducers: list[tuple[str, DimensionalityReducer]] = [
    ("PCA", pca),
    ("Kernel PCA", kpca),
    ("AutoEncoder", AE),
]

In [None]:
fig, axes = plt.subplots(ncols=len(reducers) + 1, figsize=(10, 5))
image_idx = 214
axes[0].imshow(X[image_idx].reshape(image_shape), cmap="gray")
axes[0].set_title("Original Image")
for i, (name, reducer) in enumerate(reducers, start=1):
    reducer.fit(X, X)
    Y = reducer.transform(X)
    X_hat = reducer.inverse_transform(Y)
    axes[i].imshow(X_hat[image_idx].reshape(image_shape), cmap="gray")
    axes[i].set_title(f"Reconstructed by {name}")
plt.tight_layout()
plt.show()

In [None]:
def show_reconstruction(image_idx: int, model: DimensionalityReducer):
    image_shape = (62, 47)
    original = X[image_idx].reshape(image_shape)
    reconstructed = model.inverse_transform(
        model.transform(X[image_idx].reshape(1, -1))
    ).reshape(image_shape)

    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 5))
    ax1.imshow(original, cmap="gray")
    ax2.imshow(reconstructed, cmap="gray")
    plt.suptitle(f"Original vs. Reconstructed by {model.__class__.__name__}")
    plt.show()

In [None]:
image_idx = random.randint(0, len(X) - 1)
for reducer in reducers:
    show_reconstruction(image_idx, reducer[1])