In [1]:
# Large amount of credit goes to:
# https://github.com/keras-team/keras-contrib/blob/master/examples/improved_wgan.py
# which I've used as a reference for this implementation

from functools import partial

import keras

import keras.backend as K

import matplotlib.pyplot as plt

import sys

import numpy as np

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100

# Following parameter and optimizer set as recommended in paper
n_critic = 5
optimizer = keras.optimizers.RMSprop(lr=0.00005)

In [3]:
# Build the generator and critic
model = keras.Sequential([
    keras.layers.Dense(128 * 7 * 7, activation="relu", input_dim=latent_dim),
    keras.layers.Reshape((7, 7, 128)),
    keras.layers.UpSampling2D(),
    keras.layers.Conv2D(128, kernel_size=4, padding="same", activation="elu"),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.UpSampling2D(),
    keras.layers.Conv2D(64, kernel_size=4, padding="same", activation="elu"),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.Conv2D(channels, kernel_size=4, padding="same", activation="tanh"),
])

model.summary()

noise = keras.Input(shape=(latent_dim,))
img = model(noise)

generator = keras.Model(noise, img)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 6272)              633472    
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 128)       262272    
_________________________________________________________________
batch_normalization_1 (Batch (None, 14, 14, 128)       512       
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 28, 28, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 64)        131136    
__________

In [4]:
model = keras.Sequential([
    keras.layers.Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same", activation="elu"),
    keras.layers.Dropout(0.25),
    keras.layers.Conv2D(32, kernel_size=3, strides=2, padding="same"),
    keras.layers.ZeroPadding2D(padding=((0,1),(0,1))),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.LeakyReLU(alpha=.2),
    keras.layers.Dropout(0.25),
    keras.layers.Conv2D(64, kernel_size=3, strides=2, padding="same"),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.LeakyReLU(alpha=.2),
    keras.layers.Dropout(0.25),
    keras.layers.Conv2D(128, kernel_size=3, strides=1, padding="same"),
    keras.layers.BatchNormalization(momentum=0.8),
    keras.layers.LeakyReLU(alpha=.2),
    keras.layers.Dropout(0.25),
    keras.layers.Flatten(),
    keras.layers.Dense(1)
])

model.summary()

img = keras.Input(shape=img_shape)
validity = model(img)

critic =  keras.Model(img, validity)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_4 (Conv2D)            (None, 14, 14, 16)        160       
_________________________________________________________________
dropout_1 (Dropout)          (None, 14, 14, 16)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 7, 7, 32)          4640      
_________________________________________________________________
zero_padding2d_1 (ZeroPaddin (None, 8, 8, 32)          0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 8, 8, 32)          128       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 8, 8, 32)          0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 8, 8, 32)          0         
__________

In [5]:
class RandomWeightedAverage(keras.layers.merge._Merge):
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        alpha = K.random_uniform((32, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

In [6]:
def gradient_penalty_loss(y_true, y_pred, averaged_samples):
    """
    Computes gradient penalty based on prediction and weighted real / fake samples
    """
    gradients = K.gradients(y_pred, averaged_samples)[0]
    # compute the euclidean norm by squaring ...
    gradients_sqr = K.square(gradients)
    #   ... summing over the rows ...
    gradients_sqr_sum = K.sum(gradients_sqr,
                              axis=np.arange(1, len(gradients_sqr.shape)))
    #   ... and sqrt
    gradient_l2_norm = K.sqrt(gradients_sqr_sum)
    # compute lambda * (1 - ||grad||)^2 still for each single sample
    gradient_penalty = K.square(1 - gradient_l2_norm)
    # return the mean as loss over all the batch samples
    return K.mean(gradient_penalty)

In [7]:
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

In [8]:
#-------------------------------
# Construct Computational Graph
#       for the Critic
#-------------------------------

# Freeze generator's layers while training critic
generator.trainable = False

# Image input (real sample)
real_img = keras.Input(shape=img_shape)

# Noise input
z_disc = keras.Input(shape=(latent_dim,))
# Generate image based of noise (fake sample)
fake_img = generator(z_disc)

# Discriminator determines validity of the real and fake images
fake = critic(fake_img)
valid = critic(real_img)

# Construct weighted average between real and fake images
interpolated_img = RandomWeightedAverage()([real_img, fake_img])
# Determine validity of weighted sample
validity_interpolated = critic(interpolated_img)

In [9]:
# Use Python partial to provide loss function with additional
# 'averaged_samples' argument
partial_gp_loss = partial(gradient_penalty_loss,
                  averaged_samples=interpolated_img)
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names

critic_model = keras.Model(inputs=[real_img, z_disc],
                    outputs=[valid, fake, validity_interpolated])
critic_model.compile(loss=[wasserstein_loss,
                                      wasserstein_loss,
                                      partial_gp_loss],
                                optimizer=optimizer,
                                loss_weights=[1, 1, 10])

In [10]:
#-------------------------------
# Construct Computational Graph
#         for Generator
#-------------------------------

# For the generator we freeze the critic's layers
critic.trainable = False
generator.trainable = True

# Sampled noise for input to generator
z_gen = keras.Input(shape=(latent_dim,))
# Generate images based of noise
img = generator(z_gen)
# Discriminator determines validity
valid = critic(img)
# Defines generator model
generator_model = keras.Model(z_gen, valid)
generator_model.compile(loss=wasserstein_loss, optimizer=optimizer)

In [11]:
import shutil, os

def train(epochs, batch_size, sample_interval=50):

    shutil.rmtree("images", ignore_errors=True)
    os.makedirs(f"images", exist_ok=True)
    
    # Load the dataset
    (X_train, _), (_, _) = keras.datasets.mnist.load_data()

    # Rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)

    # Adversarial ground truths
    valid = -np.ones((batch_size, 1))
    fake =  np.ones((batch_size, 1))
    dummy = np.zeros((batch_size, 1)) # Dummy gt for gradient penalty
    for epoch in range(epochs):

        for _ in range(n_critic):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            # Sample generator input
            noise = np.random.normal(0, 1, (batch_size, latent_dim))
            # Train the critic
            d_loss = critic_model.train_on_batch([imgs, noise],
                                                            [valid, fake, dummy])

        # ---------------------
        #  Train Generator
        # ---------------------

        g_loss = generator_model.train_on_batch(noise, valid)
        

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
            sample_images(epoch)

def sample_images(epoch):
    # make a video with 
    # >ffmpeg -framerate 4 -pattern_type glob -i mnist_*.png -pix_fmt yuv420p output.mp4
    
    
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, latent_dim))
    gen_imgs = generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 1

    fig, axs = plt.subplots(r, c, sharex=True, sharey=True, frameon=False, figsize=(5,5))
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray', aspect="auto")
            axs[i,j].axis('off')
            cnt += 1
    plt.tight_layout(h_pad=0, w_pad=0)
    plt.suptitle(f'epoch: {epoch}', backgroundcolor="white")
    fig.savefig(f"images/mnist_{epoch:05d}.png", dpi=150)
    plt.close()

