In [39]:
import keras
from keras.layers import Conv2D, Input, Activation, Dropout, Flatten, Dense, BatchNormalization, Reshape, UpSampling2D
from keras.models import Model
from keras.initializers import RandomNormal
import keras.backend as K
import numpy as np

In [36]:
class Gan():
    
    def __init__(self, input_size):
        
        self.input_size = input_size
        
        # Discriminator
                
        dis_conv_filters = [64, 64, 128, 128]
        dis_conv_kernel = [5, 5, 5, 5]
        dis_conv_strides = [2, 2, 2, 1]
        dis_dropout = 0.4
        dis_momentum = 0.9
        dis_lr = 0.0008
        
        dis_input = Input(self.input_size)
        x = dis_input
        
        for i, c in enumerate(zip(dis_conv_filters, dis_conv_kernel, dis_conv_strides)):
            f, k, s = c
            x = Conv2D(f, kernel_size=k, strides=s, padding="same", name=f"disc_conv_{i}")(x)
            x = BatchNormalization(momentum=dis_momentum)(x)
            x = Activation('relu')(x)
            x = Dropout(rate = dis_dropout)(x)
        x = Flatten()(x)
        dis_output = Dense(1, activation='sigmoid', kernel_initializer=RandomNormal(mean=0., stddev=0.02))(x)
        self.discriminator = Model(dis_input, dis_output)
        
        # Generator 
        z_dim = 100
        gen_initial_dense_size = (7, 7, 64)
        gen_momentum = 0.9
        
        gen_input = Input((z_dim,))
        x = gen_input
        
        x = Dense(np.prod(gen_initial_dense_size))(x)
        x = BatchNormalization(momentum=gen_momentum)(x)
        x = Activation('relu')
        x = Reshape(gen_initial_dense_size)(x)
        
        
        

In [37]:
gan = Gan((28,28,1))

In [38]:
np.prod((7,7,64))

3136