In [64]:
from __future__ import print_function, division
from tqdm import tqdm
tqdm.monitor_interval = 0

from keras.layers import Input, Dense, Flatten, Reshape
from keras.datasets import mnist
from keras.layers import BatchNormalization, ZeroPadding2D, Activation
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import UpSampling2D, Conv2D, MaxPooling2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt
import sys
import numpy as np

In [65]:
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
X_train = X_train.astype('float32')
X_train = (X_train - 127.5) / 127.5
X_train.shape

(60000, 28, 28, 1)

In [66]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows,img_cols,channels)
optimizer = Adam(0.0002,0.5)

In [67]:
def save_imgs(generator,epoch):
        r,c = 5,5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("mnist_%d.png" % epoch)
        plt.close()

In [68]:
def build_generator():
    noise_shape = (100,)
    model = Sequential()
    model.add(Dense(128*7*7,input_shape=noise_shape))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Reshape((7,7,128),input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2,2)))
    model.add(Conv2D(64,(5,5),padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(UpSampling2D(size=(2,2)))
    model.add(Conv2D(32,(5,5),padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Conv2D(1,(5,5),padding='same'))
    model.summary()

    noise = Input(shape=noise_shape)
    img = model(noise)
    return Model(noise, img)

In [69]:
def build_discriminator():
    img_shape = (img_rows,img_cols, channels)
    model = Sequential()
    model.add(Conv2D(32,(5,5),strides=(2,2),padding='same',input_shape=img_shape))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(164,(5,5),strides=(2,2),padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(128,(5,5),strides=(2,2),padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=img_shape)
    validity = model(img)
    return Model(img, validity)

In [70]:
generator = build_generator()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_10 (Dense)             (None, 6272)              633472    
_________________________________________________________________
batch_normalization_28 (Batc (None, 6272)              25088     
_________________________________________________________________
activation_20 (Activation)   (None, 6272)              0         
_________________________________________________________________
reshape_6 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d_11 (UpSampling (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_28 (Conv2D)           (None, 14, 14, 64)        204864    
_________________________________________________________________
batch_normalization_29 (Batc (None, 14, 14, 64)        256       
__________

In [71]:
discriminator = build_discriminator()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_31 (Conv2D)           (None, 14, 14, 32)        832       
_________________________________________________________________
batch_normalization_31 (Batc (None, 14, 14, 32)        128       
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU)    (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_32 (Conv2D)           (None, 7, 7, 164)         131364    
_________________________________________________________________
batch_normalization_32 (Batc (None, 7, 7, 164)         656       
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU)   (None, 7, 7, 164)         0         
_________________________________________________________________
conv2d_33 (Conv2D)           (None, 4, 4, 128)         524928    
__________

In [72]:
generator.compile(loss='binary_crossentropy',optimizer=optimizer)
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer)

In [73]:
def build_gan(gen,dis):
    dis.trainable = False
    inp = Input(shape=(100,))
    out = dis(gen(inp))
    return Model(inp,out)

In [74]:
gan = build_gan(generator,discriminator)
gan.summary()
gan.compile(loss='binary_crossentropy',optimizer=optimizer)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_18 (InputLayer)        (None, 100)               0         
_________________________________________________________________
model_15 (Model)             (None, 28, 28, 1)         915841    
_________________________________________________________________
model_16 (Model)             (None, 1)                 660469    
Total params: 1,576,310
Trainable params: 903,105
Non-trainable params: 673,205
_________________________________________________________________


In [75]:
def train(epochs=10, batch_size=128):
    batch_count = X_train.shape[0] // batch_size
    
    for i in range(epochs):
        for j in tqdm(range(batch_count)):
            # Input for the generator
            noise_input = np.random.rand(batch_size, 100)
            
            # getting random images from X_train of size=batch_size 
            # these are the real images that will be fed to the discriminator
            image_batch = X_train[np.random.randint(0, X_train.shape[0], size=batch_size)]
            
            # these are the predicted images from the generator
            predictions = generator.predict(noise_input, batch_size=batch_size)
            
            # the discriminator takes in the real images and the generated images
            X = np.concatenate([predictions, image_batch])
            
            # labels for the discriminator
            y_discriminator = [0]*batch_size + [1]*batch_size
            
            # Let's train the discriminator
            discriminator.trainable = True
            discriminator.train_on_batch(X, y_discriminator)
            
            # Let's train the generator
            noise_input = np.random.rand(batch_size, 100)
            y_generator = [1]*batch_size
            discriminator.trainable = False
            gan.train_on_batch(noise_input, y_generator)
        save_imgs(generator,i)

In [None]:
train(30, 128)

100%|██████████| 468/468 [22:55<00:00,  2.94s/it]
100%|██████████| 468/468 [23:46<00:00,  3.05s/it]
100%|██████████| 468/468 [23:03<00:00,  2.96s/it]
100%|██████████| 468/468 [22:34<00:00,  2.89s/it]
100%|██████████| 468/468 [22:51<00:00,  2.93s/it]
100%|██████████| 468/468 [21:44<00:00,  2.79s/it]
100%|██████████| 468/468 [23:33<00:00,  3.02s/it]
100%|██████████| 468/468 [23:25<00:00,  3.00s/it]
100%|██████████| 468/468 [23:20<00:00,  2.99s/it]
100%|██████████| 468/468 [23:25<00:00,  3.00s/it]
100%|██████████| 468/468 [23:29<00:00,  3.01s/it]
100%|██████████| 468/468 [23:30<00:00,  3.01s/it]
100%|██████████| 468/468 [23:31<00:00,  3.02s/it]
100%|██████████| 468/468 [23:28<00:00,  3.01s/it]
100%|██████████| 468/468 [23:13<00:00,  2.98s/it]
100%|██████████| 468/468 [23:29<00:00,  3.01s/it]
100%|██████████| 468/468 [23:28<00:00,  3.01s/it]
100%|██████████| 468/468 [23:30<00:00,  3.01s/it]
100%|██████████| 468/468 [23:16<00:00,  2.98s/it]
100%|██████████| 468/468 [22:45<00:00,  2.92s/it]
