In [None]:
import nnabla as nn

import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
from nnabla.monitor import tile_images

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

%matplotlib inline

from tensorflow.examples.tutorials.mnist import input_data

In [None]:
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
lr = 1e-3
noise_factor = .25

In [None]:
# Q(z|X) Encoder
def Q(X, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("Q"):  # Parameter scope can be nested
        h = X
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("Gaussian params"):
            z_mu = PF.affine(h, 1)
            z_var = PF.affine(h, 1)
    return z_mu, z_var

In [None]:
def sample_z(mu, long_var):
    eps = nn.Variable.from_numpy_array(np.random.rand(mb_size, Z_dim))
    return mu + F.exp(long_var / 2) * eps

In [None]:
# P(X|z) Encoder
def P(z, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("P"):  # Parameter scope can be nested
        h = z
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("last_layer"):
            X = F.sigmoid(PF.affine(h, X_dim))
    return X

In [None]:
solver = S.Adam(lr)
solver.set_parameters(nn.get_parameters())    

In [None]:
def show16(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    plt.show()

In [None]:
for it in range(100000):
    X, _ = mnist.train.next_batch(mb_size)
    X = nn.Variable.from_numpy_array(X)
    
    # Add noise
    X_noise = X.d + noise_factor * np.random.randn(X.shape[0], X.shape[1])
    X_noise = np.clip(X_noise, 0., 1.)
    X_noise = nn.Variable.from_numpy_array(X_noise)
    # Forward
    z_mu, z_var = Q(X_noise)
    z = sample_z(z_mu, z_var)
    X_sample = P(z)

    # Loss
    recon_loss = F.sum(F.binary_cross_entropy(X_sample, X)) / mb_size
    kl_loss = F.mean(0.5 * F.sum(F.exp(z_var) + z_mu**2 - 1. - z_var, axis=1))
    loss = -(recon_loss + kl_loss)

    loss.forward()
    loss.backward()
    solver.update()
    solver.zero_grad()


    # Print and plot every now and then
    if it % 1000 == 0:
        print('Step: {}, Loss: {:.4}'.format(it, loss.d))
        
        z = nn.Variable.from_numpy_array(np.random.randn(mb_size, Z_dim))
        samples = P(z).d[:16]
        show16(samples)