Skip to content

Commit

Permalink
updated adv-losses
Browse files Browse the repository at this point in the history
Former-commit-id: c05c5f2
  • Loading branch information
ZhitingHu committed Oct 15, 2017
1 parent 71b16cf commit 59a3b2c
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 57 deletions.
73 changes: 51 additions & 22 deletions txtgen/losses/adv_losses.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,58 @@
#
"""
Adversarial losses.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from txtgen.modules.encoders.rnn_encoders import ForwardRNNEncoder
from txtgen.core import optimization as opt
from txtgen.context import is_train

def adversarial_losses(true_data,
generated_data,
discriminator
):
true_cost = discriminator_result(discriminator, true_data, 1, False)
fake_cost = discriminator_result(discriminator, generated_data, 0, True)
disc_cost = tf.reduce_sum(true_cost + fake_cost) #divide by batch size?
disc_train_op, disc_global_step = opt.get_train_op(disc_cost)
generate_cost = tf.reduce_sum(discriminator_result(discriminator, generated_data, 1, True))
return disc_train_op, disc_global_step, generate_cost, disc_cost

def discriminator_result(discriminator, data, label, reuse=False):
"""Loss for both generated data and true data


def binary_adversarial_losses(real_data,
fake_data,
discriminator_fn,
mode="max_real"):
"""Computes adversarial loss of the real/fake binary classification game.
Args:
real_data (Tensor or array): Real data of shape
`[num_real_examples, ...]`.
fake_data (Tensor or array): Fake data of shape
`[num_fake_examples, ...]`. `num_real_examples` does not necessarily
equal `num_fake_examples`.
discriminator_fn: A callable takes data (e.g., :attr:`real_data` and
:attr:`fake_data`) and returns the logits of being real.
mode (str): Mode of the generator loss. Either `max_real` or `min_fake`.
If `max_real` (default), minimizing the generator loss is to
maximize the probability of fake data being classified as real.
If `min_fake`, minimizing the generator loss is to minimize the
probability of fake data being classified as fake.
Returns:
(scalar Tensor, scalar Tensor): (generator_loss, discriminator_loss).
"""
_, state = discriminator(data)
with tf.variable_scope('discriminator', reuse=reuse):
logits = tf.layers.dense(state[0], 1)
cost = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.ones_like(logits) * label)
return cost
real_logits = discriminator_fn(real_data)
real_loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=real_logits, labels=tf.ones_like(real_logits))
num_real_data = tf.shape(real_loss)[0]
ave_real_loss = tf.reduce_sum(real_loss) / tf.to_float(num_real_data)
fake_logits = discriminator_fn(fake_data)
fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=fake_logits, labels=tf.zeros_like(fake_logits))
num_fake_data = tf.shape(fake_loss)[0]
ave_fake_loss = tf.reduce_sum(fake_loss) / tf.to_float(num_fake_data)
disc_loss = ave_real_loss + ave_fake_loss
if mode == "min_fake":
gen_loss = - ave_fake_loss
elif mode == "max_real":
fake_loss_ = tf.nn.sigmoid_cross_entropy_with_logits(
logits=fake_logits, labels=tf.ones_like(fake_logits))
gen_loss = tf.reduce_sum(fake_loss_) / tf.to_float(num_fake_data)
else:
raise ValueError("Unknown mode: %s. Only 'min_fake' and 'max_real' "
"are allowed.")
return gen_loss, disc_loss

54 changes: 25 additions & 29 deletions txtgen/losses/adv_losses_test.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,38 @@
from __future__ import absolute_import
#
"""
Tests adversarial loss related functions.
"""
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import tensorflow as tf
import numpy as np
from txtgen import context
from txtgen.modules.encoders.rnn_encoders import ForwardRNNEncoder

from .adv_losses import adversarial_losses
from txtgen.losses.adv_losses import binary_adversarial_losses

class AdvLossesTest(tf.test.TestCase):

