/
inference.py
2319 lines (1960 loc) · 119 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2017--2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
# is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
Code for inference/translation
"""
import copy
import itertools
import json
import logging
import os
import time
from collections import defaultdict
from functools import lru_cache, partial
from typing import Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Union, Set, Any
import mxnet as mx
import numpy as np
from . import constants as C
from . import data_io
from . import lexical_constraints as constrained
from . import lexicon
from . import model
from . import utils
from . import vocab
from .log import is_python34
logger = logging.getLogger(__name__)
class InferenceModel(model.SockeyeModel):
"""
InferenceModel is a SockeyeModel that supports three operations used for inference/decoding:
(1) Encoder forward call: encode source sentence and return initial decoder states.
(2) Decoder forward call: single decoder step: predict next word.
:param config: Configuration object holding details about the model.
:param params_fname: File with model parameters.
:param context: MXNet context to bind modules to.
:param beam_size: Beam size.
:param softmax_temperature: Optional parameter to control steepness of softmax distribution.
:param max_output_length_num_stds: Number of standard deviations as safety margin for maximum output length.
:param decoder_return_logit_inputs: Decoder returns inputs to logit computation instead of softmax over target
vocabulary. Used when logits/softmax are handled separately.
:param cache_output_layer_w_b: Cache weights and biases for logit computation.
:param skip_softmax: If True, does not compute softmax for greedy decoding.
"""
def __init__(self,
config: model.ModelConfig,
params_fname: str,
context: mx.context.Context,
beam_size: int,
softmax_temperature: Optional[float] = None,
max_output_length_num_stds: int = C.DEFAULT_NUM_STD_MAX_OUTPUT_LENGTH,
decoder_return_logit_inputs: bool = False,
cache_output_layer_w_b: bool = False,
forced_max_output_len: Optional[int] = None,
skip_softmax: bool = False) -> None:
super().__init__(config)
self.params_fname = params_fname
self.context = context
self.beam_size = beam_size
utils.check_condition(beam_size < self.config.vocab_target_size,
'The beam size must be smaller than the target vocabulary size.')
if skip_softmax:
assert beam_size == 1, 'Skipping softmax does not have any effect for beam size > 1'
self.skip_softmax = skip_softmax
self.softmax_temperature = softmax_temperature
self.max_input_length, self.get_max_output_length = models_max_input_output_length([self],
max_output_length_num_stds,
forced_max_output_len=forced_max_output_len)
self.max_batch_size = None # type: Optional[int]
self.encoder_module = None # type: Optional[mx.mod.BucketingModule]
self.encoder_default_bucket_key = None # type: Optional[int]
self.decoder_module = None # type: Optional[mx.mod.BucketingModule]
self.decoder_default_bucket_key = None # type: Optional[Tuple[int, int]]
self.decoder_return_logit_inputs = decoder_return_logit_inputs
self.cache_output_layer_w_b = cache_output_layer_w_b
self.output_layer_w = None # type: Optional[mx.nd.NDArray]
self.output_layer_b = None # type: Optional[mx.nd.NDArray]
@property
def num_source_factors(self) -> int:
"""
Returns the number of source factors of this InferenceModel (at least 1).
"""
return self.config.config_data.num_source_factors
def initialize(self, max_batch_size: int, max_input_length: int, get_max_output_length_function: Callable):
"""
Delayed construction of modules to ensure multiple Inference models can agree on computing a common
maximum output length.
:param max_batch_size: Maximum batch size.
:param max_input_length: Maximum input length.
:param get_max_output_length_function: Callable to compute maximum output length.
"""
self.max_batch_size = max_batch_size
self.max_input_length = max_input_length
if self.max_input_length > self.training_max_seq_len_source:
logger.warning("Model was only trained with sentences up to a length of %d, "
"but a max_input_len of %d is used.",
self.training_max_seq_len_source, self.max_input_length)
self.get_max_output_length = get_max_output_length_function
# check the maximum supported length of the encoder & decoder:
if self.max_supported_seq_len_source is not None:
utils.check_condition(self.max_input_length <= self.max_supported_seq_len_source,
"Encoder only supports a maximum length of %d" % self.max_supported_seq_len_source)
if self.max_supported_seq_len_target is not None:
decoder_max_len = self.get_max_output_length(max_input_length)
utils.check_condition(decoder_max_len <= self.max_supported_seq_len_target,
"Decoder only supports a maximum length of %d, but %d was requested. Note that the "
"maximum output length depends on the input length and the source/target length "
"ratio observed during training." % (self.max_supported_seq_len_target,
decoder_max_len))
self.encoder_module, self.encoder_default_bucket_key = self._get_encoder_module()
self.decoder_module, self.decoder_default_bucket_key = self._get_decoder_module()
max_encoder_data_shapes = self._get_encoder_data_shapes(self.encoder_default_bucket_key,
self.max_batch_size)
max_decoder_data_shapes = self._get_decoder_data_shapes(self.decoder_default_bucket_key,
self.max_batch_size * self.beam_size)
self.encoder_module.bind(data_shapes=max_encoder_data_shapes, for_training=False, grad_req="null")
self.decoder_module.bind(data_shapes=max_decoder_data_shapes, for_training=False, grad_req="null")
self.load_params_from_file(self.params_fname)
self.encoder_module.init_params(arg_params=self.params, aux_params=self.aux_params, allow_missing=False)
self.decoder_module.init_params(arg_params=self.params, aux_params=self.aux_params, allow_missing=False)
if self.cache_output_layer_w_b:
if self.output_layer.weight_normalization:
# precompute normalized output layer weight imperatively
assert self.output_layer.weight_norm is not None
weight = self.params[self.output_layer.weight_norm.weight.name].as_in_context(self.context)
scale = self.params[self.output_layer.weight_norm.scale.name].as_in_context(self.context)
self.output_layer_w = self.output_layer.weight_norm(weight, scale)
else:
self.output_layer_w = self.params[self.output_layer.w.name].as_in_context(self.context)
self.output_layer_b = self.params[self.output_layer.b.name].as_in_context(self.context)
def _get_encoder_module(self) -> Tuple[mx.mod.BucketingModule, int]:
"""
Returns a BucketingModule for the encoder. Given a source sequence, it returns
the initial decoder states of the model.
The bucket key for this module is the length of the source sequence.
:return: Tuple of encoder module and default bucket key.
"""
def sym_gen(source_seq_len: int):
source = mx.sym.Variable(C.SOURCE_NAME)
source_words = source.split(num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0]
source_length = utils.compute_lengths(source_words)
# source embedding
(source_embed,
source_embed_length,
source_embed_seq_len) = self.embedding_source.encode(source, source_length, source_seq_len)
# encoder
# source_encoded: (source_encoded_length, batch_size, encoder_depth)
(source_encoded,
source_encoded_length,
source_encoded_seq_len) = self.encoder.encode(source_embed,
source_embed_length,
source_embed_seq_len)
# initial decoder states
decoder_init_states = self.decoder.init_states(source_encoded,
source_encoded_length,
source_encoded_seq_len)
data_names = [C.SOURCE_NAME]
label_names = [] # type: List[str]
return mx.sym.Group(decoder_init_states), data_names, label_names
default_bucket_key = self.max_input_length
module = mx.mod.BucketingModule(sym_gen=sym_gen,
default_bucket_key=default_bucket_key,
context=self.context)
return module, default_bucket_key
def _get_decoder_module(self) -> Tuple[mx.mod.BucketingModule, Tuple[int, int]]:
"""
Returns a BucketingModule for a single decoder step.
Given previously predicted word and previous decoder states, it returns
a distribution over the next predicted word and the next decoder states.
The bucket key for this module is the length of the source sequence
and the current time-step in the inference procedure (e.g. beam search).
The latter corresponds to the current length of the target sequences.
:return: Tuple of decoder module and default bucket key.
"""
def sym_gen(bucket_key: Tuple[int, int]):
"""
Returns either softmax output (probs over target vocabulary) or inputs to logit
computation, controlled by decoder_return_logit_inputs
"""
source_seq_len, decode_step = bucket_key
source_embed_seq_len = self.embedding_source.get_encoded_seq_len(source_seq_len)
source_encoded_seq_len = self.encoder.get_encoded_seq_len(source_embed_seq_len)
self.decoder.reset()
target_prev = mx.sym.Variable(C.TARGET_NAME)
states = self.decoder.state_variables(decode_step)
state_names = [state.name for state in states]
# embedding for previous word
# (batch_size, num_embed)
target_embed_prev, _, _ = self.embedding_target.encode(data=target_prev, data_length=None, seq_len=1)
# decoder
# target_decoded: (batch_size, decoder_depth)
(target_decoded,
attention_probs,
states) = self.decoder.decode_step(decode_step,
target_embed_prev,
source_encoded_seq_len,
*states)
if self.decoder_return_logit_inputs:
# skip output layer in graph
outputs = mx.sym.identity(target_decoded, name=C.LOGIT_INPUTS_NAME)
else:
# logits: (batch_size, target_vocab_size)
logits = self.output_layer(target_decoded)
if self.softmax_temperature is not None:
logits = logits / self.softmax_temperature
if self.skip_softmax:
# skip softmax for greedy decoding
outputs = logits
else:
outputs = mx.sym.softmax(data=logits, name=C.SOFTMAX_NAME)
data_names = [C.TARGET_NAME] + state_names
label_names = [] # type: List[str]
return mx.sym.Group([outputs, attention_probs] + states), data_names, label_names
# pylint: disable=not-callable
default_bucket_key = (self.max_input_length, self.get_max_output_length(self.max_input_length))
module = mx.mod.BucketingModule(sym_gen=sym_gen,
default_bucket_key=default_bucket_key,
context=self.context)
return module, default_bucket_key
def _get_encoder_data_shapes(self, bucket_key: int, batch_size: int) -> List[mx.io.DataDesc]:
"""
Returns data shapes of the encoder module.
:param bucket_key: Maximum input length.
:return: List of data descriptions.
"""
return [mx.io.DataDesc(name=C.SOURCE_NAME,
shape=(batch_size, bucket_key, self.num_source_factors),
layout=C.BATCH_MAJOR)]
@lru_cache(maxsize=None)
def _get_decoder_data_shapes(self, bucket_key: Tuple[int, int], batch_beam_size: int) -> List[mx.io.DataDesc]:
"""
Returns data shapes of the decoder module.
:param bucket_key: Tuple of (maximum input length, maximum target length).
:param batch_beam_size: Batch size * beam size.
:return: List of data descriptions.
"""
source_max_length, target_max_length = bucket_key
return [mx.io.DataDesc(name=C.TARGET_NAME, shape=(batch_beam_size,),
layout="NT")] + self.decoder.state_shapes(batch_beam_size,
target_max_length,
self.encoder.get_encoded_seq_len(
source_max_length),
self.encoder.get_num_hidden())
def run_encoder(self,
source: mx.nd.NDArray,
source_max_length: int) -> 'ModelState':
"""
Runs forward pass of the encoder.
Encodes source given source length and bucket key.
Returns encoder representation of the source, source_length, initial hidden state of decoder RNN,
and initial decoder states tiled to beam size.
:param source: Integer-coded input tokens. Shape (batch_size, source length, num_source_factors).
:param source_max_length: Bucket key.
:return: Initial model state.
"""
batch_size = source.shape[0]
batch = mx.io.DataBatch(data=[source],
label=None,
bucket_key=source_max_length,
provide_data=self._get_encoder_data_shapes(source_max_length, batch_size))
self.encoder_module.forward(data_batch=batch, is_train=False)
decoder_states = self.encoder_module.get_outputs()
# replicate encoder/init module results beam size times
decoder_states = [mx.nd.repeat(s, repeats=self.beam_size, axis=0) for s in decoder_states]
return ModelState(decoder_states)
def run_decoder(self,
prev_word: mx.nd.NDArray,
bucket_key: Tuple[int, int],
model_state: 'ModelState') -> Tuple[mx.nd.NDArray, mx.nd.NDArray, 'ModelState']:
"""
Runs forward pass of the single-step decoder.
:param prev_word: Previous word ids. Shape: (batch*beam,).
:param bucket_key: Bucket key.
:param model_state: Model states.
:return: Decoder stack output (logit inputs or probability distribution), attention scores, updated model state.
"""
batch_beam_size = prev_word.shape[0]
batch = mx.io.DataBatch(
data=[prev_word.as_in_context(self.context)] + model_state.states,
label=None,
bucket_key=bucket_key,
provide_data=self._get_decoder_data_shapes(bucket_key, batch_beam_size))
self.decoder_module.forward(data_batch=batch, is_train=False)
out, attention_probs, *model_state.states = self.decoder_module.get_outputs()
return out, attention_probs, model_state
@property
def training_max_seq_len_source(self) -> int:
""" The maximum sequence length on the source side during training. """
return self.config.config_data.data_statistics.max_observed_len_source
@property
def training_max_seq_len_target(self) -> int:
""" The maximum sequence length on the target side during training. """
return self.config.config_data.data_statistics.max_observed_len_target
@property
def max_supported_seq_len_source(self) -> Optional[int]:
""" If not None this is the maximally supported source length during inference (hard constraint). """
return self.encoder.get_max_seq_len()
@property
def max_supported_seq_len_target(self) -> Optional[int]:
""" If not None this is the maximally supported target length during inference (hard constraint). """
return self.decoder.get_max_seq_len()
@property
def length_ratio_mean(self) -> float:
return self.config.config_data.data_statistics.length_ratio_mean
@property
def length_ratio_std(self) -> float:
return self.config.config_data.data_statistics.length_ratio_std
@property
def source_with_eos(self) -> bool:
return self.config.config_data.source_with_eos
def load_models(context: mx.context.Context,
max_input_len: Optional[int],
beam_size: int,
batch_size: int,
model_folders: List[str],
checkpoints: Optional[List[int]] = None,
softmax_temperature: Optional[float] = None,
max_output_length_num_stds: int = C.DEFAULT_NUM_STD_MAX_OUTPUT_LENGTH,
decoder_return_logit_inputs: bool = False,
cache_output_layer_w_b: bool = False,
forced_max_output_len: Optional[int] = None,
override_dtype: Optional[str] = None,
output_scores: bool = False,
sampling: bool = False) -> Tuple[List[InferenceModel],
List[vocab.Vocab],
vocab.Vocab]:
"""
Loads a list of models for inference.
:param context: MXNet context to bind modules to.
:param max_input_len: Maximum input length.
:param beam_size: Beam size.
:param batch_size: Batch size.
:param model_folders: List of model folders to load models from.
:param checkpoints: List of checkpoints to use for each model in model_folders. Use None to load best checkpoint.
:param softmax_temperature: Optional parameter to control steepness of softmax distribution.
:param max_output_length_num_stds: Number of standard deviations to add to mean target-source length ratio
to compute maximum output length.
:param decoder_return_logit_inputs: Model decoders return inputs to logit computation instead of softmax over target
vocabulary. Used when logits/softmax are handled separately.
:param cache_output_layer_w_b: Models cache weights and biases for logit computation as NumPy arrays (used with
restrict lexicon).
:param forced_max_output_len: An optional overwrite of the maximum output length.
:param override_dtype: Overrides dtype of encoder and decoder defined at training time to a different one.
:param output_scores: Whether the scores will be needed as outputs. If True, scores will be normalized, negative
log probabilities. If False, scores will be negative, raw logit activations if decoding with beam size 1
and a single model.
:param sampling: True if the model is sampling instead of doing normal topk().
:return: List of models, source vocabulary, target vocabulary, source factor vocabularies.
"""
logger.info("Loading %d model(s) from %s ...", len(model_folders), model_folders)
load_time_start = time.time()
models = [] # type: List[InferenceModel]
source_vocabs = [] # type: List[List[vocab.Vocab]]
target_vocabs = [] # type: List[vocab.Vocab]
if checkpoints is None:
checkpoints = [None] * len(model_folders)
else:
utils.check_condition(len(checkpoints) == len(model_folders), "Must provide checkpoints for each model")
skip_softmax = False
# performance tweak: skip softmax for a single model, decoding with beam size 1, when not sampling and no scores are required in output.
if len(model_folders) == 1 and beam_size == 1 and not output_scores and not sampling:
skip_softmax = True
logger.info("Enabled skipping softmax for a single model and greedy decoding.")
for model_folder, checkpoint in zip(model_folders, checkpoints):
model_source_vocabs = vocab.load_source_vocabs(model_folder)
model_target_vocab = vocab.load_target_vocab(model_folder)
source_vocabs.append(model_source_vocabs)
target_vocabs.append(model_target_vocab)
model_version = utils.load_version(os.path.join(model_folder, C.VERSION_NAME))
logger.info("Model version: %s", model_version)
utils.check_version(model_version)
model_config = model.SockeyeModel.load_config(os.path.join(model_folder, C.CONFIG_NAME))
logger.info("Disabling dropout layers for performance reasons")
model_config.disable_dropout()
if override_dtype is not None:
model_config.config_encoder.dtype = override_dtype
model_config.config_decoder.dtype = override_dtype
if override_dtype == C.DTYPE_FP16:
logger.warning('Experimental feature \'override_dtype=float16\' has been used. '
'This feature may be removed or change its behaviour in future. '
'DO NOT USE IT IN PRODUCTION!')
if checkpoint is None:
params_fname = os.path.join(model_folder, C.PARAMS_BEST_NAME)
else:
params_fname = os.path.join(model_folder, C.PARAMS_NAME % checkpoint)
inference_model = InferenceModel(config=model_config,
params_fname=params_fname,
context=context,
beam_size=beam_size,
softmax_temperature=softmax_temperature,
decoder_return_logit_inputs=decoder_return_logit_inputs,
cache_output_layer_w_b=cache_output_layer_w_b,
skip_softmax=skip_softmax)
utils.check_condition(inference_model.num_source_factors == len(model_source_vocabs),
"Number of loaded source vocabularies (%d) does not match "
"number of source factors for model '%s' (%d)" % (len(model_source_vocabs), model_folder,
inference_model.num_source_factors))
models.append(inference_model)
utils.check_condition(vocab.are_identical(*target_vocabs), "Target vocabulary ids do not match")
first_model_vocabs = source_vocabs[0]
for fi in range(len(first_model_vocabs)):
utils.check_condition(vocab.are_identical(*[source_vocabs[i][fi] for i in range(len(source_vocabs))]),
"Source vocabulary ids do not match. Factor %d" % fi)
source_with_eos = models[0].source_with_eos
utils.check_condition(all(source_with_eos == m.source_with_eos for m in models),
"All models must agree on using source-side EOS symbols or not. "
"Did you try combining models trained with different versions?")
# set a common max_output length for all models.
max_input_len, get_max_output_length = models_max_input_output_length(models,
max_output_length_num_stds,
max_input_len,
forced_max_output_len=forced_max_output_len)
for inference_model in models:
inference_model.initialize(batch_size, max_input_len, get_max_output_length)
load_time = time.time() - load_time_start
logger.info("%d model(s) loaded in %.4fs", len(models), load_time)
return models, source_vocabs[0], target_vocabs[0]
def models_max_input_output_length(models: List[InferenceModel],
num_stds: int,
forced_max_input_len: Optional[int] = None,
forced_max_output_len: Optional[int] = None) -> Tuple[int, Callable]:
"""
Returns a function to compute maximum output length given a fixed number of standard deviations as a
safety margin, and the current input length.
Mean and std are taken from the model with the largest values to allow proper ensembling of models
trained on different data sets.
:param models: List of models.
:param num_stds: Number of standard deviations to add as a safety margin. If -1, returned maximum output lengths
will always be 2 * input_length.
:param forced_max_input_len: An optional overwrite of the maximum input length.
:param forced_max_output_len: An optional overwrite of the maximum output length.
:return: The maximum input length and a function to get the output length given the input length.
"""
max_mean = max(model.length_ratio_mean for model in models)
max_std = max(model.length_ratio_std for model in models)
supported_max_seq_len_source = min((model.max_supported_seq_len_source for model in models
if model.max_supported_seq_len_source is not None),
default=None)
supported_max_seq_len_target = min((model.max_supported_seq_len_target for model in models
if model.max_supported_seq_len_target is not None),
default=None)
training_max_seq_len_source = min(model.training_max_seq_len_source for model in models)
return get_max_input_output_length(supported_max_seq_len_source,
supported_max_seq_len_target,
training_max_seq_len_source,
length_ratio_mean=max_mean,
length_ratio_std=max_std,
num_stds=num_stds,
forced_max_input_len=forced_max_input_len,
forced_max_output_len=forced_max_output_len)
def get_max_input_output_length(supported_max_seq_len_source: Optional[int],
supported_max_seq_len_target: Optional[int],
training_max_seq_len_source: Optional[int],
length_ratio_mean: float,
length_ratio_std: float,
num_stds: int,
forced_max_input_len: Optional[int] = None,
forced_max_output_len: Optional[int] = None) -> Tuple[int, Callable]:
"""
Returns a function to compute maximum output length given a fixed number of standard deviations as a
safety margin, and the current input length. It takes into account optional maximum source and target lengths.
:param supported_max_seq_len_source: The maximum source length supported by the models.
:param supported_max_seq_len_target: The maximum target length supported by the models.
:param training_max_seq_len_source: The maximum source length observed during training.
:param length_ratio_mean: The mean of the length ratio that was calculated on the raw sequences with special
symbols such as EOS or BOS.
:param length_ratio_std: The standard deviation of the length ratio.
:param num_stds: The number of standard deviations the target length may exceed the mean target length (as long as
the supported maximum length allows for this).
:param forced_max_input_len: An optional overwrite of the maximum input length.
:param forced_max_output_len: An optional overwrite of the maximum out length.
:return: The maximum input length and a function to get the output length given the input length.
"""
space_for_bos = 1
space_for_eos = 1
if num_stds < 0:
factor = C.TARGET_MAX_LENGTH_FACTOR # type: float
else:
factor = length_ratio_mean + (length_ratio_std * num_stds)
if forced_max_input_len is None:
# Make sure that if there is a hard constraint on the maximum source or target length we never exceed this
# constraint. This is for example the case for learned positional embeddings, which are only defined for the
# maximum source and target sequence length observed during training.
if supported_max_seq_len_source is not None and supported_max_seq_len_target is None:
max_input_len = supported_max_seq_len_source
elif supported_max_seq_len_source is None and supported_max_seq_len_target is not None:
max_output_len = supported_max_seq_len_target - space_for_bos - space_for_eos
if np.ceil(factor * training_max_seq_len_source) > max_output_len:
max_input_len = int(np.floor(max_output_len / factor))
else:
max_input_len = training_max_seq_len_source
elif supported_max_seq_len_source is not None or supported_max_seq_len_target is not None:
max_output_len = supported_max_seq_len_target - space_for_bos - space_for_eos
if np.ceil(factor * supported_max_seq_len_source) > max_output_len:
max_input_len = int(np.floor(max_output_len / factor))
else:
max_input_len = supported_max_seq_len_source
else:
# Any source/target length is supported and max_input_len was not manually set, therefore we use the
# maximum length from training.
max_input_len = training_max_seq_len_source
else:
max_input_len = forced_max_input_len
def get_max_output_length(input_length: int):
"""
Returns the maximum output length for inference given the input length.
Explicitly includes space for BOS and EOS sentence symbols in the target sequence, because we assume
that the mean length ratio computed on the training data do not include these special symbols.
(see data_io.analyze_sequence_lengths)
"""
if forced_max_output_len is not None:
return forced_max_output_len
else:
return int(np.ceil(factor * input_length)) + space_for_bos + space_for_eos
return max_input_len, get_max_output_length
BeamHistory = Dict[str, List]
Tokens = List[str]
SentenceId = Union[int, str]
class TranslatorInput:
"""
Object required by Translator.translate().
If not None, `pass_through_dict` is an arbitrary dictionary instantiated from a JSON object
via `make_input_from_dict()`, and it contains extra fields found in an input JSON object.
If `--output-type json` is selected, all such fields that are not fields used or changed by
Sockeye will be included in the output JSON object. This provides a mechanism for passing
fields through the call to Sockeye.
:param sentence_id: Sentence id.
:param tokens: List of input tokens.
:param factors: Optional list of additional factor sequences.
:param constraints: Optional list of target-side constraints.
:param pass_through_dict: Optional raw dictionary of arbitrary input data.
"""
__slots__ = ('sentence_id', 'tokens', 'factors', 'constraints', 'avoid_list', 'pass_through_dict')
def __init__(self,
sentence_id: SentenceId,
tokens: Tokens,
factors: Optional[List[Tokens]] = None,
constraints: Optional[List[Tokens]] = None,
avoid_list: Optional[List[Tokens]] = None,
pass_through_dict: Optional[Dict] = None) -> None:
self.sentence_id = sentence_id
self.tokens = tokens
self.factors = factors
self.constraints = constraints
self.avoid_list = avoid_list
self.pass_through_dict = pass_through_dict
def __str__(self):
return 'TranslatorInput(%s, %s, factors=%s, constraints=%s, avoid=%s)' \
% (self.sentence_id, self.tokens, self.factors, self.constraints, self.avoid_list)
def __len__(self):
return len(self.tokens)
@property
def num_factors(self) -> int:
"""
Returns the number of factors of this instance.
"""
return 1 + (0 if not self.factors else len(self.factors))
def chunks(self, chunk_size: int) -> Generator['TranslatorInput', None, None]:
"""
Takes a TranslatorInput (itself) and yields TranslatorInputs for chunks of size chunk_size.
:param chunk_size: The maximum size of a chunk.
:return: A generator of TranslatorInputs, one for each chunk created.
"""
if len(self.tokens) > chunk_size and self.constraints is not None:
logger.warning(
'Input %s has length (%d) that exceeds max input length (%d), '
'triggering internal splitting. Placing all target-side constraints '
'with the first chunk, which is probably wrong.',
self.sentence_id, len(self.tokens), chunk_size)
for chunk_id, i in enumerate(range(0, len(self), chunk_size)):
factors = [factor[i:i + chunk_size] for factor in self.factors] if self.factors is not None else None
# Constrained decoding is not supported for chunked TranslatorInputs. As a fall-back, constraints are
# assigned to the first chunk
constraints = self.constraints if chunk_id == 0 else None
pass_through_dict = self.pass_through_dict if chunk_id == 0 else None
yield TranslatorInput(sentence_id=self.sentence_id,
tokens=self.tokens[i:i + chunk_size],
factors=factors,
constraints=constraints,
avoid_list=self.avoid_list,
pass_through_dict=pass_through_dict)
def with_eos(self) -> 'TranslatorInput':
"""
:return: A new translator input with EOS appended to the tokens and factors.
"""
return TranslatorInput(sentence_id=self.sentence_id,
tokens=self.tokens + [C.EOS_SYMBOL],
factors=[factor + [C.EOS_SYMBOL] for factor in
self.factors] if self.factors is not None else None,
constraints=self.constraints,
avoid_list=self.avoid_list,
pass_through_dict=self.pass_through_dict)
class BadTranslatorInput(TranslatorInput):
def __init__(self, sentence_id: SentenceId, tokens: Tokens) -> None:
super().__init__(sentence_id=sentence_id, tokens=tokens, factors=None)
def _bad_input(sentence_id: SentenceId, reason: str = '') -> BadTranslatorInput:
logger.warning("Bad input (%s): '%s'. Will return empty output.", sentence_id, reason.strip())
return BadTranslatorInput(sentence_id=sentence_id, tokens=[])
def make_input_from_plain_string(sentence_id: SentenceId, string: str) -> TranslatorInput:
"""
Returns a TranslatorInput object from a plain string.
:param sentence_id: Sentence id.
:param string: An input string.
:return: A TranslatorInput.
"""
return TranslatorInput(sentence_id, tokens=list(data_io.get_tokens(string)), factors=None)
def make_input_from_json_string(sentence_id: SentenceId, json_string: str) -> TranslatorInput:
"""
Returns a TranslatorInput object from a JSON object, serialized as a string.
:param sentence_id: Sentence id.
:param json_string: A JSON object serialized as a string that must contain a key "text", mapping to the input text,
and optionally a key "factors" that maps to a list of strings, each of which representing a factor sequence
for the input text. Constraints and an avoid list can also be added through the "constraints" and "avoid"
keys.
:return: A TranslatorInput.
"""
try:
jobj = json.loads(json_string, encoding=C.JSON_ENCODING)
return make_input_from_dict(sentence_id, jobj)
except Exception as e:
logger.exception(e, exc_info=True) if not is_python34() else logger.error(e) # type: ignore
return _bad_input(sentence_id, reason=json_string)
def make_input_from_dict(sentence_id: SentenceId, input_dict: Dict) -> TranslatorInput:
"""
Returns a TranslatorInput object from a JSON object, serialized as a string.
:param sentence_id: Sentence id.
:param input_dict: A dict that must contain a key "text", mapping to the input text, and optionally a key "factors"
that maps to a list of strings, each of which representing a factor sequence for the input text.
Constraints and an avoid list can also be added through the "constraints" and "avoid" keys.
:return: A TranslatorInput.
"""
try:
tokens = input_dict[C.JSON_TEXT_KEY]
tokens = list(data_io.get_tokens(tokens))
factors = input_dict.get(C.JSON_FACTORS_KEY)
if isinstance(factors, list):
factors = [list(data_io.get_tokens(factor)) for factor in factors]
lengths = [len(f) for f in factors]
if not all(length == len(tokens) for length in lengths):
logger.error("Factors have different length than input text: %d vs. %s", len(tokens), str(lengths))
return _bad_input(sentence_id, reason=str(input_dict))
# List of phrases to prevent from occuring in the output
avoid_list = input_dict.get(C.JSON_AVOID_KEY)
# List of phrases that must appear in the output
constraints = input_dict.get(C.JSON_CONSTRAINTS_KEY)
# If there is overlap between positive and negative constraints, assume the user wanted
# the words, and so remove them from the avoid_list (negative constraints)
if constraints is not None and avoid_list is not None:
avoid_set = set(avoid_list)
overlap = set(constraints).intersection(avoid_set)
if len(overlap) > 0:
logger.warning("Overlap between constraints and avoid set, dropping the overlapping avoids")
avoid_list = list(avoid_set.difference(overlap))
# Convert to a list of tokens
if isinstance(avoid_list, list):
avoid_list = [list(data_io.get_tokens(phrase)) for phrase in avoid_list]
if isinstance(constraints, list):
constraints = [list(data_io.get_tokens(constraint)) for constraint in constraints]
return TranslatorInput(sentence_id=sentence_id, tokens=tokens, factors=factors,
constraints=constraints, avoid_list=avoid_list, pass_through_dict=input_dict)
except Exception as e:
logger.exception(e, exc_info=True) if not is_python34() else logger.error(e) # type: ignore
return _bad_input(sentence_id, reason=str(input_dict))
def make_input_from_factored_string(sentence_id: SentenceId,
factored_string: str,
translator: 'Translator',
delimiter: str = C.DEFAULT_FACTOR_DELIMITER) -> TranslatorInput:
"""
Returns a TranslatorInput object from a string with factor annotations on a token level, separated by delimiter.
If translator does not require any source factors, the string is parsed as a plain token string.
:param sentence_id: Sentence id.
:param factored_string: An input string with additional factors per token, separated by delimiter.
:param translator: A translator object.
:param delimiter: A factor delimiter. Default: '|'.
:return: A TranslatorInput.
"""
utils.check_condition(bool(delimiter) and not delimiter.isspace(),
"Factor delimiter can not be whitespace or empty.")
model_num_source_factors = translator.num_source_factors
if model_num_source_factors == 1:
return make_input_from_plain_string(sentence_id=sentence_id, string=factored_string)
tokens = [] # type: Tokens
factors = [[] for _ in range(model_num_source_factors - 1)] # type: List[Tokens]
for token_id, token in enumerate(data_io.get_tokens(factored_string)):
pieces = token.split(delimiter)
if not all(pieces) or len(pieces) != model_num_source_factors:
logger.error("Failed to parse %d factors at position %d ('%s') in '%s'" % (model_num_source_factors,
token_id, token,
factored_string.strip()))
return _bad_input(sentence_id, reason=factored_string)
tokens.append(pieces[0])
for i, factor in enumerate(factors):
factors[i].append(pieces[i + 1])
return TranslatorInput(sentence_id=sentence_id, tokens=tokens, factors=factors)
def make_input_from_multiple_strings(sentence_id: SentenceId, strings: List[str]) -> TranslatorInput:
"""
Returns a TranslatorInput object from multiple strings, where the first element corresponds to the surface tokens
and the remaining elements to additional factors. All strings must parse into token sequences of the same length.
:param sentence_id: Sentence id.
:param strings: A list of strings representing a factored input sequence.
:return: A TranslatorInput.
"""
if not bool(strings):
return TranslatorInput(sentence_id=sentence_id, tokens=[], factors=None)
tokens = list(data_io.get_tokens(strings[0]))
factors = [list(data_io.get_tokens(factor)) for factor in strings[1:]]
if not all(len(factor) == len(tokens) for factor in factors):
logger.error("Length of string sequences do not match: '%s'", strings)
return _bad_input(sentence_id, reason=str(strings))
return TranslatorInput(sentence_id=sentence_id, tokens=tokens, factors=factors)
class TranslatorOutput:
"""
Output structure from Translator.
:param sentence_id: Sentence id.
:param translation: Translation string without sentence boundary tokens.
:param tokens: List of translated tokens.
:param attention_matrix: Attention matrix. Shape: (target_length, source_length).
:param score: Negative log probability of generated translation.
:param pass_through_dict: Dictionary of key/value pairs to pass through when working with JSON.
:param beam_histories: List of beam histories. The list will contain more than one
history if it was split due to exceeding max_length.
:param nbest_translations: List of nbest translations as strings.
:param nbest_tokens: List of nbest translations as lists of tokens.
:param nbest_attention_matrices: List of attention matrices, one for each nbest translation.
:param nbest_scores: List of nbest scores, one for each nbest translation.
"""
__slots__ = ('sentence_id',
'translation',
'tokens',
'attention_matrix',
'score',
'pass_through_dict',
'beam_histories',
'nbest_translations',
'nbest_tokens',
'nbest_attention_matrices',
'nbest_scores')
def __init__(self,
sentence_id: SentenceId,
translation: str,
tokens: Tokens,
attention_matrix: np.ndarray,
score: float,
pass_through_dict: Optional[Dict[str,Any]] = None,
beam_histories: Optional[List[BeamHistory]] = None,
nbest_translations: Optional[List[str]] = None,
nbest_tokens: Optional[List[Tokens]] = None,
nbest_attention_matrices: Optional[List[np.ndarray]] = None,
nbest_scores: Optional[List[float]] = None) -> None:
self.sentence_id = sentence_id
self.translation = translation
self.tokens = tokens
self.attention_matrix = attention_matrix
self.score = score
self.pass_through_dict = copy.deepcopy(pass_through_dict) if pass_through_dict else {}
self.beam_histories = beam_histories
self.nbest_translations = nbest_translations
self.nbest_tokens = nbest_tokens
self.nbest_attention_matrices = nbest_attention_matrices
self.nbest_scores = nbest_scores
def json(self, align_threshold: float = 0.0) -> Dict:
"""
Returns a dictionary suitable for json.dumps() representing all
the information in the class. It is initialized with any keys
present in the corresponding `TranslatorInput` object's pass_through_dict.
Keys from here that are not overwritten by Sockeye will thus be passed
through to the output.
:param align_threshold: If alignments are defined, only print ones over this threshold.
:return: A dictionary.
"""
_d = self.pass_through_dict # type: Dict[str, Any]
_d['sentence_id'] = self.sentence_id
_d['translation'] = self.translation
_d['score'] = self.score
if self.nbest_translations is not None and len(self.nbest_translations) > 1:
_d['translations'] = self.nbest_translations
_d['scores'] = self.nbest_scores
if self.nbest_attention_matrices:
extracted_alignments = []
for alignment_matrix in self.nbest_attention_matrices:
extracted_alignments.append(list(utils.get_alignments(alignment_matrix, threshold=align_threshold)))
_d['alignments'] = extracted_alignments
return _d
TokenIds = List[int]
class NBestTranslations:
__slots__ = ('target_ids_list',
'attention_matrices',
'scores')
def __init__(self,
target_ids_list: List[TokenIds],
attention_matrices: List[np.ndarray],
scores: List[float]) -> None:
self.target_ids_list = target_ids_list
self.attention_matrices = attention_matrices
self.scores = scores
class Translation:
__slots__ = ('target_ids',
'attention_matrix',
'score',
'beam_histories',
'nbest_translations')
def __init__(self,
target_ids: TokenIds,
attention_matrix: np.ndarray,
score: float,
beam_histories: List[BeamHistory] = None,
nbest_translations: NBestTranslations = None) -> None:
self.target_ids = target_ids
self.attention_matrix = attention_matrix
self.score = score
self.beam_histories = beam_histories if beam_histories is not None else []
self.nbest_translations = nbest_translations
def empty_translation(add_nbest: bool = False) -> Translation:
"""
Return an empty translation.
:param add_nbest: Include (empty) nbest_translations in the translation object.
"""
return Translation(target_ids=[],
attention_matrix=np.asarray([[0]]),
score=-np.inf,
nbest_translations=NBestTranslations([], [], []) if add_nbest else None
)
IndexedTranslatorInput = NamedTuple('IndexedTranslatorInput', [
('input_idx', int),
('chunk_idx', int),
('translator_input', TranslatorInput)
])
"""
Translation of a chunk of a sentence.
:param input_idx: Internal index of translation requests to keep track of the correct order of translations.
:param chunk_idx: The index of the chunk. Used when TranslatorInputs get split across multiple chunks.
:param input: The translator input.
"""
IndexedTranslation = NamedTuple('IndexedTranslation', [
('input_idx', int),
('chunk_idx', int),
('translation', Translation)