In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from drcomp.reducers import PCA, AutoEncoder
from drcomp.autoencoder import FullyConnectedAE
from drcomp.utils.notebooks import get_dataset
from sklearn.preprocessing import StandardScaler
from sklearn.utils import resample
import torch
import torch.nn as nn
import numpy as np
from skorch.callbacks import EarlyStopping, LRScheduler
from drcomp.plotting import (
    compare_metrics,
    plot_reconstructions,
    visualize_2D_latent_space,
)
import matplotlib.pyplot as plt
import scienceplots
from matplotlib import offsetbox

In [None]:
X, y = get_dataset("MNIST", root_dir="..")
preprocessor = StandardScaler().fit(X)
X_train = preprocessor.transform(X)

In [None]:
intrinsic_dim = 3  # for visualization purposes
img_size = height, width, channels = (28, 28, 1)
input_size = channels * height * width

In [None]:
def get_autoencoder(baseClass, lr=0.1, gamma=0.9):
    callbacks = [
        EarlyStopping(patience=10, monitor="valid_loss"),
        LRScheduler(policy="ExponentialLR", gamma=gamma, monitor="valid_loss"),
    ]
    return AutoEncoder(
        AutoEncoderClass=baseClass,
        criterion=nn.MSELoss,
        optimizer=torch.optim.Adam,
        lr=lr,
        contractive=False,
        callbacks=callbacks,
        max_epochs=100,
        batch_size=100,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )


def get_base_encoder(
    encoder_activations,
    hidden_layer_dims,
    tied_weights: bool = False,
    decoder_activations=None,
):
    return FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=hidden_layer_dims,
        encoder_act_fn=encoder_activations,
        decoder_act_fn=decoder_activations,
        include_batch_norm=False,
        tied_weights=tied_weights,
    )

In [None]:
models = {
    "PCA": PCA(intrinsic_dim=intrinsic_dim),
    "Linear shallow AE": get_autoencoder(
        get_base_encoder(nn.Identity, [], decoder_activations=nn.Identity)
    ),
    "Sigmoid-linear shallow AE": get_autoencoder(
        get_base_encoder(nn.Sigmoid, [], decoder_activations=nn.Identity)
    ),
    "Nonlinear shallow AE": get_autoencoder(get_base_encoder(nn.Sigmoid, [])),
}

In [None]:
# all_metrics = {}
embeddings = {}
for name, model in models.items():
    print(f"Training {name}")
    embeddings[name] = model.fit_transform(X_train)
    # metrics = model.evaluate(X_train, embeddings[name], max_K=100)
    # all_metrics[name] = metrics

In [None]:
def plot_latent_space(embedding, targets, labels, ax=None):
    if ax is None:
        if embedding.shape[1] == 3:
            fig, ax = plt.subplots(figsize=(6, 6), subplot_kw={"projection": "3d"})
        else:
            fig, ax = plt.subplots(figsize=(6, 6))
    for label in labels:
        idx = np.where(targets == label)[0]
        if embedding.shape[1] == 3:
            ax.scatter(
                embedding[idx, 0],
                embedding[idx, 1],
                embedding[idx, 2],
                label=label,
                c=targets[idx],
                alpha=0.8,
            )
        else:
            ax.scatter(
                embedding[idx, 0],
                embedding[idx, 1],
                c=targets[idx],
                label=label,
                alpha=0.8,
            )
    plt.legend()

In [None]:
plt.style.use(["science", "scatter"])
fig = plt.figure(figsize=(4, 12))
labels = np.unique(y)
rows = len(models)
cols = 1
for i, (name, model) in enumerate(models.items()):
    ax = fig.add_subplot(rows, cols, i + 1)
    Y = model.transform(X_train)
    samples, targets = resample(
        Y, y, n_samples=5000, stratify=y, random_state=0
    )  # for plotting
    plot_latent_space(samples, targets, labels, ax=ax)
plt.tight_layout()

In [None]:
# https://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html#sphx-glr-auto-examples-manifold-plot-lle-digits-py
def plot_embedding(Y, title):
    _, ax = plt.subplots()

    for digit in np.unique(y):
        ax.scatter(
            *Y[y == digit].T,
            marker=f"${digit}$",
            s=60,
            color=plt.cm.Dark2(digit),
            alpha=0.425,
            zorder=2,
        )
    shown_images = np.array([[1.0, 1.0]])  # just something big
    for i in range(Y.shape[0]):
        # plot every digit on the embedding
        # show an annotation box for a group of digits
        dist = np.sum((Y[i] - shown_images) ** 2, 1)
        if np.min(dist) < 4e-3:
            # don't show points that are too close
            continue
        shown_images = np.concatenate([shown_images, [Y[i]]], axis=0)
        imagebox = offsetbox.AnnotationBbox(
            offsetbox.OffsetImage(X[i].reshape(48, 48, 1), cmap=plt.cm.gray_r),
            X[i].reshape(48, 48, 1),
        )
        imagebox.set(zorder=1)
        ax.add_artist(imagebox)

    ax.set_title(title)
    ax.axis("off")