In [None]:
# https://arxiv.org/pdf/1511.06434.pdf

In [1]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, MaxPooling2D, BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD

import matplotlib.pyplot as plt

import sys

import numpy as np

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [8]:
class DCGAN():
    def __init__(self):
        self.latent_dim = 100
        self.rows = 28
        self.cols = 28
        self.channels = 1
        
        self.sgd = SGD(lr=0.0005, decay=2e-5, momentum=0.9, nesterov=True)
        self.adam = Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=5e-4)

        self.discriminator = self.discriminator_model()
        self.discriminator.compile(loss = 'binary_crossentropy', optimizer = self.adam, metrics = ["accuracy"])
        
        self.generator = self.generator_model()
        
        generator_input = Input(shape=(self.latent_dim,))
        generator_output = self.generator(generator_input)
        self.discriminator.trainable = False
        discriminator_output = self.discriminator(generator_output)
        
        self.combined_model = Model(generator_input, discriminator_output)
        self.combined_model.compile(loss = 'binary_crossentropy', optimizer = self.adam, metrics = ["accuracy"])
        
    def generator_model(self):

        gen_input = Input(shape=(self.latent_dim,))
        dense_1 = Dense(1024, activation="tanh")(gen_input)
        dense_1 = Dense(128 * 7 * 7, activation="tanh")(dense_1)
        dense_1 = BatchNormalization()(dense_1)
        dense_reshape = Reshape((7, 7, 128))(dense_1)
        
        upsample_1 = UpSampling2D(size = 2)(dense_reshape)
        conv_1 = Conv2D(64, kernel_size=5, padding="same", activation='tanh')(upsample_1)
        conv_1 = BatchNormalization()(conv_1)
        
        upsample_2 = UpSampling2D(size = 2)(conv_1)
        out = Conv2D(self.channels, kernel_size=5, padding="same", activation='tanh')(upsample_2)
        
        model = Model(gen_input, out)

        print(model.summary())
        
        return model

    def discriminator_model(self):
        
        dis_input = Input(shape=(self.rows, self.cols, self.channels))
        
        conv_1 = Conv2D(64, kernel_size=5, padding="same", activation="tanh")(dis_input)
        pool_1 = MaxPooling2D(pool_size=(2,2))(conv_1)
        
        conv_2 = Conv2D(128, kernel_size=5, padding="same", activation="tanh")(pool_1)
        pool_2 = MaxPooling2D(pool_size=(2,2))(conv_2)
        
        flatten = Flatten()(pool_2)
        flatten = Dense(1024, activation='tanh')(flatten)
        
        out = Dense(1, activation="sigmoid")(flatten)
        
        model = Model(dis_input, out)
        
        print(model.summary())
        return model
    
    def train(self, epochs, batch_size=128, save_interval=100):
        (X_train, _), (_, _) = mnist.load_data()
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)
        print (X_train.shape)

        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        
        for i in range(epochs):
            for iter_count in range(X_train.shape[0]//batch_size):
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs = X_train[idx]

                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                gen_imgs = self.generator.predict(noise)

                d_loss_real = self.discriminator.train_on_batch(imgs, valid)
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                g_loss = self.combined_model.train_on_batch(noise, valid)

                if iter_count % save_interval == 0:
                    print ("%d %d [D loss: %f, acc.: %.2f%%] [G loss: %f, acc.: %.2f%%]" % (i, iter_count, d_loss[0], 100*d_loss[1], g_loss[0], 100*g_loss[1]))
            
            self.save_imgs(i)                
            self.combined_model.save('combined_model.h5')
            
    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()
        

In [9]:
if __name__ == '__main__':
    dcgan = DCGAN()
    dcgan.train(epochs=80, batch_size=128, save_interval=50)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_10 (InputLayer)        (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 28, 28, 64)        1664      
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 14, 14, 128)       204928    
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 7, 7, 128)         0         
_________________________________________________________________
flatten_4 (Flatten)          (None, 6272)              0         
_________________________________________________________________
dense_13 (Dense)             (None, 1024)              6423552   
__________

