This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
bidaf.py
380 lines (343 loc) · 18.2 KB
/
bidaf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
import logging
from typing import Any, Dict, List, Optional
import torch
from torch.nn.functional import nll_loss
from allennlp.common.checks import check_dimensions_match
from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.modules import Highway
from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder
from allennlp.modules.matrix_attention import MatrixAttention
from allennlp.nn import util, InitializerApplicator, RegularizerApplicator
from allennlp.training.metrics import BooleanAccuracy, CategoricalAccuracy
from allennlp_models.rc.metrics import SquadEmAndF1
from allennlp_models.rc.models.utils import (
get_best_span,
replace_masked_values_with_big_negative_number,
)
logger = logging.getLogger(__name__)
@Model.register("bidaf")
class BidirectionalAttentionFlow(Model):
"""
This class implements Minjoon Seo's [Bidirectional Attention Flow model]
(https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d)
for answering reading comprehension questions (ICLR 2017).
The basic layout is pretty simple: encode words as a combination of word embeddings and a
character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
attentions to put question information into the passage word representations (this is the only
part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
do a softmax over span start and span end.
# Parameters
vocab : `Vocabulary`
text_field_embedder : `TextFieldEmbedder`
Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
num_highway_layers : `int`
The number of highway layers to use in between embedding the input and passing it through
the phrase layer.
phrase_layer : `Seq2SeqEncoder`
The encoder (with its own internal stacking) that we will use in between embedding tokens
and doing the bidirectional attention.
matrix_attention : `MatrixAttention`
The attention function that we will use when comparing encoded passage and question
representations.
modeling_layer : `Seq2SeqEncoder`
The encoder (with its own internal stacking) that we will use in between the bidirectional
attention and predicting span start and end.
span_end_encoder : `Seq2SeqEncoder`
The encoder that we will use to incorporate span start predictions into the passage state
before predicting span end.
dropout : `float`, optional (default=`0.2`)
If greater than 0, we will apply dropout with this probability after all encoders (pytorch
LSTMs do not apply dropout to their last layer).
mask_lstms : `bool`, optional (default=`True`)
If ``False``, we will skip passing the mask to the LSTM layers. This gives a ~2x speedup,
with only a slight performance decrease, if any. We haven't experimented much with this
yet, but have confirmed that we still get very similar performance with much faster
training times. We still use the mask for all softmaxes, but avoid the shuffling that's
required when using masking with pytorch LSTMs.
initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
Used to initialize the model parameters.
regularizer : `RegularizerApplicator`, optional (default=`None`)
If provided, will be used to calculate the regularization penalty during training.
"""
def __init__(
self,
vocab: Vocabulary,
text_field_embedder: TextFieldEmbedder,
num_highway_layers: int,
phrase_layer: Seq2SeqEncoder,
matrix_attention: MatrixAttention,
modeling_layer: Seq2SeqEncoder,
span_end_encoder: Seq2SeqEncoder,
dropout: float = 0.2,
mask_lstms: bool = True,
initializer: InitializerApplicator = InitializerApplicator(),
regularizer: Optional[RegularizerApplicator] = None,
) -> None:
super().__init__(vocab, regularizer)
self._text_field_embedder = text_field_embedder
self._highway_layer = TimeDistributed(
Highway(text_field_embedder.get_output_dim(), num_highway_layers)
)
self._phrase_layer = phrase_layer
self._matrix_attention = matrix_attention
self._modeling_layer = modeling_layer
self._span_end_encoder = span_end_encoder
encoding_dim = phrase_layer.get_output_dim()
modeling_dim = modeling_layer.get_output_dim()
span_start_input_dim = encoding_dim * 4 + modeling_dim
self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1))
span_end_encoding_dim = span_end_encoder.get_output_dim()
span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1))
# Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
# obvious from the configuration files, so we check here.
check_dimensions_match(
modeling_layer.get_input_dim(),
4 * encoding_dim,
"modeling layer input dim",
"4 * encoding dim",
)
check_dimensions_match(
text_field_embedder.get_output_dim(),
phrase_layer.get_input_dim(),
"text field embedder output dim",
"phrase layer input dim",
)
check_dimensions_match(
span_end_encoder.get_input_dim(),
4 * encoding_dim + 3 * modeling_dim,
"span end encoder input dim",
"4 * encoding dim + 3 * modeling dim",
)
self._span_start_accuracy = CategoricalAccuracy()
self._span_end_accuracy = CategoricalAccuracy()
self._span_accuracy = BooleanAccuracy()
self._squad_metrics = SquadEmAndF1()
if dropout > 0:
self._dropout = torch.nn.Dropout(p=dropout)
else:
self._dropout = lambda x: x
self._mask_lstms = mask_lstms
initializer(self)
def forward( # type: ignore
self,
question: Dict[str, torch.LongTensor],
passage: Dict[str, torch.LongTensor],
span_start: torch.IntTensor = None,
span_end: torch.IntTensor = None,
metadata: List[Dict[str, Any]] = None,
) -> Dict[str, torch.Tensor]:
"""
# Parameters
question : `Dict[str, torch.LongTensor]`
From a ``TextField``.
passage : `Dict[str, torch.LongTensor]`
From a ``TextField``. The model assumes that this passage contains the answer to the
question, and predicts the beginning and ending positions of the answer within the
passage.
span_start : `torch.IntTensor`, optional
From an ``IndexField``. This is one of the things we are trying to predict - the
beginning position of the answer with the passage. This is an `inclusive` token index.
If this is given, we will compute a loss that gets included in the output dictionary.
span_end : `torch.IntTensor`, optional
From an ``IndexField``. This is one of the things we are trying to predict - the
ending position of the answer with the passage. This is an `inclusive` token index.
If this is given, we will compute a loss that gets included in the output dictionary.
metadata : `List[Dict[str, Any]]`, optional
metadata : `List[Dict[str, Any]]`, optional
If present, this should contain the question tokens, passage tokens, original passage
text, and token offsets into the passage for each instance in the batch. The length
of this list should be the batch size, and each dictionary should have the keys
``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.
Returns
-------
An output dictionary consisting of:
span_start_logits : `torch.FloatTensor`
A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
probabilities of the span start position.
span_start_probs : `torch.FloatTensor`
The result of ``softmax(span_start_logits)``.
span_end_logits : `torch.FloatTensor`
A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
probabilities of the span end position (inclusive).
span_end_probs : `torch.FloatTensor`
The result of ``softmax(span_end_logits)``.
best_span : `torch.IntTensor`
The result of a constrained inference over ``span_start_logits`` and
``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)``
and each offset is a token index.
loss : `torch.FloatTensor`, optional
A scalar loss to be optimised.
best_span_str : `List[str]`
If sufficient metadata was provided for the instances in the batch, we also return the
string from the original passage that the model thinks is the best answer to the
question.
"""
embedded_question = self._highway_layer(self._text_field_embedder(question))
embedded_passage = self._highway_layer(self._text_field_embedder(passage))
batch_size = embedded_question.size(0)
passage_length = embedded_passage.size(1)
question_mask = util.get_text_field_mask(question)
passage_mask = util.get_text_field_mask(passage)
question_lstm_mask = question_mask if self._mask_lstms else None
passage_lstm_mask = passage_mask if self._mask_lstms else None
encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
encoding_dim = encoded_question.size(-1)
# Shape: (batch_size, passage_length, question_length)
passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
# Shape: (batch_size, passage_length, question_length)
passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
# Shape: (batch_size, passage_length, encoding_dim)
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)
# We replace masked values with something really negative here, so they don't affect the
# max below.
masked_similarity = replace_masked_values_with_big_negative_number(
passage_question_similarity, question_mask.unsqueeze(1)
)
# Shape: (batch_size, passage_length)
question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
# Shape: (batch_size, passage_length)
question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
# Shape: (batch_size, encoding_dim)
question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
# Shape: (batch_size, passage_length, encoding_dim)
tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(
batch_size, passage_length, encoding_dim
)
# Shape: (batch_size, passage_length, encoding_dim * 4)
final_merged_passage = torch.cat(
[
encoded_passage,
passage_question_vectors,
encoded_passage * passage_question_vectors,
encoded_passage * tiled_question_passage_vector,
],
dim=-1,
)
modeled_passage = self._dropout(
self._modeling_layer(final_merged_passage, passage_lstm_mask)
)
modeling_dim = modeled_passage.size(-1)
# Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
# Shape: (batch_size, passage_length)
span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
# Shape: (batch_size, passage_length)
span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
# Shape: (batch_size, modeling_dim)
span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
# Shape: (batch_size, passage_length, modeling_dim)
tiled_start_representation = span_start_representation.unsqueeze(1).expand(
batch_size, passage_length, modeling_dim
)
# Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
span_end_representation = torch.cat(
[
final_merged_passage,
modeled_passage,
tiled_start_representation,
modeled_passage * tiled_start_representation,
],
dim=-1,
)
# Shape: (batch_size, passage_length, encoding_dim)
encoded_span_end = self._dropout(
self._span_end_encoder(span_end_representation, passage_lstm_mask)
)
# Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
# Replace the masked values with a very negative constant.
span_start_logits = replace_masked_values_with_big_negative_number(
span_start_logits, passage_mask
)
span_end_logits = replace_masked_values_with_big_negative_number(
span_end_logits, passage_mask
)
best_span = get_best_span(span_start_logits, span_end_logits)
output_dict = {
"passage_question_attention": passage_question_attention,
"span_start_logits": span_start_logits,
"span_start_probs": span_start_probs,
"span_end_logits": span_end_logits,
"span_end_probs": span_end_probs,
"best_span": best_span,
}
# Compute the loss for training.
if span_start is not None:
loss = nll_loss(
util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)
)
self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
loss += nll_loss(
util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)
)
self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
self._span_accuracy(best_span, torch.cat([span_start, span_end], -1))
output_dict["loss"] = loss
# Compute the EM and F1 on SQuAD and add the tokenized input to the output.
if metadata is not None:
output_dict["best_span_str"] = []
question_tokens = []
passage_tokens = []
token_offsets = []
for i in range(batch_size):
question_tokens.append(metadata[i]["question_tokens"])
passage_tokens.append(metadata[i]["passage_tokens"])
token_offsets.append(metadata[i]["token_offsets"])
passage_str = metadata[i]["original_passage"]
offsets = metadata[i]["token_offsets"]
predicted_span = tuple(best_span[i].detach().cpu().numpy())
start_offset = offsets[predicted_span[0]][0]
end_offset = offsets[predicted_span[1]][1]
best_span_string = passage_str[start_offset:end_offset]
output_dict["best_span_str"].append(best_span_string)
answer_texts = metadata[i].get("answer_texts", [])
if answer_texts:
self._squad_metrics(best_span_string, answer_texts)
output_dict["question_tokens"] = question_tokens
output_dict["passage_tokens"] = passage_tokens
output_dict["token_offsets"] = token_offsets
return output_dict
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
exact_match, f1_score = self._squad_metrics.get_metric(reset)
return {
"start_acc": self._span_start_accuracy.get_metric(reset),
"end_acc": self._span_end_accuracy.get_metric(reset),
"span_acc": self._span_accuracy.get_metric(reset),
"em": exact_match,
"f1": f1_score,
}
@staticmethod
def get_best_span(
span_start_logits: torch.Tensor, span_end_logits: torch.Tensor
) -> torch.Tensor:
# We call the inputs "logits" - they could either be unnormalized logits or normalized log
# probabilities. A log_softmax operation is a constant shifting of the entire logit
# vector, so taking an argmax over either one gives the same result.
if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
raise ValueError("Input shapes must be (batch_size, passage_length)")
batch_size, passage_length = span_start_logits.size()
device = span_start_logits.device
# (batch_size, passage_length, passage_length)
span_log_probs = span_start_logits.unsqueeze(2) + span_end_logits.unsqueeze(1)
# Only the upper triangle of the span matrix is valid; the lower triangle has entries where
# the span ends before it starts.
span_log_mask = (
torch.triu(torch.ones((passage_length, passage_length), device=device))
.log()
.unsqueeze(0)
)
valid_span_log_probs = span_log_probs + span_log_mask
# Here we take the span matrix and flatten it, then find the best span using argmax. We
# can recover the start and end indices from this flattened list using simple modular
# arithmetic.
# (batch_size, passage_length * passage_length)
best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1)
span_start_indices = best_spans // passage_length
span_end_indices = best_spans % passage_length
return torch.stack([span_start_indices, span_end_indices], dim=-1)
default_predictor = "reading_comprehension"