This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
batch.py
185 lines (158 loc) · 9 KB
/
batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""
A :class:`Batch` represents a collection of `Instance` s to be fed
through a model.
"""
import logging
from collections import defaultdict
from typing import Dict, Iterable, Iterator, List, Union
import numpy
import torch
from allennlp.common.checks import ConfigurationError
from allennlp.common.util import ensure_list
from allennlp.data.instance import Instance
from allennlp.data.vocabulary import Vocabulary
logger = logging.getLogger(__name__)
class Batch(Iterable):
"""
A batch of Instances. In addition to containing the instances themselves,
it contains helper functions for converting the data into tensors.
A Batch just takes an iterable of instances in its constructor and hangs onto them
in a list.
"""
__slots__ = ["instances"]
def __init__(self, instances: Iterable[Instance]) -> None:
super().__init__()
self.instances = ensure_list(instances)
self._check_types()
def _check_types(self) -> None:
"""
Check that all the instances have the same types.
"""
all_instance_fields_and_types: List[Dict[str, str]] = [
{k: v.__class__.__name__ for k, v in x.fields.items()} for x in self.instances
]
# Check all the field names and Field types are the same for every instance.
if not all(all_instance_fields_and_types[0] == x for x in all_instance_fields_and_types):
raise ConfigurationError("You cannot construct a Batch with non-homogeneous Instances.")
def get_padding_lengths(self) -> Dict[str, Dict[str, int]]:
"""
Gets the maximum padding lengths from all `Instances` in this batch. Each `Instance`
has multiple `Fields`, and each `Field` could have multiple things that need padding.
We look at all fields in all instances, and find the max values for each (field_name,
padding_key) pair, returning them in a dictionary.
This can then be used to convert this batch into arrays of consistent length, or to set
model parameters, etc.
"""
padding_lengths: Dict[str, Dict[str, int]] = defaultdict(dict)
all_instance_lengths: List[Dict[str, Dict[str, int]]] = [
instance.get_padding_lengths() for instance in self.instances
]
all_field_lengths: Dict[str, List[Dict[str, int]]] = defaultdict(list)
for instance_lengths in all_instance_lengths:
for field_name, instance_field_lengths in instance_lengths.items():
all_field_lengths[field_name].append(instance_field_lengths)
for field_name, field_lengths in all_field_lengths.items():
for padding_key in field_lengths[0].keys():
max_value = max(x.get(padding_key, 0) for x in field_lengths)
padding_lengths[field_name][padding_key] = max_value
return {**padding_lengths}
def as_tensor_dict(
self, padding_lengths: Dict[str, Dict[str, int]] = None, verbose: bool = False
) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
# This complex return type is actually predefined elsewhere as a DataArray,
# but we can't use it because mypy doesn't like it.
"""
This method converts this `Batch` into a set of pytorch Tensors that can be passed
through a model. In order for the tensors to be valid tensors, all `Instances` in this
batch need to be padded to the same lengths wherever padding is necessary, so we do that
first, then we combine all of the tensors for each field in each instance into a set of
batched tensors for each field.
# Parameters
padding_lengths : `Dict[str, Dict[str, int]]`
If a key is present in this dictionary with a non-`None` value, we will pad to that
length instead of the length calculated from the data. This lets you, e.g., set a
maximum value for sentence length if you want to throw out long sequences.
Entries in this dictionary are keyed first by field name (e.g., "question"), then by
padding key (e.g., "num_tokens").
verbose : `bool`, optional (default=`False`)
Should we output logging information when we're doing this padding? If the batch is
large, this is nice to have, because padding a large batch could take a long time.
But if you're doing this inside of a data generator, having all of this output per
batch is a bit obnoxious (and really slow).
# Returns
tensors : `Dict[str, DataArray]`
A dictionary of tensors, keyed by field name, suitable for passing as input to a model.
This is a `batch` of instances, so, e.g., if the instances have a "question" field and
an "answer" field, the "question" fields for all of the instances will be grouped
together into a single tensor, and the "answer" fields for all instances will be
similarly grouped in a parallel set of tensors, for batched computation. Additionally,
for complex `Fields`, the value of the dictionary key is not necessarily a single
tensor. For example, with the `TextField`, the output is a dictionary mapping
`TokenIndexer` keys to tensors. The number of elements in this sub-dictionary
therefore corresponds to the number of `TokenIndexers` used to index the
`TextField`. Each `Field` class is responsible for batching its own output.
"""
padding_lengths = padding_lengths or defaultdict(dict)
# First we need to decide _how much_ to pad. To do that, we find the max length for all
# relevant padding decisions from the instances themselves. Then we check whether we were
# given a max length for a particular field and padding key. If we were, we use that
# instead of the instance-based one.
if verbose:
logger.info(f"Padding batch of size {len(self.instances)} to lengths {padding_lengths}")
logger.info("Getting max lengths from instances")
instance_padding_lengths = self.get_padding_lengths()
if verbose:
logger.info(f"Instance max lengths: {instance_padding_lengths}")
lengths_to_use: Dict[str, Dict[str, int]] = defaultdict(dict)
for field_name, instance_field_lengths in instance_padding_lengths.items():
for padding_key in instance_field_lengths.keys():
if padding_key in padding_lengths[field_name]:
lengths_to_use[field_name][padding_key] = padding_lengths[field_name][
padding_key
]
else:
lengths_to_use[field_name][padding_key] = instance_field_lengths[padding_key]
# Now we actually pad the instances to tensors.
field_tensors: Dict[str, list] = defaultdict(list)
if verbose:
logger.info(f"Now actually padding instances to length: {lengths_to_use}")
for instance in self.instances:
for field, tensors in instance.as_tensor_dict(lengths_to_use).items():
field_tensors[field].append(tensors)
# Finally, we combine the tensors that we got for each instance into one big tensor (or set
# of tensors) per field. The `Field` classes themselves have the logic for batching the
# tensors together, so we grab a dictionary of field_name -> field class from the first
# instance in the batch.
field_classes = self.instances[0].fields
return {
field_name: field_classes[field_name].batch_tensors(field_tensor_list)
for field_name, field_tensor_list in field_tensors.items()
}
def __iter__(self) -> Iterator[Instance]:
return iter(self.instances)
def index_instances(self, vocab: Vocabulary) -> None:
for instance in self.instances:
instance.index_fields(vocab)
def print_statistics(self) -> None:
# Make sure if has been indexed first
sequence_field_lengths: Dict[str, List] = defaultdict(list)
for instance in self.instances:
if not instance.indexed:
raise ConfigurationError(
"Instances must be indexed with vocabulary "
"before asking to print dataset statistics."
)
for field, field_padding_lengths in instance.get_padding_lengths().items():
for key, value in field_padding_lengths.items():
sequence_field_lengths[f"{field}.{key}"].append(value)
print("\n\n----Dataset Statistics----\n")
for name, lengths in sequence_field_lengths.items():
print(f"Statistics for {name}:")
print(
f"\tLengths: Mean: {numpy.mean(lengths)}, Standard Dev: {numpy.std(lengths)}, "
f"Max: {numpy.max(lengths)}, Min: {numpy.min(lengths)}"
)
print("\n10 Random instances:")
for i in numpy.random.randint(len(self.instances), size=10):
print(f"Instance {i}:")
print(f"\t{self.instances[i]}")