## GAN

In [13]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.initializers import random_normal

import matplotlib.pyplot as plt

import sys

import numpy as np

In [14]:
class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, kernel_initializer=random_normal(), input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(1024, kernel_initializer=random_normal()))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # Adversarial ground truths
        valid = np.ones(batch_size)
        fake = np.zeros(batch_size)

        for epoch in range(epochs+1):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()

In [7]:
np.ones(128)

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [8]:
np.ones((128, 1))

array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],

In [12]:
if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=1000, batch_size=128, sample_interval=200)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_2 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_8 (Dense)              (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_9 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_10 (Dense)             (None, 1)                 257       
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
____

  'Discrepancy between trainable weights and collected trainable'


1 [D loss: 0.502631, acc.: 55.86%] [G loss: 0.608233]
2 [D loss: 0.378939, acc.: 72.66%] [G loss: 0.782753]
3 [D loss: 0.314956, acc.: 83.98%] [G loss: 0.929062]
4 [D loss: 0.266350, acc.: 93.36%] [G loss: 1.078284]
5 [D loss: 0.219549, acc.: 97.27%] [G loss: 1.225582]
6 [D loss: 0.193964, acc.: 98.44%] [G loss: 1.366628]
7 [D loss: 0.174427, acc.: 98.44%] [G loss: 1.511957]
8 [D loss: 0.146550, acc.: 99.61%] [G loss: 1.632651]
9 [D loss: 0.143761, acc.: 99.61%] [G loss: 1.704270]
10 [D loss: 0.125060, acc.: 99.61%] [G loss: 1.790219]
11 [D loss: 0.117557, acc.: 99.61%] [G loss: 1.878601]
12 [D loss: 0.104076, acc.: 100.00%] [G loss: 1.948069]
13 [D loss: 0.094227, acc.: 100.00%] [G loss: 2.061252]
14 [D loss: 0.085229, acc.: 100.00%] [G loss: 2.141508]
15 [D loss: 0.070948, acc.: 100.00%] [G loss: 2.227463]
16 [D loss: 0.072642, acc.: 100.00%] [G loss: 2.292915]
17 [D loss: 0.066963, acc.: 100.00%] [G loss: 2.346740]
18 [D loss: 0.062119, acc.: 100.00%] [G loss: 2.418702]
19 [D loss: 

149 [D loss: 0.234284, acc.: 92.19%] [G loss: 2.643116]
150 [D loss: 0.309858, acc.: 82.81%] [G loss: 3.535574]
151 [D loss: 1.477806, acc.: 39.45%] [G loss: 1.179516]
152 [D loss: 1.036940, acc.: 60.94%] [G loss: 1.196026]
153 [D loss: 0.434293, acc.: 70.31%] [G loss: 1.960806]
154 [D loss: 0.135506, acc.: 96.09%] [G loss: 3.063986]
155 [D loss: 0.225488, acc.: 96.88%] [G loss: 2.541502]
156 [D loss: 0.258233, acc.: 92.19%] [G loss: 2.127388]
157 [D loss: 0.219349, acc.: 93.75%] [G loss: 2.380096]
158 [D loss: 0.240224, acc.: 90.23%] [G loss: 2.270355]
159 [D loss: 0.242799, acc.: 91.80%] [G loss: 2.322944]
160 [D loss: 0.264515, acc.: 93.75%] [G loss: 2.216620]
161 [D loss: 0.313168, acc.: 85.55%] [G loss: 2.311587]
162 [D loss: 0.423127, acc.: 77.34%] [G loss: 2.009403]
163 [D loss: 0.305505, acc.: 85.94%] [G loss: 2.451000]
164 [D loss: 0.467811, acc.: 74.22%] [G loss: 2.084091]
165 [D loss: 0.224669, acc.: 94.14%] [G loss: 2.433859]
166 [D loss: 0.432774, acc.: 76.56%] [G loss: 1.

297 [D loss: 0.662853, acc.: 48.83%] [G loss: 0.628776]
298 [D loss: 0.664642, acc.: 49.61%] [G loss: 0.633471]
299 [D loss: 0.661709, acc.: 47.27%] [G loss: 0.633953]
300 [D loss: 0.668486, acc.: 48.05%] [G loss: 0.637163]
301 [D loss: 0.662295, acc.: 48.05%] [G loss: 0.637544]
302 [D loss: 0.653224, acc.: 50.39%] [G loss: 0.635807]
303 [D loss: 0.662181, acc.: 47.66%] [G loss: 0.631311]
304 [D loss: 0.651313, acc.: 49.61%] [G loss: 0.634278]
305 [D loss: 0.661334, acc.: 48.05%] [G loss: 0.630751]
306 [D loss: 0.659641, acc.: 50.00%] [G loss: 0.632068]
307 [D loss: 0.652520, acc.: 49.22%] [G loss: 0.634714]
308 [D loss: 0.652189, acc.: 48.83%] [G loss: 0.633669]
309 [D loss: 0.658552, acc.: 49.61%] [G loss: 0.636019]
310 [D loss: 0.656122, acc.: 49.61%] [G loss: 0.641348]
311 [D loss: 0.653933, acc.: 49.22%] [G loss: 0.643180]
312 [D loss: 0.656769, acc.: 48.83%] [G loss: 0.641327]
313 [D loss: 0.652302, acc.: 49.61%] [G loss: 0.645945]
314 [D loss: 0.654866, acc.: 48.44%] [G loss: 0.

444 [D loss: 0.639532, acc.: 58.20%] [G loss: 0.743047]
445 [D loss: 0.619885, acc.: 60.55%] [G loss: 0.747721]
446 [D loss: 0.632983, acc.: 61.33%] [G loss: 0.728128]
447 [D loss: 0.627236, acc.: 59.38%] [G loss: 0.718624]
448 [D loss: 0.637961, acc.: 55.86%] [G loss: 0.726679]
449 [D loss: 0.625116, acc.: 60.16%] [G loss: 0.735084]
450 [D loss: 0.628262, acc.: 61.33%] [G loss: 0.727777]
451 [D loss: 0.623910, acc.: 62.11%] [G loss: 0.724525]
452 [D loss: 0.610749, acc.: 66.41%] [G loss: 0.735028]
453 [D loss: 0.618964, acc.: 68.36%] [G loss: 0.725430]
454 [D loss: 0.616538, acc.: 69.14%] [G loss: 0.722293]
455 [D loss: 0.627449, acc.: 63.67%] [G loss: 0.721621]
456 [D loss: 0.624302, acc.: 67.58%] [G loss: 0.727432]
457 [D loss: 0.623203, acc.: 65.62%] [G loss: 0.735722]
458 [D loss: 0.621538, acc.: 63.67%] [G loss: 0.732979]
459 [D loss: 0.623751, acc.: 64.06%] [G loss: 0.741414]
460 [D loss: 0.632096, acc.: 61.72%] [G loss: 0.754096]
461 [D loss: 0.629811, acc.: 60.16%] [G loss: 0.

591 [D loss: 0.621795, acc.: 63.67%] [G loss: 0.751474]
592 [D loss: 0.617155, acc.: 68.36%] [G loss: 0.768209]
593 [D loss: 0.610699, acc.: 66.80%] [G loss: 0.779912]
594 [D loss: 0.611706, acc.: 67.97%] [G loss: 0.763199]
595 [D loss: 0.607320, acc.: 67.97%] [G loss: 0.777296]
596 [D loss: 0.619232, acc.: 62.89%] [G loss: 0.783575]
597 [D loss: 0.623017, acc.: 63.28%] [G loss: 0.779951]
598 [D loss: 0.611606, acc.: 66.41%] [G loss: 0.784342]
599 [D loss: 0.623934, acc.: 63.28%] [G loss: 0.788355]
600 [D loss: 0.624417, acc.: 64.84%] [G loss: 0.782001]
601 [D loss: 0.626042, acc.: 66.41%] [G loss: 0.776995]
602 [D loss: 0.633234, acc.: 60.55%] [G loss: 0.784090]
603 [D loss: 0.629794, acc.: 66.80%] [G loss: 0.779195]
604 [D loss: 0.635399, acc.: 63.67%] [G loss: 0.793558]
605 [D loss: 0.641308, acc.: 64.06%] [G loss: 0.789551]
606 [D loss: 0.637184, acc.: 63.67%] [G loss: 0.801161]
607 [D loss: 0.658972, acc.: 60.16%] [G loss: 0.798061]
608 [D loss: 0.629668, acc.: 66.80%] [G loss: 0.

738 [D loss: 0.625558, acc.: 65.62%] [G loss: 0.837945]
739 [D loss: 0.621596, acc.: 71.48%] [G loss: 0.824417]
740 [D loss: 0.623332, acc.: 66.80%] [G loss: 0.809942]
741 [D loss: 0.626137, acc.: 66.41%] [G loss: 0.805115]
742 [D loss: 0.629990, acc.: 63.67%] [G loss: 0.805090]
743 [D loss: 0.633897, acc.: 66.41%] [G loss: 0.810620]
744 [D loss: 0.622523, acc.: 64.84%] [G loss: 0.831880]
745 [D loss: 0.630401, acc.: 64.45%] [G loss: 0.801643]
746 [D loss: 0.613052, acc.: 69.14%] [G loss: 0.817512]
747 [D loss: 0.630739, acc.: 63.28%] [G loss: 0.810170]
748 [D loss: 0.620448, acc.: 62.50%] [G loss: 0.850345]
749 [D loss: 0.639448, acc.: 64.06%] [G loss: 0.818778]
750 [D loss: 0.637051, acc.: 66.02%] [G loss: 0.815640]
751 [D loss: 0.624559, acc.: 70.31%] [G loss: 0.795708]
752 [D loss: 0.604942, acc.: 73.83%] [G loss: 0.796513]
753 [D loss: 0.648019, acc.: 64.84%] [G loss: 0.800719]
754 [D loss: 0.590180, acc.: 74.61%] [G loss: 0.828948]
755 [D loss: 0.624725, acc.: 67.19%] [G loss: 0.

885 [D loss: 0.564549, acc.: 80.08%] [G loss: 0.916331]
886 [D loss: 0.583076, acc.: 75.39%] [G loss: 0.918519]
887 [D loss: 0.560488, acc.: 79.69%] [G loss: 0.915672]
888 [D loss: 0.573851, acc.: 76.17%] [G loss: 0.904014]
889 [D loss: 0.584193, acc.: 72.66%] [G loss: 0.908135]
890 [D loss: 0.552116, acc.: 80.47%] [G loss: 0.930791]
891 [D loss: 0.555113, acc.: 83.59%] [G loss: 0.924620]
892 [D loss: 0.563828, acc.: 73.83%] [G loss: 0.948047]
893 [D loss: 0.589851, acc.: 74.61%] [G loss: 0.921921]
894 [D loss: 0.576627, acc.: 73.05%] [G loss: 0.927448]
895 [D loss: 0.605593, acc.: 71.48%] [G loss: 0.933859]
896 [D loss: 0.570604, acc.: 76.56%] [G loss: 0.940138]
897 [D loss: 0.587621, acc.: 76.56%] [G loss: 0.914131]
898 [D loss: 0.563459, acc.: 77.73%] [G loss: 0.915940]
899 [D loss: 0.564318, acc.: 77.34%] [G loss: 0.904638]
900 [D loss: 0.577954, acc.: 75.00%] [G loss: 0.922927]
901 [D loss: 0.554828, acc.: 81.25%] [G loss: 0.926977]
902 [D loss: 0.561092, acc.: 80.47%] [G loss: 0.