This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
text_field.py
205 lines (175 loc) · 8.69 KB
/
text_field.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""
A `TextField` represents a string of text, the kind that you might want to represent with
standard word vectors, or pass through an LSTM.
"""
from collections import defaultdict
from copy import deepcopy
from typing import Dict, List, Optional, Iterator
import textwrap
from overrides import overrides
from spacy.tokens import Token as SpacyToken
import torch
from allennlp.common.checks import ConfigurationError
from allennlp.data.fields.sequence_field import SequenceField
from allennlp.data.tokenizers import Token
from allennlp.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList
from allennlp.data.vocabulary import Vocabulary
from allennlp.nn import util
# There are two levels of dictionaries here: the top level is for the *key*, which aligns
# TokenIndexers with their corresponding TokenEmbedders. The bottom level is for the *objects*
# produced by a given TokenIndexer, which will be input to a particular TokenEmbedder's forward()
# method. We label these as tensors, because that's what they typically are, though they could in
# reality have arbitrary type.
TextFieldTensors = Dict[str, Dict[str, torch.Tensor]]
class TextField(SequenceField[TextFieldTensors]):
"""
This `Field` represents a list of string tokens. Before constructing this object, you need
to tokenize raw strings using a :class:`~allennlp.data.tokenizers.tokenizer.Tokenizer`.
Because string tokens can be represented as indexed arrays in a number of ways, we also take a
dictionary of :class:`~allennlp.data.token_indexers.token_indexer.TokenIndexer`
objects that will be used to convert the tokens into indices.
Each `TokenIndexer` could represent each token as a single ID, or a list of character IDs, or
something else.
This field will get converted into a dictionary of arrays, one for each `TokenIndexer`. A
`SingleIdTokenIndexer` produces an array of shape (num_tokens,), while a
`TokenCharactersIndexer` produces an array of shape (num_tokens, num_characters).
"""
__slots__ = ["tokens", "_token_indexers", "_indexed_tokens"]
def __init__(
self, tokens: List[Token], token_indexers: Optional[Dict[str, TokenIndexer]] = None
) -> None:
self.tokens = tokens
self._token_indexers = token_indexers
self._indexed_tokens: Optional[Dict[str, IndexedTokenList]] = None
if not all(isinstance(x, (Token, SpacyToken)) for x in tokens):
raise ConfigurationError(
"TextFields must be passed Tokens. "
"Found: {} with types {}.".format(tokens, [type(x) for x in tokens])
)
@property
def token_indexers(self) -> Dict[str, TokenIndexer]:
if self._token_indexers is None:
raise ValueError(
"TextField's token_indexers have not been set.\n"
"Did you forget to call DatasetReader.apply_token_indexers(instance) "
"on your instance?\n"
"If apply_token_indexers() is being called but "
"you're still seeing this error, it may not be implemented correctly."
)
return self._token_indexers
@token_indexers.setter
def token_indexers(self, token_indexers: Dict[str, TokenIndexer]) -> None:
self._token_indexers = token_indexers
@overrides
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
for indexer in self.token_indexers.values():
for token in self.tokens:
indexer.count_vocab_items(token, counter)
@overrides
def index(self, vocab: Vocabulary):
self._indexed_tokens = {}
for indexer_name, indexer in self.token_indexers.items():
self._indexed_tokens[indexer_name] = indexer.tokens_to_indices(self.tokens, vocab)
@overrides
def get_padding_lengths(self) -> Dict[str, int]:
"""
The `TextField` has a list of `Tokens`, and each `Token` gets converted into arrays by
(potentially) several `TokenIndexers`. This method gets the max length (over tokens)
associated with each of these arrays.
"""
if self._indexed_tokens is None:
raise ConfigurationError(
"You must call .index(vocabulary) on a field before determining padding lengths."
)
padding_lengths = {}
for indexer_name, indexer in self.token_indexers.items():
indexer_lengths = indexer.get_padding_lengths(self._indexed_tokens[indexer_name])
for key, length in indexer_lengths.items():
padding_lengths[f"{indexer_name}___{key}"] = length
return padding_lengths
@overrides
def sequence_length(self) -> int:
return len(self.tokens)
@overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> TextFieldTensors:
if self._indexed_tokens is None:
raise ConfigurationError(
"You must call .index(vocabulary) on a field before calling .as_tensor()"
)
tensors = {}
indexer_lengths: Dict[str, Dict[str, int]] = defaultdict(dict)
for key, value in padding_lengths.items():
# We want this to crash if the split fails. Should never happen, so I'm not
# putting in a check, but if you fail on this line, open a github issue.
indexer_name, padding_key = key.split("___")
indexer_lengths[indexer_name][padding_key] = value
for indexer_name, indexer in self.token_indexers.items():
tensors[indexer_name] = indexer.as_padded_tensor_dict(
self._indexed_tokens[indexer_name], indexer_lengths[indexer_name]
)
return tensors
@overrides
def empty_field(self):
text_field = TextField([], self._token_indexers)
text_field._indexed_tokens = {}
if self._token_indexers is not None:
for indexer_name, indexer in self.token_indexers.items():
text_field._indexed_tokens[indexer_name] = indexer.get_empty_token_list()
return text_field
@overrides
def batch_tensors(self, tensor_list: List[TextFieldTensors]) -> TextFieldTensors:
# This is creating a dict of {token_indexer_name: {token_indexer_outputs: batched_tensor}}
# for each token indexer used to index this field.
indexer_lists: Dict[str, List[Dict[str, torch.Tensor]]] = defaultdict(list)
for tensor_dict in tensor_list:
for indexer_name, indexer_output in tensor_dict.items():
indexer_lists[indexer_name].append(indexer_output)
batched_tensors = {
# NOTE(mattg): if an indexer has its own nested structure, rather than one tensor per
# argument, then this will break. If that ever happens, we should move this to an
# `indexer.batch_tensors` method, with this logic as the default implementation in the
# base class.
indexer_name: util.batch_tensor_dicts(indexer_outputs)
for indexer_name, indexer_outputs in indexer_lists.items()
}
return batched_tensors
def __str__(self) -> str:
# Double tab to indent under the header.
formatted_text = "".join(
"\t\t" + text + "\n" for text in textwrap.wrap(repr(self.tokens), 100)
)
if self._token_indexers is not None:
indexers = {
name: indexer.__class__.__name__ for name, indexer in self._token_indexers.items()
}
return (
f"TextField of length {self.sequence_length()} with "
f"text: \n {formatted_text} \t\tand TokenIndexers : {indexers}"
)
else:
return f"TextField of length {self.sequence_length()} with text: \n {formatted_text}"
# Sequence[Token] methods
def __iter__(self) -> Iterator[Token]:
return iter(self.tokens)
def __getitem__(self, idx: int) -> Token:
return self.tokens[idx]
def __len__(self) -> int:
return len(self.tokens)
@overrides
def duplicate(self):
"""
Overrides the behavior of `duplicate` so that `self._token_indexers` won't
actually be deep-copied.
Not only would it be extremely inefficient to deep-copy the token indexers,
but it also fails in many cases since some tokenizers (like those used in
the 'transformers' lib) cannot actually be deep-copied.
"""
if self._token_indexers is not None:
new = TextField(deepcopy(self.tokens), {k: v for k, v in self._token_indexers.items()})
else:
new = TextField(deepcopy(self.tokens))
new._indexed_tokens = deepcopy(self._indexed_tokens)
return new
@overrides
def human_readable_repr(self) -> List[str]:
return [str(t) for t in self.tokens]