In [1]:
# references
# - https://github.com/keras-team/keras-contrib/blob/master/examples/improved_wgan.py
# - https://github.com/eriklindernoren/Keras-GAN/blob/master/wgan_gp/wgan_gp.py
# - https://github.com/LynnHo/WGAN-GP-DRAGAN-Celeba-Pytorch/blob/master/models_64x64.py

from functools import partial

import keras
import keras_contrib

import keras.backend as K

import matplotlib.pyplot as plt

import sys

import numpy as np

import glob
import imageio
import lycon
from skimage import transform
from tqdm import tqdm_notebook as tqdm

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [5]:
img_shape = [64, 64, 3]

img_rows = img_shape[0]
img_cols = img_shape[1]
channels = img_shape[2]
latent_dim = 100

# Following parameter and optimizer set as recommended in paper
n_critic = 5
lr = 0.0002
batch_size = 64
optimizer = keras.optimizers.Adam(lr=lr, beta_1=.5, beta_2=.999)
tnorm = keras.initializers.truncated_normal(stddev=.02)
rnorm = keras.initializers.random_normal(stddev=.02)
lrelu = lambda x: keras.activations.relu(x, alpha=.2)

In [None]:
ker

In [2]:
# takes about 3.5min
X_train = np.load("celeba.npy")

In [6]:
X_train.shape

(202599, 64, 64, 3)

In [7]:
dim = 64
DROP = 0

# Build the generator
model = keras.Sequential([
    keras.layers.Dense(
        dim * 8 * 4 * 4,
        activation="relu",
        input_dim=latent_dim,
        kernel_initializer=rnorm,
        use_bias=False),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.Reshape((4, 4, dim * 8)),
    keras.layers.Dropout(DROP),
    keras.layers.Conv2DTranspose(
        dim * 4,
        kernel_size=5,
        strides=2,
        padding="same",
        activation="relu",
        kernel_initializer=rnorm,
        use_bias=False),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.Dropout(DROP),
    keras.layers.Conv2DTranspose(
        dim * 2,
        kernel_size=5,
        strides=2,
        padding="same",
        activation="relu",
        kernel_initializer=rnorm,
        use_bias=False),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.Dropout(DROP),
    keras.layers.Conv2DTranspose(
        dim,
        kernel_size=5,
        strides=2,
        padding="same",
        activation="relu",
        kernel_initializer=rnorm,
        use_bias=False),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.Dropout(DROP),
    keras.layers.Conv2DTranspose(
        channels,
        kernel_size=5,
        strides=2,
        padding="same",
        kernel_initializer=rnorm,
        activation="tanh",
    ),
])

model.summary()

noise = keras.Input(shape=(latent_dim, ))
img = model(noise)

generator = keras.Model(noise, img)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 8192)              819200    
_________________________________________________________________
batch_normalization_1 (Batch (None, 8192)              32768     
_________________________________________________________________
reshape_1 (Reshape)          (None, 4, 4, 512)         0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 4, 4, 512)         0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 8, 8, 256)         3276800   
_________________________________________________________________
batch_normalization_2 (Batch (None, 8, 8, 256)         1024      
_________________________________________________________________
dropout_2 (Dropout)          (None, 8, 8, 256)         0         
__________

In [8]:
import tensorflow as tf

def ln_func(x):
    return tf.contrib.layers.layer_norm(x)

layer_norm = lambda : keras.layers.Lambda(ln_func)

In [9]:
# build critic

dim = 64
DROP = 0.1

model = keras.Sequential([
    keras.layers.InputLayer(input_shape=img_shape),
    keras.layers.Conv2D(
        dim,
        kernel_size=5,
        strides=2,
        padding="same",
        activation=lrelu,
        kernel_initializer=tnorm,
    ),
    layer_norm(),
    keras.layers.Dropout(DROP),
    keras.layers.Conv2D(
        dim * 2,
        kernel_size=5,
        strides=2,
        input_shape=img_shape,
        padding="same",
        activation=lrelu,
        kernel_initializer=tnorm,
    ),
    layer_norm(),
    keras.layers.Dropout(DROP),
    keras.layers.Conv2D(
        dim * 4,
        kernel_size=5,
        strides=2,
        padding="same",
        activation=lrelu,
        kernel_initializer=tnorm,
    ),
    layer_norm(),
    keras.layers.Dropout(DROP),
    keras.layers.Conv2D(
        dim * 8,
        kernel_size=5,
        strides=2,
        padding="same",
        activation=lrelu,
        kernel_initializer=tnorm,
    ),
    layer_norm(),
    keras.layers.Dropout(DROP),
    keras.layers.Flatten(),
    keras.layers.Dense(1)
])

model.summary()

img = keras.Input(shape=img_shape)
validity = model(img)

critic = keras.Model(img, validity)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 32, 32, 64)        4864      
_________________________________________________________________
lambda_1 (Lambda)            (None, 32, 32, 64)        0         
_________________________________________________________________
dropout_5 (Dropout)          (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 16, 16, 128)       204928    
_________________________________________________________________
lambda_2 (Lambda)            (None, 16, 16, 128)       0         
_________________________________________________________________
dropout_6 (Dropout)          (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 8, 8, 256)         819456    
__________

