Skip to content

ChengBinJin/DiscoGAN-TensorFlow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 

Repository files navigation

DiscoGAN-TensorFlow

This repository is a Tensorflow implementation of DiscoGAN, ICML2017.

  • All samples in README.md are genearted by neural network except the first image for each row.

Requirements

  • tensorflow 1.10.0
  • python 3.5.3
  • numpy 1.14.2
  • opencv 3.2.0
  • matplotlib 2.2.2
  • scipy 0.19.1
  • pillow 5.0.0

Applied GAN Structure

  1. Generator

  1. Discriminator

Generated Images

1. Toy Dataset

Results from 2-dimensional Gaussian Mixture Models. Ipython Notebook.
(A) Original GAN
(B) GAN with Reconstruction Loss
(C) Domain A to B of DiscoGAN
(D) Domain B to A of DiscoGAN

2. Handbags2Shoes Dataset

  • handbag -> shoe -> handbag

  • shoe -> handbag -> shoe

3. edges2shoes

  • edge -> shoe -> edge

  • shoe -> edge -> shoe

4. edges2handbags

  • edge -> handbag -> edg

  • handbag -> edge -> handbag

5. cityscapes

  • RGB image -> segmentation label -> RGB image

  • segmentation label -> RGB image -> segmentation label

6. facades

  • RGB image -> segmentation label -> RGB image

  • segmentation label -> RGB image -> segmentation label

7. maps

  • RGB image -> segmentation label -> RGB image

  • segmentation label -> RGB image -> segmentation label

Documentation

Download Dataset

Download edges2shoes, edges2handbags, cityscapes, facades, and maps datasets from pix2pix first. Use the following command to download datasets and copy the datasets on the corresponding file as introduced in Directory Hierarchy information.

python download.py

Directory Hierarchy

.
│   DiscoGAN
│   ├── src
│   │   ├── dataset.py
│   │   ├── discogan.py
│   │   ├── download.py
│   │   ├── main.py
│   │   ├── reader.py
│   │   ├── solver.py
│   │   ├── tensorflow_utils.py
│   │   └── utils.py
│   Data
│   ├── cityscapes
│   ├── edges2handbags
│   ├── edge2shoes
│   ├── facades
│   └── maps

src: source codes of the WGAN

Implementation Details

Implementation uses TensorFlow to train the DiscoGAN. Same generator and critic networks are used as described in DiscoGAN paper. We applied learning rate control that started at 2e-4 for the first 1e5 iterations, and decayed linearly to zero as cycleGAN. It's helpful to overcome mode collapse problem.

To respect the original discoGAN paper we set the balance between GAN loss and reconstruction loss are 1:1. Therefore, discoGAN is not good at A -> B -> A. However, in the cycleGAN the ratio is 1:10. So the reconstructed image is still very similar to input image.

The official code of DiscoGAN implemented by pytorch that used weigt decay. Unfortunately, tensorflow is not support weight deacy as I know. I used regularization term instead of weight decay. So the performance maybe a little different with original one.

Training DiscoGAN

Use main.py to train a DiscoGAN network. Example usage:

python main.py
  • gpu_index: gpu index, default: 0

  • batch_size: batch size for one feed forward, default: 200

  • dataset: dataset name from [edges2handbags, edges2shoes, handbags2shoes, maps, cityscapes, facades], default: facades

  • is_train: training or inference mode, default: True

  • learning_rate: initial learning rate for Adam, default: 0.0002

  • beta1: beta1 momentum term of Adam, default: 0.5

  • beta2: beta2 momentum term of Adam, default: 0.999

  • weight_decay: hyper-parameter for regularization term, default: 1e-4

  • iters: number of interations, default: 100000

  • print_freq: print frequency for loss, default: 100

  • save_freq: save frequency for model, default: 10000

  • sample_freq: sample frequency for saving image, default: 500

  • sample_batch: number of sampling images for check generator quality, default: 200

  • load_model: folder of save model that you wish to test, (e.g. 20180907-1739). default: None

Test DiscoGAN

Use main.py to test a DiscoGAN network. Example usage:

python main.py --is_train=false --load_model=folder/you/wish/to/test/e.g./20180926-1739

Please refer to the above arguments.

Citation

  @misc{chengbinjin2018discogan,
    author = {Cheng-Bin Jin},
    title = {DiscoGAN-tensorflow},
    year = {2018},
    howpublished = {\url{https://github.com/ChengBinJin/DiscoGAN-TensorFlow}},
    note = {commit xxxxxxx}
  }

Attributions/Thanks

License

Copyright (c) 2018 Cheng-Bin Jin. Contact me for commercial use (or rather any use that is not academic research) (email: sbkim0407@gmail.com). Free for research use, as long as proper attribution is given and this copyright notice is retained.

Related Projects

About

DiscoGAN TensorFlow Implementation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published