In [None]:
from drcomp.autoencoder import FullyConnectedAE
from drcomp.reducers import AutoEncoder, PCA
from drcomp.utils.notebooks import get_dataset, get_preprocessor
import torch.nn as nn
from skorch.callbacks import EarlyStopping, LRScheduler
import torch
import numpy as np
import matplotlib.pyplot as plt
import scienceplots
from sklearn.utils import resample

plt.style.use(["science", "scatter"])

## Load MNIST

In [None]:
X, y = get_dataset("MNIST", root_dir="..")
preprocessor = get_preprocessor("MNIST", root_dir="..", from_pretrained=False)

In [None]:
input_size = X.shape[1]
intrinsic_dim = 2

## Define Autoencoder Architectures

In [None]:
hidden_layer_dims = []
modules = {
    "(b) Linear-Linear": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=hidden_layer_dims,
        encoder_act_fn=nn.Identity,
        decoder_act_fn=nn.Identity,
        include_batch_norm=False,
    ),
    "(c) Sigmoid-Linear": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=hidden_layer_dims,
        encoder_act_fn=nn.Sigmoid,
        decoder_act_fn=nn.Identity,
        include_batch_norm=False,
    ),
    "(d) Sigmoid-Sigmoid": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=hidden_layer_dims,
        encoder_act_fn=nn.Sigmoid,
        decoder_act_fn=nn.Sigmoid,
        include_batch_norm=False,
    ),
}

## Train the autoencoders and PCA

In [None]:
def train(model, X_train, **kwargs):
    lr_scheduler = LRScheduler(policy="ExponentialLR", gamma=0.9)
    early_stopping = EarlyStopping(patience=20)
    reducer = AutoEncoder(
        model,
        max_epochs=100,
        batch_size=64,
        lr=0.01,
        callbacks=[lr_scheduler, early_stopping],
        device="cuda" if torch.cuda.is_available() else "cpu",
        **kwargs
    )
    reducer.fit(X_train)
    return reducer

In [None]:
pca = PCA(intrinsic_dim).fit(preprocessor.fit_transform(X))
reducers = {"(a) PCA": pca}

for name in modules:
    print(f"Training {name}")
    X_train = preprocessor.fit_transform(X)
    reducers[name] = train(modules[name], X_train, weight_decay=1e-6)

# Plot the latent spaces

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(5.7, 4.8))
for ax, (name, reducer) in zip(axs.flatten(), reducers.items()):
    Y = reducer.transform(preprocessor.transform(X))
    Y, y_sampled = resample(Y, y, stratify=y, random_state=0, n_samples=10_000)
    scatter = ax.scatter(
        Y[:, 0], Y[:, 1], c=y_sampled, s=3, cmap="tab10", alpha=0.9, label=name
    )
    plt.text(0.5, -0.25, name, fontsize=11, transform=ax.transAxes, ha="center")
fig.legend(
    *scatter.legend_elements(),
    bbox_to_anchor=(0.975, 0.5),
    loc="center left",
    fontsize=11,
)
plt.tight_layout()
plt.subplots_adjust(wspace=0.25, hspace=0.4)
fig.savefig("../figures/autoencoders-nonlinearity.pdf", bbox_inches="tight")
plt.show()