In [1]:

# references
# - https://github.com/keras-team/keras-contrib/blob/master/examples/improved_wgan.py
# - https://github.com/eriklindernoren/Keras-GAN/blob/master/wgan_gp/wgan_gp.py
# - https://github.com/LynnHo/WGAN-GP-DRAGAN-Celeba-Pytorch/blob/master/models_64x64.py


from functools import partial

import keras
import keras_contrib

import keras.backend as K

import matplotlib.pyplot as plt

import sys

import numpy as np

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
# 218, 178

In [3]:
178/(218/64)

52.25688073394495

In [4]:
num_images = 50000
img_shape = [64, 52, 3]

img_rows = img_shape[0]
img_cols = img_shape[1]
channels = img_shape[2]
latent_dim = 100

# Following parameter and optimizer set as recommended in paper
n_critic = 5
optimizer = keras.optimizers.RMSprop(lr=0.00005)

In [5]:
import glob
import imageio
import lycon
from skimage import transform
from tqdm import tqdm, trange

In [6]:
img_shape[1]

52

In [7]:
im_files = glob.glob("/home/ubuntu/downloads/img_align_celeba/*.jpg")

X_train = np.zeros([num_images] + img_shape)

for i in trange(num_images):

    X_train[i, ...] = lycon.resize(
        lycon.load(im_files[i]),
        width=img_shape[1],
        height=img_shape[0],
        interpolation=lycon.Interpolation.LINEAR)

    imageio.imread(im_files[i])

100%|██████████| 50000/50000 [01:48<00:00, 459.64it/s]


In [8]:
X_train.min(), X_train.max()

(0.0, 255.0)

In [9]:
X_train = X_train*(2/255) -1

In [10]:
X_train.min(), X_train.max()

(-1.0, 1.0)

In [11]:
# Build the generator and critic
model = keras.Sequential([
    keras.layers.Dense(128 * 4 * 4, activation="relu", input_dim=latent_dim),
    keras.layers.Reshape((4, 4, 128)),
    
    keras.layers.Conv2DTranspose(128, kernel_size=5, strides=2, padding="same", activation="elu"),
    keras.layers.BatchNormalization(momentum=0.8),
    
    
    keras.layers.Conv2DTranspose(128, kernel_size=5, strides=2, padding="same", activation="elu"),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.Cropping2D(((0, 0), (1, 2))),
    
    keras.layers.Conv2DTranspose(64, kernel_size=5, strides=2, padding="same", activation="elu"),
    keras.layers.BatchNormalization(momentum=0.8),
#     keras.layers.Cropping2D(((0, 0), (1, 1))),
    
    keras.layers.Conv2DTranspose(64, kernel_size=5, strides=2, padding="same", activation="tanh"),
    keras.layers.BatchNormalization(momentum=0.8),
#     keras.layers.Cropping2D(((0, 0), (1, 2))),
    
    keras.layers.Conv2D(channels, kernel_size=5, padding="same", activation="tanh"),
])

model.summary()

noise = keras.Input(shape=(latent_dim,))
img = model(noise)

