In [None]:
import sys
sys.path.append('../..')

## Load trained model

In [None]:
from vae_mnist import VariationalAutoEncoder

checkpoint_path = '../../lightning_logs/vae_mnist/version_0/checkpoints/epoch=173-step=326249.ckpt'
model = VariationalAutoEncoder.load_from_checkpoint(checkpoint_path)
model.eval()

In [None]:
model.encoder

## Load test data

In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

val_dataset = MNIST(root='../../data', train=False, transform=transforms.ToTensor(), download=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=10, shuffle=True)

In [None]:
imgs, labels = next(iter(val_loader))
imgs.shape, labels

## Reconstruction

In [None]:
mu, logvar = model.encoder(imgs)
mu.shape

In [None]:
z = model.reparameterize(mu, logvar)
recon_imgs = model.decoder(z)
recon_imgs.shape

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(15, 3))

for i in range(10):
    img = imgs[i].squeeze()
    ax = fig.add_subplot(2, 10, i + 1)
    ax.axis('off')
    ax.imshow(img, cmap='gray_r')

for i in range(10):
    img = recon_imgs[i].squeeze().detach().numpy()
    ax = fig.add_subplot(2, 10, i + 10 + 1)
    ax.axis('off')
    ax.imshow(img, cmap='gray_r')

## Latent space

In [None]:
val_loader = DataLoader(dataset=val_dataset, batch_size=5000, shuffle=True)
imgs, labels = next(iter(val_loader))
imgs.shape, labels.shape

In [None]:
mu, logvar = model.encoder(imgs)
mu = mu.detach().numpy()
logvar = logvar.detach().numpy()
mu.shape, logvar.shape

In [None]:
plt.figure(figsize=(12, 12))
plt.scatter(mu[:, 0], mu[:, 1], cmap='rainbow', c=labels, alpha=0.5, s=2)
plt.colorbar()
plt.show()