In [12]:
train(epochs=30000, batch_size=32, sample_interval=100)

  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 8.376334] [G loss: 0.348251]
100 [D loss: -4.494257] [G loss: -0.427195]
200 [D loss: -0.346024] [G loss: -0.547124]
300 [D loss: -0.567543] [G loss: -1.540657]
400 [D loss: -0.581732] [G loss: -1.897855]
500 [D loss: -0.048570] [G loss: -1.934634]
600 [D loss: -0.864462] [G loss: -0.926045]
700 [D loss: -0.500787] [G loss: -0.586527]
800 [D loss: -0.102128] [G loss: -1.145723]
900 [D loss: -0.626080] [G loss: -1.844454]
1000 [D loss: -1.185885] [G loss: -1.580364]
1100 [D loss: -0.884991] [G loss: -1.315495]
1200 [D loss: -1.657865] [G loss: -1.469141]
1300 [D loss: -1.016559] [G loss: -1.626222]
1400 [D loss: -1.173788] [G loss: 0.423203]
1500 [D loss: -1.480434] [G loss: -0.033803]
1600 [D loss: -2.240587] [G loss: 0.816050]
1700 [D loss: -1.345939] [G loss: 0.523644]
1800 [D loss: -1.398457] [G loss: 0.154413]
1900 [D loss: -0.983535] [G loss: 0.916479]
2000 [D loss: -0.712417] [G loss: 0.757155]
2100 [D loss: -1.197576] [G loss: 1.079333]
2200 [D loss: -1.535270] [G los

18500 [D loss: -0.120024] [G loss: 1.457533]
18600 [D loss: 0.160457] [G loss: 2.101882]
18700 [D loss: 0.463108] [G loss: 2.499120]
18800 [D loss: -0.467658] [G loss: 2.319057]
18900 [D loss: -1.764443] [G loss: 1.769130]
19000 [D loss: -0.143182] [G loss: 3.475824]
19100 [D loss: -0.693433] [G loss: 3.007635]
19200 [D loss: 0.071167] [G loss: 2.587565]
19300 [D loss: -0.624237] [G loss: 2.982646]
19400 [D loss: 0.010711] [G loss: 2.440885]
19500 [D loss: 0.010004] [G loss: 1.869667]
19600 [D loss: -0.401579] [G loss: 1.121923]
19700 [D loss: -0.636725] [G loss: 2.218659]
19800 [D loss: -0.410405] [G loss: 2.885381]
19900 [D loss: 0.205634] [G loss: 1.868466]
20000 [D loss: 0.293200] [G loss: 1.573007]
20100 [D loss: -1.170002] [G loss: 1.313784]
20200 [D loss: 0.299144] [G loss: 0.890134]
20300 [D loss: 0.286147] [G loss: 2.271142]
20400 [D loss: -0.759467] [G loss: 2.196149]
20500 [D loss: -0.078327] [G loss: 0.771504]
20600 [D loss: 0.419600] [G loss: 0.829716]
20700 [D loss: 0.518