Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fisher GAN + Recurrent Highway Network + TTUR #3

Merged
merged 12 commits into from
Jul 9, 2017
105 changes: 105 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# Repo data/logs/models
data/
logs/
pkl/


# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Code for training and evaluation of the model from ["Language Generation with Recurrent Generative Adversarial Networks without Pre-training"](https://arxiv.org/abs/1706.01399).


Additional Code for using Fisher GAN in Recurrent Generative Adversarial Networks

### Sample outputs (32 chars)

Expand All @@ -24,6 +24,15 @@ Then use the following command:
python curriculum_training.py
```

To train with fgan with recurrent highway cell:

```
python curriculum_training.py --GAN_TYPE fgan --CRITIC_ITERS 2 --GEN_ITERS 4 \
--PRINT_ITERATION 500 --ITERATIONS_PER_SEQ_LENGTH 60000 --RNN_CELL rhn
```

Please note that for fgan, there may be completely different hyperparameters that are more suitable for better convergence.

The following packages are required:

* Python 2.7
Expand Down Expand Up @@ -56,15 +65,26 @@ START_SEQ: Sequence length to start the curriculum learning with (defaults to 1)
END_SEQ: Sequence length to end the curriculum learning with (defaults to 32)
SAVE_CHECKPOINTS_EVERY: Save checkpoint every # steps (defaults to 25000)
LIMIT_BATCH: Boolean that indicates whether to limit the batch size (defaults to true)
GAN_TYPE: String Type of GAN to use. Choose between 'wgan' and 'fgan' for wasserstein and fisher respectively

```

Paramters can be set by either changing their value in the config file or by passing them in the terminal:
Parameters can be set by either changing their value in the config file or by passing them in the terminal:

```
python curriculum_training.py --START_SEQ=1 --END_SEQ=32
```

## Monitoring Convergence During Training

### Wasserstein GAN
In the wasserstein GAN, please monitor the disc_cost. It should be a negative number and approach zero. The disc_cost represents the negative wasserstein distance between gen and critic.

### Fisher GAN
To measure convergence, gen_cost should start at a positive number and decrease. The lower, the better.

Warning: in the very beginning of training, you may see the gen_cost rise. Please wait at least 5000 iterations and the gen_cost should start to lower. This phenomena is due to the critic finding the appropriate wasserstein distance and then the generator adjusting for it.

## Generating text

The `generate.py` script will generate `BATCH_SIZE` samples using a saved model. It should be run using the parameters used to train the model (if they are different than the default values). For example:
Expand Down
54 changes: 46 additions & 8 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import time

import tensorflow as tf
from tensorflow.contrib.rnn import GRUCell
from highway_rnn_cell import RHNCell

tf.logging.set_verbosity(tf.logging.INFO)

flags = tf.app.flags

Expand All @@ -12,20 +16,31 @@
flags.DEFINE_string('DATA_DIR', './data/1-billion-word-language-modeling-benchmark-r13output/', "")
flags.DEFINE_string('CKPT_PATH', "./ckpt/", "")
flags.DEFINE_integer('BATCH_SIZE', 64, '')
flags.DEFINE_integer('CRITIC_ITERS', 10, '')
flags.DEFINE_integer('CRITIC_ITERS', 10, """When training wgan, it is helpful to use
10 critic_iters, however, when training with fgan, 2 critic iters may be more suitable.""")
flags.DEFINE_integer('LAMBDA', 10, '')
flags.DEFINE_integer('MAX_N_EXAMPLES', 10000000, '')
flags.DEFINE_string('GENERATOR_MODEL', 'Generator_GRU_CL_VL_TH', '')
flags.DEFINE_string('DISCRIMINATOR_MODEL', 'Discriminator_GRU', '')
flags.DEFINE_string('GENERATOR_MODEL', 'Generator_RNN_CL_VL_TH', '')
flags.DEFINE_string('DISCRIMINATOR_MODEL', 'Discriminator_RNN', '')
flags.DEFINE_string('PICKLE_PATH', './pkl', '')
flags.DEFINE_integer('GEN_ITERS', 50, '')
flags.DEFINE_integer('GEN_ITERS', 50, """When training wgan, it is helpful to use
50 gen_iters, however, when training with fgan, 2 gen_iters may be more suitable.""")
flags.DEFINE_integer('ITERATIONS_PER_SEQ_LENGTH', 15000, '')
flags.DEFINE_float('NOISE_STDEV', 10.0, '')