generator = keras.Model(noise, img)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 2048)              206848    
_________________________________________________________________
reshape_1 (Reshape)          (None, 4, 4, 128)         0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 8, 8, 128)         409728    
_________________________________________________________________
batch_normalization_1 (Batch (None, 8, 8, 128)         512       
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 16, 16, 128)       409728    
_________________________________________________________________
batch_normalization_2 (Batch (None, 16, 16, 128)       512       
_________________________________________________________________
cropping2d_1 (Cropping2D)    (None, 16, 13, 128)       0         
__________

In [12]:
model = keras.Sequential([
    keras.layers.Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same", activation="elu"),
#     keras.layers.BatchNormalization(momentum=0.8),
    keras_contrib.layers.InstanceNormalization(),
    keras.layers.LeakyReLU(alpha=.2),
    keras.layers.Dropout(0.25),
    
    keras.layers.Conv2D(32, kernel_size=3, strides=2, padding="same"),
#     keras.layers.BatchNormalization(momentum=0.8),
    keras_contrib.layers.InstanceNormalization(),
    keras.layers.LeakyReLU(alpha=.2),
    keras.layers.Dropout(0.25),
    
    keras.layers.Conv2D(64, kernel_size=3, strides=2, padding="same"),
#     keras.layers.BatchNormalization(momentum=0.8),
    keras_contrib.layers.InstanceNormalization(),
    keras.layers.LeakyReLU(alpha=.2),
    keras.layers.Dropout(0.25),
    
    keras.layers.Conv2D(128, kernel_size=3, strides=2, padding="same"),
#     keras.layers.BatchNormalization(momentum=0.8),
#     keras_contrib.layers.InstanceNormalization(),
    keras.layers.LeakyReLU(alpha=.2),
    
#     keras.layers.Conv2D(128, kernel_size=3, strides=2, padding="same"),
#     keras.layers.BatchNormalization(momentum=0.8),
#     keras.layers.LeakyReLU(alpha=.2),
#     keras.layers.Dropout(0.25),
    
    keras.layers.Flatten(),
    keras.layers.Dense(1)
])

model.summary()

img = keras.Input(shape=img_shape)
validity = model(img)

critic =  keras.Model(img, validity)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_2 (Conv2D)            (None, 32, 26, 16)        448       
_________________________________________________________________
instance_normalization_1 (In (None, 32, 26, 16)        2         
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 32, 26, 16)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 32, 26, 16)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 13, 32)        4640      
_________________________________________________________________
instance_normalization_2 (In (None, 16, 13, 32)        2         
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 16, 13, 32)        0         
__________

