In [None]:
import nnabla as nn

import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
from nnabla.monitor import tile_images

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

%matplotlib inline

from tensorflow.examples.tutorials.mnist import input_data

In [None]:
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
lr = 1e-1

In [None]:
def G(z, c,  hidden=[128]):
    hs = []
    with nn.parameter_scope("G"):  # Parameter scope can be nested
        h = F.concatenate(z, c)
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("sigmoid_layer"):
            X = F.sigmoid(PF.affine(h, X_dim))
    return X

In [None]:
def D(X, hidden=[128]):
    hs = []
    with nn.parameter_scope("D"):
        h = X
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
                
        with nn.parameter_scope("D_GAN"):
            y_GAN = F.sigmoid(PF.affine(h, 1)) 
            
        with nn.parameter_scope("D_AUX"):
            y_AUX = F.softmax(PF.affine(h, y_dim))
            
    return y_GAN, y_AUX

In [None]:
G_solver = S.Adam(lr)
with nn.parameter_scope("G"):
    G_solver.set_parameters(nn.get_parameters())
    
D_solver = S.Adam(lr)
with nn.parameter_scope("D"):
    D_solver.set_parameters(nn.get_parameters())

In [None]:
def reset_grad():
    G_solver.zero_grad()
    D_solver.zero_grad()

In [None]:
def show16(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    plt.show()

In [None]:
for it in range(1000000):
    # Sample data
    X, y = mnist.train.next_batch(mb_size)
    X = nn.Variable.from_numpy_array(X)
    c = nn.Variable.from_numpy_array(y.astype('float32'))
    y_true = nn.Variable.from_numpy_array(y)
    z = nn.Variable.from_numpy_array(np.random.randn(mb_size, Z_dim))

    # Discriminator Update
    G_sample = G(z, c)
    D_real, C_real = D(X)
    D_fake, C_fake = D(G_sample)
    
    D_loss = F.mean(F.log(D_real + 1e-8) + F.log(1 - D_fake + 1e-8), keepdims=1)
    C_loss = -F.sigmoid_cross_entropy(C_real, y_true) - F.sigmoid_cross_entropy(C_fake, y_true)

    # Maximize
    DC_loss = -(D_loss + C_loss)
    
    DC_loss.forward()
    DC_loss.backward()
    D_solver.update()
    
    reset_grad()

    # Generator Update
    G_sample = G(z, c)
    D_fake, C_fake = D(G_sample)
    _, C_real = D(X)

    G_loss = F.mean(F.log(D_fake + 1e-8), keepdims=1)
    C_loss = -F.sigmoid_cross_entropy(C_real, y_true) - F.sigmoid_cross_entropy(C_fake, y_true)

    # Maximize
    GC_loss = -(G_loss + C_loss)

    GC_loss.forward()
    GC_loss.backward()
    G_solver.update()

    reset_grad()

    
    # Generate and Show Samples 
    if it % 1000 == 0:
        idx = np.random.randint(0, 10)
        print('Step: {}, D_loss: {}, G_loss: {}, Idx: {}'.format(it, -D_loss.d[0], -G_loss.d[0], idx))
        
        c = np.zeros([16, y_dim])
        c[range(16), idx] = 1
        c = nn.Variable.from_numpy_array(c.astype('float32'))

        z = nn.Variable.from_numpy_array(np.random.randn(16, Z_dim))

        samples = G(z, c).d

        show16(samples)