In [None]:
from drcomp.autoencoder import FullyConnectedAE
from drcomp.reducers import AutoEncoder
from drcomp.utils.notebooks import get_dataset, get_preprocessor
from drcomp.plotting import compare_metrics
import torch.nn as nn
from skorch.callbacks import EarlyStopping, LRScheduler
import json

In [None]:
X, y = get_dataset("MNIST", root_dir="..")
preprocessor = get_preprocessor("MNIST", root_dir="..")

In [None]:
input_size = X.shape[1]
intrinsic_dim = 10

In [None]:
models_untied = {
    "Shallow linear AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[],
        encoder_act_fn=nn.Identity,
    ),
    "Shallow sigmoid AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[],
        encoder_act_fn=nn.Sigmoid,
        decoder_act_fn=nn.Identity,
    ),
    "5-layer linear AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[256],
        encoder_act_fn=nn.Identity,
    ),
    "7-layer linear AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[256, 128],
        encoder_act_fn=nn.Identity,
    ),
    "5-layer sigmoid AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[256],
        encoder_act_fn=nn.Sigmoid,
        decoder_act_fn=nn.Identity,
    ),
    "7-layer sigmoid AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[256, 128],
        encoder_act_fn=nn.Sigmoid,
        decoder_act_fn=nn.Identity,
    ),
}

In [None]:
models_tied = {
    "Shallow linear AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[],
        encoder_act_fn=nn.Identity,
        tied_weights=True,
    ),
    "Shallow sigmoid AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[],
        encoder_act_fn=nn.Sigmoid,
        decoder_act_fn=nn.Identity,
        tied_weights=True,
    ),
    "5-layer linear AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[256],
        encoder_act_fn=nn.Identity,
        tied_weights=True,
    ),
    "7-layer linear AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[256, 128],
        encoder_act_fn=nn.Identity,
        tied_weights=True,
    ),
    "5-layer sigmoid AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[256],
        encoder_act_fn=nn.Sigmoid,
        decoder_act_fn=nn.Identity,
        tied_weights=True,
    ),
    "7-layer sigmoid AE": FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=[256, 128],
        encoder_act_fn=nn.Sigmoid,
        decoder_act_fn=nn.Identity,
        tied_weights=True,
    ),
}

In [None]:
def train(model, X, y, preprocessor):
    X_train = preprocessor.transform(X)
    lr_scheduler = LRScheduler(policy="ExponentialLR", gamma=0.98)
    early_stopping = EarlyStopping(patience=50)
    reducer = AutoEncoder(
        model,
        n_epochs=1000,
        batch_size=250,
        lr=0.1,
        callbacks=[lr_scheduler, early_stopping],
    )
    reducer.fit(X)
    return reducer

In [None]:
reducers = {}
metrics = {}
for name in models_untied:
    print(f"Training {name}")
    reducer = train(models_untied[name], X, y, preprocessor)
    reducers[name] = reducer
    metrics[name] = reducer.evaluate(X, max_K=100, as_builtin_list=True)
json.dump(metrics, open("metrics_untied.json", "w"))

In [None]:
reducers = {}
metrics = {}
for name in models_tied:
    print(f"Training {name}")
    reducer = train(models_tied[name], X, y, preprocessor)
    reducers[name] = reducer
    metrics[name] = reducer.evaluate(X, max_K=100, as_builtin_list=True)
json.dump(metrics, open("metrics_tied.json", "w"))

In [None]:
compare_metrics(metrics)