In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from drcomp.autoencoder import FullyConnectedAutoencoder, MnistConvolutionalAE
from drcomp.reducers import AutoEncoder
from drcomp import estimate_intrinsic_dimension

from torchvision import datasets, transforms

In [None]:
mnist_train = datasets.MNIST(
    root="../data/raw", download=True, transform=transforms.ToTensor()
)
mnist_test = datasets.MNIST(
    root="../data/raw", download=True, transform=transforms.ToTensor(), train=False
)

In [None]:
X_train = mnist_train.data.numpy().astype("float32")
X_train = X_train.reshape(-1, 1, 28, 28)
X_test = mnist_test.data.numpy().astype("float32")
n_samples = X_train.shape[0]
image_size = (28, 28)

In [None]:
# intrinsic_dim = estimate_intrinsic_dimension(X_train.reshape(n_samples, -1), K=10) # 15
intrinsic_dim = 15

In [None]:
conv_ae = AutoEncoder(
    MnistConvolutionalAE(intrinsic_dim), lr=1e-3, max_epochs=10, batch_size=128
).fit(
    X_train
)  # ~3min

In [None]:
Y = conv_ae.transform(X_train)
X_hat = conv_ae.inverse_transform(Y)

In [None]:
# conv_ae.evaluate(X_train.reshape(n_samples, -1), Y=Y.reshape(-1, intrinsic_dim), K=10) takes ~12min
# T = 0.9994

In [None]:
import matplotlib.pyplot as plt

In [None]:
original = X_train[0].reshape(image_size)
reconstructed = X_hat[0].reshape(image_size)

In [None]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 5))
ax1.imshow(original, cmap="gray")
ax2.imshow(reconstructed, cmap="gray")
plt.suptitle(f"Original vs. Reconstructed by Convolutional AE")
plt.show()

In [None]:
# TODO: look at the latent space Y with TSNE
from sklearn.manifold import TSNE

Y_embedded = TSNE(n_components=2).fit_transform(Y)  # takes ~3min

In [None]:
plt.scatter(Y_embedded[:, 0], Y_embedded[:, 1], c=mnist_train.targets)
plt.savefig("../figures/mnist_conv_ae_latent_space.png")
plt.show()