forked from awslabs/sockeye
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_io.py
1779 lines (1506 loc) · 80.9 KB
/
data_io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2017, 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
# is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
Implements data iterators and I/O related functions for sequence-to-sequence models.
"""
import bisect
import logging
import math
import os
import pickle
import random
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import ExitStack
from typing import Any, cast, Dict, Iterator, Iterable, List, Optional, Sequence, Sized, Tuple
import mxnet as mx
import numpy as np
from . import align
from . import config
from . import constants as C
from . import vocab
from .utils import check_condition, smart_open, get_tokens, OnlineMeanAndVariance
logger = logging.getLogger(__name__)
def define_buckets(max_seq_len: int, step=10) -> List[int]:
"""
Returns a list of integers defining bucket boundaries.
Bucket boundaries are created according to the following policy:
We generate buckets with a step size of step until the final bucket fits max_seq_len.
We then limit that bucket to max_seq_len (difference between semi-final and final bucket may be less than step).
:param max_seq_len: Maximum bucket size.
:param step: Distance between buckets.
:return: List of bucket sizes.
"""
buckets = [bucket_len for bucket_len in range(step, max_seq_len + step, step)]
buckets[-1] = max_seq_len
return buckets
def define_parallel_buckets(max_seq_len_source: int,
max_seq_len_target: int,
bucket_width: int = 10,
length_ratio: float = 1.0) -> List[Tuple[int, int]]:
"""
Returns (source, target) buckets up to (max_seq_len_source, max_seq_len_target). The longer side of the data uses
steps of bucket_width while the shorter side uses steps scaled down by the average target/source length ratio. If
one side reaches its max_seq_len before the other, width of extra buckets on that side is fixed to that max_seq_len.
:param max_seq_len_source: Maximum source bucket size.
:param max_seq_len_target: Maximum target bucket size.
:param bucket_width: Width of buckets on longer side.
:param length_ratio: Length ratio of data (target/source).
"""
source_step_size = bucket_width
target_step_size = bucket_width
if length_ratio >= 1.0:
# target side is longer -> scale source
source_step_size = max(1, int(round(bucket_width / length_ratio)))
else:
# source side is longer, -> scale target
target_step_size = max(1, int(round(bucket_width * length_ratio)))
source_buckets = define_buckets(max_seq_len_source, step=source_step_size)
target_buckets = define_buckets(max_seq_len_target, step=target_step_size)
# Extra buckets
if len(source_buckets) < len(target_buckets):
source_buckets += [source_buckets[-1] for _ in range(len(target_buckets) - len(source_buckets))]
elif len(target_buckets) < len(source_buckets):
target_buckets += [target_buckets[-1] for _ in range(len(source_buckets) - len(target_buckets))]
# minimum bucket size is 2 (as we add BOS symbol to target side)
source_buckets = [max(2, b) for b in source_buckets]
target_buckets = [max(2, b) for b in target_buckets]
parallel_buckets = list(zip(source_buckets, target_buckets))
# deduplicate for return
buckets = list(OrderedDict.fromkeys(parallel_buckets))
buckets.sort()
return buckets
def define_empty_source_parallel_buckets(max_seq_len_target: int,
bucket_width: int = 10) -> List[Tuple[int, int]]:
"""
Returns (source, target) buckets up to (None, max_seq_len_target). The source
is empty since it is supposed to not contain data that can be bucketized.
The target is used as reference to create the buckets.
:param max_seq_len_target: Maximum target bucket size.
:param bucket_width: Width of buckets on longer side.
"""
target_step_size = max(1, bucket_width)
target_buckets = define_buckets(max_seq_len_target, step=target_step_size)
# source buckets are always 0 since there is no text
source_buckets = [0 for b in target_buckets]
target_buckets = [max(2, b) for b in target_buckets]
parallel_buckets = list(zip(source_buckets, target_buckets))
# deduplicate for return
buckets = list(OrderedDict.fromkeys(parallel_buckets))
buckets.sort()
return buckets
def get_bucket(seq_len: int, buckets: List[int]) -> Optional[int]:
"""
Given sequence length and a list of buckets, return corresponding bucket.
:param seq_len: Sequence length.
:param buckets: List of buckets.
:return: Chosen bucket.
"""
bucket_idx = bisect.bisect_left(buckets, seq_len)
if bucket_idx == len(buckets):
return None
return buckets[bucket_idx]
class BucketBatchSize:
"""
:param bucket: The corresponding bucket.
:param batch_size: Number of sequences in each batch.
:param average_words_per_batch: Approximate number of non-padding tokens in each batch.
"""
def __init__(self, bucket: Tuple[int, int], batch_size: int, average_words_per_batch: float) -> None:
self.bucket = bucket
self.batch_size = batch_size
self.average_words_per_batch = average_words_per_batch
def define_bucket_batch_sizes(buckets: List[Tuple[int, int]],
batch_size: int,
batch_by_words: bool,
batch_num_devices: int,
data_target_average_len: List[Optional[float]]) -> List[BucketBatchSize]:
"""
Computes bucket-specific batch sizes (sentences, average_words).
If sentence-based batching: number of sentences is the same for each batch, determines the
number of words. Hence all batch sizes for each bucket are equal.
If word-based batching: number of sentences for each batch is set to the multiple of number
of devices that produces the number of words closest to the target batch size. Average
target sentence length (non-padding symbols) is used for word number calculations.
:param buckets: Bucket list.
:param batch_size: Batch size.
:param batch_by_words: Batch by words.
:param batch_num_devices: Number of devices.
:param data_target_average_len: Optional average target length for each bucket.
"""
check_condition(len(data_target_average_len) == len(buckets),
"Must provide None or average target length for each bucket")
data_target_average_len = list(data_target_average_len)
bucket_batch_sizes = [] # type: List[BucketBatchSize]
largest_total_num_words = 0
for buck_idx, bucket in enumerate(buckets):
# Target/label length with padding
padded_seq_len = bucket[1]
# Average target/label length excluding padding
if data_target_average_len[buck_idx] is None:
data_target_average_len[buck_idx] = padded_seq_len
average_seq_len = data_target_average_len[buck_idx]
# Word-based: num words determines num sentences
# Sentence-based: num sentences determines num words
if batch_by_words:
check_condition(padded_seq_len <= batch_size, "Word batch size must cover sequence lengths for all"
" buckets: (%d > %d)" % (padded_seq_len, batch_size))
# Multiple of number of devices (int) closest to target number of words, assuming each sentence is of
# average length
batch_size_seq = batch_num_devices * max(1, round((batch_size / average_seq_len) / batch_num_devices))
batch_size_word = batch_size_seq * average_seq_len
else:
batch_size_seq = batch_size
batch_size_word = batch_size_seq * average_seq_len
bucket_batch_sizes.append(BucketBatchSize(bucket, batch_size_seq, batch_size_word))
# Track largest number of source or target word samples in a batch
largest_total_num_words = max(largest_total_num_words, batch_size_seq * max(*bucket))
# Final step: guarantee that largest bucket by sequence length also has a batch size so that it covers any
# (batch_size, len_source) and (batch_size, len_target) matrix from the data iterator to allow for memory sharing.
# When batching by sentences, this will already be the case.
if batch_by_words:
padded_seq_len = max(*buckets[-1])
average_seq_len = data_target_average_len[-1]
while bucket_batch_sizes[-1].batch_size * padded_seq_len < largest_total_num_words:
bucket_batch_sizes[-1] = BucketBatchSize(
bucket_batch_sizes[-1].bucket,
bucket_batch_sizes[-1].batch_size + batch_num_devices,
bucket_batch_sizes[-1].average_words_per_batch + batch_num_devices * average_seq_len)
return bucket_batch_sizes
def calculate_length_statistics(source_iterables: Sequence[Iterable[Any]],
target_iterable: Iterable[Any],
max_seq_len_source: int,
max_seq_len_target: int) -> 'LengthStatistics':
"""
Returns mean and standard deviation of target-to-source length ratios of parallel corpus.
:param source_iterables: Source sequence readers.
:param target_iterable: Target sequence reader.
:param max_seq_len_source: Maximum source sequence length.
:param max_seq_len_target: Maximum target sequence length.
:return: The number of sentences as well as the mean and standard deviation of target to source length ratios.
"""
mean_and_variance = OnlineMeanAndVariance()
for sources, target in parallel_iter(source_iterables, target_iterable):
source_len = len(sources[0])
target_len = len(target)
if source_len > max_seq_len_source or target_len > max_seq_len_target:
continue
length_ratio = target_len / source_len
mean_and_variance.update(length_ratio)
num_sents = mean_and_variance.count
mean = mean_and_variance.mean
if not math.isnan(mean_and_variance.variance):
std = math.sqrt(mean_and_variance.variance)
else:
std = 0.0
return LengthStatistics(num_sents, mean, std)
def analyze_sequence_lengths(sources: List[str],
target: str,
vocab_sources: List[vocab.Vocab],
vocab_target: vocab.Vocab,
max_seq_len_source: int,
max_seq_len_target: int,
use_pointer_nets: bool,
max_oov_words: int,
pointer_nets_type: str) -> 'LengthStatistics':
train_sources_sentences, train_target_sentences = create_sequence_readers(sources, target, vocab_sources,
vocab_target, use_pointer_nets,
max_oov_words, pointer_nets_type)
length_statistics = calculate_length_statistics(train_sources_sentences, train_target_sentences,
max_seq_len_source,
max_seq_len_target)
logger.info("%d sequences of maximum length (%d, %d) in '%s' and '%s'.",
length_statistics.num_sents, max_seq_len_source, max_seq_len_target, sources[0], target)
logger.info("Mean training target/source length ratio: %.2f (+-%.2f)",
length_statistics.length_ratio_mean,
length_statistics.length_ratio_std)
return length_statistics
def are_token_parallel(sequences: Sequence[Sized]) -> bool:
"""
Returns True if all sequences in the list have the same length.
"""
if not sequences or len(sequences) == 1:
return True
return all(len(s) == len(sequences[0]) for s in sequences)
class DataStatisticsAccumulator:
def __init__(self,
buckets: List[Tuple[int, int]],
vocab_source: Optional[Dict[str, int]],
vocab_target: Dict[str, int],
length_ratio_mean: float,
length_ratio_std: float) -> None:
self.buckets = buckets
num_buckets = len(buckets)
self.length_ratio_mean = length_ratio_mean
self.length_ratio_std = length_ratio_std
if vocab_source is not None:
self.unk_id_source = vocab_source[C.UNK_SYMBOL]
self.size_vocab_source = len(vocab_source)
else:
self.unk_id_source = None
self.size_vocab_source = 0
self.unk_id_target = vocab_target[C.UNK_SYMBOL]
self.size_vocab_target = len(vocab_target)
self.num_sents = 0
self.num_discarded = 0
self.num_tokens_source = 0
self.num_tokens_target = 0
self.num_unks_source = 0
self.num_unks_target = 0
self.max_observed_len_source = 0
self.max_observed_len_target = 0
self._mean_len_target_per_bucket = [OnlineMeanAndVariance() for _ in range(num_buckets)]
def sequence_pair(self,
source: List[int],
target: List[int],
bucket_idx: Optional[int]):
if bucket_idx is None:
self.num_discarded += 1
return
source_len = len(source)
target_len = len(target)
self._mean_len_target_per_bucket[bucket_idx].update(target_len)
self.num_sents += 1
self.num_tokens_source += source_len
self.num_tokens_target += target_len
self.max_observed_len_source = max(source_len, self.max_observed_len_source)
self.max_observed_len_target = max(target_len, self.max_observed_len_target)
if self.unk_id_source is not None:
self.num_unks_source += source.count(self.unk_id_source)
self.num_unks_target += target.count(self.unk_id_target)
@property
def mean_len_target_per_bucket(self) -> List[Optional[float]]:
return [mean_and_variance.mean if mean_and_variance.count > 0 else None
for mean_and_variance in self._mean_len_target_per_bucket]
@property
def statistics(self):
num_sents_per_bucket = [mean_and_variance.count for mean_and_variance in self._mean_len_target_per_bucket]
return DataStatistics(num_sents=self.num_sents,
num_discarded=self.num_discarded,
num_tokens_source=self.num_tokens_source,
num_tokens_target=self.num_tokens_target,
num_unks_source=self.num_unks_source,
num_unks_target=self.num_unks_target,
max_observed_len_source=self.max_observed_len_source,
max_observed_len_target=self.max_observed_len_target,
size_vocab_source=self.size_vocab_source,
size_vocab_target=self.size_vocab_target,
length_ratio_mean=self.length_ratio_mean,
length_ratio_std=self.length_ratio_std,
buckets=self.buckets,
num_sents_per_bucket=num_sents_per_bucket,
mean_len_target_per_bucket=self.mean_len_target_per_bucket)
def shard_data(source_fnames: List[str],
target_fname: str,
source_vocabs: List[vocab.Vocab],
target_vocab: vocab.Vocab,
num_shards: int,
buckets: List[Tuple[int, int]],
length_ratio_mean: float,
length_ratio_std: float,
output_prefix: str,
use_pointer_nets: bool,
max_oov_words: int,
pointer_nets_type: str) -> Tuple[List[Tuple[List[str], str, 'DataStatistics']], 'DataStatistics']:
"""
Assign int-coded source/target sentence pairs to shards at random.
:param source_fnames: The path to the source text (and optional token-parallel factor files).
:param target_fname: The file name of the target file.
:param source_vocabs: Source vocabulary (and optional source factor vocabularies).
:param target_vocab: Target vocabulary.
:param num_shards: The total number of shards.
:param buckets: Bucket list.
:param length_ratio_mean: Mean length ratio.
:param length_ratio_std: Standard deviation of length ratios.
:param output_prefix: The prefix under which the shard files will be created.
:return: Tuple of source (and source factor) file names, target file names and statistics for each shard,
as well as global statistics.
"""
os.makedirs(output_prefix, exist_ok=True)
sources_shard_fnames = [[os.path.join(output_prefix, C.SHARD_SOURCE % i) + ".%d" % f for i in range(num_shards)]
for f in range(len(source_fnames))]
target_shard_fnames = [os.path.join(output_prefix, C.SHARD_TARGET % i)
for i in range(num_shards)] # type: List[str]
data_stats_accumulator = DataStatisticsAccumulator(buckets, source_vocabs[0], target_vocab,
length_ratio_mean, length_ratio_std)
per_shard_stat_accumulators = [DataStatisticsAccumulator(buckets, source_vocabs[0], target_vocab, length_ratio_mean,
length_ratio_std) for shard_idx in range(num_shards)]
with ExitStack() as exit_stack:
sources_shards = [[exit_stack.enter_context(smart_open(f, mode="wt")) for f in sources_shard_fnames[i]] for i in
range(len(source_fnames))]
target_shards = [exit_stack.enter_context(smart_open(f, mode="wt")) for f in target_shard_fnames]
source_readers, target_reader = create_sequence_readers(source_fnames, target_fname, source_vocabs,
target_vocab, use_pointer_nets, max_oov_words,
pointer_nets_type)
random_shard_iter = iter(lambda: random.randrange(num_shards), None)
for (sources, target), random_shard_index in zip(parallel_iter(source_readers, target_reader),
random_shard_iter):
random_shard_index = cast(int, random_shard_index)
source_len = len(sources[0])
target_len = len(target)
buck_idx, buck = get_parallel_bucket(buckets, source_len, target_len)
data_stats_accumulator.sequence_pair(sources[0], target, buck_idx)
per_shard_stat_accumulators[random_shard_index].sequence_pair(sources[0], target, buck_idx)
if buck is None:
continue
for i, line in enumerate(sources):
sources_shards[i][random_shard_index].write(ids2strids(line) + "\n")
target_shards[random_shard_index].write(ids2strids(target) + "\n")
per_shard_stats = [shard_stat_accumulator.statistics for shard_stat_accumulator in per_shard_stat_accumulators]
sources_shard_fnames_by_shards = zip(*sources_shard_fnames) # type: List[List[str]]
return list(
zip(sources_shard_fnames_by_shards, target_shard_fnames, per_shard_stats)), data_stats_accumulator.statistics
class RawParallelDatasetLoader:
"""
Loads a data set of variable-length parallel source/target sequences into buckets of NDArrays.
:param buckets: Bucket list.
:param eos_id: End-of-sentence id.
:param pad_id: Padding id.
:param eos_id: Unknown id.
:param dtype: Data type.
"""
def __init__(self,
buckets: List[Tuple[int, int]],
eos_id: int,
pad_id: int,
target_vocab_size: int,
aligner: Optional[align.Aligner] = None,
dtype: str = 'float32') -> None:
self.buckets = buckets
self.eos_id = eos_id
self.pad_id = pad_id
self.dtype = dtype
self.target_vocab_size = target_vocab_size
self.aligner = aligner
def load(self,
source_iterables: Sequence[Iterable],
target_iterable: Iterable,
num_samples_per_bucket: List[int],
use_pointer_nets: bool,
pointer_nets_type: str) -> 'ParallelDataSet':
assert len(num_samples_per_bucket) == len(self.buckets)
num_factors = len(source_iterables)
data_source = [np.full((num_samples, source_len, num_factors), self.pad_id, dtype=self.dtype)
for (source_len, target_len), num_samples in zip(self.buckets, num_samples_per_bucket)]
data_target = [np.full((num_samples, target_len), self.pad_id, dtype=self.dtype)
for (source_len, target_len), num_samples in zip(self.buckets, num_samples_per_bucket)]
data_label = [np.full((num_samples, target_len), self.pad_id, dtype=self.dtype)
for (source_len, target_len), num_samples in zip(self.buckets, num_samples_per_bucket)]
bucket_sample_index = [0 for _ in self.buckets]
# track amount of padding introduced through bucketing
num_tokens_source = 0
num_tokens_target = 0
num_pad_source = 0
num_pad_target = 0
# Bucket sentences as padded np arrays
for sources, target in parallel_iter(source_iterables, target_iterable):
source_len = len(sources[0])
target_len = len(target)
buck_index, buck = get_parallel_bucket(self.buckets, source_len, target_len)
if buck is None:
continue # skip this sentence pair
num_tokens_source += buck[0]
num_tokens_target += buck[1]
num_pad_source += buck[0] - source_len
num_pad_target += buck[1] - target_len
sample_index = bucket_sample_index[buck_index]
for i, s in enumerate(sources):
data_source[buck_index][sample_index, 0:source_len, i] = s
data_target[buck_index][sample_index, :target_len] = target
# NOTE(fhieber): while this is wasteful w.r.t memory, we need to explicitly create the label sequence
# with the EOS symbol here sentence-wise and not per-batch due to variable sequence length within a batch.
# Once MXNet allows item assignments given a list of indices (probably MXNet 1.0): e.g a[[0,1,5,2]] = x,
# we can try again to compute the label sequence on the fly in next().
data_label[buck_index][sample_index, :target_len] = target[1:] + [self.eos_id]
if pointer_nets_type != C.POINTER_NET_SUMMARY:
labels = target[1:] + [self.eos_id]
if self.aligner is not None:
labels = self.aligner.get_labels(sources[0], target, labels)
data_label[buck_index][sample_index, :target_len] = labels
bucket_sample_index[buck_index] += 1
for i in range(len(data_source)):
data_source[i] = mx.nd.array(data_source[i], dtype=self.dtype)
data_target[i] = mx.nd.array(data_target[i], dtype=self.dtype)
data_label[i] = mx.nd.array(data_label[i], dtype=self.dtype)
if num_tokens_source > 0 and num_tokens_target > 0:
logger.info("Created bucketed parallel data set. Introduced padding: source=%.1f%% target=%.1f%%)",
num_pad_source / num_tokens_source * 100,
num_pad_target / num_tokens_target * 100)
if pointer_nets_type != C.POINTER_NET_SUMMARY:
labels = target[1:] + [self.eos_id]
if self.aligner is not None:
labels = self.aligner.get_labels(sources[0], target, labels)
return ParallelDataSet(data_source, data_target, data_label)
def get_num_shards(num_samples: int, samples_per_shard: int, min_num_shards: int) -> int:
"""
Returns the number of shards.
:param num_samples: Number of training data samples.
:param samples_per_shard: Samples per shard.
:param min_num_shards: Minimum number of shards.
:return: Number of shards.
"""
return max(int(math.ceil(num_samples / samples_per_shard)), min_num_shards)
def prepare_data(source_fnames: List[str],
target_fname: str,
source_vocabs: List[vocab.Vocab],
target_vocab: vocab.Vocab,
source_vocab_paths: List[Optional[str]],
target_vocab_path: Optional[str],
shared_vocab: bool,
max_seq_len_source: int,
max_seq_len_target: int,
bucketing: bool,
bucket_width: int,
samples_per_shard: int,
min_num_shards: int,
output_prefix: str,
use_pointer_nets: bool,
max_oov_words: int,
pointer_nets_type: str,
aligner: Optional[align.Aligner] = None,
keep_tmp_shard_files: bool = False):
logger.info("Preparing data.")
# write vocabularies to data folder
vocab.save_source_vocabs(source_vocabs, output_prefix)
vocab.save_target_vocab(target_vocab, output_prefix)
# Pass 1: get target/source length ratios.
length_statistics = analyze_sequence_lengths(source_fnames, target_fname, source_vocabs, target_vocab,
max_seq_len_source, max_seq_len_target,
use_pointer_nets=use_pointer_nets, max_oov_words=max_oov_words,
pointer_nets_type=pointer_nets_type)
# define buckets
buckets = define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width,
length_statistics.length_ratio_mean) if bucketing else [
(max_seq_len_source, max_seq_len_target)]
logger.info("Buckets: %s", buckets)
# Pass 2: Randomly assign data to data shards
# no pre-processing yet, just write the sentences to different files
num_shards = get_num_shards(length_statistics.num_sents, samples_per_shard, min_num_shards)
logger.info("%d samples will be split into %d shard(s) (requested samples/shard=%d, min_num_shards=%d)."
% (length_statistics.num_sents, num_shards, samples_per_shard, min_num_shards))
shards, data_statistics = shard_data(source_fnames=source_fnames,
target_fname=target_fname,
source_vocabs=source_vocabs,
target_vocab=target_vocab,
num_shards=num_shards,
buckets=buckets,
length_ratio_mean=length_statistics.length_ratio_mean,
length_ratio_std=length_statistics.length_ratio_std,
output_prefix=output_prefix,
use_pointer_nets=use_pointer_nets,
max_oov_words=max_oov_words,
pointer_nets_type=pointer_nets_type)
data_statistics.log()
data_loader = RawParallelDatasetLoader(buckets=buckets,
eos_id=target_vocab[C.EOS_SYMBOL],
pad_id=C.PAD_ID,
target_vocab_size=len(target_vocab),
aligner=aligner)
# 3. convert each shard to serialized ndarrays
for shard_idx, (shard_sources, shard_target, shard_stats) in enumerate(shards):
sources_sentences = [SequenceReader(s, use_pointer_nets=use_pointer_nets, max_oov_words=max_oov_words,
pointer_nets_type=pointer_nets_type)
for s in shard_sources]
target_sentences = SequenceReader(shard_target, use_pointer_nets=use_pointer_nets, max_oov_words=max_oov_words,
pointer_nets_type=pointer_nets_type)
dataset = data_loader.load(sources_sentences, target_sentences, shard_stats.num_sents_per_bucket,
use_pointer_nets, pointer_nets_type)
shard_fname = os.path.join(output_prefix, C.SHARD_NAME % shard_idx)
shard_stats.log()
logger.info("Writing '%s'", shard_fname)
dataset.save(shard_fname)
if not keep_tmp_shard_files:
for f in shard_sources:
os.remove(f)
os.remove(shard_target)
data_info = DataInfo(sources=[os.path.abspath(fname) for fname in source_fnames],
target=os.path.abspath(target_fname),
source_vocabs=source_vocab_paths,
target_vocab=target_vocab_path,
shared_vocab=shared_vocab,
num_shards=num_shards)
data_info_fname = os.path.join(output_prefix, C.DATA_INFO)
logger.info("Writing data info to '%s'", data_info_fname)
data_info.save(data_info_fname)
config_data = DataConfig(data_statistics=data_statistics,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
num_source_factors=len(source_fnames),
use_pointer_nets=aligner is not None,
pointer_nets_type=pointer_nets_type,
source_with_eos=True)
config_data_fname = os.path.join(output_prefix, C.DATA_CONFIG)
logger.info("Writing data config to '%s'", config_data_fname)
config_data.save(config_data_fname)
version_file = os.path.join(output_prefix, C.PREPARED_DATA_VERSION_FILE)
with open(version_file, "w") as version_out:
version_out.write(str(C.PREPARED_DATA_VERSION))
if aligner is not None and pointer_nets_type != C.POINTER_NET_SUMMARY:
logger.info("Pointed to %d / %d source words (%.2f%%)",
aligner.num_pointed, aligner.num_total, 100 * aligner.num_pointed / aligner.num_total)
def get_data_statistics(source_readers: Sequence[Iterable],
target_reader: Iterable,
buckets: List[Tuple[int, int]],
length_ratio_mean: float,
length_ratio_std: float,
source_vocabs: List[vocab.Vocab],
target_vocab: vocab.Vocab) -> 'DataStatistics':
data_stats_accumulator = DataStatisticsAccumulator(buckets, source_vocabs[0], target_vocab,
length_ratio_mean, length_ratio_std)
if source_readers is not None and target_reader is not None:
for sources, target in parallel_iter(source_readers, target_reader):
buck_idx, buck = get_parallel_bucket(buckets, len(sources[0]), len(target))
data_stats_accumulator.sequence_pair(sources[0], target, buck_idx)
else: # Allow stats for target only data
for target in target_reader:
buck_idx, buck = get_target_bucket(buckets, len(target))
data_stats_accumulator.sequence_pair([], target, buck_idx)
return data_stats_accumulator.statistics
def get_validation_data_iter(data_loader: RawParallelDatasetLoader,
validation_sources: List[str],
validation_target: str,
buckets: List[Tuple[int, int]],
bucket_batch_sizes: List[BucketBatchSize],
source_vocabs: List[vocab.Vocab],
target_vocab: vocab.Vocab,
max_seq_len_source: int,
max_seq_len_target: int,
batch_size: int,
fill_up: str,
use_pointer_nets: bool,
max_oov_words: int,
pointer_nets_type: str) -> 'ParallelSampleIter':
"""
Returns a ParallelSampleIter for the validation data.
"""
logger.info("=================================")
logger.info("Creating validation data iterator")
logger.info("=================================")
validation_length_statistics = analyze_sequence_lengths(validation_sources, validation_target,
source_vocabs, target_vocab,
max_seq_len_source, max_seq_len_target,
use_pointer_nets,
max_oov_words,
pointer_nets_type)
validation_sources_sentences, validation_target_sentences = create_sequence_readers(validation_sources,
validation_target,
source_vocabs, target_vocab,
use_pointer_nets, max_oov_words,
pointer_nets_type)
validation_data_statistics = get_data_statistics(validation_sources_sentences,
validation_target_sentences,
buckets,
validation_length_statistics.length_ratio_mean,
validation_length_statistics.length_ratio_std,
source_vocabs, target_vocab)
validation_data_statistics.log(bucket_batch_sizes)
validation_data = data_loader.load(validation_sources_sentences, validation_target_sentences,
validation_data_statistics.num_sents_per_bucket,
use_pointer_nets, pointer_nets_type).fill_up(
bucket_batch_sizes,
fill_up)
return ParallelSampleIter(data=validation_data,
buckets=buckets,
batch_size=batch_size,
bucket_batch_sizes=bucket_batch_sizes,
num_factors=len(validation_sources))
def get_prepared_data_iters(prepared_data_dir: str,
validation_sources: List[str],
validation_target: str,
shared_vocab: bool,
batch_size: int,
batch_by_words: bool,
batch_num_devices: int,
fill_up: str,
use_pointer_nets: bool,
max_oov_words: int,
pointer_nets_type: str) -> Tuple['BaseParallelSampleIter',
'BaseParallelSampleIter',
'DataConfig', List[vocab.Vocab], vocab.Vocab]:
logger.info("===============================")
logger.info("Creating training data iterator")
logger.info("===============================")
version_file = os.path.join(prepared_data_dir, C.PREPARED_DATA_VERSION_FILE)
with open(version_file) as version_in:
version = int(version_in.read())
check_condition(version == C.PREPARED_DATA_VERSION,
"The dataset %s was written in an old and incompatible format. Please rerun data "
"preparation with a current version of Sockeye." % prepared_data_dir)
info_file = os.path.join(prepared_data_dir, C.DATA_INFO)
check_condition(os.path.exists(info_file),
"Could not find data info %s. Are you sure %s is a directory created with "
"python -m sockeye.prepare_data?" % (info_file, prepared_data_dir))
data_info = cast(DataInfo, DataInfo.load(info_file))
config_file = os.path.join(prepared_data_dir, C.DATA_CONFIG)
check_condition(os.path.exists(config_file),
"Could not find data config %s. Are you sure %s is a directory created with "
"python -m sockeye.prepare_data?" % (config_file, prepared_data_dir))
config_data = cast(DataConfig, DataConfig.load(config_file))
shard_fnames = [os.path.join(prepared_data_dir,
C.SHARD_NAME % shard_idx) for shard_idx in range(data_info.num_shards)]
for shard_fname in shard_fnames:
check_condition(os.path.exists(shard_fname), "Shard %s does not exist." % shard_fname)
check_condition(shared_vocab == data_info.shared_vocab, "Shared config needed (e.g. for weight tying), but "
"data was prepared without a shared vocab. Use --shared-vocab when "
"preparing the data.")
source_vocabs = vocab.load_source_vocabs(prepared_data_dir)
target_vocab = vocab.load_target_vocab(prepared_data_dir)
check_condition(len(source_vocabs) == len(data_info.sources),
"Wrong number of source vocabularies. Found %d, need %d." % (len(source_vocabs),
len(data_info.sources)))
buckets = config_data.data_statistics.buckets
max_seq_len_source = config_data.max_seq_len_source
max_seq_len_target = config_data.max_seq_len_target
bucket_batch_sizes = define_bucket_batch_sizes(buckets,
batch_size,
batch_by_words,
batch_num_devices,
config_data.data_statistics.average_len_target_per_bucket)
config_data.data_statistics.log(bucket_batch_sizes)
train_iter = ShardedParallelSampleIter(shard_fnames,
buckets,
batch_size,
bucket_batch_sizes,
fill_up,
num_factors=len(data_info.sources))
data_loader = RawParallelDatasetLoader(buckets=buckets,
eos_id=target_vocab[C.EOS_SYMBOL],
pad_id=C.PAD_ID,
target_vocab_size=len(target_vocab))
validation_iter = get_validation_data_iter(data_loader=data_loader,
validation_sources=validation_sources,
validation_target=validation_target,
buckets=buckets,
bucket_batch_sizes=bucket_batch_sizes,
source_vocabs=source_vocabs,
target_vocab=target_vocab,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
batch_size=batch_size,
fill_up=fill_up,
use_pointer_nets=use_pointer_nets,
max_oov_words=max_oov_words,
pointer_nets_type=pointer_nets_type)
return train_iter, validation_iter, config_data, source_vocabs, target_vocab
def get_training_data_iters(sources: List[str],
target: str,
validation_sources: List[str],
validation_target: str,
source_vocabs: List[vocab.Vocab],
target_vocab: vocab.Vocab,
source_vocab_paths: List[Optional[str]],
target_vocab_path: Optional[str],
shared_vocab: bool,
batch_size: int,
batch_by_words: bool,
batch_num_devices: int,
fill_up: str,
max_seq_len_source: int,
max_seq_len_target: int,
bucketing: bool,
bucket_width: int,
use_pointer_nets: bool,
max_oov_words: int,
pointer_nets_type: str,
aligner: Optional[align.Aligner] = None) -> Tuple['BaseParallelSampleIter',
'BaseParallelSampleIter',
'DataConfig', 'DataInfo']:
"""
Returns data iterators for training and validation data.
:param sources: Path to source training data (with optional factor data paths).
:param target: Path to target training data.
:param validation_sources: Path to source validation data (with optional factor data paths).
:param validation_target: Path to target validation data.
:param source_vocabs: Source vocabulary and optional factor vocabularies.
:param target_vocab: Target vocabulary.
:param source_vocab_paths: Path to source vocabulary.
:param target_vocab_path: Path to target vocabulary.
:param shared_vocab: Whether the vocabularies are shared.
:param batch_size: Batch size.
:param batch_by_words: Size batches by words rather than sentences.
:param batch_num_devices: Number of devices batches will be parallelized across.
:param fill_up: Fill-up strategy for buckets.
:param max_seq_len_source: Maximum source sequence length.
:param max_seq_len_target: Maximum target sequence length.
:param bucketing: Whether to use bucketing.
:param bucket_width: Size of buckets.
:param aligner: The aligner to use if pointer nets are enabled.
:param use_pointer_nets: Flag to indicate if pointer networks is enabled.
:param max_oov_words: Max out-of-vocabulary words to consider if point-nets-summary is used.
:param pointer_nets_type: Pointer Networks implementation to use.
:return: Tuple of (training data iterator, validation data iterator, data config).
"""
logger.info("===============================")
logger.info("Creating training data iterator")
logger.info("===============================")
# Pass 1: get target/source length ratios.
length_statistics = analyze_sequence_lengths(sources, target, source_vocabs, target_vocab,
max_seq_len_source, max_seq_len_target,
use_pointer_nets, max_oov_words,
pointer_nets_type)
# define buckets
buckets = define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width,
length_statistics.length_ratio_mean) if bucketing else [
(max_seq_len_source, max_seq_len_target)]
sources_sentences, target_sentences = create_sequence_readers(sources, target, source_vocabs, target_vocab,
use_pointer_nets, max_oov_words,
pointer_nets_type)
# 2. pass: Get data statistics
data_statistics = get_data_statistics(sources_sentences, target_sentences, buckets,
length_statistics.length_ratio_mean, length_statistics.length_ratio_std,
source_vocabs, target_vocab)
bucket_batch_sizes = define_bucket_batch_sizes(buckets,
batch_size,
batch_by_words,
batch_num_devices,
data_statistics.average_len_target_per_bucket)
data_statistics.log(bucket_batch_sizes)
data_loader = RawParallelDatasetLoader(buckets=buckets,
eos_id=target_vocab[C.EOS_SYMBOL],
pad_id=C.PAD_ID,
target_vocab_size=len(target_vocab),
aligner=aligner)
training_data = data_loader.load(sources_sentences, target_sentences,
data_statistics.num_sents_per_bucket, use_pointer_nets,
pointer_nets_type).fill_up(bucket_batch_sizes,
fill_up)
data_info = DataInfo(sources=sources,
target=target,
source_vocabs=source_vocab_paths,
target_vocab=target_vocab_path,
shared_vocab=shared_vocab,
num_shards=1)
config_data = DataConfig(data_statistics=data_statistics,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
num_source_factors=len(sources),
source_with_eos=True)
train_iter = ParallelSampleIter(data=training_data,
buckets=buckets,
batch_size=batch_size,
bucket_batch_sizes=bucket_batch_sizes,
num_factors=len(sources))
validation_iter = get_validation_data_iter(data_loader=data_loader,
validation_sources=validation_sources,
validation_target=validation_target,
buckets=buckets,
bucket_batch_sizes=bucket_batch_sizes,
source_vocabs=source_vocabs,
target_vocab=target_vocab,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
batch_size=batch_size,
fill_up=fill_up,
use_pointer_nets=use_pointer_nets,
max_oov_words=max_oov_words,
pointer_nets_type=pointer_nets_type)
return train_iter, validation_iter, config_data, data_info
class LengthStatistics(config.Config):
def __init__(self,
num_sents: int,
length_ratio_mean: float,
length_ratio_std: float) -> None:
super().__init__()
self.num_sents = num_sents
self.length_ratio_mean = length_ratio_mean
self.length_ratio_std = length_ratio_std
class DataStatistics(config.Config):
def __init__(self,
num_sents: int,
num_discarded,
num_tokens_source,
num_tokens_target,
num_unks_source,
num_unks_target,
max_observed_len_source,
max_observed_len_target,
size_vocab_source,
size_vocab_target,
length_ratio_mean,
length_ratio_std,
buckets: List[Tuple[int, int]],
num_sents_per_bucket: List[int],
mean_len_target_per_bucket: List[Optional[float]]) -> None:
super().__init__()
self.num_sents = num_sents
self.num_discarded = num_discarded
self.num_tokens_source = num_tokens_source
self.num_tokens_target = num_tokens_target
self.num_unks_source = num_unks_source
self.num_unks_target = num_unks_target
self.max_observed_len_source = max_observed_len_source
self.max_observed_len_target = max_observed_len_target
self.size_vocab_source = size_vocab_source
self.size_vocab_target = size_vocab_target
self.length_ratio_mean = length_ratio_mean
self.length_ratio_std = length_ratio_std
self.buckets = buckets
self.num_sents_per_bucket = num_sents_per_bucket
self.average_len_target_per_bucket = mean_len_target_per_bucket
def log(self, bucket_batch_sizes: Optional[List[BucketBatchSize]] = None):
logger.info("Tokens: source %d target %d", self.num_tokens_source, self.num_tokens_target)
if self.num_tokens_source > 0 and self.num_tokens_target > 0:
logger.info("Vocabulary coverage: source %.0f%% target %.0f%%",
(1 - self.num_unks_source / self.num_tokens_source) * 100,
(1 - self.num_unks_target / self.num_tokens_target) * 100)
logger.info("%d sequences across %d buckets", self.num_sents, len(self.num_sents_per_bucket))
logger.info("%d sequences did not fit into buckets and were discarded", self.num_discarded)
if bucket_batch_sizes is not None:
describe_data_and_buckets(self, bucket_batch_sizes)