In [2]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

class DCGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
        model.add(Reshape((7, 7, 128)))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
        model.add(Activation("tanh"))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(ZeroPadding2D(padding=((0,1),(0,1))))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(1, activation='sigmoid'))

        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            # Sample noise and generate a batch of new images
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator (real classified as ones and generated as zeros)
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Train the generator (wants discriminator to mistake images as real)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    dcgan = DCGAN()
    dcgan.train(epochs=4000, batch_size=32, save_interval=50)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_8 (Conv2D)            (None, 14, 14, 32)        320       
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 14, 14, 32)        0         
_________________________________________________________________
dropout_5 (Dropout)          (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 7, 7, 64)          18496     
_________________________________________________________________
zero_padding2d_2 (ZeroPaddin (None, 8, 8, 64)          0         
_________________________________________________________________
batch_normalization_6 (Batch (None, 8, 8, 64)          256       
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 8, 8, 64)          0         
__________

70 [D loss: 0.762781, acc.: 59.38%] [G loss: 0.906338]
71 [D loss: 0.662751, acc.: 60.94%] [G loss: 0.872865]
72 [D loss: 0.888864, acc.: 46.88%] [G loss: 1.079451]
73 [D loss: 0.823328, acc.: 46.88%] [G loss: 1.297596]
74 [D loss: 1.055216, acc.: 42.19%] [G loss: 1.263545]
75 [D loss: 1.096108, acc.: 35.94%] [G loss: 1.246214]
76 [D loss: 0.782548, acc.: 60.94%] [G loss: 1.361449]
77 [D loss: 0.984988, acc.: 39.06%] [G loss: 1.064546]
78 [D loss: 0.897339, acc.: 48.44%] [G loss: 1.256221]
79 [D loss: 0.979739, acc.: 42.19%] [G loss: 1.131291]
80 [D loss: 0.851675, acc.: 50.00%] [G loss: 1.143041]
81 [D loss: 0.887380, acc.: 51.56%] [G loss: 1.206262]
82 [D loss: 0.899133, acc.: 48.44%] [G loss: 1.073696]
83 [D loss: 0.780976, acc.: 53.12%] [G loss: 1.160697]
84 [D loss: 0.887362, acc.: 37.50%] [G loss: 1.176769]
85 [D loss: 0.993293, acc.: 43.75%] [G loss: 1.185020]
86 [D loss: 1.119791, acc.: 34.38%] [G loss: 0.928324]
87 [D loss: 0.983354, acc.: 42.19%] [G loss: 1.059914]
88 [D loss

219 [D loss: 0.853937, acc.: 40.62%] [G loss: 0.961140]
220 [D loss: 0.776743, acc.: 48.44%] [G loss: 1.002175]
221 [D loss: 0.852411, acc.: 35.94%] [G loss: 0.904047]
222 [D loss: 0.793900, acc.: 51.56%] [G loss: 1.063722]
223 [D loss: 0.922010, acc.: 45.31%] [G loss: 1.088577]
224 [D loss: 0.922135, acc.: 46.88%] [G loss: 0.864376]
225 [D loss: 0.760819, acc.: 42.19%] [G loss: 1.101758]
226 [D loss: 0.930544, acc.: 42.19%] [G loss: 0.958565]
227 [D loss: 0.844191, acc.: 50.00%] [G loss: 1.092203]
228 [D loss: 0.858289, acc.: 46.88%] [G loss: 0.988899]
229 [D loss: 0.806957, acc.: 51.56%] [G loss: 0.954975]
230 [D loss: 0.905625, acc.: 45.31%] [G loss: 0.959502]
231 [D loss: 0.861151, acc.: 43.75%] [G loss: 1.057131]
232 [D loss: 0.813060, acc.: 46.88%] [G loss: 0.983158]
233 [D loss: 0.696446, acc.: 64.06%] [G loss: 1.092681]
234 [D loss: 0.760346, acc.: 54.69%] [G loss: 1.074688]
235 [D loss: 0.827270, acc.: 42.19%] [G loss: 0.930024]
236 [D loss: 0.779669, acc.: 51.56%] [G loss: 1.

372 [D loss: 0.760131, acc.: 42.19%] [G loss: 1.031789]
373 [D loss: 0.762250, acc.: 56.25%] [G loss: 1.088920]
374 [D loss: 0.762112, acc.: 46.88%] [G loss: 0.960469]
375 [D loss: 0.694089, acc.: 53.12%] [G loss: 0.975161]
376 [D loss: 0.919875, acc.: 40.62%] [G loss: 0.924743]
377 [D loss: 0.802324, acc.: 48.44%] [G loss: 0.904683]
378 [D loss: 0.784866, acc.: 45.31%] [G loss: 0.930586]
379 [D loss: 0.815278, acc.: 46.88%] [G loss: 1.028180]
380 [D loss: 0.881044, acc.: 43.75%] [G loss: 0.941486]
381 [D loss: 0.882279, acc.: 40.62%] [G loss: 1.095813]
382 [D loss: 0.700932, acc.: 59.38%] [G loss: 0.908998]
383 [D loss: 0.776947, acc.: 45.31%] [G loss: 0.985399]
384 [D loss: 0.816980, acc.: 37.50%] [G loss: 0.940297]
385 [D loss: 0.772488, acc.: 50.00%] [G loss: 0.890104]
386 [D loss: 0.712681, acc.: 53.12%] [G loss: 0.863937]
387 [D loss: 0.826732, acc.: 46.88%] [G loss: 0.909253]
388 [D loss: 0.910226, acc.: 32.81%] [G loss: 0.843665]
389 [D loss: 0.763310, acc.: 50.00%] [G loss: 0.

519 [D loss: 0.756301, acc.: 50.00%] [G loss: 0.877605]
520 [D loss: 0.744924, acc.: 54.69%] [G loss: 0.954007]
521 [D loss: 0.796599, acc.: 46.88%] [G loss: 1.097318]
522 [D loss: 0.901040, acc.: 42.19%] [G loss: 0.841578]
523 [D loss: 0.786325, acc.: 53.12%] [G loss: 0.989040]
524 [D loss: 0.735413, acc.: 53.12%] [G loss: 0.934489]
525 [D loss: 0.686691, acc.: 56.25%] [G loss: 1.091615]
526 [D loss: 0.829497, acc.: 37.50%] [G loss: 0.743876]
527 [D loss: 0.754539, acc.: 45.31%] [G loss: 1.072405]
528 [D loss: 0.686444, acc.: 56.25%] [G loss: 0.924680]
529 [D loss: 0.867660, acc.: 42.19%] [G loss: 0.927123]
530 [D loss: 0.825425, acc.: 40.62%] [G loss: 0.999630]
531 [D loss: 0.818684, acc.: 46.88%] [G loss: 1.063950]
532 [D loss: 0.794444, acc.: 43.75%] [G loss: 0.934129]
533 [D loss: 0.823258, acc.: 40.62%] [G loss: 0.995812]
534 [D loss: 0.797544, acc.: 42.19%] [G loss: 1.004571]
535 [D loss: 0.690372, acc.: 65.62%] [G loss: 0.903375]
536 [D loss: 0.812494, acc.: 45.31%] [G loss: 0.

667 [D loss: 0.718055, acc.: 54.69%] [G loss: 1.167697]
668 [D loss: 0.704033, acc.: 51.56%] [G loss: 0.963759]
669 [D loss: 0.803827, acc.: 48.44%] [G loss: 1.018965]
670 [D loss: 0.765557, acc.: 50.00%] [G loss: 0.775333]
671 [D loss: 0.702239, acc.: 51.56%] [G loss: 1.121590]
672 [D loss: 0.788842, acc.: 48.44%] [G loss: 1.000139]
673 [D loss: 0.700563, acc.: 50.00%] [G loss: 0.948287]
674 [D loss: 0.799315, acc.: 45.31%] [G loss: 0.958538]
675 [D loss: 0.728971, acc.: 46.88%] [G loss: 0.887369]
676 [D loss: 0.760038, acc.: 51.56%] [G loss: 0.888198]
677 [D loss: 0.690623, acc.: 56.25%] [G loss: 1.021244]
678 [D loss: 0.824115, acc.: 48.44%] [G loss: 0.946718]
679 [D loss: 0.661356, acc.: 56.25%] [G loss: 1.008405]
680 [D loss: 0.753410, acc.: 46.88%] [G loss: 1.037211]
681 [D loss: 0.737129, acc.: 53.12%] [G loss: 0.947881]
682 [D loss: 0.761025, acc.: 54.69%] [G loss: 0.910312]
683 [D loss: 0.807082, acc.: 45.31%] [G loss: 0.788077]
684 [D loss: 0.875679, acc.: 32.81%] [G loss: 0.

819 [D loss: 0.641937, acc.: 67.19%] [G loss: 0.903646]
820 [D loss: 0.660612, acc.: 59.38%] [G loss: 1.136652]
821 [D loss: 0.834312, acc.: 48.44%] [G loss: 0.906367]
822 [D loss: 0.727652, acc.: 56.25%] [G loss: 0.949559]
823 [D loss: 0.709825, acc.: 54.69%] [G loss: 0.963095]
824 [D loss: 0.842757, acc.: 43.75%] [G loss: 0.852526]
825 [D loss: 0.702861, acc.: 60.94%] [G loss: 1.088045]
826 [D loss: 0.727600, acc.: 56.25%] [G loss: 0.934490]
827 [D loss: 0.736174, acc.: 53.12%] [G loss: 0.975659]
828 [D loss: 0.737452, acc.: 54.69%] [G loss: 0.949612]
829 [D loss: 0.760531, acc.: 48.44%] [G loss: 0.915557]
830 [D loss: 0.765722, acc.: 50.00%] [G loss: 0.874049]
831 [D loss: 0.668007, acc.: 65.62%] [G loss: 1.177634]
832 [D loss: 0.630142, acc.: 57.81%] [G loss: 0.984288]
833 [D loss: 0.652585, acc.: 59.38%] [G loss: 0.973755]
834 [D loss: 0.697317, acc.: 50.00%] [G loss: 0.964671]
835 [D loss: 0.762977, acc.: 43.75%] [G loss: 0.967145]
836 [D loss: 0.817258, acc.: 39.06%] [G loss: 1.

971 [D loss: 0.855766, acc.: 53.12%] [G loss: 1.148401]
972 [D loss: 0.714060, acc.: 54.69%] [G loss: 0.891527]
973 [D loss: 0.756385, acc.: 53.12%] [G loss: 0.889993]
974 [D loss: 0.670347, acc.: 54.69%] [G loss: 0.971903]
975 [D loss: 0.661932, acc.: 59.38%] [G loss: 0.806111]
976 [D loss: 0.711199, acc.: 51.56%] [G loss: 1.013825]
977 [D loss: 0.715567, acc.: 56.25%] [G loss: 0.929060]
978 [D loss: 0.776022, acc.: 48.44%] [G loss: 0.860867]
979 [D loss: 0.759736, acc.: 42.19%] [G loss: 0.893738]
980 [D loss: 0.646381, acc.: 54.69%] [G loss: 0.980956]
981 [D loss: 0.715525, acc.: 60.94%] [G loss: 1.045639]
982 [D loss: 0.754872, acc.: 51.56%] [G loss: 1.081443]
983 [D loss: 0.797078, acc.: 40.62%] [G loss: 0.842014]
984 [D loss: 0.790355, acc.: 50.00%] [G loss: 0.925009]
985 [D loss: 0.773299, acc.: 48.44%] [G loss: 1.043961]
986 [D loss: 0.686424, acc.: 57.81%] [G loss: 0.997265]
987 [D loss: 0.717744, acc.: 57.81%] [G loss: 0.930494]
988 [D loss: 0.749049, acc.: 50.00%] [G loss: 0.

KeyboardInterrupt: 