def test_adv_losses(self):
vocab_size = 4
max_time = 8
"""Tests adversarial losses.
"""
def test_binary_adversarial_losses(self):
"""Tests :meth:`~txtgen.losses.adv_losses.binary_adversarial_losse`.
"""
batch_size = 16
#true_inputs = tf.random_uniform([batch_size, max_time],
# maxval=vocab_size,
# dtype=tf.int32)
true_inputs = np.zeros([batch_size, max_time], dtype="int32")
#generate_inputs = tf.random_uniform([batch_size, max_time + 3],
# maxval=vocab_size,
# dtype=tf.int32)
generate_inputs = np.ones([batch_size, max_time], dtype="int32")
true_inputs_ph = tf.placeholder(tf.int32, [batch_size, max_time])
generate_inputs_ph = tf.placeholder(tf.int32, [batch_size, max_time])
embedding = None
discriminator = ForwardRNNEncoder(embedding=embedding, vocab_size=vocab_size)
disc_train_op, disc_global_step, generator_loss, disc_loss = adversarial_losses(true_inputs_ph, generate_inputs_ph, discriminator)
data_dim = 64
real_data = tf.zeros([batch_size, data_dim], dtype=tf.float32)
fake_data = tf.ones([batch_size, data_dim], dtype=tf.float32)
const_logits = tf.zeros([batch_size], dtype=tf.float32)
# Use a dumb discriminator that always outputs logits=0.
gen_loss, disc_loss = binary_adversarial_losses(
real_data, fake_data, lambda x: const_logits)
gen_loss_2, disc_loss_2 = binary_adversarial_losses(
real_data, fake_data, lambda x: const_logits, mode="min_fake")

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
for i in range(10000):
g_loss, d_loss, _, _ = sess.run([generator_loss, disc_loss, disc_global_step, disc_train_op], feed_dict={context.is_train(): True, true_inputs_ph: true_inputs, generate_inputs_ph: generate_inputs})
print("generator_loss", g_loss)
print("disc_loss", d_loss)
#print("true inputs", true_inputs)
gen_loss_, disc_loss_ = sess.run([gen_loss, disc_loss])
gen_loss_2_, disc_loss_2_ = sess.run([gen_loss_2, disc_loss_2])
self.assertAlmostEqual(gen_loss_, -gen_loss_2_)
self.assertAlmostEqual(disc_loss_, disc_loss_2_)


if __name__ == "__main__":
tf.test.main()
12 changes: 6 additions & 6 deletions txtgen/losses/rl_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow.python.ops import rnn # pylint: disable=E0611

from txtgen.losses.mle_losses import _mask_sequences

Expand All @@ -17,7 +16,7 @@ def reinforce_loss(sample_fn,
global_reward_fn,
local_reward_fn=None,
num_samples=1):
"""Compute REINFORCE loss with global and local rewards.
"""Computes REINFORCE loss with global and local rewards.
Args:
sample_fn: A callable that takes :attr:`num_samples` and returns
Expand Down Expand Up @@ -48,7 +47,7 @@ def reinforce_loss(sample_fn,

# shape = [batch, length]
sequences, probs, seq_lens = sample_fn(num_samples)
batch, length = tf.shape(sequences)
batch, _ = tf.shape(sequences)
rewards_local = tf.constant(0., dtype=probs.dtype, shape=probs.shape)
if local_reward_fn is not None:
rewards_local = local_reward_fn(sequences, seq_lens)
Expand All @@ -60,15 +59,16 @@ def reinforce_loss(sample_fn,

eps = 1e-12
log_probs = _mask_sequences(tf.log(probs + eps), seq_lens)
loss = -tf.reduce_mean(tf.reduce_sum(log_probs * rewards, axis=1) / seq_lens)
loss = - tf.reduce_mean(
tf.reduce_sum(log_probs * rewards, axis=1) / seq_lens)
return loss


def reinforce_loss_with_MCtree(sample_fn,
def reinforce_loss_with_MCtree(sample_fn, # pylint: disable=invalid-name
global_reward_fn,
local_reward_fn=None,
num_samples=1):
"""Compute REINFORCE loss with Monte Carlo tree search.
"""Computes REINFORCE loss with Monte Carlo tree search.
Args:
sample_fn: A callable that takes :attr:`num_samples` and returns
Expand Down

0 comments on commit 59a3b2c

Please sign in to comment.