Skip to content

Commit

Permalink
DataIterator deterministically picks the 1st dataset for creating TF …
Browse files Browse the repository at this point in the history
…iterators.

Fixing #23
  • Loading branch information
ZhitingHu committed Sep 14, 2018
1 parent 1fa6b4d commit 03ad846
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
5 changes: 4 additions & 1 deletion texar/data/data/data_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def default_hparams():
"allow_smaller_final_batch" : bool
Whether to allow the final batch to be smaller if there are
insufficient elements left. If `False`, the final batch is
discarded if it is smaller than batch size.
discarded if it is smaller than batch size. Note that,
if `True`, `output_shapes` of the resulting dataset
will have a a **static** batch_size dimension equal to
"batch_size".
"shuffle" : bool
Whether to randomly shuffle the elements of the dataset.
Expand Down
10 changes: 5 additions & 5 deletions texar/data/data/data_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def __init__(self, datasets):

self._variable_scope = get_unique_named_variable_scope('data_iterator')
with tf.variable_scope(self._variable_scope):
arb_dataset = self._datasets[next(iter(self._datasets))]
first_dataset = self._datasets[sorted(self.dataset_names)[0]]
self._iterator = tf.data.Iterator.from_structure(
arb_dataset.output_types, arb_dataset.output_shapes)
first_dataset.output_types, first_dataset.output_shapes)
self._iterator_init_ops = {
name: self._iterator.make_initializer(d)
for name, d in self._datasets.items()
Expand Down Expand Up @@ -324,10 +324,10 @@ def __init__(self, datasets):
'feedable_data_iterator')
with tf.variable_scope(self._variable_scope):
self._handle = tf.placeholder(tf.string, shape=[], name='handle')
arb_dataset = self._datasets[next(iter(self._datasets))]
first_dataset = self._datasets[sorted(self.dataset_names)[0]]
self._iterator = tf.data.Iterator.from_string_handle(
self._handle, arb_dataset.output_types,
arb_dataset.output_shapes)
self._handle, first_dataset.output_types,
first_dataset.output_shapes)

self._dataset_iterators = {
name: dataset.make_initializable_iterator()
Expand Down

0 comments on commit 03ad846

Please sign in to comment.