Skip to content

Commit

Permalink
Minor tweaks in the data module (#87)
Browse files Browse the repository at this point in the history
* Minor tweaks in the data module

- Change Vocab classes to use plain dict instead of defaultdict.
  defaultdict contains a lambda function that cannot be pickled when
  used with a multi-processing data loader.
- Change DataBase behavior when `num_parallel_calls` == 1. Previously no
  worker processes are spawned when `num_parallel_calls` is 1, now the
  number worker processes is equal to `num_parallel_calls`. As a result,
  the default value is changed to 0.
- Add comments describing the internals of DataBase, including the big
  tables describing behaviors under different lazy/caching mode
  combinations.

* Add device argument in all DataBases

* Removed unused type: ignore comment

* Add device change in collate

* Fix type annotations

* Fix typo in data module docstrings
  • Loading branch information
huzecong authored and AvinashBukkittu committed Jul 3, 2019
1 parent d4da27a commit fc99e0d
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 29 deletions.
88 changes: 79 additions & 9 deletions texar/data/data/data_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,80 @@ class DataBase(Dataset, Generic[RawExample, Example], ABC):
r"""Base class inherited by all data classes.
"""

# pylint: disable=line-too-long

# The `DataBase` is used in combination with Texar `DataIterator`, which internally uses the PyTorch `DataLoader`
# for multi-processing support.
#
# We divide the entire data pipeline into three stages, namely *load*, *process*, and *batch*:
# - **Load** refers to loading data from the data source (e.g., a file, a Python list or iterator). In Texar,
# loading is handled by `DataSource` classes.
# - **Process** refers to preprocessing routines for each data example (e.g., vocabulary mapping, tokenization). In
# Texar, this is the `process` function of each `DataBase` class.
# - **Batch** refers to combining multiple examples to form a batch, which typically includes padding and moving
# data across devices. In Texar, this is the `collate` function of each `DataBase` class.
#
# PyTorch DataLoader only performs batching, and since multi-processing is used, the entire dataset is expected to
# be in memory before iteration, i.e. loading and processing cannot be lazy. The DataBase class is carefully crafted
# to provide laziness and caching options at all possible stages.
#
# To support laziness, we pass data examples (either raw or processed, depending on whether processing is lazy) to
# the worker processes. To prevent modifying the underlying `DataLoader` implementation, we hack the PyTorch
# `Sampler` classes (responsible for sampling the next data example from the dataset, and returning its index) to
# also return data examples. To support caching, the worker may also need to return the processed examples through
# pipes.
#
# The following table describes the intended behavior of each combination of lazy/caching modes, and the exact
# behaviors of the sampler and workers. `<X>` means the mode combination does not make sense (e.g. with `Lazy.None`,
# processed data examples are effectively cached, so `Cache.None` makes no sense). Parts in `*[blah]*` hold true
# only for the first epoch.
#
# +---------------+-------------------------------+-------------------------------+-------------------------------+
# | | Cache.None | Cache.Loaded | Cache.Processed |
# | | no caching | only cache loaded examples | only cache processed examples |
# +===============+===============================+===============================+===============================+
# | Lazy.None | <X> | <X> | Sampler returns indices. |
# | eager load, | | | Worker only does batching. |
# | eager process | | | Worker returns batch. |
# +---------------+-------------------------------+-------------------------------+-------------------------------+
# | Lazy.Process | <X> | Sampler returns indices. | Sampler returns indices. |
# | eager load, | | Worker does batching and | Worker does batching |
# | lazy process | | processing. | *[and processing]*. |
# | | | Worker returns batch. | Worker returns batch |
# | | | | *[and processed examples]*. |
# +---------------+-------------------------------+-------------------------------+-------------------------------+
# | Lazy.All | Sampler returns indices and | Sampler returns indices | Sampler returns indices |
# | lazy load, | data examples. | *[and data examples]*. | *[and data examples]*. |
# | lazy process | Worker does batching and | Worker does batching and | Worker does batching |
# | | processing. | processing. | *[and processing]*. |
# | | Worker returns batch. | Worker returns batch. | Worker returns batch |
# | | | | *[and processed examples]*. |
# +---------------+-------------------------------+-------------------------------+-------------------------------+
#
# Note that in the above table we assume `parallelize_processing` to be True. In rare cases this may not be desired,
# for instance, when `process` depends on some shared variable that must be modified during iteration, e.g. a
# vocabulary constructed on-the-fly. When `parallelize_processing` is False, behaviors are as the following (much
# simpler) table. Although, note that compared to the above cases, this often results in worse performance.
#
# +---------------+-------------------------------+-------------------------------+-------------------------------+
# | | Cache.None | Cache.Loaded | Cache.Processed |
# | | no caching | only cache loaded examples | only cache processed examples |
# +===============+===============================+===============================+===============================+
# | Lazy.None | <X> | <X> | Sampler returns indices. |
# | eager load, | | | Worker only does batching. |
# | eager process | | | Worker returns batch. |
# +---------------+-------------------------------+-------------------------------+-------------------------------+
# | Lazy.Process | <X> | Sampler returns indices and processed examples. |
# | eager load, | | Worker only does batching. |
# | lazy process | | Worker returns batch. |
# +---------------+-------------------------------+---------------------------------------------------------------+
# | Lazy.All | Sampler returns indices and processed examples. |
# | lazy load, | Worker only does batching. |
# | lazy process | Worker returns batch. |
# +---------------+-----------------------------------------------------------------------------------------------+

# pylint: enable=line-too-long

_source: DataSource[RawExample]
_dataset_size: Optional[int]

Expand Down Expand Up @@ -286,7 +360,7 @@ def __init__(self, source: DataSource[RawExample], hparams=None,
f"strategy. This will be equivalent to 'loaded' cache "
f"strategy.")
self._cache_strategy = _CacheStrategy.LOADED
self._uses_multi_processing = self._hparams.num_parallel_calls > 1
self._uses_multi_processing = self._hparams.num_parallel_calls > 0
self._parallelize_processing = self._hparams.parallelize_processing

self._processed_cache: List[Example] = []
Expand Down Expand Up @@ -437,10 +511,9 @@ def default_hparams():
`"num_parallel_calls"`: int
Number of elements from the datasets to process in parallel.
When ``"num_parallel_calls"`` equals 1, no worker threads will
be created; however, when the value is greater than 1, the
number of worker threads will be equal to
``"num_parallel_calls"``.
When ``"num_parallel_calls"`` equals 0, no worker processes will
be created; when the value is greater than 0, the number of worker
processes will be equal to ``"num_parallel_calls"``.
`"prefetch_buffer_size"`: int
The maximum number of elements that will be buffered when
Expand Down Expand Up @@ -514,9 +587,6 @@ def default_hparams():
`none`. If `lazy_strategy` is `none`, processing will be
performed on a single process regardless of this value.
"""
# TODO: Find a way to embed that big table here. Also change the table
# to include different behaviors when `parallelize_processing` is
# `False`.
# TODO: Sharding not yet supported.
# TODO: `seed` is not yet applied.
# TODO: `prefetch_buffer_size` will not be supported, but could remain
Expand All @@ -529,7 +599,7 @@ def default_hparams():
"shuffle": True,
"shuffle_buffer_size": None,
"shard_and_shuffle": False,
"num_parallel_calls": 1,
"num_parallel_calls": 0,
"prefetch_buffer_size": 0,
"max_dataset_size": -1,
"seed": None,
Expand Down
3 changes: 1 addition & 2 deletions texar/data/data/data_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,8 @@ def __init__(self, dataset: DataBase,
else:
sampler = SequentialSampler(dataset)

num_parallel_calls = dataset.hparams.num_parallel_calls
num_workers = dataset.hparams.num_parallel_calls
collate_fn = dataset._collate_and_maybe_return
num_workers = (0 if num_parallel_calls == 1 else num_parallel_calls)

if batching_strategy is not None:
batch_sampler = DynamicBatchSampler(
Expand Down
2 changes: 2 additions & 0 deletions texar/data/data/mono_text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class MonoTextData(TextDataBase[str, List[str]]):
Args:
hparams: A `dict` or instance of :class:`~texar.HParams` containing
hyperparameters. See :meth:`default_hparams` for the defaults.
device: The device of the produced batches. For GPU training, set to
current CUDA device.
By default, the processor reads raw data files, performs tokenization,
batching and other pre-processing steps, and results in a Dataset
Expand Down
7 changes: 7 additions & 0 deletions texar/data/data/multi_aligned_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class MultiAlignedData(
Args:
hparams (dict): Hyperparameters. See :meth:`default_hparams` for the
defaults.
device: The device of the produced batches. For GPU training, set to
current CUDA device.
The processor can read any number of parallel fields as specified in
the "datasets" list of :attr:`hparams`, and result in a Dataset whose
Expand Down Expand Up @@ -377,6 +379,11 @@ def default_hparams():
hparams["datasets"] = []
return hparams

def to(self, device: torch.device):
for dataset in self._databases:
dataset.to(device)
return super().to(device)

@staticmethod
def _raise_sharing_error(err_data, share_data, hparam_name):
raise ValueError(
Expand Down
2 changes: 2 additions & 0 deletions texar/data/data/paired_text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class PairedTextData(TextDataBase[Tuple[str, str],
Args:
hparams (dict): Hyperparameters. See :meth:`default_hparams` for the
defaults.
device: The device of the produced batches. For GPU training, set to
current CUDA device.
By default, the processor reads raw data files, performs tokenization,
batching and other pre-processing steps, and results in a Dataset
Expand Down
26 changes: 17 additions & 9 deletions texar/data/data/record_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,12 @@ def __iter__(self):
break


TransformFn = Callable[[bytes], torch.ByteTensor]


def _create_image_transform(height: Optional[int], width: Optional[int],
resize_method: Union[str, int] = 'bilinear') \
-> Callable[[bytes], torch.ByteTensor]:
-> TransformFn:
r"""Create a function based on `Pillow image transforms
<https://pillow.readthedocs.io/en/3.1.x/reference/Image.html#PIL.Image.Image.resize>`
that performs resizing with desired resize method (interpolation).
Expand Down Expand Up @@ -236,6 +239,8 @@ class RecordData(DataBase[Dict[str, Any], Dict[str, Any]]):
Args:
hparams (dict): Hyperparameters. See :meth:`default_hparams`
for the defaults.
device: The device of the produced batches. For GPU training, set to
current CUDA device.
The module reads and restores data from TFRecord files and
results in a TF Dataset whose element is a Python `dict` that maps feature
Expand Down Expand Up @@ -317,7 +322,7 @@ class RecordData(DataBase[Dict[str, Any], Dict[str, Any]]):
"""

def __init__(self, hparams):
def __init__(self, hparams=None, device: Optional[torch.device] = None):
self._hparams = HParams(hparams, self.default_hparams())

feature_types = self._hparams.dataset.feature_original_types
Expand All @@ -332,7 +337,7 @@ def __init__(self, hparams):
image_options = self._hparams.dataset.image_options
if isinstance(image_options, HParams):
image_options = [image_options]
self._image_transforms = {}
self._image_transforms: Dict[str, TransformFn] = {}
for options in image_options:
key = options.get('image_feature_name')
if key is None or key not in self._features:
Expand All @@ -343,9 +348,10 @@ def __init__(self, hparams):

self._other_transforms = self._hparams.dataset.other_transformations

data_source = PickleDataSource(self._hparams.dataset.files)
data_source = PickleDataSource[Dict[str, Any]](
self._hparams.dataset.files)

super().__init__(data_source, hparams)
super().__init__(data_source, hparams, device)

@classmethod
def _construct(cls, hparams):
Expand All @@ -357,9 +363,9 @@ def _construct(cls, hparams):

convert_types = record_data._hparams.dataset.feature_convert_types
record_data._convert_types = {key: get_numpy_dtype(value)
for key, value in convert_types.items()}
for key, value in convert_types.items()}
for key, dtype in record_data._convert_types.items():
record_data._features[key] = record_data._features[key].\
record_data._features[key] = record_data._features[key]. \
_replace(dtype=dtype)

image_options = record_data._hparams.dataset.image_options
Expand Down Expand Up @@ -625,9 +631,11 @@ def collate(self, examples: List[Dict[str, Any]]) -> Batch:
values, _ = padded_batch(values)
else:
values = np.stack(values, axis=0)
if (not torch.is_tensor(values) and
if (not isinstance(values, torch.Tensor) and
descriptor.dtype not in [np.str_, np.bytes_]):
values = torch.from_numpy(values)
values = torch.from_numpy(values).to(device=self.device)
elif isinstance(values, torch.Tensor):
values = values.to(device=self.device)
else:
# VarLenFeature, just put everything in a Python list.
pass
Expand Down
2 changes: 2 additions & 0 deletions texar/data/data/scalar_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class ScalarData(DataBase[str, Union[int, float]]):
Args:
hparams (dict): Hyperparameters. See :meth:`default_hparams` for the
defaults.
device: The device of the produced batches. For GPU training, set to
current CUDA device.
The processor reads and processes raw data and results in a dataset
whose element is a python `dict` including one field. The field name is
Expand Down
9 changes: 3 additions & 6 deletions texar/data/vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,14 @@ def load(self, filename: str) \
vocab = [self._pad_token, self._bos_token, self._eos_token,
self._unk_token] + vocab
# Must make sure this is consistent with the above line
unk_token_idx = 3
vocab_size = len(vocab)
vocab_idx = np.arange(vocab_size)

# Creates python maps to interface with python code
id_to_token_map_py = _make_defaultdict(vocab_idx, vocab,
self._unk_token)
token_to_id_map_py = _make_defaultdict(vocab, vocab_idx,
unk_token_idx)
id_to_token_map_py = dict(zip(vocab_idx, vocab))
token_to_id_map_py = dict(zip(vocab, vocab_idx))

return id_to_token_map_py, token_to_id_map_py # type: ignore
return id_to_token_map_py, token_to_id_map_py

def map_ids_to_tokens_py(self, ids: np.ndarray) -> np.ndarray:
r"""Maps ids into text tokens.
Expand Down
6 changes: 3 additions & 3 deletions texar/data/vocabulary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def test_vocab_construction(self):
# import pdb
# pdb.set_trace()
# Tests UNK token
unk_token_id = vocab.token_to_id_map_py['new']
unk_token_text = vocab.id_to_token_map_py[unk_token_id]
self.assertEqual(unk_token_text, vocab.unk_token)
unk_token_id = vocab.map_tokens_to_ids_py(['new'])
unk_token_text = vocab.map_ids_to_tokens_py(unk_token_id)
self.assertEqual(unk_token_text[0], vocab.unk_token)


if __name__ == "__main__":
Expand Down

0 comments on commit fc99e0d

Please sign in to comment.