In [8]:
from __future__ import print_function, division

In [2]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [3]:
import matplotlib.pyplot as plt

In [4]:
import sys

In [5]:
import numpy as np

In [6]:
class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(100,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, 100))

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, 100))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()

In [7]:
if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, sample_interval=200)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 257       
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
____

  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 1.039234, acc.: 50.00%] [G loss: 0.790982]
1 [D loss: 0.447138, acc.: 78.12%] [G loss: 0.819692]
2 [D loss: 0.352639, acc.: 84.38%] [G loss: 0.895983]
3 [D loss: 0.288668, acc.: 89.06%] [G loss: 1.006849]
4 [D loss: 0.259122, acc.: 93.75%] [G loss: 1.144986]
5 [D loss: 0.227110, acc.: 96.88%] [G loss: 1.248038]
6 [D loss: 0.185942, acc.: 96.88%] [G loss: 1.342166]
7 [D loss: 0.161515, acc.: 100.00%] [G loss: 1.426543]
8 [D loss: 0.164785, acc.: 96.88%] [G loss: 1.501548]
9 [D loss: 0.157062, acc.: 100.00%] [G loss: 1.667943]
10 [D loss: 0.119932, acc.: 100.00%] [G loss: 1.643316]
11 [D loss: 0.122005, acc.: 100.00%] [G loss: 1.800172]
12 [D loss: 0.119157, acc.: 98.44%] [G loss: 1.827316]
13 [D loss: 0.101439, acc.: 100.00%] [G loss: 2.033730]
14 [D loss: 0.098787, acc.: 100.00%] [G loss: 1.959269]
15 [D loss: 0.086246, acc.: 100.00%] [G loss: 2.032908]
16 [D loss: 0.079014, acc.: 100.00%] [G loss: 2.088560]
17 [D loss: 0.076932, acc.: 100.00%] [G loss: 2.216519]
18 [D loss:

148 [D loss: 0.231609, acc.: 87.50%] [G loss: 3.312974]
149 [D loss: 0.154273, acc.: 95.31%] [G loss: 3.889272]
150 [D loss: 0.349878, acc.: 82.81%] [G loss: 3.270880]
151 [D loss: 0.143616, acc.: 96.88%] [G loss: 4.284002]
152 [D loss: 0.706647, acc.: 75.00%] [G loss: 2.936396]
153 [D loss: 0.094626, acc.: 98.44%] [G loss: 3.507365]
154 [D loss: 0.098106, acc.: 98.44%] [G loss: 3.526769]
155 [D loss: 0.218619, acc.: 89.06%] [G loss: 3.582727]
156 [D loss: 0.452143, acc.: 73.44%] [G loss: 3.973213]
157 [D loss: 0.269161, acc.: 89.06%] [G loss: 3.347941]
158 [D loss: 0.170253, acc.: 96.88%] [G loss: 3.496710]
159 [D loss: 0.350018, acc.: 87.50%] [G loss: 2.878919]
160 [D loss: 0.143355, acc.: 96.88%] [G loss: 3.438194]
161 [D loss: 0.882504, acc.: 67.19%] [G loss: 2.110136]
162 [D loss: 0.178013, acc.: 90.62%] [G loss: 2.852975]
163 [D loss: 0.303628, acc.: 87.50%] [G loss: 2.841530]
164 [D loss: 0.246786, acc.: 89.06%] [G loss: 3.014622]
165 [D loss: 0.202978, acc.: 95.31%] [G loss: 2.

297 [D loss: 0.843540, acc.: 39.06%] [G loss: 0.995038]
298 [D loss: 0.622171, acc.: 60.94%] [G loss: 1.513192]
299 [D loss: 0.787911, acc.: 37.50%] [G loss: 1.079866]
300 [D loss: 0.630210, acc.: 56.25%] [G loss: 1.397316]
301 [D loss: 0.687848, acc.: 56.25%] [G loss: 1.179470]
302 [D loss: 0.579277, acc.: 71.88%] [G loss: 1.473758]
303 [D loss: 0.647292, acc.: 60.94%] [G loss: 1.195024]
304 [D loss: 0.685287, acc.: 57.81%] [G loss: 1.283339]
305 [D loss: 0.638916, acc.: 60.94%] [G loss: 1.271784]
306 [D loss: 0.668066, acc.: 56.25%] [G loss: 1.174040]
307 [D loss: 0.676945, acc.: 60.94%] [G loss: 1.188412]
308 [D loss: 0.585641, acc.: 64.06%] [G loss: 1.065813]
309 [D loss: 0.852167, acc.: 43.75%] [G loss: 0.836108]
310 [D loss: 0.677417, acc.: 54.69%] [G loss: 1.012174]
311 [D loss: 0.691543, acc.: 46.88%] [G loss: 1.175586]
312 [D loss: 0.934452, acc.: 25.00%] [G loss: 0.741453]
313 [D loss: 0.690563, acc.: 53.12%] [G loss: 0.960199]
314 [D loss: 0.649495, acc.: 62.50%] [G loss: 1.

445 [D loss: 0.677230, acc.: 45.31%] [G loss: 0.694951]
446 [D loss: 0.653687, acc.: 50.00%] [G loss: 0.701994]
447 [D loss: 0.651185, acc.: 46.88%] [G loss: 0.721748]
448 [D loss: 0.649712, acc.: 50.00%] [G loss: 0.719996]
449 [D loss: 0.658161, acc.: 50.00%] [G loss: 0.719654]
450 [D loss: 0.642728, acc.: 51.56%] [G loss: 0.746638]
451 [D loss: 0.626432, acc.: 53.12%] [G loss: 0.736306]
452 [D loss: 0.615599, acc.: 56.25%] [G loss: 0.746634]
453 [D loss: 0.631092, acc.: 54.69%] [G loss: 0.754494]
454 [D loss: 0.630078, acc.: 62.50%] [G loss: 0.747873]
455 [D loss: 0.603085, acc.: 64.06%] [G loss: 0.741145]
456 [D loss: 0.605007, acc.: 64.06%] [G loss: 0.753742]
457 [D loss: 0.621507, acc.: 59.38%] [G loss: 0.775674]
458 [D loss: 0.660868, acc.: 51.56%] [G loss: 0.741692]
459 [D loss: 0.641177, acc.: 50.00%] [G loss: 0.729628]
460 [D loss: 0.633574, acc.: 54.69%] [G loss: 0.734864]
461 [D loss: 0.676517, acc.: 51.56%] [G loss: 0.710771]
462 [D loss: 0.617529, acc.: 56.25%] [G loss: 0.

592 [D loss: 0.652238, acc.: 57.81%] [G loss: 0.753224]
593 [D loss: 0.670662, acc.: 50.00%] [G loss: 0.726719]
594 [D loss: 0.643540, acc.: 62.50%] [G loss: 0.735327]
595 [D loss: 0.629363, acc.: 60.94%] [G loss: 0.764542]
596 [D loss: 0.649915, acc.: 59.38%] [G loss: 0.739950]
597 [D loss: 0.661535, acc.: 59.38%] [G loss: 0.770043]
598 [D loss: 0.622566, acc.: 62.50%] [G loss: 0.784754]
599 [D loss: 0.644223, acc.: 56.25%] [G loss: 0.771602]
600 [D loss: 0.637295, acc.: 60.94%] [G loss: 0.782759]
601 [D loss: 0.661038, acc.: 51.56%] [G loss: 0.775782]
602 [D loss: 0.611352, acc.: 71.88%] [G loss: 0.775686]
603 [D loss: 0.643964, acc.: 68.75%] [G loss: 0.761788]
604 [D loss: 0.646055, acc.: 57.81%] [G loss: 0.767645]
605 [D loss: 0.649958, acc.: 54.69%] [G loss: 0.776400]
606 [D loss: 0.635963, acc.: 60.94%] [G loss: 0.798469]
607 [D loss: 0.682827, acc.: 46.88%] [G loss: 0.762610]
608 [D loss: 0.637136, acc.: 59.38%] [G loss: 0.757759]
609 [D loss: 0.632424, acc.: 56.25%] [G loss: 0.

740 [D loss: 0.627298, acc.: 65.62%] [G loss: 0.821568]
741 [D loss: 0.598128, acc.: 71.88%] [G loss: 0.813590]
742 [D loss: 0.609303, acc.: 68.75%] [G loss: 0.822433]
743 [D loss: 0.608781, acc.: 65.62%] [G loss: 0.822203]
744 [D loss: 0.618708, acc.: 67.19%] [G loss: 0.833003]
745 [D loss: 0.639489, acc.: 60.94%] [G loss: 0.831568]
746 [D loss: 0.606726, acc.: 67.19%] [G loss: 0.834782]
747 [D loss: 0.611519, acc.: 71.88%] [G loss: 0.814656]
748 [D loss: 0.565677, acc.: 81.25%] [G loss: 0.811943]
749 [D loss: 0.623799, acc.: 59.38%] [G loss: 0.812000]
750 [D loss: 0.610264, acc.: 65.62%] [G loss: 0.805086]
751 [D loss: 0.624873, acc.: 57.81%] [G loss: 0.822593]
752 [D loss: 0.634506, acc.: 64.06%] [G loss: 0.827058]
753 [D loss: 0.617995, acc.: 71.88%] [G loss: 0.860390]
754 [D loss: 0.653347, acc.: 57.81%] [G loss: 0.844799]
755 [D loss: 0.594134, acc.: 73.44%] [G loss: 0.824087]
756 [D loss: 0.617424, acc.: 68.75%] [G loss: 0.808748]
757 [D loss: 0.617346, acc.: 60.94%] [G loss: 0.

887 [D loss: 0.600051, acc.: 70.31%] [G loss: 0.787597]
888 [D loss: 0.607244, acc.: 67.19%] [G loss: 0.793436]
889 [D loss: 0.627368, acc.: 59.38%] [G loss: 0.818939]
890 [D loss: 0.608815, acc.: 70.31%] [G loss: 0.793801]
891 [D loss: 0.626025, acc.: 71.88%] [G loss: 0.802948]
892 [D loss: 0.637582, acc.: 64.06%] [G loss: 0.782724]
893 [D loss: 0.614354, acc.: 64.06%] [G loss: 0.815331]
894 [D loss: 0.606608, acc.: 64.06%] [G loss: 0.842756]
895 [D loss: 0.599406, acc.: 73.44%] [G loss: 0.804134]
896 [D loss: 0.655301, acc.: 57.81%] [G loss: 0.789476]
897 [D loss: 0.605803, acc.: 70.31%] [G loss: 0.787119]
898 [D loss: 0.598228, acc.: 67.19%] [G loss: 0.803959]
899 [D loss: 0.606782, acc.: 70.31%] [G loss: 0.825329]
900 [D loss: 0.635665, acc.: 62.50%] [G loss: 0.821648]
901 [D loss: 0.649948, acc.: 57.81%] [G loss: 0.847337]
902 [D loss: 0.598033, acc.: 73.44%] [G loss: 0.872109]
903 [D loss: 0.615018, acc.: 67.19%] [G loss: 0.885124]
904 [D loss: 0.625888, acc.: 73.44%] [G loss: 0.

1034 [D loss: 0.537844, acc.: 85.94%] [G loss: 0.926093]
1035 [D loss: 0.566874, acc.: 82.81%] [G loss: 0.936471]
1036 [D loss: 0.600901, acc.: 73.44%] [G loss: 0.873445]
1037 [D loss: 0.576367, acc.: 79.69%] [G loss: 0.908608]
1038 [D loss: 0.633784, acc.: 70.31%] [G loss: 0.875882]
1039 [D loss: 0.604270, acc.: 71.88%] [G loss: 0.917000]
1040 [D loss: 0.611861, acc.: 73.44%] [G loss: 0.877870]
1041 [D loss: 0.590397, acc.: 73.44%] [G loss: 0.870021]
1042 [D loss: 0.575261, acc.: 75.00%] [G loss: 0.847110]
1043 [D loss: 0.595783, acc.: 79.69%] [G loss: 0.869876]
1044 [D loss: 0.581152, acc.: 73.44%] [G loss: 0.879406]
1045 [D loss: 0.605585, acc.: 64.06%] [G loss: 0.845664]
1046 [D loss: 0.595882, acc.: 78.12%] [G loss: 0.876356]
1047 [D loss: 0.595475, acc.: 71.88%] [G loss: 0.882981]
1048 [D loss: 0.561725, acc.: 78.12%] [G loss: 0.900892]
1049 [D loss: 0.616193, acc.: 71.88%] [G loss: 0.862519]
1050 [D loss: 0.573112, acc.: 82.81%] [G loss: 0.900053]
1051 [D loss: 0.585173, acc.: 7

1178 [D loss: 0.606752, acc.: 68.75%] [G loss: 0.924847]
1179 [D loss: 0.625308, acc.: 62.50%] [G loss: 0.966372]
1180 [D loss: 0.572664, acc.: 79.69%] [G loss: 0.923102]
1181 [D loss: 0.607483, acc.: 70.31%] [G loss: 0.899372]
1182 [D loss: 0.580021, acc.: 71.88%] [G loss: 0.877638]
1183 [D loss: 0.560993, acc.: 73.44%] [G loss: 0.909315]
1184 [D loss: 0.561012, acc.: 81.25%] [G loss: 0.881297]
1185 [D loss: 0.604904, acc.: 67.19%] [G loss: 0.921420]
1186 [D loss: 0.577526, acc.: 75.00%] [G loss: 0.928809]
1187 [D loss: 0.594941, acc.: 75.00%] [G loss: 0.906286]
1188 [D loss: 0.545067, acc.: 76.56%] [G loss: 0.968898]
1189 [D loss: 0.695269, acc.: 48.44%] [G loss: 0.869926]
1190 [D loss: 0.544547, acc.: 81.25%] [G loss: 0.993310]
1191 [D loss: 0.526085, acc.: 87.50%] [G loss: 1.000656]
1192 [D loss: 0.580419, acc.: 76.56%] [G loss: 0.929203]
1193 [D loss: 0.552430, acc.: 82.81%] [G loss: 0.902621]
1194 [D loss: 0.563710, acc.: 68.75%] [G loss: 0.903166]
1195 [D loss: 0.618345, acc.: 6

1322 [D loss: 0.612187, acc.: 65.62%] [G loss: 0.907871]
1323 [D loss: 0.602788, acc.: 68.75%] [G loss: 0.887438]
1324 [D loss: 0.616627, acc.: 70.31%] [G loss: 0.858368]
1325 [D loss: 0.571182, acc.: 70.31%] [G loss: 0.836175]
1326 [D loss: 0.629228, acc.: 62.50%] [G loss: 0.954596]
1327 [D loss: 0.583036, acc.: 75.00%] [G loss: 0.953842]
1328 [D loss: 0.639980, acc.: 65.62%] [G loss: 0.919710]
1329 [D loss: 0.642002, acc.: 59.38%] [G loss: 0.892883]
1330 [D loss: 0.567804, acc.: 73.44%] [G loss: 0.943448]
1331 [D loss: 0.646589, acc.: 60.94%] [G loss: 0.901417]
1332 [D loss: 0.616387, acc.: 65.62%] [G loss: 0.968925]
1333 [D loss: 0.540743, acc.: 81.25%] [G loss: 0.978859]
1334 [D loss: 0.585187, acc.: 67.19%] [G loss: 1.005361]
1335 [D loss: 0.608524, acc.: 76.56%] [G loss: 0.891405]
1336 [D loss: 0.580431, acc.: 71.88%] [G loss: 0.915802]
1337 [D loss: 0.611278, acc.: 64.06%] [G loss: 0.911230]
1338 [D loss: 0.577546, acc.: 70.31%] [G loss: 0.950059]
1339 [D loss: 0.618587, acc.: 6

KeyboardInterrupt: 