Skip to content

Commit

Permalink
Merge pull request #3 from NickShahML/master
Browse files Browse the repository at this point in the history
Fisher GAN + Recurrent Highway Network + TTUR
  • Loading branch information
amirbar committed Jul 9, 2017
2 parents ef1481e + da52d9f commit 1ac0e09
Show file tree
Hide file tree
Showing 9 changed files with 493 additions and 90 deletions.
105 changes: 105 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# Repo data/logs/models
data/
logs/
pkl/


# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Code for training and evaluation of the model from ["Language Generation with Recurrent Generative Adversarial Networks without Pre-training"](https://arxiv.org/abs/1706.01399).


Additional Code for using Fisher GAN in Recurrent Generative Adversarial Networks

### Sample outputs (32 chars)

Expand All @@ -24,6 +24,15 @@ Then use the following command:
python curriculum_training.py
```

To train with fgan with recurrent highway cell:

```
python curriculum_training.py --GAN_TYPE fgan --CRITIC_ITERS 2 --GEN_ITERS 4 \
--PRINT_ITERATION 500 --ITERATIONS_PER_SEQ_LENGTH 60000 --RNN_CELL rhn
```

Please note that for fgan, there may be completely different hyperparameters that are more suitable for better convergence.

The following packages are required:

* Python 2.7
Expand Down Expand Up @@ -56,15 +65,26 @@ START_SEQ: Sequence length to start the curriculum learning with (defaults to 1)
END_SEQ: Sequence length to end the curriculum learning with (defaults to 32)
SAVE_CHECKPOINTS_EVERY: Save checkpoint every # steps (defaults to 25000)
LIMIT_BATCH: Boolean that indicates whether to limit the batch size (defaults to true)
GAN_TYPE: String Type of GAN to use. Choose between 'wgan' and 'fgan' for wasserstein and fisher respectively
```

Paramters can be set by either changing their value in the config file or by passing them in the terminal:
Parameters can be set by either changing their value in the config file or by passing them in the terminal:

```
python curriculum_training.py --START_SEQ=1 --END_SEQ=32
```

## Monitoring Convergence During Training

### Wasserstein GAN
In the wasserstein GAN, please monitor the disc_cost. It should be a negative number and approach zero. The disc_cost represents the negative wasserstein distance between gen and critic.

### Fisher GAN
To measure convergence, gen_cost should start at a positive number and decrease. The lower, the better.

Warning: in the very beginning of training, you may see the gen_cost rise. Please wait at least 5000 iterations and the gen_cost should start to lower. This phenomena is due to the critic finding the appropriate wasserstein distance and then the generator adjusting for it.

## Generating text

The `generate.py` script will generate `BATCH_SIZE` samples using a saved model. It should be run using the parameters used to train the model (if they are different than the default values). For example:
Expand Down
54 changes: 46 additions & 8 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import time

import tensorflow as tf
from tensorflow.contrib.rnn import GRUCell
from highway_rnn_cell import RHNCell

tf.logging.set_verbosity(tf.logging.INFO)

flags = tf.app.flags

Expand All @@ -12,20 +16,31 @@
flags.DEFINE_string('DATA_DIR', './data/1-billion-word-language-modeling-benchmark-r13output/', "")
flags.DEFINE_string('CKPT_PATH', "./ckpt/", "")
flags.DEFINE_integer('BATCH_SIZE', 64, '')
flags.DEFINE_integer('CRITIC_ITERS', 10, '')
flags.DEFINE_integer('CRITIC_ITERS', 10, """When training wgan, it is helpful to use
10 critic_iters, however, when training with fgan, 2 critic iters may be more suitable.""")
flags.DEFINE_integer('LAMBDA', 10, '')
flags.DEFINE_integer('MAX_N_EXAMPLES', 10000000, '')
flags.DEFINE_string('GENERATOR_MODEL', 'Generator_GRU_CL_VL_TH', '')
flags.DEFINE_string('DISCRIMINATOR_MODEL', 'Discriminator_GRU', '')
flags.DEFINE_string('GENERATOR_MODEL', 'Generator_RNN_CL_VL_TH', '')
flags.DEFINE_string('DISCRIMINATOR_MODEL', 'Discriminator_RNN', '')
flags.DEFINE_string('PICKLE_PATH', './pkl', '')
flags.DEFINE_integer('GEN_ITERS', 50, '')
flags.DEFINE_integer('GEN_ITERS', 50, """When training wgan, it is helpful to use
50 gen_iters, however, when training with fgan, 2 gen_iters may be more suitable.""")
flags.DEFINE_integer('ITERATIONS_PER_SEQ_LENGTH', 15000, '')
flags.DEFINE_float('NOISE_STDEV', 10.0, '')

