<a href="https://colab.research.google.com/github/WatanabeRyusuke/automated-build/blob/master/DCGAN_mnist_with_Keras.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import os

import numpy as np
import math
import matplotlib.pyplot as plt

from keras.layers import Dense, Input
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose
from keras.layers import Activation, BatchNormalization
from keras.layers import LeakyReLU
from keras.models import Model
from keras.optimizers import RMSprop
from keras.datasets import mnist
from keras.utils import plot_model
from keras import backend as K

In [0]:
def build_generator(inputs, image_size):
    image_resize = image_size // 4
    kernel_size = 5
    layer_filters = [128, 64, 32, 1]

    x = Dense(image_resize * image_resize * layer_filters[0])(inputs)
    x = Reshape((image_resize, image_resize, layer_filters[0]))(x)
    
    
    for filters in layer_filters:
        if filters > layer_filters[-2]:
            strides = 2
        else:
            strides = 1
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters=filters,
                                                kernel_size=kernel_size,
                                                strides=strides,
                                                padding='same')(x)
    x = Activation('sigmoid')(x)
    generator = Model(inputs, x, name='generator')
    return generator

In [0]:
def build_discriminator(inputs):
    kernel_size = 5
    layer_filters = [32, 64, 128, 256]

    x = inputs
    
    for filters in layer_filters:
        if filters == layer_filters[-1]:
            strides = 1
        else:
            strides = 2
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(filters=filters,
                              kernel_size=kernel_size,
                              strides=strides,
                              padding='same')(x)
    x = Flatten()(x)
    x = Dense(1)(x)
    x = Activation('sigmoid')(x)
    discriminator = Model(inputs, x, name='discriminator')
    return discriminator

In [0]:
def plot_images(generator,
                noise_input,
                show=False,
                step=0,
                model_name="gan"):
    """Generate fake images and plot them
    For visualization purposes, generate fake images
    then plot them in a square grid
    # Arguments
        generator (Model): The Generator Model for fake images generation
        noise_input (ndarray): Array of z-vectors
        show (bool): Whether to show plot or not
        step (int): Appended to filename of the save images
        model_name (string): Model name
    """
    os.makedirs(model_name, exist_ok=True)
    filename = os.path.join(model_name, "%05d.png" % step)
    images = generator.predict(noise_input)
    plt.figure(figsize=(2.2, 2.2))
    num_images = images.shape[0]
    image_size = images.shape[1]
    rows = int(math.sqrt(noise_input.shape[0]))
    for i in range(num_images):
        plt.subplot(rows, rows, i + 1)
        image = np.reshape(images[i], [image_size, image_size])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.savefig(filename)
    if show:
        plt.show()
    else:
        plt.close('all')

In [0]:
def train(models, x_train, params):
    generator, discriminator, adversarial = models
    batch_size, latent_size, train_steps, model_name = params
    save_interval = 500
    noise_input = np.random.uniform(-1.0, 1.0, size=[16, latent_size])
    train_size = x_train.shape[0]

    for i in range(train_steps):
        rand_indexes = np.random.randint(0, train_size, size=batch_size)
        real_images = x_train[rand_indexes]
        noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
        fake_images = generator.predict(noise)
        x = np.concatenate((real_images, fake_images))
        y = np.ones([2*batch_size, 1])
        y[batch_size: , :] = 0.0
        loss, acc = discriminator.train_on_batch(x, y)
        log = '%d: [discriminator loss: %f, acc: %f]' % (i, loss, acc)
        noise = np.random.uniform(-1.0, 1.0, size=[batch_size, latent_size])
        y = np.ones([batch_size, 1])
        loss, acc = adversarial.train_on_batch(noise, y)
        log = '%s [adversarial loss: %f, acc: %f]' % (log, loss, acc)
        print(log)
        if (i + 1) % save_interval == 0:
            if (i + 1) == train_steps:
                show = True
            else:
                show = False

            plot_images(generator,
                                  noise_input=noise_input,
                                  show=show,
                                  step=(i + 1),
                                  model_name=model_name)

    generator.save(model_name + '.h5')

In [0]:
def build_and_train_model():
    (x_train, _), (_, _) = mnist.load_data()
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255
    model_name = 'dcgan_mnist'
    latent_size = 100
    batch_size = 64
    train_steps = 40000
    lr = 2e-4
    decay = 6e-8
    input_shape = (image_size, image_size, 1)
    inputs = Input(shape=input_shape, name='discriminator_input')
    discriminator = build_discriminator(inputs)
    optimizer = RMSprop(lr=lr, decay=decay)
    discriminator.compile(loss='binary_crossentropy',
                                          optimizer=optimizer,
                                          metrics=['accuracy'])
    discriminator.summary()
    #generator
    input_shape = (latent_size, )
    inputs = Input(shape=input_shape, name='z_input')
    generator = build_generator(inputs, image_size)
    generator.summary()
    optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
    discriminator.trainable = False 
    adversarial = Model(inputs,
                                       discriminator(generator(inputs)),
                                       name=model_name)
    adversarial.compile(loss='binary_crossentropy',
                                      optimizer=optimizer,
                                      metrics=['accuracy'])
    adversarial.summary()
    models = (generator, discriminator, adversarial)
    params = (batch_size, latent_size, train_steps, model_name)
    train(models, x_train, params)

In [0]:
build_and_train_model()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
discriminator_input (InputLa (None, 28, 28, 1)         0         
_________________________________________________________________
leaky_re_lu_17 (LeakyReLU)   (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 14, 14, 32)        832       
_________________________________________________________________
leaky_re_lu_18 (LeakyReLU)   (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 7, 7, 64)          51264     
_________________________________________________________________
leaky_re_lu_19 (LeakyReLU)   (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 4, 4, 128)         204928    
__________

  'Discrepancy between trainable weights and collected trainable'


0: [discriminator loss: 0.696685, acc: 0.281250] [adversarial loss: 0.894508, acc: 0.000000]
1: [discriminator loss: 0.653487, acc: 0.625000] [adversarial loss: 0.944343, acc: 0.000000]
2: [discriminator loss: 0.560565, acc: 0.984375] [adversarial loss: 1.203620, acc: 0.000000]
3: [discriminator loss: 0.463943, acc: 0.992188] [adversarial loss: 1.780011, acc: 0.000000]


  'Discrepancy between trainable weights and collected trainable'


4: [discriminator loss: 0.333574, acc: 0.976562] [adversarial loss: 1.612168, acc: 0.000000]
5: [discriminator loss: 0.262300, acc: 0.992188] [adversarial loss: 4.222385, acc: 0.000000]
6: [discriminator loss: 0.392653, acc: 0.789062] [adversarial loss: 1.479958, acc: 0.000000]
7: [discriminator loss: 0.225467, acc: 0.992188] [adversarial loss: 2.595415, acc: 0.000000]
8: [discriminator loss: 0.162209, acc: 0.976562] [adversarial loss: 1.867355, acc: 0.000000]
9: [discriminator loss: 0.107775, acc: 0.992188] [adversarial loss: 1.744715, acc: 0.000000]
10: [discriminator loss: 0.092075, acc: 0.984375] [adversarial loss: 1.392382, acc: 0.000000]
11: [discriminator loss: 0.067613, acc: 0.992188] [adversarial loss: 0.829658, acc: 0.218750]
12: [discriminator loss: 0.060404, acc: 1.000000] [adversarial loss: 1.219656, acc: 0.015625]
13: [discriminator loss: 0.038648, acc: 1.000000] [adversarial loss: 0.764080, acc: 0.343750]
14: [discriminator loss: 0.050911, acc: 0.992188] [adversarial los