7 150 [D loss: 0.371119, acc.: 84.77%] [G loss: 1.634735, acc.: 10.16%]
7 200 [D loss: 0.298690, acc.: 91.41%] [G loss: 1.952257, acc.: 4.69%]
7 250 [D loss: 0.376679, acc.: 83.98%] [G loss: 1.655778, acc.: 12.50%]
7 300 [D loss: 0.390392, acc.: 82.81%] [G loss: 1.654531, acc.: 7.81%]
7 350 [D loss: 0.364750, acc.: 84.38%] [G loss: 1.729358, acc.: 10.16%]
7 400 [D loss: 0.348464, acc.: 86.33%] [G loss: 1.760110, acc.: 6.25%]
7 450 [D loss: 0.412649, acc.: 82.81%] [G loss: 1.623608, acc.: 10.94%]
8 0 [D loss: 0.361467, acc.: 83.59%] [G loss: 1.721581, acc.: 8.59%]
8 50 [D loss: 0.340875, acc.: 88.28%] [G loss: 1.756008, acc.: 6.25%]
8 100 [D loss: 0.357502, acc.: 85.55%] [G loss: 1.708528, acc.: 8.59%]
8 150 [D loss: 0.406959, acc.: 82.03%] [G loss: 1.545867, acc.: 13.28%]
8 200 [D loss: 0.332079, acc.: 90.23%] [G loss: 1.606299, acc.: 7.81%]
8 250 [D loss: 0.392856, acc.: 82.42%] [G loss: 1.687578, acc.: 10.16%]
8 300 [D loss: 0.359290, acc.: 86.72%] [G loss: 1.738871, acc.: 8.59%]
8 3

