# How to Implement Wasserstein Loss for Generative Adversarial Networks

[Theory](https://machinelearningmastery.com/how-to-implement-wasserstein-loss-for-generative-adversarial-networks/)




The Wasserstein Generative Adversarial Network, or Wasserstein GAN, is an extension to the generative adversarial network that both improves the __stability__ when training the model and provides a __loss function that correlates with the quality of generated images__.

It is an important extension to the GAN model and requires a __conceptual shift away from a discriminator__ that predicts the probability of a generated image being “real” and toward the idea of a critic model that scores the “realness” of a given image.

This conceptual shift is motivated mathematically using the __earth mover distance__, or __Wasserstein distance__, to train the GAN that measures the ___distance between the data distribution observed in the training dataset and the distribution observed in the generated examples___.

In this post, you will discover how to implement Wasserstein loss for Generative Adversarial Networks.

After reading this post, you will know:

- The conceptual shift in the WGAN from discriminator predicting a probability to a critic predicting a score.
- The implementation details for the WGAN as minor changes to the standard deep convolutional GAN.
- The intuition behind the Wasserstein loss function and how implement it from scratch.

# Overview
This tutorial is divided into five parts; they are:

1. GAN Stability and the Discriminator
2. What is a Wasserstein GAN?
3. Implementation Details of the Wasserstein GAN
4. How to Implement Wasserstein Loss
5. Common Point of Confusion With Expected Labels

GAN Stability and the Discriminator Generative Adversarial Networks, or GANs, are challenging to train.

The discriminator model must classify a given input image as real (from the dataset) or fake (generated), and the generator model must generate new and plausible images.

The reason GANs are __difficult__ to train is that the architecture involves the __simultaneous training of a generator and a discriminator model in a zero-sum game__. Stable training requires finding and maintaining an __equilibrium__ between the capabilities of the two models.

The __discriminator__ model is a neural network that learns a binary classification problem, using a __sigmoid__ activation function in the output layer, and is fit using a __binary cross entropy__ loss function. As such, the model predicts a probability that a given input is real (or fake as 1 minus the predicted) as a value between 0 and 1.

The loss function has the effect of __penalizing__ the model ___proportionally to how far the predicted probability distribution differs from the expected probability distribution for a given image___. This provides the basis for the error that is back propagated through the discriminator and the generator in order to perform better on the next batch.

The WGAN relaxes the role of the discriminator when training a GAN and proposes the alternative of a ___critic___.

# What is a Wasserstein GAN?
The Wasserstein GAN, or WGAN for short, was introduced by Martin Arjovsky, et al. in their 2017 paper titled “Wasserstein GAN.”

It is an extension of the GAN that seeks an alternate way of training the generator model to better approximate the distribution of data observed in a given training dataset.

Instead of using a discriminator to classify or predict the probability of generated images as being real or fake, the WGAN changes or replaces the discriminator model with a ___critic___ that ___scores the realness or fakeness of a given image___.

This change is motivated by a mathematical argument that training the ___generator___ should seek a __minimization__ of the __distance__ between the distribution of the data observed in the ___training__ dataset and the distribution observed in ___generated___ examples. The argument contrasts different distribution distance measures, such as ___Kullback-Leibler (KL) divergence, Jensen-Shannon (JS) divergence, and the Earth-Mover (EM) distance, referred to as Wasserstein distance.___

_The most fundamental difference between such distances is their impact on the convergence of sequences of probability distributions._

— Wasserstein GAN, 2017.

They demonstrate that a __critic__ neural network can be trained to __approximate the Wasserstein distance__, and, in turn, used to effectively train a generator model.

_… we define a form of GAN called Wasserstein-GAN that minimizes a reasonable and efficient approximation of the EM distance, and we theoretically show that the corresponding optimization problem is sound._

— Wasserstein GAN, 2017.

Importantly, the __Wasserstein distance__ has the properties that it is __continuous__ and __differentiable__ and _continues to provide a linear gradient, even after the critic is well trained._

_The fact that the EM distance is continuous and differentiable a.e. means that we can (and should) ___train the critic till optimality.___ […] _the more we train the critic, the more reliable gradient of the Wasserstein we get_, which is actually useful by the fact that Wasserstein is differentiable almost everywhere._

— Wasserstein GAN, 2017.

___This is unlike the discriminator model that, once trained, may fail to provide useful gradient information for updating the generator model.___

_The __discriminator__ learns ___very quickly___ to distinguish between fake and real, and as expected ___provides no reliable gradient information___. The __critic__, however, ___can’t saturate___ and converges to a linear function that gives remarkably clean gradients everywhere._

— Wasserstein GAN, 2017.

The benefit of the __WGAN__ is that the training process is more stable and ___less sensitive to model architecture and choice of hyperparameter configurations.___

_… training WGANs ___does not___ require maintaining a ___careful balance___ in training of the discriminator and the generator, and does not require a careful design of the network architecture either. The __mode dropping phenomenon__ that is typical in GANs is also ___drastically reduced___._

— Wasserstein GAN, 2017.

Perhaps most importantly, the ___loss of the discriminator appears to relate to the quality of images created by the generator.___

Specifically, ___the lower___ the __loss of the critic__ when evaluating generated images, ___the higher___ the expected __quality of the generated images__. This is important as __unlike other GANs__ that seek stability in terms of finding an __equilibrium__ between two models, the __WGAN__ seeks __convergence__, lowering generator loss.

__Way to evaluate__

_To our knowledge, this is the first time in GAN literature that such a property is shown, where the __loss of the GAN shows properties of convergence__. This property is extremely useful when doing research in adversarial networks as __one does not need to stare at the generated samples to figure out failure modes and to gain information on which models are doing better over others__._

— Wasserstein GAN, 2017.

# Implementation Details of the Wasserstein GAN
Although the theoretical grounding for the WGAN is dense, the implementation of a WGAN requires a _few minor_ changes to the standard deep convolutional GAN, or DCGAN.

Those changes are as follows:

- Use a linear activation function in the output layer of the critic model (instead of sigmoid).

- Use __Wasserstein loss__ to train the critic and generator models that promote larger difference between scores for real and generated images.

- Constrain critic model weights to a limited range after each mini batch update (e.g. [-0.01,0.01]).

_In order to have parameters w lie in a compact space, something simple we can do is clamp the weights to a fixed box (say W = [−0.01, 0.01]) after each gradient update._

— Wasserstein GAN, 2017.

- Update the critic model more times than the generator each iteration (e.g. 5).

- Use the RMSProp version of gradient descent with small learning rate and no momentum (e.g. 0.00005).

_… we report that WGAN training becomes unstable at times when one uses a momentum based optimizer such as Adam […] We therefore switched to RMSProp …_

— Wasserstein GAN, 2017.

The image below provides a summary of the main training loop for training a WGAN, taken from the paper. Note the listing of recommended hyperparameters used in the model.

![WGAN](https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/05/Algorithm-for-the-Wasserstein-Generative-Adversarial-Networks.png)

# How to Implement Wasserstein Loss
The Wasserstein loss function seeks to increase the gap between the scores for real and generated images. Objective is to minimize this loss, or min the gap.

We can summarize the function as it is described in the paper as follows:

- Critic Loss = [average critic score on real images] – [average critic score on fake images] --> to minimize the distance between real and fake distributions.
- Generator Loss = -[average critic score on fake images] -->maximize critic score on fake image (so the negative sign, or the maximization of critic score on fake images).

When we had a discriminator, we used to feed fake generated samples, with "real" labels "=1", so that if the discriminator says it's fake "=0", it gets more penalized, so to encourage the generator to fool the discriminator. This is the same case here, where we feed fake generated samples, and we penalize the high critic score (by the negative sign), so the generator gets updated to fool the critic to give low score and get fooled (lower distance means same as real distribution)

Where the average scores are calculated across a mini-batch of samples.

This is precisely how the loss is implemented for graph-based deep learning frameworks such as PyTorch and TensorFlow.

The calculations are straightforward to interpret once we recall that stochastic gradient descent seeks to minimize loss.

In the case of the generator, a larger score from the critic will result in a smaller loss for the generator, encouraging the critic to output larger scores for fake images. For example, an average score of 10 becomes -10, an average score of 50 becomes -50, which is smaller, and so on.

In the case of the critic, a larger score for real images results in a larger resulting loss for the critic, penalizing the model. This encourages the critic to output smaller scores for real images. For example, an average score of 20 for real images and 50 for fake images results in a loss of -30; an average score of 10 for real images and 50 for fake images results in a loss of -40, which is better, and so on.

The sign of the loss does not matter in this case, as long as loss for real images is a small number and the loss for fake images is a large number. The Wasserstein loss encourages the critic to separate these numbers.

We can also reverse the situation and encourage the critic to output a large score for real images and a small score for fake images and achieve the same result. Some implementations make this change.

# Keras as high level API framework
In the Keras deep learning library (and some others), we cannot implement the Wasserstein loss function directly as described in the paper and as implemented in PyTorch and TensorFlow. Instead, we can achieve the same effect without having the calculation of the loss for the critic dependent upon the loss calculated for real and fake images.

A good way to think about this is a __negative score for real images__ and a __positive score for fake images__, although this negative/positive split of scores learned during training is not required; __just larger and smaller is sufficient.__

- Small Critic Score (e.g.< 0): Real – Large Critic Score (e.g. >0): Fake

We can multiply the __average predicted score by -1__ in the case of __fake__ images so that larger averages become smaller averages and the gradient is in the correct direction, i.e. minimizing loss. For example, average scores on fake images of [0.5, 0.8, and 1.0] across three batches of fake images would become [-0.5, -0.8, and -1.0] when calculating weight updates.

- Loss For Fake Images = -1 * Average Critic Score

_No change is needed for the case of real scores, as we want to encourage smaller average scores for real images._

- Loss For Real Images = Average Critic Score

This can be implemented consistently by assigning an ___expected outcome target of -1 for fake images and 1 for real images___ and implementing the __loss__ function as the ___expected label multiplied by the average score___. The -1 label will be multiplied by the average score for fake images and encourage a larger predicted average, and the +1 label will be multiplied by the average score for real images and have no effect, encouraging a smaller predicted average.

Wasserstein Loss = Label * Average Critic Score
Or

- Wasserstein Loss(Real Images) = 1 * Average Predicted Score
- Wasserstein Loss(Fake Images) = -1 * Average Predicted Score

__We can implement this in Keras by assigning the expected labels of -1 and 1 for fake and real images respectively.__ 

The __inverse labels could be used to the same effect__, e.g. -1 for real and +1 for fake to encourage small scores for fake images and large scores for real images. Some developers do implement the WGAN in this alternate way, which is just as correct.

___The loss function can be implemented by multiplying the expected label for each sample by the predicted score (element wise), then calculating the mean.___

```
def wasserstein_loss(y_true, y_pred):
	return mean(y_true * y_pred)
```

The above function is the elegant way to implement the loss function; an alternative, less-elegant implementation that might be more intuitive is as follows:

```
def wasserstein_loss(y_true, y_pred):
 	return mean(y_true) * mean(y_pred)
```

In Keras, the mean function can be implemented using the Keras __backend__ API to ensure the mean is calculated across samples in the provided __tensors__; for example:

```
from keras import backend

# implementation of wasserstein loss
def wasserstein_loss(y_true, y_pred):
	return backend.mean(y_true * y_pred)
```

To understand it better, recall the WGAN algo above, in the loss, any real input is a positive term, while any fake is a negative term. The loss is just the mean of all the critic predictions then.

Now that we know how to implement the Wasserstein loss function in Keras, let’s clarify one common point of misunderstanding.

# Common Point of Confusion With Expected Labels
Recall we are using the expected labels of -1 for fake images and +1 for real images.

A common point of confusion is that a perfect critic model will output -1 for every fake image and +1 for every real image.

___This is incorrect.___

Again, recall we are using stochastic gradient descent to find the set of weights in the critic (and generator) models that minimize the loss function.

We have established that we want the critic model to output larger scores on average for fake images and smaller scores on average for real images. We then designed a loss function to encourage this outcome.

This is the key point about loss functions used to train neural network models. They encourage a desired model behavior, and they do not have to achieve this by providing the expected outcomes. In this case, we defined our Wasserstein loss function to interpret the average score predicted by the critic model and used labels for the real and fake cases to help with this interpretation.

__So what is a good loss for real and fake images under Wasserstein loss?__

Wasserstein is not an absolute and comparable loss for comparing across GAN models. Instead, it is relative and depends on your model configuration and dataset. What is important is that it is consistent for a given critic model and ___convergence of the generator (better loss) does correlate with better generated image quality.___

It could be negative scores for real images and positive scores for fake images, but this is not required. All scores could be positive or all scores could be negative.

_The loss function only encourages a separation between scores for fake and real images as larger and smaller, not necessarily positive and negative._

# How to Develop a Wasserstein Generative Adversarial Network (WGAN) From Scratch
[Source](https://machinelearningmastery.com/how-to-code-a-wasserstein-generative-adversarial-network-wgan-from-scratch/)

To summarize, the differences in implementation for the WGAN are as follows:

- Use a linear activation function in the output layer of the critic model (instead of sigmoid).
- Use -1 labels for real images and 1 labels for fake images (instead of 1 and 0).
- Use Wasserstein loss to train the critic and generator models.
- Constrain critic model weights to a limited range after each mini batch update (e.g. [-0.01,0.01]).
- Update the critic model more times than the generator each iteration (e.g. 5).
- Use the RMSProp version of gradient descent with a small learning rate and no momentum (e.g. 0.00005).
- Using the standard DCGAN model as a starting point, let’s take a look at each of these implementation details in turn.

## Linear Activation in Critic Output Layer
The DCGAN uses the sigmoid activation function in the output layer of the discriminator to predict the likelihood of a given image being real.

In the WGAN, the critic model requires a linear activation to predict the score of “realness” for a given image.

This can be achieved by setting the ‘activation‘ argument to ‘linear‘ in the output layer of the critic model.
```
# define output layer of the critic model
...
model.add(Dense(1, activation='linear'))
```

The linear activation is the default activation for a layer, so we can, in fact, leave the activation unspecified to achieve the same result.

```
# define output layer of the critic model
...
model.add(Dense(1))
```

## Class Labels for Real and Fake Images
The DCGAN uses the class 0 for fake images and class 1 for real images, and these class labels are used to train the GAN.

In the DCGAN, these are precise labels that the discriminator is expected to achieve. 
___The WGAN does not have precise labels for the critic.___ 

__Instead, it encourages the critic to output scores that are different for real and fake images.__

This is achieved via the Wasserstein function that cleverly makes use of positive and negative class labels.

The WGAN can be implemented where -1 class labels are used for real images and +1 class labels are used for fake or generated images.

This can be achieved using the ones() NumPy function.

For example:

```
...
# generate class labels, -1 for 'real'
y = -ones((n_samples, 1))
...
# create class labels with 1.0 for 'fake'
y = ones((n_samples, 1))
```

# Wasserstein Loss Function
The DCGAN trains the discriminator as a binary classification model to predict the probability that a given image is real.

To train this model, the discriminator is optimized using the binary cross entropy loss function. The same loss function is used to update the generator model.

The primary contribution of the WGAN model is the use of a new loss function that encourages the discriminator to predict a score of ___how real or fake a given input looks___. This transforms the role of the discriminator from a classifier into a critic for scoring the realness or fakeness of images, where the difference between the scores is as large as possible.

We can implement the Wasserstein loss as a custom function in Keras that calculates the average score for real or fake images.

The score is maximizing for real examples and minimizing for fake examples. Given that stochastic gradient descent is a minimization algorithm, we can multiply the class label by the mean score (e.g. -1 for real and 1 for fake which as no effect), which ensures that the loss for real and fake images is minimizing to the network.

An efficient implementation of this loss function for Keras is listed below.

In [1]:
from keras import backend

# implementation of wasserstein loss
def wasserstein_loss(y_true, y_pred):
	return backend.mean(y_true * y_pred)

Using TensorFlow backend.


This loss function can be used to train a Keras model by specifying the function name when compiling the model.

For example:
```
...
# compile the model
model.compile(loss=wasserstein_loss, ...)
```

# Critic Weight Clipping
The DCGAN does not use any gradient clipping, although the WGAN requires gradient clipping for the critic model.

We can implement weight clipping as a __Keras constraint.__

This is a class that must extend the Constraint class and define an implementation of the __call__() function for applying the operation and the get_config() function for returning any configuration.

We can also define an __init__() function to set the configuration, in this case, the symmetrical size of the bounding box for the weight hypercube, e.g. 0.01.

The ClipConstraint class is defined below.

```
# clip model weights to a given hypercube
class ClipConstraint(Constraint):
	# set clip value when initialized
	def __init__(self, clip_value):
		self.clip_value = clip_value

	# clip model weights to hypercube
	def __call__(self, weights):
		return backend.clip(weights, -self.clip_value, self.clip_value)

	# get the config
	def get_config(self):
		return {'clip_value': self.clip_value}
```

To use the constraint, the class can be constructed, then used in a layer by setting the ___```kernel_constraint```___ argument; for example:

```
...
# define the constraint
const = ClipConstraint(0.01)
...
# use the constraint in a layer
model.add(Conv2D(..., kernel_constraint=const))
```

The constraint is only required when updating the critic model.

# Update Critic More Than Generator
In the DCGAN, the generator and the discriminator model must be updated in equal amounts.

Specifically, the discriminator is updated with a half batch of real and a half batch of fake samples each iteration, whereas the generator is updated with a single batch of generated samples.

For example:
```
...
# main gan training loop
for i in range(n_steps):

	# update the discriminator

	# get randomly selected 'real' samples
	X_real, y_real = generate_real_samples(dataset, half_batch)
	# update critic model weights
	c_loss1 = c_model.train_on_batch(X_real, y_real)
	# generate 'fake' examples
	X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
	# update critic model weights
	c_loss2 = c_model.train_on_batch(X_fake, y_fake)

	# update generator

	# prepare points in latent space as input for the generator
	X_gan = generate_latent_points(latent_dim, n_batch)
	# create inverted labels for the fake samples
	y_gan = ones((n_batch, 1))
	# update the generator via the critic's error
	g_loss = gan_model.train_on_batch(X_gan, y_gan)
```

In the WGAN model, the critic model must be updated more than the generator model.

Specifically, a new hyperparameter is defined to control the number of times that the critic is updated for each update to the generator model, called n_critic, and is set to 5.

This can be implemented as a new loop within the main GAN update loop; for example:
```
...
# main gan training loop
for i in range(n_steps):

	# update the critic
	for _ in range(n_critic):
		# get randomly selected 'real' samples
		X_real, y_real = generate_real_samples(dataset, half_batch)
		# update critic model weights
		c_loss1 = c_model.train_on_batch(X_real, y_real)
		# generate 'fake' examples
		X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
		# update critic model weights
		c_loss2 = c_model.train_on_batch(X_fake, y_fake)

	# update generator

	# prepare points in latent space as input for the generator
	X_gan = generate_latent_points(latent_dim, n_batch)
	# create inverted labels for the fake samples
	y_gan = ones((n_batch, 1))
	# update the generator via the critic's error
	g_loss = gan_model.train_on_batch(X_gan, y_gan)
```

# Use RMSProp Stochastic Gradient Descent
The DCGAN uses the Adam version of stochastic gradient descent with a small learning rate and modest momentum.

The WGAN recommends the use of RMSProp instead, with a small learning rate of 0.00005.

This can be implemented in Keras when the model is compiled. For example:
```
...
# compile model
opt = RMSprop(lr=0.00005)
model.compile(loss=wasserstein_loss, optimizer=opt)
```

# How to Train a Wasserstein GAN Model
Now that we know the specific implementation details for the WGAN, we can implement the model for image generation.

## MNIST
In this section, we will develop a WGAN to generate a single handwritten digit (‘7’) from the MNIST dataset. This is a good test problem for the WGAN as it is a small dataset requiring a modest mode that is quick to train.

The first step is to define the models.

## Critic
The critic model takes as input one 28×28 grayscale image and outputs a score for the realness or fakeness of the image. It is implemented as a modest convolutional neural network using best practices for DCGAN design such as using the LeakyReLU activation function with a slope of 0.2, batch normalization, and using a 2×2 stride to downsample.

The critic model makes use of the new ClipConstraint weight constraint to clip model weights after mini-batch updates and is optimized using the custom wasserstein_loss() function, the RMSProp version of stochastic gradient descent with a learning rate of 0.00005.

The `define_critic()` function below implements this, defining and compiling the critic model and returning it. The input shape of the image is parameterized as a default function argument to make it clear.

In [0]:
# define the standalone critic model
def define_critic(in_shape=(28,28,1)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# weight constraint
	const = ClipConstraint(0.01)
	# define model
	model = Sequential()
	# downsample to 14x14
	model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	# downsample to 7x7
	model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	# scoring, linear activation
	model.add(Flatten())
	model.add(Dense(1))
	# compile model
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

# Generator
The generator model takes as input a point in the latent space and outputs a single 28×28 grayscale image.

This is achieved by using a fully connected layer to interpret the point in the latent space and provide sufficient activations that can be reshaped into many copies (in this case, 128) of a low-resolution version of the output image (e.g. 7×7). This is then upsampled two times, doubling the size and quadrupling the area of the activations each time using transpose convolutional layers.

The model uses best practices such as the LeakyReLU activation, a kernel size that is a factor of the stride size, and a hyperbolic tangent (tanh) activation function in the output layer.

The `define_generator()` function below defines the __generator model__ but intentionally __does not compile__ it as it is not trained directly, then returns the model. The size of the latent space is parameterized as a function argument.

In [0]:
# define the standalone generator model
def define_generator(latent_dim):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# define model
	model = Sequential()
	# foundation for 7x7 image
	n_nodes = 128 * 7 * 7
	model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
	model.add(LeakyReLU(alpha=0.2))
	model.add(Reshape((7, 7, 128)))
	# upsample to 14x14
	model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	# upsample to 28x28
	model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
	model.add(BatchNormalization())
	model.add(LeakyReLU(alpha=0.2))
	# output 28x28x1
	model.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))
	return model

# WGAN
Next, a GAN model can be defined that combines both the generator model and the critic model into one larger model.

This larger model will be used to train the model weights in the generator, using the output and error calculated by the critic model. The __critic model__ is trained separately, and as such, the model weights are marked as __not trainable__ in this larger GAN model to ensure that only the weights of the generator model are updated. This change to the trainability of the critic weights only has an effect when training the combined GAN model, not when training the critic standalone.

This larger GAN model takes as input a point in the latent space, uses the generator model to generate an image, which is fed as input to the critic model, then output scored as real or fake. The model is fit using RMSProp with the custom `wasserstein_loss()` function.

The `define_gan()` function below implements this, taking the already defined generator and critic models as input.

In [0]:
# define the combined generator and critic model, for updating the generator
def define_gan(generator, critic):
	# make weights in the critic not trainable
	critic.trainable = False
	# connect them
	model = Sequential()
	# add generator
	model.add(generator)
	# add the critic
	model.add(critic)
	# compile model
	opt = RMSprop(lr=0.00005)
	model.compile(loss=wasserstein_loss, optimizer=opt)
	return model

Now that we have defined the GAN model, we need to train it. But, before we can train the model, we require input data.

## Data preparation
The first step is to load and scale the MNIST dataset. The whole dataset is loaded via a call to the load_data() Keras function, then a __subset__ of the images is selected (about 5,000) that belongs to __class 7__, e.g. are a handwritten depiction of the number seven. Then the pixel values must be scaled to the range [-1,1] to match the output of the generator model.

The `load_real_samples()` function below implements this, returning the loaded and scaled subset of the MNIST training dataset ready for modeling.

In [0]:
# load images
def load_real_samples():
	# load dataset
	(trainX, trainy), (_, _) = load_data()
	# select all of the examples for a given class
	selected_ix = trainy == 7
	X = trainX[selected_ix]
	# expand to 3d, e.g. add channels
	X = expand_dims(X, axis=-1)
	# convert from ints to floats
	X = X.astype('float32')
	# scale from [0,255] to [-1,1]
	X = (X - 127.5) / 127.5
	return X

We will require one batch (or a half) batch of real images from the dataset each update to the GAN model. A simple way to achieve this is to __select a random sample of images from the dataset each time__.

The `generate_real_samples()` function below implements this, taking the prepared dataset as an argument, selecting and returning a random sample of images and their corresponding label for the critic, specifically target=-1 indicating that they are real images.

In [0]:
# select real samples
def generate_real_samples(dataset, n_samples):
	# choose random instances
	ix = randint(0, dataset.shape[0], n_samples)
	# select images
	X = dataset[ix]
	# generate class labels, -1 for 'real'
	y = -ones((n_samples, 1))
	return X, y

Next, we need inputs for the generator model. These are random points from the latent space, specifically Gaussian distributed random variables.

The `generate_latent_points()` function implements this, taking the size of the latent space as an argument and the number of points required, and returning them as a batch of input samples for the generator model.

In [0]:
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
	# generate points in the latent space
	x_input = randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	x_input = x_input.reshape(n_samples, latent_dim)
	return x_input

Next, we need to use the points in the latent space as input to the generator in order to generate new images.

The `generate_fake_samples()` function below implements this, taking the generator model and size of the latent space as arguments, then generating points in the latent space and using them as input to the generator model.

The function returns the generated images and their __corresponding label for the critic model__, specifically target=1 to indicate they are fake or generated.

In [0]:
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
	# generate points in latent space
	x_input = generate_latent_points(latent_dim, n_samples)
	# predict outputs
	X = generator.predict(x_input)
	# create class labels with 1.0 for 'fake'
	y = ones((n_samples, 1))
	return X, y

We need to record the performance of the model. Perhaps the most reliable way to evaluate the performance of a GAN is to use the generator to generate images, and then review and subjectively evaluate them.

The `summarize_performance()` function below takes the generator model at a given point during training and uses it to generate 100 images in a 10×10 grid, that are then plotted and saved to file. The model is also saved to file at this time, in case we would like to use it later to generate more images.

In [0]:
# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, latent_dim, n_samples=100):
	# prepare fake examples
	X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
	# scale from [-1,1] to [0,1]
	X = (X + 1) / 2.0
	# plot images
	for i in range(10 * 10):
		# define subplot
		pyplot.subplot(10, 10, 1 + i)
		# turn off axis
		pyplot.axis('off')
		# plot raw pixel data
		pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
	# save plot to file
	filename1 = 'generated_plot_%04d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()
	# save the generator model
	filename2 = 'model_%04d.h5' % (step+1)
	g_model.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

In addition to image quality, it is a good idea to keep track of the loss and accuracy of the model over time.

The loss for the critic for real and fake samples can be tracked for each model update, as can the loss for the generator for each update. These can then be used to create line plots of loss at the end of the training run. The `plot_history()` function below implements this and saves the results to file.

In [0]:
# create a line plot of loss for the gan and save to file
def plot_history(d1_hist, d2_hist, g_hist):
	# plot history
	pyplot.plot(d1_hist, label='crit_real')
	pyplot.plot(d2_hist, label='crit_fake')
	pyplot.plot(g_hist, label='gen')
	pyplot.legend()
	pyplot.savefig('plot_line_plot_loss.png')
	pyplot.close()

We are now ready to fit the GAN model.

The model is fit for 10 training epochs, which is arbitrary, as the model begins generating plausible number-7 digits after perhaps the first few epochs. A batch size of 64 samples is used, and each training epoch involves 6,265/64, or about 97, batches of real and fake samples and updates to the model. The model is therefore trained for 10 epochs of 97 batches, or 970 iterations.

First, the critic model is updated for a half batch of real samples, then a half batch of fake samples, together forming one batch of weight updates. This is then repeated n_critic (5) times as required by the WGAN algorithm.

__Inverse labels for the composite model__

The generator is then updated via the composite GAN model. __Importantly, the target label is set to -1 or real for the generated samples__. 

___This has the effect of updating the generator toward getting better at generating real samples on the next batch.___

___Note that___
This might seem the opposite to what stated in the WGAN algo above, since the when we update the generator, we used the "fake" sign (-1 in that case) not the "real" one. But take care, in the algo above we used to maximize the objective (- sign in the gradient update), while here in the `train` function, we minimize it.
The point is, in both cases, we want to update the generator such that it is penalized when the critic is NOT fooled. So we feed the critic the generated samples, and mark them as "real" (-1 in our case), and if the critic gives low score then it's fooled and no need to update the generator, while if it gives high score, then it's not fooled, and then the generator needs to be updated.

The `train()` function below implements this, taking the defined models, dataset, and size of the latent dimension as arguments and parameterizing the number of epochs and batch size with default arguments. The generator model is saved at the end of training.

The performance of the critic and generator models is reported each iteration. Sample images are generated and saved every epoch, and line plots of model performance are created and saved at the end of the run.

In [0]:
# train the generator and critic
def train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=64, n_critic=5):
	# calculate the number of batches per training epoch
	bat_per_epo = int(dataset.shape[0] / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# calculate the size of half a batch of samples
	half_batch = int(n_batch / 2)
	# lists for keeping track of loss
	c1_hist, c2_hist, g_hist = list(), list(), list()
	# manually enumerate epochs
	for i in range(n_steps):
		# update the critic more than the generator
		c1_tmp, c2_tmp = list(), list()
		for _ in range(n_critic):
			# get randomly selected 'real' samples
			X_real, y_real = generate_real_samples(dataset, half_batch)
			# update critic model weights
			c_loss1 = c_model.train_on_batch(X_real, y_real)
			c1_tmp.append(c_loss1)
			# generate 'fake' examples
			X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
			# update critic model weights
			c_loss2 = c_model.train_on_batch(X_fake, y_fake)
			c2_tmp.append(c_loss2)
		# store critic loss
		c1_hist.append(mean(c1_tmp))
		c2_hist.append(mean(c2_tmp))
		# prepare points in latent space as input for the generator
		X_gan = generate_latent_points(latent_dim, n_batch)
		# create inverted labels for the fake samples
		y_gan = -ones((n_batch, 1))
		# update the generator via the critic's error
		g_loss = gan_model.train_on_batch(X_gan, y_gan)
		g_hist.append(g_loss)
		# summarize loss on this batch
		print('>%d, c1=%.3f, c2=%.3f g=%.3f' % (i+1, c1_hist[-1], c2_hist[-1], g_loss))
		# evaluate the model performance every 'epoch'
		if (i+1) % bat_per_epo == 0:
			summarize_performance(i, g_model, latent_dim)
	# line plots of loss
	plot_history(c1_hist, c2_hist, g_hist)

Now that all of the functions have been defined, we can create the models, load the dataset, and begin the training process.

In [16]:
# example of a wgan for generating handwritten digits
from numpy import expand_dims
from numpy import mean
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras import backend
from keras.optimizers import RMSprop
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.initializers import RandomNormal
from keras.constraints import Constraint
from matplotlib import pyplot

# clip model weights to a given hypercube
class ClipConstraint(Constraint):
	# set clip value when initialized
	def __init__(self, clip_value):
		self.clip_value = clip_value

	# clip model weights to hypercube
	def __call__(self, weights):
		return backend.clip(weights, -self.clip_value, self.clip_value)

	# get the config
	def get_config(self):
		return {'clip_value': self.clip_value}

# calculate wasserstein loss
def wasserstein_loss(y_true, y_pred):
	return backend.mean(y_true * y_pred)

# size of the latent space
latent_dim = 50
# create the critic
critic = define_critic()
# create the generator
generator = define_generator(latent_dim)
# create the gan
gan_model = define_gan(generator, critic)
# load image data
dataset = load_real_samples()
print(dataset.shape)
# train model
train(generator, critic, gan_model, dataset, latent_dim)

W0829 13:58:15.825289 140646808881024 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0829 13:58:15.880085 140646808881024 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0829 13:58:15.887164 140646808881024 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:4115: The name tf.random_normal is deprecated. Please use tf.random.normal instead.

W0829 13:58:15.933376 140646808881024 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W0829 13:58:15.934643 140646808881

Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
(6265, 28, 28, 1)


  'Discrepancy between trainable weights and collected trainable'
  'Discrepancy between trainable weights and collected trainable'


>1, c1=-3.189, c2=-0.014 g=0.140
>2, c1=-7.878, c2=0.014 g=0.971


  'Discrepancy between trainable weights and collected trainable'


>3, c1=-10.536, c2=-0.585 g=1.019
>4, c1=-11.809, c2=-0.310 g=0.948
>5, c1=-13.267, c2=-0.282 g=0.711
>6, c1=-14.437, c2=-0.132 g=0.629
>7, c1=-15.796, c2=-0.484 g=0.368
>8, c1=-16.586, c2=-0.259 g=0.398
>9, c1=-17.957, c2=-0.335 g=-0.049
>10, c1=-18.284, c2=-0.352 g=-0.024
>11, c1=-18.366, c2=-0.345 g=-0.529
>12, c1=-20.076, c2=-0.185 g=-1.085
>13, c1=-19.624, c2=-0.138 g=-1.957
>14, c1=-20.588, c2=-0.266 g=-2.530
>15, c1=-21.053, c2=0.053 g=-3.142
>16, c1=-21.228, c2=-0.123 g=-4.056
>17, c1=-21.307, c2=-0.393 g=-4.903
>18, c1=-21.727, c2=-0.039 g=-5.082
>19, c1=-21.183, c2=-0.038 g=-5.261
>20, c1=-21.732, c2=-0.112 g=-4.768
>21, c1=-21.059, c2=-0.701 g=-4.229
>22, c1=-21.265, c2=-0.786 g=-3.548
>23, c1=-21.614, c2=-1.351 g=-2.527
>24, c1=-21.805, c2=-1.977 g=-1.012
>25, c1=-21.255, c2=-2.822 g=-0.010
>26, c1=-21.056, c2=-3.216 g=1.462
>27, c1=-21.540, c2=-4.062 g=2.640
>28, c1=-21.432, c2=-4.982 g=4.141
>29, c1=-20.803, c2=-5.984 g=5.275
>30, c1=-21.647, c2=-6.434 g=6.460
>31, c1=-21

Your specific results will vary given the stochastic nature of the learning algorithm. Nevertheless, the general structure of training should be very similar.

First, the loss of the critic and generator models is reported to the console each iteration of the training loop. Specifically, c1 is the loss of the critic on real examples, c2 is the loss of the critic in generated samples, and g is the loss of the generator trained via the critic.

The c1 scores are inverted as part of the loss function; this means if they are reported as negative, then they are really positive, and if they are reported as positive, they are really negative. The sign of the c2 scores is unchanged.

Recall that the Wasserstein loss seeks scores for real and fake that are more different during training. We can see this towards the end of the run, such as the final epoch where the c1 loss for real examples is 5.338 (really -5.338) and the c2 loss for fake examples is -14.260, and this separation of about 10 units is consistent at least for the prior few iterations.

We can also see that in this case, the model is scoring the loss of the generator at around 20. Again, recall that we update the generator via the critic model and treat the generated examples as real with the target of -1, therefore the score can be interpreted as a value around -20, close to the loss for fake samples.

Line plots for loss are created and saved at the end of the run.

The plot shows the loss for the critic on real samples (blue), the loss for the critic on fake samples (orange), and the loss for the critic when updating the generator with fake samples (green).

There is one important factor when reviewing learning curves for the WGAN and that is the trend.

The benefit of the WGAN is that the loss correlates with generated image quality. Lower loss means better quality images, for a stable training process.

In this case, lower loss specifically refers to lower Wasserstein loss for generated images as reported by the critic (orange line). This sign of this loss is not inverted by the target label (e.g. the target label is +1.0), therefore, a well-performing WGAN should show this line trending down as the image quality of the generated model is increased.

![learn_curves](https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/05/Line-Plots-of-Loss-and-Accuracy-for-a-Wasserstein-Generative-Adversarial-Network.png)