In [19]:
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Flatten, Dropout, LeakyReLU
from keras.optimizers import Adam
from keras.utils.vis_utils import plot_model

from keras.datasets.mnist import load_data

import numpy as np

In [10]:
# define discriminator

def define_discriminator(in_shape=(28, 28, 1)):
    
    model = Sequential()
    model.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=in_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    
    model.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    
    model.add(Flatten())
    
    # compile model
    optimizerr = Adam(lr=0.0002, beta_1 = 0.5)
    model.compile(loss='binary_crossentropy', optimizer = optimizerr, metrics=['accuracy'])
    
    return model

In [11]:
# get discriminator
discriminator = define_discriminator()
discriminator.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_8 (Conv2D)            (None, 14, 14, 64)        640       
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
dropout_8 (Dropout)          (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 7, 7, 128)         73856     
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
dropout_9 (Dropout)          (None, 7, 7, 128)         0         
_________________________________________________________________
flatten_4 (Flatten)          (None, 6272)             

In [23]:
# load mnist dataset

(images, _), (_, _) = load_data()

In [24]:
print(images.shape)

(60000, 28, 28)


In [25]:
images = np.expand_dims(images, axis=2)
print(images.shape)

(60000, 28, 1, 28)
