Skip to content
This repository has been archived by the owner on Aug 3, 2021. It is now read-only.

Commit

Permalink
Merge pull request #162 from Kipok/dev0.4
Browse files Browse the repository at this point in the history
Fix typo with AUTOTUNE
  • Loading branch information
Kipok committed Jun 28, 2018
2 parents ea50859 + 479f1ce commit 17b12eb
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions open_seq2seq/data/image2label/image2label.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def build_graph(self):
)

dataset = dataset.batch(self.params['batch_size'])
dataset = dataset.prefetch(tf.contrib.AUTOTUNE)
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)

self._iterator = dataset.make_initializable_iterator()
inputs, labels = self.iterator.get_next()
Expand Down Expand Up @@ -231,7 +231,7 @@ def build_graph(self):
)

dataset = dataset.batch(self.params['batch_size'])
dataset = dataset.prefetch(tf.contrib.AUTOTUNE)
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)

self._iterator = dataset.make_initializable_iterator()
inputs, labels = self.iterator.get_next()
Expand Down
3 changes: 2 additions & 1 deletion open_seq2seq/data/speech2text/speech2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def build_graph(self):
padded_shapes=([None, self.params['num_audio_features']], 1, 1)
)

self._iterator = self._dataset.prefetch(tf.contrib.AUTOTUNE).make_initializable_iterator()
self._iterator = self._dataset.prefetch(tf.contrib.data.AUTOTUNE)\
.make_initializable_iterator()

if self.params['mode'] != 'infer':
x, x_length, y, y_length = self._iterator.get_next()
Expand Down
3 changes: 2 additions & 1 deletion open_seq2seq/data/text2text/text2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __init__(self, params, model, num_workers=1, worker_id=0):
self._delimiter = self.params.get('delimiter', ' ')
self._map_parallel_calls = self.params.get('map_parallel_calls', 8)
self._pad_lengths_to_eight = self.params.get('pad_lengths_to_eight', False)
self._prefetch_buffer_size = self.params.get('prefetch_buffer_size', tf.contrib.AUTOTUNE)
self._prefetch_buffer_size = self.params.get('prefetch_buffer_size',
tf.contrib.data.AUTOTUNE)
self._num_workers = num_workers
self._worker_id = worker_id
if self._pad_lengths_to_eight and not (self.params['max_length'] % 8 == 0):
Expand Down

0 comments on commit 17b12eb

Please sign in to comment.