This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
graph_parser.py
363 lines (311 loc) · 15.8 KB
/
graph_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
from typing import Dict, Tuple, Any, List
import logging
import copy
from overrides import overrides
import torch
from torch.nn.modules import Dropout
import numpy
from allennlp.common.checks import check_dimensions_match, ConfigurationError
from allennlp.data import TextFieldTensors, Vocabulary
from allennlp.modules import Seq2SeqEncoder, TextFieldEmbedder, Embedding, InputVariationalDropout
from allennlp.modules.matrix_attention.bilinear_matrix_attention import BilinearMatrixAttention
from allennlp.modules import FeedForward
from allennlp.models.model import Model
from allennlp.nn import InitializerApplicator, Activation
from allennlp.nn.util import min_value_of_dtype
from allennlp.nn.util import get_text_field_mask
from allennlp.nn.util import get_lengths_from_binary_sequence_mask
from allennlp.training.metrics import F1Measure
logger = logging.getLogger(__name__)
@Model.register("graph_parser")
@Model.register("sp-graph-parser")
class GraphParser(Model):
"""
A Parser for arbitrary graph structures.
Registered as a `Model` with name "graph_parser".
# Parameters
vocab : `Vocabulary`, required
A Vocabulary, required in order to compute sizes for input/output projections.
text_field_embedder : `TextFieldEmbedder`, required
Used to embed the `tokens` `TextField` we get as input to the model.
encoder : `Seq2SeqEncoder`
The encoder (with its own internal stacking) that we will use to generate representations
of tokens.
tag_representation_dim : `int`, required.
The dimension of the MLPs used for arc tag prediction.
arc_representation_dim : `int`, required.
The dimension of the MLPs used for arc prediction.
tag_feedforward : `FeedForward`, optional, (default = `None`).
The feedforward network used to produce tag representations.
By default, a 1 layer feedforward network with an elu activation is used.
arc_feedforward : `FeedForward`, optional, (default = `None`).
The feedforward network used to produce arc representations.
By default, a 1 layer feedforward network with an elu activation is used.
pos_tag_embedding : `Embedding`, optional.
Used to embed the `pos_tags` `SequenceLabelField` we get as input to the model.
dropout : `float`, optional, (default = `0.0`)
The variational dropout applied to the output of the encoder and MLP layers.
input_dropout : `float`, optional, (default = `0.0`)
The dropout applied to the embedded text input.
edge_prediction_threshold : `int`, optional (default = `0.5`)
The probability at which to consider a scored edge to be 'present'
in the decoded graph. Must be between 0 and 1.
initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
Used to initialize the model parameters.
"""
def __init__(
self,
vocab: Vocabulary,
text_field_embedder: TextFieldEmbedder,
encoder: Seq2SeqEncoder,
tag_representation_dim: int,
arc_representation_dim: int,
tag_feedforward: FeedForward = None,
arc_feedforward: FeedForward = None,
pos_tag_embedding: Embedding = None,
dropout: float = 0.0,
input_dropout: float = 0.0,
edge_prediction_threshold: float = 0.5,
initializer: InitializerApplicator = InitializerApplicator(),
**kwargs,
) -> None:
super().__init__(vocab, **kwargs)
self.text_field_embedder = text_field_embedder
self.encoder = encoder
self.edge_prediction_threshold = edge_prediction_threshold
if not 0 < edge_prediction_threshold < 1:
raise ConfigurationError(
f"edge_prediction_threshold must be between "
f"0 and 1 (exclusive) but found {edge_prediction_threshold}."
)
encoder_dim = encoder.get_output_dim()
self.head_arc_feedforward = arc_feedforward or FeedForward(
encoder_dim, 1, arc_representation_dim, Activation.by_name("elu")()
)
self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)
self.arc_attention = BilinearMatrixAttention(
arc_representation_dim, arc_representation_dim, use_input_biases=True
)
num_labels = self.vocab.get_vocab_size("labels")
self.head_tag_feedforward = tag_feedforward or FeedForward(
encoder_dim, 1, tag_representation_dim, Activation.by_name("elu")()
)
self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)
self.tag_bilinear = BilinearMatrixAttention(
tag_representation_dim, tag_representation_dim, label_dim=num_labels
)
self._pos_tag_embedding = pos_tag_embedding or None
self._dropout = InputVariationalDropout(dropout)
self._input_dropout = Dropout(input_dropout)
representation_dim = text_field_embedder.get_output_dim()
if pos_tag_embedding is not None:
representation_dim += pos_tag_embedding.get_output_dim()
check_dimensions_match(
representation_dim,
encoder.get_input_dim(),
"text field embedding dim",
"encoder input dim",
)
check_dimensions_match(
tag_representation_dim,
self.head_tag_feedforward.get_output_dim(),
"tag representation dim",
"tag feedforward output dim",
)
check_dimensions_match(
arc_representation_dim,
self.head_arc_feedforward.get_output_dim(),
"arc representation dim",
"arc feedforward output dim",
)
self._unlabelled_f1 = F1Measure(positive_label=1)
self._arc_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
self._tag_loss = torch.nn.CrossEntropyLoss(reduction="none")
initializer(self)
@overrides
def forward(
self, # type: ignore
tokens: TextFieldTensors,
pos_tags: torch.LongTensor = None,
metadata: List[Dict[str, Any]] = None,
arc_tags: torch.LongTensor = None,
) -> Dict[str, torch.Tensor]:
"""
# Parameters
tokens : `TextFieldTensors`, required
The output of `TextField.as_array()`.
pos_tags : `torch.LongTensor`, optional (default = `None`)
The output of a `SequenceLabelField` containing POS tags.
metadata : `List[Dict[str, Any]]`, optional (default = `None`)
A dictionary of metadata for each batch element which has keys:
tokens : `List[str]`, required.
The original string tokens in the sentence.
arc_tags : `torch.LongTensor`, optional (default = `None`)
A torch tensor representing the sequence of integer indices denoting the parent of every
word in the dependency parse. Has shape `(batch_size, sequence_length, sequence_length)`.
# Returns
An output dictionary.
"""
embedded_text_input = self.text_field_embedder(tokens)
if pos_tags is not None and self._pos_tag_embedding is not None:
embedded_pos_tags = self._pos_tag_embedding(pos_tags)
embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
elif self._pos_tag_embedding is not None:
raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")
mask = get_text_field_mask(tokens)
embedded_text_input = self._input_dropout(embedded_text_input)
encoded_text = self.encoder(embedded_text_input, mask)
encoded_text = self._dropout(encoded_text)
# shape (batch_size, sequence_length, arc_representation_dim)
head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))
# shape (batch_size, sequence_length, tag_representation_dim)
head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
# shape (batch_size, sequence_length, sequence_length)
arc_scores = self.arc_attention(head_arc_representation, child_arc_representation)
# shape (batch_size, num_tags, sequence_length, sequence_length)
arc_tag_logits = self.tag_bilinear(head_tag_representation, child_tag_representation)
# Switch to (batch_size, sequence_length, sequence_length, num_tags)
arc_tag_logits = arc_tag_logits.permute(0, 2, 3, 1).contiguous()
# Since we'll be doing some additions, using the min value will cause underflow
minus_mask = ~mask * min_value_of_dtype(arc_scores.dtype) / 10
arc_scores = arc_scores + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)
arc_probs, arc_tag_probs = self._greedy_decode(arc_scores, arc_tag_logits, mask)
output_dict = {"arc_probs": arc_probs, "arc_tag_probs": arc_tag_probs, "mask": mask}
if metadata:
output_dict["tokens"] = [meta["tokens"] for meta in metadata]
if arc_tags is not None:
arc_nll, tag_nll = self._construct_loss(
arc_scores=arc_scores, arc_tag_logits=arc_tag_logits, arc_tags=arc_tags, mask=mask
)
output_dict["loss"] = arc_nll + tag_nll
output_dict["arc_loss"] = arc_nll
output_dict["tag_loss"] = tag_nll
# Make the arc tags not have negative values anywhere
# (by default, no edge is indicated with -1).
arc_indices = (arc_tags != -1).float()
tag_mask = mask.unsqueeze(1) & mask.unsqueeze(2)
one_minus_arc_probs = 1 - arc_probs
# We stack scores here because the f1 measure expects a
# distribution, rather than a single value.
self._unlabelled_f1(
torch.stack([one_minus_arc_probs, arc_probs], -1), arc_indices, tag_mask
)
return output_dict
@overrides
def make_output_human_readable(
self, output_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
arc_tag_probs = output_dict["arc_tag_probs"].cpu().detach().numpy()
arc_probs = output_dict["arc_probs"].cpu().detach().numpy()
mask = output_dict["mask"]
lengths = get_lengths_from_binary_sequence_mask(mask)
arcs = []
arc_tags = []
for instance_arc_probs, instance_arc_tag_probs, length in zip(
arc_probs, arc_tag_probs, lengths
):
arc_matrix = instance_arc_probs > self.edge_prediction_threshold
edges = []
edge_tags = []
for i in range(length):
for j in range(length):
if arc_matrix[i, j] == 1:
edges.append((i, j))
tag = instance_arc_tag_probs[i, j].argmax(-1)
edge_tags.append(self.vocab.get_token_from_index(tag, "labels"))
arcs.append(edges)
arc_tags.append(edge_tags)
output_dict["arcs"] = arcs
output_dict["arc_tags"] = arc_tags
return output_dict
def _construct_loss(
self,
arc_scores: torch.Tensor,
arc_tag_logits: torch.Tensor,
arc_tags: torch.Tensor,
mask: torch.BoolTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the arc and tag loss for an adjacency matrix.
# Parameters
arc_scores : `torch.Tensor`, required.
A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a
binary classification decision for whether an edge is present between two words.
arc_tag_logits : `torch.Tensor`, required.
A tensor of shape (batch_size, sequence_length, sequence_length, num_tags) used to generate
a distribution over edge tags for a given edge.
arc_tags : `torch.Tensor`, required.
A tensor of shape (batch_size, sequence_length, sequence_length).
The labels for every arc.
mask : `torch.BoolTensor`, required.
A mask of shape (batch_size, sequence_length), denoting unpadded
elements in the sequence.
# Returns
arc_nll : `torch.Tensor`, required.
The negative log likelihood from the arc loss.
tag_nll : `torch.Tensor`, required.
The negative log likelihood from the arc tag loss.
"""
arc_indices = (arc_tags != -1).float()
# Make the arc tags not have negative values anywhere
# (by default, no edge is indicated with -1).
arc_tags = arc_tags * arc_indices
arc_nll = self._arc_loss(arc_scores, arc_indices) * mask.unsqueeze(1) * mask.unsqueeze(2)
# We want the mask for the tags to only include the unmasked words
# and we only care about the loss with respect to the gold arcs.
tag_mask = mask.unsqueeze(1) * mask.unsqueeze(2) * arc_indices
batch_size, sequence_length, _, num_tags = arc_tag_logits.size()
original_shape = [batch_size, sequence_length, sequence_length]
reshaped_logits = arc_tag_logits.view(-1, num_tags)
reshaped_tags = arc_tags.view(-1)
tag_nll = (
self._tag_loss(reshaped_logits, reshaped_tags.long()).view(original_shape) * tag_mask
)
valid_positions = tag_mask.sum()
arc_nll = arc_nll.sum() / valid_positions.float()
tag_nll = tag_nll.sum() / valid_positions.float()
return arc_nll, tag_nll
@staticmethod
def _greedy_decode(
arc_scores: torch.Tensor, arc_tag_logits: torch.Tensor, mask: torch.BoolTensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Decodes the head and head tag predictions by decoding the unlabeled arcs
independently for each word and then again, predicting the head tags of
these greedily chosen arcs independently.
# Parameters
arc_scores : `torch.Tensor`, required.
A tensor of shape (batch_size, sequence_length, sequence_length) used to generate
a distribution over attachments of a given word to all other words.
arc_tag_logits : `torch.Tensor`, required.
A tensor of shape (batch_size, sequence_length, sequence_length, num_tags) used to
generate a distribution over tags for each arc.
mask : `torch.BoolTensor`, required.
A mask of shape (batch_size, sequence_length).
# Returns
arc_probs : `torch.Tensor`
A tensor of shape (batch_size, sequence_length, sequence_length) representing the
probability of an arc being present for this edge.
arc_tag_probs : `torch.Tensor`
A tensor of shape (batch_size, sequence_length, sequence_length, sequence_length)
representing the distribution over edge tags for a given edge.
"""
# Mask the diagonal, because we don't self edges.
inf_diagonal_mask = torch.diag(arc_scores.new(mask.size(1)).fill_(-numpy.inf))
arc_scores = arc_scores + inf_diagonal_mask
# shape (batch_size, sequence_length, sequence_length, num_tags)
arc_tag_logits = arc_tag_logits + inf_diagonal_mask.unsqueeze(0).unsqueeze(-1)
# Mask padded tokens, because we only want to consider actual word -> word edges.
minus_mask = ~mask.unsqueeze(2)
arc_scores.masked_fill_(minus_mask, -numpy.inf)
arc_tag_logits.masked_fill_(minus_mask.unsqueeze(-1), -numpy.inf)
# shape (batch_size, sequence_length, sequence_length)
arc_probs = arc_scores.sigmoid()
# shape (batch_size, sequence_length, sequence_length, num_tags)
arc_tag_probs = torch.nn.functional.softmax(arc_tag_logits, dim=-1)
return arc_probs, arc_tag_probs
@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return self._unlabelled_f1.get_metric(reset)