In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
import scienceplots
from drcomp.autoencoder import FullyConnectedAE
from drcomp.reducers import PCA, AutoEncoder
from drcomp.utils.notebooks import get_dataset

plt.style.use("science")

## Utility Functions

In [None]:
def get_linear_autoencoder(input_size: int, intrinsic_dim: int, weight_decay: int):
    base = FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        include_batch_norm=False,
        tied_weights=False,
        encoder_act_fn=nn.Identity,
    )

    return AutoEncoder(
        base,
        max_epochs=50,
        lr=0.001,
        batch_size=128,
        weight_decay=weight_decay,
    )


def plot_weights(weights, intrinsic_dim, img_size, title=None, axs=None):
    assert weights.shape == (
        intrinsic_dim,
        np.prod(img_size),
    ), f"Weights must be of shape (intrinsic_dim, np.prod(img_size)), but got {weights.shape}."
    if axs is None:
        fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(5.91, 5))
    for ax, weight in zip(axs.flat, weights):
        ax.imshow(weight.reshape(img_size), cmap="gray")
        ax.axis("off")
    if title is not None:
        plt.suptitle(title)

## PCA, Autoencoder and the method implementation

In [None]:
def analytical_pca(X, intrinsic_dim):
    pca = PCA(intrinsic_dim).fit(X)
    loadings = pca.pca.components_
    embedding = pca.transform(X)
    return loadings, embedding


def linear_autoencoder_weights(X, intrinsic_dim, weight_decay):
    ae = get_linear_autoencoder(
        X.shape[1], intrinsic_dim, weight_decay=weight_decay
    ).fit(X)
    weights = ae.module_.decoder[0].weight.data.numpy().T
    embedding = ae.transform(X)
    return weights, embedding


def pca_by_autoencoder(X, weights):
    U, s, _ = np.linalg.svd(weights.T, full_matrices=False)
    U = U[:, np.argsort(s)[::-1]]
    embedding = X @ U
    return U.T, embedding

## Load the data

In [None]:
## Load FER dataset
X, y = get_dataset("FER2013", root_dir="..")
X_train = StandardScaler(with_mean=True, with_std=True).fit_transform(X)

intrinsic_dim = 9
img_size = (48, 48, 1)

## Train the methods

In [None]:
p_analytical, embedding_pca = analytical_pca(X_train, intrinsic_dim)

# unregularized autoencoder
weights_unreg, embedding_ae_unreg = linear_autoencoder_weights(
    X_train, intrinsic_dim, weight_decay=0
)
p_linear_unreg, embedding_pca_by_ae_unreg = pca_by_autoencoder(X_train, weights_unreg)

# svd of the weights of a regularized autoencoder
weights_reg, embedding_ae_reg = linear_autoencoder_weights(
    X_train, intrinsic_dim, weight_decay=2e-4
)
p_linear, embedding_pca_by_ae = pca_by_autoencoder(X_train, weights_reg)

In [None]:
embeddings = [embedding_pca, embedding_pca_by_ae_unreg, embedding_pca_by_ae]

## Plot the Results

In [None]:
suptitles = ["(a)", "(b)", "(c)"]

# plot the correlation matrices of the embeddings
fig, axes = plt.subplots(1, 3, figsize=(5.91, 2.8))
for i, (ax, embedding) in enumerate(zip(axes, embeddings)):
    cov = np.cov(embedding, rowvar=False)
    ax.imshow(cov, cmap="gray")
    ax.set_xticks([])
    ax.set_yticks([])
    plt.text(0.5, -0.2, suptitles[i], transform=ax.transAxes, ha="center", fontsize=11)
fig.savefig("../figures/covariance-matrices.pdf", bbox_inches="tight")
plt.show()

In [None]:
# plot the weights
fig = plt.figure(figsize=(5.91, 1.95))
sfigs = fig.subfigures(1, 3)

layout = (3, 3)
axsL = sfigs[0].subplots(*layout)
axsM = sfigs[1].subplots(*layout)
axsR = sfigs[2].subplots(*layout)
plot_weights(p_analytical, intrinsic_dim, img_size, axs=axsL)
plot_weights(weights_unreg, intrinsic_dim, img_size, axs=axsM)
plot_weights(p_linear, intrinsic_dim, img_size, axs=axsR)

for sfig, suptitle in zip(sfigs, suptitles):
    sfig.supxlabel(suptitle, fontsize=11)
    sfig.subplots_adjust(wspace=0.1, hspace=0.1)

fig.savefig("../figures/weights-comparison.pdf", bbox_inches="tight")
plt.show()