In [None]:
from drcomp.autoencoder import FullyConnectedAE
from drcomp.reducers import AutoEncoder, PCA
from drcomp.utils.notebooks import get_dataset, get_preprocessor
import torch.nn as nn
from skorch.callbacks import EarlyStopping, LRScheduler
import torch
import numpy as np
import matplotlib.pyplot as plt
import scienceplots

plt.style.use(["science", "scatter"])

## Get the dataset

In [None]:
X, y = get_dataset("MNIST", root_dir="..")
preprocessor = get_preprocessor("MNIST", root_dir="..", from_pretrained=False)

In [None]:
input_size = X.shape[1]
intrinsic_dim = 2

## Define Autoencoder Architectures

In [None]:
hidden_layer_dims = [128]
modules = {
    "linear": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=hidden_layer_dims,
        encoder_act_fn=nn.Identity,
    ),
    "linear-sigmoid": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=hidden_layer_dims,
        encoder_act_fn=[nn.Sigmoid, nn.Identity],
    ),
    "sigmoid-sigmoid": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=hidden_layer_dims,
        encoder_act_fn=nn.Sigmoid,
    ),
}

## Train the autoencoders and PCA

In [None]:
def train(model, X_train):
    lr_scheduler = LRScheduler(policy="ExponentialLR", gamma=0.98)
    early_stopping = EarlyStopping(patience=20)
    reducer = AutoEncoder(
        model,
        max_epochs=100,
        batch_size=128,
        lr=0.01,
        callbacks=[lr_scheduler, early_stopping],
        device="cuda" if torch.cuda.is_available() else "cpu",
    )
    reducer.fit(X_train)
    return reducer

In [None]:
reducers = {}
for name in modules:
    print(f"Training {name}")
    X_train = preprocessor.fit_transform(X)
    reducers[name] = train(modules[name], X_train)


pca = PCA(intrinsic_dim).fit(preprocessor.transform(X))
reducers["PCA"] = pca

# Plot the latent spaces

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(8, 8))
for ax, (name, reducer) in zip(axs.flatten(), reducers.items()):
    Y = reducer.transform(preprocessor.transform(X))
    ax.scatter(Y[:, 0], Y[:, 1], cmap="tab10", c=y, s=1)
    ax.set_title(name)
plt.legend(handles=axs[0, 0].get_legend_handles_labels()[0], loc="best")
plt.show()