flags.DEFINE_boolean('TRAIN_FROM_CKPT', False, '')

# RNN Settings
flags.DEFINE_integer('GEN_RNN_LAYERS', 1, '')
flags.DEFINE_integer('DISC_RNN_LAYERS', 1, '')
flags.DEFINE_integer('DISC_STATE_SIZE', 512, '')
flags.DEFINE_integer('GEN_STATE_SIZE', 512, '')
flags.DEFINE_boolean('TRAIN_FROM_CKPT', False, '')
flags.DEFINE_integer('GEN_GRU_LAYERS', 1, '')
flags.DEFINE_integer('DISC_GRU_LAYERS', 1, '')
flags.DEFINE_string('RNN_CELL', 'gru', """Choose between 'gru' or 'rhn'.
'gru' option refers to a vanilla gru implementation
'rhn' options refers to a multiplicative integration 2-layer highway rnn
with normalizing tanh activation
""")

flags.DEFINE_integer('START_SEQ', 1, '')
flags.DEFINE_integer('END_SEQ', 32, '')
flags.DEFINE_bool('PADDING_IS_SUFFIX', False, '')
Expand All @@ -36,6 +51,22 @@
flags.DEFINE_boolean('DYNAMIC_BATCH', False, '')
flags.DEFINE_string('SCHEDULE_SPEC', 'all', '')

# Print Options
flags.DEFINE_boolean('PRINT_EVERY_STEP', False, '')
flags.DEFINE_integer('PRINT_ITERATION', 100, '')


# Fisher GAN Flags
flags.DEFINE_string('GAN_TYPE', 'wgan', "Type of GAN to use. Choose between 'wgan' and 'fgan' for wasserstein and fisher respectively")
flags.DEFINE_float('FISHER_GAN_RHO', 1e-6, "Weight on the penalty term for (sigmas -1)**2")

# Learning Rates
flags.DEFINE_float('DISC_LR', 2e-4, """Disc learning rate -- should be different than generator
learning rate due to TTUR paper https://arxiv.org/abs/1706.08500""")
flags.DEFINE_float('GEN_LR', 1e-4, """Gen learning rate""")



# Only for inference mode

flags.DEFINE_string('INPUT_SAMPLE', './output/sample.txt', '')
Expand Down Expand Up @@ -84,4 +115,11 @@ def create_logs_dir():
CKPT_PATH = FLAGS.CKPT_PATH
GENERATOR_MODEL = FLAGS.GENERATOR_MODEL
DISCRIMINATOR_MODEL = FLAGS.DISCRIMINATOR_MODEL
GEN_ITERS = FLAGS.GEN_ITERS
GEN_ITERS = FLAGS.GEN_ITERS

if FLAGS.RNN_CELL.lower() == 'gru':
RNN_CELL = GRUCell
elif FLAGS.RNN_CELL.lower() == 'rhn':
RNN_CELL = RHNCell
else:
raise ValueError('improper rnn cell type selected')
70 changes: 70 additions & 0 deletions fisher_gan_objective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import tensorflow as tf

class FisherGAN():
"""Implements fisher gan objective functions
Modeled off https://github.com/ethancaballero/FisherGAN/blob/master/main.py
Tried to keep variable names the same as much as possible

To measure convergence, gen_cost should start at a positive number and decrease
to zero. The lower, the better.

Warning: in the very beginning of training, you may see the gen_cost rise. Please
wait at least 5000 iterations and the gen_cost should start to lower. This
phenomena is due to the critic finding the appropriate wasserstein distance
and then the generator adjusting for it.

It is recommended that you use a critic iteration of 1 when using fisher gan
"""

