In [12]:
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Conv2DTranspose, Reshape, Flatten, Dropout, LeakyReLU
from keras.optimizers import Adam
from keras.utils.vis_utils import plot_model

from keras.datasets.mnist import load_data

import numpy as np
from matplotlib import pyplot

In [13]:
# define discriminator

def define_discriminator(in_shape=(28, 28, 1)):
    
    model = Sequential(name='Discriminator')
    model.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=in_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    
    model.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    
    model.add(Flatten())
    
    model.add(Dense(1, activation='sigmoid'))
    
    # compile model
    optimizerr = Adam(lr=0.0002, beta_1 = 0.5)
    model.compile(loss='binary_crossentropy', optimizer = optimizerr, metrics=['accuracy'])
    
    return model

In [14]:
# define the standalone generator

def define_generator(latent_dim):
    gen = Sequential(name='Generator')
    
    # foundation for 7x7 image
    n_nodes = 7 * 7 * 128
    
    gen.add(Dense(n_nodes, input_dim=latent_dim))
    gen.add(Reshape((7, 7, 128)))
    
    # upsample to 14x14
    gen.add(Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'))
    gen.add(LeakyReLU(alpha=0.2))
    
    # upsample to 28x28
    gen.add(Conv2DTranspose(32, (4, 4), strides=(2, 2), padding='same'))
    gen.add(LeakyReLU(alpha=0.2))
    
    gen.add(Conv2D(1, (7, 7), activation='sigmoid', padding='same'))
    
    return gen

In [15]:
# define gan

def define_gan(g_model, d_model):
    # make weights in the discriminator not trainable
    d_model.trainable = False
    
    # connect them
    gan = Sequential(name='GAN')
    
    gan.add(g_model)
    gan.add(d_model)
    
    # compile model
    opt = Adam(lr = 0.0002, beta_1 = 0.5)
    
    gan.compile(loss='binary_crossentropy', optimizer=opt)
    
    return gan

In [16]:
# load mnist dataset

def get_mnist():
    (images, _), (_, _) = load_data()
    images = np.expand_dims(images, axis=-1)
    images = images.astype('float32')
    
    # scale from [0, 255] to [0, 1]
    images = images / 255.0
    
    return images

In [17]:
# select real samples

def generate_real_samples(dataset, n_samples):
    # choose random index of dataset
    ix = np.random.randint(0, dataset.shape[0], n_samples)
    X = dataset[ix]
    y = np.ones((n_samples, 1))
    
    return X, y

In [18]:
# generate points in latent space as input for the generator

def generate_latent_points(latent_dim, n_samples):
    # generate points in the latent space
    x_input = np.random.randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    x_input = x_input.reshape((n_samples, latent_dim))
    
    return x_input

In [19]:
# generate n fake samples with class labels

def generate_fake_samples(g_model, latent_dim, n_samples):
    
    X = generate_latent_points(latent_dim, n_samples)
    # reshape generated data into gray scale
    X = g_model.predict(X)
    # generate class labels for fake images
    y = np.zeros((n_samples, 1))
    
    return X, y

In [20]:
# summarize performance

def summarize_performance(epoch, g_model, d_model, dataset, latent_dim, n_samples=100):
     
        # prepare real samples
        X_real, y_real = generate_real_samples(dataset, n_samples)
        # evaluate discriminator on real samples
        _, acc_real = d_model.evaluate(X_real, y_real, verbose=0)
        
        # prepare fake samples
        X_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_samples)
        # evaluate discriminator on fake examples
        _, acc_fake = d_model.evaluate(X_fake, y_fake, verbose=0)
        
        # summarize the discriminator performance
        print('\n>Acc for real sample: %.0f%%\nAcc for fake sample: %.0f%%' % (acc_real*100, acc_fake*100))
        
        # save the generator model file
        filename = 'generator_model_%03d.h5' % (epoch + 1)
        g_model.save(filename)

In [21]:
# train the discriminator

def trainGAN(g_model, d_model, gan, dataset, latent_dim, n_epochs=100, n_batch=256):
    batch_per_epoch = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch/2)
    
    for i in range(n_epochs):
        for j in range(batch_per_epoch):
            X_real, y_real = generate_real_samples(dataset, half_batch)
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            
            X, y = np.vstack((X_real, X_fake)), np.vstack((y_real, y_fake))
            
            d_loss, _ = d_model.train_on_batch(X, y)
            
            X_gan = generate_latent_points(latent_dim, n_batch)
            y_gan = np.ones((n_batch, 1))
            
            # update the generator via the discriminator's error
            g_loss = gan.train_on_batch(X_gan, y_gan)
            
            print('>%d, %d/%d, d=%.3f, g=%.3f' % (i+1, j+1, batch_per_epoch, d_loss, g_loss))
            
        # evaluate the model performance sometimes
        if (i+1)%10 == 0:
            summarize_performance(i, g_model, d_model, dataset, latent_dim)

