Skip to content

Tensorflow2 implementation of WGAN-GP (WassersteinGAN with Gradient Penalty)

Notifications You must be signed in to change notification settings

Mohammad-Rahmdel/WassersteinGAN-GradientPenalty-Tensorflow

Repository files navigation

Wasserstein Generative Adversarial Networks with Gradient Penalty (WGAN-GP)

Paper: Improved Training of Wasserstein GANS

Description

Generative adversarial networks (GANs) have shown great results in many generative tasks to replicate the real-world rich content such as images, human language, and music. It is inspired by game theory: two models, a generator and a discriminator, are competing with each other while making each other stronger at the same time. However, it is rather challenging to train a GAN model, as people are facing issues like training instability or failure to converge.

The recently proposed Wasserstein GAN (WGAN) makes progress toward stable training of GANs, but sometimes can still generate only poor samples or fail to converge. The author proves that these problems are often due to the use of weight clipping in WGAN to enforce a Lipschitz constraint on the critic, which can lead to undesired behavior.

In this paper, they propose an alternative to clipping weights: penalize the norm of gradient of the critic with respect to its input. The proposed method performs better than standard WGAN (you can find my WGAN repository here) and enables stable training of a wide variety of GAN architectures with almost no hyperparameter tuning.

Clipping allows us to enforce the Lipschitz constraint on the critic’s model to calculate the Wasserstein distance. However, weight clipping isn't a good way to enforce a Lipschitz constraint. It behaves as a weight regulation. It reduces the capacity of the model and limits the capability to model complex functions. If the clipping parameter is large, then it can take a long time for any weights to reach their limit, thereby making it harder to train the critic till optimality. If the clipping is small, this can easily lead to vanishing gradients when the number of layers is big, or batch normalization is not used (such as in RNNs).

WGAN-GP uses gradient penalty instead of the weight clipping to enforce the Lipschitz constraint.

Gradient penalty

A differentiable function f is 1-Lipschitz if and only if it has gradients with norm at most 1 everywhere. So instead of applying clipping, WGAN-GP penalizes the model if the gradient norm moves away from its target norm value 1.

​​

Experiments

The major advantage of WGAN-GP is its convergency. It makes training more stable and therefore easier to train. As WGAN-GP helps models to converge better, we can use a more complex model like a deep ResNet for the generator and the discriminator.

​​

Algorithm

​​

Tensorflow2 Implementation

I used three different datasets for WGAN-GP. MNIST, CIFAR10-horse class, and Fashion_MNIST.

Results

Fashion MNIST samples during training Cifar10 horse samples during training

Generated samples for MNIST-fashion during training

epoch 10 epoch 60
epoch 150 epoch 350

Generated samples for horse images during training

epoch 1 epoch 600

Generator's outputs after 60 epochs for MNIST

You can find more generated samples here.

References

  1. Improved Training of Wasserstein GANS
  2. Wasserstein GAN

About

Tensorflow2 implementation of WGAN-GP (WassersteinGAN with Gradient Penalty)

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published