Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
PassThroughIterator (#3015)
Browse files Browse the repository at this point in the history
* Added PassThroughIterator

* Added test for PassThroughIterator

* Add @OVERRIDES and appease mypy.

* Appease pylint and mypy.

* Added new iterator to docs (I think...)

* Opted for simplified implementation

* Appease pylint

* Typo

* Added back in overrides decorator
  • Loading branch information
rloganiv authored and matt-gardner committed Jun 28, 2019
1 parent 70fa4aa commit 15a9cbe
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 2 deletions.
1 change: 1 addition & 0 deletions allennlp/data/iterators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from allennlp.data.iterators.bucket_iterator import BucketIterator
from allennlp.data.iterators.homogeneous_batch_iterator import HomogeneousBatchIterator
from allennlp.data.iterators.multiprocess_iterator import MultiprocessIterator
from allennlp.data.iterators.pass_through_iterator import PassThroughIterator
from allennlp.data.iterators.same_language_iterator import SameLanguageIterator
51 changes: 51 additions & 0 deletions allennlp/data/iterators/pass_through_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Iterable, Iterator
import itertools
import logging

from overrides import overrides

from allennlp.data.dataset import Batch
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@DataIterator.register("pass_through")
class PassThroughIterator(DataIterator):
"""
An iterator which performs no batching or shuffling of instances, only tensorization. E.g,
instances are effectively passed 'straight through' the iterator.
This is essentially the same as a BasicIterator with shuffling disabled, the batch size set
to 1, and maximum samples per batch disabled. The only difference is that this iterator
removes the batch dimension. This can be useful for rare situations where batching is best
performed within the dataset reader (e.g. for contiguous language modeling, or for other
problems where state is shared across batches).
"""
def __init__(self):
super().__init__(batch_size=1)

@overrides
def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]:
raise RuntimeError("PassThroughIterator doesn't use create_batches")

def __call__(self,
instances: Iterable[Instance],
num_epochs: int = None,
shuffle: bool = False) -> Iterator[TensorDict]:
# Warn users that this iterator does not do anything for you.
if shuffle:
logger.warning("PassThroughIterator does not shuffle instances. If shuffling is "
"required, please implement in your DatasetReader.")

if num_epochs is None:
epochs: Iterable[int] = itertools.count()
else:
epochs = range(num_epochs)

for _ in epochs:
for instance in instances:
if self.vocab is not None:
instance.index_fields(self.vocab)
yield instance.as_tensor_dict()
30 changes: 30 additions & 0 deletions allennlp/tests/data/iterators/pass_through_iterator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# pylint: disable=no-self-use,invalid-name
from allennlp.data.iterators.pass_through_iterator import PassThroughIterator, logger
from allennlp.tests.data.iterators.basic_iterator_test import IteratorTest


class TestPassThroughIterator(IteratorTest):
def test_get_num_batches(self):
# Since batching is assumed to be performed in the DatasetReader, the number of batches
# (according to the iterator) should always equal the number of instances.
self.assertEqual(PassThroughIterator().get_num_batches(self.instances),
len(self.instances))

def test_enabling_shuffling_raises_warning(self):
iterator = PassThroughIterator()
iterator.index_with(self.vocab)
generator = iterator(self.instances, shuffle=True)
with self.assertLogs(logger, level='INFO') as context_manager:
next(generator)
self.assertIn('WARNING', context_manager.output[0])

def test_batch_dim_is_removed(self):
# Ensure that PassThroughIterator does not add a batch dimension to tensors.

# First instance is a sequence of four tokens. Thus the expected output is a dict
# containing a single tensor with shape (4,).
iterator = PassThroughIterator()
iterator.index_with(self.vocab)
generator = iterator(self.instances)
tensor_dict = next(generator)
self.assertEqual(tensor_dict['text']['tokens'].size(), (4,))
10 changes: 8 additions & 2 deletions doc/api/allennlp.data.iterators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ allennlp.data.iterators
* :ref:`MultiprocessIterator<multiprocess-iterator>`
* :ref:`HomogeneousBatchIterator<homogeneous-batch-iterator>`
* :ref:`SameLanguageIterator<same-language-iterator>`
* :ref:`PassThroughIterator<pass-through-iterator>`

.. _data-iterator:
.. automodule:: allennlp.data.iterators.data_iterator
Expand Down Expand Up @@ -42,10 +43,15 @@ allennlp.data.iterators
:members:
:undoc-members:
:show-inheritance:

.. _same-language-iterator:
.. automodule:: allennlp.data.iterators.same_language_iterator
:members:
:undoc-members:
:show-inheritance:


.. _pass-through-iterator:
.. automodule:: allennlp.data.iterators.pass_through_iterator
:members:
:undoc-members:
:show-inheritance:

0 comments on commit 15a9cbe

Please sign in to comment.