## Import Dependencies

In [None]:
import glob
import keras.backend as backend
import matplotlib.pyplot as plt
import numpy as np
import os

from IPython.display import clear_output
from keras.layers import Activation, Dense, Dropout, Flatten, \
    Input, Reshape
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Model, Sequential
from keras.optimizers import RMSprop
from PIL import Image
from tqdm import tqdm

In [None]:
## Constants

In [None]:
# size of the images being used
IMG_ROWS = 104
IMG_COL = 88
IMG_CHANNELS = 1
IMG_SHAPE = (IMG_ROWS, IMG_COL, IMG_CHANNELS)

# image format
IMG_EXT = '.png'
# training data location
DATA_DIR = 'img_align_celeba/'

# image output directory
IMAGE_DIR = 'images/'
# loss graph output directory
LOSS_DIR = 'loss/'
# saved models directory
MODEL_DIR = 'saved_models/'

# make the directories if it doesn't exist
try:
    os.mkdir(IMAGE_DIR)
except:
    pass
try:
    os.mkdir(LOSS_DIR)
except:
    pass
try:
    os.mkdir(MODEL_DIR)
except:
    pass

# save teh model every n epochs
SAVE_RATE = 500
# save teh output and loss every n epochs
SAMPLE_RATE = 50
# sqrt of number of images to sample
SAMPLE_NUM = 6

# epochs to train for
EPOCHS = 10000
# batch size of images
BATCH_SIZE = 128
# number of batches to train discriminator per epoch
DISCRIM_ITER = 5
# WGAN-recommended weight clipping value
WEIGHT_CLIP = 0.01
# dimensions of the input noise vector
LATENT_DIM = 100
# optimizer to use for each model
OPTIMIZER = RMSprop(lr=0.00005)

## Helper methods

In [None]:
def wasserstein_loss(y_true, y_pred):
    return backend.mean(y_true * y_pred)

def load_data():
    # load data
    raw_images = glob.glob(DATA_DIR + '*' + IMG_EXT)
    # read data
    images = np.array([np.array(Image.open(image)) for image in raw_images])
    # shuffle data
    np.random.shuffle(images)
    # rearrange data
    images = images.reshape(images.shape[0], 104, 88, 1)
    # normalize data [-1, 1]
    images = (images.astype(np.float) - 127.5)/127.5
    
    return images

def plot_loss(name, gen_loss, disc_loss, d_r_loss, d_f_loss):
    plt.figure(figsize=(12, 8))
    plt.plot(gen_loss, label='Gen. Loss')
    plt.plot(disc_loss, label='Discrim. Total Loss')
    plt.legend()
    plt.savefig(LOSS_DIR + name + '_total_loss_plot' + IMG_EXT)
    plt.close()
    
    plt.figure(figsize=(12, 8))
    plt.plot(gen_loss, label='Gen. Loss')
    plt.plot(d_r_loss, label='Discrim. Real Loss')
    plt.plot(d_f_loss, label='Discrim. Fake Loss')
    plt.legend()
    plt.savefig(LOSS_DIR + name + '_real_fake_loss_plot' + IMG_EXT)
    plt.close()
    
def save_output(epoch, model):
    # generate 25 images
    images = model.generator.predict(np.random.normal(0, 1, (SAMPLE_NUM**2, LATENT_DIM)))
    # normalize images [0,1]
    images = 0.5 * images + 0.5
    # make a square plot
    fig, axs = plt.subplots(SAMPLE_NUM, SAMPLE_NUM, figsize=(10, 10))
    img_idx = 0
    for y in range(SAMPLE_NUM):
        for x in range(SAMPLE_NUM):
            axs[x,y].imshow(images[img_idx, :,:,0], cmap='gray')
            axs[x,y].axis('off')
            img_idx += 1
    # pad the output names to make things cleaner
    fig.savefig(IMAGE_DIR + model.name + '_sample_' + str(epoch).zfill(5) + IMG_EXT)
    plt.close()
    
def save_loss(name, g_loss, d_loss, d_r_loss, d_f_loss):
    np.save(LOSS_DIR + name + '_gen_loss', g_loss)
    np.save(LOSS_DIR + name + '_total_discrim_loss', d_loss)
    np.save(LOSS_DIR + name + '_discrim_fake_loss', d_f_loss)
    np.save(LOSS_DIR + name + '_discrim_real_loss', d_r_loss)
    
def save_model(epoch, model):
    model.discriminator.save(MODEL_DIR + model.name + 
                             '_discrim_' + str(epoch).zfill(5) + '.h5')
    model.combined.save(MODEL_DIR + model.name + 
                        '_combined_' + str(epoch).zfill(5) + '.h5')
    
