Skip to content

Commit

Permalink
updated data base
Browse files Browse the repository at this point in the history
Former-commit-id: 05e0204
  • Loading branch information
ZhitingHu committed Mar 21, 2018
1 parent 029586a commit 297047c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 29 deletions.
49 changes: 49 additions & 0 deletions texar/data/data/data_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from __future__ import print_function
from __future__ import unicode_literals

import tensorflow as tf

from texar.hyperparams import HParams
from texar.data.data import data_utils

__all__ = [
"DataBase"
Expand Down Expand Up @@ -41,6 +44,52 @@ def default_hparams():
"seed": None
}

@staticmethod
def _make_batch(dataset, hparams, padded_batch=False):
dataset = dataset.repeat(hparams.num_epochs)
batch_size = hparams["batch_size"]
if hparams["allow_smaller_final_batch"]:
if padded_batch:
dataset = dataset.padded_batch(
batch_size, dataset.output_shapes)
else:
dataset = dataset.batch(batch_size)
else:
dataset = dataset.apply(
tf.contrib.data.padded_batch_and_drop_remainder(
batch_size, dataset.output_shapes))
return dataset

@staticmethod
def _shuffle_dataset(dataset, hparams, dataset_files):
dataset_size = None
shuffle_buffer_size = hparams["shuffle_buffer_size"]
if hparams["shard_and_shuffle"]:
if shuffle_buffer_size is None:
raise ValueError(
"Dataset hyperparameter 'shuffle_buffer_size' "
"must not be `None` if 'shard_and_shuffle'=`True`.")
dataset_size = data_utils.count_file_lines(dataset_files)
if shuffle_buffer_size >= dataset_size:
raise ValueError(
"Dataset size (%d) <= shuffle_buffer_size (%d). Set "
"shuffle_and_shard to `False`." %
(dataset_size, shuffle_buffer_size))
#TODO(zhiting): Use a different seed?
dataset = dataset.apply(data_utils.random_shard_dataset(
dataset_size, shuffle_buffer_size, hparams["seed"]))
dataset = dataset.shuffle(shuffle_buffer_size + 16, # add a margin
seed=hparams["seed"])
elif hparams["shuffle"]:
if shuffle_buffer_size is None:
dataset_size = data_utils.count_file_lines(dataset_files)
shuffle_buffer_size = dataset_size
dataset = dataset.shuffle(shuffle_buffer_size, seed=hparams["seed"])

return dataset, dataset_size



@property
def num_epochs(self):
"""Number of epochs.
Expand Down
29 changes: 0 additions & 29 deletions texar/data/data/text_data_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import tensorflow as tf

from texar.data.data.data_base import DataBase
from texar.data.data import data_utils


__all__ = [
Expand Down Expand Up @@ -58,31 +57,3 @@ def _make_batch(dataset, hparams, element_length_func):
element_length_func, bucket_boundaries, bucket_batch_size)
return dataset

@staticmethod
def _shuffle_dataset(dataset, hparams, dataset_files):
dataset_size = None
shuffle_buffer_size = hparams["shuffle_buffer_size"]
if hparams["shard_and_shuffle"]:
if shuffle_buffer_size is None:
raise ValueError(
"Dataset hyperparameter 'shuffle_buffer_size' "
"must not be `None` if 'shard_and_shuffle'=`True`.")
dataset_size = data_utils.count_file_lines(dataset_files)
if shuffle_buffer_size >= dataset_size:
raise ValueError(
"Dataset size (%d) <= shuffle_buffer_size (%d). Set "
"shuffle_and_shard to `False`." %
(dataset_size, shuffle_buffer_size))
#TODO(zhiting): Use a different seed?
dataset = dataset.apply(data_utils.random_shard_dataset(
dataset_size, shuffle_buffer_size, hparams["seed"]))
dataset = dataset.shuffle(shuffle_buffer_size + 16, # add a margin
seed=hparams["seed"])
elif hparams["shuffle"]:
if shuffle_buffer_size is None:
dataset_size = data_utils.count_file_lines(dataset_files)
shuffle_buffer_size = dataset_size
dataset = dataset.shuffle(shuffle_buffer_size, seed=hparams["seed"])

return dataset, dataset_size

0 comments on commit 297047c

Please sign in to comment.