# Run spatialVAE on rotated MNIST dataset

In [1]:
import sys
sys.path.append('../models')

import hydra
#from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from hydra.experimental import compose, initialize
import logging
import os
import numpy as np
import os.path as osp
import pytorch_lightning as pl
import torch
import torchvision.transforms.functional as TF

from dm_mnist import MnistDataModule
import matplotlib.pyplot as plt
from vae import VAE
from spatialVAE import SpatialVAE
from omegaconf import OmegaConf
logger = logging.getLogger(__name__)

  rank_zero_deprecation(


# Load settings

In [2]:
with initialize(config_path="../configs"):
    cfg = compose(config_name="spatialVAE_mnist.yaml")
    print(cfg)
out_dir = osp.join("..", "output")

# To ensure reproducibility
pl.seed_everything(123)

Global seed set to 123


{'dataset': 'mnist', 'likelihood': 'bernoulli', 'z_dim': 8, 'hidden_dim': 512, 'num_layers': 2, 'activation': 'tanh', 'modify': 1, 'kl_coef': 1.0, 'lr': 0.0001, 'step_size': 25, 'batch_size': 64, 'num_workers': 2, 'num_epochs': 100, 'seed': 123, 'fast_dev_run': False, 'dx_scale': 0.1, 'theta_prior': 0.7854}


123

In [3]:
# Dataset
dm = MnistDataModule(data_dir = osp.join('..', 'data'), dataset_name="MNIST", modify=cfg.modify)
# Model
spatialvae_model = SpatialVAE(cfg, data_dim=dm.size())
# Train
trainer = pl.Trainer(checkpoint_callback=False,
                     max_epochs=cfg.num_epochs,
                     fast_dev_run=cfg.fast_dev_run,
                     gpus=[0] if torch.cuda.is_available() else 0)
trainer.fit(model=spatialvae_model, datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

  | Name  | Type           | Params
-----------------------------------------
0 | p_net | SpatialDecoder | 268 K 
1 | q_net | Encoder        | 673 K 
-----------------------------------------
942 K     Trainable params
0         Non-trainable params
942 K     Total params
3.770     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
Global seed set to 123
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')


In [None]:
save_file = osp.join(out_dir, 'vae_mnist_encoder.pth')
torch.save(vae_model.q_net.state_dict(), save_file)
save_file = osp.join(out_dir, 'vae_mnist_decoder.pth')
torch.save(vae_model.p_net.state_dict(), save_file)

## Visualise example data

In [None]:
plt.rcParams['figure.figsize'] = [8, 8]

# Visualize rotated dataset
fig, axs = plt.subplots(2, 5)
for ax, i in zip(axs.flatten(), np.arange(0, 10)):
    img, _, angle = dm.train_set[i]
    ax.imshow(img.squeeze())
    ax.set_axis_off()
fig.suptitle("MNIST Rotated Visualization")
plt.tight_layout()
plt.show()

In [None]:
# Visualize original dataset
fig, axs = plt.subplots(2, 5)
for ax, i in zip(axs.flatten(), np.arange(0, 10)):
    img, _, angle = dm.train_set[i]
    angle_deg = angle.item() * 180 / np.pi
    img = TF.rotate(img=img, angle=-angle_deg) 
    ax.imshow(img.squeeze())
    ax.set_axis_off()
fig.suptitle("MNIST Original Visualization")
plt.tight_layout()
plt.show()

In [None]:
# VAE predictions
vae_model.eval()
fig, axs = plt.subplots(2, 5)
for ax, i in zip(axs.flatten(), np.arange(0, 10)):
    img, _, angle = dm.train_set[i]
    # First parameter of forward is y_hat: reconstructed input
    imgs_vae = vae_model(img)[0].detach().numpy()
    ax.imshow(imgs_vae.squeeze()) #cmap='gray'
    ax.set_axis_off()
fig.suptitle("MNIST Reconstructed Visualization")
plt.tight_layout()
plt.show()

## Latent space visualisation

In [None]:
def plot_latent(vae, data, num_batches=100):
    z = np.zeros((num_batches+1, 2))
    target = []
    for i, (img, t, angle) in enumerate(data):
        # Second parameter of forward is z: latent space
        z[i, :] = vae_model(img)[1].detach().numpy()
        target.append(t)
        if i >= num_batches:
            break
    plt.scatter(z[:, 0], z[:, 1], c=target, cmap='tab10')
    plt.colorbar()
plot_latent(vae_model, dm.train_set, num_batches=15000)

In [None]:
def plot_reconstructed(vae, r0=(-10, 10), r1=(-5, 10), n=15):
    w = 28
    img = np.zeros((n*w, n*w))
    for i, y in enumerate(np.linspace(*r1, n)):
        for j, x in enumerate(np.linspace(*r0, n)):
            z = torch.Tensor([[x, y]]).to('cpu')
            x_hat = vae.p_net(z)
            x_hat = x_hat.reshape(28, 28).to('cpu').detach().numpy()
            img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = x_hat
    plt.imshow(img, extent=[*r0, *r1], cmap='gray')
plot_reconstructed(vae_model)

## Interpolation

In [None]:
plt.rcParams['figure.figsize'] = [15, 20]
def interpolate(vae, x1, x2, n=12):
    mu1, lv1 = vae.q_net(x1)
    mu2, lv2 = vae.q_net(x2)
    z1, _ = vae.reparameterize(mu=mu1, logstd=lv1)
    z2, _ = vae.reparameterize(mu=mu2, logstd=lv2)
    
    z = torch.stack([z1 + (z2 - z1)*t for t in np.linspace(0, 1, n)])
    interpolate_list = vae.p_net(z)
    interpolate_list = interpolate_list.to('cpu').detach().numpy()

    w = 28
    img = np.zeros((w, n*w))
    for i, x_hat in enumerate(interpolate_list):
        img[:, i*w:(i+1)*w] = x_hat.reshape(28, 28)
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])

data = dm.train_dataloader()
x, y, k = data.__iter__().next() # hack to grab a batch
x_1 = x[y == 1][1].to('cpu') # find a 1
x_2 = x[y == 0][1].to('cpu') # find a 0

interpolate(vae_model, x_1, x_2, n=20)