In [None]:
import nnabla as nn

import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
from nnabla.monitor import tile_images

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import random

%matplotlib inline

from tensorflow.examples.tutorials.mnist import input_data

In [None]:
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
lr = 1e-4
K = 100

In [None]:
def G1_(z, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("G1"):  # Parameter scope can be nested
        h = z
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("sigomid_layer"):
            X = F.sigmoid(PF.affine(h, X_dim))
    return X

In [None]:
def G2_(z, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("G2"):  # Parameter scope can be nested
        h = z
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("sigmoid_layer"):
            X = F.sigmoid(PF.affine(h, X_dim))
    return X

In [None]:
def D1_(X, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("D1"):  # Parameter scope can be nested
        h = X
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("classifier"):
            y = PF.affine(h, 1)
    return y

In [None]:
def D2_(X, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("D2"):  # Parameter scope can be nested
        h = X
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("classifier"):
            y = PF.affine(h, 1)
    return y

In [None]:
G1_solver = S.Adam(lr)
with nn.parameter_scope("G1"):
    G1_solver.set_parameters(nn.get_parameters())
    
G2_solver = S.Adam(lr)
with nn.parameter_scope("G2"):
    G2_solver.set_parameters(nn.get_parameters())
    
D1_solver = S.Adam(lr)
with nn.parameter_scope("D1"):
    D1_solver.set_parameters(nn.get_parameters())

D2_solver = S.Adam(lr)
with nn.parameter_scope("D2"):
    D2_solver.set_parameters(nn.get_parameters())

In [None]:
def reset_grad():
    G1_solver.zero_grad()
    G2_solver.zero_grad()
    D1_solver.zero_grad()
    D2_solver.zero_grad()

In [None]:
D1 = {'model': D1_, 'solver': D1_solver}
G1 = {'model': G1_, 'solver': G1_solver}
D2 = {'model': D2_, 'solver': D2_solver}
G2 = {'model': G2_, 'solver': G2_solver}

In [None]:
GAN_pairs = [(D1, G1), (D2, G2)]

In [None]:
def show16(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    plt.show()

In [None]:
for it in range(1000000):
    # Discriminator
    z = nn.Variable.from_numpy_array(np.random.randn(mb_size, Z_dim))
    X, _ = mnist.train.next_batch(mb_size)
    X = nn.Variable.from_numpy_array(X)
    
    for D, G in GAN_pairs:
        G_sample = G['model'](z)
        D_real = D['model'](X)
        D_fake = D['model'](G_sample)
        
        D_loss = - F.mean(F.log(D_real + 1e-8) + F.log(1 - D_fake + 1e-8))
        D_loss.forward()
        D_loss.backward()
        D['solver'].update()

        reset_grad()

        # Generator Update
        G_sample = G['model'](z)
        D_fake = D['model'](G_sample)
        G_loss = - F.mean(F.log(D_fake))

        G_loss.forward()
        G_loss.backward()
        G['solver'].update()

        reset_grad()
        
    if it != 0 and it % K == 0:
        # Swap (D, G) pairs
        new_D1, new_D2 = GAN_pairs[1][0], GAN_pairs[0][0]
        GAN_pairs = [(new_D1, G1), (new_D2, G2)]
    
    # Generate and Show Samples 
    if it % 1000 == 0:
        idx = np.random.randint(0, 10)
        print('Step: {}, D_loss: {}, G_loss: {}'.format(it, D_loss.d, G_loss.d))
        
               # Pick G randomly
        G_rand = random.choice([G1_, G2_])
        samples = G_rand(z).d[:16]
        show16(samples)