In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from drcomp.reducers import PCA, AutoEncoder
from drcomp.autoencoder import FullyConnectedAE
from drcomp.utils.notebooks import get_dataset
from sklearn.preprocessing import StandardScaler
from sklearn.utils import resample
import torch
import torch.nn as nn
import numpy as np
from skorch.callbacks import EarlyStopping, LRScheduler
from drcomp.plotting import (
    compare_metrics,
    plot_reconstructions,
    visualize_2D_latent_space,
)
import matplotlib.pyplot as plt
import scienceplots
from matplotlib import offsetbox

In [None]:
X, y = get_dataset("FER2013", root_dir="..")
preprocessor = StandardScaler().fit(X)
X_train = preprocessor.transform(X)

In [None]:
intrinsic_dim = 6
img_size = height, width, channels = (48, 48, 1)
input_size = channels * height * width

In [None]:
labels = {
    "angry": 0,
    "disgust": 1,
    "fear": 2,
    "happy": 3,
    "neutral": 4,
    "sad": 5,
    "surprise": 6,
}

In [None]:
def get_autoencoder(baseClass, lr=0.1, gamma=0.9):
    callbacks = [
        EarlyStopping(patience=10, monitor="valid_loss"),
        LRScheduler(policy="ExponentialLR", gamma=gamma, monitor="valid_loss"),
    ]
    return AutoEncoder(
        AutoEncoderClass=baseClass,
        criterion=nn.MSELoss,
        optimizer=torch.optim.Adam,
        lr=lr,
        contractive=False,
        callbacks=callbacks,
        max_epochs=100,
        batch_size=32,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )


def get_base_encoder(
    encoder_activations,
    hidden_layer_dims,
    tied_weights: bool = False,
    decoder_activations=None,
):
    return FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=hidden_layer_dims,
        encoder_act_fn=encoder_activations,
        decoder_act_fn=decoder_activations,
        include_batch_norm=False,
        tied_weights=tied_weights,
    )

In [None]:
models = {
    "PCA": PCA(intrinsic_dim=intrinsic_dim),
    "Linear shallow AE": get_autoencoder(get_base_encoder(nn.Identity, [])),
    "Sigmoid-linear shallow AE": get_autoencoder(
        get_base_encoder(nn.Sigmoid, [], decoder_activations=nn.Identity)
    ),
    "Sigmoid shallow AE": get_autoencoder(get_base_encoder(nn.Sigmoid, [])),
    "ReLU-linear shallow AE": get_autoencoder(
        get_base_encoder(nn.ReLU, [], decoder_activations=nn.Identity)
    ),
    "ReLU shallow AE": get_autoencoder(get_base_encoder(nn.ReLU, [])),
    "Tanh-linear shallow AE": get_autoencoder(
        get_base_encoder(nn.Tanh, [], decoder_activations=nn.Identity)
    ),
    "Tanh shallow AE": get_autoencoder(get_base_encoder(nn.Tanh, [])),
}

In [None]:
embeddings = {}
for name, model in models.items():
    print(f"Training {name}...")
    embeddings[name] = model.fit_transform(X_train)
    print(f"Training {name} done.")

In [None]:
def plot_weights(weights, intrinsic_dim, title, img_size):
    assert weights.shape == (
        intrinsic_dim,
        np.prod(img_size),
    ), f"Weights must be of shape (intrinsic_dim, np.prod(img_size)), but got {weights.shape}."
    fig, axs = plt.subplots(2, 3, figsize=(6, 5))
    for ax, weight in zip(axs.flat, weights):
        ax.imshow(weight.reshape(img_size), cmap="gray")
        ax.axis("off")
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

In [None]:
# show all weights
for name, model in models.items():
    if isinstance(model, AutoEncoder):
        weights = model.module_.decoder[0].weight.data.cpu().numpy().T
    elif isinstance(model, PCA):
        weights = model.pca.components_
    else:
        raise ValueError(f"Unknown model type {type(model)}")
    plot_weights(weights, intrinsic_dim, f"{name} weights", img_size)