In [22]:
latent_dim = 100
d_model = define_discriminator()
g_model = define_generator(latent_dim)

gan = define_gan(g_model, d_model)

# load image data
dataset = get_mnist()

d_model.summary()
g_model.summary()
gan.summary()

Model: "Discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_2 (Conv2D)            (None, 14, 14, 64)        640       
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 128)         73856     
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 7, 7, 128)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 6272)            

In [None]:
# train GAN
trainGAN(g_model, d_model, gan, dataset, latent_dim)

>1, 1/234, d=0.700, g=0.710
>1, 2/234, d=0.688, g=0.739
>1, 3/234, d=0.678, g=0.765
>1, 4/234, d=0.667, g=0.787
>1, 5/234, d=0.656, g=0.816
>1, 6/234, d=0.646, g=0.843
>1, 7/234, d=0.633, g=0.871
>1, 8/234, d=0.626, g=0.901
>1, 9/234, d=0.614, g=0.941
>1, 10/234, d=0.604, g=0.975
>1, 11/234, d=0.592, g=1.012
>1, 12/234, d=0.577, g=1.046
>1, 13/234, d=0.570, g=1.087
>1, 14/234, d=0.555, g=1.121
>1, 15/234, d=0.544, g=1.155
>1, 16/234, d=0.528, g=1.190
>1, 17/234, d=0.522, g=1.223
>1, 18/234, d=0.506, g=1.248
>1, 19/234, d=0.495, g=1.267
>1, 20/234, d=0.479, g=1.290
>1, 21/234, d=0.468, g=1.301
>1, 22/234, d=0.456, g=1.314
>1, 23/234, d=0.446, g=1.318
>1, 24/234, d=0.433, g=1.320
>1, 25/234, d=0.417, g=1.310
>1, 26/234, d=0.408, g=1.285
>1, 27/234, d=0.403, g=1.266
>1, 28/234, d=0.399, g=1.260
>1, 29/234, d=0.385, g=1.218
>1, 30/234, d=0.384, g=1.187
>1, 31/234, d=0.379, g=1.148
>1, 32/234, d=0.370, g=1.123
>1, 33/234, d=0.385, g=1.083
>1, 34/234, d=0.373, g=1.032
>1, 35/234, d=0.373, g=

>2, 47/234, d=0.516, g=1.104
>2, 48/234, d=0.519, g=1.117
>2, 49/234, d=0.458, g=1.059
>2, 50/234, d=0.480, g=1.114
>2, 51/234, d=0.524, g=1.069
>2, 52/234, d=0.520, g=0.960
>2, 53/234, d=0.563, g=1.032
>2, 54/234, d=0.480, g=1.259
>2, 55/234, d=0.553, g=0.846
>2, 56/234, d=0.519, g=1.208
>2, 57/234, d=0.539, g=1.125
>2, 58/234, d=0.512, g=0.888
>2, 59/234, d=0.509, g=1.162
>2, 60/234, d=0.538, g=1.075
>2, 61/234, d=0.517, g=1.164
>2, 62/234, d=0.523, g=1.079
>2, 63/234, d=0.508, g=1.087
>2, 64/234, d=0.528, g=1.029
>2, 65/234, d=0.507, g=1.022
>2, 66/234, d=0.516, g=1.129
>2, 67/234, d=0.523, g=1.016
>2, 68/234, d=0.559, g=0.984
>2, 69/234, d=0.542, g=1.003
>2, 70/234, d=0.557, g=1.052
>2, 71/234, d=0.550, g=0.930
>2, 72/234, d=0.568, g=1.084
>2, 73/234, d=0.538, g=1.102
>2, 74/234, d=0.580, g=1.042
>2, 75/234, d=0.575, g=0.991
>2, 76/234, d=0.597, g=0.967
>2, 77/234, d=0.550, g=1.062
>2, 78/234, d=0.546, g=1.039
>2, 79/234, d=0.543, g=0.925
>2, 80/234, d=0.578, g=1.024
>2, 81/234, d=

