In [None]:
import sys
import os
import shutil
sys.path.insert(0,'../..')

from AutoGAN import GAN
from AutoGAN.schemes.IWGAN_TrainingScheme import IWGAN_TrainingScheme, wasserstein_loss

import keras
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.95
config.gpu_options.visible_device_list = "1"
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config))

def build_generator():
    model = Sequential()

    model.add(Dense(256, input_dim=100))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod([28, 28, 1]), activation='tanh'))
    model.add(Reshape([28, 28, 1]))

    model.summary()

    m_noise = Input(shape=(100,))
    img = model(m_noise)

    return Model(m_noise, img)


def build_discriminator(layer_activation='sigmoid'):
    model = Sequential()

    model.add(Flatten(input_shape=(28, 28, 1)))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation=layer_activation))
    model.summary()

    img = Input(shape=(28, 28, 1))
    validity = model(img)

    return Model(img, validity)


class save_images(keras.callbacks.Callback):
    def __init__(self, model, name='gan'):
        super(save_images, self).__init__()
        self.full_model = model
        self.name = name
    def on_epoch_end(self, epoch, logs=None):
        r, c = 5, 5
        local_noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = self.full_model.generator_model().predict(local_noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        #print("saving image: %d" % epoch)
        fig.savefig("images/%s/%d.png" % (self.name,epoch))
        plt.close()


def load_data():
    (real_targets, _), (_, _) = mnist.load_data()
    # Rescale -1 to 1
    real_targets = real_targets / 127.5 - 1.
    real_targets = np.expand_dims(real_targets, axis=3)
    noise = np.random.normal(0, 1, (real_targets.shape[0], 100))
    return noise, real_targets


def iwgan():
    model = GAN(generator=build_generator(), discriminator=build_discriminator('linear'))
    optimizer = Adam(0.0002, 0.5)
    try:
        shutil.rmtree('images/iwgan')
    except:
        pass
    try:
        os.makedirs('images/iwgan')
    except:
        pass
    discriminator_kwargs = {'optimizer': optimizer}
    generator_kwargs = {'loss': wasserstein_loss, 'optimizer': optimizer}
    model.compile(training_scheme=IWGAN_TrainingScheme(batch_size=32),
                  generator_kwargs=generator_kwargs, discriminator_kwargs=discriminator_kwargs)
    return model

x,y = load_data()

In [None]:
model = iwgan()
model.fit(x=x, y=y, epochs=30, steps_per_epoch=2000, batch_size=32,
          generator_callbacks=[save_images(model=model, name='iwgan')], verbose=1)