In [3]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

class GAN():
    def __init__(self):
        self.img_rows = 28 
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', 
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build and compile the generator
        self.generator = self.build_generator()
        self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        # The generator takes noise as input and generated imgs
        z = Input(shape=(100,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The valid takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator) takes
        # noise as input => generates images => determines validity 
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        noise_shape = (100,)
        
        model = Sequential()

        model.add(Dense(256, input_shape=noise_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=noise_shape)
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        img_shape = (self.img_rows, self.img_cols, self.channels)
        
        model = Sequential()

        model.add(Flatten(input_shape=img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (half_batch, 100))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, 100))

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            valid_y = np.array([1] * batch_size)

            # Train the generator
            g_loss = self.combined.train_on_batch(noise, valid_y)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, save_interval=200)


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_3 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_15 (Dense)             (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU)   (None, 512)               0         
_________________________________________________________________
dense_16 (Dense)             (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_12 (LeakyReLU)   (None, 256)               0         
_________________________________________________________________
dense_17 (Dense)             (None, 1)                 257       
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
____

97 [D loss: 0.015965, acc.: 100.00%] [G loss: 4.766818]
98 [D loss: 0.013484, acc.: 100.00%] [G loss: 4.812083]
99 [D loss: 0.010489, acc.: 100.00%] [G loss: 4.746594]
100 [D loss: 0.008062, acc.: 100.00%] [G loss: 4.641602]
101 [D loss: 0.018125, acc.: 100.00%] [G loss: 4.998999]
102 [D loss: 0.012566, acc.: 100.00%] [G loss: 4.861823]
103 [D loss: 0.019380, acc.: 100.00%] [G loss: 5.045329]
104 [D loss: 0.010483, acc.: 100.00%] [G loss: 4.880792]
105 [D loss: 0.015628, acc.: 100.00%] [G loss: 5.004563]
106 [D loss: 0.013592, acc.: 100.00%] [G loss: 4.987072]
107 [D loss: 0.018321, acc.: 100.00%] [G loss: 4.843178]
108 [D loss: 0.016092, acc.: 100.00%] [G loss: 4.894298]
109 [D loss: 0.008290, acc.: 100.00%] [G loss: 5.199981]
110 [D loss: 0.015811, acc.: 100.00%] [G loss: 4.970630]
111 [D loss: 0.020273, acc.: 100.00%] [G loss: 5.276041]
112 [D loss: 0.017002, acc.: 100.00%] [G loss: 5.159397]
113 [D loss: 0.008253, acc.: 100.00%] [G loss: 4.901512]
114 [D loss: 0.019811, acc.: 100.0

244 [D loss: 0.676029, acc.: 56.25%] [G loss: 1.224958]
245 [D loss: 1.037067, acc.: 31.25%] [G loss: 0.765532]
246 [D loss: 0.673105, acc.: 50.00%] [G loss: 0.985473]
247 [D loss: 0.724925, acc.: 46.88%] [G loss: 1.077056]
248 [D loss: 0.769787, acc.: 40.62%] [G loss: 0.929680]
249 [D loss: 0.680495, acc.: 46.88%] [G loss: 1.033176]
250 [D loss: 0.733438, acc.: 50.00%] [G loss: 0.982123]
251 [D loss: 0.681748, acc.: 53.12%] [G loss: 0.946832]
252 [D loss: 0.809787, acc.: 37.50%] [G loss: 0.792056]
253 [D loss: 0.607527, acc.: 59.38%] [G loss: 1.004329]
254 [D loss: 0.849763, acc.: 31.25%] [G loss: 0.732867]
255 [D loss: 0.658774, acc.: 46.88%] [G loss: 0.773408]
256 [D loss: 0.684575, acc.: 43.75%] [G loss: 0.747482]
257 [D loss: 0.721245, acc.: 46.88%] [G loss: 0.816052]
258 [D loss: 0.752282, acc.: 43.75%] [G loss: 0.777599]
259 [D loss: 0.564403, acc.: 68.75%] [G loss: 1.112757]
260 [D loss: 0.827266, acc.: 37.50%] [G loss: 0.813205]
261 [D loss: 0.777882, acc.: 46.88%] [G loss: 0.

391 [D loss: 0.613777, acc.: 62.50%] [G loss: 0.706074]
392 [D loss: 0.639859, acc.: 65.62%] [G loss: 0.722366]
393 [D loss: 0.696026, acc.: 46.88%] [G loss: 0.671283]
394 [D loss: 0.659259, acc.: 56.25%] [G loss: 0.660230]
395 [D loss: 0.703017, acc.: 46.88%] [G loss: 0.661676]
396 [D loss: 0.663531, acc.: 46.88%] [G loss: 0.674665]
397 [D loss: 0.677893, acc.: 50.00%] [G loss: 0.667509]
398 [D loss: 0.668882, acc.: 50.00%] [G loss: 0.653827]
399 [D loss: 0.659246, acc.: 50.00%] [G loss: 0.664419]
400 [D loss: 0.685910, acc.: 46.88%] [G loss: 0.684111]
401 [D loss: 0.669259, acc.: 46.88%] [G loss: 0.652101]
402 [D loss: 0.647265, acc.: 53.12%] [G loss: 0.639825]
403 [D loss: 0.651917, acc.: 56.25%] [G loss: 0.653257]
404 [D loss: 0.650822, acc.: 59.38%] [G loss: 0.680625]
405 [D loss: 0.690994, acc.: 46.88%] [G loss: 0.663951]
406 [D loss: 0.680166, acc.: 46.88%] [G loss: 0.650104]
407 [D loss: 0.677611, acc.: 50.00%] [G loss: 0.650855]
408 [D loss: 0.688384, acc.: 50.00%] [G loss: 0.

540 [D loss: 0.705228, acc.: 56.25%] [G loss: 0.761503]
541 [D loss: 0.669871, acc.: 53.12%] [G loss: 0.774188]
542 [D loss: 0.631152, acc.: 75.00%] [G loss: 0.734368]
543 [D loss: 0.678093, acc.: 53.12%] [G loss: 0.714154]
544 [D loss: 0.675807, acc.: 46.88%] [G loss: 0.678501]
545 [D loss: 0.650694, acc.: 59.38%] [G loss: 0.689516]
546 [D loss: 0.674183, acc.: 46.88%] [G loss: 0.698618]
547 [D loss: 0.616891, acc.: 62.50%] [G loss: 0.713340]
548 [D loss: 0.637919, acc.: 68.75%] [G loss: 0.717936]
549 [D loss: 0.634680, acc.: 59.38%] [G loss: 0.749140]
550 [D loss: 0.654767, acc.: 56.25%] [G loss: 0.752324]
551 [D loss: 0.660820, acc.: 59.38%] [G loss: 0.739130]
552 [D loss: 0.635626, acc.: 65.62%] [G loss: 0.757651]
553 [D loss: 0.662622, acc.: 65.62%] [G loss: 0.736794]
554 [D loss: 0.676238, acc.: 53.12%] [G loss: 0.708562]
555 [D loss: 0.655115, acc.: 65.62%] [G loss: 0.694236]
556 [D loss: 0.662711, acc.: 59.38%] [G loss: 0.673975]
557 [D loss: 0.667223, acc.: 56.25%] [G loss: 0.

690 [D loss: 0.620612, acc.: 62.50%] [G loss: 0.818941]
691 [D loss: 0.678432, acc.: 53.12%] [G loss: 0.781900]
692 [D loss: 0.600666, acc.: 68.75%] [G loss: 0.813287]
693 [D loss: 0.631477, acc.: 65.62%] [G loss: 0.825718]
694 [D loss: 0.688216, acc.: 53.12%] [G loss: 0.766312]
695 [D loss: 0.621343, acc.: 59.38%] [G loss: 0.759763]
696 [D loss: 0.624545, acc.: 59.38%] [G loss: 0.777318]
697 [D loss: 0.634026, acc.: 65.62%] [G loss: 0.729448]
698 [D loss: 0.674647, acc.: 59.38%] [G loss: 0.719397]
699 [D loss: 0.625218, acc.: 56.25%] [G loss: 0.761912]
700 [D loss: 0.646619, acc.: 62.50%] [G loss: 0.795921]
701 [D loss: 0.586791, acc.: 78.12%] [G loss: 0.817861]
702 [D loss: 0.644920, acc.: 68.75%] [G loss: 0.734843]
703 [D loss: 0.667965, acc.: 59.38%] [G loss: 0.709831]
704 [D loss: 0.663844, acc.: 43.75%] [G loss: 0.717304]
705 [D loss: 0.678718, acc.: 50.00%] [G loss: 0.765432]
706 [D loss: 0.663916, acc.: 59.38%] [G loss: 0.764145]
707 [D loss: 0.633947, acc.: 59.38%] [G loss: 0.

837 [D loss: 0.618421, acc.: 71.88%] [G loss: 0.798161]
838 [D loss: 0.601476, acc.: 71.88%] [G loss: 0.751755]
839 [D loss: 0.609123, acc.: 65.62%] [G loss: 0.750642]
840 [D loss: 0.648579, acc.: 62.50%] [G loss: 0.721334]
841 [D loss: 0.624920, acc.: 62.50%] [G loss: 0.754077]
842 [D loss: 0.676951, acc.: 56.25%] [G loss: 0.736555]
843 [D loss: 0.614446, acc.: 65.62%] [G loss: 0.796273]
844 [D loss: 0.602010, acc.: 65.62%] [G loss: 0.820986]
845 [D loss: 0.641445, acc.: 68.75%] [G loss: 0.796329]
846 [D loss: 0.742968, acc.: 46.88%] [G loss: 0.799046]
847 [D loss: 0.656263, acc.: 56.25%] [G loss: 0.754691]
848 [D loss: 0.640105, acc.: 59.38%] [G loss: 0.795985]
849 [D loss: 0.713675, acc.: 50.00%] [G loss: 0.799579]
850 [D loss: 0.665909, acc.: 53.12%] [G loss: 0.840808]
851 [D loss: 0.658139, acc.: 56.25%] [G loss: 0.804728]
852 [D loss: 0.696178, acc.: 56.25%] [G loss: 0.808299]
853 [D loss: 0.621680, acc.: 56.25%] [G loss: 0.863379]
854 [D loss: 0.692817, acc.: 56.25%] [G loss: 0.

987 [D loss: 0.638605, acc.: 59.38%] [G loss: 0.843406]
988 [D loss: 0.611120, acc.: 68.75%] [G loss: 0.893775]
989 [D loss: 0.640271, acc.: 65.62%] [G loss: 0.850550]
990 [D loss: 0.594886, acc.: 68.75%] [G loss: 0.923480]
991 [D loss: 0.612628, acc.: 71.88%] [G loss: 0.890976]
992 [D loss: 0.596472, acc.: 81.25%] [G loss: 0.868062]
993 [D loss: 0.600325, acc.: 68.75%] [G loss: 0.857983]
994 [D loss: 0.588368, acc.: 71.88%] [G loss: 0.855970]
995 [D loss: 0.560931, acc.: 78.12%] [G loss: 0.774584]
996 [D loss: 0.600967, acc.: 68.75%] [G loss: 0.798013]
997 [D loss: 0.685447, acc.: 53.12%] [G loss: 0.833126]
998 [D loss: 0.625215, acc.: 62.50%] [G loss: 0.860195]
999 [D loss: 0.650963, acc.: 56.25%] [G loss: 0.833515]
1000 [D loss: 0.681837, acc.: 50.00%] [G loss: 0.864061]
1001 [D loss: 0.553079, acc.: 81.25%] [G loss: 0.889839]
1002 [D loss: 0.689588, acc.: 50.00%] [G loss: 0.825954]
1003 [D loss: 0.686582, acc.: 50.00%] [G loss: 0.763782]
1004 [D loss: 0.623317, acc.: 65.62%] [G los

1131 [D loss: 0.698449, acc.: 62.50%] [G loss: 0.801497]
1132 [D loss: 0.670707, acc.: 56.25%] [G loss: 0.862657]
1133 [D loss: 0.581995, acc.: 71.88%] [G loss: 0.863113]
1134 [D loss: 0.601727, acc.: 81.25%] [G loss: 0.832679]
1135 [D loss: 0.637979, acc.: 50.00%] [G loss: 0.863377]
1136 [D loss: 0.578908, acc.: 81.25%] [G loss: 0.919656]
1137 [D loss: 0.628228, acc.: 68.75%] [G loss: 0.886720]
1138 [D loss: 0.655071, acc.: 53.12%] [G loss: 0.774211]
1139 [D loss: 0.656370, acc.: 59.38%] [G loss: 0.900189]
1140 [D loss: 0.618070, acc.: 65.62%] [G loss: 0.884907]
1141 [D loss: 0.659460, acc.: 56.25%] [G loss: 0.826726]
1142 [D loss: 0.607455, acc.: 75.00%] [G loss: 0.847278]
1143 [D loss: 0.632393, acc.: 65.62%] [G loss: 0.821788]
1144 [D loss: 0.607456, acc.: 68.75%] [G loss: 0.857083]
1145 [D loss: 0.697625, acc.: 50.00%] [G loss: 0.773063]
1146 [D loss: 0.628299, acc.: 62.50%] [G loss: 0.802879]
1147 [D loss: 0.704164, acc.: 46.88%] [G loss: 0.769068]
1148 [D loss: 0.652830, acc.: 5

1276 [D loss: 0.620156, acc.: 68.75%] [G loss: 0.822175]
1277 [D loss: 0.570085, acc.: 84.38%] [G loss: 0.815933]
1278 [D loss: 0.558651, acc.: 87.50%] [G loss: 0.798788]
1279 [D loss: 0.579887, acc.: 78.12%] [G loss: 0.803300]
1280 [D loss: 0.683201, acc.: 62.50%] [G loss: 0.804431]
1281 [D loss: 0.619212, acc.: 59.38%] [G loss: 0.833463]
1282 [D loss: 0.591724, acc.: 71.88%] [G loss: 0.828547]
1283 [D loss: 0.617775, acc.: 59.38%] [G loss: 0.882289]
1284 [D loss: 0.586461, acc.: 71.88%] [G loss: 0.867380]
1285 [D loss: 0.569631, acc.: 84.38%] [G loss: 0.939691]
1286 [D loss: 0.672464, acc.: 50.00%] [G loss: 0.852761]
1287 [D loss: 0.575026, acc.: 71.88%] [G loss: 0.849691]
1288 [D loss: 0.592373, acc.: 81.25%] [G loss: 0.855561]
1289 [D loss: 0.595973, acc.: 75.00%] [G loss: 0.880761]
1290 [D loss: 0.613256, acc.: 68.75%] [G loss: 0.840073]
1291 [D loss: 0.604364, acc.: 71.88%] [G loss: 0.841256]
1292 [D loss: 0.587749, acc.: 68.75%] [G loss: 0.841201]
1293 [D loss: 0.566284, acc.: 7

KeyboardInterrupt: 