This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
decomposable_attention.py
226 lines (190 loc) · 9.94 KB
/
decomposable_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
from typing import Dict, Optional, List, Any
import torch
from allennlp.common.checks import check_dimensions_match
from allennlp.data import TextFieldTensors, Vocabulary
from allennlp.models.model import Model
from allennlp.modules import FeedForward
from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention
from allennlp.nn import InitializerApplicator
from allennlp.nn.util import get_text_field_mask, masked_softmax, weighted_sum
from allennlp.training.metrics import CategoricalAccuracy
@Model.register("decomposable_attention")
class DecomposableAttention(Model):
"""
This `Model` implements the Decomposable Attention model described in [A Decomposable
Attention Model for Natural Language Inference](https://api.semanticscholar.org/CorpusID:8495258)
by Parikh et al., 2016, with some optional enhancements before the decomposable attention
actually happens. Parikh's original model allowed for computing an "intra-sentence" attention
before doing the decomposable entailment step. We generalize this to any
[`Seq2SeqEncoder`](../modules/seq2seq_encoders/seq2seq_encoder.md) that can be applied to
the premise and/or the hypothesis before computing entailment.
The basic outline of this model is to get an embedded representation of each word in the
premise and hypothesis, align words between the two, compare the aligned phrases, and make a
final entailment decision based on this aggregated comparison. Each step in this process uses
a feedforward network to modify the representation.
Registered as a `Model` with name "decomposable_attention".
# Parameters
vocab : `Vocabulary`
text_field_embedder : `TextFieldEmbedder`
Used to embed the `premise` and `hypothesis` `TextFields` we get as input to the
model.
attend_feedforward : `FeedForward`
This feedforward network is applied to the encoded sentence representations before the
similarity matrix is computed between words in the premise and words in the hypothesis.
matrix_attention : `MatrixAttention`
This is the attention function used when computing the similarity matrix between words in
the premise and words in the hypothesis.
compare_feedforward : `FeedForward`
This feedforward network is applied to the aligned premise and hypothesis representations,
individually.
aggregate_feedforward : `FeedForward`
This final feedforward network is applied to the concatenated, summed result of the
`compare_feedforward` network, and its output is used as the entailment class logits.
premise_encoder : `Seq2SeqEncoder`, optional (default=`None`)
After embedding the premise, we can optionally apply an encoder. If this is `None`, we
will do nothing.
hypothesis_encoder : `Seq2SeqEncoder`, optional (default=`None`)
After embedding the hypothesis, we can optionally apply an encoder. If this is `None`,
we will use the `premise_encoder` for the encoding (doing nothing if `premise_encoder`
is also `None`).
initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
Used to initialize the model parameters.
"""
def __init__(
self,
vocab: Vocabulary,
text_field_embedder: TextFieldEmbedder,
attend_feedforward: FeedForward,
matrix_attention: MatrixAttention,
compare_feedforward: FeedForward,
aggregate_feedforward: FeedForward,
premise_encoder: Optional[Seq2SeqEncoder] = None,
hypothesis_encoder: Optional[Seq2SeqEncoder] = None,
initializer: InitializerApplicator = InitializerApplicator(),
**kwargs,
) -> None:
super().__init__(vocab, **kwargs)
self._text_field_embedder = text_field_embedder
self._attend_feedforward = TimeDistributed(attend_feedforward)
self._matrix_attention = matrix_attention
self._compare_feedforward = TimeDistributed(compare_feedforward)
self._aggregate_feedforward = aggregate_feedforward
self._premise_encoder = premise_encoder
self._hypothesis_encoder = hypothesis_encoder or premise_encoder
self._num_labels = vocab.get_vocab_size(namespace="labels")
check_dimensions_match(
text_field_embedder.get_output_dim(),
attend_feedforward.get_input_dim(),
"text field embedding dim",
"attend feedforward input dim",
)
check_dimensions_match(
aggregate_feedforward.get_output_dim(),
self._num_labels,
"final output dimension",
"number of labels",
)
self._accuracy = CategoricalAccuracy()
self._loss = torch.nn.CrossEntropyLoss()
initializer(self)
def forward( # type: ignore
self,
premise: TextFieldTensors,
hypothesis: TextFieldTensors,
label: torch.IntTensor = None,
metadata: List[Dict[str, Any]] = None,
) -> Dict[str, torch.Tensor]:
"""
# Parameters
premise : `TextFieldTensors`
From a `TextField`
hypothesis : `TextFieldTensors`
From a `TextField`
label : `torch.IntTensor`, optional (default = `None`)
From a `LabelField`
metadata : `List[Dict[str, Any]]`, optional (default = `None`)
Metadata containing the original tokenization of the premise and
hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
# Returns
An output dictionary consisting of:
label_logits : `torch.FloatTensor`
A tensor of shape `(batch_size, num_labels)` representing unnormalised log
probabilities of the entailment label.
label_probs : `torch.FloatTensor`
A tensor of shape `(batch_size, num_labels)` representing probabilities of the
entailment label.
loss : `torch.FloatTensor`, optional
A scalar loss to be optimised.
"""
embedded_premise = self._text_field_embedder(premise)
embedded_hypothesis = self._text_field_embedder(hypothesis)
premise_mask = get_text_field_mask(premise)
hypothesis_mask = get_text_field_mask(hypothesis)
if self._premise_encoder:
embedded_premise = self._premise_encoder(embedded_premise, premise_mask)
if self._hypothesis_encoder:
embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask)
projected_premise = self._attend_feedforward(embedded_premise)
projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
# Shape: (batch_size, premise_length, hypothesis_length)
similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis)
# Shape: (batch_size, premise_length, hypothesis_length)
p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
# Shape: (batch_size, premise_length, embedding_dim)
attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)
# Shape: (batch_size, hypothesis_length, premise_length)
h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
# Shape: (batch_size, hypothesis_length, embedding_dim)
attended_premise = weighted_sum(embedded_premise, h2p_attention)
premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1)
hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1)
compared_premise = self._compare_feedforward(premise_compare_input)
compared_premise = compared_premise * premise_mask.unsqueeze(-1)
# Shape: (batch_size, compare_dim)
compared_premise = compared_premise.sum(dim=1)
compared_hypothesis = self._compare_feedforward(hypothesis_compare_input)
compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1)
# Shape: (batch_size, compare_dim)
compared_hypothesis = compared_hypothesis.sum(dim=1)
aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1)
label_logits = self._aggregate_feedforward(aggregate_input)
label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
output_dict = {
"label_logits": label_logits,
"label_probs": label_probs,
"h2p_attention": h2p_attention,
"p2h_attention": p2h_attention,
}
if label is not None:
loss = self._loss(label_logits, label.long().view(-1))
self._accuracy(label_logits, label)
output_dict["loss"] = loss
if metadata is not None:
output_dict["premise_tokens"] = [x["premise_tokens"] for x in metadata]
output_dict["hypothesis_tokens"] = [x["hypothesis_tokens"] for x in metadata]
return output_dict
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return {"accuracy": self._accuracy.get_metric(reset)}
def make_output_human_readable(
self, output_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Does a simple argmax over the probabilities, converts index to string label, and
add `"label"` key to the dictionary with the result.
"""
predictions = output_dict["label_probs"]
if predictions.dim() == 2:
predictions_list = [predictions[i] for i in range(predictions.shape[0])]
else:
predictions_list = [predictions]
classes = []
for prediction in predictions_list:
label_idx = prediction.argmax(dim=-1).item()
label_str = self.vocab.get_index_to_token_vocabulary("labels").get(
label_idx, str(label_idx)
)
classes.append(label_str)
output_dict["label"] = classes
return output_dict
default_predictor = "textual_entailment"