Skip to content
This repository has been archived by the owner on May 14, 2023. It is now read-only.
/ WassersteinGAN Public archive

Tensorflow Wasserstein GAN implementation with TFRecord data format. WGAN, WGAN-GP (gradient penalty), DCGAN and showing example usage with CelebA dataset.

License

Notifications You must be signed in to change notification settings

asahi417/WassersteinGAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

40 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Wasserstein GAN (with Gradient Penalty)

dep1 license

Tensorflow implementation of Wasserstein GAN with gradient penalty and DCGAN. Properties are summalized as below

How to use it ?

Clone the repository

git clone https://github.com/asahi417/WassersteinGAN
cd WassersteinGAN
pip install .
mkdir datasets

CelebA dataset need to be downloaded from here, and be located below the directory datasets, so the data directory should be seems like

WassersteinGAN/datasets/celeba/img/img_align_celeba

TFRecord

To produce TFRecord file,

python bin/build_tfrecord.py --data celeba -r 64 -c 108

optional arguments:
    -c [CROP], --crop [CROP] Each image will be center cropped by this integer.
    -r [RESIZE], --resize [RESIZE] Each image will be resized by this integer.

Train Model

python bin/train.py -m [MODEL] -e [EPOCH] -c [CROP] -r [RESIZE] --data [DATA]

optional arguments:
  -m [MODEL], --model [MODEL] `dcgan` or `wgan`
  -e [EPOCH], --epoch [EPOCH] Epoch.
  -c [CROP], --crop [CROP]
  -r [RESIZE], --resize [RESIZE]

Hyperparameters are here.

Visualization

usage: generate_img.py -m [MODEL] -v [VERSION] -c [CROP] -r [RESIZE] --data [DATA]

optional arguments:
  -m [MODEL], --model [MODEL] `dcgan` or `wgan`
  -v [VERSION], --version [VERSION] version of checkpoint
  -c [CROP], --crop [CROP]
  -r [RESIZE], --resize [RESIZE]

Generated Images

Images are arbitrary generated by random variables sampled from normal distribution.

WGAN-GP (Gradient Penalty)


Fig 1: WGAN-GP with epoch 160

WGAN


Fig 2: WGAN with epoch 30

DCGAN


Fig 3: DCGAN with epoch 30

Discussions

To evaluate stability of training, following metrics are considered:

    1. mode collapse
    1. local minima
    1. overfitting
    1. quality of generated images

1. mode collapse

DCGAN easily suffers from mode collapse like, usually you find tendency of mode collapse after 30 ~ 40 epoch.


Fig 4: Example of mode collapse (DCGAN with epoch 30)

I actually wan't able to train model, which can produce variety of images in my experiments. WGAN with and without GP are really good at avoiding mode collapse even after running large epoch.

2. trapped by local minima

DCGAN are often trapped by undesired local minima as well as mode collapse. Here are some examples from DCGAN trapped by local minima. Once a model get trapped, it would be never improved anymore.


Fig 5: Examples of local minima (three separatelly trained DCGAN ending up training with local minima)

WGAN with and without GP seems have capacity to escape from those local minima so that I don't have any cases that WGAN trapped by local minima.

3. overfitting (too strong critic?)

Vanilla WGAN (without GP) is likely to be overfitted after 40 epoch. Even though a model seems properly trained (such as the model produces Fig 2 at epoch 30), it suddenly start to generate crappy images and finish training with something like Fig 6.


Fig 6: Examples of overfitting (WGAN with epoch 50)

It's kinda similar to local minima but unlike trapped by local minima immediately, this phenomena usually allow generator to be trained properly once, then the generator would be messed up after a while.

On the other hand, WGAN-GP have never caused this. From Fig 7, you can see that WGAN-GP actually can generates images with competitive quality even by 50 epoch. WGAN-GP enjoy training over large epoch without any crucial issues.


Fig 7: WGAN-GP with epoch 50


Fig 7: WGAN-GP with epoch 100

In the end, I'm not sure if this can be referred as overfitting so let me know if you have proper name for this type of problem.

4. quality

I have't tried any metrics to evaluate quality of generated image such as inception score. By seeing the images, generated from each models (Fig 1 ~3), WGAN-GP is seems producing the best quality-images.

conclusion

Let's say training model with same hyperaprameters five times:

  • DCGAN: three out of five models would fail (one would end up with mode collapse and two would be trapped by local minima).
  • WGAN: two out of five models would fail (two for overfitting)
  • WGAN-GP: Never fail!

WGAN with GP is the most stable model, which also have capacity to produce diversity image with relatively high quality with comparing vanilla WGAN and DCGAN.

Tips

Here, I listed a few tips, used in this implementations. While it's hard to train without those tips for DCGAN and vanilla WGAN, WGAN-GP does't need any specific tips (it's quite friendly, isn't it?)

Hyperparameters

  • WGAN-GP
{
  "gradient_penalty": 10,
  "batch": 128,
  "optimizer": "adam",
  "learning_rate": 1e-05,
  "n_critic": 5,
  "initializer": "truncated_normal",
  "config": {
    "n_z": 100,
    "image_shape": [
      64,
      64,
      3
    ]
  },
  "config_critic": {
    "mode": "cnn",
    "parameter": {
      "batch_norm": false
    }
  },
  "config_generator": {
    "mode": "cnn",
    "parameter": {
      "batch_norm": true,
      "batch_norm_decay": 0.99,
      "batch_norm_scale": true
    }
  }
}
  • WGAN
{
  "gradient_clip": 0.05,
  "batch": 64,
  "optimizer": "rmsprop",
  "learning_rate": 5e-05,
  "n_critic": 5,
  "initializer": "truncated_normal",
  "overdose": true,
  "config": {
    "n_z": 100,
    "image_shape": [
      64,
      64,
      3
    ]
  },
  "config_critic": {
    "mode": "cnn",
    "parameter": {
      "batch_norm": true,
      "batch_norm_decay": 0.99,
      "batch_norm_scale": true
    }
  },
  "config_generator": {
    "mode": "cnn",
    "parameter": {
      "batch_norm": true,
      "batch_norm_decay": 0.99,
      "batch_norm_scale": true
    }
  }
}
  • DCGAN
{
  "batch": 64,
  "optimizer": "adam",
  "learning_rate": 2e-05,
  "generator_advantage": 2,
  "initializer": "truncated_normal",
  "config": {
    "n_z": 100,
    "image_shape": [
      64,
      64,
      3
    ]
  },
  "config_critic": {
    "mode": "cnn",
    "parameter": {
      "batch_norm": true,
      "batch_norm_decay": 0.99,
      "batch_norm_scale": true
    }
  },
  "config_generator": {
    "mode": "cnn",
    "parameter": {
      "batch_norm": true,
      "batch_norm_decay": 0.99,
      "batch_norm_scale": true
    }
  }
}

About

Tensorflow Wasserstein GAN implementation with TFRecord data format. WGAN, WGAN-GP (gradient penalty), DCGAN and showing example usage with CelebA dataset.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published