Skip to content

Commit

Permalink
restoring the original code
Browse files Browse the repository at this point in the history
  • Loading branch information
Ofir Press committed Jan 24, 2018
1 parent be22214 commit 563fc9f
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 497 deletions.
105 changes: 0 additions & 105 deletions .gitignore

This file was deleted.

27 changes: 1 addition & 26 deletions README.md
Expand Up @@ -2,7 +2,7 @@

Code for training and evaluation of the model from ["Language Generation with Recurrent Generative Adversarial Networks without Pre-training"](https://arxiv.org/abs/1706.01399).

Additional Code for using Fisher GAN in Recurrent Generative Adversarial Networks


### Sample outputs (32 chars)

Expand Down Expand Up @@ -56,7 +56,6 @@ START_SEQ: Sequence length to start the curriculum learning with (defaults to 1)
END_SEQ: Sequence length to end the curriculum learning with (defaults to 32)
SAVE_CHECKPOINTS_EVERY: Save checkpoint every # steps (defaults to 25000)
LIMIT_BATCH: Boolean that indicates whether to limit the batch size (defaults to true)
GAN_TYPE: String Type of GAN to use. Choose between 'wgan' and 'fgan' for wasserstein and fisher respectively
```

Expand All @@ -66,10 +65,6 @@ Parameters can be set by either changing their value in the config file or by pa
python curriculum_training.py --START_SEQ=1 --END_SEQ=32
```

## Monitoring Convergence During Training

In the wasserstein GAN, please monitor the disc_cost. It should be a negative number and approach zero. The disc_cost represents the negative wasserstein distance between gen and critic.

## Generating text

The `generate.py` script will generate `BATCH_SIZE` samples using a saved model. It should be run using the parameters used to train the model (if they are different than the default values). For example:
Expand All @@ -87,24 +82,6 @@ python evaluate.py --INPUT_SAMPLE=/path/to/samples.txt
```



## Experimental Features (not mentioned in the paper)

To train with fgan with recurrent highway cell:

```
python curriculum_training.py --GAN_TYPE fgan --CRITIC_ITERS 2 --GEN_ITERS 4 \
--PRINT_ITERATION 500 --ITERATIONS_PER_SEQ_LENGTH 60000 --RNN_CELL rhn
```

Please note that for fgan, there may be completely different hyperparameters that are more suitable for better convergence.

### Monitoring Convergence

To measure fgan convergence, gen_cost should start at a positive number and decrease. The lower, the better.

Warning: in the very beginning of training, you may see the gen_cost rise. Please wait at least 5000 iterations and the gen_cost should start to lower. This phenomena is due to the critic finding the appropriate wasserstein distance and then the generator adjusting for it.

## Reference
If you found this code useful, please cite the following paper:

Expand All @@ -117,8 +94,6 @@ If you found this code useful, please cite the following paper:
}
```


## Acknowledgments

This repository is based on the code published in [Improved Training of Wasserstein GANs](https://github.com/igul222/improved_wgan_training).

54 changes: 8 additions & 46 deletions config.py
Expand Up @@ -2,10 +2,6 @@
import time

import tensorflow as tf
from tensorflow.contrib.rnn import GRUCell
from highway_rnn_cell import RHNCell

tf.logging.set_verbosity(tf.logging.INFO)

flags = tf.app.flags

Expand All @@ -16,31 +12,20 @@
flags.DEFINE_string('DATA_DIR', './data/1-billion-word-language-modeling-benchmark-r13output/', "")
flags.DEFINE_string('CKPT_PATH', "./ckpt/", "")
flags.DEFINE_integer('BATCH_SIZE', 64, '')
flags.DEFINE_integer('CRITIC_ITERS', 10, """When training wgan, it is helpful to use
10 critic_iters, however, when training with fgan, 2 critic iters may be more suitable.""")
flags.DEFINE_integer('CRITIC_ITERS', 10, '')
flags.DEFINE_integer('LAMBDA', 10, '')
flags.DEFINE_integer('MAX_N_EXAMPLES', 10000000, '')
flags.DEFINE_string('GENERATOR_MODEL', 'Generator_RNN_CL_VL_TH', '')
flags.DEFINE_string('DISCRIMINATOR_MODEL', 'Discriminator_RNN', '')
flags.DEFINE_string('GENERATOR_MODEL', 'Generator_GRU_CL_VL_TH', '')
flags.DEFINE_string('DISCRIMINATOR_MODEL', 'Discriminator_GRU', '')
flags.DEFINE_string('PICKLE_PATH', './pkl', '')
flags.DEFINE_integer('GEN_ITERS', 50, """When training wgan, it is helpful to use
50 gen_iters, however, when training with fgan, 2 gen_iters may be more suitable.""")
flags.DEFINE_integer('GEN_ITERS', 50, '')
flags.DEFINE_integer('ITERATIONS_PER_SEQ_LENGTH', 15000, '')
flags.DEFINE_float('NOISE_STDEV', 10.0, '')

flags.DEFINE_boolean('TRAIN_FROM_CKPT', False, '')

# RNN Settings
flags.DEFINE_integer('GEN_RNN_LAYERS', 1, '')
flags.DEFINE_integer('DISC_RNN_LAYERS', 1, '')
flags.DEFINE_integer('DISC_STATE_SIZE', 512, '')
flags.DEFINE_integer('GEN_STATE_SIZE', 512, '')
flags.DEFINE_string('RNN_CELL', 'gru', """Choose between 'gru' or 'rhn'.
'gru' option refers to a vanilla gru implementation
'rhn' options refers to a multiplicative integration 2-layer highway rnn
with normalizing tanh activation
""")

flags.DEFINE_boolean('TRAIN_FROM_CKPT', False, '')
flags.DEFINE_integer('GEN_GRU_LAYERS', 1, '')
flags.DEFINE_integer('DISC_GRU_LAYERS', 1, '')
flags.DEFINE_integer('START_SEQ', 1, '')
flags.DEFINE_integer('END_SEQ', 32, '')
flags.DEFINE_bool('PADDING_IS_SUFFIX', False, '')
Expand All @@ -51,22 +36,6 @@
flags.DEFINE_boolean('DYNAMIC_BATCH', False, '')
flags.DEFINE_string('SCHEDULE_SPEC', 'all', '')

# Print Options
flags.DEFINE_boolean('PRINT_EVERY_STEP', False, '')
flags.DEFINE_integer('PRINT_ITERATION', 100, '')


# Fisher GAN Flags
flags.DEFINE_string('GAN_TYPE', 'wgan', "Type of GAN to use. Choose between 'wgan' and 'fgan' for wasserstein and fisher respectively")
flags.DEFINE_float('FISHER_GAN_RHO', 1e-6, "Weight on the penalty term for (sigmas -1)**2")

# Learning Rates
flags.DEFINE_float('DISC_LR', 2e-4, """Disc learning rate -- should be different than generator
learning rate due to TTUR paper https://arxiv.org/abs/1706.08500""")
flags.DEFINE_float('GEN_LR', 1e-4, """Gen learning rate""")



# Only for inference mode

flags.DEFINE_string('INPUT_SAMPLE', './output/sample.txt', '')
Expand Down Expand Up @@ -115,11 +84,4 @@ def create_logs_dir():
CKPT_PATH = FLAGS.CKPT_PATH
GENERATOR_MODEL = FLAGS.GENERATOR_MODEL
DISCRIMINATOR_MODEL = FLAGS.DISCRIMINATOR_MODEL
GEN_ITERS = FLAGS.GEN_ITERS

if FLAGS.RNN_CELL.lower() == 'gru':
RNN_CELL = GRUCell
elif FLAGS.RNN_CELL.lower() == 'rhn':
RNN_CELL = RHNCell
else:
raise ValueError('improper rnn cell type selected')
GEN_ITERS = FLAGS.GEN_ITERS
70 changes: 0 additions & 70 deletions fisher_gan_objective.py

This file was deleted.

0 comments on commit 563fc9f

Please sign in to comment.