This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
esim.py
267 lines (225 loc) · 10.1 KB
/
esim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
from typing import Dict, List, Any
from overrides import overrides
import torch
from allennlp.common.checks import check_dimensions_match
from allennlp.data import TextFieldTensors, Vocabulary
from allennlp.models.model import Model
from allennlp.modules import FeedForward, InputVariationalDropout
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention
from allennlp.modules import Seq2SeqEncoder, TextFieldEmbedder
from allennlp.nn import InitializerApplicator
from allennlp.nn.util import (
get_text_field_mask,
masked_softmax,
weighted_sum,
masked_max,
)
from allennlp.training.metrics import CategoricalAccuracy
@Model.register("esim")
class ESIM(Model):
"""
This `Model` implements the ESIM sequence model described in [Enhanced LSTM for Natural Language Inference]
(https://api.semanticscholar.org/CorpusID:34032948) by Chen et al., 2017.
Registered as a `Model` with name "esim".
# Parameters
vocab : `Vocabulary`
text_field_embedder : `TextFieldEmbedder`
Used to embed the `premise` and `hypothesis` `TextFields` we get as input to the
model.
encoder : `Seq2SeqEncoder`
Used to encode the premise and hypothesis.
matrix_attention : `MatrixAttention`
This is the attention function used when computing the similarity matrix between encoded
words in the premise and words in the hypothesis.
projection_feedforward : `FeedForward`
The feedforward network used to project down the encoded and enhanced premise and hypothesis.
inference_encoder : `Seq2SeqEncoder`
Used to encode the projected premise and hypothesis for prediction.
output_feedforward : `FeedForward`
Used to prepare the concatenated premise and hypothesis for prediction.
output_logit : `FeedForward`
This feedforward network computes the output logits.
dropout : `float`, optional (default=`0.5`)
Dropout percentage to use.
initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
Used to initialize the model parameters.
"""
def __init__(
self,
vocab: Vocabulary,
text_field_embedder: TextFieldEmbedder,
encoder: Seq2SeqEncoder,
matrix_attention: MatrixAttention,
projection_feedforward: FeedForward,
inference_encoder: Seq2SeqEncoder,
output_feedforward: FeedForward,
output_logit: FeedForward,
dropout: float = 0.5,
initializer: InitializerApplicator = InitializerApplicator(),
**kwargs,
) -> None:
super().__init__(vocab, **kwargs)
self._text_field_embedder = text_field_embedder
self._encoder = encoder
self._matrix_attention = matrix_attention
self._projection_feedforward = projection_feedforward
self._inference_encoder = inference_encoder
if dropout:
self.dropout = torch.nn.Dropout(dropout)
self.rnn_input_dropout = InputVariationalDropout(dropout)
else:
self.dropout = None
self.rnn_input_dropout = None
self._output_feedforward = output_feedforward
self._output_logit = output_logit
self._num_labels = vocab.get_vocab_size(namespace="labels")
check_dimensions_match(
text_field_embedder.get_output_dim(),
encoder.get_input_dim(),
"text field embedding dim",
"encoder input dim",
)
check_dimensions_match(
encoder.get_output_dim() * 4,
projection_feedforward.get_input_dim(),
"encoder output dim",
"projection feedforward input",
)
check_dimensions_match(
projection_feedforward.get_output_dim(),
inference_encoder.get_input_dim(),
"proj feedforward output dim",
"inference lstm input dim",
)
self._accuracy = CategoricalAccuracy()
self._loss = torch.nn.CrossEntropyLoss()
initializer(self)
def forward( # type: ignore
self,
premise: TextFieldTensors,
hypothesis: TextFieldTensors,
label: torch.IntTensor = None,
metadata: List[Dict[str, Any]] = None,
) -> Dict[str, torch.Tensor]:
"""
# Parameters
premise : `TextFieldTensors`
From a `TextField`
hypothesis : `TextFieldTensors`
From a `TextField`
label : `torch.IntTensor`, optional (default = `None`)
From a `LabelField`
metadata : `List[Dict[str, Any]]`, optional (default = `None`)
Metadata containing the original tokenization of the premise and
hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
# Returns
An output dictionary consisting of:
label_logits : `torch.FloatTensor`
A tensor of shape `(batch_size, num_labels)` representing unnormalised log
probabilities of the entailment label.
label_probs : `torch.FloatTensor`
A tensor of shape `(batch_size, num_labels)` representing probabilities of the
entailment label.
loss : `torch.FloatTensor`, optional
A scalar loss to be optimised.
"""
embedded_premise = self._text_field_embedder(premise)
embedded_hypothesis = self._text_field_embedder(hypothesis)
premise_mask = get_text_field_mask(premise)
hypothesis_mask = get_text_field_mask(hypothesis)
# apply dropout for LSTM
if self.rnn_input_dropout:
embedded_premise = self.rnn_input_dropout(embedded_premise)
embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)
# encode premise and hypothesis
encoded_premise = self._encoder(embedded_premise, premise_mask)
encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)
# Shape: (batch_size, premise_length, hypothesis_length)
similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)
# Shape: (batch_size, premise_length, hypothesis_length)
p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
# Shape: (batch_size, premise_length, embedding_dim)
attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)
# Shape: (batch_size, hypothesis_length, premise_length)
h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
# Shape: (batch_size, hypothesis_length, embedding_dim)
attended_premise = weighted_sum(encoded_premise, h2p_attention)
# the "enhancement" layer
premise_enhanced = torch.cat(
[
encoded_premise,
attended_hypothesis,
encoded_premise - attended_hypothesis,
encoded_premise * attended_hypothesis,
],
dim=-1,
)
hypothesis_enhanced = torch.cat(
[
encoded_hypothesis,
attended_premise,
encoded_hypothesis - attended_premise,
encoded_hypothesis * attended_premise,
],
dim=-1,
)
# The projection layer down to the model dimension. Dropout is not applied before
# projection.
projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)
# Run the inference layer
if self.rnn_input_dropout:
projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)
# The pooling layer -- max and avg pooling.
# (batch_size, model_dim)
v_a_max = masked_max(v_ai, premise_mask.unsqueeze(-1), dim=1)
v_b_max = masked_max(v_bi, hypothesis_mask.unsqueeze(-1), dim=1)
v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(
premise_mask, 1, keepdim=True
)
v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(
hypothesis_mask, 1, keepdim=True
)
# Now concat
# (batch_size, model_dim * 2 * 4)
v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)
# the final MLP -- apply dropout to input, and MLP applies to output & hidden
if self.dropout:
v_all = self.dropout(v_all)
output_hidden = self._output_feedforward(v_all)
label_logits = self._output_logit(output_hidden)
label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
output_dict = {"label_logits": label_logits, "label_probs": label_probs}
if label is not None:
loss = self._loss(label_logits, label.long().view(-1))
self._accuracy(label_logits, label)
output_dict["loss"] = loss
return output_dict
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
return {"accuracy": self._accuracy.get_metric(reset)}
@overrides
def make_output_human_readable(
self, output_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Does a simple argmax over the probabilities, converts index to string label, and
add `"label"` key to the dictionary with the result.
"""
predictions = output_dict["label_probs"]
if predictions.dim() == 2:
predictions_list = [predictions[i] for i in range(predictions.shape[0])]
else:
predictions_list = [predictions]
classes = []
for prediction in predictions_list:
label_idx = prediction.argmax(dim=-1).item()
label_str = self.vocab.get_index_to_token_vocabulary("labels").get(
label_idx, str(label_idx)
)
classes.append(label_str)
output_dict["label"] = classes
return output_dict
default_predictor = "textual_entailment"