Skip to content

Commit

Permalink
update pg_losses_test
Browse files Browse the repository at this point in the history
  • Loading branch information
gpengzhi committed May 13, 2019
1 parent f702043 commit 7d15242
Showing 1 changed file with 57 additions and 28 deletions.
85 changes: 57 additions & 28 deletions texar/losses/pg_losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def setUp(self):
self._d2 = 32
self._d3 = 32
self._num_classes = 10

self._actions_batch = tf.ones([self._batch_size, self._max_time,
self._d1, self._d2, self._d3],
dtype=tf.int32)
Expand All @@ -39,51 +38,81 @@ def setUp(self):
self._max_time,
self._d1, self._d2,
self._d3])

self._actions_no_batch = tf.ones([self._batch_size, self._max_time,
self._d1, self._d2, self._d3],
dtype=tf.int32)
self._logits_no_batch = tf.random_uniform([self._batch_size,
self._max_time,
self._d1, self._d2, self._d3,
self._num_classes])
self._advantages_no_batch = tf.random_uniform([self._batch_size,
self._max_time,
self._d1, self._d2,
self._d3])
self._sequence_length = tf.random_uniform(
[self._batch_size], maxval=self._max_time, dtype=tf.int32)

def _test_sequence_loss(self, loss_fn, actions, logits, advantages, rank,
def _test_sequence_loss(self, loss_fn, actions, logits, advantages,
batched, sequence_length):
with self.test_session() as sess:
loss = loss_fn(actions, logits, advantages, rank, batched,
sequence_length)
loss = loss_fn(actions, logits, advantages, batched=batched,
sequence_length=sequence_length)
rank = sess.run(tf.rank(loss))
self.assertEqual(rank, 0)

loss = loss_fn(actions, logits, advantages, batched=batched,
sequence_length=sequence_length,
sum_over_timesteps=False)
rank = sess.run(tf.rank(loss))
self.assertEqual(rank, 1)
self.assertEqual(loss.shape, tf.TensorShape([self._max_time]))

loss = loss_fn(actions, logits, advantages, batched=batched,
sequence_length=sequence_length,
sum_over_timesteps=False,
average_across_timesteps=True,
average_across_batch=False)
rank = sess.run(tf.rank(loss))
self.assertEqual(rank, 1)
self.assertEqual(loss.shape, tf.TensorShape([self._batch_size]))

loss = loss_fn(actions, logits, advantages, batched=batched,
sequence_length=sequence_length,
sum_over_timesteps=False,
average_across_batch=False)
rank = sess.run(tf.rank(loss))
self.assertEqual(rank, 2)
self.assertEqual(loss.shape,
tf.TensorShape([self._batch_size, self._max_time]))

sequence_length_time = tf.random_uniform(
[self._max_time], maxval=self._max_time, dtype=tf.int32)
loss = loss_fn(actions, logits, advantages, batched=batched,
sequence_length=sequence_length_time,
sum_over_timesteps=False,
average_across_batch=False,
time_major=True)
self.assertEqual(loss.shape,
tf.TensorShape([self._batch_size, self._max_time]))

def test_pg_losses_with_logits(self):
"""Tests `texar.losses.pg_losses_with_logits`.
"""
self._test_sequence_loss(tx.losses.pg_loss_with_logits,
self._actions_batch,
self._logits_batch,
self._advantages_batch,
None,
True,
self._sequence_length)

def test_1(self):
self._test_sequence_loss(tx.losses.pg_loss_with_logits,
self._actions_no_batch,
self._logits_no_batch,
self._advantages_no_batch,
False,
self._sequence_length)


if __name__ == "__main__":
tf.test.main()






















0 comments on commit 7d15242

Please sign in to comment.