# [UE Biometrics] TP GAN
ZHANG Yuancheng
3704091

In [35]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, GlobalAveragePooling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D, Conv2DTranspose
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys
import os 

import numpy as np

In [36]:
class ConvGAN():
    def __init__(self):
        
        # shape of input images (for MNIST, 28 x 28)
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        # distribute into 10 classes
        #self.num_classes = 10
        
        # dimension of random latent (input for generator)
        self.latent_dim = 32
        
        # adam optimizer
        optimizer = Adam(0.0002, 0.5)
        
        # discriminator model
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss=['binary_crossentropy'], optimizer=optimizer, metrics=['accuracy'])
        
        # generator model
        self.generator = self.build_generator()
        
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)
        
        # when training generator, stop training discriminator
        self.discriminator.trainable = False
        
        # evaluate the generated image
        valid = self.discriminator(img)
        
        # combine generator and discriminator
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        model = Sequential()
        # a first dense layer which transforms the input of generator to dimension of 14*14*128
        model.add(Dense(14 * 14 * 128, use_bias = False, input_shape=(self.latent_dim, )))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        # reshape
        model.add(Reshape((14, 14, 128)))

        # 14*14*128 -> 14*14*128
        model.add(Conv2DTranspose(128, (5, 5), use_bias = False, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        
        # 14*14*128 -> 14*14*256
        model.add(Conv2DTranspose(256, (5, 5), use_bias = False, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        
        # 14*14*256 -> 28*28*128
        model.add(UpSampling2D())
        model.add(Conv2DTranspose(128, (5, 5), use_bias = False, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))

        # 14*14*128 -> 28*28*1
        model.add(Conv2DTranspose(self.channels, (5, 5), use_bias = False, padding="same", activation=LeakyReLU(alpha=0.2)))

        model.summary()
        
        # input: random numbers
        noise = Input(shape=(self.latent_dim,))
        # output: generated images
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()
        
        model.add(Conv2D(256, (5, 5), input_shape=self.img_shape))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        
        
        model.add(Conv2D(128, (5, 5)))
        #model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        
        
        #model.add(ZeroPadding2D(((0,1),(0,1))))
        model.add(Conv2D(64, (5, 5)))
        #model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        
        #model.add(GlobalAveragePooling2D())
        model.add(Flatten())
        model.add(Dense(1, activation='sigmoid'))

        model.summary()

        # input: 
        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, show_save_frequency=10):
        # load mnist dataset
        (X_train, _), (_, _) = mnist.load_data()

        # normalization to [-1, 1]
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # shuffle
        #train_datasets = tf.data.Dataset.from_tensor_slices(X_train)
        #train_datasets = train_datasets.shuffle(X_train.shape[0]).batch(128)

        # ground truths: valid or fake
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):
            
                 
            # ------------------------ #
            #  train the discriminator #
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            #print(imgs.shape)

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            gen_imgs = self.generator.predict(noise)

            # train and compute loss
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ------------------------ #
            #  train the generator  #
            g_loss = self.combined.train_on_batch(noise, valid)

            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
            
            # show and save images for every 10 epochs
            if epoch % show_save_frequency == 0:
                self.show_and_save_imgs(epoch)

    def show_and_save_imgs(self, epoch):
        row, col = 2, 5
        noise = np.random.normal(0, 1, (row * col, self.latent_dim))
        gen_imgs = self.generator.predict(noise)
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(row, col)
        cnt = 0
        for i in range(row):
            for j in range(col):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        #fig.savefig("images/mnist_%d.png" % epoch)         
        #plt.close()             # activate this part if you want to save the generated images

In [None]:
if __name__ == '__main__':
    #if not os.path.exists("./images"):
        #os.makedirs("./images")       # activate this part if you want to save the generated images
    convgan = ConvGAN()
    convgan.train(epochs=300, batch_size=256, show_save_frequency=10)