In [1]:
import numpy as np
import json
import os
import pickle as pkl
import matplotlib.pyplot as plt

from keras.models import Model, Sequential
from keras import backend as K
from keras.optimizers import Adam, RMSprop
from keras.utils import plot_model
from keras.initializers import RandomNormal

from keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ZeroPadding2D, UpSampling2D
from keras.layers.merge import _Merge

weight_init = RandomNormal(mean=0., stddev=0.02)

d_losses = []
g_losses = []

epoch = 0

def discriminator():
    discriminator_input = Input(shape=(28,28,1), name='discriminator_input')
    x = discriminator_input
    x = Conv2D(
        filters = 64
            , kernel_size = 5
            , strides = 2
            , padding = 'same'
            , name = 'discriminator_conv_1'
            , kernel_initializer = weight_init
            )(x)
    x = Conv2D(
        filters = 128
            , kernel_size = 5
            , strides = 1
            , padding = 'same'
            , name = 'discriminator_conv_2'
            , kernel_initializer = weight_init
            )(x)
    x = Activation('relu')(x)
    x = Dropout(rate = 0.4)(x)

    x = Flatten()(x)

    discriminator_output = Dense(1, activation='sigmoid', kernel_initializer = weight_init)(x)
    discriminator = Model(discriminator_input, discriminator_output)


def generator():
    generator_input = Input(shape=(100,), name='generator_input')
    x = generator_input
    x = Dense(np.prod((7,7,64)), kernel_initializer = weight_init)(x)
    x = BatchNormalization(momentum = 0.9)(x)
    x = Activation('relu')(x)
    x = Reshape((7,7,64)))(x)

    x = UpSampling2D()(x)
    x = Conv2D(
        filters = 128
        , kernel_size = 5
        , padding = 'same'
        , strides = 2
        , name = 'generator_conv_1'
        , kernel_initializer = weight_init
    )(x)
    x = BatchNormalization(momentum = 0.9)(x)
    x = Activation('relu')(x)
    x = Conv2D(
        filters = 1
        , kernel_size = 5
        , padding = 'same'
        , name = 'generator_conv_2'
        , kernel_initializer = weight_init
    )(x)
    x = Activation('tanh')(x)

    generator_output = x
    generator = Model(generator_input, generator_output)

def adversarial():
    discriminator.compile(optimizer=RMSprop(0.0008), loss = 'binary_crossentropy',  metrics = ['accuracy'])

    discriminator.trainable= False
    for l in discriminator.layers:
        l.trainable = False

    model_input = Input(shape=(100,), name='model_input')
    model_output = discriminator(generator(model_input))
    model = Model(model_input, model_output)

    model.compile(optimizer=RMSprop(0.0004) , loss='binary_crossentropy', metrics=['accuracy'])

    discriminator.trainable= True
    for l in discriminator.layers:
        l.trainable = True

def train_discriminator(x_train, batch_size, using_generator):

    valid = np.ones((batch_size,1))
    fake = np.zeros((batch_size,1))

    if using_generator:
        true_imgs = next(x_train)[0]
        if true_imgs.shape[0] != batch_size:
            true_imgs = next(x_train)[0]
    else:
        idx = np.random.randint(0, x_train.shape[0], batch_size)
        true_imgs = x_train[idx]

    noise = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(noise)

    d_loss_real, d_acc_real =   discriminator.train_on_batch(true_imgs, valid)
    d_loss_fake, d_acc_fake =   discriminator.train_on_batch(gen_imgs, fake)
    d_loss =  0.5 * (d_loss_real + d_loss_fake)
    d_acc = 0.5 * (d_acc_real + d_acc_fake)

    return [d_loss, d_loss_real, d_loss_fake, d_acc, d_acc_real, d_acc_fake]

def train_generator(batch_size):
    valid = np.ones((batch_size,1))
    noise = np.random.normal(0, 1, (batch_size, 100))
    return model.train_on_batch(noise, valid)


def train(x_train, batch_size, epochs, using_generator = False):

    for epoch in range(epoch, epoch + epochs):
        d = train_discriminator(x_train, batch_size, using_generator)
        g = train_generator(batch_size)

        print ("%d [D loss: (%.3f)(R %.3f, F %.3f)] [D acc: (%.3f)(%.3f, %.3f)] [G loss: %.3f] [G acc: %.3f]" % (epoch, d[0], d[1], d[2], d[3], d[4], d[5], g[0], g[1]))

        d_losses.append(d)
        g_losses.append(g)

        epoch += 1

discriminator()
generator()
adversarial()
#TODO load data => (x_train, y_train) = load data
train(x_train, batch_size = 64, epochs = 10)