In [1]:
# Large amount of credit goes to:
# https://github.com/keras-team/keras-contrib/blob/master/examples/improved_wgan.py
# which I've used as a reference for this implementation

from functools import partial

import keras

import keras.backend as K

import matplotlib.pyplot as plt

import sys

import numpy as np

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
img_rows = 32
img_cols = 32
channels = 3
img_shape = (img_rows, img_cols, channels)
latent_dim = 100

# Following parameter and optimizer set as recommended in paper
n_critic = 5
optimizer = keras.optimizers.RMSprop(lr=0.00005)

In [14]:
# Build the generator and critic
model = keras.Sequential([
    keras.layers.Dense(128 * 4 * 4, activation="relu", input_dim=latent_dim),
    keras.layers.Reshape((4, 4, 128)),
    keras.layers.UpSampling2D(),
    keras.layers.Conv2D(128, kernel_size=3, padding="same", activation="elu"),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.UpSampling2D(),
    keras.layers.Conv2D(64, kernel_size=4, padding="same", activation="elu"),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.UpSampling2D(),
    keras.layers.Conv2D(64, kernel_size=5, padding="same", activation="elu"),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.Conv2D(channels, kernel_size=5, padding="same", activation="tanh"),
])

model.summary()

noise = keras.Input(shape=(latent_dim,))
img = model(noise)

generator = keras.Model(noise, img)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_4 (Dense)              (None, 2048)              206848    
_________________________________________________________________
reshape_3 (Reshape)          (None, 4, 4, 128)         0         
_________________________________________________________________
up_sampling2d_7 (UpSampling2 (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 8, 8, 128)         147584    
_________________________________________________________________
batch_normalization_10 (Batc (None, 8, 8, 128)         512       
_________________________________________________________________
up_sampling2d_8 (UpSampling2 (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 16, 16, 64)        131136    
__________

In [15]:
model = keras.Sequential([
    keras.layers.Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same", activation="elu"),
    keras.layers.Dropout(0.25),
    keras.layers.Conv2D(32, kernel_size=3, strides=2, padding="same"),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.LeakyReLU(alpha=.2),
    keras.layers.Dropout(0.25),
    keras.layers.Conv2D(64, kernel_size=3, strides=2, padding="same"),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.LeakyReLU(alpha=.2),
    keras.layers.Dropout(0.25),
    keras.layers.Conv2D(128, kernel_size=3, strides=1, padding="same"),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.LeakyReLU(alpha=.2),
    keras.layers.Dropout(0.25),
    keras.layers.Flatten(),
    keras.layers.Dense(1)
])

model.summary()

img = keras.Input(shape=img_shape)
validity = model(img)

critic =  keras.Model(img, validity)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_17 (Conv2D)           (None, 16, 16, 16)        448       
_________________________________________________________________
dropout_5 (Dropout)          (None, 16, 16, 16)        0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 8, 8, 32)          4640      
_________________________________________________________________
batch_normalization_13 (Batc (None, 8, 8, 32)          128       
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 8, 8, 32)          0         
_________________________________________________________________
dropout_6 (Dropout)          (None, 8, 8, 32)          0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 4, 4, 64)          18496     
__________

In [16]:
class RandomWeightedAverage(keras.layers.merge._Merge):
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        alpha = K.random_uniform((32, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

In [17]:
def gradient_penalty_loss(y_true, y_pred, averaged_samples):
    """
    Computes gradient penalty based on prediction and weighted real / fake samples
    """
    gradients = K.gradients(y_pred, averaged_samples)[0]
    # compute the euclidean norm by squaring ...
    gradients_sqr = K.square(gradients)
    #   ... summing over the rows ...
    gradients_sqr_sum = K.sum(gradients_sqr,
                              axis=np.arange(1, len(gradients_sqr.shape)))
    #   ... and sqrt
    gradient_l2_norm = K.sqrt(gradients_sqr_sum)
    # compute lambda * (1 - ||grad||)^2 still for each single sample
    gradient_penalty = K.square(1 - gradient_l2_norm)
    # return the mean as loss over all the batch samples
    return K.mean(gradient_penalty)

In [18]:
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

In [19]:
#-------------------------------
# Construct Computational Graph
#       for the Critic
#-------------------------------

# Freeze generator's layers while training critic
generator.trainable = False

# Image input (real sample)
real_img = keras.Input(shape=img_shape)

# Noise input
z_disc = keras.Input(shape=(latent_dim,))
# Generate image based of noise (fake sample)
fake_img = generator(z_disc)

# Discriminator determines validity of the real and fake images
fake = critic(fake_img)
valid = critic(real_img)

# Construct weighted average between real and fake images
interpolated_img = RandomWeightedAverage()([real_img, fake_img])
# Determine validity of weighted sample
validity_interpolated = critic(interpolated_img)

In [20]:
# Use Python partial to provide loss function with additional
# 'averaged_samples' argument
partial_gp_loss = partial(gradient_penalty_loss,
                  averaged_samples=interpolated_img)
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names

critic_model = keras.Model(inputs=[real_img, z_disc],
                    outputs=[valid, fake, validity_interpolated])
critic_model.compile(loss=[wasserstein_loss,
                                      wasserstein_loss,
                                      partial_gp_loss],
                                optimizer=optimizer,
                                loss_weights=[1, 1, 10])

In [21]:
#-------------------------------
# Construct Computational Graph
#         for Generator
#-------------------------------

# For the generator we freeze the critic's layers
critic.trainable = False
generator.trainable = True

# Sampled noise for input to generator
z_gen = keras.Input(shape=(latent_dim,))
# Generate images based of noise
img = generator(z_gen)
# Discriminator determines validity
valid = critic(img)
# Defines generator model
generator_model = keras.Model(z_gen, valid)
generator_model.compile(loss=wasserstein_loss, optimizer=optimizer)

In [22]:
import shutil, os

def train(epochs, batch_size, sample_interval=50):

    shutil.rmtree("cifar_images", ignore_errors=True)
    os.makedirs("cifar_images", exist_ok=True)
    
    # Load the dataset
    (X_train, _), (_, _) = keras.datasets.cifar10.load_data()

    # Rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
#     X_train = np.expand_dims(X_train, axis=3)

    # Adversarial ground truths
    valid = -np.ones((batch_size, 1))
    fake =  np.ones((batch_size, 1))
    dummy = np.zeros((batch_size, 1)) # Dummy gt for gradient penalty
    for epoch in range(epochs):

        for _ in range(n_critic):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            # Sample generator input
            noise = np.random.normal(0, 1, (batch_size, latent_dim))
            # Train the critic
            d_loss = critic_model.train_on_batch([imgs, noise],
                                                            [valid, fake, dummy])

        # ---------------------
        #  Train Generator
        # ---------------------

        g_loss = generator_model.train_on_batch(noise, valid)
        

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
            sample_images(epoch)

def sample_images(epoch):
    # make a video with 
    # >ffmpeg -y -framerate 4 -pattern_type glob -i 'im_*.png' -pix_fmt yuv420p -vf scale=500:-1 output.mp4
    
    
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, latent_dim))
    gen_imgs = generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + .5

    fig, axs = plt.subplots(r, c, sharex=True, sharey=True, frameon=False, figsize=(5,5))
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt,...], aspect="auto")
            axs[i,j].axis('off')
            cnt += 1
    plt.tight_layout(h_pad=0, w_pad=0)
    plt.suptitle(f'epoch: {epoch}', backgroundcolor="white")
    fig.savefig(f"cifar_images/im_{epoch:05d}.png", dpi=200)
    plt.close()

