In [None]:
import nnabla as nn

import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
from nnabla.monitor import tile_images

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

%matplotlib inline

from tensorflow.examples.tutorials.mnist import input_data

In [None]:
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
lr = 1e-4

lamb = 3 

In [None]:
# Encoder q(z|X)
def Q(X, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("Q"):  # Parameter scope can be nested
        h = X
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("last_affine"):
            z = PF.affine(h, Z_dim)
    return z

In [None]:
# Decoder p(X|z)
def P(z, hidden=[128, 128]):
    h = z
    hs = []
    with nn.parameter_scope("P"):  # Parameter scope can be nested
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("sigmoid_layer"):
            X = F.sigmoid(PF.affine(h, X_dim))
    return X

In [None]:
# Discriminator of X
def D(X, hidden=[128, 128]):
    h = X
    hs = []
    with nn.parameter_scope("D"):  # Parameter scope can be nested
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("classifier"):
            y = F.sigmoid(PF.affine(h, 1))
    return y

In [None]:
# Discriminator of z
def C(z, hidden=[128, 128]):
    h = z
    hs = []
    with nn.parameter_scope("C"):
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("classifier"):
            y = F.sigmoid(PF.affine(h, 1))
    return y

In [None]:
Q_solver = S.Adam(lr)
with nn.parameter_scope("Q"):
    Q_solver.set_parameters(nn.get_parameters())
    
P_solver = S.Adam(lr)
with nn.parameter_scope("P"):
    P_solver.set_parameters(nn.get_parameters())
    
D_solver = S.Adam(lr)
with nn.parameter_scope("D"):
    D_solver.set_parameters(nn.get_parameters())
    
C_solver = S.Adam(lr)
with nn.parameter_scope("C"):
    C_solver.set_parameters(nn.get_parameters())

In [None]:
def reset_grad():
    Q_solver.zero_grad()
    P_solver.zero_grad()
    D_solver.zero_grad()
    C_solver.zero_grad()

In [None]:
def sample_X(size, include_y=False):
    X, y = mnist.train.next_batch(size)
    X = nn.Variable.from_numpy_array(X)
    
    if include_y:
        y = np.argmax(y, axis=1).astype(np.int)
        y = nn.Variable.from_numpy_array(y)
        return X, y
    
    return X

In [None]:
def show16(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    plt.show()

In [None]:
def L1_loss(X0, X1):
    element_wise_L1 = F.abs(X0 - X1)
    batch_wise_L1 = F.sum(element_wise_L1, axis=1, keepdims=True)
    return batch_wise_L1

In [None]:
for it in range(1000000):
    # Q (Encoder) Update
    X = sample_X(mb_size)
    z_sample = Q(X)
    X_recon = P(z_sample)
    C_fake = C(z_sample)
    
    Q_loss = F.mean(lamb * L1_loss(X_recon, X) - F.log(C_fake + 1e-8))
    
    Q_loss.forward()
    Q_loss.backward()
    Q_solver.update()
    reset_grad()
    
    # P (Decoder) Update
    X = sample_X(mb_size)
    z_sample = Q(X)
    X_recon = P(z_sample)
    z = nn.Variable.from_numpy_array(np.random.randn(mb_size, Z_dim))
    X_sample = P(z)
    
    D_recon = D(X_recon)
    D_sample = D(X_sample)
    
    P_loss = F.mean(lamb * L1_loss(X_recon, X) - F.log(D_recon + 1e-8) - F.log(D_sample + 1e-8))
    
    P_loss.forward()
    P_loss.backward()
    P_solver.update()
    reset_grad()
    
    # D (Discriminator of X) Update
    X = sample_X(mb_size)
    z_sample = Q(X)
    X_recon = P(z_sample)
    z = nn.Variable.from_numpy_array(np.random.randn(mb_size, Z_dim))
    X_sample = P(z)
    
    D_recon = D(X_recon)
    D_sample = D(X_sample)
    D_real = D(X)
    
    D_loss = F.mean(-F.log(D_real + 1e-8) - F.log(1 - D_recon + 1e-8) - F.log(1 - D_sample + 1e-8))

    D_loss.forward()
    D_loss.backward()
    D_solver.update()
    reset_grad()
    
    # Discriminator C
    X = sample_X(mb_size)
    z_fake = Q(X)
    z_real = nn.Variable.from_numpy_array(np.random.randn(mb_size, Z_dim))
    
    C_real = C(z_real)
    C_fake = C(z_fake)
    
    C_loss = F.mean(-F.log(C_fake + 1e-8) - F.log(1 - C_real + 1e-8))

    C_loss.forward()
    C_loss.backward()
    C_solver.update()
    reset_grad()
    
    # Generate and Show Samples 
    if it % 1000 == 0:
        idx = np.random.randint(0, 10)
        print('Step: {}, P_loss: {}, Q_loss: {}, D_loss: {}, C_loss: {}'.format(it, P_loss.d, Q_loss.d, D_loss.d, C_loss.d))
        
        samples = P(z_real).d[:16]

        show16(samples)