Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Improve handling of empty ListFields. #2697

Merged
merged 24 commits into from May 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions allennlp/data/fields/list_field.py
Expand Up @@ -68,6 +68,11 @@ def get_padding_lengths(self) -> Dict[str, int]:
# when we construct the dictionary from the list of fields, we add something to the
# name, and we remove it when padding the list of fields.
padding_lengths['list_' + key] = max(x[key] if key in x else 0 for x in field_lengths)

# Set minimum padding length to handle empty list fields.
for padding_key in padding_lengths:
padding_lengths[padding_key] = max(padding_lengths[padding_key], 1)

return padding_lengths

@overrides
Expand Down
79 changes: 77 additions & 2 deletions allennlp/tests/data/fields/list_field_test.py
@@ -1,11 +1,39 @@
# pylint: disable=no-self-use,invalid-name
# pylint: disable=no-self-use,invalid-name,arguments-differ
from typing import Dict

import numpy
import torch

from allennlp.common.testing import AllenNlpTestCase
from allennlp.data import Token, Vocabulary
from allennlp.data import Token, Vocabulary, Instance
from allennlp.data.fields import TextField, LabelField, ListField, IndexField, SequenceLabelField
from allennlp.data.iterators import BasicIterator
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenCharactersIndexer
from allennlp.data.tokenizers import WordTokenizer
from allennlp.models import Model
from allennlp.modules import Embedding
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder


class DummyModel(Model):
"""
Performs a common operation (embedding) that won't work on an empty tensor.
Returns an arbitrary loss.
"""
def __init__(self, vocab: Vocabulary) -> None:
super().__init__(vocab)
weight = torch.ones(vocab.get_vocab_size(), 10)
token_embedding = Embedding(
num_embeddings=vocab.get_vocab_size(),
embedding_dim=10,
weight=weight,
trainable=False)
self.embedder = BasicTextFieldEmbedder({"words": token_embedding})

def forward(self, # type: ignore
list_tensor: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]:
self.embedder(list_tensor)
return {"loss": 1.0}

class TestListField(AllenNlpTestCase):
def setUp(self):
Expand Down Expand Up @@ -39,6 +67,17 @@ def setUp(self):
self.sequence_label_field = SequenceLabelField([1, 1, 0, 1], self.field1)
self.empty_sequence_label_field = self.sequence_label_field.empty_field()

tokenizer = WordTokenizer()
tokens = tokenizer.tokenize("Foo")
text_field = TextField(tokens, self.word_indexer)
empty_list_field = ListField([text_field.empty_field()])
empty_fields = {'list_tensor': empty_list_field}
self.empty_instance = Instance(empty_fields)

non_empty_list_field = ListField([text_field])
non_empty_fields = {'list_tensor': non_empty_list_field}
self.non_empty_instance = Instance(non_empty_fields)

super(TestListField, self).setUp()

def test_get_padding_lengths(self):
Expand Down Expand Up @@ -189,3 +228,39 @@ def test_sequence_methods(self):
assert len(list_field) == 3
assert list_field[1] == self.field2
assert [f for f in list_field] == [self.field1, self.field2, self.field3]

def test_empty_list_can_be_tensorized(self):
tokenizer = WordTokenizer()
tokens = tokenizer.tokenize("Foo")
text_field = TextField(tokens, self.word_indexer)
list_field = ListField([text_field.empty_field()])
fields = {'list': list_field, 'bar': TextField(tokenizer.tokenize("BAR"), self.word_indexer)}
instance = Instance(fields)
instance.index_fields(self.vocab)
instance.as_tensor_dict()

def test_batch_with_some_empty_lists_works(self):
dataset = [self.empty_instance, self.non_empty_instance]

model = DummyModel(self.vocab)
model.eval()
iterator = BasicIterator(batch_size=2)
iterator.index_with(self.vocab)
batch = next(iterator(dataset, shuffle=False))
model.forward(**batch)

# This use case may seem a bit peculiar. It's intended for situations where
# you have sparse inputs that are used as additional features for some
# prediction, and they are sparse enough that they can be empty for some
# cases. It would be silly to try to handle these as None in your model; it
# makes a whole lot more sense to just have a minimally-sized tensor that
# gets entirely masked and has no effect on the rest of the model.
def test_batch_of_entirely_empty_lists_works(self):
dataset = [self.empty_instance, self.empty_instance]

model = DummyModel(self.vocab)
model.eval()
iterator = BasicIterator(batch_size=2)
iterator.index_with(self.vocab)
batch = next(iterator(dataset, shuffle=False))
model.forward(**batch)