In [10]:
class RandomWeightedAverage(keras.layers.merge._Merge):
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        alpha = K.random_uniform((batch_size, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

In [11]:
def gradient_penalty_loss(y_true, y_pred, averaged_samples):
    """
    Computes gradient penalty based on prediction and weighted real / fake samples
    """
    gradients = K.gradients(y_pred, averaged_samples)[0]
    # compute the euclidean norm by squaring ...
    gradients_sqr = K.square(gradients)
    #   ... summing over the rows ...
    gradients_sqr_sum = K.sum(gradients_sqr,
                              axis=np.arange(1, len(gradients_sqr.shape)))
    #   ... and sqrt
    gradient_l2_norm = K.sqrt(gradients_sqr_sum)
    # compute lambda * (1 - ||grad||)^2 still for each single sample
    gradient_penalty = K.square(1 - gradient_l2_norm)
    # return the mean as loss over all the batch samples
    return K.mean(gradient_penalty)

In [12]:
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

In [13]:
#-------------------------------
# Construct Computational Graph
#       for the Critic
#-------------------------------

# Freeze generator's layers while training critic
generator.trainable = False

# Image input (real sample)
real_img = keras.Input(shape=img_shape)

# Noise input
z_disc = keras.Input(shape=(latent_dim,))
# Generate image based of noise (fake sample)
fake_img = generator(z_disc)

# Discriminator determines validity of the real and fake images
fake = critic(fake_img)
valid = critic(real_img)

# Construct weighted average between real and fake images
interpolated_img = RandomWeightedAverage()([real_img, fake_img])
# Determine validity of weighted sample
validity_interpolated = critic(interpolated_img)

In [14]:
# Use Python partial to provide loss function with additional
# 'averaged_samples' argument
partial_gp_loss = partial(gradient_penalty_loss,
                  averaged_samples=interpolated_img)
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names

critic_model = keras.Model(inputs=[real_img, z_disc],
                    outputs=[valid, fake, validity_interpolated])
critic_model.compile(loss=[wasserstein_loss,
                                      wasserstein_loss,
                                      partial_gp_loss],
                                optimizer=optimizer,
                                loss_weights=[1, 1, 10])

In [15]:
#-------------------------------
# Construct Computational Graph
#         for Generator
#-------------------------------

# For the generator we freeze the critic's layers
critic.trainable = False
generator.trainable = True

# Sampled noise for input to generator
z_gen = keras.Input(shape=(latent_dim,))
# Generate images based of noise
img = generator(z_gen)
# Discriminator determines validity
valid = critic(img)
# Defines generator model
generator_model = keras.Model(z_gen, valid)
generator_model.compile(loss=wasserstein_loss, optimizer=optimizer)

In [16]:
import shutil, os

def train(steps, batch_size, sample_interval=50):
    
    d_loss_list = []
    g_loss_list = []

    shutil.rmtree("celeba_images", ignore_errors=True)
    os.makedirs("celeba_images", exist_ok=True)

    # Adversarial ground truths
    valid = -np.ones((batch_size, 1))
    fake =  np.ones((batch_size, 1))
    dummy = np.zeros((batch_size, 1)) # Dummy gt for gradient penalty
    for step in tqdm(range(1,steps+1)):

        for _ in range(n_critic):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            # Sample generator input
            noise = np.random.normal(0, 1, (batch_size, latent_dim))
            # Train the critic
            d_loss = critic_model.train_on_batch([imgs, noise],
                                                            [valid, fake, dummy])
            d_loss_list.append(d_loss)

        # ---------------------
        #  Train Generator
        # ---------------------

        g_loss = generator_model.train_on_batch(noise, valid)
        g_loss_list.append(g_loss)
        

        # If at save interval => save generated image samples
        if step % sample_interval == 0:
            print_string = f"step: {step}, D-W1: {d_loss[0]:g}, D-W2: {d_loss[1]:g}, D-GP: {d_loss[2]:g}, D: {d_loss[0]+d_loss[1]+10*d_loss[2]:g}, G: {g_loss:g}"
            print(print_string)
            sample_images(print_string, step)

def sample_images(print_string, step):
    # make a video with 
    # >ffmpeg -y -framerate 4 -pattern_type glob -i 'im_*.png' -pix_fmt yuv420p -vf scale=500:-1 output.mp4
    
    
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, latent_dim))
    gen_imgs = generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + .5

    fig, axs = plt.subplots(r, c, sharex=True, sharey=True, frameon=False, figsize=(5,5))
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt,...], aspect="auto", interpolation="spline16")
            axs[i,j].axis('off')
            cnt += 1
    plt.tight_layout(h_pad=0, w_pad=0)
    plt.suptitle(print_string, backgroundcolor="white", fontsize=7)
    fig.savefig(f"celeba_images/im_{step:05d}.png", dpi=150)
    plt.close()

In [None]:
train(steps=30000, batch_size=batch_size, sample_interval=100)

HBox(children=(IntProgress(value=0, max=30000), HTML(value='')))

  'Discrepancy between trainable weights and collected trainable'
