This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
list_field.py
120 lines (96 loc) · 5.1 KB
/
list_field.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Dict, List, Iterator
from overrides import overrides
from allennlp.data.fields.field import DataArray, Field
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.fields.sequence_field import SequenceField
from allennlp.common.util import pad_sequence_to_length
class ListField(SequenceField[DataArray]):
"""
A `ListField` is a list of other fields. You would use this to represent, e.g., a list of
answer options that are themselves `TextFields`.
This field will get converted into a tensor that has one more mode than the items in the list.
If this is a list of `TextFields` that have shape (num_words, num_characters), this
`ListField` will output a tensor of shape (num_sentences, num_words, num_characters).
# Parameters
field_list : `List[Field]`
A list of `Field` objects to be concatenated into a single input tensor. All of the
contained `Field` objects must be of the same type.
"""
__slots__ = ["field_list"]
def __init__(self, field_list: List[Field]) -> None:
field_class_set = {field.__class__ for field in field_list}
assert (
len(field_class_set) == 1
), "ListFields must contain a single field type, found " + str(field_class_set)
# Not sure why mypy has a hard time with this type...
self.field_list: List[Field] = field_list
# Sequence[Field] methods
def __iter__(self) -> Iterator[Field]:
return iter(self.field_list)
def __getitem__(self, idx: int) -> Field:
return self.field_list[idx]
def __len__(self) -> int:
return len(self.field_list)
@overrides
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
for field in self.field_list:
field.count_vocab_items(counter)
@overrides
def index(self, vocab: Vocabulary):
for field in self.field_list:
field.index(vocab)
@overrides
def get_padding_lengths(self) -> Dict[str, int]:
field_lengths = [field.get_padding_lengths() for field in self.field_list]
padding_lengths = {"num_fields": len(self.field_list)}
# We take the set of all possible padding keys for all fields, rather
# than just a random key, because it is possible for fields to be empty
# when we pad ListFields.
possible_padding_keys = [
key for field_length in field_lengths for key in list(field_length.keys())
]
for key in set(possible_padding_keys):
# In order to be able to nest ListFields, we need to scope the padding length keys
# appropriately, so that nested ListFields don't all use the same "num_fields" key. So
# when we construct the dictionary from the list of fields, we add something to the
# name, and we remove it when padding the list of fields.
padding_lengths["list_" + key] = max(x[key] if key in x else 0 for x in field_lengths)
# Set minimum padding length to handle empty list fields.
for padding_key in padding_lengths:
padding_lengths[padding_key] = max(padding_lengths[padding_key], 1)
return padding_lengths
@overrides
def sequence_length(self) -> int:
return len(self.field_list)
@overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> DataArray:
padded_field_list = pad_sequence_to_length(
self.field_list, padding_lengths["num_fields"], self.field_list[0].empty_field
)
# Here we're removing the scoping on the padding length keys that we added in
# `get_padding_lengths`; see the note there for more detail.
child_padding_lengths = {
key.replace("list_", "", 1): value
for key, value in padding_lengths.items()
if key.startswith("list_")
}
padded_fields = [field.as_tensor(child_padding_lengths) for field in padded_field_list]
return self.field_list[0].batch_tensors(padded_fields)
@overrides
def empty_field(self):
# Our "empty" list field will actually have a single field in the list, so that we can
# correctly construct nested lists. For example, if we have a type that is
# `ListField[ListField[LabelField]]`, we need the top-level `ListField` to know to
# construct a `ListField[LabelField]` when it's padding, and the nested `ListField` needs
# to know that it's empty objects are `LabelFields`. Having an "empty" list actually have
# length one makes this all work out, and we'll always be padding to at least length 1,
# anyway.
return ListField([self.field_list[0].empty_field()])
@overrides
def batch_tensors(self, tensor_list: List[DataArray]) -> DataArray:
# We defer to the class we're wrapping in a list to handle the batching.
return self.field_list[0].batch_tensors(tensor_list)
def __str__(self) -> str:
field_class = self.field_list[0].__class__.__name__
base_string = f"ListField of {len(self.field_list)} {field_class}s : \n"
return " ".join([base_string] + [f"\t {field} \n" for field in self.field_list])