In [None]:
import sys
import os
import shutil
sys.path.insert(0,'../..')

from AutoGAN import GAN
from AutoGAN.schemes.ACGAN_TrainingScheme import ACGAN_TrainingScheme

import keras
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import numpy as np

import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.95
config.gpu_options.visible_device_list = "1"
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config))

def build_generator():
    img_rows = 28
    img_cols = 28
    channels = 1
    img_shape = (img_rows, img_cols, channels)
    num_classes = 10
    latent_dim = 100

    model = Sequential()

    model.add(Dense(128 * 7 * 7, activation="relu", input_dim=latent_dim))
    model.add(Reshape((7, 7, 128)))
    model.add(BatchNormalization(momentum=0.8))
    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(channels, kernel_size=3, padding='same'))
    model.add(Activation("tanh"))

    #model.summary()
    noise = Input(shape=(latent_dim,))
    label = Input(shape=(1,), dtype='int32')
    label_embedding = Flatten()(Embedding(num_classes, 100)(label))

    model_input = multiply([noise, label_embedding])
    img = model(model_input)

    return Model([noise, label], img)

def build_discriminator():
    img_rows = 28
    img_cols = 28
    channels = 1
    img_shape = (img_rows, img_cols, channels)
    num_classes = 10
    latent_dim = 100
    model = Sequential()

    model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))

    model.add(Flatten())
    #model.summary()

    img = Input(shape=img_shape)

    # Extract feature representation
    features = model(img)

    # Determine validity and label of the image
    validity = Dense(1, activation="sigmoid")(features)
    label = Dense(num_classes+1, activation="softmax")(features)

    return Model(img, [validity, label])
class save_images(keras.callbacks.Callback):
    def __init__(self, model, name='gan'):
        super(save_images, self).__init__()
        self.full_model = model
        self.name = name
    def on_epoch_end(self, epoch, logs=None):
        r, c = 2, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        sampled_labels = np.array([num for num in range(r*c)])
        gen_imgs = self.full_model.generator_model().predict([noise, sampled_labels])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
                axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d.png" % (self.name,epoch))
        plt.close()

def load_data():
    (real_targets, real_label), (_, _) = mnist.load_data()
    # Rescale -1 to 1
    real_targets = real_targets / 127.5 - 1.
    real_targets = np.expand_dims(real_targets, axis=3)
    noise = np.random.normal(0, 1, (real_targets.shape[0], 100))
    label_onehot = np.zeros((real_label.shape[0], 10))
    for i in range(real_label.shape[0]):
        label_onehot[real_label[i]] = 1
    return noise, real_targets , real_label, label_onehot


def acgan():
    model = GAN(generator=build_generator(), discriminator=build_discriminator())
    optimizer = Adam(0.0002, 0.5)
    try:
        shutil.rmtree('./images/acgan')
    except:
        pass
    try:
        os.makedirs(name='./images/acgan')
    except:
        pass
    discriminator_kwargs = {'loss':['binary_crossentropy', 'sparse_categorical_crossentropy']*2, 'optimizer': optimizer, 'metrics':[]}
    generator_kwargs = {'discriminator_loss':['binary_crossentropy', 'sparse_categorical_crossentropy'] , 'optimizer': optimizer}
    model.compile(training_scheme=ACGAN_TrainingScheme(),
                  generator_kwargs=generator_kwargs, discriminator_kwargs=discriminator_kwargs)
    return model

x, y, labels, label_onehot = load_data()

In [None]:
model = acgan()
model.fit(x=[x,labels], y=[y,labels], epochs=25, steps_per_epoch=800, batch_size=32,
          generator_callbacks=[save_images(model=model, name='acgan')], verbose=1)