In [None]:
import nnabla as nn

import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solvers as S
from nnabla.monitor import tile_images

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from itertools import chain
import scipy.ndimage.interpolation

%matplotlib inline

from tensorflow.examples.tutorials.mnist import input_data

In [None]:
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
lr = 1e-4

n_critics = 3
lam1, lam2 = 100, 100

In [None]:
def G1(X, z, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("G1"):  # Parameter scope can be nested
        h = F.concatenate(X, z)
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("last_layer"):
            X = F.sigmoid(PF.affine(h, X_dim))
    return X

In [None]:
def G2(X, z, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("G2"):  # Parameter scope can be nested
        h = F.concatenate(X, z)
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("last_layer"):
            X = F.sigmoid(PF.affine(h, X_dim))
    return X

In [None]:
def D1(X, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("D1"):  # Parameter scope can be nested
        h = X
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("classifier"):
            y = PF.affine(h, 1)
    return y

In [None]:
def D2(X, hidden=[128, 128]):
    hs = []
    with nn.parameter_scope("D2"):  # Parameter scope can be nested
        h = X
        for hid, hsize in enumerate(hidden):
            with nn.parameter_scope("affine{}".format(hid + 1)):
                h = F.relu(PF.affine(h, hsize))
                hs.append(h)
        with nn.parameter_scope("classifier"):
            y = PF.affine(h, 1)
    return y

In [None]:
G_solver = S.RMSprop(lr)
with nn.parameter_scope("G1"), nn.parameter_scope("G2"):
    G_solver.set_parameters(nn.get_parameters())
    
D1_solver = S.RMSprop(lr)
with nn.parameter_scope("D1"):
    D1_solver.set_parameters(nn.get_parameters())

D2_solver = S.RMSprop(lr)
with nn.parameter_scope("D2"):
    D2_solver.set_parameters(nn.get_parameters())

In [None]:
def reset_grad():
    G_solver.zero_grad()
    D1_solver.zero_grad()
    D2_solver.zero_grad()

In [None]:
def show16(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    plt.show()

In [None]:
X_train = mnist.train.images
half = int(X_train.shape[0] / 2)

# Real image
X_train1 = X_train[:half]
# Rotated image
X_train2 = X_train[half:].reshape(-1, 28, 28)
X_train2 = scipy.ndimage.interpolation.rotate(X_train2, 90, axes=(1, 2))
X_train2 = X_train2.reshape(-1, 28*28)

# Cleanup
del X_train

In [None]:
def sample_X(X, size):
    start_idx = np.random.randint(0, X.shape[0]-size)
    return nn.Variable.from_numpy_array(X[start_idx:start_idx+size])

In [None]:
for it in range(1000000):
    for _ in range(n_critics):
        # Discriminator Update
        z1 = nn.Variable.from_numpy_array(np.random.randn(mb_size, Z_dim))
        z2 = nn.Variable.from_numpy_array(np.random.randn(mb_size, Z_dim))
        X1 = sample_X(X_train1, mb_size)
        X2 = sample_X(X_train1, mb_size)
        
        # D1
        X2_sample = G1(X1, z1)  # G1: X1 -> X2
        D1_real = D1(X2)
        D1_fake = D1(X2_sample)

        D1_loss = -(F.mean(D1_real, axis=1) - F.mean(D1_fake, axis=1))

        D1_loss.forward()
        D1_loss.backward()
        D1_solver.update()
        
        reset_grad()
        
        # D1
        X1_sample = G1(X2, z2)  # G1: X1 -> X2
        D2_real = D2(X1)
        D2_fake = D2(X1_sample)

        D2_loss = -(F.mean(D2_real, axis=1) - F.mean(D2_fake, axis=1))

        D2_loss.forward()
        D2_loss.backward()
        D2_solver.update()
        
        reset_grad()

    # Generator Update
    z1 = nn.Variable.from_numpy_array(np.random.randn(mb_size, Z_dim))
    z2 = nn.Variable.from_numpy_array(np.random.randn(mb_size, Z_dim))
    X1 = sample_X(X_train1, mb_size)
    X2 = sample_X(X_train2, mb_size)
    
    X1_sample = G2(X2, z2)
    X2_sample = G1(X1, z1)
    
    X1_recon = G2(X2, z2)
    X2_recon = G1(X1, z1)

    G_loss = - F.mean(D1_fake) - F.mean(D2_fake)
    reg1 = lam1 * F.mean(F.sum(F.abs(X1_recon - X1), axis=1))
    reg2 = lam2 * F.mean(F.sum(F.abs(X2_recon - X2), axis=1))
    
    G_loss += reg1 + reg2

    G_loss.forward()
    G_loss.backward()
    G_solver.update()

    reset_grad()
    
    # Generate and Show Samples 
    if it % 1000 == 0:
        idx = np.random.randint(0, 10)
        print('Step: {}, D_loss: {}, G_loss: {}'.format(it, D1_loss.d[0] + D2_loss.d[0], G_loss.d))
        
        real1 = X1.d[:4]
        real2 = X2.d[:4]
        samples1 = X1_sample.d[:4]
        samples2 = X2_sample.d[:4]
        samples = np.vstack([real2, samples1, real1, samples2])

        show16(samples)