In [None]:
import sys
import os
import shutil
sys.path.insert(0,'../..')

from AutoGAN import GAN
from AutoGAN.schemes.Base_TrainingScheme import GAN_TrainingScheme

import keras
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import numpy as np

import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.95
config.gpu_options.visible_device_list = "1"
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config))

def build_generator(self):
    img_rows = 28
    img_cols = 28
    channels = 1
    img_shape = (img_rows, img_cols, channels)
    num_classes = 10
    latent_dim = 100

    model = Sequential()

    model.add(Dense(256, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))

    model.summary()

    noise = Input(shape=(latent_dim,))
    label = Input(shape=(1,), dtype='int32')
    label_R = Reshape((1,))(label)
    label_embedding = Flatten()(Embedding(num_classes, latent_dim)(label))

    model_input = multiply([noise, label_embedding])
    img = model(model_input)

    return Model([noise, label], [img, label_R])

def build_discriminator(self, layer_activation='sigmoid'):

    img_rows = 28
    img_cols = 28
    channels = 1
    img_shape = (img_rows, img_cols, channels)
    num_classes = 10
    latent_dim = 100
    
    model = Sequential()

    model.add(Dense(512, input_dim=np.prod(img_shape)))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Dense(1, activation=layer_activation))
    model.summary()

    img = Input(shape=img_shape)
    label = Input(shape=(1,), dtype='int32')

    label_embedding = Flatten()(Embedding(num_classes, np.prod(img_shape))(label))
    flat_img = Flatten()(img)

    model_input = multiply([flat_img, label_embedding])

    validity = model(model_input)

    return Model([img, label], validity)

class save_images(keras.callbacks.Callback):
    def __init__(self, model, name='gan'):
        super(save_images, self).__init__()
        self.full_model = model
        self.name = name
    def on_epoch_end(self, epoch, logs=None):
        r, c = 2, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        sampled_labels = np.arange(0, 10).reshape(-1, 1)

        gen_imgs = self.full_model.generator_model().predict([noise, sampled_labels])[0]

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
                axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d.png" % (self.name,epoch))
        plt.close()

def load_data():
    (real_targets, real_label), (_, _) = mnist.load_data()
    # Rescale -1 to 1
    real_targets = real_targets / 127.5 - 1.
    real_targets = np.expand_dims(real_targets, axis=3)
    noise = np.random.normal(0, 1, (real_targets.shape[0], 100))
    return noise, real_targets , real_label


def iwgan():
    model = GAN(generator=build_generator('sigmoid'), discriminator=build_discriminator('linear'))
    optimizer = Adam(0.0002, 0.5)
    try:
        shutil.rmtree('./images/iwgan')
    except:
        pass
    try:
        os.makedirs(name='./images/iwgan')
    except:
        pass
    discriminator_kwargs = {'optimizer': optimizer}
    generator_kwargs = {'loss': wasserstein_loss, 'optimizer': optimizer}
    model.compile(training_scheme=IWGAN_TrainingScheme(batch_size=32),
                  generator_kwargs=generator_kwargs, discriminator_kwargs=discriminator_kwargs)
    return model


def gan():
    model = GAN(generator=build_generator('sigmoid'), discriminator=build_discriminator('sigmoid'))
    optimizer = Adam(0.0002, 0.5)
    try:
        shutil.rmtree('./images/condgan')
    except:
        pass
    try:
        os.makedirs(name='./images/condgan')
    except:
        pass
    discriminator_kwargs = {'loss': 'binary_crossentropy', 'optimizer': optimizer, 'metrics': ['accuracy']}
    generator_kwargs = {'loss': 'binary_crossentropy', 'optimizer': optimizer}
    model.compile(training_scheme=GAN_TrainingScheme(),
                  generator_kwargs=generator_kwargs, discriminator_kwargs=discriminator_kwargs)
    return model

x,y,labels = load_data()

In [None]:
model = gan()
model.fit(x=[x,labels], y=[y,labels], epochs=25, steps_per_epoch=800, batch_size=32,
          generator_callbacks=[save_images(model=model, name='condgan')], verbose=1)