This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
qanet.py
294 lines (256 loc) · 13.6 KB
/
qanet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
from typing import Any, Dict, List, Optional
import torch
from torch.nn.functional import nll_loss
from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.modules import Highway
from allennlp.modules import Seq2SeqEncoder, TextFieldEmbedder
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention
from allennlp.nn import util, InitializerApplicator, RegularizerApplicator
from allennlp.training.metrics import BooleanAccuracy, CategoricalAccuracy
from allennlp.nn.util import masked_softmax
from allennlp_models.rc.models.utils import (
get_best_span,
replace_masked_values_with_big_negative_number,
)
from allennlp_models.rc.metrics import SquadEmAndF1
@Model.register("qanet")
@Model.register("rc-qanet")
class QaNet(Model):
"""
This class implements Adams Wei Yu's `QANet Model <https://openreview.net/forum?id=B14TlG-RW>`_
for machine reading comprehension published at ICLR 2018.
The overall architecture of QANet is very similar to BiDAF. The main difference is that QANet
replaces the RNN encoder with CNN + self-attention. There are also some minor differences in the
modeling layer and output layer.
Parameters
----------
vocab : ``Vocabulary``
text_field_embedder : ``TextFieldEmbedder``
Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
num_highway_layers : ``int``
The number of highway layers to use in between embedding the input and passing it through
the phrase layer.
phrase_layer : ``Seq2SeqEncoder``
The encoder (with its own internal stacking) that we will use in between embedding tokens
and doing the passage-question attention.
matrix_attention_layer : ``MatrixAttention``
The matrix attention function that we will use when comparing encoded passage and question
representations.
modeling_layer : ``Seq2SeqEncoder``
The encoder (with its own internal stacking) that we will use in between the bidirectional
attention and predicting span start and end.
dropout_prob : ``float``, optional (default=0.1)
If greater than 0, we will apply dropout with this probability between layers.
initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
Used to initialize the model parameters.
regularizer : ``RegularizerApplicator``, optional (default=``None``)
If provided, will be used to calculate the regularization penalty during training.
"""
def __init__(
self,
vocab: Vocabulary,
text_field_embedder: TextFieldEmbedder,
num_highway_layers: int,
phrase_layer: Seq2SeqEncoder,
matrix_attention_layer: MatrixAttention,
modeling_layer: Seq2SeqEncoder,
dropout_prob: float = 0.1,
initializer: InitializerApplicator = InitializerApplicator(),
regularizer: Optional[RegularizerApplicator] = None,
) -> None:
super().__init__(vocab, regularizer)
text_embed_dim = text_field_embedder.get_output_dim()
encoding_in_dim = phrase_layer.get_input_dim()
encoding_out_dim = phrase_layer.get_output_dim()
modeling_in_dim = modeling_layer.get_input_dim()
modeling_out_dim = modeling_layer.get_output_dim()
self._text_field_embedder = text_field_embedder
self._embedding_proj_layer = torch.nn.Linear(text_embed_dim, encoding_in_dim)
self._highway_layer = Highway(encoding_in_dim, num_highway_layers)
self._encoding_proj_layer = torch.nn.Linear(encoding_in_dim, encoding_in_dim, bias=False)
self._phrase_layer = phrase_layer
self._matrix_attention = matrix_attention_layer
self._modeling_proj_layer = torch.nn.Linear(
encoding_out_dim * 4, modeling_in_dim, bias=False
)
self._modeling_layer = modeling_layer
self._span_start_predictor = torch.nn.Linear(modeling_out_dim * 2, 1)
self._span_end_predictor = torch.nn.Linear(modeling_out_dim * 2, 1)
self._span_start_accuracy = CategoricalAccuracy()
self._span_end_accuracy = CategoricalAccuracy()
self._span_accuracy = BooleanAccuracy()
self._metrics = SquadEmAndF1()
self._dropout = torch.nn.Dropout(p=dropout_prob) if dropout_prob > 0 else lambda x: x
initializer(self)
def forward( # type: ignore
self,
question: Dict[str, torch.LongTensor],
passage: Dict[str, torch.LongTensor],
span_start: torch.IntTensor = None,
span_end: torch.IntTensor = None,
metadata: List[Dict[str, Any]] = None,
) -> Dict[str, torch.Tensor]:
"""
Parameters
----------
question : Dict[str, torch.LongTensor]
From a ``TextField``.
passage : Dict[str, torch.LongTensor]
From a ``TextField``. The model assumes that this passage contains the answer to the
question, and predicts the beginning and ending positions of the answer within the
passage.
span_start : ``torch.IntTensor``, optional
From an ``IndexField``. This is one of the things we are trying to predict - the
beginning position of the answer with the passage. This is an `inclusive` token index.
If this is given, we will compute a loss that gets included in the output dictionary.
span_end : ``torch.IntTensor``, optional
From an ``IndexField``. This is one of the things we are trying to predict - the
ending position of the answer with the passage. This is an `inclusive` token index.
If this is given, we will compute a loss that gets included in the output dictionary.
metadata : ``List[Dict[str, Any]]``, optional
If present, this should contain the question tokens, passage tokens, original passage
text, and token offsets into the passage for each instance in the batch. The length
of this list should be the batch size, and each dictionary should have the keys
``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.
Returns
-------
An output dictionary consisting of:
span_start_logits : torch.FloatTensor
A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
probabilities of the span start position.
span_start_probs : torch.FloatTensor
The result of ``softmax(span_start_logits)``.
span_end_logits : torch.FloatTensor
A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
probabilities of the span end position (inclusive).
span_end_probs : torch.FloatTensor
The result of ``softmax(span_end_logits)``.
best_span : torch.IntTensor
The result of a constrained inference over ``span_start_logits`` and
``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)``
and each offset is a token index.
loss : torch.FloatTensor, optional
A scalar loss to be optimised.
best_span_str : List[str]
If sufficient metadata was provided for the instances in the batch, we also return the
string from the original passage that the model thinks is the best answer to the
question.
"""
question_mask = util.get_text_field_mask(question)
passage_mask = util.get_text_field_mask(passage)
embedded_question = self._dropout(self._text_field_embedder(question))
embedded_passage = self._dropout(self._text_field_embedder(passage))
embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question))
embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage))
batch_size = embedded_question.size(0)
projected_embedded_question = self._encoding_proj_layer(embedded_question)
projected_embedded_passage = self._encoding_proj_layer(embedded_passage)
encoded_question = self._dropout(
self._phrase_layer(projected_embedded_question, question_mask)
)
encoded_passage = self._dropout(
self._phrase_layer(projected_embedded_passage, passage_mask)
)
# Shape: (batch_size, passage_length, question_length)
passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
# Shape: (batch_size, passage_length, question_length)
passage_question_attention = masked_softmax(
passage_question_similarity, question_mask, memory_efficient=True
)
# Shape: (batch_size, passage_length, encoding_dim)
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)
# Shape: (batch_size, question_length, passage_length)
question_passage_attention = masked_softmax(
passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True
)
# Shape: (batch_size, passage_length, passage_length)
attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention)
# Shape: (batch_size, passage_length, encoding_dim)
passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention)
# Shape: (batch_size, passage_length, encoding_dim * 4)
merged_passage_attention_vectors = self._dropout(
torch.cat(
[
encoded_passage,
passage_question_vectors,
encoded_passage * passage_question_vectors,
encoded_passage * passage_passage_vectors,
],
dim=-1,
)
)
modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)]
for _ in range(3):
modeled_passage = self._dropout(
self._modeling_layer(modeled_passage_list[-1], passage_mask)
)
modeled_passage_list.append(modeled_passage)
# Shape: (batch_size, passage_length, modeling_dim * 2))
span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
# Shape: (batch_size, passage_length)
span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
# Shape: (batch_size, passage_length, modeling_dim * 2)
span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
span_start_logits = replace_masked_values_with_big_negative_number(
span_start_logits, passage_mask
)
span_end_logits = replace_masked_values_with_big_negative_number(
span_end_logits, passage_mask
)
# Shape: (batch_size, passage_length)
span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1)
span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)
best_span = get_best_span(span_start_logits, span_end_logits)
output_dict = {
"passage_question_attention": passage_question_attention,
"span_start_logits": span_start_logits,
"span_start_probs": span_start_probs,
"span_end_logits": span_end_logits,
"span_end_probs": span_end_probs,
"best_span": best_span,
}
# Compute the loss for training.
if span_start is not None:
loss = nll_loss(
util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)
)
self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
loss += nll_loss(
util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)
)
self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
self._span_accuracy(best_span, torch.cat([span_start, span_end], -1))
output_dict["loss"] = loss
# Compute the EM and F1 on SQuAD and add the tokenized input to the output.
if metadata is not None:
output_dict["best_span_str"] = []
question_tokens = []
passage_tokens = []
for i in range(batch_size):
question_tokens.append(metadata[i]["question_tokens"])
passage_tokens.append(metadata[i]["passage_tokens"])
passage_str = metadata[i]["original_passage"]
offsets = metadata[i]["token_offsets"]
predicted_span = tuple(best_span[i].detach().cpu().numpy())
start_offset = offsets[predicted_span[0]][0]
end_offset = offsets[predicted_span[1]][1]
best_span_string = passage_str[start_offset:end_offset]
output_dict["best_span_str"].append(best_span_string)
answer_texts = metadata[i].get("answer_texts", [])
if answer_texts:
self._metrics(best_span_string, answer_texts)
output_dict["question_tokens"] = question_tokens
output_dict["passage_tokens"] = passage_tokens
return output_dict
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
exact_match, f1_score = self._metrics.get_metric(reset)
return {
"start_acc": self._span_start_accuracy.get_metric(reset),
"end_acc": self._span_end_accuracy.get_metric(reset),
"span_acc": self._span_accuracy.get_metric(reset),
"em": exact_match,
"f1": f1_score,
}
default_predictor = "reading_comprehension"