This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
srl.py
297 lines (244 loc) · 11.5 KB
/
srl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import logging
from typing import Dict, List, Iterable, Tuple, Any
from overrides import overrides
from transformers.models.bert.tokenization_bert import BertTokenizer
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
from allennlp.data.tokenizers import Token
from allennlp_models.common.ontonotes import Ontonotes, OntonotesSentence
logger = logging.getLogger(__name__)
def _convert_tags_to_wordpiece_tags(tags: List[str], offsets: List[int]) -> List[str]:
"""
Converts a series of BIO tags to account for a wordpiece tokenizer,
extending/modifying BIO tags where appropriate to deal with words which
are split into multiple wordpieces by the tokenizer.
This is only used if you pass a `bert_model_name` to the dataset reader below.
# Parameters
tags : `List[str]`
The BIO formatted tags to convert to BIO tags for wordpieces
offsets : `List[int]`
The wordpiece offsets.
# Returns
The new BIO tags.
"""
new_tags = []
j = 0
for i, offset in enumerate(offsets):
tag = tags[i]
is_o = tag == "O"
is_start = True
while j < offset:
if is_o:
new_tags.append("O")
elif tag.startswith("I"):
new_tags.append(tag)
elif is_start and tag.startswith("B"):
new_tags.append(tag)
is_start = False
elif tag.startswith("B"):
_, label = tag.split("-", 1)
new_tags.append("I-" + label)
j += 1
# Add O tags for cls and sep tokens.
return ["O"] + new_tags + ["O"]
def _convert_verb_indices_to_wordpiece_indices(verb_indices: List[int], offsets: List[int]):
"""
Converts binary verb indicators to account for a wordpiece tokenizer,
extending/modifying BIO tags where appropriate to deal with words which
are split into multiple wordpieces by the tokenizer.
This is only used if you pass a `bert_model_name` to the dataset reader below.
# Parameters
verb_indices : `List[int]`
The binary verb indicators, 0 for not a verb, 1 for verb.
offsets : `List[int]`
The wordpiece offsets.
# Returns
The new verb indices.
"""
j = 0
new_verb_indices = []
for i, offset in enumerate(offsets):
indicator = verb_indices[i]
while j < offset:
new_verb_indices.append(indicator)
j += 1
# Add 0 indicators for cls and sep tokens.
return [0] + new_verb_indices + [0]
@DatasetReader.register("srl")
class SrlReader(DatasetReader):
"""
This DatasetReader is designed to read in the English OntoNotes v5.0 data
for semantic role labelling. It returns a dataset of instances with the
following fields:
tokens : `TextField`
The tokens in the sentence.
verb_indicator : `SequenceLabelField`
A sequence of binary indicators for whether the word is the verb for this frame.
tags : `SequenceLabelField`
A sequence of Propbank tags for the given verb in a BIO format.
# Parameters
token_indexers : `Dict[str, TokenIndexer]`, optional
We similarly use this for both the premise and the hypothesis. See :class:`TokenIndexer`.
Default is `{"tokens": SingleIdTokenIndexer()}`.
domain_identifier : `str`, (default = `None`)
A string denoting a sub-domain of the Ontonotes 5.0 dataset to use. If present, only
conll files under paths containing this domain identifier will be processed.
bert_model_name : `Optional[str]`, (default = `None`)
The BERT model to be wrapped. If you specify a bert_model here, then we will
assume you want to use BERT throughout; we will use the bert tokenizer,
and will expand your tags and verb indicators accordingly. If not,
the tokens will be indexed as normal with the token_indexers.
# Returns
A `Dataset` of `Instances` for Semantic Role Labelling.
"""
def __init__(
self,
token_indexers: Dict[str, TokenIndexer] = None,
domain_identifier: str = None,
bert_model_name: str = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if token_indexers is not None:
self._token_indexers = token_indexers
elif bert_model_name is not None:
from allennlp.data.token_indexers import PretrainedTransformerIndexer
self._token_indexers = {"tokens": PretrainedTransformerIndexer(bert_model_name)}
else:
self._token_indexers = {"tokens": SingleIdTokenIndexer()}
self._domain_identifier = domain_identifier
if bert_model_name is not None:
self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
self.lowercase_input = "uncased" in bert_model_name
else:
self.bert_tokenizer = None
self.lowercase_input = False
def _wordpiece_tokenize_input(
self, tokens: List[str]
) -> Tuple[List[str], List[int], List[int]]:
"""
Convert a list of tokens to wordpiece tokens and offsets, as well as adding
BERT CLS and SEP tokens to the beginning and end of the sentence.
A slight oddity with this function is that it also returns the wordpiece offsets
corresponding to the _start_ of words as well as the end.
We need both of these offsets (or at least, it's easiest to use both), because we need
to convert the labels to tags using the end_offsets. However, when we are decoding a
BIO sequence inside the SRL model itself, it's important that we use the start_offsets,
because otherwise we might select an ill-formed BIO sequence from the BIO sequence on top of
wordpieces (this happens in the case that a word is split into multiple word pieces,
and then we take the last tag of the word, which might correspond to, e.g, I-V, which
would not be allowed as it is not preceded by a B tag).
For example:
`annotate` will be bert tokenized as ["anno", "##tate"].
If this is tagged as [B-V, I-V] as it should be, we need to select the
_first_ wordpiece label to be the label for the token, because otherwise
we may end up with invalid tag sequences (we cannot start a new tag with an I).
# Returns
wordpieces : `List[str]`
The BERT wordpieces from the words in the sentence.
end_offsets : `List[int]`
Indices into wordpieces such that `[wordpieces[i] for i in end_offsets]`
results in the end wordpiece of each word being chosen.
start_offsets : `List[int]`
Indices into wordpieces such that `[wordpieces[i] for i in start_offsets]`
results in the start wordpiece of each word being chosen.
"""
word_piece_tokens: List[str] = []
end_offsets = []
start_offsets = []
cumulative = 0
for token in tokens:
if self.lowercase_input:
token = token.lower()
word_pieces = self.bert_tokenizer.wordpiece_tokenizer.tokenize(token)
start_offsets.append(cumulative + 1)
cumulative += len(word_pieces)
end_offsets.append(cumulative)
word_piece_tokens.extend(word_pieces)
wordpieces = ["[CLS]"] + word_piece_tokens + ["[SEP]"]
return wordpieces, end_offsets, start_offsets
@overrides
def _read(self, file_path: str):
# if `file_path` is a URL, redirect to the cache
file_path = cached_path(file_path)
ontonotes_reader = Ontonotes()
logger.info("Reading SRL instances from dataset files at: %s", file_path)
if self._domain_identifier is not None:
logger.info(
"Filtering to only include file paths containing the %s domain",
self._domain_identifier,
)
for sentence in self._ontonotes_subset(
ontonotes_reader, file_path, self._domain_identifier
):
tokens = [Token(t) for t in sentence.words]
if not sentence.srl_frames:
# Sentence contains no predicates.
tags = ["O" for _ in tokens]
verb_label = [0 for _ in tokens]
yield self.text_to_instance(tokens, verb_label, tags)
else:
for (_, tags) in sentence.srl_frames:
verb_indicator = [1 if label[-2:] == "-V" else 0 for label in tags]
yield self.text_to_instance(tokens, verb_indicator, tags)
@staticmethod
def _ontonotes_subset(
ontonotes_reader: Ontonotes, file_path: str, domain_identifier: str
) -> Iterable[OntonotesSentence]:
"""
Iterates over the Ontonotes 5.0 dataset using an optional domain identifier.
If the domain identifier is present, only examples which contain the domain
identifier in the file path are yielded.
"""
for conll_file in ontonotes_reader.dataset_path_iterator(file_path):
if domain_identifier is None or f"/{domain_identifier}/" in conll_file:
yield from ontonotes_reader.sentence_iterator(conll_file)
def text_to_instance( # type: ignore
self, tokens: List[Token], verb_label: List[int], tags: List[str] = None
) -> Instance:
"""
We take `pre-tokenized` input here, along with a verb label. The verb label should be a
one-hot binary vector, the same length as the tokens, indicating the position of the verb
to find arguments for.
"""
metadata_dict: Dict[str, Any] = {}
if self.bert_tokenizer is not None:
wordpieces, offsets, start_offsets = self._wordpiece_tokenize_input(
[t.text for t in tokens]
)
new_verbs = _convert_verb_indices_to_wordpiece_indices(verb_label, offsets)
metadata_dict["offsets"] = start_offsets
# In order to override the indexing mechanism, we need to set the `text_id`
# attribute directly. This causes the indexing to use this id.
text_field = TextField(
[Token(t, text_id=self.bert_tokenizer.vocab[t]) for t in wordpieces],
token_indexers=self._token_indexers,
)
verb_indicator = SequenceLabelField(new_verbs, text_field)
else:
text_field = TextField(tokens, token_indexers=self._token_indexers)
verb_indicator = SequenceLabelField(verb_label, text_field)
fields: Dict[str, Field] = {}
fields["tokens"] = text_field
fields["verb_indicator"] = verb_indicator
if all(x == 0 for x in verb_label):
verb = None
verb_index = None
else:
verb_index = verb_label.index(1)
verb = tokens[verb_index].text
metadata_dict["words"] = [x.text for x in tokens]
metadata_dict["verb"] = verb
metadata_dict["verb_index"] = verb_index
if tags:
if self.bert_tokenizer is not None:
new_tags = _convert_tags_to_wordpiece_tags(tags, offsets)
fields["tags"] = SequenceLabelField(new_tags, text_field)
else:
fields["tags"] = SequenceLabelField(tags, text_field)
metadata_dict["gold_tags"] = tags
fields["metadata"] = MetadataField(metadata_dict)
return Instance(fields)