def train(model, training_data):

        # create labels for GAN
        if model.wasserstein == False:
            r_labels = np.ones((BATCH_SIZE, 1))
            f_labels = np.zeros((BATCH_SIZE, 1))
        # create labels for WGAN
        else:
            r_labels = np.ones((BATCH_SIZE, 1))
            f_labels = -np.ones((BATCH_SIZE, 1))

        # lists to store the losses
        all_d_r_loss = []
        all_d_f_loss = []
        all_g_loss = []
        all_d_loss = []
        
        for epoch in range(EPOCHS):
            
            for _ in tqdm(range(DISCRIM_ITER)):

                # sample a (pseudo-)random batch of images
                r_imgs = training_data[np.random.randint(
                    0, training_data.shape[0], BATCH_SIZE) ]
                
                # generate batch of images
                noise = np.random.normal(0, 1, (BATCH_SIZE, LATENT_DIM))
                f_imgs = model.generator.predict(noise)

                # train discrim on real and generated seperately
                d_r_loss = model.discriminator.train_on_batch(r_imgs, r_labels) 
                d_f_loss = model.discriminator.train_on_batch(f_imgs, f_labels)
                d_loss = 0.5 * np.add(d_f_loss, d_r_loss)
                
                # Clip critic weights
                for l in model.discriminator.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -WEIGHT_CLIP, WEIGHT_CLIP) for w in weights]
                    l.set_weights(weights)

            # train the generator (via combined model)
            noise = np.random.normal(0, 1, (BATCH_SIZE, LATENT_DIM))
            g_loss = model.combined.train_on_batch(noise, r_labels) 

                            
            # clear the output every 10 iterations
            if epoch % 10 == 0:
                clear_output()
        
            all_d_r_loss.append(d_r_loss)
            all_d_f_loss.append(d_f_loss)
            all_g_loss.append(g_loss)
            all_d_loss.append(d_loss)
            print("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss, g_loss))

            # log at the sampling interval
            if epoch % SAMPLE_RATE == 0 or epoch == EPOCHS - 1:
                save_output(epoch, model)
                save_loss(model.name, np.array(all_g_loss), np.array(all_d_loss), 
                          np.array(all_d_r_loss), np.array(all_d_f_loss))
                plot_loss(model.name, all_g_loss, all_d_loss, all_d_r_loss, all_d_f_loss)
                
            # save at the saving interval
            if epoch % SAVE_RATE == 0 or epoch == EPOCHS - 1:
                save_model(epoch, model)


## Declare the class for the models, toggleable to make WGAN or GAN

In [None]:
class GAN():
    def __init__(self, wasserstein=False):
        # check if WGAN
        if wasserstein == True:
            loss = wasserstein_loss
            self.name = 'WGAN'
        else:
            loss = 'binary_crossentropy'
            self.name = 'GAN'
        # set flag since faster than string comparison
        self.wasserstein = wasserstein
        
        # get and compile the discriminator
        self.discriminator = self.get_discriminator()
        self.discriminator.compile(loss=loss, optimizer=OPTIMIZER)
        # set not trainable for combined version (compiled is still trainable)
        self.discriminator.trainable = False
        
        # get the generator but don't compile it
        self.generator = self.get_generator()
        
        # build all the necessary in's and out's
        g_in = Input(shape=(LATENT_DIM, ))
        g_out = self.generator(g_in)
        d_out = self.discriminator(g_out)
        
        # construct and compile the actual combined model
        self.combined = Model(g_in, d_out)
        self.combined.compile(loss=loss, optimizer=OPTIMIZER)
        
    def get_discriminator(self):
        discrim = Sequential()
    
        discrim.add(Conv2D(64, kernel_size=(12, 10), strides=1, padding='same', 
                           input_shape=IMG_SHAPE))
        discrim.add(Activation('selu'))
        discrim.add(Dropout(0.2))
        
        discrim.add(Conv2D(64, kernel_size=(4, 4), strides=1, padding='same'))
        discrim.add(Activation('selu'))
        discrim.add(Dropout(0.2))
        
        discrim.add(Conv2D(64, kernel_size=(6, 5), strides=2, padding='same'))
        discrim.add(Activation('selu'))
        discrim.add(Dropout(0.2))
    
        discrim.add(Conv2D(128, kernel_size=(6, 5), strides=2, padding='same'))
        discrim.add(Activation('selu'))
        discrim.add(Dropout(0.3))
        
        discrim.add(Conv2D(256, kernel_size=(6, 5), strides=2, padding='same'))
        discrim.add(Activation('selu'))
        discrim.add(Dropout(0.2))
    
        discrim.add(Flatten())
        discrim.add(Dense(18))
        discrim.add(Activation('selu'))
        discrim.add(Dropout(0.2))
        discrim.add(Dense(1))
        
        # add sigmoid to create a GAN instead of a WGAN
        if self.wasserstein == False:
            discrim.add(Activation('sigmoid'))
        
        d_in = Input(shape=IMG_SHAPE)
        d_out = discrim(d_in)
        
        discrim.summary()

        return Model(d_in, d_out)
        
    def get_generator(self):
        gen = Sequential()

        gen.add(Dense(128 * 13 * 11, input_dim=LATENT_DIM))
        gen.add(Activation('selu'))
        gen.add(Reshape((13, 11, 128)))
    
        gen.add(UpSampling2D())
        gen.add(Conv2D(128, (6, 5), strides=1, padding='same'))
        gen.add(Activation('selu'))

        gen.add(UpSampling2D())
        gen.add(Conv2D(64, (6, 5), strides=1, padding='same'))
        gen.add(Activation('selu'))
    
        gen.add(UpSampling2D())
        gen.add(Conv2D(32, (6, 5), strides=1, padding='same'))
        gen.add(Activation('selu'))
        
        gen.add(UpSampling2D())
        gen.add(Conv2D(32, (6, 5), strides=2, padding='same'))
        gen.add(Activation('selu'))
        
        gen.add(Conv2D(1, kernel_size=(6, 5), strides=1, padding='same'))
        gen.add(Activation('tanh'))
    
        g_in = Input(shape=(LATENT_DIM,))
        g_out = gen(g_in)
        
        gen.summary()
    
        return Model(g_in, g_out)

## Load the dataset

In [None]:
training_set = load_data()

## Build and train the models

In [None]:
wgan = GAN(wasserstein=True)

In [None]:
train(wgan, training_set)

In [None]:
gan = GAN(wasserstein=False)

In [None]:
train(gan, training_set)