Skip to content
This repository has been archived by the owner on Aug 3, 2021. It is now read-only.

text2text.py data pipeline inefficiency #243

Closed
vsuthichai opened this issue Oct 1, 2018 · 3 comments
Closed

text2text.py data pipeline inefficiency #243

vsuthichai opened this issue Oct 1, 2018 · 3 comments

Comments

@vsuthichai
Copy link
Contributor

vsuthichai commented Oct 1, 2018

There are some inefficiencies with applying a shard after cache in text2text.py build_graph method. Rewriting part of the code such that the shard is called first before map, cache results in a pretty sizable decrease in time per step when training big transformer on wmt14 en-de. Happy to provide a pull request. On 4 nodes, 32 gpus, batch size 128, iter size 16, mixed precision training, 16gb voltas, I've noticed time per step drop from 13.9s without the fix below to 6.6s with the following code change in the dataset pipeline.

  def build_graph(self):
    _sources = tf.data.TextLineDataset(self.source_file)

    if self._num_workers > 1:
      _sources = _sources.shard(num_shards=self._num_workers, index=self._worker_id)

    _sources = _sources.map(lambda line: tf.py_func(func=self._src_token_to_id, inp=[line],
                                   Tout=[tf.int32], stateful=False),
           num_parallel_calls=self._map_parallel_calls) \
      .map(lambda tokens: (tokens, tf.size(tokens)),
           num_parallel_calls=self._map_parallel_calls)

    _targets = tf.data.TextLineDataset(self.target_file) \

    if self._num_workers > 1:
      _targets = _targets.shard(num_shards=self._num_workers, index=self._worker_id)

    _targets = _targets.map(lambda line: tf.py_func(func=self._tgt_token_to_id, inp=[line],
                                   Tout=[tf.int32], stateful=False),
           num_parallel_calls=self._map_parallel_calls) \
      .map(lambda tokens: (tokens, tf.size(tokens)),
           num_parallel_calls=self._map_parallel_calls)

    _src_tgt_dataset = tf.data.Dataset.zip((_sources, _targets)).filter(
      lambda t1, t2: tf.logical_and(tf.less_equal(t1[1], self.max_len),
                                    tf.less_equal(t2[1], self.max_len))
    ).cache()

#204

@okuchaiev
Copy link
Member

okuchaiev commented Oct 10, 2018

Thanks @vsuthichai ! It does seem to make a big difference if horovod isn't used.
Is this what you meant: #246 ?
My only concern is whether shard is deterministic or not (it seems such) so that mapping between src and tgt is preserved

@vsuthichai
Copy link
Contributor Author

@okuchaiev Sorry, I didn't specify this earlier. All experiments I've been doing are using Horovod, but yes, that appears to be the correct fix 👍

@okuchaiev
Copy link
Member

thx!

btw we changed license to Apache 2.0 so that it is easier for external people to send PRs https://github.com/NVIDIA/OpenSeq2Seq/blob/18.11-dev/LICENSE

(it currently shows MIT but this is because we haven't "released" 18.11 yet)

okuchaiev added a commit that referenced this issue Oct 10, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants