In [10]:
from keras.layers import Input, Conv2D, BatchNormalization,\
Activation, Dropout, Flatten, Dense, Reshape, UpSampling2D
from keras.optimizers import RMSprop
from keras.models import Model
from keras.initializers import RandomNormal
import numpy as np

In [5]:
# hyper params
input_dims = (28, 28, 1)
num_disc_layers = 4
disc_conv_fils = [64, 64, 128, 128]
disc_conv_kernel_size = [5, 5, 5, 5]
disc_conv_strides = [2, 2, 2, 1]
disc_batch_norm_momentum = None
disc_dropout_rate = 0.4

z_dims = (100,)
shape_after_dense = (7, 7, 64)
upsamp_layers = [True, True, False, False]
gen_batch_norm_momentum = 0.9
gen_dropout_rate = None
num_gen_layers = 4
gen_conv_fils = [128, 64, 64, 1]
gen_conv_kernel_size = [5, 5, 5, 5]

## Discriminator

In [3]:
disc_input = Input(shape=input_dims, name="discriminator_input")
x = disc_input

for i in range(num_disc_layers):
    x = Conv2D(filters=disc_conv_fils[i],
              kernel_size=disc_conv_kernel_size[i],
              strides=disc_conv_strides[i],
              padding="same",
              name="disc_conv_" + str(i)
              )(x)
    
    if disc_batch_norm_momentum and i > 0:
        x = BatchNormalization(momentum=disc_batch_norm_momentum)(x)
        
    x = Activation("relu")(x)
    
    if disc_dropout_rate:
        x = Dropout(disc_dropout_rate)(x)
    
x = Flatten()(x)
disc_output = Dense(1, activation="sigmoid", \
                    kernel_initializer=RandomNormal(mean=0., stddev=0.02))(x)
disc_model = Model(disc_input, disc_output)

In [4]:
disc_model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
discriminator_input (InputLa (None, 28, 28, 1)         0         
_________________________________________________________________
disc_conv_0 (Conv2D)         (None, 14, 14, 64)        1664      
_________________________________________________________________
activation_1 (Activation)    (None, 14, 14, 64)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 14, 14, 64)        0         
_________________________________________________________________
disc_conv_1 (Conv2D)         (None, 7, 7, 64)          102464    
_________________________________________________________________
activation_2 (Activation)    (None, 7, 7, 64)          0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 7, 7, 64)          0         
__________

## Generator

In [8]:
gen_input = Input(shape=z_dims, name="gen_input")
x = gen_input
x = Dense(np.prod(shape_after_dense))(x)

if gen_batch_norm_momentum:
    x = BatchNormalization(momentum=gen_batch_norm_momentum)(x)
    
x = Activation("relu")(x)
x = Reshape(shape_after_dense)(x)

if gen_dropout_rate:
    x = Dropout(rate=gen_dropout_rate)(x)
    
for i in range(num_gen_layers):
    if upsamp_layers[i]:
        x = UpSampling2D()(x)
        
    x = Conv2D(gen_conv_fils[i],
              gen_conv_kernel_size[i],
              padding="same",
              name="gen_conv_" + str(i)
              )(x)
    
    if i < num_gen_layers - 1:
        if gen_batch_norm_momentum:
            x = BatchNormalization(
            momentum=gen_batch_norm_momentum)(x)
            
        x = Activation("relu")(x)
    else:
        x = Activation("tanh")(x)
        
gen_output = x
gen_model = Model(gen_input, gen_output)

In [9]:
gen_model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
gen_input (InputLayer)       (None, 100)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 3136)              316736    
_________________________________________________________________
batch_normalization_5 (Batch (None, 3136)              12544     
_________________________________________________________________
activation_10 (Activation)   (None, 3136)              0         
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
up_sampling2d_5 (UpSampling2 (None, 14, 14, 64)        0         
_________________________________________________________________
gen_conv_0 (Conv2D)          (None, 14, 14, 128)       204928    
__________

## Train the GAN

### compile discriminator train model

In [11]:
disc_model.compile(optimizer=RMSprop(lr=0.0008),
                  loss="binary_crossentropy",
                  metrics=["accuracy"])

In [17]:
disc_model.trainable = False
model_input = Input(shape=z_dims, name="model_input")
model_output = disc_model(gen_model(model_input))
model = Model(model_input, model_output)

### compile generator train model

In [20]:
model.compile(optimizer=RMSprop(0.0004),
             loss="binary_crossentropy",
             metrics=["accuracy"])