In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import sys

# local imports
%load_ext autoreload
%autoreload 2
sys.path.append('../src')
from models import VariationalAutoencoder, ImportanceWeightedAutoencoder

Create datasets using torchvision

In [None]:
import torchvision
path_to_datasets = '~/datasets'
mnist_train_data = torchvision.datasets.MNIST(path_to_datasets, train=True, download=True,
                                              transform=torchvision.transforms.ToTensor())
mnist_test_data = torchvision.datasets.MNIST(path_to_datasets, train=False, download=True,
                                             transform=torchvision.transforms.ToTensor())

fig, ax = plt.subplots(1, 10, figsize=(6, 1), tight_layout=True)
for i in range(10):
    image, label = mnist_train_data[i]
    ax[i].imshow(image.numpy()[0, :, :], cmap=plt.cm.Greys_r)
    ax[i].axis('off')
    ax[i].set_title(label)

Prepare dataloaders

In [None]:
from torch.utils.data import DataLoader, SubsetRandomSampler

np.random.seed(0)
idx = list(range(len(mnist_train_data)))
idx = list(range(10000))
#np.random.shuffle(idx)
split = int(0.7*len(idx))

train_loader = DataLoader(mnist_train_data, batch_size=32, drop_last=True,
                          sampler=SubsetRandomSampler(idx[:split]))

valid_loader = DataLoader(mnist_train_data, batch_size=256, drop_last=True,
                          sampler=SubsetRandomSampler(idx[split:]))

test_loader = DataLoader(mnist_test_data, batch_size=256, drop_last=False, shuffle=False)

Train the model (or skip this and load last model)

In [None]:
torch.manual_seed(1234)

model = VariationalAutoencoder(latent_dim=2, data_dim=28*28)
#model = ImportanceWeightedAutoencoder(latent_dim=2, data_dim=28*28)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def normalize_data(x):
    x = x.reshape(-1, 28*28)
    x = (x - torch.mean(x, dim=1, keepdim=True))/torch.std(x, dim=1, keepdim=True)
    return x

from tqdm.notebook import tqdm

for epoch in tqdm(range(100)):
    epoch_loss = 0.0
    for x, label in train_loader:        
        x = normalize_data(x)
        optimizer.zero_grad()
        loss = model.negELBO(x, mc_samples=10)[0]
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    # Log and checkpoint
    if np.mod(epoch, 10) == 0:
        print(f"{epoch} {epoch_loss}")
        torch.save({'current_epoch': epoch, 'model_state_dict': model.state_dict()}, 
                    '../checkpoints/mnist_vae_last.pt')

Load final model

In [None]:
model = ImportanceWeightedAutoencoder(latent_dim=2, data_dim=28*28)
model.load_state_dict(torch.load('../checkpoints/mnist_vae_spherical_100.pt')['model_state_dict'])

Latent space

In [None]:
fig, ax = plt.subplots(tight_layout=True)

all_q_mu, all_q_std = torch.tensor([]), torch.tensor([])
with torch.no_grad():
    model.eval()
    for x, label in test_loader:
        x = normalize_data(x)
        q_mu, q_logvar = model.encode(x)
        all_q_mu = torch.cat((all_q_mu, q_mu))
        all_q_std = torch.cat((all_q_std.cpu(), (0.5*q_logvar.cpu()).exp()))
all_q_mu = all_q_mu.numpy()
all_q_std = all_q_std.numpy()

for digit in range(10):
    mask = mnist_test_data.targets == digit
    ax.errorbar(all_q_mu[mask, 0], all_q_mu[mask, 1],
                all_q_std[mask, 0], all_q_std[mask, 1], fmt='none',
                alpha=0.5, cmap=plt.cm.tab10, label=str(digit))
ax.set_xlabel(r'$z_1$')
ax.set_ylabel(r'$z_2$')
ax.set_title('MNIST VAE latent space visualization')
plt.legend()

Reconstructions

In [None]:
fig, ax = plt.subplots(6, 10, figsize=(8, 4), tight_layout=True, sharey=True)
for ax_ in ax.ravel():
    ax_.get_xaxis().set_ticks([])
    ax_.get_yaxis().set_ticks([])

for i in range(10):
    with torch.no_grad():
        x = normalize_data(mnist_test_data.__getitem__(i)[0].unsqueeze(0))
        p_parameters, q_parameters, z = model.forward(x)
        xhat = model.sample_decoder(p_parameters, mc_samples=3).reshape(-1, 28, 28).numpy()
        dec_mu, dec_logvar = p_parameters
        dec_mu = dec_mu.reshape(28, 28).cpu().numpy()
        #dec_std = (0.5*dec_logvar).exp().reshape(-1, 28, 28) #diagonal
        dec_std = (0.5*dec_logvar).exp().repeat(28, 28).cpu().numpy() #spherical
    
    ax[0, i].imshow(x.reshape(28, 28), cmap=plt.cm.Greys_r) # Data
    ax[0, 0].set_ylabel('Data')
    ax[1, i].imshow(dec_mu, cmap=plt.cm.Greys_r) # Mean
    ax[1, 0].set_ylabel(r'$\mu$')
    ax[2, i].imshow(dec_std, cmap=plt.cm.Greys_r) # Std
    ax[2, 0].set_ylabel(r'$\sigma$')
    for j in range(3): # 3 realizations N(Mean, Std^2)
        ax[3+j, i].imshow(xhat[j], cmap=plt.cm.Greys_r)
        ax[3+j, 0].set_ylabel(f'$\mu + \sigma\epsilon_{j+1}$')
    