flags.DEFINE_boolean('TRAIN_FROM_CKPT', False, '')

# RNN Settings
flags.DEFINE_integer('GEN_RNN_LAYERS', 1, '')
flags.DEFINE_integer('DISC_RNN_LAYERS', 1, '')
flags.DEFINE_integer('DISC_STATE_SIZE', 512, '')
flags.DEFINE_integer('GEN_STATE_SIZE', 512, '')
flags.DEFINE_boolean('TRAIN_FROM_CKPT', False, '')
flags.DEFINE_integer('GEN_GRU_LAYERS', 1, '')
flags.DEFINE_integer('DISC_GRU_LAYERS', 1, '')
flags.DEFINE_string('RNN_CELL', 'gru', """Choose between 'gru' or 'rhn'.
'gru' option refers to a vanilla gru implementation
'rhn' options refers to a multiplicative integration 2-layer highway rnn
with normalizing tanh activation
""")

flags.DEFINE_integer('START_SEQ', 1, '')
flags.DEFINE_integer('END_SEQ', 32, '')
flags.DEFINE_bool('PADDING_IS_SUFFIX', False, '')
Expand All @@ -36,6 +51,22 @@
flags.DEFINE_boolean('DYNAMIC_BATCH', False, '')
flags.DEFINE_string('SCHEDULE_SPEC', 'all', '')

# Print Options
flags.DEFINE_boolean('PRINT_EVERY_STEP', False, '')
flags.DEFINE_integer('PRINT_ITERATION', 100, '')


# Fisher GAN Flags
flags.DEFINE_string('GAN_TYPE', 'wgan', "Type of GAN to use. Choose between 'wgan' and 'fgan' for wasserstein and fisher respectively")
flags.DEFINE_float('FISHER_GAN_RHO', 1e-6, "Weight on the penalty term for (sigmas -1)**2")

# Learning Rates
flags.DEFINE_float('DISC_LR', 2e-4, """Disc learning rate -- should be different than generator
learning rate due to TTUR paper https://arxiv.org/abs/1706.08500""")
flags.DEFINE_float('GEN_LR', 1e-4, """Gen learning rate""")



# Only for inference mode

flags.DEFINE_string('INPUT_SAMPLE', './output/sample.txt', '')
Expand Down Expand Up @@ -84,4 +115,11 @@ def create_logs_dir():
CKPT_PATH = FLAGS.CKPT_PATH
GENERATOR_MODEL = FLAGS.GENERATOR_MODEL
DISCRIMINATOR_MODEL = FLAGS.DISCRIMINATOR_MODEL
GEN_ITERS = FLAGS.GEN_ITERS
GEN_ITERS = FLAGS.GEN_ITERS

if FLAGS.RNN_CELL.lower() == 'gru':
RNN_CELL = GRUCell
elif FLAGS.RNN_CELL.lower() == 'rhn':
RNN_CELL = RHNCell
else:
raise ValueError('improper rnn cell type selected')
70 changes: 70 additions & 0 deletions fisher_gan_objective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import tensorflow as tf

class FisherGAN():
"""Implements fisher gan objective functions
Modeled off https://github.com/ethancaballero/FisherGAN/blob/master/main.py
Tried to keep variable names the same as much as possible
To measure convergence, gen_cost should start at a positive number and decrease
to zero. The lower, the better.
Warning: in the very beginning of training, you may see the gen_cost rise. Please
wait at least 5000 iterations and the gen_cost should start to lower. This
phenomena is due to the critic finding the appropriate wasserstein distance
and then the generator adjusting for it.
It is recommended that you use a critic iteration of 1 when using fisher gan
"""