def __init__(self, rho=1e-5):
tf.logging.warn("USING FISHER GAN OBJECTIVE FUNCTION")
self._rho = rho
# Initialize alpha (or in paper called lambda) with zero
# Throughout training alpha is trained with an independent sgd optimizer
# We use "alpha" instead of lambda because code we are modeling off of
# uses "alpha" instead of lambda
self._alpha = tf.get_variable("fisher_alpha", [], initializer=tf.zeros_initializer)

def _optimize_alpha(self, disc_cost):
""" In the optimization of alpha, we optimize via regular sgd with a learning rate
of rho.

This optimization should occur every time the discriminator is optimized because
the same batch is used.

Very crucial point --> We minimize the NEGATIVE disc_cost with our alpha parameter.
This is done to enforce the Lipchitz constraint. If we minimized the positive disc_cost
then our discriminator loss would drop to a very low negative number and the Lipchitz
constraint would not hold.
"""

# Find gradient of alpha with respect to negative disc_cost
self._alpha_optimizer = tf.train.GradientDescentOptimizer(self._rho)
self.alpha_optimizer_op = self._alpha_optimizer.minimize(-disc_cost, var_list=[self._alpha])
return

def loss_d_g(self, disc_fake, disc_real, fake_inputs, real_inputs, charmap, seq_length, Discriminator):

# Compared to WGAN, generator cost remains the same in fisher GAN
gen_cost = -tf.reduce_mean(disc_fake)

# Calculate Lipchitz Constraint
# E_P and E_Q refer to Expectation over real and fake.

E_Q_f = tf.reduce_mean(disc_fake)
E_P_f = tf.reduce_mean(disc_real)
E_Q_f2 = tf.reduce_mean(disc_fake**2)
E_P_f2 = tf.reduce_mean(disc_real**2)

constraint = (1 - (0.5*E_P_f2 + 0.5*E_Q_f2))

# See Equation (9) in Fisher GAN paper
# In the original implementation, they use a backward computation with mone (minus one)
# To implement this in tensorflow, we simply multiply the objective
# cost function by minus one.
disc_cost = -1.0 * (E_P_f - E_Q_f + self._alpha * constraint - self._rho/2 * constraint**2)

# calculate optimization op for alpha
self._optimize_alpha(disc_cost)

return disc_cost, gen_cost
78 changes: 78 additions & 0 deletions highway_rnn_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import tensorflow as tf
from multiplicative_integration import multiplicative_integration
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import RNNCell


def ntanh(_x, name="normalizing_tanh"):
"""
Inspired by self normalizing networks paper, we adjust scale on tanh
function to encourage mean of 0 and variance of 1 in activations

From comments on reddit, the normalizing tanh function:
1.5925374197228312
"""
scale = 1.5925374197228312
return scale*tf.nn.tanh(_x, name=name)


class RHNCell(RNNCell):
"""
Recurrent Highway Cell
Reference: https://arxiv.org/abs/1607.03474
"""

def __init__(self, num_units, depth=2, forget_bias=-2.0, activation=ntanh):

"""We initialize forget bias to negative two so that highway layers don't activate
"""


assert activation.__name__ == "ntanh"
self._num_units = num_units
self._in_size = num_units
self._depth = depth
self._forget_bias = forget_bias
self._activation = activation

tf.logging.info("""Building Recurrent Highway Cell with {} Activation of depth {}
and forget bias of {}""".format(
self._activation.__name__, self._depth, self._forget_bias))


@property
def input_size(self):
return self._in_size

@property
def output_size(self):
return self._num_units

@property
def state_size(self):
return self._num_units

def __call__(self, inputs, state, timestep=0, scope=None):
current_state = state

for i in range(self._depth):
with tf.variable_scope('h_'+str(i)):
if i == 0:
h = self._activation(
multiplicative_integration([inputs,current_state], self._num_units))
else:
h = tf.layers.dense(current_state, self._num_units, self._activation,
bias_initializer=tf.zeros_initializer())

with tf.variable_scope('gate_'+str(i)):
if i == 0:
t = tf.sigmoid(
multiplicative_integration([inputs,current_state], self._num_units,
self._forget_bias))

else:
t = tf.layers.dense(current_state, self._num_units, tf.sigmoid,
bias_initializer=tf.constant_initializer(self._forget_bias))

current_state = (h - current_state)* t + current_state

return current_state, current_state