>3, 93/234, d=0.665, g=0.789
>3, 94/234, d=0.665, g=0.777
>3, 95/234, d=0.640, g=0.823
>3, 96/234, d=0.661, g=0.839
>3, 97/234, d=0.647, g=0.831
>3, 98/234, d=0.670, g=0.791
>3, 99/234, d=0.643, g=0.816
>3, 100/234, d=0.655, g=0.866
>3, 101/234, d=0.663, g=0.871
>3, 102/234, d=0.642, g=0.829
>3, 103/234, d=0.610, g=0.846
>3, 104/234, d=0.607, g=0.837
>3, 105/234, d=0.637, g=0.841
>3, 106/234, d=0.649, g=0.875
>3, 107/234, d=0.623, g=0.863
>3, 108/234, d=0.642, g=0.873
>3, 109/234, d=0.622, g=0.896
>3, 110/234, d=0.609, g=0.843
>3, 111/234, d=0.606, g=0.848
>3, 112/234, d=0.611, g=0.826
>3, 113/234, d=0.611, g=0.797
>3, 114/234, d=0.615, g=0.870
>3, 115/234, d=0.619, g=0.886
>3, 116/234, d=0.615, g=0.925
>3, 117/234, d=0.631, g=0.850
>3, 118/234, d=0.625, g=0.807
>3, 119/234, d=0.606, g=0.832
>3, 120/234, d=0.633, g=0.920
>3, 121/234, d=0.623, g=0.910
>3, 122/234, d=0.621, g=0.823
>3, 123/234, d=0.635, g=0.818
>3, 124/234, d=0.623, g=0.827
>3, 125/234, d=0.605, g=0.867
>3, 126/234, d=0.

>4, 136/234, d=0.642, g=0.810
>4, 137/234, d=0.635, g=0.792
>4, 138/234, d=0.637, g=0.800
>4, 139/234, d=0.634, g=0.754
>4, 140/234, d=0.653, g=0.770
>4, 141/234, d=0.615, g=0.776
>4, 142/234, d=0.658, g=0.813
>4, 143/234, d=0.645, g=0.835
>4, 144/234, d=0.634, g=0.802
>4, 145/234, d=0.627, g=0.781
>4, 146/234, d=0.624, g=0.748
>4, 147/234, d=0.643, g=0.779
>4, 148/234, d=0.630, g=0.817
>4, 149/234, d=0.651, g=0.811
>4, 150/234, d=0.637, g=0.793
>4, 151/234, d=0.668, g=0.769
>4, 152/234, d=0.647, g=0.744
>4, 153/234, d=0.658, g=0.779
>4, 154/234, d=0.659, g=0.823
>4, 155/234, d=0.641, g=0.772
>4, 156/234, d=0.654, g=0.743
>4, 157/234, d=0.669, g=0.708
>4, 158/234, d=0.670, g=0.755
>4, 159/234, d=0.660, g=0.774
>4, 160/234, d=0.684, g=0.823
>4, 161/234, d=0.670, g=0.784
>4, 162/234, d=0.662, g=0.769
>4, 163/234, d=0.664, g=0.773
>4, 164/234, d=0.654, g=0.777
>4, 165/234, d=0.671, g=0.718
>4, 166/234, d=0.677, g=0.733
>4, 167/234, d=0.672, g=0.762
>4, 168/234, d=0.669, g=0.819
>4, 169/23

>5, 180/234, d=0.673, g=0.757
>5, 181/234, d=0.673, g=0.805
>5, 182/234, d=0.681, g=0.816
>5, 183/234, d=0.657, g=0.796
>5, 184/234, d=0.684, g=0.787
>5, 185/234, d=0.656, g=0.751
>5, 186/234, d=0.655, g=0.767
>5, 187/234, d=0.667, g=0.745
>5, 188/234, d=0.646, g=0.797
>5, 189/234, d=0.646, g=0.819
>5, 190/234, d=0.647, g=0.823
>5, 191/234, d=0.651, g=0.787
>5, 192/234, d=0.661, g=0.758
>5, 193/234, d=0.648, g=0.776
>5, 194/234, d=0.656, g=0.780
>5, 195/234, d=0.657, g=0.759
>5, 196/234, d=0.680, g=0.748
>5, 197/234, d=0.664, g=0.736
>5, 198/234, d=0.660, g=0.729
>5, 199/234, d=0.669, g=0.711
>5, 200/234, d=0.647, g=0.756
>5, 201/234, d=0.664, g=0.737
>5, 202/234, d=0.679, g=0.727
>5, 203/234, d=0.665, g=0.750
>5, 204/234, d=0.684, g=0.744
>5, 205/234, d=0.682, g=0.685
>5, 206/234, d=0.672, g=0.682
>5, 207/234, d=0.664, g=0.714
>5, 208/234, d=0.676, g=0.739
>5, 209/234, d=0.662, g=0.768
>5, 210/234, d=0.654, g=0.745
>5, 211/234, d=0.658, g=0.734
>5, 212/234, d=0.664, g=0.767
>5, 213/23

