In [None]:
import keras
from keras.models import Model
import keras.layers as L
import keras.backend as K

img_width = 512
img_height = 512

# Using https://github.com/ctmakro/hellotensor/blob/master/lets_gan_clean.py
# as a basis

def build_gen_network(seed_shape):
    input = L.Input(shape=(seed_shape,))
    reshaped = L.Reshape((1,1,seed_shape))(input)
    def deconv(layer, num_filters, kernel_size, strides=4, normalize=True, padding='same'):
        layer = L.Conv2DTranspose(
            num_filters, kernel_size, 
            padding=padding,
            strides=strides)(layer)
#         print(K.int_shape(layer))
        if normalize:
            layer = L.BatchNormalization()(layer)
            layer = L.LeakyReLU(0.2)(layer)
        return layer

    reshaped = deconv(reshaped, 256, 4, padding='valid', strides=1)
    reshaped = deconv(reshaped, 128, 4, padding='same')
    reshaped = deconv(reshaped, 64, 4, padding='same')
    reshaped = deconv(reshaped, 32, 4, padding='same')
    reshaped = deconv(reshaped, 16, 4, padding='same', strides=2)
    reshaped = deconv(reshaped, 1, 4, normalize=False, strides=1)    
    
    return Model(input=[input], output=[reshaped])

gen_network = build_gen_network(100)
gen_network.compile(loss='binary_crossentropy', optimizer='adam')
gen_network.summary()

In [None]:
def build_discriminator_network():
    input = L.Input(shape=(img_width, img_height, 1))
    def conv(layer, num_filters, kernel_size, strides, normalize=True, padding='same'):
        layer = L.Conv2D(
            num_filters, kernel_size, 
            padding=padding,
            strides=strides)(layer)
#        print(K.int_shape(layer))
        if normalize:
            layer = L.BatchNormalization()(layer)
            layer = L.LeakyReLU(0.2)(layer)
        return layer
    
    l = conv(input, 32, kernel_size=4, strides=2)
    l = conv(l, 64, kernel_size=4, strides=2)
    l = conv(l, 128, kernel_size=4, strides=2)
    l = conv(l, 256, kernel_size=4, strides=2)
    l = conv(l, 512, kernel_size=4, strides=2)
    l = conv(l, 1024, kernel_size=4, strides=2)
    l = L.Flatten()(l)
    l = L.Dense(units=1, activation='sigmoid')(l)
    
    return Model(input=[input], output=[l])

discriminator_network = build_discriminator_network()
discriminator_network.compile(loss='binary_crossentropy', optimizer='adam')
discriminator_network.summary()

In [None]:
gan_input = L.Input(gen_network.input_shape[1:])
GAN = Model(gan_input, 
            discriminator_network(gen_network(gan_input)))
GAN.compile(loss='binary_crossentropy', optimizer='adam')
GAN.summary()

In [None]:
BATCH_SIZE = 8

def make_trainable(net, val):
    net.trainable = val
    for l in net.layers:
        l.trainable = val

d_loss = []
g_loss = []

for i in range(10):
    print('\r', i, flush=True, end='')
    noise_gen = np.random.uniform(0, 1, size=(BATCH_SIZE, 100))
    image_batch = all_images[np.random.randint(0, all_images.shape[0], size=BATCH_SIZE)]
    
    generated = gen_network.predict(noise_gen)
    X = np.concatenate((image_batch, generated))
    Y = np.zeros(2 * BATCH_SIZE)
    Y[:BATCH_SIZE] = 1
    make_trainable(discriminator_network, True)
    d_loss.append(discriminator_network.train_on_batch(X, Y))
    
    noise_tr = np.random.uniform(0, 1, size=(BATCH_SIZE,100))
    y2 = np.zeros(BATCH_SIZE)
    y2[:] = 1
    make_trainable(discriminator_network, False)
    g_loss.append(GAN.train_on_batch(noise_tr, y2))