In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from drcomp.autoencoder import FullyConnectedAE, MnistConvAE
from drcomp.reducers import AutoEncoder
from drcomp import estimate_intrinsic_dimension
import wandb
import torch
import torch.nn as nn
import pickle
import numpy as np
from torchvision import datasets, transforms
from skorch.callbacks import LRScheduler, WandbLogger, EarlyStopping, ProgressBar
from torchsummary import summary

In [None]:
mnist_train = datasets.MNIST(
    root="/storage/data", download=True, transform=transforms.ToTensor()
)
# mnist_test = datasets.MNIST(
#    root="/storage/data", download=True, transform=transforms.ToTensor(), train=False
# )

In [None]:
X_train = mnist_train.data.numpy().astype("float32")
X_train = X_train.reshape(-1, 1, 28, 28)
# X_test = mnist_test.data.numpy().astype("float32")
n_samples = X_train.shape[0]
image_size = (28, 28)

In [None]:
# intrinsic_dim = estimate_intrinsic_dimension(X_train.reshape(n_samples, -1), K=10) # 15
intrinsic_dim = 16

In [None]:
config = {
    "max_epochs": 500,
    "batch_size": 100,
}
wandb_run = wandb.init(project="drcomp", group="MNIST", reinit=True, config=config)
wandb = WandbLogger(wandb_run)

In [None]:
lr_schedule = LRScheduler(policy="ReduceLROnPlateau")
device = "cuda" if torch.cuda.is_available() else "cpu"
# X_train = X_train.reshape(n_samples, -1)
# base = FullyConnectedAE(
#     input_size=784,
#     hidden_layer_dims=[256],
#     intrinsic_dim=intrinsic_dim,
#     act_fn=nn.ReLU,
# )
model = AutoEncoder(
    MnistConvAE(intrinsic_dim=intrinsic_dim),
    batch_size=config["batch_size"],
    max_epochs=config["max_epochs"],
    device=device,
    callbacks=[
        lr_schedule,
        WandbLogger(wandb_run),
        EarlyStopping(patience=10),
        ProgressBar(),
    ],
)
model.initialize()
summary(model.module_, (1, 28, 28))

In [None]:
X_train = torch.from_numpy(X_train).to(device)
model.fit(X_train)
wandb_run.finish()
name = "mnist_fc_conv_ae"
with open(f"../models/{name}.pkl", "wb") as f:
    pickle.dump(model, f)

In [None]:
model = pickle.load(open("../models/mnist_conv_ae.pkl", "rb"))

In [None]:
# this model is quite big, so we need to transform it in batches
X_hat = np.array([])
Y = np.array([])
for decoded, encoded in model.forward_iter(X_train):
    X_hat = np.append(X_hat, decoded.detach().cpu().numpy())
    Y = np.append(Y, encoded.detach().cpu().numpy())

In [None]:
import matplotlib.pyplot as plt

In [None]:
X_hat = X_hat.reshape(n_samples, -1)

In [None]:
n_images = 10
for i in range(n_images):
    original = X_train[i].reshape(image_size)
    reconstructed = X_hat[i].reshape(image_size)
    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 5))
    ax1.imshow(original, cmap="gray")
    ax2.imshow(reconstructed, cmap="gray")
    plt.suptitle(f"Original vs. Reconstructed by Convolutional AE")
    plt.show()

In [None]:
# TODO: look at the latent space Y with TSNE
from sklearn.manifold import TSNE

Y_embedded = TSNE(n_components=2).fit_transform(Y)  # takes ~3min

In [None]:
plt.scatter(Y_embedded[:, 0], Y_embedded[:, 1], c=mnist_train.targets)
plt.savefig("figures/mnist_fc_ae_embedding_visualization.png")
plt.show()