Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

PassThroughIterator #3015

Merged
merged 9 commits into from
Jun 28, 2019
Merged

Conversation

rloganiv
Copy link
Contributor

This PR adds a new PassThroughIterator which tensorizes Instances one at a time, and returns them in the exact order that they are created in the DatasetReader. It generalizes the LanguageModelIterator written by @nelson-liu in #2414, which is specifically designed for the task of contiguous language modeling. Since it seems like this approach is generally useful for problems which apply stateful models to encode long sequences (see discussion #2828), and #2414 is currently blocked by the tangential (and rather thorny) issue #2373, I think it is worth adding this seperately.

The only non-trivial aspect of this iterator is that it needs to remove the batch dimension introduced when calling Batch.as_tensor_dict() (since this iterator is intended to be used in situations where batching performed ahead of time within the DatasetReader). To do this, I've written a function which recursively squeezes the first dimension of tensors in a TensorDict (see here). While I think the function behaves sensibly on tensors and dictionaries, I am not so sure about non-tensor fields like MetadataField or ProductionRuleField. Right now I am just returning them exactly as is, but maybe if they are singleton lists it would be better to return the single element?

Copy link
Contributor

@matt-gardner matt-gardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Largely looks good. You still need to add this to the docs; let me know if you don't know how to do that.

return {key: _remove_batch_dim(value) for key, value in singleton.items()} # type: ignore
elif isinstance(singleton, torch.Tensor):
return singleton.squeeze(0)
# TODO(rloganiv): Not sure if this is appropriate for Fields whose as_tensor and batch_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the easiest thing to do about this would be have a test that checks that a reasonable thing happens for a MetadataField (don't worry too much about the ProductionRuleField).

instances are effectively passed 'straight through' the iterator.

This is essentially the same as a BasicIterator with shuffling disabled, the batch size set
to 1, and maximum sampled per batch disabled. The only difference is that this iterator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/maximum sampled/maximum samples/?

This is essentially the same as a BasicIterator with shuffling disabled, the batch size set
to 1, and maximum sampled per batch disabled. The only difference is that this iterator
removes the batch dimension. This can be useful for rare situations where batching is best
performed within the dataset reader (e.g. for continguous language modeling, or for other
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/continguous/continuous/

instances_per_epoch : ``int``, optional, (default = None)
If specified, each epoch will consist of precisely this many instances.
If not specified, each epoch will consist of a single pass through the dataset.
max_instances_in_memory : ``int``, optional, (default = None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mainly for bucket iterators, so we can get more instances into memory before we sort them by size, so you're more likely to have consistently-sized batches. I don't think you need this parameter here, because you're handling this in the DatasetReader, not the base DataIterator.

logger.warning("PassThroughIterator does not shuffle instances. If shuffling is "
"required, please implement in your DatasetReader.")
shuffle = False
for tensor_dict in super().__call__(instances, num_epochs, shuffle):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could make this even simpler and remove the need for _remove_batch_dim by just doing:

def __call__(self, instances, num_epochs, shuffle):
    for epoch in num_epochs:  # handle num_epochs == None here
        for instance in instances:
            yield instance.as_tensor_dict()

This means you don't get caching or epoch tracking, but it simplifies a lot of other things. I'm not sure whether we should do it this way or not, just something to think about.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this approach. It handles the issue _remove_batch_dim has with non-tensor inputs. Also caching will probably not be needed in most use cases since it is expected that the dataset reader will perform actions like shuffling, perturbing sequence lengths, etc.

Copy link
Contributor

@matt-gardner matt-gardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

def __init__(self):
super().__init__(batch_size=1)

def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping an @overrides tag here would be nice. The point is that it makes it obvious to the reader why this method exists.

@matt-gardner matt-gardner merged commit 15a9cbe into allenai:master Jun 28, 2019
@nelson-liu
Copy link
Contributor

@rloganiv thanks for this! Would love your thoughts on the (thorny indeed) #2373 / unblocking #2414 , if you have any

reiyw pushed a commit to reiyw/allennlp that referenced this pull request Nov 12, 2019
* Added PassThroughIterator

* Added test for PassThroughIterator

* Add @OVERRIDES and appease mypy.

* Appease pylint and mypy.

* Added new iterator to docs (I think...)

* Opted for simplified implementation

* Appease pylint

* Typo

* Added back in overrides decorator
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants