From 563fc9f3cd85734740f61960bc13b053e74d15a0 Mon Sep 17 00:00:00 2001 From: Ofir Press Date: Wed, 24 Jan 2018 13:05:57 +0200 Subject: [PATCH] restoring the original code --- .gitignore | 105 ---------------------------------- README.md | 27 +-------- config.py | 54 +++-------------- fisher_gan_objective.py | 70 ----------------------- highway_rnn_cell.py | 78 ------------------------- model.py | 68 ++++++++++++---------- multiplicative_integration.py | 42 -------------- objective.py | 53 ++++++----------- single_length_train.py | 89 +++++++++------------------- 9 files changed, 89 insertions(+), 497 deletions(-) delete mode 100644 .gitignore delete mode 100644 fisher_gan_objective.py delete mode 100644 highway_rnn_cell.py delete mode 100644 multiplicative_integration.py diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 8932d8d..0000000 --- a/.gitignore +++ /dev/null @@ -1,105 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg - -# Repo data/logs/models -data/ -logs/ -pkl/ - - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -.hypothesis/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# pyenv -.python-version - -# celery beat schedule file -celerybeat-schedule - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ \ No newline at end of file diff --git a/README.md b/README.md index 98c4ea4..92a5116 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Code for training and evaluation of the model from ["Language Generation with Recurrent Generative Adversarial Networks without Pre-training"](https://arxiv.org/abs/1706.01399). -Additional Code for using Fisher GAN in Recurrent Generative Adversarial Networks + ### Sample outputs (32 chars) @@ -56,7 +56,6 @@ START_SEQ: Sequence length to start the curriculum learning with (defaults to 1) END_SEQ: Sequence length to end the curriculum learning with (defaults to 32) SAVE_CHECKPOINTS_EVERY: Save checkpoint every # steps (defaults to 25000) LIMIT_BATCH: Boolean that indicates whether to limit the batch size (defaults to true) -GAN_TYPE: String Type of GAN to use. Choose between 'wgan' and 'fgan' for wasserstein and fisher respectively ``` @@ -66,10 +65,6 @@ Parameters can be set by either changing their value in the config file or by pa python curriculum_training.py --START_SEQ=1 --END_SEQ=32 ``` -## Monitoring Convergence During Training - -In the wasserstein GAN, please monitor the disc_cost. It should be a negative number and approach zero. The disc_cost represents the negative wasserstein distance between gen and critic. - ## Generating text The `generate.py` script will generate `BATCH_SIZE` samples using a saved model. It should be run using the parameters used to train the model (if they are different than the default values). For example: @@ -87,24 +82,6 @@ python evaluate.py --INPUT_SAMPLE=/path/to/samples.txt ``` - -## Experimental Features (not mentioned in the paper) - -To train with fgan with recurrent highway cell: - -``` -python curriculum_training.py --GAN_TYPE fgan --CRITIC_ITERS 2 --GEN_ITERS 4 \ ---PRINT_ITERATION 500 --ITERATIONS_PER_SEQ_LENGTH 60000 --RNN_CELL rhn -``` - -Please note that for fgan, there may be completely different hyperparameters that are more suitable for better convergence. - -### Monitoring Convergence - -To measure fgan convergence, gen_cost should start at a positive number and decrease. The lower, the better. - -Warning: in the very beginning of training, you may see the gen_cost rise. Please wait at least 5000 iterations and the gen_cost should start to lower. This phenomena is due to the critic finding the appropriate wasserstein distance and then the generator adjusting for it. - ## Reference If you found this code useful, please cite the following paper: @@ -117,8 +94,6 @@ If you found this code useful, please cite the following paper: } ``` - ## Acknowledgments This repository is based on the code published in [Improved Training of Wasserstein GANs](https://github.com/igul222/improved_wgan_training). - diff --git a/config.py b/config.py index 6fa93d1..9119735 100644 --- a/config.py +++ b/config.py @@ -2,10 +2,6 @@ import time import tensorflow as tf -from tensorflow.contrib.rnn import GRUCell -from highway_rnn_cell import RHNCell - -tf.logging.set_verbosity(tf.logging.INFO) flags = tf.app.flags @@ -16,31 +12,20 @@ flags.DEFINE_string('DATA_DIR', './data/1-billion-word-language-modeling-benchmark-r13output/', "") flags.DEFINE_string('CKPT_PATH', "./ckpt/", "") flags.DEFINE_integer('BATCH_SIZE', 64, '') -flags.DEFINE_integer('CRITIC_ITERS', 10, """When training wgan, it is helpful to use - 10 critic_iters, however, when training with fgan, 2 critic iters may be more suitable.""") +flags.DEFINE_integer('CRITIC_ITERS', 10, '') flags.DEFINE_integer('LAMBDA', 10, '') flags.DEFINE_integer('MAX_N_EXAMPLES', 10000000, '') -flags.DEFINE_string('GENERATOR_MODEL', 'Generator_RNN_CL_VL_TH', '') -flags.DEFINE_string('DISCRIMINATOR_MODEL', 'Discriminator_RNN', '') +flags.DEFINE_string('GENERATOR_MODEL', 'Generator_GRU_CL_VL_TH', '') +flags.DEFINE_string('DISCRIMINATOR_MODEL', 'Discriminator_GRU', '') flags.DEFINE_string('PICKLE_PATH', './pkl', '') -flags.DEFINE_integer('GEN_ITERS', 50, """When training wgan, it is helpful to use - 50 gen_iters, however, when training with fgan, 2 gen_iters may be more suitable.""") +flags.DEFINE_integer('GEN_ITERS', 50, '') flags.DEFINE_integer('ITERATIONS_PER_SEQ_LENGTH', 15000, '') flags.DEFINE_float('NOISE_STDEV', 10.0, '') - -flags.DEFINE_boolean('TRAIN_FROM_CKPT', False, '') - -# RNN Settings -flags.DEFINE_integer('GEN_RNN_LAYERS', 1, '') -flags.DEFINE_integer('DISC_RNN_LAYERS', 1, '') flags.DEFINE_integer('DISC_STATE_SIZE', 512, '') flags.DEFINE_integer('GEN_STATE_SIZE', 512, '') -flags.DEFINE_string('RNN_CELL', 'gru', """Choose between 'gru' or 'rhn'. - 'gru' option refers to a vanilla gru implementation - 'rhn' options refers to a multiplicative integration 2-layer highway rnn - with normalizing tanh activation - """) - +flags.DEFINE_boolean('TRAIN_FROM_CKPT', False, '') +flags.DEFINE_integer('GEN_GRU_LAYERS', 1, '') +flags.DEFINE_integer('DISC_GRU_LAYERS', 1, '') flags.DEFINE_integer('START_SEQ', 1, '') flags.DEFINE_integer('END_SEQ', 32, '') flags.DEFINE_bool('PADDING_IS_SUFFIX', False, '') @@ -51,22 +36,6 @@ flags.DEFINE_boolean('DYNAMIC_BATCH', False, '') flags.DEFINE_string('SCHEDULE_SPEC', 'all', '') -# Print Options -flags.DEFINE_boolean('PRINT_EVERY_STEP', False, '') -flags.DEFINE_integer('PRINT_ITERATION', 100, '') - - -# Fisher GAN Flags -flags.DEFINE_string('GAN_TYPE', 'wgan', "Type of GAN to use. Choose between 'wgan' and 'fgan' for wasserstein and fisher respectively") -flags.DEFINE_float('FISHER_GAN_RHO', 1e-6, "Weight on the penalty term for (sigmas -1)**2") - -# Learning Rates -flags.DEFINE_float('DISC_LR', 2e-4, """Disc learning rate -- should be different than generator - learning rate due to TTUR paper https://arxiv.org/abs/1706.08500""") -flags.DEFINE_float('GEN_LR', 1e-4, """Gen learning rate""") - - - # Only for inference mode flags.DEFINE_string('INPUT_SAMPLE', './output/sample.txt', '') @@ -115,11 +84,4 @@ def create_logs_dir(): CKPT_PATH = FLAGS.CKPT_PATH GENERATOR_MODEL = FLAGS.GENERATOR_MODEL DISCRIMINATOR_MODEL = FLAGS.DISCRIMINATOR_MODEL -GEN_ITERS = FLAGS.GEN_ITERS - -if FLAGS.RNN_CELL.lower() == 'gru': - RNN_CELL = GRUCell -elif FLAGS.RNN_CELL.lower() == 'rhn': - RNN_CELL = RHNCell -else: - raise ValueError('improper rnn cell type selected') \ No newline at end of file +GEN_ITERS = FLAGS.GEN_ITERS \ No newline at end of file diff --git a/fisher_gan_objective.py b/fisher_gan_objective.py deleted file mode 100644 index 06458fc..0000000 --- a/fisher_gan_objective.py +++ /dev/null @@ -1,70 +0,0 @@ -import tensorflow as tf - -class FisherGAN(): - """Implements fisher gan objective functions - Modeled off https://github.com/ethancaballero/FisherGAN/blob/master/main.py - Tried to keep variable names the same as much as possible - - To measure convergence, gen_cost should start at a positive number and decrease - to zero. The lower, the better. - - Warning: in the very beginning of training, you may see the gen_cost rise. Please - wait at least 5000 iterations and the gen_cost should start to lower. This - phenomena is due to the critic finding the appropriate wasserstein distance - and then the generator adjusting for it. - - It is recommended that you use a critic iteration of 1 when using fisher gan - """ - - def __init__(self, rho=1e-5): - tf.logging.warn("USING FISHER GAN OBJECTIVE FUNCTION") - self._rho = rho - # Initialize alpha (or in paper called lambda) with zero - # Throughout training alpha is trained with an independent sgd optimizer - # We use "alpha" instead of lambda because code we are modeling off of - # uses "alpha" instead of lambda - self._alpha = tf.get_variable("fisher_alpha", [], initializer=tf.zeros_initializer) - - def _optimize_alpha(self, disc_cost): - """ In the optimization of alpha, we optimize via regular sgd with a learning rate - of rho. - - This optimization should occur every time the discriminator is optimized because - the same batch is used. - - Very crucial point --> We minimize the NEGATIVE disc_cost with our alpha parameter. - This is done to enforce the Lipchitz constraint. If we minimized the positive disc_cost - then our discriminator loss would drop to a very low negative number and the Lipchitz - constraint would not hold. - """ - - # Find gradient of alpha with respect to negative disc_cost - self._alpha_optimizer = tf.train.GradientDescentOptimizer(self._rho) - self.alpha_optimizer_op = self._alpha_optimizer.minimize(-disc_cost, var_list=[self._alpha]) - return - - def loss_d_g(self, disc_fake, disc_real, fake_inputs, real_inputs, charmap, seq_length, Discriminator): - - # Compared to WGAN, generator cost remains the same in fisher GAN - gen_cost = -tf.reduce_mean(disc_fake) - - # Calculate Lipchitz Constraint - # E_P and E_Q refer to Expectation over real and fake. - - E_Q_f = tf.reduce_mean(disc_fake) - E_P_f = tf.reduce_mean(disc_real) - E_Q_f2 = tf.reduce_mean(disc_fake**2) - E_P_f2 = tf.reduce_mean(disc_real**2) - - constraint = (1 - (0.5*E_P_f2 + 0.5*E_Q_f2)) - - # See Equation (9) in Fisher GAN paper - # In the original implementation, they use a backward computation with mone (minus one) - # To implement this in tensorflow, we simply multiply the objective - # cost function by minus one. - disc_cost = -1.0 * (E_P_f - E_Q_f + self._alpha * constraint - self._rho/2 * constraint**2) - - # calculate optimization op for alpha - self._optimize_alpha(disc_cost) - - return disc_cost, gen_cost \ No newline at end of file diff --git a/highway_rnn_cell.py b/highway_rnn_cell.py deleted file mode 100644 index 965d916..0000000 --- a/highway_rnn_cell.py +++ /dev/null @@ -1,78 +0,0 @@ -import tensorflow as tf -from multiplicative_integration import multiplicative_integration -from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import RNNCell - - -def ntanh(_x, name="normalizing_tanh"): - """ - Inspired by self normalizing networks paper, we adjust scale on tanh - function to encourage mean of 0 and variance of 1 in activations - - From comments on reddit, the normalizing tanh function: - 1.5925374197228312 - """ - scale = 1.5925374197228312 - return scale*tf.nn.tanh(_x, name=name) - - -class RHNCell(RNNCell): - """ - Recurrent Highway Cell - Reference: https://arxiv.org/abs/1607.03474 - """ - - def __init__(self, num_units, depth=2, forget_bias=-2.0, activation=ntanh): - - """We initialize forget bias to negative two so that highway layers don't activate - """ - - - assert activation.__name__ == "ntanh" - self._num_units = num_units - self._in_size = num_units - self._depth = depth - self._forget_bias = forget_bias - self._activation = activation - - tf.logging.info("""Building Recurrent Highway Cell with {} Activation of depth {} - and forget bias of {}""".format( - self._activation.__name__, self._depth, self._forget_bias)) - - - @property - def input_size(self): - return self._in_size - - @property - def output_size(self): - return self._num_units - - @property - def state_size(self): - return self._num_units - - def __call__(self, inputs, state, timestep=0, scope=None): - current_state = state - - for i in range(self._depth): - with tf.variable_scope('h_'+str(i)): - if i == 0: - h = self._activation( - multiplicative_integration([inputs,current_state], self._num_units)) - else: - h = tf.layers.dense(current_state, self._num_units, self._activation, - bias_initializer=tf.zeros_initializer()) - - with tf.variable_scope('gate_'+str(i)): - if i == 0: - t = tf.sigmoid( - multiplicative_integration([inputs,current_state], self._num_units, - self._forget_bias)) - - else: - t = tf.layers.dense(current_state, self._num_units, tf.sigmoid, - bias_initializer=tf.constant_initializer(self._forget_bias)) - - current_state = (h - current_state)* t + current_state - - return current_state, current_state \ No newline at end of file diff --git a/model.py b/model.py index 963f3b8..c3ad13d 100644 --- a/model.py +++ b/model.py @@ -1,19 +1,21 @@ import tensorflow as tf +from tensorflow.contrib.rnn import GRUCell + from config import * -def Discriminator_RNN(inputs, charmap_len, seq_len, reuse=False, rnn_cell=None): +def Discriminator_GRU(inputs, charmap_len, seq_len, reuse=False): with tf.variable_scope("Discriminator", reuse=reuse): num_neurons = FLAGS.DISC_STATE_SIZE weight = tf.get_variable("embedding", shape=[charmap_len, num_neurons], - initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1)) + initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1)) - # backwards compatibility - if FLAGS.DISC_RNN_LAYERS == 1: - cell = rnn_cell(num_neurons) + # backwards compatability + if FLAGS.DISC_GRU_LAYERS == 1: + cell = GRUCell(num_neurons) else: - cell = tf.contrib.rnn.MultiRNNCell([rnn_cell(num_neurons) for _ in range(FLAGS.DISC_RNN_LAYERS)], state_is_tuple=True) + cell = tf.contrib.rnn.MultiRNNCell([GRUCell(num_neurons) for _ in range(FLAGS.DISC_GRU_LAYERS)], state_is_tuple=True) flat_inputs = tf.reshape(inputs, [-1, charmap_len]) @@ -41,14 +43,14 @@ def Discriminator_RNN(inputs, charmap_len, seq_len, reuse=False, rnn_cell=None): return prediction -def Generator_RNN_CL_VL_TH(n_samples, charmap_len, seq_len=None, gt=None, rnn_cell=None): +def Generator_GRU_CL_VL_TH(n_samples, charmap_len, seq_len=None, gt=None): with tf.variable_scope("Generator"): noise, noise_shape = get_noise() num_neurons = FLAGS.GEN_STATE_SIZE cells = [] - for l in range(FLAGS.GEN_RNN_LAYERS): - cells.append(rnn_cell(num_neurons)) + for l in range(FLAGS.GEN_GRU_LAYERS): + cells.append(GRUCell(num_neurons)) # this is separate to decouple train and test train_initial_states = create_initial_states(noise) @@ -66,14 +68,15 @@ def Generator_RNN_CL_VL_TH(n_samples, charmap_len, seq_len=None, gt=None, rnn_ce seq_len = tf.placeholder(tf.int32, None, name="ground_truth_sequence_length") if gt is not None: #if no GT, we are training - train_pred = get_train_op(cells, char_input, charmap_len, embedding, gt, n_samples, num_neurons, seq_len, sm_bias, sm_weight, train_initial_states) + train_pred = get_train_op(cells, char_input, charmap_len, embedding, gt, n_samples, num_neurons, seq_len, + sm_bias, sm_weight, train_initial_states) inference_op = get_inference_op(cells, char_input, embedding, seq_len, sm_bias, sm_weight, inference_initial_states, - num_neurons, - charmap_len, reuse=True) + num_neurons, + charmap_len, reuse=True) else: inference_op = get_inference_op(cells, char_input, embedding, seq_len, sm_bias, sm_weight, inference_initial_states, - num_neurons, - charmap_len, reuse=False) + num_neurons, + charmap_len, reuse=False) train_pred = None return train_pred, inference_op @@ -81,22 +84,25 @@ def Generator_RNN_CL_VL_TH(n_samples, charmap_len, seq_len=None, gt=None, rnn_ce def create_initial_states(noise): states = [] - for l in range(FLAGS.GEN_RNN_LAYERS): + for l in range(FLAGS.GEN_GRU_LAYERS): states.append(noise) return states -def get_train_op(cells, char_input, charmap_len, embedding, gt, n_samples, num_neurons, seq_len, sm_bias, sm_weight, states): +def get_train_op(cells, char_input, charmap_len, embedding, gt, n_samples, num_neurons, seq_len, sm_bias, sm_weight, + states): gt_embedding = tf.reshape(gt, [n_samples * seq_len, charmap_len]) - gt_RNN_input = tf.matmul(gt_embedding, embedding) - gt_RNN_input = tf.reshape(gt_RNN_input, [n_samples, seq_len, num_neurons])[:, :-1] - gt_sentence_input = tf.concat([char_input, gt_RNN_input], axis=1) - RNN_output, _ = rnn_step_prediction(cells, charmap_len, gt_sentence_input, num_neurons, seq_len, sm_bias, sm_weight, states) + gt_GRU_input = tf.matmul(gt_embedding, embedding) + gt_GRU_input = tf.reshape(gt_GRU_input, [n_samples, seq_len, num_neurons])[:, :-1] + gt_sentence_input = tf.concat([char_input, gt_GRU_input], axis=1) + GRU_output, _ = rnn_step_prediction(cells, charmap_len, gt_sentence_input, num_neurons, seq_len, sm_bias, + sm_weight, + states) train_pred = [] # TODO: optimize loop for i in range(seq_len): train_pred.append( - tf.concat([tf.zeros([BATCH_SIZE, seq_len - i - 1, charmap_len]), gt[:, :i], RNN_output[:, i:i + 1, :]], + tf.concat([tf.zeros([BATCH_SIZE, seq_len - i - 1, charmap_len]), gt[:, :i], GRU_output[:, i:i + 1, :]], axis=1)) train_pred = tf.reshape(train_pred, [BATCH_SIZE*seq_len, seq_len, charmap_len]) @@ -111,14 +117,14 @@ def get_train_op(cells, char_input, charmap_len, embedding, gt, n_samples, num_n def rnn_step_prediction(cells, charmap_len, gt_sentence_input, num_neurons, seq_len, sm_bias, sm_weight, states, reuse=False): with tf.variable_scope("rnn", reuse=reuse): - RNN_output = gt_sentence_input - for l in range(FLAGS.GEN_RNN_LAYERS): - RNN_output, states[l] = tf.nn.dynamic_rnn(cells[l], RNN_output, dtype=tf.float32, - initial_state=states[l], scope="layer_%d" % (l + 1)) - RNN_output = tf.reshape(RNN_output, [-1, num_neurons]) - RNN_output = tf.nn.softmax(tf.matmul(RNN_output, sm_weight) + sm_bias) - RNN_output = tf.reshape(RNN_output, [BATCH_SIZE, -1, charmap_len]) - return RNN_output, states + GRU_output = gt_sentence_input + for l in range(FLAGS.GEN_GRU_LAYERS): + GRU_output, states[l] = tf.nn.dynamic_rnn(cells[l], GRU_output, dtype=tf.float32, + initial_state=states[l], scope="layer_%d" % (l + 1)) + GRU_output = tf.reshape(GRU_output, [-1, num_neurons]) + GRU_output = tf.nn.softmax(tf.matmul(GRU_output, sm_weight) + sm_bias) + GRU_output = tf.reshape(GRU_output, [BATCH_SIZE, -1, charmap_len]) + return GRU_output, states def get_inference_op(cells, char_input, embedding, seq_len, sm_bias, sm_weight, states, num_neurons, charmap_len, @@ -139,11 +145,11 @@ def get_inference_op(cells, char_input, embedding, seq_len, sm_bias, sm_weight, generators = { - "Generator_RNN_CL_VL_TH": Generator_RNN_CL_VL_TH, + "Generator_GRU_CL_VL_TH": Generator_GRU_CL_VL_TH, } discriminators = { - "Discriminator_RNN": Discriminator_RNN, + "Discriminator_GRU": Discriminator_GRU, } def get_noise(): diff --git a/multiplicative_integration.py b/multiplicative_integration.py deleted file mode 100644 index d96eb45..0000000 --- a/multiplicative_integration.py +++ /dev/null @@ -1,42 +0,0 @@ -import tensorflow as tf - -def multiplicative_integration(list_of_inputs, output_size, initial_bias_value=0.0, - weights_already_calculated=False, scope=None): - """Multiplicative Integration from https://arxiv.org/abs/1606.06630 - - expects len(2) for list of inputs and will perform integrative multiplication - - weights_already_calculated will treat the list of inputs as Wx and Uz and is useful for batch normed inputs - """ - with tf.variable_scope(scope or 'double_inputs_multiple_integration'): - if len(list_of_inputs) != 2: - raise ValueError('list of inputs must be 2, you have: {}'.format(len(list_of_inputs))) - - if weights_already_calculated: - Wx = list_of_inputs[0] - Uz = list_of_inputs[1] - - else: - with tf.variable_scope('Calculate_Wx_mulint'): - Wx = tf.layers.dense(list_of_inputs[0], output_size, use_bias=False) - - with tf.variable_scope("Calculate_Uz_mulint"): - Uz = tf.layers.dense(list_of_inputs[1], output_size, use_bias=False) - - with tf.variable_scope("multiplicative_integration"): - alpha = tf.get_variable('mulint_alpha', [output_size], - initializer = tf.truncated_normal_initializer(mean=1.0, stddev=0.1)) - - # For efficiency, we retrieve both beta parameters via tf split - beta1, beta2 = tf.split( - tf.get_variable('mulint_params_betas', [output_size*2], - initializer = tf.truncated_normal_initializer(mean=0.5, stddev=0.1)), - num_or_size_splits=2, - axis=0) - - original_bias = tf.get_variable('mulint_original_bias', [output_size], - initializer = tf.truncated_normal_initializer(mean=initial_bias_value, stddev=0.1)) - - final_output = alpha*Wx*Uz + beta1*Uz + beta2*Wx + original_bias - - return final_output \ No newline at end of file diff --git a/objective.py b/objective.py index 0504850..5ec1e4e 100644 --- a/objective.py +++ b/objective.py @@ -1,22 +1,18 @@ import tensorflow as tf from config import FLAGS, BATCH_SIZE, LAMBDA from model import get_generator, get_discriminator, params_with_name -from fisher_gan_objective import FisherGAN -def get_optimization_ops(disc_cost, gen_cost, global_step, gen_lr, disc_lr): + +def get_optimization_ops(disc_cost, gen_cost, global_step): gen_params = params_with_name('Generator') disc_params = params_with_name('Discriminator') print("Generator Params: %s" % gen_params) print("Disc Params: %s" % disc_params) - gen_train_op = tf.train.AdamOptimizer(learning_rate=gen_lr, beta1=0.5, beta2=0.9).minimize(gen_cost, - var_list=gen_params, - global_step=global_step) - - # Due to TTUR paper, the learning rate of the disc should be different than generator - # https://arxiv.org/abs/1706.08500 - # Therefore, we double disc learning rate - disc_train_op = tf.train.AdamOptimizer(learning_rate=disc_lr, beta1=0.5, beta2=0.9).minimize(disc_cost, - var_list=disc_params) + gen_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(gen_cost, + var_list=gen_params, + global_step=global_step) + disc_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(disc_cost, + var_list=disc_params) return disc_train_op, gen_train_op @@ -37,38 +33,23 @@ def get_substrings_from_gt(real_inputs, seq_length, charmap_len): return all_sub_strings -def define_objective(charmap, real_inputs_discrete, seq_length, gan_type="wgan", rnn_cell=None): - assert gan_type in ["wgan", "fgan", "cgan"] - assert rnn_cell - other_ops = {} +def define_objective(charmap, real_inputs_discrete, seq_length): real_inputs = tf.one_hot(real_inputs_discrete, len(charmap)) Generator = get_generator(FLAGS.GENERATOR_MODEL) Discriminator = get_discriminator(FLAGS.DISCRIMINATOR_MODEL) - train_pred, inference_op = Generator(BATCH_SIZE, len(charmap), seq_len=seq_length, gt=real_inputs, rnn_cell=rnn_cell) + train_pred, inference_op = Generator(BATCH_SIZE, len(charmap), seq_len=seq_length, gt=real_inputs) real_inputs_substrings = get_substrings_from_gt(real_inputs, seq_length, len(charmap)) - disc_real = Discriminator(real_inputs_substrings, len(charmap), seq_length, reuse=False, - rnn_cell=rnn_cell) - disc_fake = Discriminator(train_pred, len(charmap), seq_length, reuse=True, - rnn_cell=rnn_cell) - disc_on_inference = Discriminator(inference_op, len(charmap), seq_length, reuse=True, - rnn_cell=rnn_cell) - - - if gan_type == "wgan": - disc_cost, gen_cost = loss_d_g(disc_fake, disc_real, train_pred, real_inputs_substrings, charmap, seq_length, Discriminator, rnn_cell) - elif gan_type == "fgan": - fgan = FisherGAN() - disc_cost, gen_cost = fgan.loss_d_g(disc_fake, disc_real, train_pred, real_inputs_substrings, charmap, seq_length, Discriminator) - other_ops["alpha_optimizer_op"] = fgan.alpha_optimizer_op - else: - raise NotImplementedError("Cramer GAN not implemented") + disc_real = Discriminator(real_inputs_substrings, len(charmap), seq_length, reuse=False) + disc_fake = Discriminator(train_pred, len(charmap), seq_length, reuse=True) + disc_on_inference = Discriminator(inference_op, len(charmap), seq_length, reuse=True) - return disc_cost, gen_cost, train_pred, disc_fake, disc_real, disc_on_inference, inference_op, other_ops + disc_cost, gen_cost = loss_d_g(disc_fake, disc_real, train_pred, real_inputs_substrings, charmap, seq_length, Discriminator) + return disc_cost, gen_cost, train_pred, disc_fake, disc_real, disc_on_inference, inference_op -def loss_d_g(disc_fake, disc_real, fake_inputs, real_inputs, charmap, seq_length, Discriminator, rnn_cell): +def loss_d_g(disc_fake, disc_real, fake_inputs, real_inputs, charmap, seq_length, Discriminator): disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real) gen_cost = -tf.reduce_mean(disc_fake) @@ -80,9 +61,9 @@ def loss_d_g(disc_fake, disc_real, fake_inputs, real_inputs, charmap, seq_length ) differences = fake_inputs - real_inputs interpolates = real_inputs + (alpha * differences) - gradients = tf.gradients(Discriminator(interpolates, len(charmap), seq_length, reuse=True, rnn_cell=rnn_cell), [interpolates])[0] + gradients = tf.gradients(Discriminator(interpolates, len(charmap), seq_length, reuse=True), [interpolates])[0] slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) disc_cost += LAMBDA * gradient_penalty - return disc_cost, gen_cost + return disc_cost, gen_cost \ No newline at end of file diff --git a/single_length_train.py b/single_length_train.py index 2ad1b7f..dd4fd4d 100644 --- a/single_length_train.py +++ b/single_length_train.py @@ -7,7 +7,6 @@ from objective import get_optimization_ops, define_objective from summaries import define_summaries, \ log_samples -import numpy as np sys.path.append(os.getcwd()) @@ -22,42 +21,33 @@ def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length) if len(DATA_DIR) == 0: raise Exception('Please specify path to data directory in single_length_train.py!') - lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False, n_examples=FLAGS.MAX_N_EXAMPLES) + lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False, + n_examples=FLAGS.MAX_N_EXAMPLES) real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length]) global_step = tf.Variable(0, trainable=False) - disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op, other_ops = define_objective(charmap,real_inputs_discrete, seq_length, - gan_type=FLAGS.GAN_TYPE, rnn_cell=RNN_CELL) - - + disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(charmap, + real_inputs_discrete, + seq_length) merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length) - disc_train_op, gen_train_op = get_optimization_ops( - disc_cost, gen_cost, global_step, FLAGS.DISC_LR, FLAGS.GEN_LR) + disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost, global_step) saver = tf.train.Saver(tf.trainable_variables()) - # Use JIT XLA compilation to speed up calculations - config=tf.ConfigProto( - log_device_placement=False, allow_soft_placement=True) - config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 - - with tf.Session(config=config) as session: + with tf.Session() as session: session.run(tf.initialize_all_variables()) if not is_first: print("Loading previous checkpoint...") internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(prev_seq_length) model_and_data_serialization.optimistic_restore(session, - latest_checkpoint(internal_checkpoint_dir, "checkpoint")) + latest_checkpoint(internal_checkpoint_dir, "checkpoint")) restore_config.set_restore_dir( load_from_curr_session=True) # global param, always load from curr session after finishing the first seq gen = inf_train_gen(lines, charmap) - _gen_cost_list = [] - _disc_cost_list = [] - _step_time_list = [] for iteration in range(iterations): start_time = time.time() @@ -65,71 +55,44 @@ def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length) # Train critic for i in range(CRITIC_ITERS): _data = next(gen) - - if FLAGS.GAN_TYPE.lower() == "fgan": - _disc_cost, _, real_scores, _ = session.run( - [disc_cost, disc_train_op, disc_real, - other_ops["alpha_optimizer_op"]], - feed_dict={real_inputs_discrete: _data} - ) - - elif FLAGS.GAN_TYPE.lower() == "wgan": - _disc_cost, _, real_scores = session.run( + _disc_cost, _, real_scores = session.run( [disc_cost, disc_train_op, disc_real], feed_dict={real_inputs_discrete: _data} - ) - - else: - raise ValueError( - "Appropriate gan type not selected: {}".format(FLAGS.GAN_TYPE)) - _disc_cost_list.append(_disc_cost) - - + ) # Train G for i in range(GEN_ITERS): _data = next(gen) - # in Fisher GAN, paper measures convergence by gen_cost instead of disc_cost - # To measure convergence, gen_cost should start at a positive number and decrease - # to zero. The lower, the better. - _gen_cost, _ = session.run([gen_cost, gen_train_op], feed_dict={real_inputs_discrete: _data}) - _gen_cost_list.append(_gen_cost) - - _step_time_list.append(time.time() - start_time) - - if FLAGS.PRINT_EVERY_STEP: - print("iteration %s/%s"%(iteration, iterations)) - print("disc cost {}"%_disc_cost) - print("gen cost {}".format(_gen_cost)) - print("total step time {}".format(time.time() - start_time)) + _ = session.run(gen_train_op, feed_dict={real_inputs_discrete: _data}) + print("iteration %s/%s"%(iteration, iterations)) + print("disc cost %f"%_disc_cost) # Summaries - if iteration % FLAGS.PRINT_ITERATION == FLAGS.PRINT_ITERATION-1: + if iteration % 100 == 99: _data = next(gen) summary_str = session.run( merged, feed_dict={real_inputs_discrete: _data} ) - tf.logging.warn("iteration %s/%s"%(iteration, iterations)) - tf.logging.warn("disc cost {} gen cost {} average step time {}".format( - np.mean(_disc_cost_list), np.mean(_gen_cost_list), np.mean(_step_time_list))) - _gen_cost_list, _disc_cost_list, _step_time_list = [], [], [] - train_writer.add_summary(summary_str, global_step=iteration) - fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap, fake_inputs, disc_fake, gen, real_inputs_discrete,feed_gt=True) + fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap, + fake_inputs, + disc_fake, + gen, + real_inputs_discrete, + feed_gt=True) log_samples(fake_samples, fake_scores, iteration, seq_length, "train") log_samples(decode_indices_to_string(_data, inv_charmap), real_scores, iteration, seq_length, "gt") - test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, - inv_charmap, - inference_op, - disc_on_inference, - gen, - real_inputs_discrete, - feed_gt=False) + test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap, + inference_op, + disc_on_inference, + gen, + real_inputs_discrete, + feed_gt=False) # disc_on_inference, inference_op log_samples(test_samples, fake_scores, iteration, seq_length, "test")