In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from drcomp.reducers import PCA, AutoEncoder
from drcomp.autoencoder import FullyConnectedAE
from drcomp.utils.notebooks import get_dataset
from sklearn.preprocessing import StandardScaler
from sklearn.utils import resample
import torch
import torch.nn as nn
import numpy as np
from skorch.callbacks import EarlyStopping, LRScheduler
from drcomp.plotting import (
    compare_metrics,
    plot_reconstructions,
    visualize_2D_latent_space,
)
import matplotlib.pyplot as plt
import scienceplots
from matplotlib import offsetbox

plt.style.use("science")

In [None]:
X, y = get_dataset("FER2013", root_dir=".")
preprocessor = StandardScaler().fit(X)
X_train = preprocessor.transform(X)

In [None]:
intrinsic_dim = 7
img_size = height, width, channels = (48, 48, 1)
input_size = channels * height * width

In [None]:
labels = {
    0: "angry",
    1: "disgust",
    2: "fear",
    3: "happy",
    4: "sad",
    5: "suprise",
    6: "neutral",
}
labels_de = {
    0: "wütend",
    1: "empört",
    2: "ängstlich",
    3: "glücklich",
    4: "traurig",
    5: "überrascht",
    6: "neutral",
}

In [None]:
# plot some images
fig, axs = plt.subplots(2, 3, figsize=(5.91, 4))
y_labels = list(map(lambda x: labels[x], np.unique(y)))
for label in labels:
    idx = np.where(y == label)[0][1]
    ax = axs.flat[label - 1]
    ax.imshow(X[idx].reshape(height, width), cmap="gray")
    ax.set_title(labels_de[y[idx]], fontsize=11)
    ax.axis("off")
fig.savefig("../figures/fer2013-images.pgf", backend="pgf")

In [None]:
def get_autoencoder(baseClass, lr=0.1, gamma=0.95):
    callbacks = [
        EarlyStopping(patience=10, monitor="valid_loss"),
        LRScheduler(policy="ExponentialLR", gamma=gamma, monitor="valid_loss"),
    ]
    return AutoEncoder(
        AutoEncoderClass=baseClass,
        criterion=nn.MSELoss,
        optimizer=torch.optim.Adam,
        lr=lr,
        contractive=False,
        callbacks=callbacks,
        max_epochs=100,
        batch_size=100,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )


def get_base_encoder(
    encoder_activations,
    hidden_layer_dims,
    tied_weights: bool = False,
    decoder_activations=None,
):
    return FullyConnectedAE(
        input_size=input_size,
        intrinsic_dim=intrinsic_dim,
        hidden_layer_dims=hidden_layer_dims,
        encoder_act_fn=encoder_activations,
        decoder_act_fn=decoder_activations,
        include_batch_norm=False,
        tied_weights=tied_weights,
    )

In [None]:
models = {
    "PCA": PCA(intrinsic_dim=intrinsic_dim),
    "Linear shallow AE": get_autoencoder(get_base_encoder(nn.Identity, [])),
    "Sigmoid-linear shallow AE": get_autoencoder(
        get_base_encoder(nn.Sigmoid, [], decoder_activations=nn.Identity)
    ),
    "Sigmoid shallow AE": get_autoencoder(get_base_encoder(nn.Sigmoid, [])),
    "ReLU-linear shallow AE": get_autoencoder(
        get_base_encoder(nn.ReLU, [], decoder_activations=nn.Identity)
    ),
    "ReLU shallow AE": get_autoencoder(get_base_encoder(nn.ReLU, [])),
    "Tanh-linear shallow AE": get_autoencoder(
        get_base_encoder(nn.Tanh, [], decoder_activations=nn.Identity)
    ),
    "Tanh shallow AE": get_autoencoder(get_base_encoder(nn.Tanh, [])),
}

In [None]:
embeddings = {}
for name, model in models.items():
    print(f"Training {name}...")
    embeddings[name] = model.fit_transform(X_train.copy())
    print(f"Training {name} done.")

