In [None]:
import nnabla as nn

import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
from nnabla.monitor import tile_images

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

%matplotlib inline

from tensorflow.examples.tutorials.mnist import input_data

In [None]:
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 128
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
lr = 1e-3

d_step = 3
lr = 1e-3
m = 5
n_iter = 1000
n_epoch = 1000
N = n_iter * mb_size  # N data per epoch

In [None]:
def G(z, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("G"):  # Parameter scope can be nested
        h = z
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("last_layer"):
            X = F.sigmoid(PF.affine(h, X_dim))
    return X

In [None]:
def D(X, hidden=[128]):
    hs = []
    with nn.parameter_scope("D"):  # Parameter scope can be nested
        h = X
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("reconstruction"):
            X_recon = PF.affine(h, X_dim)      
    return F.sum((X - X_recon)**2, axis=1)

In [None]:
G_solver = S.Adamax(lr)
with nn.parameter_scope("G"):
    G_solver.set_parameters(nn.get_parameters())
    
D_solver = S.Adamax(lr)
with nn.parameter_scope("D"):
    D_solver.set_parameters(nn.get_parameters())

In [None]:
def reset_grad():
    G_solver.zero_grad()
    D_solver.zero_grad()

In [None]:
def show16(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    plt.show()

In [None]:
# Pretrain discriminator
for it in range(2*n_iter):
    X, _ = mnist.train.next_batch(mb_size)
    X = nn.Variable.from_numpy_array(X)

    loss = F.mean(D(X))  # Minimize real samples energy

    loss.forward()
    loss.backward()
    D_solver.update()
    reset_grad()

    if it % 1000 == 0:
        print('Step: {}, Pretrained D loss: {:.4}'.format(it, loss.d))

In [None]:
# Initial margin, expected energy of real data
m = F.mean(D(nn.Variable.from_numpy_array(mnist.train.images))).d
s_z_before = np.inf

In [None]:
for it in range(n_epoch):
    s_x, s_z = np.zeros(1), np.zeros(1)
    for it in range(n_iter):
        # Discriminator
        X, _ = mnist.train.next_batch(mb_size)
        X = nn.Variable.from_numpy_array(X)
        z = nn.Variable.from_numpy_array(np.random.randn(mb_size, Z_dim))

        G_sample = G(z)
        D_real = D(X)
        D_fake = D(G_sample)

        D_loss = F.mean(D_real) + F.relu(m - F.mean(D_fake))

        D_loss.forward()
        D_loss.backward() # Applying weight decay as an regulariation
        D_solver.update()

        # Update real samples statistics
        s_x += np.sum(D_real.d)

        reset_grad()

        # Generator Update
        z_G = nn.Variable.from_numpy_array(np.random.randn(mb_size, Z_dim))
        G_sample = G(z)
        D_fake = D(G_sample)
        
        G_loss = F.mean(D_fake)

        G_loss.forward()
        G_loss.backward()
        G_solver.update()
        
        # Update fake samples statistics
        s_z += np.sum(D_fake.d)

        reset_grad()
    
    # Update margin
    if (((s_x[0] / N) < m) and (s_x[0] < s_z[0]) and (s_z_before[0] < s_z[0])):
        m = s_x[0] / N

    s_z_before = s_z

    # Convergence measure
    Ex = s_x[0] / N
    Ez = s_z[0] / N
    L = Ex + np.abs(Ex - Ez)
    # Generate and Show Samples 
    print('Step: {}, m = {}, L = {}'.format(it, m, L))
        
    samples = G(z_G).d[:16]
    show16(samples)