In [None]:
from drcomp.autoencoder import FullyConnectedAE
from drcomp.reducers import AutoEncoder
from drcomp.utils.notebooks import get_dataset, get_preprocessor
from drcomp.plotting import compare_metrics
import torch.nn as nn

In [None]:
X, y = get_dataset("MNIST", root_dir="..")
preprocessor = get_preprocessor("MNIST", root_dir="..")

In [None]:
input_size = X.shape[1]
intrinsic_dim = 10

In [None]:
models = {
    "Shallow linear AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[],
        encoder_act_fn=nn.Identity,
    ),
    "Shallow sigmoid AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[],
        encoder_act_fn=nn.Sigmoid,
        decoder_act_fn=nn.Identity,
    ),
    "5-layer linear AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[256],
        encoder_act_fn=nn.Identity,
    ),
    "7-layer linear AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[256, 128],
        encoder_act_fn=nn.Identity,
    ),
    "5-layer sigmoid AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[256],
        encoder_act_fn=nn.Sigmoid,
        decoder_act_fn=nn.Identity,
    ),
    "7-layer sigmoid AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[256, 128],
        encoder_act_fn=nn.Sigmoid,
        decoder_act_fn=nn.Identity,
    ),
}

In [None]:
def train(model, X, y, preprocessor):
    X_train = preprocessor.transform(X)
    reducer = AutoEncoder(model, n_epochs=1000, batch_size=250, lr=1e-3)
    reducer.fit(X)
    return reducer

In [None]:
reducers = {}
metrics = {}
for name in models:
    print(f"Training {name}")
    reducer = train(models[name], X, y, preprocessor)
    reducers[name] = reducer
    metrics[name] = reducer.evaluate(X, max_K=100)

In [None]:
compare_metrics(metrics)