>6, 224/234, d=0.638, g=0.836
>6, 225/234, d=0.642, g=0.786
>6, 226/234, d=0.638, g=0.734
>6, 227/234, d=0.653, g=0.751
>6, 228/234, d=0.660, g=0.827
>6, 229/234, d=0.627, g=0.870
>6, 230/234, d=0.652, g=0.777
>6, 231/234, d=0.650, g=0.742
>6, 232/234, d=0.641, g=0.782
>6, 233/234, d=0.639, g=0.822
>6, 234/234, d=0.645, g=0.806
>7, 1/234, d=0.638, g=0.763
>7, 2/234, d=0.648, g=0.773
>7, 3/234, d=0.647, g=0.742
>7, 4/234, d=0.661, g=0.786
>7, 5/234, d=0.662, g=0.805
>7, 6/234, d=0.636, g=0.789
>7, 7/234, d=0.663, g=0.754
>7, 8/234, d=0.662, g=0.792
>7, 9/234, d=0.639, g=0.810
>7, 10/234, d=0.654, g=0.739
>7, 11/234, d=0.658, g=0.767
>7, 12/234, d=0.645, g=0.783
>7, 13/234, d=0.651, g=0.809
>7, 14/234, d=0.671, g=0.771
>7, 15/234, d=0.656, g=0.755
>7, 16/234, d=0.658, g=0.788
>7, 17/234, d=0.682, g=0.749
>7, 18/234, d=0.665, g=0.760
>7, 19/234, d=0.637, g=0.745
>7, 20/234, d=0.646, g=0.762
>7, 21/234, d=0.659, g=0.784
>7, 22/234, d=0.647, g=0.744
>7, 23/234, d=0.658, g=0.784
>7, 24/234, 

>8, 36/234, d=0.640, g=0.798
>8, 37/234, d=0.638, g=0.781
>8, 38/234, d=0.643, g=0.776
>8, 39/234, d=0.649, g=0.803
>8, 40/234, d=0.626, g=0.768
>8, 41/234, d=0.639, g=0.777
>8, 42/234, d=0.648, g=0.756
>8, 43/234, d=0.647, g=0.780
>8, 44/234, d=0.646, g=0.801
>8, 45/234, d=0.627, g=0.795
>8, 46/234, d=0.640, g=0.843
>8, 47/234, d=0.638, g=0.757
>8, 48/234, d=0.667, g=0.769
>8, 49/234, d=0.639, g=0.870
>8, 50/234, d=0.651, g=0.768
>8, 51/234, d=0.641, g=0.830
>8, 52/234, d=0.634, g=0.809
>8, 53/234, d=0.640, g=0.800
>8, 54/234, d=0.647, g=0.785
>8, 55/234, d=0.657, g=0.831
>8, 56/234, d=0.637, g=0.825
>8, 57/234, d=0.645, g=0.757
>8, 58/234, d=0.639, g=0.782
>8, 59/234, d=0.664, g=0.757
>8, 60/234, d=0.651, g=0.780
>8, 61/234, d=0.641, g=0.784
>8, 62/234, d=0.653, g=0.815
>8, 63/234, d=0.648, g=0.719
>8, 64/234, d=0.680, g=0.805
>8, 65/234, d=0.638, g=0.786
>8, 66/234, d=0.631, g=0.770
>8, 67/234, d=0.647, g=0.771
>8, 68/234, d=0.652, g=0.748
>8, 69/234, d=0.659, g=0.751
>8, 70/234, d=

