In [None]:
from drcomp.autoencoder import FullyConnectedAE
from drcomp.reducers import AutoEncoder, PCA
from drcomp.utils.notebooks import get_dataset, get_preprocessor
from drcomp.plotting import compare_metrics
import torch.nn as nn
from skorch.callbacks import EarlyStopping, LRScheduler
import json
from sklearn.utils import resample
import torch
import numpy as np

In [None]:
X, y = get_dataset("MNIST", root_dir=".")
preprocessor = get_preprocessor("MNIST", root_dir=".", from_pretrained=False)

In [None]:
input_size = X.shape[1]
intrinsic_dim = 10

In [None]:
modules = {
    "Shallow linear AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[],
        encoder_act_fn=nn.Identity,
    ),
    "Shallow sigmoid AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[],
        encoder_act_fn=nn.Sigmoid,
        # decoder_act_fn=nn.Identity,
    ),
    "5-layer linear AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[128],
        encoder_act_fn=nn.Identity,
    ),
    "7-layer linear AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[128, 64],
        encoder_act_fn=nn.Identity,
    ),
    "5-layer sigmoid AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[128],
        encoder_act_fn=nn.Sigmoid,
        # decoder_act_fn=nn.Identity,
    ),
    "7-layer sigmoid AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[128, 64],
        encoder_act_fn=nn.Sigmoid,
        # decoder_act_fn=nn.Identity,
    ),
}

In [None]:
def train(model, X_train):
    lr_scheduler = LRScheduler(policy="ExponentialLR", gamma=0.98)
    early_stopping = EarlyStopping(patience=20)
    reducer = AutoEncoder(
        model,
        max_epochs=1000,
        batch_size=128,
        lr=0.01,
        callbacks=[lr_scheduler, early_stopping],
        weight_decay=1e-3,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )
    reducer.fit(X_train)
    return reducer

In [None]:
reducers = {}
metrics = {}
for name in modules:
    print(f"Training {name}")
    X_train = preprocessor.fit_transform(X)
    reducer = train(modules[name], X_train)
    reducers[name] = reducer
    X_eval = resample(X_train, n_samples=5000)
    metrics[name] = reducer.evaluate(X_eval, max_K=100, as_builtin_list=True)
X_train = preprocessor.fit_transform(X)
pca = PCA(n_components=10).fit(X_train)
metrics_pca = pca.evaluate(
    resample(X_train, n_samples=5000), max_K=100, as_builtin_list=True
)
metrics["PCA"] = metrics_pca
reducers["PCA"] = pca
json.dump(metrics, open("metrics-comp3.json", "w"))

In [None]:
compare_metrics(metrics)

In [None]:
from drcomp.plotting import plot_reconstructions

In [None]:
reducers.keys()

In [None]:
fig, axs = plot_reconstructions(
    reducers,
    resample(X, n_samples=10, stratify=y),
    preprocessor=preprocessor,
    width=28,
    height=28,
    channels=1,
)

In [None]:
fig.savefig("reconstructions-comp3.png", dpi=300)