Keras_minimal_gan_jpn
---
Kerasで記述された最小のGANコード

Mnistの数字を生成<br />
潜在変数(z_dim)100で生成

参考コード：GANs in Action Deep Learning with Generative Adversarial Networks / Jakub Langr, Vladimir Bok<br />
(邦題：実践GAN 敵対的生成ネットワークによる深層学習、ISBN-10: 4839967717)より<br />
https://github.com/GANs-in-Action/gans-in-action/blob/master/chapter-3/Chapter_3_GAN.ipynb

In [0]:
import matplotlib.pyplot as plt
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam

In [0]:
# パラメータ設定
iterations = 20000
batch_size = 256
sample_interval = 1000
# Size of the noise vector, used as input to the Generator
z_dim = 100

# Load the MNIST dataset
(X_train, _), (_, _) = mnist.load_data()
# Rescale [0, 255] grayscale pixel values to [-1, 1]
X_train = X_train / 127.5 - 1.0
X_train = np.expand_dims(X_train, axis=3)
img_shape = X_train[0].shape

In [0]:
def build_generator(img_shape, z_dim):
    h, w, c = img_shape
    model = Sequential()
    model.add(Dense(128, input_dim=z_dim))
    model.add(LeakyReLU(alpha=0.01))
    model.add(Dense(h * w * c, activation='tanh'))
    model.add(Reshape(img_shape))
    return model

In [0]:
def build_discriminator(img_shape):
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(128))
    model.add(LeakyReLU(alpha=0.01))
    model.add(Dense(1, activation='sigmoid'))
    return model

In [0]:
def build_gan(generator, discriminator):
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    return model

In [0]:
discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])
generator = build_generator(img_shape, z_dim)

discriminator.trainable = False
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())

In [0]:
def sample_images(generator, image_grid_rows=4, image_grid_columns=4):
    z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))
    gen_imgs = generator.predict(z)
    gen_imgs = 0.5 * gen_imgs + 0.5
    fig, axs = plt.subplots(image_grid_rows, image_grid_columns, figsize=(4, 4))

    cnt = 0
    for i in range(image_grid_rows):
        for j in range(image_grid_columns):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1

In [0]:
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for iteration in range(iterations):
    #  Train the Discriminator
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgs = X_train[idx]
    z = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(z)
    d_loss_real = discriminator.train_on_batch(imgs, real)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    d_loss, accuracy = 0.5 * np.add(d_loss_real, d_loss_fake)

    #  Train the Generator
    z = np.random.normal(0, 1, (batch_size, 100))
    g_loss = gan.train_on_batch(z, real)

    if (iteration + 1) % sample_interval == 0:
        print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (iteration + 1, d_loss, 100.0 * accuracy, g_loss))
        sample_images(generator)