def __init__(self, rho=1e-5):
tf.logging.warn("USING FISHER GAN OBJECTIVE FUNCTION")
self._rho = rho
# Initialize alpha (or in paper called lambda) with zero
# Throughout training alpha is trained with an independent sgd optimizer
# We use "alpha" instead of lambda because code we are modeling off of
# uses "alpha" instead of lambda
self._alpha = tf.get_variable("fisher_alpha", [], initializer=tf.zeros_initializer)

def _optimize_alpha(self, disc_cost):
""" In the optimization of alpha, we optimize via regular sgd with a learning rate
of rho.
This optimization should occur every time the discriminator is optimized because
the same batch is used.
Very crucial point --> We minimize the NEGATIVE disc_cost with our alpha parameter.
This is done to enforce the Lipchitz constraint. If we minimized the positive disc_cost
then our discriminator loss would drop to a very low negative number and the Lipchitz
constraint would not hold.
"""

# Find gradient of alpha with respect to negative disc_cost
self._alpha_optimizer = tf.train.GradientDescentOptimizer(self._rho)
self.alpha_optimizer_op = self._alpha_optimizer.minimize(-disc_cost, var_list=[self._alpha])
return

def loss_d_g(self, disc_fake, disc_real, fake_inputs, real_inputs, charmap, seq_length, Discriminator):

# Compared to WGAN, generator cost remains the same in fisher GAN
gen_cost = -tf.reduce_mean(disc_fake)

# Calculate Lipchitz Constraint
# E_P and E_Q refer to Expectation over real and fake.

E_Q_f = tf.reduce_mean(disc_fake)
E_P_f = tf.reduce_mean(disc_real)
E_Q_f2 = tf.reduce_mean(disc_fake**2)
E_P_f2 = tf.reduce_mean(disc_real**2)

constraint = (1 - (0.5*E_P_f2 + 0.5*E_Q_f2))

# See Equation (9) in Fisher GAN paper
# In the original implementation, they use a backward computation with mone (minus one)
# To implement this in tensorflow, we simply multiply the objective
# cost function by minus one.
disc_cost = -1.0 * (E_P_f - E_Q_f + self._alpha * constraint - self._rho/2 * constraint**2)

# calculate optimization op for alpha
self._optimize_alpha(disc_cost)

return disc_cost, gen_cost
78 changes: 78 additions & 0 deletions highway_rnn_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import tensorflow as tf
from multiplicative_integration import multiplicative_integration
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import RNNCell


def ntanh(_x, name="normalizing_tanh"):
"""
Inspired by self normalizing networks paper, we adjust scale on tanh
function to encourage mean of 0 and variance of 1 in activations
From comments on reddit, the normalizing tanh function:
1.5925374197228312
"""
scale = 1.5925374197228312
return scale*tf.nn.tanh(_x, name=name)


class RHNCell(RNNCell):
"""
Recurrent Highway Cell
Reference: https://arxiv.org/abs/1607.03474
"""

def __init__(self, num_units, depth=2, forget_bias=-2.0, activation=ntanh):

"""We initialize forget bias to negative two so that highway layers don't activate
"""


assert activation.__name__ == "ntanh"
self._num_units = num_units
self._in_size = num_units
self._depth = depth
self._forget_bias = forget_bias
self._activation = activation

tf.logging.info("""Building Recurrent Highway Cell with {} Activation of depth {}
and forget bias of {}""".format(
self._activation.__name__, self._depth, self._forget_bias))


@property
def input_size(self):
return self._in_size

@property
def output_size(self):
return self._num_units

@property
def state_size(self):
return self._num_units

def __call__(self, inputs, state, timestep=0, scope=None):
current_state = state

for i in range(self._depth):
with tf.variable_scope('h_'+str(i)):
if i == 0:
h = self._activation(
multiplicative_integration([inputs,current_state], self._num_units))
else:
h = tf.layers.dense(current_state, self._num_units, self._activation,
bias_initializer=tf.zeros_initializer())

with tf.variable_scope('gate_'+str(i)):
if i == 0:
t = tf.sigmoid(
multiplicative_integration([inputs,current_state], self._num_units,
self._forget_bias))

else:
t = tf.layers.dense(current_state, self._num_units, tf.sigmoid,
bias_initializer=tf.constant_initializer(self._forget_bias))

current_state = (h - current_state)* t + current_state

return current_state, current_state

0 comments on commit 1ac0e09

Please sign in to comment.