>9, 81/234, d=0.658, g=0.816
>9, 82/234, d=0.643, g=0.781
>9, 83/234, d=0.649, g=0.742
>9, 84/234, d=0.658, g=0.871
>9, 85/234, d=0.641, g=0.853
>9, 86/234, d=0.640, g=0.848
>9, 87/234, d=0.619, g=0.798
>9, 88/234, d=0.629, g=0.885
>9, 89/234, d=0.604, g=0.828
>9, 90/234, d=0.618, g=0.829
>9, 91/234, d=0.613, g=0.875
>9, 92/234, d=0.641, g=0.817
>9, 93/234, d=0.648, g=0.808
>9, 94/234, d=0.669, g=0.776
>9, 95/234, d=0.620, g=0.738
>9, 96/234, d=0.643, g=0.773
>9, 97/234, d=0.641, g=0.751
>9, 98/234, d=0.660, g=0.732
>9, 99/234, d=0.650, g=0.807
>9, 100/234, d=0.659, g=0.844
>9, 101/234, d=0.659, g=0.720
>9, 102/234, d=0.654, g=0.773
>9, 103/234, d=0.688, g=0.871
>9, 104/234, d=0.673, g=0.810
>9, 105/234, d=0.656, g=0.747
>9, 106/234, d=0.688, g=0.773
>9, 107/234, d=0.676, g=0.764
>9, 108/234, d=0.653, g=0.814
>9, 109/234, d=0.654, g=0.774
>9, 110/234, d=0.639, g=0.791
>9, 111/234, d=0.639, g=0.867
>9, 112/234, d=0.628, g=0.800
>9, 113/234, d=0.648, g=0.779
>9, 114/234, d=0.606, g=0.822

>10, 121/234, d=0.646, g=0.811
>10, 122/234, d=0.657, g=0.844
>10, 123/234, d=0.657, g=0.807
>10, 124/234, d=0.644, g=0.769
>10, 125/234, d=0.626, g=0.903
>10, 126/234, d=0.640, g=0.857
>10, 127/234, d=0.640, g=0.755
>10, 128/234, d=0.639, g=0.781
>10, 129/234, d=0.640, g=0.815
>10, 130/234, d=0.639, g=0.790
>10, 131/234, d=0.620, g=0.773
>10, 132/234, d=0.621, g=0.802
>10, 133/234, d=0.640, g=0.825
>10, 134/234, d=0.636, g=0.779
>10, 135/234, d=0.623, g=0.764
>10, 136/234, d=0.645, g=0.761
>10, 137/234, d=0.628, g=0.802
>10, 138/234, d=0.658, g=0.822
>10, 139/234, d=0.652, g=0.781
>10, 140/234, d=0.658, g=0.794
>10, 141/234, d=0.666, g=0.774
>10, 142/234, d=0.662, g=0.820
>10, 143/234, d=0.686, g=0.806
>10, 144/234, d=0.677, g=0.722
>10, 145/234, d=0.666, g=0.827
>10, 146/234, d=0.674, g=0.763
>10, 147/234, d=0.664, g=0.776
>10, 148/234, d=0.667, g=0.689
>10, 149/234, d=0.660, g=0.757
>10, 150/234, d=0.664, g=0.728
>10, 151/234, d=0.675, g=0.794
>10, 152/234, d=0.650, g=0.785
>10, 153

>11, 154/234, d=0.674, g=0.770
>11, 155/234, d=0.659, g=0.753
>11, 156/234, d=0.641, g=0.773
>11, 157/234, d=0.687, g=0.801
>11, 158/234, d=0.656, g=0.767
>11, 159/234, d=0.642, g=0.727
>11, 160/234, d=0.663, g=0.843
>11, 161/234, d=0.672, g=0.775
>11, 162/234, d=0.659, g=0.721
>11, 163/234, d=0.667, g=0.717
>11, 164/234, d=0.659, g=0.733
>11, 165/234, d=0.673, g=0.736
>11, 166/234, d=0.651, g=0.754
>11, 167/234, d=0.660, g=0.779
>11, 168/234, d=0.660, g=0.791
>11, 169/234, d=0.657, g=0.710
>11, 170/234, d=0.676, g=0.765
>11, 171/234, d=0.643, g=0.817
>11, 172/234, d=0.649, g=0.794
>11, 173/234, d=0.652, g=0.781
>11, 174/234, d=0.662, g=0.825
>11, 175/234, d=0.627, g=0.774
>11, 176/234, d=0.662, g=0.744
>11, 177/234, d=0.675, g=0.749
>11, 178/234, d=0.670, g=0.708
>11, 179/234, d=0.645, g=0.772
>11, 180/234, d=0.683, g=0.788
>11, 181/234, d=0.653, g=0.655
>11, 182/234, d=0.689, g=0.765
>11, 183/234, d=0.655, g=0.837
>11, 184/234, d=0.650, g=0.760
>11, 185/234, d=0.667, g=0.801
>11, 186