In [50]:
from keras.layers import Activation, Dense, Input
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.layers import concatenate
from keras.optimizers import RMSprop
from keras.models import Model
from keras.datasets import mnist
from keras.utils import to_categorical

import numpy as np
import math
import matplotlib.pyplot as plt
import os

In [51]:
def build_generator(latent_size, image_size):
    layer_sizes = [256, 512, 1024, image_size * image_size]

    inputs = Input(shape=(latent_size,), name='z_input')
    x = inputs

    for size in layer_sizes[:-1]:
        x = Dense(size)(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
    
    x = Dense(layer_sizes[-1])(x)
    x = Activation('sigmoid')(x)
    x = Reshape((image_size, image_size, 1))(x)

    generator = Model(inputs, x, name='generator')
    return generator


In [52]:
def build_discriminator(image_size):
    layer_sizes = [1024, 512, 256, 1]

    inputs = Input(shape=(image_size, image_size, 1), name='discriminator_input')
    x = Flatten()(inputs)

    for size in layer_sizes[:-1]:
        x = Dense(size)(x)
        x = LeakyReLU(alpha=0.2)(x)
    
    x = Dense(layer_sizes[-1])(x)
    x = Activation('sigmoid')(x)

    discriminator = Model(inputs, x, name='discriminator')
    return discriminator


In [53]:
def train(models, data, params):
    generator, discriminator, adversarial = models
    x_train = data
    batch_size, latent_size, train_steps, model_name = params
    save_interval = 500
    train_size = x_train.shape[0]

    for i in range(train_steps):
        rand_indexes = np.random.randint(0, train_size, size=batch_size)
        real_images = x_train[rand_indexes]
        noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
        fake_images = generator.predict(noise)
        x = np.concatenate((real_images, fake_images))

        y = np.ones([2 * batch_size, 1])
        y[batch_size:, :] = 0.0
        loss, acc = discriminator.train_on_batch(x, y)
        log = "%d: [discriminator loss: %f, acc: %f]" % (i, loss, acc)
        noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
        y = np.ones([batch_size, 1])
        loss, acc = adversarial.train_on_batch(noise, y)
        log = "%s [adversarial loss: %f, acc: %f]" % (log, loss, acc)
        if (i + 1) % 100 == 0:
          print(log)
        if (i + 1) % save_interval == 0:
            plot_images(generator,
                        noise_input=noise,
                        show=False,
                        step=(i + 1),
                        model_name=model_name)
    generator.save(model_name + ".h5")



In [54]:
def plot_images(generator,
                noise_input,
                show=False,
                step=0,
                model_name="gan"):
    os.makedirs(model_name, exist_ok=True)
    filename = os.path.join(model_name, "%05d.png" % step)
    images = generator.predict(noise_input)
    plt.figure(figsize=(2.2, 2.2))
    num_images = images.shape[0]
    image_size = images.shape[1]
    rows = int(math.sqrt(noise_input.shape[0]))
    for i in range(num_images):
        plt.subplot(rows, rows, i + 1)
        image = np.reshape(images[i], [image_size, image_size])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.savefig(filename)
    if show:
        plt.show()
    else:
        plt.close('all')



In [55]:
def build_and_train_models():
    (x_train, _), (_, _) = mnist.load_data()

    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255

    model_name = "gan_mnist"
    latent_size = 100
    batch_size = 64
    train_steps = 4000 # 40000
    lr = 2e-4
    decay = 6e-8
    input_shape = (image_size, image_size, 1)

    discriminator = build_discriminator(image_size)
    optimizer = RMSprop(learning_rate=lr, decay=decay)
    discriminator.compile(loss='binary_crossentropy',
                          optimizer=optimizer,
                          metrics=['accuracy'])
    discriminator.summary()

    generator = build_generator(latent_size, image_size)
    generator.summary()

    optimizer = RMSprop(learning_rate=lr * 0.5, decay=decay * 0.5)
    discriminator.trainable = False
    inputs = Input(shape=(latent_size,), name='z_input')
    adversarial = Model(inputs, discriminator(generator(inputs)), name=model_name)
    adversarial.compile(loss='binary_crossentropy',
                        optimizer=optimizer,
                        metrics=['accuracy'])
    adversarial.summary()

    models = (generator, discriminator, adversarial)
    params = (batch_size, latent_size, train_steps, model_name)
    train(models, x_train, params)


In [56]:
build_and_train_models()

Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 discriminator_input (InputL  [(None, 28, 28, 1)]      0         
 ayer)                                                           
                                                                 
 flatten_4 (Flatten)         (None, 784)               0         
                                                                 
 dense_69 (Dense)            (None, 1024)              803840    
                                                                 
 leaky_re_lu_27 (LeakyReLU)  (None, 1024)              0         
                                                                 
 dense_70 (Dense)            (None, 512)               524800    
                                                                 
 leaky_re_lu_28 (LeakyReLU)  (None, 512)               0         
                                                     