In [None]:
train(epochs=30000, batch_size=32, sample_interval=100)

0 [D loss: 22.549532] [G loss: -0.085128]
100 [D loss: -1.179994] [G loss: 2.338360]
200 [D loss: -0.348127] [G loss: 3.588444]
300 [D loss: -1.086858] [G loss: 5.133800]
400 [D loss: -0.876872] [G loss: 4.884992]
500 [D loss: -0.609653] [G loss: 4.347772]
600 [D loss: -0.334168] [G loss: 2.548889]
700 [D loss: -0.581471] [G loss: 1.469270]
800 [D loss: -0.043264] [G loss: 0.461900]
900 [D loss: -0.759276] [G loss: 0.855498]
1000 [D loss: -0.437620] [G loss: 0.523317]
1100 [D loss: -0.167538] [G loss: 0.852615]
1200 [D loss: -0.268820] [G loss: 1.019385]
1300 [D loss: 0.026647] [G loss: 0.963118]
1400 [D loss: -0.514463] [G loss: 1.097671]
1500 [D loss: -0.521503] [G loss: 0.904237]
1600 [D loss: -0.245003] [G loss: 1.037712]
1700 [D loss: -0.331149] [G loss: 1.138600]
1800 [D loss: -0.096037] [G loss: 0.656050]
1900 [D loss: -0.331608] [G loss: 0.802226]
2000 [D loss: -0.213479] [G loss: 1.055718]
2100 [D loss: -0.133520] [G loss: 1.378542]
2200 [D loss: -0.270195] [G loss: 1.208834]


18300 [D loss: -0.417379] [G loss: -3.728946]
18400 [D loss: -0.009635] [G loss: -3.070096]
18500 [D loss: 0.343299] [G loss: -3.370257]
18600 [D loss: -0.417341] [G loss: -3.487167]
18700 [D loss: -0.160984] [G loss: -3.644227]
18800 [D loss: 0.687193] [G loss: -2.934487]
18900 [D loss: -0.032827] [G loss: -3.306953]
19000 [D loss: -0.375450] [G loss: -3.747408]
19100 [D loss: 0.561327] [G loss: -3.402855]
19200 [D loss: -0.049396] [G loss: -2.947413]
19300 [D loss: -0.023339] [G loss: -2.929947]
19400 [D loss: -0.381016] [G loss: -2.465288]
19500 [D loss: -0.491207] [G loss: -3.180260]
19600 [D loss: -0.382616] [G loss: -3.309305]
19700 [D loss: 0.112468] [G loss: -3.150632]
19800 [D loss: 0.086436] [G loss: -3.324967]
19900 [D loss: -0.077478] [G loss: -3.217413]
20000 [D loss: 0.116025] [G loss: -3.367551]
20100 [D loss: -0.119196] [G loss: -3.543710]
20200 [D loss: 0.008402] [G loss: -3.135237]
20300 [D loss: 0.149330] [G loss: -3.026728]
20400 [D loss: -0.202665] [G loss: -3.0871

In [None]:
noise = np.random.normal(0, 1, (10, latent_dim))
gen_imgs = generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + .5

In [None]:
gen_imgs.shape

In [None]:
gen_imgs.max()

In [None]:
gen_imgs.min()