# Vanilla GAN implementation using keras

In [1]:
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import keras
from keras.models import Sequential
from keras.layers import Dense, Activation, BatchNormalization

from keras.optimizers import SGD
from keras.datasets import mnist
from keras.regularizers import l1_l2

Using TensorFlow backend.


In [2]:
# generated images and visualize
def plot_generated(noise, generator_model, examples=9, plot_dim=(3,3), size=(7,7), epoch=None):
    # make directory if there is not
    path = "generated_figures"
    if not os.path.isdir(path):
        os.makedirs(path)
    
    # generate images
    generated_images = generator_model.predict(noise)

    # visualize
    fig = plt.figure(figsize=size)
    for i in range(examples):
        plt.subplot(plot_dim[0], plot_dim[1], i+1)
        img = generated_images[i, :]
        img = img.reshape((28, 28))
        plt.tight_layout()
        plt.imshow(img, cmap="gray")
        plt.axis("off")
    plt.savefig(os.path.join(path, str(epoch) + ".png"))
    plt.close()

In [3]:
def plot_metrics(metrics, epoch=None):
    # make directory if there is not
    path = "metrics"
    if not os.path.isdir(path):
        os.makedirs(path)
    
    plt.figure(figsize=(10,8))
    plt.plot(metrics["d"], label="discriminative loss", color="b")
    plt.legend()
    plt.savefig(os.path.join(path, "dloss" + str(epoch) + ".png"))
    plt.close()

    plt.figure(figsize=(10,8))
    plt.plot(metrics["g"], label="generative loss", color="r")
    plt.legend()
    plt.savefig(os.path.join(path, "g_loss" + str(epoch) + ".png"))
    plt.close()

In [4]:
def Generator():
    act = keras.layers.advanced_activations.LeakyReLU(alpha=0.2)
    Gen = Sequential()
    Gen.add(Dense(input_dim=100, units=256, kernel_regularizer=l1_l2(1e-5, 1e-5)))
    Gen.add(BatchNormalization(mode=0))
    Gen.add(act)
    Gen.add(Dense(units=512, kernel_regularizer=l1_l2(1e-5, 1e-5)))
    Gen.add(BatchNormalization(mode=0))
    Gen.add(act)
    Gen.add(Dense(units=1024, kernel_regularizer=l1_l2(1e-5, 1e-5)))
    Gen.add(BatchNormalization(mode=0))
    Gen.add(act)
    Gen.add(Dense(units=28*28, kernel_regularizer=l1_l2(1e-5, 1e-5)))
    Gen.add(BatchNormalization(mode=0))
    Gen.add(Activation("sigmoid"))
    generator_optimizer = SGD(lr=0.1, momentum=0.3, decay=1e-5)
    
    Gen.compile(loss="binary_crossentropy", optimizer=generator_optimizer)
    return Gen

In [5]:
def Discriminator():
    act = keras.layers.advanced_activations.LeakyReLU(alpha=0.2)
    Dis = Sequential()
    Dis.add(Dense(input_dim=784, units=1024, kernel_regularizer=l1_l2(1e-5, 1e-5)))
    Dis.add(act)
    Dis.add(Dense(units=512, kernel_regularizer=l1_l2(1e-5, 1e-5)))
    Dis.add(act)
    Dis.add(Dense(units=256, kernel_regularizer=l1_l2(1e-5, 1e-5)))
    Dis.add(act)
    Dis.add(Dense(units=1, kernel_regularizer=l1_l2(1e-5, 1e-5)))
    Dis.add(Activation("sigmoid"))
    discriminator_optimizer = SGD(lr=0.1, momentum=0.1, decay=1e-5)
    Dis.compile(loss="binary_crossentropy", optimizer=discriminator_optimizer)
    return Dis

In [6]:
def Generative_Adversarial_Network(generator_model, discriminator_model):
    GAN = Sequential()
    GAN.add(generator_model)
    discriminator_model.trainable=False
    GAN.add(discriminator_model)
    gan_optimizer = SGD(0.1, momentum=0.3)
    GAN.compile(loss="binary_crossentropy", optimizer=gan_optimizer)
    return GAN

In [7]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.reshape(X_train.shape[0], 28*28)
X_test = X_test.reshape(X_test.shape[0], 28*28)
X_train.astype('float32')
X_test.astype('float32')
X_train = X_train/255
X_test = X_test/255

print("X_train shape", X_train.shape)
print(X_train.shape[0], "train samples")
print(X_test.shape[0], "test samples")

BadZipFile: File is not a zip file

In [None]:
def main_train(z_input_size, generator_model, discriminator_model, gan_model, loss_dict, X_train, generated_figures=None, z_group=None, z_plot_freq=200, epoch=1000, plot_freq=25, batch=100):

    # tqdm
    with tqdm(total=epoch) as pbar:
        for e in range(epoch):
            pbar.update(1)

            # generate images from noise
            noise = np.random.uniform(0, 1, size=[batch, z_input_size])
            generated_images = generator_model.predict_on_batch(noise)

            # extract training images
            rand_train_index = np.random.randint(0, X_train.shape[0], size=batch)
            image_batch = X_train[rand_train_index, :]

            # concat training images and generated images
            X = np.vstack((image_batch, generated_images))
            # make label
            y = np.zeros(int(2*batch))
            y[batch:] = 1
            y = y.astype(int)

            # training of discriminator
            discriminator_model.trainable = True
            d_loss = discriminator_model.train_on_batch(x=X, y=y)
            discriminator_model.trainable = False

            # training of generator
            noise = np.random.uniform(0, 1, size=[batch, z_input_size])
            y = np.zeros(batch)
            y = y.astype(int)
            g_loss = gan_model.train_on_batch(x=noise, y=y)

            loss_dict["d"].append(d_loss)
            loss_dict["g"].append(g_loss)

            # show graph
            if e%plot_freq == plot_freq-1:
                plot_metrics(loss_dict, int(e/plot_freq))
                
            # show the generated images
            if visualize_train and e < epoch:
                if e%z_plot_freq == z_plot_freq-1:
                    plot_generated(z_group, generator_model=generator_model, epoch=int(e/z_plot_freq))
                    #generated_figures.append(fig)

In [None]:
# generate model
Gen = Generator()
Dis = Discriminator()
GAN = Generative_Adversarial_Network(Gen, Dis)
GAN.summary()

Gen.summary()
Dis.summary()

# parameter setting
gan_losses = {"d":[], "g":[], "f":[]}
epoch = 300000
batch = 1000
z_plot_freq = 1000
plot_freq = 10000
z_input_vector = 100
n_train_samples = 60000
examples = 9

z_group_matrix = np.random.uniform(0,1,examples*z_input_vector)
z_group_matrix = z_group_matrix.reshape([9, z_input_vector])
print(z_group_matrix.shape)

generated_figures = []

In [None]:
main_train(100, Gen, Dis, GAN, loss_dict=gan_losses, X_train=X_train, generated_figures=z_training_figures, z_group=z_group_matrix, z_plot_freq=z_plot_freq, epoch=epoch, plot_freq=plot_freq, batch=batch)