Skip to content

Commit

Permalink
tested bucketing
Browse files Browse the repository at this point in the history
Former-commit-id: 969b5a8
  • Loading branch information
ZhitingHu committed Apr 4, 2018
1 parent f8df49f commit e5ddf2e
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 46 deletions.
43 changes: 36 additions & 7 deletions texar/data/data/mono_text_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,42 @@ def test_batching(self):
hparams.update({"allow_smaller_final_batch": False})
self._run_and_test(hparams, test_batch_size=True)

## bucketing
#raise TexarError("Only usible on TF-1.7")
#hparams = copy.copy(self._hparams)
#hparams.update({
# "bucket_boundaries": [2, 4, 6],
# "bucket_batch_sizes": [6, 4, 2]})
#self._run_and_test(hparams, test_batch_size=True)
def test_bucketing(self):
"""Tests bucketing.
"""
hparams = copy.copy(self._hparams)
hparams.update({
"bucket_boundaries": [7],
"bucket_batch_sizes": [6, 4]})

text_data = tx.data.MonoTextData(hparams)
iterator = text_data.dataset.make_initializable_iterator()
text_data_batch = iterator.get_next()

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
sess.run(iterator.initializer)

while True:
try:
# Run the logics
data_batch_ = sess.run(text_data_batch)
length_ = data_batch_['length'][0]
if length_ < 7:
last_batch_size = hparams['num_epochs'] % 6
self.assertTrue(
len(data_batch_['text']) == 6 or
len(data_batch_['text']) == last_batch_size)
else:
last_batch_size = hparams['num_epochs'] % 4
self.assertTrue(
len(data_batch_['text']) == 4 or
len(data_batch_['text']) == last_batch_size)
except tf.errors.OutOfRangeError:
print('Done -- epoch limit reached')
break

def test_shuffle(self):
"""Tests different shuffle strategies.
Expand Down
85 changes: 46 additions & 39 deletions texar/data/data/paired_text_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ def test_default_setting(self):
"""
self._run_and_test(self._hparams)

def test_shuffle(self):
"""Tests toggling shuffle.
"""
hparams = copy.copy(self._hparams)
hparams["shuffle"] = False
self._run_and_test(hparams)

def test_processing_share(self):
"""Tests sharing processing.
"""
Expand Down Expand Up @@ -160,45 +167,45 @@ def test_length_filter(self):
"length_filter_mode": "discard"})
self._run_and_test(hparams, discard_src=True)

def test_sequence_length(self):
hparams = {
"batch_size": 64,
"num_epochs": 1,
"shuffle": False,
"allow_smaller_final_batch": False,
"source_dataset": {
"files": "../../../data/yelp/sentiment.dev.sort.0",
"vocab_file": "../../../data/yelp/vocab",
"bos_token": SpecialTokens.BOS,
"eos_token": SpecialTokens.EOS,
},
"target_dataset": {
"files": "../../../data/yelp/sentiment.dev.sort.1",
"vocab_share": True,
},
}
data = tx.data.PairedTextData(hparams)

iterator = tx.data.TrainTestDataIterator(val=data)
text_data_batch = iterator.get_next()

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
iterator.switch_to_val_data(sess)

while True:
try:
data_batch_ = sess.run(text_data_batch)
src = data_batch_["source_text_ids"]
src_len = data_batch_["source_length"]
self.assertEqual(src.shape[1], np.max(src_len))
tgt = data_batch_["target_text_ids"]
tgt_len = data_batch_["target_length"]
self.assertEqual(tgt.shape[1], np.max(tgt_len))
except tf.errors.OutOfRangeError:
break
#def test_sequence_length(self):
# hparams = {
# "batch_size": 64,
# "num_epochs": 1,
# "shuffle": False,
# "allow_smaller_final_batch": False,
# "source_dataset": {
# "files": "../../../data/yelp/sentiment.dev.sort.0",
# "vocab_file": "../../../data/yelp/vocab",
# "bos_token": SpecialTokens.BOS,
# "eos_token": SpecialTokens.EOS,
# },
# "target_dataset": {
# "files": "../../../data/yelp/sentiment.dev.sort.1",
# "vocab_share": True,
# },
# }
# data = tx.data.PairedTextData(hparams)

# iterator = tx.data.TrainTestDataIterator(val=data)
# text_data_batch = iterator.get_next()

# with self.test_session() as sess:
# sess.run(tf.global_variables_initializer())
# sess.run(tf.local_variables_initializer())
# sess.run(tf.tables_initializer())
# iterator.switch_to_val_data(sess)

# while True:
# try:
# data_batch_ = sess.run(text_data_batch)
# src = data_batch_["source_text_ids"]
# src_len = data_batch_["source_length"]
# self.assertEqual(src.shape[1], np.max(src_len))
# tgt = data_batch_["target_text_ids"]
# tgt_len = data_batch_["target_length"]
# self.assertEqual(tgt.shape[1], np.max(tgt_len))
# except tf.errors.OutOfRangeError:
# break

if __name__ == "__main__":
tf.test.main()

0 comments on commit e5ddf2e

Please sign in to comment.