In [None]:
import torch
import torchvision
from torchvision.utils import make_grid, save_image
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as wdg
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir)
import common.setup as setup

In [None]:
model_name = 'VAE'
model_path = '../trained_models/VAE_MNIST_2019-03-23_12:57/final_model'
config_path = os.path.join(os.path.dirname(model_path), 'settings.config')
config = setup.parse_config(config_path)
loader = setup.create_test_loader(data='MNIST', directory='../data')
model = setup.create_model(config, model_name)
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
cuda = torch.cuda.is_available()
if cuda:
    model.cuda()

## Test reconstruction of test samples

In [None]:
imgs = []
n = 15
for i, (data, target) in enumerate(loader):
    if i > n-1:
        break
    imgs.append(data.squeeze(0))
    mu, log_var = model.encode(data.cuda())
    z = model.sample_z(mu, log_var)
    rec = model.decode(z)
    rec = rec.detach().cpu()
    #print('mu: %s\n\var: %s\n'%(str(mu), str(log_var)))
    imgs.append(rec.squeeze(0))

imgs = make_grid(imgs, nrow=2)
imgs = np.moveaxis(imgs.numpy(), 0, -1)
plt.figure(figsize=(2.0, n))
plt.imshow(imgs)
plt.show()

## Create new samples

In [None]:
hidden_size = tuple(map(int, config['hidden_dim_size'].split(',')))
eps = torch.randn(hidden_size[1])
if cuda:
    eps = eps.cuda()
img = model.decode(eps)
img = img.squeeze().detach().cpu().numpy()
print(img.shape)

In [None]:
plt.imshow(img, cmap=plt.cm.gray)
plt.show()

### Variety of generated samples

In [None]:
square = 10
imgs = []
for i in range(square**2):
    eps = torch.randn(hidden_size[1])
    if cuda:
        eps = eps.cuda()
    #eps = mu + torch.exp(log_var / 2) * eps
    img = model.decode(eps)
    img = img.squeeze(0).detach().cpu()
    imgs.append(img)
imgs = make_grid(imgs, nrow=square)
plt.figure(figsize=(square, square))
plt.imshow(np.moveaxis(imgs.numpy(), 0, -1), cmap=plt.cm.gray)
plt.show()

## Show regularity in created samples

In [None]:
def plot_img(x, y):
    z = torch.Tensor([x, y])
    if cuda:
        z = z.cuda()
    img = model.decode(z)
    img = img.squeeze().detach().cpu()
    plt.imshow(img, cmap=plt.cm.gray)
    return None

In [None]:
borders = 5.0
wdg.interact(plot_img, x=wdg.FloatSlider(min=-borders, max=borders, step=0.1),y=wdg.FloatSlider(min=-borders, max=borders, step=0.1))

In [None]:
borders = 3.0
spacing = 0.25

num = int((borders*2)/spacing+1.0)
print('Num per dim: %d'%num)
X = np.linspace(-borders, borders, num=num)
Y = np.linspace(-borders, borders, num=num)
imgs = []
for x in X:
    for y in Y:
        z = torch.Tensor([x, y])
        if cuda:
            z = z.cuda()
        img = model.decode(z)
        img = img.squeeze(0).detach().cpu()
        imgs.append(img)
imgs = make_grid(imgs, nrow=num)
plt.figure(figsize=(num/2.0, num/2.0))
plt.imshow(np.moveaxis(imgs.numpy(), 0, -1), cmap=plt.cm.gray)
plt.savefig('../result_figures/VAE_generation.png')
plt.show()