In [13]:
class RandomWeightedAverage(keras.layers.merge._Merge):
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        alpha = K.random_uniform((32, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

In [14]:
def gradient_penalty_loss(y_true, y_pred, averaged_samples):
    """
    Computes gradient penalty based on prediction and weighted real / fake samples
    """
    gradients = K.gradients(y_pred, averaged_samples)[0]
    # compute the euclidean norm by squaring ...
    gradients_sqr = K.square(gradients)
    #   ... summing over the rows ...
    gradients_sqr_sum = K.sum(gradients_sqr,
                              axis=np.arange(1, len(gradients_sqr.shape)))
    #   ... and sqrt
    gradient_l2_norm = K.sqrt(gradients_sqr_sum)
    # compute lambda * (1 - ||grad||)^2 still for each single sample
    gradient_penalty = K.square(1 - gradient_l2_norm)
    # return the mean as loss over all the batch samples
    return K.mean(gradient_penalty)

In [15]:
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

In [16]:
#-------------------------------
# Construct Computational Graph
#       for the Critic
#-------------------------------

# Freeze generator's layers while training critic
generator.trainable = False

# Image input (real sample)
real_img = keras.Input(shape=img_shape)

# Noise input
z_disc = keras.Input(shape=(latent_dim,))
# Generate image based of noise (fake sample)
fake_img = generator(z_disc)

# Discriminator determines validity of the real and fake images
fake = critic(fake_img)
valid = critic(real_img)

# Construct weighted average between real and fake images
interpolated_img = RandomWeightedAverage()([real_img, fake_img])
# Determine validity of weighted sample
validity_interpolated = critic(interpolated_img)

In [17]:
# Use Python partial to provide loss function with additional
# 'averaged_samples' argument
partial_gp_loss = partial(gradient_penalty_loss,
                  averaged_samples=interpolated_img)
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names

critic_model = keras.Model(inputs=[real_img, z_disc],
                    outputs=[valid, fake, validity_interpolated])
critic_model.compile(loss=[wasserstein_loss,
                                      wasserstein_loss,
                                      partial_gp_loss],
                                optimizer=optimizer,
                                loss_weights=[1, 1, 10])

In [18]:
#-------------------------------
# Construct Computational Graph
#         for Generator
#-------------------------------

# For the generator we freeze the critic's layers
critic.trainable = False
generator.trainable = True

# Sampled noise for input to generator
z_gen = keras.Input(shape=(latent_dim,))
# Generate images based of noise
img = generator(z_gen)
# Discriminator determines validity
valid = critic(img)
# Defines generator model
generator_model = keras.Model(z_gen, valid)
generator_model.compile(loss=wasserstein_loss, optimizer=optimizer)

In [19]:
import shutil, os

def train(epochs, batch_size, sample_interval=50):

    shutil.rmtree("celeba_images", ignore_errors=True)
    os.makedirs("celeba_images", exist_ok=True)

    # Adversarial ground truths
    valid = -np.ones((batch_size, 1))
    fake =  np.ones((batch_size, 1))
    dummy = np.zeros((batch_size, 1)) # Dummy gt for gradient penalty
    for epoch in range(epochs):

        for _ in range(n_critic):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            # Sample generator input
            noise = np.random.normal(0, 1, (batch_size, latent_dim))
            # Train the critic
            d_loss = critic_model.train_on_batch([imgs, noise],
                                                            [valid, fake, dummy])

        # ---------------------
        #  Train Generator
        # ---------------------

        g_loss = generator_model.train_on_batch(noise, valid)
        

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
            sample_images(epoch)

def sample_images(epoch):
    # make a video with 
    # >ffmpeg -y -framerate 4 -pattern_type glob -i 'im_*.png' -pix_fmt yuv420p -vf scale=500:-1 output.mp4
    
    
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, latent_dim))
    gen_imgs = generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + .5

    fig, axs = plt.subplots(r, c, sharex=True, sharey=True, frameon=False, figsize=(5,5))
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt,...], aspect="auto")
            axs[i,j].axis('off')
            cnt += 1
    plt.tight_layout(h_pad=0, w_pad=0)
    plt.suptitle(f'epoch: {epoch}', backgroundcolor="white")
    fig.savefig(f"celeba_images/im_{epoch:05d}.png", dpi=200)
    plt.close()

In [None]:
train(epochs=30000, batch_size=32, sample_interval=100)

  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.071010] [G loss: 0.399737]
100 [D loss: -15.906043] [G loss: 11.872585]
200 [D loss: -10.462564] [G loss: 12.331842]
300 [D loss: -6.009324] [G loss: 6.556452]
400 [D loss: -2.258909] [G loss: -5.384147]
500 [D loss: -0.662623] [G loss: -8.165950]
600 [D loss: -1.386973] [G loss: -5.359030]
700 [D loss: -0.288987] [G loss: -0.037723]
800 [D loss: -2.307323] [G loss: 5.657558]
900 [D loss: -2.037921] [G loss: -8.884901]
1000 [D loss: -2.822394] [G loss: -8.537423]
1100 [D loss: -0.489803] [G loss: -15.163243]
1200 [D loss: -1.426459] [G loss: -18.684834]
1300 [D loss: -1.630782] [G loss: -21.376253]
1400 [D loss: -2.592243] [G loss: -22.582027]
1500 [D loss: -1.676471] [G loss: -29.899513]
1600 [D loss: -2.051211] [G loss: -25.263687]
1700 [D loss: -2.363373] [G loss: -26.964220]
1800 [D loss: -2.178939] [G loss: -26.493341]
1900 [D loss: -2.008708] [G loss: -28.711365]
2000 [D loss: -2.075954] [G loss: -33.315414]
2100 [D loss: -1.926997] [G loss: -29.601543]
2200 [D loss:

In [None]:
noise = np.random.normal(0, 1, (10, latent_dim))
gen_imgs = generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + .5

In [None]:
gen_imgs.shape

In [None]:
gen_imgs.max()

In [None]:
gen_imgs.min()