In [None]:
def plot_weights(weights, intrinsic_dim, img_size, title=None, axs=None):
    assert weights.shape == (
        intrinsic_dim,
        input_size,
    ), f"Weights must be of shape (intrinsic_dim, np.prod(img_size)), but got {weights.shape}."
    if axs is None:
        fig, axs = plt.subplots(2, 3, figsize=(6, 5))
    for ax, weight in zip(axs.flat, weights):
        ax.imshow(weight.reshape(img_size), cmap="gray")
        ax.axis("off")
    if title is not None:
        plt.suptitle(title)

In [None]:
def get_weights(model):
    if isinstance(model, AutoEncoder):
        weights = model.module_.decoder[0].weight.data.cpu().numpy().T
    elif isinstance(model, PCA):
        weights = model.pca.components_
    else:
        raise ValueError(f"Unknown model type {type(model)}")
    return weights

In [None]:
# plot the weights of three selected models
plt.style.use("science")
fig = plt.figure(figsize=(5.9, 3))
sfigs = fig.subfigures(1, 3)

layout = (3, 2)
axsL = sfigs[0].subplots(*layout)
axsM = sfigs[1].subplots(*layout)
axsR = sfigs[2].subplots(*layout)
show = ["PCA", "Linear shallow AE", "Sigmoid-linear shallow AE"]
plot_weights(get_weights(models[show[0]]), intrinsic_dim, img_size, axs=axsL)
plot_weights(get_weights(models[show[1]]), intrinsic_dim, img_size, axs=axsM)
plot_weights(get_weights(models[show[2]]), intrinsic_dim, img_size, axs=axsR)

sfigs[0].text(0.52, 0.02, "(a)", ha="center")
sfigs[1].text(0.52, 0.02, "(b)", ha="center")
sfigs[2].text(0.52, 0.02, "(c)", ha="center")
# fig.savefig("../figures/weights-comparison.pgf", backend="pgf")
plt.show()

In [None]:
import string

# plot the weights of some selected models
show = ["PCA", "Linear shallow AE", "Sigmoid-linear shallow AE", "Sigmoid shallow AE"]
models_to_show = {name: model for name, model in models.items() if name in show}
names = [f"({x})" for x in string.ascii_lowercase[: len(models_to_show)]]
plt.style.use("science")
fig = plt.figure(figsize=(5.9, 7))
length = len(models_to_show)
sfigs = fig.subfigures(1, length)
for i, model in enumerate(models_to_show.values()):
    axs = sfigs[i].subplots(intrinsic_dim, 1)
    sfigs[i].suptitle(names[i])
    plot_weights(
        get_weights(model), intrinsic_dim=intrinsic_dim, img_size=img_size, axs=axs
    )
# fig.savefig("../figures/weights-comparison.pgf", backend="pgf")
plt.subplots_adjust(wspace=0.05, hspace=0.05)
plt.show()

In [None]:
# display covariance matrices of transformed data
fig, axs = plt.subplots(2, 4, figsize=(9, 6))
for i, (name, embedding) in enumerate(embeddings.items()):
    axs.flat[i].matshow(np.corrcoef(embedding.T), cmap="RdBu_r", vmin=-1, vmax=1)
    axs.flat[i].set_title(name)
plt.tight_layout()

In [None]:
# display correlation matrices of transformed data (only selected methods)
fig, axs = plt.subplots(1, 3, figsize=(5.9, 3))
for i, (name, embedding) in enumerate(embeddings.items()):
    if name not in show:
        continue
    ax = axs.flat[i]
    ax.matshow(np.corrcoef(embedding.T), cmap="RdBu_r", vmin=-1, vmax=1)
    ax.set_xticks([])
    ax.set_yticks([])
    plt.text(0.5, -0.2, names[i], transform=ax.transAxes, ha="center")
fig.colorbar(axs.flat[0].images[0], ax=axs, location="right", shrink=0.6)
fig.savefig("../figures/correlation-matrices.pgf", backend="pgf")

In [None]:
images = resample(X, n_samples=5)
plot_reconstructions(
    models,
    preprocessor=preprocessor,
    images=images,
    channels=channels,
    height=height,
    width=width,
)
plt.show()