This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 174
/
copynet_seq2seq.py
987 lines (886 loc) · 49.1 KB
/
copynet_seq2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
import logging
import warnings
from typing import Any, Dict, List, Tuple, Union
import numpy
import torch
from allennlp.common.lazy import Lazy
from allennlp.common.util import END_SYMBOL, START_SYMBOL
from allennlp.data import TextFieldTensors, Vocabulary
from allennlp.models.model import Model
from allennlp.modules import Attention, Seq2SeqEncoder, TextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.nn import InitializerApplicator, util
from allennlp.nn.beam_search import BeamSearch
from allennlp.training.metrics import BLEU, Metric
from torch.nn.modules.linear import Linear
from torch.nn.modules.rnn import LSTMCell
logger = logging.getLogger(__name__)
@Model.register("copynet_seq2seq")
class CopyNetSeq2Seq(Model):
"""
This is an implementation of [CopyNet](https://api.semanticscholar.org/CorpusID:8174613).
CopyNet is a sequence-to-sequence encoder-decoder model with a copying mechanism
that can copy tokens from the source sentence into the target sentence instead of
generating all target tokens only from the target vocabulary.
It is very similar to a typical seq2seq model used in neural machine translation
tasks, for example, except that in addition to providing a "generation" score at each timestep
for the tokens in the target vocabulary, it also provides a "copy" score for each
token that appears in the source sentence. In other words, you can think of CopyNet
as a seq2seq model with a dynamic target vocabulary that changes based on the tokens
in the source sentence, allowing it to predict tokens that are out-of-vocabulary (OOV)
with respect to the actual target vocab.
# Parameters
vocab : `Vocabulary`, required
Vocabulary containing source and target vocabularies.
source_embedder : `TextFieldEmbedder`, required
Embedder for source side sequences
encoder : `Seq2SeqEncoder`, required
The encoder of the "encoder/decoder" model
attention : `Attention`, required
This is used to get a dynamic summary of encoder outputs at each timestep
when producing the "generation" scores for the target vocab.
label_smoothing : `float`, optional (default = `None`)
Whether or not to apply label smoothing to the cross-entropy loss.
For example, with a label smoothing value of 0.2, a 4 class classification
target would look like `[0.05, 0.05, 0.85, 0.05]` if the 3rd class was
the correct label.
beam_search : `BeamSearch`, optional (default = `Lazy(BeamSearch)`)
This is used to during inference to select the tokens of the decoded output sequence.
target_embedding_dim : `int`, optional (default = `30`)
The size of the embeddings for the target vocabulary.
scheduled_sampling_ratio : `float`, optional (default = `0.`)
At each timestep during training, we sample a random number between 0 and 1, and if it is
not less than this value, we use the ground truth labels for the whole batch. Else, we use
the predictions from the previous time step for the whole batch. If this value is 0.0
(default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
using target side ground truth labels. See the following paper for more information:
[Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
2015](https://arxiv.org/abs/1506.03099).
copy_token : `str`, optional (default = `'@COPY@'`)
The token used to indicate that a target token was copied from the source.
If this token is not already in your target vocabulary, it will be added.
target_namespace : `str`, optional (default = `'target_tokens'`)
The namespace for the target vocabulary.
tensor_based_metric : `Metric`, optional (default = `'BLEU'`)
A metric to track on validation data that takes raw tensors when its called.
This metric must accept two arguments when called: a batched tensor
of predicted token indices, and a batched tensor of gold token indices.
token_based_metric : `Metric`, optional (default = `None`)
A metric to track on validation data that takes lists of lists of tokens
as input. This metric must accept two arguments when called, both
of type `List[List[str]]`. The first is a predicted sequence for each item
in the batch and the second is a gold sequence for each item in the batch.
initializer : `InitializerApplicator`, optional
An initialization strategy for the model weights.
"""
def __init__(
self,
vocab: Vocabulary,
source_embedder: TextFieldEmbedder,
encoder: Seq2SeqEncoder,
attention: Attention,
label_smoothing: float = None,
beam_search: Lazy[BeamSearch] = Lazy(BeamSearch),
target_embedding_dim: int = 30,
scheduled_sampling_ratio: float = 0.0,
copy_token: str = "@COPY@",
target_namespace: str = "target_tokens",
tensor_based_metric: Metric = None,
token_based_metric: Metric = None,
initializer: InitializerApplicator = InitializerApplicator(),
**kwargs
) -> None:
super().__init__(vocab)
self._target_namespace = target_namespace
self._scheduled_sampling_ratio = scheduled_sampling_ratio
self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
self._oov_index = self.vocab.get_token_index(self.vocab._oov_token, self._target_namespace)
self._pad_index = self.vocab.get_token_index(
self.vocab._padding_token, self._target_namespace
)
self._copy_index = self.vocab.add_token_to_namespace(copy_token, self._target_namespace)
self._tensor_based_metric = tensor_based_metric or BLEU(
exclude_indices={self._pad_index, self._end_index, self._start_index}
)
self._token_based_metric = token_based_metric
self._target_vocab_size = self.vocab.get_vocab_size(self._target_namespace)
# Encoding modules.
self._source_embedder = source_embedder
self._encoder = encoder
# Decoder output dim needs to be the same as the encoder output dim since we initialize the
# hidden state of the decoder with the final hidden state of the encoder.
# We arbitrarily set the decoder's input dimension to be the same as the output dimension.
self.encoder_output_dim = self._encoder.get_output_dim()
self.decoder_output_dim = self.encoder_output_dim
self.decoder_input_dim = self.decoder_output_dim
# The decoder input will be a function of the embedding of the previous predicted token,
# an attended encoder hidden state called the "attentive read", and another
# weighted sum of the encoder hidden state called the "selective read".
# While the weights for the attentive read are calculated by an `Attention` module,
# the weights for the selective read are simply the predicted probabilities
# corresponding to each token in the source sentence that matches the target
# token from the previous timestep.
self._target_embedder = Embedding(
num_embeddings=self._target_vocab_size, embedding_dim=target_embedding_dim
)
self._attention = attention
self._label_smoothing = label_smoothing
self._input_projection_layer = Linear(
target_embedding_dim + self.encoder_output_dim * 2, self.decoder_input_dim
)
# We then run the projected decoder input through an LSTM cell to produce
# the next hidden state.
self._decoder_cell = LSTMCell(self.decoder_input_dim, self.decoder_output_dim)
# We create a "generation" score for each token in the target vocab
# with a linear projection of the decoder hidden state.
self._output_generation_layer = Linear(self.decoder_output_dim, self._target_vocab_size)
# We create a "copying" score for each source token by applying a non-linearity
# (tanh) to a linear projection of the encoded hidden state for that token,
# and then taking the dot product of the result with the decoder hidden state.
self._output_copying_layer = Linear(self.encoder_output_dim, self.decoder_output_dim)
# At prediction time, we'll use a beam search to find the best target sequence.
# For backwards compatibility, check if beam_size or max_decoding_steps were passed in as
# kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning
deprecation_warning = (
"The parameter {} has been deprecated."
" Provide this parameter as argument to beam_search instead."
)
beam_search_extras = {}
if "beam_size" in kwargs:
beam_search_extras["beam_size"] = kwargs["beam_size"]
warnings.warn(deprecation_warning.format("beam_size"), DeprecationWarning)
if "max_decoding_steps" in kwargs:
beam_search_extras["max_steps"] = kwargs["max_decoding_steps"]
warnings.warn(deprecation_warning.format("max_decoding_steps"), DeprecationWarning)
self._beam_search = beam_search.construct(
end_index=self._end_index, vocab=self.vocab, **beam_search_extras
)
initializer(self)
def forward(
self, # type: ignore
source_tokens: TextFieldTensors,
source_token_ids: torch.Tensor,
source_to_target: torch.Tensor,
metadata: List[Dict[str, Any]],
target_tokens: TextFieldTensors = None,
target_token_ids: torch.Tensor = None,
weight: torch.Tensor = None,
) -> Dict[str, torch.Tensor]:
"""
Make foward pass with decoder logic for producing the entire target sequence.
# Parameters
source_tokens : `TextFieldTensors`, required
The output of `TextField.as_array()` applied on the source `TextField`. This will be
passed through a `TextFieldEmbedder` and then through an encoder.
source_token_ids : `torch.Tensor`, required
Tensor containing IDs that indicate which source tokens match each other.
Has shape: `(batch_size, source_sequence_length)`.
source_to_target : `torch.Tensor`, required
Tensor containing vocab index of each source token with respect to the
target vocab namespace. Shape: `(batch_size, source_sequence_length)`.
metadata : `List[Dict[str, Any]]`, required
Metadata field that contains the original source tokens with key 'source_tokens'
and any other meta fields. When 'target_tokens' is also passed, the metadata
should also contain the original target tokens with key 'target_tokens'.
target_tokens : `TextFieldTensors`, optional (default = `None`)
Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
target tokens are also represented as a `TextField` which must contain a "tokens"
key that uses single ids.
target_token_ids : `torch.Tensor`, optional (default = `None`)
A tensor of shape `(batch_size, target_sequence_length)` which indicates which
tokens in the target sequence match tokens in the source sequence.
weight : `torch.Tensor`, optional (default = `None`)
A optional tensor of shape `(batch_size,)` or `(batch_size, 1)` which determines
how to weight each instance in the batch when calculating the loss.
The default of `None` is equivalent to `weights = torch.tensor([1.0] * batch_size)`.
The final loss is calculated as `-(weight * log_likelihood).sum() / batch_size`,
where `log_likelihood` is a tensor of shape `(batch_size,)`, representing the overall
log likelihood of the gold target sequence.
# Returns
`Dict[str, torch.Tensor]`
"""
state = self._encode(source_tokens)
state["source_token_ids"] = source_token_ids
state["source_to_target"] = source_to_target
if target_tokens:
state = self._init_decoder_state(state)
output_dict = self._forward_loss(target_tokens, target_token_ids, state, weight=weight)
else:
output_dict = {}
output_dict["metadata"] = metadata
if not self.training:
state = self._init_decoder_state(state)
predictions = self._forward_beam_search(state)
output_dict.update(predictions)
if target_tokens:
if self._tensor_based_metric is not None:
# shape: (batch_size, beam_size, max_sequence_length)
top_k_predictions = output_dict["predictions"]
# shape: (batch_size, max_predicted_sequence_length)
best_predictions = top_k_predictions[:, 0, :]
# shape: (batch_size, target_sequence_length)
gold_tokens = self._gather_extended_gold_tokens(
target_tokens["tokens"]["tokens"], source_token_ids, target_token_ids
)
self._tensor_based_metric(best_predictions, gold_tokens) # type: ignore
if self._token_based_metric is not None:
predicted_tokens = self._get_predicted_tokens(
output_dict["predictions"], metadata, n_best=1
)
self._token_based_metric( # type: ignore
predicted_tokens, [x["target_tokens"] for x in metadata]
)
return output_dict
def _gather_extended_gold_tokens(
self,
target_tokens: torch.Tensor,
source_token_ids: torch.Tensor,
target_token_ids: torch.Tensor,
) -> torch.LongTensor:
"""
Modify the gold target tokens relative to the extended vocabulary.
For gold targets that are OOV but were copied from the source, the OOV index
will be changed to the index of the first occurence in the source sentence,
offset by the size of the target vocabulary.
# Parameters
target_tokens : `torch.Tensor`
Shape: `(batch_size, target_sequence_length)`.
source_token_ids : `torch.Tensor`
Shape: `(batch_size, source_sequence_length)`.
target_token_ids : `torch.Tensor`
Shape: `(batch_size, target_sequence_length)`.
# Returns
torch.Tensor
Modified `target_tokens` with OOV indices replaced by offset index
of first match in source sentence.
"""
batch_size, target_sequence_length = target_tokens.size()
source_sequence_length = source_token_ids.size(1)
# Only change indices for tokens that were OOV in target vocab but copied from source.
# shape: (batch_size, target_sequence_length)
oov = target_tokens == self._oov_index
# shape: (batch_size, target_sequence_length, source_sequence_length)
expanded_source_token_ids = source_token_ids.unsqueeze(1).expand(
batch_size, target_sequence_length, source_sequence_length
)
# shape: (batch_size, target_sequence_length, source_sequence_length)
expanded_target_token_ids = target_token_ids.unsqueeze(-1).expand(
batch_size, target_sequence_length, source_sequence_length
)
# shape: (batch_size, target_sequence_length, source_sequence_length)
matches = expanded_source_token_ids == expanded_target_token_ids
# shape: (batch_size, target_sequence_length)
copied = matches.sum(-1) > 0
# shape: (batch_size, target_sequence_length)
mask = oov & copied
# shape: (batch_size, target_sequence_length)
first_match = ((matches.cumsum(-1) == 1) & matches).to(torch.uint8).argmax(-1)
# shape: (batch_size, target_sequence_length)
new_target_tokens = (
target_tokens * ~mask + (first_match.long() + self._target_vocab_size) * mask
)
return new_target_tokens
def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Initialize the encoded state to be passed to the first decoding time step.
"""
batch_size, _ = state["source_mask"].size()
# Initialize the decoder hidden state with the final output of the encoder,
# and the decoder context with zeros.
# shape: (batch_size, encoder_output_dim)
final_encoder_output = util.get_final_encoder_states(
state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()
)
# shape: (batch_size, decoder_output_dim)
state["decoder_hidden"] = final_encoder_output
# shape: (batch_size, decoder_output_dim)
state["decoder_context"] = state["encoder_outputs"].new_zeros(
batch_size, self.decoder_output_dim
)
return state
def _encode(self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Encode source input sentences.
"""
# shape: (batch_size, source_sequence_length, encoder_input_dim)
embedded_input = self._source_embedder(source_tokens)
# shape: (batch_size, source_sequence_length)
source_mask = util.get_text_field_mask(source_tokens)
# shape: (batch_size, source_sequence_length, encoder_output_dim)
encoder_outputs = self._encoder(embedded_input, source_mask)
return {"source_mask": source_mask, "encoder_outputs": encoder_outputs}
def _decoder_step(
self,
last_predictions: torch.Tensor,
selective_weights: torch.Tensor,
state: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
# shape: (group_size, source_sequence_length, encoder_output_dim)
encoder_outputs_mask = state["source_mask"]
# shape: (group_size, target_embedding_dim)
embedded_input = self._target_embedder(last_predictions)
# shape: (group_size, source_sequence_length)
attentive_weights = self._attention(
state["decoder_hidden"], state["encoder_outputs"], encoder_outputs_mask
)
# shape: (group_size, encoder_output_dim)
attentive_read = util.weighted_sum(state["encoder_outputs"], attentive_weights)
# shape: (group_size, encoder_output_dim)
selective_read = util.weighted_sum(state["encoder_outputs"], selective_weights)
# shape: (group_size, target_embedding_dim + encoder_output_dim * 2)
decoder_input = torch.cat((embedded_input, attentive_read, selective_read), -1)
# shape: (group_size, decoder_input_dim)
projected_decoder_input = self._input_projection_layer(decoder_input)
state["decoder_hidden"], state["decoder_context"] = self._decoder_cell(
projected_decoder_input.float(),
(state["decoder_hidden"].float(), state["decoder_context"].float()),
)
return state
def _get_generation_scores(self, state: Dict[str, torch.Tensor]) -> torch.Tensor:
return self._output_generation_layer(state["decoder_hidden"])
def _get_copy_scores(self, state: Dict[str, torch.Tensor]) -> torch.Tensor:
# shape: (batch_size, source_sequence_length, encoder_output_dim)
encoder_outputs = state["encoder_outputs"]
# shape: (batch_size, source_sequence_length, decoder_output_dim)
copy_projection = self._output_copying_layer(encoder_outputs)
# shape: (batch_size, source_sequence_length, decoder_output_dim)
copy_projection = torch.tanh(copy_projection)
# shape: (batch_size, source_sequence_length)
copy_scores = copy_projection.bmm(state["decoder_hidden"].unsqueeze(-1)).squeeze(-1)
return copy_scores
def _get_ll_contrib(
self,
generation_scores: torch.Tensor,
generation_scores_mask: torch.BoolTensor,
copy_scores: torch.Tensor,
target_tokens: torch.Tensor,
target_to_source: torch.Tensor,
source_mask: torch.BoolTensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get the log-likelihood contribution from a single timestep.
# Parameters
generation_scores : `torch.Tensor`
Shape: `(batch_size, target_vocab_size)`
generation_scores_mask : `torch.BoolTensor`
Shape: `(batch_size, target_vocab_size)`. This is just a tensor of 1's.
copy_scores : `torch.Tensor`
Shape: `(batch_size, source_sequence_length)`
target_tokens : `torch.Tensor`
Shape: `(batch_size,)`
target_to_source : `torch.Tensor`
Shape: `(batch_size, source_sequence_length)`
source_mask : `torch.BoolTensor`
Shape: `(batch_size, source_sequence_length)`
# Returns
Tuple[torch.Tensor, torch.Tensor]
Shape: `(batch_size,), (batch_size, source_sequence_length)`
"""
_, target_size = generation_scores.size()
# The point of this mask is to just mask out all source token scores
# that just represent padding. We apply the mask to the concatenation
# of the generation scores and the copy scores to normalize the scores
# correctly during the softmax.
# shape: (batch_size, target_vocab_size + source_sequence_length)
mask = torch.cat((generation_scores_mask, source_mask), dim=-1)
# shape: (batch_size, target_vocab_size + source_sequence_length)
all_scores = torch.cat((generation_scores, copy_scores), dim=-1)
# Normalize generation and copy scores.
# shape: (batch_size, target_vocab_size + source_sequence_length)
log_probs = util.masked_log_softmax(all_scores, mask)
# Calculate the log probability (`copy_log_probs`) for each token in the source sentence
# that matches the current target token. We use the sum of these copy probabilities
# for matching tokens in the source sentence to get the total probability
# for the target token. We also need to normalize the individual copy probabilities
# to create `selective_weights`, which are used in the next timestep to create
# a selective read state.
# shape: (batch_size, source_sequence_length)
copy_log_probs = (
log_probs[:, target_size:]
+ (
target_to_source.to(log_probs.dtype) + util.tiny_value_of_dtype(log_probs.dtype)
).log()
)
# Since `log_probs[:, target_size]` gives us the raw copy log probabilities,
# we use a non-log softmax to get the normalized non-log copy probabilities.
selective_weights = util.masked_softmax(log_probs[:, target_size:], target_to_source)
# This mask ensures that item in the batch has a non-zero generation probabilities
# for this timestep only when the gold target token is not OOV or there are no
# matching tokens in the source sentence.
# shape: (batch_size, 1)
gen_mask = (target_tokens != self._oov_index) | (target_to_source.sum(-1) == 0)
log_gen_mask = (gen_mask + util.tiny_value_of_dtype(log_probs.dtype)).log().unsqueeze(-1)
# Now we get the generation score for the gold target token.
if self._label_smoothing is not None and self._label_smoothing > 0.0:
smoothing_value = self._label_smoothing / target_size
# Create the smoothed targets to be multiplied with `log_probs`
# shape: (batch_size, target_vocab_size + source_sequence_length)
smoothed_targets = torch.full_like(log_probs, smoothing_value).scatter_(
1, target_tokens.unsqueeze(1), 1.0 - self._label_smoothing + smoothing_value
)
generation_log_probs = log_probs * smoothed_targets
# shape: (batch_size, 1)
generation_log_probs = generation_log_probs.sum(-1, keepdim=True)
else:
# Contribution to the negative log likelihood only comes from the exact indices
# of the targets, as the target distributions are one-hot. Here we use torch.gather
# to extract the indices which contribute to the loss.
# shape: (batch_size, 1)
generation_log_probs = log_probs.gather(1, target_tokens.unsqueeze(1))
generation_log_probs = generation_log_probs + log_gen_mask
# ... and add the copy score to get the step log likelihood.
# shape: (batch_size, 1 + source_sequence_length)
combined_gen_and_copy = torch.cat((generation_log_probs, copy_log_probs), dim=-1)
# shape: (batch_size,)
step_log_likelihood = util.logsumexp(combined_gen_and_copy)
return step_log_likelihood, selective_weights, log_probs
def _forward_loss(
self,
target_tokens: TextFieldTensors,
target_token_ids: torch.Tensor,
state: Dict[str, torch.Tensor],
weight: torch.Tensor = None,
) -> Dict[str, torch.Tensor]:
"""
Calculate the loss against gold targets.
"""
batch_size, target_sequence_length = target_tokens["tokens"]["tokens"].size()
# shape: (batch_size, source_sequence_length)
source_mask = state["source_mask"]
# We have a decoding step for every target token except the "START" token.
num_decoding_steps = target_sequence_length - 1
# We use this to fill in the copy index when the previous input was copied.
# shape: (batch_size,)
copy_input_choices = source_mask.new_full(
(batch_size,), fill_value=self._copy_index, dtype=torch.long
)
# We need to keep track of the probabilities assigned to tokens in the source
# sentence that were copied during the previous timestep, since we use
# those probabilities as weights when calculating the "selective read".
# shape: (batch_size, source_sequence_length)
selective_weights = state["decoder_hidden"].new_zeros(source_mask.size())
# Indicates which tokens in the source sentence match the current target token.
# shape: (batch_size, source_sequence_length)
target_to_source = state["source_token_ids"].new_zeros(source_mask.size())
# This is just a tensor of ones which we use repeatedly in `self._get_ll_contrib`,
# so we create it once here to avoid doing it over-and-over.
generation_scores_mask = state["decoder_hidden"].new_full(
(batch_size, self._target_vocab_size), fill_value=1.0, dtype=torch.bool
)
# Initialize target predictions with the start index.
# shape: (batch_size,)
last_predictions = source_mask.new_full(
(batch_size,), fill_value=self._start_index, dtype=torch.long
)
step_log_likelihoods = []
for timestep in range(num_decoding_steps):
if (
self.training
and self._scheduled_sampling_ratio > 0.0
and torch.rand(1).item() < self._scheduled_sampling_ratio
):
# Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
# during training.
# shape: (batch_size,)
input_choices = last_predictions
# Get mask tensor indicating which instances were copied.
# shape: (batch_size,)
copied = (input_choices >= self._target_vocab_size).long()
else:
# shape: (batch_size,)
input_choices = target_tokens["tokens"]["tokens"][:, timestep]
# shape: (batch_size,)
copied = (
(input_choices == self._oov_index) & (target_to_source.sum(-1) > 0)
).long()
# shape: (batch_size,)
input_choices = input_choices * (1 - copied) + copy_input_choices * copied
# shape: (batch_size, source_sequence_length)
target_to_source = state["source_token_ids"] == target_token_ids[
:, timestep + 1
].unsqueeze(-1)
# Update the decoder state by taking a step through the RNN.
state = self._decoder_step(input_choices, selective_weights, state)
# Get generation scores for each token in the target vocab.
# shape: (batch_size, target_vocab_size)
generation_scores = self._get_generation_scores(state)
# Get copy scores for each token in the source sentence, excluding the start
# and end tokens.
# shape: (batch_size, source_sequence_length)
copy_scores = self._get_copy_scores(state)
# shape: (batch_size,)
step_target_tokens = target_tokens["tokens"]["tokens"][:, timestep + 1]
step_log_likelihood, selective_weights, log_probs = self._get_ll_contrib(
generation_scores,
generation_scores_mask,
copy_scores,
step_target_tokens,
target_to_source,
source_mask,
)
step_log_likelihoods.append(step_log_likelihood.unsqueeze(1))
# shape (predicted_classes): (batch_size,)
_, last_predictions = torch.max(log_probs, 1)
# Gather step log-likelihoods.
# shape: (batch_size, num_decoding_steps = target_sequence_length - 1)
log_likelihoods = torch.cat(step_log_likelihoods, 1)
# Get target mask to exclude likelihood contributions from timesteps after
# the END token.
# shape: (batch_size, target_sequence_length)
target_mask = util.get_text_field_mask(target_tokens)
# The first timestep is just the START token, which is not included in the likelihoods.
# shape: (batch_size, num_decoding_steps)
target_mask = target_mask[:, 1:]
# Sum of step log-likelihoods.
# shape: (batch_size,)
log_likelihood = (log_likelihoods * target_mask).sum(dim=-1)
# The loss is the negative log-likelihood, averaged over the batch.
if weight is None:
loss = -log_likelihood.sum() / batch_size
else:
# shape: (batch_size,)
if len(weight.shape) > 1:
weight = weight.squeeze()
loss = -(weight * log_likelihood).sum() / batch_size
return {"loss": loss}
def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
batch_size, source_sequence_length = state["source_mask"].size()
# Initialize the copy scores to zero.
state["copy_log_probs"] = (
state["decoder_hidden"].new_zeros((batch_size, source_sequence_length))
+ util.tiny_value_of_dtype(state["decoder_hidden"].dtype)
).log()
# shape: (batch_size,)
start_predictions = state["source_mask"].new_full(
(batch_size,), fill_value=self._start_index, dtype=torch.long
)
# shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
# shape (log_probabilities): (batch_size, beam_size)
all_top_k_predictions, log_probabilities = self._beam_search.search(
start_predictions, state, self.take_search_step
)
return {"predicted_log_probs": log_probabilities, "predictions": all_top_k_predictions}
def _get_input_and_selective_weights(
self, last_predictions: torch.LongTensor, state: Dict[str, torch.Tensor]
) -> Tuple[torch.LongTensor, torch.Tensor]:
"""
Get input choices for the decoder and the selective copy weights.
The decoder input choices are simply the `last_predictions`, except for
target OOV predictions that were copied from source tokens, in which case
the prediction will be changed to the COPY symbol in the target namespace.
The selective weights are just the probabilities assigned to source
tokens that were copied, normalized to sum to 1. If no source tokens were copied,
there will be all zeros.
# Parameters
last_predictions : `torch.LongTensor`
Shape: `(group_size,)`
state : `Dict[str, torch.Tensor]`
# Returns
Tuple[torch.LongTensor, torch.Tensor]
`input_choices` (shape `(group_size,)`) and `selective_weights`
(shape `(group_size, source_sequence_length)`).
"""
group_size, source_sequence_length = state["source_to_target"].size()
# This is a mask indicating which last predictions were copied from the
# the source AND not in the target vocabulary (OOV).
# (group_size,)
only_copied_mask = last_predictions >= self._target_vocab_size
# If the last prediction was in the target vocab or OOV but not copied,
# we use that as input, otherwise we use the COPY token.
# shape: (group_size,)
copy_input_choices = only_copied_mask.new_full(
(group_size,), fill_value=self._copy_index, dtype=torch.long
)
input_choices = last_predictions * ~only_copied_mask + copy_input_choices * only_copied_mask
# In order to get the `selective_weights`, we need to find out which predictions
# were copied or copied AND generated, which is the case when a prediction appears
# in both the source sentence and the target vocab. But whenever a prediction
# is in the target vocab (even if it also appeared in the source sentence),
# its index will be the corresponding target vocab index, not its index in
# the source sentence offset by the target vocab size. So we first
# use `state["source_to_target"]` to get an indicator of every source token
# that matches the predicted target token.
# shape: (group_size, source_sequence_length)
expanded_last_predictions = last_predictions.unsqueeze(-1).expand(
group_size, source_sequence_length
)
# shape: (group_size, source_sequence_length)
source_copied_and_generated = state["source_to_target"] == expanded_last_predictions
# In order to get indicators for copied source tokens that are OOV with respect
# to the target vocab, we'll make use of `state["source_token_ids"]`.
# First we adjust predictions relative to the start of the source tokens.
# This makes sense because predictions for copied tokens are given by the index of the copied
# token in the source sentence, offset by the size of the target vocabulary.
# shape: (group_size,)
adjusted_predictions = last_predictions - self._target_vocab_size
# The adjusted indices for items that were not copied will be negative numbers,
# and therefore invalid. So we zero them out.
adjusted_predictions = adjusted_predictions * only_copied_mask
# shape: (group_size, source_sequence_length)
source_token_ids = state["source_token_ids"]
# shape: (group_size, source_sequence_length)
adjusted_prediction_ids = source_token_ids.gather(-1, adjusted_predictions.unsqueeze(-1))
# This mask will contain indicators for source tokens that were copied
# during the last timestep.
# shape: (group_size, source_sequence_length)
source_only_copied = source_token_ids == adjusted_prediction_ids
# Since we zero'd-out indices for predictions that were not copied,
# we need to zero out all entries of this mask corresponding to those predictions.
source_only_copied = source_only_copied & only_copied_mask.unsqueeze(-1)
# shape: (group_size, source_sequence_length)
mask = source_only_copied | source_copied_and_generated
# shape: (group_size, source_sequence_length)
selective_weights = util.masked_softmax(state["copy_log_probs"], mask)
return input_choices, selective_weights
def _gather_final_log_probs(
self,
generation_log_probs: torch.Tensor,
copy_log_probs: torch.Tensor,
state: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""
Combine copy probabilities with generation probabilities for matching tokens.
# Parameters
generation_log_probs : `torch.Tensor`
Shape: `(group_size, target_vocab_size)`
copy_log_probs : `torch.Tensor`
Shape: `(group_size, source_sequence_length)`
state : `Dict[str, torch.Tensor]`
# Returns
torch.Tensor
Shape: `(group_size, target_vocab_size + source_sequence_length)`.
"""
_, source_sequence_length = state["source_to_target"].size()
source_token_ids = state["source_token_ids"]
# shape: [(batch_size, *)]
modified_log_probs_list: List[torch.Tensor] = []
for i in range(source_sequence_length):
# shape: (group_size,)
copy_log_probs_slice = copy_log_probs[:, i]
# `source_to_target` is a matrix of shape (group_size, source_sequence_length)
# where element (i, j) is the vocab index of the target token that matches the jth
# source token in the ith group, if there is one, or the index of the OOV symbol otherwise.
# We'll use this to add copy scores to corresponding generation scores.
# shape: (group_size,)
source_to_target_slice = state["source_to_target"][:, i]
# The OOV index in the source_to_target_slice indicates that the source
# token is not in the target vocab, so we don't want to add that copy score
# to the OOV token.
copy_log_probs_to_add_mask = source_to_target_slice != self._oov_index
copy_log_probs_to_add = (
copy_log_probs_slice
+ (
copy_log_probs_to_add_mask
+ util.tiny_value_of_dtype(copy_log_probs_slice.dtype)
).log()
)
# shape: (batch_size, 1)
copy_log_probs_to_add = copy_log_probs_to_add.unsqueeze(-1)
# shape: (batch_size, 1)
selected_generation_log_probs = generation_log_probs.gather(
1, source_to_target_slice.unsqueeze(-1)
)
combined_scores = util.logsumexp(
torch.cat((selected_generation_log_probs, copy_log_probs_to_add), dim=1)
)
generation_log_probs = generation_log_probs.scatter(
-1, source_to_target_slice.unsqueeze(-1), combined_scores.unsqueeze(-1)
)
# We have to combine copy scores for duplicate source tokens so that
# we can find the overall most likely source token. So, if this is the first
# occurence of this particular source token, we add the log_probs from all other
# occurences, otherwise we zero it out since it was already accounted for.
if i < (source_sequence_length - 1):
# Sum copy scores from future occurences of source token.
# shape: (group_size, source_sequence_length - i)
source_future_occurences = source_token_ids[:, (i + 1) :] == source_token_ids[
:, i
].unsqueeze(-1)
# shape: (group_size, source_sequence_length - i)
future_copy_log_probs = (
copy_log_probs[:, (i + 1) :]
+ (
source_future_occurences + util.tiny_value_of_dtype(copy_log_probs.dtype)
).log()
)
# shape: (group_size, 1 + source_sequence_length - i)
combined = torch.cat(
(copy_log_probs_slice.unsqueeze(-1), future_copy_log_probs), dim=-1
)
# shape: (group_size,)
copy_log_probs_slice = util.logsumexp(combined)
if i > 0:
# Remove copy log_probs that we have already accounted for.
# shape: (group_size, i)
source_previous_occurences = source_token_ids[:, 0:i] == source_token_ids[
:, i
].unsqueeze(-1)
# shape: (group_size,)
duplicate_mask = source_previous_occurences.sum(dim=-1) == 0
copy_log_probs_slice = (
copy_log_probs_slice
+ (duplicate_mask + util.tiny_value_of_dtype(copy_log_probs_slice.dtype)).log()
)
# Finally, we zero-out copy scores that we added to the generation scores
# above so that we don't double-count them.
# shape: (group_size,)
left_over_copy_log_probs = (
copy_log_probs_slice
+ (
~copy_log_probs_to_add_mask
+ util.tiny_value_of_dtype(copy_log_probs_slice.dtype)
).log()
)
modified_log_probs_list.append(left_over_copy_log_probs.unsqueeze(-1))
modified_log_probs_list.insert(0, generation_log_probs)
# shape: (group_size, target_vocab_size + source_sequence_length)
modified_log_probs = torch.cat(modified_log_probs_list, dim=-1)
return modified_log_probs
def take_search_step(
self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Take step during beam search.
This function is what gets passed to the `BeamSearch.search` method. It takes
predictions from the last timestep and the current state and outputs
the log probabilities assigned to tokens for the next timestep, as well as the updated
state.
Since we are predicting tokens out of the extended vocab (target vocab + all unique
tokens from the source sentence), this is a little more complicated that just
making a forward pass through the model. The output log probs will have
shape `(group_size, target_vocab_size + source_sequence_length)` so that each
token in the target vocab and source sentence are assigned a probability.
Note that copy scores are assigned to each source token based on their position, not unique value.
So if a token appears more than once in the source sentence, it will have more than one score.
Further, if a source token is also part of the target vocab, its final score
will be the sum of the generation and copy scores. Therefore, in order to
get the score for all tokens in the extended vocab at this step,
we have to combine copy scores for re-occuring source tokens and potentially
add them to the generation scores for the matching token in the target vocab, if
there is one.
So we can break down the final log probs output as the concatenation of two
matrices, A: `(group_size, target_vocab_size)`, and B: `(group_size, source_sequence_length)`.
Matrix A contains the sum of the generation score and copy scores (possibly 0)
for each target token. Matrix B contains left-over copy scores for source tokens
that do NOT appear in the target vocab, with zeros everywhere else. But since
a source token may appear more than once in the source sentence, we also have to
sum the scores for each appearance of each unique source token. So matrix B
actually only has non-zero values at the first occurence of each source token
that is not in the target vocab.
# Parameters
last_predictions : `torch.Tensor`
Shape: `(group_size,)`
state : `Dict[str, torch.Tensor]`
Contains all state tensors necessary to produce generation and copy scores
for next step.
step : `int`
The time step in beam search decoding.
# Notes
`group_size != batch_size`. In fact, `group_size = batch_size * beam_size`.
"""
_, source_sequence_length = state["source_to_target"].size()
# Get input to the decoder RNN and the selective weights. `input_choices`
# is the result of replacing target OOV tokens in `last_predictions` with the
# copy symbol. `selective_weights` consist of the normalized copy probabilities
# assigned to the source tokens that were copied. If no tokens were copied,
# there will be all zeros.
# shape: (group_size,), (group_size, source_sequence_length)
input_choices, selective_weights = self._get_input_and_selective_weights(
last_predictions, state
)
# Update the decoder state by taking a step through the RNN.
state = self._decoder_step(input_choices, selective_weights, state)
# Get the un-normalized generation scores for each token in the target vocab.
# shape: (group_size, target_vocab_size)
generation_scores = self._get_generation_scores(state)
# Get the un-normalized copy scores for each token in the source sentence,
# excluding the start and end tokens.
# shape: (group_size, source_sequence_length)
copy_scores = self._get_copy_scores(state)
# Concat un-normalized generation and copy scores.
# shape: (batch_size, target_vocab_size + source_sequence_length)
all_scores = torch.cat((generation_scores, copy_scores), dim=-1)
# shape: (group_size, source_sequence_length)
source_mask = state["source_mask"]
# shape: (batch_size, target_vocab_size + source_sequence_length)
mask = torch.cat(
(
generation_scores.new_full(generation_scores.size(), True, dtype=torch.bool),
source_mask,
),
dim=-1,
)
# Normalize generation and copy scores.
# shape: (batch_size, target_vocab_size + source_sequence_length)
log_probs = util.masked_log_softmax(all_scores, mask)
# shape: (group_size, target_vocab_size), (group_size, source_sequence_length)
generation_log_probs, copy_log_probs = log_probs.split(
[self._target_vocab_size, source_sequence_length], dim=-1
)
# Update copy_probs needed for getting the `selective_weights` at the next timestep.
state["copy_log_probs"] = copy_log_probs
# We now have normalized generation and copy scores, but to produce the final
# score for each token in the extended vocab, we have to go through and add
# the copy scores to the generation scores of matching target tokens, and sum
# the copy scores of duplicate source tokens.
# shape: (group_size, target_vocab_size + source_sequence_length)
final_log_probs = self._gather_final_log_probs(generation_log_probs, copy_log_probs, state)
return final_log_probs, state
def _get_predicted_tokens(
self,
predicted_indices: Union[torch.Tensor, numpy.ndarray],
batch_metadata: List[Any],
n_best: int = None,
) -> List[Union[List[List[str]], List[str]]]:
"""
Convert predicted indices into tokens.
If `n_best = 1`, the result type will be `List[List[str]]`. Otherwise the result
type will be `List[List[List[str]]]`.
"""
if not isinstance(predicted_indices, numpy.ndarray):
predicted_indices = predicted_indices.detach().cpu().numpy()
predicted_tokens: List[Union[List[List[str]], List[str]]] = []
for top_k_predictions, metadata in zip(predicted_indices, batch_metadata):
batch_predicted_tokens: List[List[str]] = []
for indices in top_k_predictions[:n_best]:
tokens: List[str] = []
indices = list(indices)
if self._end_index in indices:
indices = indices[: indices.index(self._end_index)]
for index in indices:
if index >= self._target_vocab_size:
adjusted_index = index - self._target_vocab_size
token = metadata["source_tokens"][adjusted_index]
else:
token = self.vocab.get_token_from_index(index, self._target_namespace)
tokens.append(token)
batch_predicted_tokens.append(tokens)
if n_best == 1:
predicted_tokens.append(batch_predicted_tokens[0])
else:
predicted_tokens.append(batch_predicted_tokens)
return predicted_tokens
def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
"""
Finalize predictions.
After a beam search, the predicted indices correspond to tokens in the target vocabulary
OR tokens in source sentence. Here we gather the actual tokens corresponding to
the indices.
"""
predicted_tokens = self._get_predicted_tokens(
output_dict["predictions"], output_dict["metadata"]
)
output_dict["predicted_tokens"] = predicted_tokens
return output_dict
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
all_metrics: Dict[str, float] = {}
if not self.training:
if self._tensor_based_metric is not None:
all_metrics.update(
self._tensor_based_metric.get_metric(reset=reset) # type: ignore
)
if self._token_based_metric is not None:
all_metrics.update(self._token_based_metric.get_metric(reset=reset)) # type: ignore
return all_metrics
default_predictor = "seq2seq"