In [None]:
from drcomp.reducers import PCA, AutoEncoder
from drcomp.autoencoder import FullyConnectedAE
from drcomp.utils.notebooks import get_dataset
from sklearn.preprocessing import StandardScaler
from sklearn.utils import resample
import torch
import torch.nn as nn
import numpy as np
from skorch.callbacks import EarlyStopping, LRScheduler
from drcomp.plotting import (
    compare_metrics,
    plot_reconstructions,
    visualize_2D_latent_space,
)
import matplotlib.pyplot as plt
import scienceplots
from matplotlib import offsetbox

plt.style.use("science")

In [None]:
X, y = get_dataset("MNIST", root_dir="..")
preprocessor = StandardScaler(with_std=False).fit(X)
X_centered = preprocessor.transform(X)

In [None]:
def get_linear_autoencoder(input_size: int, intrinsic_dim: int):
    base = FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        include_batch_norm=False,
        tied_weights=False,
        encoder_act_fn=nn.Identity,
    )
    return AutoEncoder(base, max_epochs=50, lr=0.001, batch_size=128)

In [None]:
def get_weights(model):
    if isinstance(model, AutoEncoder):
        weights = model.module_.decoder[0].weight.data.cpu().numpy().T
    elif isinstance(model, PCA):
        weights = model.pca.components_
    else:
        raise ValueError(f"Unknown model type {type(model)}")
    return weights

In [None]:
def linearAE(data, intrinsic_dim):
    input_size = data.shape[1]
    ae = get_linear_autoencoder(input_size, intrinsic_dim)
    ae.fit(data)
    weights = get_weights(ae)
    return weights, ae
    u, _, _ = np.linalg.svd(weights.T, full_matrices=False)

    Y = u.T @ data.T
    return u.T, Y.T, ae

In [None]:
intrinsic_dim = 16

In [None]:
def PCA_by_autoencoder(weights):
    U, _, _ = np.linalg.svd(weights.T, full_matrices=False)
    return U.T

In [None]:
# normal PCA
P = PCA(intrinsic_dim=intrinsic_dim).fit(X_centered).pca.components_
print(P.shape)

In [None]:
# PCA by autoencoder
weights, autoencoder = linearAE(X, intrinsic_dim)
U = PCA_by_autoencoder(weights)
print(U.shape)

In [None]:
# regular autoencoder
W = weights
print(W.shape)

In [None]:
# compute embeddings

embedding_pca = P @ X_centered.T
embedding_method = U @ X_centered.T
embedding_ae = W @ X_centered.T
print(embedding_pca.shape)

In [None]:
cov_pca = np.cov(embedding_pca)
cov_method = np.cov(embedding_method)
cov_ae = np.cov(embedding_ae)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 5))
m = ax1.matshow(cov_pca, cmap="gray")
ax2.matshow(cov_method, cmap="gray")
ax2.set_title("PCA-Autoencoder Method")
ax3.matshow(cov_ae, cmap="gray")
ax3.set_title("Autoencoder")

# fig.colorbar(m, ax=[ax1, ax2, ax3], shrink=0.7)
plt.show()

In [None]:
indices = np.triu_indices_from(cov_method, k=1)
mean_corr_method = np.mean(cov_method[indices])
mean_corr_ae = np.mean(cov_ae[indices])
print(
    f"Mean Correlation of the embedding by PCA-Autoencoder Method: {mean_corr_method:.4f}"
)
print(f"Mean Correlation of the embedding by regular Autoencoder: {mean_corr_ae:.4f}")