18 350 [D loss: 0.513635, acc.: 73.05%] [G loss: 1.214139, acc.: 21.09%]
18 400 [D loss: 0.488209, acc.: 78.12%] [G loss: 1.251646, acc.: 18.75%]
18 450 [D loss: 0.484658, acc.: 77.34%] [G loss: 1.306176, acc.: 13.28%]
19 0 [D loss: 0.495206, acc.: 76.17%] [G loss: 1.229553, acc.: 23.44%]
19 50 [D loss: 0.509863, acc.: 76.95%] [G loss: 1.306238, acc.: 14.06%]
19 100 [D loss: 0.524494, acc.: 76.17%] [G loss: 1.251524, acc.: 14.84%]
19 150 [D loss: 0.496396, acc.: 77.73%] [G loss: 1.369751, acc.: 9.38%]
19 200 [D loss: 0.465988, acc.: 76.95%] [G loss: 1.272873, acc.: 17.97%]
19 250 [D loss: 0.555810, acc.: 71.88%] [G loss: 1.206947, acc.: 19.53%]
19 300 [D loss: 0.499247, acc.: 76.56%] [G loss: 1.240422, acc.: 19.53%]
19 350 [D loss: 0.521939, acc.: 76.95%] [G loss: 1.166323, acc.: 17.97%]
19 400 [D loss: 0.492757, acc.: 75.78%] [G loss: 1.223638, acc.: 21.88%]
19 450 [D loss: 0.454394, acc.: 79.30%] [G loss: 1.206277, acc.: 14.06%]
20 0 [D loss: 0.491519, acc.: 77.73%] [G loss: 1.333728

30 0 [D loss: 0.492449, acc.: 74.61%] [G loss: 1.204600, acc.: 21.09%]
30 50 [D loss: 0.487481, acc.: 76.56%] [G loss: 1.173996, acc.: 21.88%]
30 100 [D loss: 0.529163, acc.: 76.17%] [G loss: 1.178397, acc.: 15.62%]
30 150 [D loss: 0.528704, acc.: 75.39%] [G loss: 1.186449, acc.: 19.53%]
30 200 [D loss: 0.530522, acc.: 71.88%] [G loss: 1.115368, acc.: 24.22%]
30 250 [D loss: 0.505788, acc.: 76.95%] [G loss: 1.255973, acc.: 10.94%]
30 300 [D loss: 0.554526, acc.: 71.48%] [G loss: 1.049640, acc.: 28.12%]
30 350 [D loss: 0.513671, acc.: 76.95%] [G loss: 1.142732, acc.: 20.31%]
30 400 [D loss: 0.476049, acc.: 77.34%] [G loss: 1.224875, acc.: 15.62%]
30 450 [D loss: 0.515460, acc.: 74.22%] [G loss: 1.246697, acc.: 11.72%]
31 0 [D loss: 0.532045, acc.: 73.05%] [G loss: 1.195916, acc.: 14.84%]
31 50 [D loss: 0.503788, acc.: 75.78%] [G loss: 1.191438, acc.: 21.88%]
31 100 [D loss: 0.500402, acc.: 74.61%] [G loss: 1.138549, acc.: 19.53%]
31 150 [D loss: 0.462133, acc.: 79.30%] [G loss: 1.301240

41 150 [D loss: 0.537245, acc.: 73.05%] [G loss: 1.161621, acc.: 17.97%]
41 200 [D loss: 0.500524, acc.: 75.00%] [G loss: 1.158519, acc.: 17.97%]
41 250 [D loss: 0.536505, acc.: 73.05%] [G loss: 1.123056, acc.: 23.44%]
41 300 [D loss: 0.482826, acc.: 75.78%] [G loss: 1.193088, acc.: 20.31%]
41 350 [D loss: 0.526603, acc.: 78.91%] [G loss: 1.086971, acc.: 24.22%]
41 400 [D loss: 0.504180, acc.: 76.56%] [G loss: 1.152435, acc.: 15.62%]
41 450 [D loss: 0.552327, acc.: 73.44%] [G loss: 1.106067, acc.: 26.56%]
42 0 [D loss: 0.511043, acc.: 75.39%] [G loss: 1.089278, acc.: 18.75%]
42 50 [D loss: 0.534058, acc.: 74.61%] [G loss: 1.123866, acc.: 21.09%]
42 100 [D loss: 0.528888, acc.: 71.88%] [G loss: 1.130122, acc.: 21.88%]
42 150 [D loss: 0.498924, acc.: 77.34%] [G loss: 1.204607, acc.: 14.84%]
42 200 [D loss: 0.484959, acc.: 78.91%] [G loss: 1.197028, acc.: 17.19%]
42 250 [D loss: 0.528579, acc.: 73.44%] [G loss: 1.138817, acc.: 17.97%]
42 300 [D loss: 0.535607, acc.: 73.44%] [G loss: 1.174

52 300 [D loss: 0.515929, acc.: 73.44%] [G loss: 1.188671, acc.: 18.75%]
52 350 [D loss: 0.475070, acc.: 77.73%] [G loss: 1.186671, acc.: 20.31%]
52 400 [D loss: 0.503523, acc.: 73.44%] [G loss: 1.095903, acc.: 23.44%]
52 450 [D loss: 0.526235, acc.: 73.83%] [G loss: 1.161411, acc.: 18.75%]
53 0 [D loss: 0.501170, acc.: 74.22%] [G loss: 1.228190, acc.: 16.41%]
53 50 [D loss: 0.511831, acc.: 76.17%] [G loss: 1.097444, acc.: 26.56%]
53 100 [D loss: 0.535668, acc.: 75.39%] [G loss: 1.121582, acc.: 21.09%]
53 150 [D loss: 0.514975, acc.: 76.56%] [G loss: 1.197478, acc.: 21.09%]
53 200 [D loss: 0.515652, acc.: 76.56%] [G loss: 1.142121, acc.: 17.97%]
53 250 [D loss: 0.475955, acc.: 79.69%] [G loss: 1.141556, acc.: 18.75%]
53 300 [D loss: 0.521538, acc.: 74.22%] [G loss: 1.176307, acc.: 19.53%]
53 350 [D loss: 0.593003, acc.: 69.14%] [G loss: 1.090773, acc.: 25.78%]
53 400 [D loss: 0.495460, acc.: 76.95%] [G loss: 1.168342, acc.: 21.88%]
53 450 [D loss: 0.515881, acc.: 76.95%] [G loss: 1.172

63 450 [D loss: 0.558838, acc.: 71.09%] [G loss: 1.082309, acc.: 17.97%]
64 0 [D loss: 0.523711, acc.: 74.61%] [G loss: 1.116993, acc.: 18.75%]
64 50 [D loss: 0.513154, acc.: 73.05%] [G loss: 1.083929, acc.: 26.56%]
64 100 [D loss: 0.515332, acc.: 76.56%] [G loss: 1.107011, acc.: 22.66%]
64 150 [D loss: 0.551476, acc.: 71.09%] [G loss: 1.075144, acc.: 22.66%]
64 200 [D loss: 0.556912, acc.: 69.53%] [G loss: 1.138586, acc.: 17.97%]
64 250 [D loss: 0.573630, acc.: 70.31%] [G loss: 1.106667, acc.: 27.34%]
64 300 [D loss: 0.536184, acc.: 73.05%] [G loss: 1.138114, acc.: 16.41%]
64 350 [D loss: 0.521057, acc.: 75.00%] [G loss: 1.110577, acc.: 23.44%]
64 400 [D loss: 0.559045, acc.: 75.00%] [G loss: 1.060055, acc.: 28.12%]
64 450 [D loss: 0.557193, acc.: 72.66%] [G loss: 1.082734, acc.: 23.44%]
65 0 [D loss: 0.496901, acc.: 77.73%] [G loss: 1.156622, acc.: 21.09%]
65 50 [D loss: 0.542533, acc.: 72.66%] [G loss: 1.044188, acc.: 25.00%]
65 100 [D loss: 0.559922, acc.: 66.80%] [G loss: 1.127904

75 100 [D loss: 0.540562, acc.: 73.05%] [G loss: 1.102075, acc.: 21.09%]
75 150 [D loss: 0.560738, acc.: 73.05%] [G loss: 1.134668, acc.: 19.53%]
75 200 [D loss: 0.524836, acc.: 74.61%] [G loss: 1.152559, acc.: 22.66%]
75 250 [D loss: 0.536612, acc.: 75.39%] [G loss: 1.047552, acc.: 27.34%]
75 300 [D loss: 0.510277, acc.: 75.39%] [G loss: 1.122161, acc.: 22.66%]
75 350 [D loss: 0.511582, acc.: 76.56%] [G loss: 1.139919, acc.: 19.53%]
75 400 [D loss: 0.544493, acc.: 71.09%] [G loss: 1.009753, acc.: 28.91%]
75 450 [D loss: 0.507205, acc.: 75.39%] [G loss: 1.090768, acc.: 22.66%]
76 0 [D loss: 0.606349, acc.: 72.27%] [G loss: 1.041965, acc.: 23.44%]
76 50 [D loss: 0.538095, acc.: 71.48%] [G loss: 1.131898, acc.: 25.00%]
76 100 [D loss: 0.537540, acc.: 73.44%] [G loss: 1.168484, acc.: 19.53%]
76 150 [D loss: 0.504379, acc.: 77.34%] [G loss: 1.075953, acc.: 23.44%]
76 200 [D loss: 0.510217, acc.: 72.27%] [G loss: 1.180113, acc.: 21.88%]
76 250 [D loss: 0.509418, acc.: 74.61%] [G loss: 1.158