In [None]:
import nnabla as nn

import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
from nnabla.monitor import tile_images

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import scipy.ndimage.interpolation

%matplotlib inline

from tensorflow.examples.tutorials.mnist import input_data

In [None]:
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
lr = 1e-4

In [None]:
def G_AB(X, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("G_AB"):  # Parameter scope can be nested
        h = X
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("last_layer"):
            X = F.sigmoid(PF.affine(h, X_dim))
    return X

In [None]:
def G_BA(X, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("G_BA"):  # Parameter scope can be nested
        h = X
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("last_layer"):
            X = F.sigmoid(PF.affine(h, X_dim))
    return X

In [None]:
def D_A(X, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("D_A"):  # Parameter scope can be nested
        h = X
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("classifier"):
            y = PF.affine(h, 1)
    return y

In [None]:
def D_B(X, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("D_B"):  # Parameter scope can be nested
        h = X
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("classifier"):
            y = PF.affine(h, 1)
    return y

In [None]:
G_solver = S.Adam(lr)
with nn.parameter_scope("G_AB"), nn.parameter_scope("G_BA"):
    G_solver.set_parameters(nn.get_parameters())
    
D_solver = S.Adam(lr)
with nn.parameter_scope("D_AB"), nn.parameter_scope("D_BA"):
    D_solver.set_parameters(nn.get_parameters())

In [None]:
def reset_grad():
    G_solver.zero_grad()
    D_solver.zero_grad()

In [None]:
# Gather training data: domain1 <- real MNIST img, domain2 <- rotated MNIST img
X_train = mnist.train.images
half = int(X_train.shape[0] / 2)
# Real image
X_train1 = X_train[:half]
# Rotated image
X_train2 = X_train[half:].reshape(-1, 28, 28)
X_train2 = scipy.ndimage.interpolation.rotate(X_train2, 90, axes=(1, 2))
X_train2 = X_train2.reshape(-1, 28*28)
# Cleanup
del X_train

In [None]:
def sample_X(X, size):
    start_idx = np.random.randint(0, X.shape[0]-size)
    return nn.Variable.from_numpy_array(X[start_idx:start_idx+size])

In [None]:
def show16(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    plt.show()

In [None]:
for it in range(1000000):
    # Sample data from both domains
    X_A = sample_X(X_train1, mb_size)
    X_B = sample_X(X_train2, mb_size)

    # Discriminator A
    X_BA = G_BA(X_B)
    D_A_real = D_A(X_A)
    D_A_fake = D_A(X_BA)

    L_D_A = - F.mean(F.log(D_A_real + 1e-8) + F.log(1 - D_A_fake + 1e-8))

    # Discriminator B
    X_AB = G_AB(X_A)
    D_B_real = D_B(X_B)
    D_B_fake = D_B(X_AB)

    L_D_B = - F.mean(F.log(D_B_real + 1e-8) + F.log(1 - D_B_fake + 1e-8))

    # Total discriminator loss
    D_loss = L_D_A + L_D_B
    
    D_loss.forward()
    D_loss.backward()
    D_solver.update()

    reset_grad()

    # Generator Update
    
    # Generator AB
    X_AB = G_AB(X_A)
    D_B_fake = D_B(X_AB)
    X_ABA = G_BA(X_AB)

    L_adv_B = - F.mean(F.log(D_B_fake + 1e-8))
    L_recon_A = F.mean(F.sum((X_A - X_ABA)**2, axis=1))
    L_G_AB = L_adv_B + L_recon_A

    # Generator BA
    X_BA = G_BA(X_B)
    D_A_fake = D_A(X_BA)
    X_BAB = G_AB(X_BA)

    L_adv_A = - F.mean(F.log(D_A_fake))
    L_recon_B = F.mean(F.sum((X_B - X_BAB)**2, axis=1))
    L_G_BA = L_adv_A + L_recon_B

    # Total generator loss
    G_loss = L_G_AB + L_G_BA

    G_loss.forward()
    G_loss.backward()
    G_solver.update()

    reset_grad()
    
    
    # Generate and Show Samples 
    if it % 1000 == 0:
        print('Step: {}, D_loss: {}, G_loss: {}'.format(it, D_loss.d, G_loss.d))
        
        input_A = sample_X(X_train1, size=4)
        input_B = sample_X(X_train2, size=4)

        samples_A = G_BA(input_B).d
        samples_B = G_AB(input_A).d

        input_A = input_A.d
        input_B = input_B.d
        
        # The resulting image sample would be in 4 rows:
        # row 1: real data from domain A, row 2 is its domain B translation
        # row 3: real data from domain B, row 4 is its domain A translation
        samples = np.vstack([input_A, samples_B, input_B, samples_A])        
        show16(samples)