-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from NickShahML/master
Fisher GAN + Recurrent Highway Network + TTUR
- Loading branch information
Showing
9 changed files
with
493 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# Repo data/logs/models | ||
data/ | ||
logs/ | ||
pkl/ | ||
|
||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import tensorflow as tf | ||
|
||
class FisherGAN(): | ||
"""Implements fisher gan objective functions | ||
Modeled off https://github.com/ethancaballero/FisherGAN/blob/master/main.py | ||
Tried to keep variable names the same as much as possible | ||
To measure convergence, gen_cost should start at a positive number and decrease | ||
to zero. The lower, the better. | ||
Warning: in the very beginning of training, you may see the gen_cost rise. Please | ||
wait at least 5000 iterations and the gen_cost should start to lower. This | ||
phenomena is due to the critic finding the appropriate wasserstein distance | ||
and then the generator adjusting for it. | ||
It is recommended that you use a critic iteration of 1 when using fisher gan | ||
""" | ||
|
||
def __init__(self, rho=1e-5): | ||
tf.logging.warn("USING FISHER GAN OBJECTIVE FUNCTION") | ||
self._rho = rho | ||
# Initialize alpha (or in paper called lambda) with zero | ||
# Throughout training alpha is trained with an independent sgd optimizer | ||
# We use "alpha" instead of lambda because code we are modeling off of | ||
# uses "alpha" instead of lambda | ||
self._alpha = tf.get_variable("fisher_alpha", [], initializer=tf.zeros_initializer) | ||
|
||
def _optimize_alpha(self, disc_cost): | ||
""" In the optimization of alpha, we optimize via regular sgd with a learning rate | ||
of rho. | ||
This optimization should occur every time the discriminator is optimized because | ||
the same batch is used. | ||
Very crucial point --> We minimize the NEGATIVE disc_cost with our alpha parameter. | ||
This is done to enforce the Lipchitz constraint. If we minimized the positive disc_cost | ||
then our discriminator loss would drop to a very low negative number and the Lipchitz | ||
constraint would not hold. | ||
""" | ||
|
||
# Find gradient of alpha with respect to negative disc_cost | ||
self._alpha_optimizer = tf.train.GradientDescentOptimizer(self._rho) | ||
self.alpha_optimizer_op = self._alpha_optimizer.minimize(-disc_cost, var_list=[self._alpha]) | ||
return | ||
|
||
def loss_d_g(self, disc_fake, disc_real, fake_inputs, real_inputs, charmap, seq_length, Discriminator): | ||
|
||
# Compared to WGAN, generator cost remains the same in fisher GAN | ||
gen_cost = -tf.reduce_mean(disc_fake) | ||
|
||
# Calculate Lipchitz Constraint | ||
# E_P and E_Q refer to Expectation over real and fake. | ||
|
||
E_Q_f = tf.reduce_mean(disc_fake) | ||
E_P_f = tf.reduce_mean(disc_real) | ||
E_Q_f2 = tf.reduce_mean(disc_fake**2) | ||
E_P_f2 = tf.reduce_mean(disc_real**2) | ||
|
||
constraint = (1 - (0.5*E_P_f2 + 0.5*E_Q_f2)) | ||
|
||
# See Equation (9) in Fisher GAN paper | ||
# In the original implementation, they use a backward computation with mone (minus one) | ||
# To implement this in tensorflow, we simply multiply the objective | ||
# cost function by minus one. | ||
disc_cost = -1.0 * (E_P_f - E_Q_f + self._alpha * constraint - self._rho/2 * constraint**2) | ||
|
||
# calculate optimization op for alpha | ||
self._optimize_alpha(disc_cost) | ||
|
||
return disc_cost, gen_cost |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import tensorflow as tf | ||
from multiplicative_integration import multiplicative_integration | ||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import RNNCell | ||
|
||
|
||
def ntanh(_x, name="normalizing_tanh"): | ||
""" | ||
Inspired by self normalizing networks paper, we adjust scale on tanh | ||
function to encourage mean of 0 and variance of 1 in activations | ||
From comments on reddit, the normalizing tanh function: | ||
1.5925374197228312 | ||
""" | ||
scale = 1.5925374197228312 | ||
return scale*tf.nn.tanh(_x, name=name) | ||
|
||
|
||
class RHNCell(RNNCell): | ||
""" | ||
Recurrent Highway Cell | ||
Reference: https://arxiv.org/abs/1607.03474 | ||
""" | ||
|
||
def __init__(self, num_units, depth=2, forget_bias=-2.0, activation=ntanh): | ||
|
||
"""We initialize forget bias to negative two so that highway layers don't activate | ||
""" | ||
|
||
|
||
assert activation.__name__ == "ntanh" | ||
self._num_units = num_units | ||
self._in_size = num_units | ||
self._depth = depth | ||
self._forget_bias = forget_bias | ||
self._activation = activation | ||
|
||
tf.logging.info("""Building Recurrent Highway Cell with {} Activation of depth {} | ||
and forget bias of {}""".format( | ||
self._activation.__name__, self._depth, self._forget_bias)) | ||
|
||
|
||
@property | ||
def input_size(self): | ||
return self._in_size | ||
|
||
@property | ||
def output_size(self): | ||
return self._num_units | ||
|
||
@property | ||
def state_size(self): | ||
return self._num_units | ||
|
||
def __call__(self, inputs, state, timestep=0, scope=None): | ||
current_state = state | ||
|
||
for i in range(self._depth): | ||
with tf.variable_scope('h_'+str(i)): | ||
if i == 0: | ||
h = self._activation( | ||
multiplicative_integration([inputs,current_state], self._num_units)) | ||
else: | ||
h = tf.layers.dense(current_state, self._num_units, self._activation, | ||
bias_initializer=tf.zeros_initializer()) | ||
|
||
with tf.variable_scope('gate_'+str(i)): | ||
if i == 0: | ||
t = tf.sigmoid( | ||
multiplicative_integration([inputs,current_state], self._num_units, | ||
self._forget_bias)) | ||
|
||
else: | ||
t = tf.layers.dense(current_state, self._num_units, tf.sigmoid, | ||
bias_initializer=tf.constant_initializer(self._forget_bias)) | ||
|
||
current_state = (h - current_state)* t + current_state | ||
|
||
return current_state, current_state |
Oops, something went wrong.