-
Notifications
You must be signed in to change notification settings - Fork 371
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Former-commit-id: c05c5f2
- Loading branch information
Showing
3 changed files
with
82 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,58 @@ | ||
# | ||
""" | ||
Adversarial losses. | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
from txtgen.modules.encoders.rnn_encoders import ForwardRNNEncoder | ||
from txtgen.core import optimization as opt | ||
from txtgen.context import is_train | ||
|
||
def adversarial_losses(true_data, | ||
generated_data, | ||
discriminator | ||
): | ||
true_cost = discriminator_result(discriminator, true_data, 1, False) | ||
fake_cost = discriminator_result(discriminator, generated_data, 0, True) | ||
disc_cost = tf.reduce_sum(true_cost + fake_cost) #divide by batch size? | ||
disc_train_op, disc_global_step = opt.get_train_op(disc_cost) | ||
generate_cost = tf.reduce_sum(discriminator_result(discriminator, generated_data, 1, True)) | ||
return disc_train_op, disc_global_step, generate_cost, disc_cost | ||
|
||
def discriminator_result(discriminator, data, label, reuse=False): | ||
"""Loss for both generated data and true data | ||
|
||
|
||
def binary_adversarial_losses(real_data, | ||
fake_data, | ||
discriminator_fn, | ||
mode="max_real"): | ||
"""Computes adversarial loss of the real/fake binary classification game. | ||
Args: | ||
real_data (Tensor or array): Real data of shape | ||
`[num_real_examples, ...]`. | ||
fake_data (Tensor or array): Fake data of shape | ||
`[num_fake_examples, ...]`. `num_real_examples` does not necessarily | ||
equal `num_fake_examples`. | ||
discriminator_fn: A callable takes data (e.g., :attr:`real_data` and | ||
:attr:`fake_data`) and returns the logits of being real. | ||
mode (str): Mode of the generator loss. Either `max_real` or `min_fake`. | ||
If `max_real` (default), minimizing the generator loss is to | ||
maximize the probability of fake data being classified as real. | ||
If `min_fake`, minimizing the generator loss is to minimize the | ||
probability of fake data being classified as fake. | ||
Returns: | ||
(scalar Tensor, scalar Tensor): (generator_loss, discriminator_loss). | ||
""" | ||
_, state = discriminator(data) | ||
with tf.variable_scope('discriminator', reuse=reuse): | ||
logits = tf.layers.dense(state[0], 1) | ||
cost = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.ones_like(logits) * label) | ||
return cost | ||
real_logits = discriminator_fn(real_data) | ||
real_loss = tf.nn.sigmoid_cross_entropy_with_logits( | ||
logits=real_logits, labels=tf.ones_like(real_logits)) | ||
num_real_data = tf.shape(real_loss)[0] | ||
ave_real_loss = tf.reduce_sum(real_loss) / tf.to_float(num_real_data) | ||
fake_logits = discriminator_fn(fake_data) | ||
fake_loss = tf.nn.sigmoid_cross_entropy_with_logits( | ||
logits=fake_logits, labels=tf.zeros_like(fake_logits)) | ||
num_fake_data = tf.shape(fake_loss)[0] | ||
ave_fake_loss = tf.reduce_sum(fake_loss) / tf.to_float(num_fake_data) | ||
disc_loss = ave_real_loss + ave_fake_loss | ||
if mode == "min_fake": | ||
gen_loss = - ave_fake_loss | ||
elif mode == "max_real": | ||
fake_loss_ = tf.nn.sigmoid_cross_entropy_with_logits( | ||
logits=fake_logits, labels=tf.ones_like(fake_logits)) | ||
gen_loss = tf.reduce_sum(fake_loss_) / tf.to_float(num_fake_data) | ||
else: | ||
raise ValueError("Unknown mode: %s. Only 'min_fake' and 'max_real' " | ||
"are allowed.") | ||
return gen_loss, disc_loss | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,38 @@ | ||
from __future__ import absolute_import | ||
# | ||
""" | ||
Tests adversarial loss related functions. | ||
""" | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
import tensorflow as tf | ||
import numpy as np | ||
from txtgen import context | ||
from txtgen.modules.encoders.rnn_encoders import ForwardRNNEncoder | ||
|
||
from .adv_losses import adversarial_losses | ||
from txtgen.losses.adv_losses import binary_adversarial_losses | ||
|
||
class AdvLossesTest(tf.test.TestCase): | ||
|
||
def test_adv_losses(self): | ||
vocab_size = 4 | ||
max_time = 8 | ||
"""Tests adversarial losses. | ||
""" | ||
def test_binary_adversarial_losses(self): | ||
"""Tests :meth:`~txtgen.losses.adv_losses.binary_adversarial_losse`. | ||
""" | ||
batch_size = 16 | ||
#true_inputs = tf.random_uniform([batch_size, max_time], | ||
# maxval=vocab_size, | ||
# dtype=tf.int32) | ||
true_inputs = np.zeros([batch_size, max_time], dtype="int32") | ||
#generate_inputs = tf.random_uniform([batch_size, max_time + 3], | ||
# maxval=vocab_size, | ||
# dtype=tf.int32) | ||
generate_inputs = np.ones([batch_size, max_time], dtype="int32") | ||
true_inputs_ph = tf.placeholder(tf.int32, [batch_size, max_time]) | ||
generate_inputs_ph = tf.placeholder(tf.int32, [batch_size, max_time]) | ||
embedding = None | ||
discriminator = ForwardRNNEncoder(embedding=embedding, vocab_size=vocab_size) | ||
disc_train_op, disc_global_step, generator_loss, disc_loss = adversarial_losses(true_inputs_ph, generate_inputs_ph, discriminator) | ||
data_dim = 64 | ||
real_data = tf.zeros([batch_size, data_dim], dtype=tf.float32) | ||
fake_data = tf.ones([batch_size, data_dim], dtype=tf.float32) | ||
const_logits = tf.zeros([batch_size], dtype=tf.float32) | ||
# Use a dumb discriminator that always outputs logits=0. | ||
gen_loss, disc_loss = binary_adversarial_losses( | ||
real_data, fake_data, lambda x: const_logits) | ||
gen_loss_2, disc_loss_2 = binary_adversarial_losses( | ||
real_data, fake_data, lambda x: const_logits, mode="min_fake") | ||
|
||
with self.test_session() as sess: | ||
sess.run(tf.global_variables_initializer()) | ||
sess.run(tf.local_variables_initializer()) | ||
for i in range(10000): | ||
g_loss, d_loss, _, _ = sess.run([generator_loss, disc_loss, disc_global_step, disc_train_op], feed_dict={context.is_train(): True, true_inputs_ph: true_inputs, generate_inputs_ph: generate_inputs}) | ||
print("generator_loss", g_loss) | ||
print("disc_loss", d_loss) | ||
#print("true inputs", true_inputs) | ||
gen_loss_, disc_loss_ = sess.run([gen_loss, disc_loss]) | ||
gen_loss_2_, disc_loss_2_ = sess.run([gen_loss_2, disc_loss_2]) | ||
self.assertAlmostEqual(gen_loss_, -gen_loss_2_) | ||
self.assertAlmostEqual(disc_loss_, disc_loss_2_) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters