/
translations.py
2102 lines (1822 loc) · 79.1 KB
/
translations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Pipeline transformations for the FnApiRunner.
"""
# pytype: skip-file
# mypy: check-untyped-defs
import collections
import copy
import functools
import itertools
import logging
import operator
from builtins import object
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Collection
from typing import Container
from typing import DefaultDict
from typing import Dict
from typing import FrozenSet
from typing import Iterable
from typing import Iterator
from typing import List
from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import Set
from typing import Tuple
from typing import TypeVar
from typing import Union
from apache_beam import coders
from apache_beam.internal import pickler
from apache_beam.portability import common_urns
from apache_beam.portability import python_urns
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.worker import bundle_processor
from apache_beam.transforms import combiners
from apache_beam.transforms import core
from apache_beam.utils import proto_utils
from apache_beam.utils import timestamp
if TYPE_CHECKING:
from apache_beam.runners.portability.fn_api_runner.execution import ListBuffer
from apache_beam.runners.portability.fn_api_runner.execution import PartitionableBuffer
T = TypeVar('T')
# This module is experimental. No backwards-compatibility guarantees.
_LOGGER = logging.getLogger(__name__)
KNOWN_COMPOSITES = frozenset([
common_urns.primitives.GROUP_BY_KEY.urn,
common_urns.composites.COMBINE_PER_KEY.urn,
common_urns.primitives.PAR_DO.urn, # After SDF expansion.
])
COMBINE_URNS = frozenset([
common_urns.composites.COMBINE_PER_KEY.urn,
])
PAR_DO_URNS = frozenset([
common_urns.primitives.PAR_DO.urn,
common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
common_urns.sdf_components.SPLIT_AND_SIZE_RESTRICTIONS.urn,
common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
])
IMPULSE_BUFFER = b'impulse'
# TimerFamilyId is identified by transform name + timer family
# TODO(pabloem): Rename this type to express this name is unique per pipeline.
TimerFamilyId = Tuple[str, str]
BufferId = bytes
# SideInputId is identified by a consumer ParDo + tag.
SideInputId = Tuple[str, str]
SideInputAccessPattern = beam_runner_api_pb2.FunctionSpec
# A map from a PCollection coder ID to a Safe Coder ID
# A safe coder is a coder that can be used on the runner-side of the FnApi.
# A safe coder receives a byte string, and returns a type that can be
# understood by the runner when deserializing.
SafeCoderMapping = Dict[str, str]
# DataSideInput maps SideInputIds to a tuple of the encoded bytes of the side
# input content, and a payload specification regarding the type of side input
# (MultiMap / Iterable).
DataSideInput = Dict[SideInputId, Tuple[bytes, SideInputAccessPattern]]
DataOutput = Dict[str, BufferId]
# A map of [Transform ID, Timer Family ID] to [Buffer ID, Time Domain for timer]
# The time domain comes from beam_runner_api_pb2.TimeDomain. It may be
# EVENT_TIME or PROCESSING_TIME.
OutputTimers = MutableMapping[TimerFamilyId, Tuple[BufferId, Any]]
# A map of [Transform ID, Timer Family ID] to [Buffer CONTENTS, Timestamp]
OutputTimerData = MutableMapping[TimerFamilyId,
Tuple['PartitionableBuffer',
timestamp.Timestamp]]
BundleProcessResult = Tuple[beam_fn_api_pb2.InstructionResponse,
List[beam_fn_api_pb2.ProcessBundleSplitResponse]]
# TODO(pabloem): Change tha name to a more representative one
class DataInput(NamedTuple):
data: MutableMapping[str, 'PartitionableBuffer']
timers: MutableMapping[TimerFamilyId, 'PartitionableBuffer']
class Stage(object):
"""A set of Transforms that can be sent to the worker for processing."""
def __init__(
self,
name, # type: str
transforms, # type: List[beam_runner_api_pb2.PTransform]
downstream_side_inputs=None, # type: Optional[FrozenSet[str]]
must_follow=frozenset(), # type: FrozenSet[Stage]
parent=None, # type: Optional[str]
environment=None, # type: Optional[str]
forced_root=False):
self.name = name
self.transforms = transforms
self.downstream_side_inputs = downstream_side_inputs
self.must_follow = must_follow
self.timers = set() # type: Set[TimerFamilyId]
self.parent = parent
if environment is None:
environment = functools.reduce(
self._merge_environments,
(self._extract_environment(t) for t in transforms))
self.environment = environment
self.forced_root = forced_root
def __repr__(self):
must_follow = ', '.join(prev.name for prev in self.must_follow)
if self.downstream_side_inputs is None:
downstream_side_inputs = '<unknown>'
else:
downstream_side_inputs = ', '.join(
str(si) for si in self.downstream_side_inputs)
return "%s\n %s\n must follow: %s\n downstream_side_inputs: %s" % (
self.name,
'\n'.join([
"%s:%s" % (transform.unique_name, transform.spec.urn)
for transform in self.transforms
]),
must_follow,
downstream_side_inputs)
@staticmethod
def _extract_environment(transform):
# type: (beam_runner_api_pb2.PTransform) -> Optional[str]
environment = transform.environment_id
return environment if environment else None
@staticmethod
def _merge_environments(env1, env2):
# type: (Optional[str], Optional[str]) -> Optional[str]
if env1 is None:
return env2
elif env2 is None:
return env1
else:
if env1 != env2:
raise ValueError(
"Incompatible environments: '%s' != '%s'" %
(str(env1).replace('\n', ' '), str(env2).replace('\n', ' ')))
return env1
def can_fuse(self, consumer, context):
# type: (Stage, TransformContext) -> bool
try:
self._merge_environments(self.environment, consumer.environment)
except ValueError:
return False
def no_overlap(a, b):
return not a or not b or not a.intersection(b)
return (
not consumer.forced_root and not self in consumer.must_follow and
self.is_all_sdk_urns(context) and consumer.is_all_sdk_urns(context) and
no_overlap(self.downstream_side_inputs, consumer.side_inputs()))
def fuse(self, other, context):
# type: (Stage, TransformContext) -> Stage
return Stage(
"(%s)+(%s)" % (self.name, other.name),
self.transforms + other.transforms,
union(self.downstream_side_inputs, other.downstream_side_inputs),
union(self.must_follow, other.must_follow),
environment=self._merge_environments(
self.environment, other.environment),
parent=_parent_for_fused_stages([self, other], context),
forced_root=self.forced_root or other.forced_root)
def is_runner_urn(self, context):
# type: (TransformContext) -> bool
return any(
transform.spec.urn in context.known_runner_urns
for transform in self.transforms)
def is_all_sdk_urns(self, context):
def is_sdk_transform(transform):
# Execute multi-input flattens in the runner.
if transform.spec.urn == common_urns.primitives.FLATTEN.urn and len(
transform.inputs) > 1:
return False
else:
return transform.spec.urn not in context.runner_only_urns
return all(is_sdk_transform(transform) for transform in self.transforms)
def is_stateful(self):
for transform in self.transforms:
if transform.spec.urn in PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
if payload.state_specs or payload.timer_family_specs:
return True
return False
def side_inputs(self):
# type: () -> Iterator[str]
for transform in self.transforms:
yield from side_inputs(transform).values()
def has_as_main_input(self, pcoll):
for transform in self.transforms:
if transform.spec.urn in PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
local_side_inputs = payload.side_inputs
else:
local_side_inputs = {} # type: ignore[assignment]
for local_id, pipeline_id in transform.inputs.items():
if pcoll == pipeline_id and local_id not in local_side_inputs:
return True
def deduplicate_read(self):
# type: () -> None
seen_pcolls = set() # type: Set[str]
new_transforms = []
for transform in self.transforms:
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
pcoll = only_element(list(transform.outputs.items()))[1]
if pcoll in seen_pcolls:
continue
seen_pcolls.add(pcoll)
new_transforms.append(transform)
self.transforms = new_transforms
def executable_stage_transform(
self,
known_runner_urns, # type: FrozenSet[str]
all_consumers,
components # type: beam_runner_api_pb2.Components
):
# type: (...) -> beam_runner_api_pb2.PTransform
if (len(self.transforms) == 1 and
self.transforms[0].spec.urn in known_runner_urns):
result = copy.copy(self.transforms[0])
del result.subtransforms[:]
return result
else:
all_inputs = set(
pcoll for t in self.transforms for pcoll in t.inputs.values())
all_outputs = set(
pcoll for t in self.transforms for pcoll in t.outputs.values())
internal_transforms = set(id(t) for t in self.transforms)
external_outputs = [
pcoll for pcoll in all_outputs
if all_consumers[pcoll] - internal_transforms
]
stage_components = beam_runner_api_pb2.Components()
stage_components.CopyFrom(components)
# Only keep the PCollections referenced in this stage.
stage_components.pcollections.clear()
for pcoll_id in all_inputs.union(all_outputs):
stage_components.pcollections[pcoll_id].CopyFrom(
components.pcollections[pcoll_id])
# Only keep the transforms in this stage.
# Also gather up payload data as we iterate over the transforms.
stage_components.transforms.clear()
main_inputs = set() # type: Set[str]
side_inputs = []
user_states = []
timers = []
for ix, transform in enumerate(self.transforms):
transform_id = 'transform_%d' % ix
if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
for tag in payload.side_inputs.keys():
side_inputs.append(
beam_runner_api_pb2.ExecutableStagePayload.SideInputId(
transform_id=transform_id, local_name=tag))
for tag in payload.state_specs.keys():
user_states.append(
beam_runner_api_pb2.ExecutableStagePayload.UserStateId(
transform_id=transform_id, local_name=tag))
for tag in payload.timer_family_specs.keys():
timers.append(
beam_runner_api_pb2.ExecutableStagePayload.TimerId(
transform_id=transform_id, local_name=tag))
main_inputs.update(
pcoll_id for tag,
pcoll_id in transform.inputs.items()
if tag not in payload.side_inputs)
else:
main_inputs.update(transform.inputs.values())
stage_components.transforms[transform_id].CopyFrom(transform)
main_input_id = only_element(main_inputs - all_outputs)
named_inputs = dict({
'%s:%s' % (side.transform_id, side.local_name):
stage_components.transforms[side.transform_id].inputs[side.local_name]
for side in side_inputs
},
main_input=main_input_id)
# at this point we should have resolved an environment, as the key of
# components.environments cannot be None.
assert self.environment is not None
exec_payload = beam_runner_api_pb2.ExecutableStagePayload(
environment=components.environments[self.environment],
input=main_input_id,
outputs=external_outputs,
transforms=stage_components.transforms.keys(),
components=stage_components,
side_inputs=side_inputs,
user_states=user_states,
timers=timers)
return beam_runner_api_pb2.PTransform(
unique_name=unique_name(None, self.name),
spec=beam_runner_api_pb2.FunctionSpec(
urn='beam:runner:executable_stage:v1',
payload=exec_payload.SerializeToString()),
inputs=named_inputs,
outputs={
'output_%d' % ix: pcoll
for ix,
pcoll in enumerate(external_outputs)
},
)
def memoize_on_instance(f):
missing = object()
def wrapper(self, *args):
try:
cache = getattr(self, '_cache_%s' % f.__name__)
except AttributeError:
cache = {}
setattr(self, '_cache_%s' % f.__name__, cache)
result = cache.get(args, missing)
if result is missing:
result = cache[args] = f(self, *args)
return result
return wrapper
class TransformContext(object):
_COMMON_CODER_URNS = set(
value.urn for (key, value) in common_urns.coders.__dict__.items()
if not key.startswith('_')
# Length prefix Rows rather than re-coding them.
) - set([common_urns.coders.ROW.urn])
_REQUIRED_CODER_URNS = set([
common_urns.coders.WINDOWED_VALUE.urn,
# For impulse.
common_urns.coders.BYTES.urn,
common_urns.coders.GLOBAL_WINDOW.urn,
# For GBK.
common_urns.coders.KV.urn,
common_urns.coders.ITERABLE.urn,
# For SDF.
common_urns.coders.DOUBLE.urn,
# For timers.
common_urns.coders.TIMER.urn,
# For everything else.
common_urns.coders.LENGTH_PREFIX.urn,
common_urns.coders.CUSTOM_WINDOW.urn,
])
def __init__(
self,
components, # type: beam_runner_api_pb2.Components
known_runner_urns, # type: FrozenSet[str]
use_state_iterables=False,
is_drain=False):
self.components = components
self.known_runner_urns = known_runner_urns
self.runner_only_urns = known_runner_urns - frozenset(
[common_urns.primitives.FLATTEN.urn])
self._known_coder_urns = set.union(
# Those which are required.
self._REQUIRED_CODER_URNS,
# Those common coders which are understood by all environments.
self._COMMON_CODER_URNS.intersection(
*(
set(env.capabilities)
for env in self.components.environments.values())))
self.use_state_iterables = use_state_iterables
self.is_drain = is_drain
# ok to pass None for context because BytesCoder has no components
coder_proto = coders.BytesCoder().to_runner_api(
None) # type: ignore[arg-type]
self.bytes_coder_id = self.add_or_get_coder_id(coder_proto, 'bytes_coder')
self.safe_coders: SafeCoderMapping = {
self.bytes_coder_id: self.bytes_coder_id
}
# A map of PCollection ID to Coder ID.
self.data_channel_coders = {} # type: Dict[str, str]
def add_or_get_coder_id(
self,
coder_proto, # type: beam_runner_api_pb2.Coder
coder_prefix='coder'):
# type: (...) -> str
for coder_id, coder in self.components.coders.items():
if coder == coder_proto:
return coder_id
new_coder_id = unique_name(self.components.coders, coder_prefix)
self.components.coders[new_coder_id].CopyFrom(coder_proto)
return new_coder_id
def add_data_channel_coder(self, pcoll_id):
pcoll = self.components.pcollections[pcoll_id]
proto = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.WINDOWED_VALUE.urn),
component_coder_ids=[
pcoll.coder_id,
self.components.windowing_strategies[
pcoll.windowing_strategy_id].window_coder_id
])
self.data_channel_coders[pcoll_id] = self.maybe_length_prefixed_coder(
self.add_or_get_coder_id(proto, pcoll.coder_id + '_windowed'))
@memoize_on_instance
def with_state_iterables(self, coder_id):
# type: (str) -> str
coder = self.components.coders[coder_id]
if coder.spec.urn == common_urns.coders.ITERABLE.urn:
new_coder_id = unique_name(
self.components.coders, coder_id + '_state_backed')
new_coder = self.components.coders[new_coder_id]
new_coder.CopyFrom(coder)
new_coder.spec.urn = common_urns.coders.STATE_BACKED_ITERABLE.urn
new_coder.spec.payload = b'1'
new_coder.component_coder_ids[0] = self.with_state_iterables(
coder.component_coder_ids[0])
return new_coder_id
else:
new_component_ids = [
self.with_state_iterables(c) for c in coder.component_coder_ids
]
if new_component_ids == coder.component_coder_ids:
return coder_id
else:
new_coder_id = unique_name(
self.components.coders, coder_id + '_state_backed')
self.components.coders[new_coder_id].CopyFrom(
beam_runner_api_pb2.Coder(
spec=coder.spec, component_coder_ids=new_component_ids))
return new_coder_id
@memoize_on_instance
def maybe_length_prefixed_coder(self, coder_id):
# type: (str) -> str
if coder_id in self.safe_coders:
return coder_id
(maybe_length_prefixed_id,
safe_id) = self.maybe_length_prefixed_and_safe_coder(coder_id)
self.safe_coders[maybe_length_prefixed_id] = safe_id
return maybe_length_prefixed_id
@memoize_on_instance
def maybe_length_prefixed_and_safe_coder(self, coder_id):
# type: (str) -> Tuple[str, str]
coder = self.components.coders[coder_id]
if coder.spec.urn == common_urns.coders.LENGTH_PREFIX.urn:
return coder_id, self.bytes_coder_id
elif coder.spec.urn in self._known_coder_urns:
new_component_ids = [
self.maybe_length_prefixed_coder(c) for c in coder.component_coder_ids
]
if new_component_ids == coder.component_coder_ids:
new_coder_id = coder_id
else:
new_coder_id = unique_name(
self.components.coders, coder_id + '_length_prefixed')
self.components.coders[new_coder_id].CopyFrom(
beam_runner_api_pb2.Coder(
spec=coder.spec, component_coder_ids=new_component_ids))
safe_component_ids = [self.safe_coders[c] for c in new_component_ids]
if safe_component_ids == coder.component_coder_ids:
safe_coder_id = coder_id
else:
safe_coder_id = unique_name(self.components.coders, coder_id + '_safe')
self.components.coders[safe_coder_id].CopyFrom(
beam_runner_api_pb2.Coder(
spec=coder.spec, component_coder_ids=safe_component_ids))
return new_coder_id, safe_coder_id
else:
new_coder_id = unique_name(
self.components.coders, coder_id + '_length_prefixed')
self.components.coders[new_coder_id].CopyFrom(
beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.coders.LENGTH_PREFIX.urn),
component_coder_ids=[coder_id]))
return new_coder_id, self.bytes_coder_id
def length_prefix_pcoll_coders(self, pcoll_id):
# type: (str) -> None
self.components.pcollections[pcoll_id].coder_id = (
self.maybe_length_prefixed_coder(
self.components.pcollections[pcoll_id].coder_id))
@memoize_on_instance
def parents_map(self):
return {
child: parent
for (parent, transform) in self.components.transforms.items()
for child in transform.subtransforms
}
def leaf_transform_stages(
root_ids, # type: Iterable[str]
components, # type: beam_runner_api_pb2.Components
parent=None, # type: Optional[str]
known_composites=KNOWN_COMPOSITES # type: FrozenSet[str]
):
# type: (...) -> Iterator[Stage]
for root_id in root_ids:
root = components.transforms[root_id]
if root.spec.urn in known_composites:
yield Stage(root_id, [root], parent=parent)
elif not root.subtransforms:
# Make sure its outputs are not a subset of its inputs.
if set(root.outputs.values()) - set(root.inputs.values()):
yield Stage(root_id, [root], parent=parent)
else:
for stage in leaf_transform_stages(root.subtransforms,
components,
root_id,
known_composites):
yield stage
def pipeline_from_stages(
pipeline_proto, # type: beam_runner_api_pb2.Pipeline
stages, # type: Iterable[Stage]
known_runner_urns, # type: FrozenSet[str]
partial # type: bool
):
# type: (...) -> beam_runner_api_pb2.Pipeline
# In case it was a generator that mutates components as it
# produces outputs (as is the case with most transformations).
stages = list(stages)
new_proto = beam_runner_api_pb2.Pipeline()
new_proto.CopyFrom(pipeline_proto)
components = new_proto.components
components.transforms.clear()
components.pcollections.clear()
roots = set()
parents = {
child: parent
for parent,
proto in pipeline_proto.components.transforms.items()
for child in proto.subtransforms
}
def copy_output_pcollections(transform):
for pcoll_id in transform.outputs.values():
components.pcollections[pcoll_id].CopyFrom(
pipeline_proto.components.pcollections[pcoll_id])
def add_parent(child, parent):
if parent is None:
roots.add(child)
else:
if (parent not in components.transforms and
parent in pipeline_proto.components.transforms):
components.transforms[parent].CopyFrom(
pipeline_proto.components.transforms[parent])
copy_output_pcollections(components.transforms[parent])
del components.transforms[parent].subtransforms[:]
# Ensure that child is the last item in the parent's subtransforms.
# If the stages were previously sorted into topological order using
# sort_stages, this ensures that the parent transforms are also
# added in topological order.
if child in components.transforms[parent].subtransforms:
components.transforms[parent].subtransforms.remove(child)
components.transforms[parent].subtransforms.append(child)
add_parent(parent, parents.get(parent))
def copy_subtransforms(transform):
for subtransform_id in transform.subtransforms:
if subtransform_id not in pipeline_proto.components.transforms:
raise RuntimeError(
'Could not find subtransform to copy: ' + subtransform_id)
subtransform = pipeline_proto.components.transforms[subtransform_id]
components.transforms[subtransform_id].CopyFrom(subtransform)
copy_output_pcollections(components.transforms[subtransform_id])
copy_subtransforms(subtransform)
all_consumers = collections.defaultdict(
set) # type: DefaultDict[str, Set[int]]
for stage in stages:
for transform in stage.transforms:
for pcoll in transform.inputs.values():
all_consumers[pcoll].add(id(transform))
for stage in stages:
if partial:
transform = only_element(stage.transforms)
copy_subtransforms(transform)
else:
transform = stage.executable_stage_transform(
known_runner_urns, all_consumers, pipeline_proto.components)
transform_id = unique_name(components.transforms, stage.name)
components.transforms[transform_id].CopyFrom(transform)
copy_output_pcollections(transform)
add_parent(transform_id, stage.parent)
del new_proto.root_transform_ids[:]
new_proto.root_transform_ids.extend(roots)
return new_proto
def create_and_optimize_stages(
pipeline_proto, # type: beam_runner_api_pb2.Pipeline
phases,
known_runner_urns, # type: FrozenSet[str]
use_state_iterables=False,
is_drain=False):
# type: (...) -> Tuple[TransformContext, List[Stage]]
"""Create a set of stages given a pipeline proto, and set of optimizations.
Args:
pipeline_proto (beam_runner_api_pb2.Pipeline): A pipeline defined by a user.
phases (callable): Each phase identifies a specific transformation to be
applied to the pipeline graph. Existing phases are defined in this file,
and receive a list of stages, and a pipeline context. Some available
transformations are ``lift_combiners``, ``expand_sdf``, ``expand_gbk``,
etc.
Returns:
A tuple with a pipeline context, and a list of stages (i.e. an optimized
graph).
"""
pipeline_context = TransformContext(
pipeline_proto.components,
known_runner_urns,
use_state_iterables=use_state_iterables,
is_drain=is_drain)
# Initial set of stages are singleton leaf transforms.
stages = list(
leaf_transform_stages(
pipeline_proto.root_transform_ids,
pipeline_proto.components,
known_composites=union(known_runner_urns, KNOWN_COMPOSITES)))
# Apply each phase in order.
for phase in phases:
_LOGGER.info('%s %s %s', '=' * 20, phase, '=' * 20)
stages = list(phase(stages, pipeline_context))
_LOGGER.debug('%s %s' % (len(stages), [len(s.transforms) for s in stages]))
_LOGGER.debug('Stages: %s', [str(s) for s in stages])
# Return the (possibly mutated) context and ordered set of stages.
return pipeline_context, stages
def optimize_pipeline(
pipeline_proto, # type: beam_runner_api_pb2.Pipeline
phases,
known_runner_urns, # type: FrozenSet[str]
partial=False,
**kwargs):
unused_context, stages = create_and_optimize_stages(
pipeline_proto,
phases,
known_runner_urns,
**kwargs)
return pipeline_from_stages(
pipeline_proto, stages, known_runner_urns, partial)
# Optimization stages.
def annotate_downstream_side_inputs(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
"""Annotate each stage with fusion-prohibiting information.
Each stage is annotated with the (transitive) set of pcollections that
depend on this stage that are also used later in the pipeline as a
side input.
While theoretically this could result in O(n^2) annotations, the size of
each set is bounded by the number of side inputs (typically much smaller
than the number of total nodes) and the number of *distinct* side-input
sets is also generally small (and shared due to the use of union
defined above).
This representation is also amenable to simple recomputation on fusion.
"""
consumers = collections.defaultdict(
list) # type: DefaultDict[str, List[Stage]]
def get_all_side_inputs():
# type: () -> Set[str]
all_side_inputs = set() # type: Set[str]
for stage in stages:
for transform in stage.transforms:
for input in transform.inputs.values():
consumers[input].append(stage)
for si in stage.side_inputs():
all_side_inputs.add(si)
return all_side_inputs
all_side_inputs = frozenset(get_all_side_inputs())
downstream_side_inputs_by_stage = {} # type: Dict[Stage, FrozenSet[str]]
def compute_downstream_side_inputs(stage):
# type: (Stage) -> FrozenSet[str]
if stage not in downstream_side_inputs_by_stage:
downstream_side_inputs = frozenset() # type: FrozenSet[str]
for transform in stage.transforms:
for output in transform.outputs.values():
if output in all_side_inputs:
downstream_side_inputs = union(
downstream_side_inputs, frozenset([output]))
for consumer in consumers[output]:
downstream_side_inputs = union(
downstream_side_inputs,
compute_downstream_side_inputs(consumer))
downstream_side_inputs_by_stage[stage] = downstream_side_inputs
return downstream_side_inputs_by_stage[stage]
for stage in stages:
stage.downstream_side_inputs = compute_downstream_side_inputs(stage)
return stages
def annotate_stateful_dofns_as_roots(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
for stage in stages:
for transform in stage.transforms:
if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
pardo_payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
if pardo_payload.state_specs or pardo_payload.timer_family_specs:
stage.forced_root = True
yield stage
def fix_side_input_pcoll_coders(stages, pipeline_context):
# type: (Iterable[Stage], TransformContext) -> Iterable[Stage]
"""Length prefix side input PCollection coders.
"""
for stage in stages:
for si in stage.side_inputs():
pipeline_context.length_prefix_pcoll_coders(si)
return stages
def _group_stages_by_key(stages, get_stage_key):
grouped_stages = collections.defaultdict(list)
stages_with_none_key = []
for stage in stages:
stage_key = get_stage_key(stage)
if stage_key is None:
stages_with_none_key.append(stage)
else:
grouped_stages[stage_key].append(stage)
return (grouped_stages, stages_with_none_key)
def _group_stages_with_limit(stages, get_limit):
# type: (Iterable[Stage], Callable[[str], int]) -> Iterable[Collection[Stage]]
stages_with_limit = [(stage, get_limit(stage.name)) for stage in stages]
group: List[Stage] = []
group_limit = 0
for stage, limit in sorted(stages_with_limit, key=operator.itemgetter(1)):
if limit < 1:
raise Exception(
'expected get_limit to return an integer >= 1, '
'instead got: %d for stage: %s' % (limit, stage))
if not group:
group_limit = limit
assert len(group) < group_limit
group.append(stage)
if len(group) >= group_limit:
yield group
group = []
if group:
yield group
def _remap_input_pcolls(transform, pcoll_id_remap):
for input_key in list(transform.inputs.keys()):
if transform.inputs[input_key] in pcoll_id_remap:
transform.inputs[input_key] = pcoll_id_remap[transform.inputs[input_key]]
def _make_pack_name(names):
"""Return the packed Transform or Stage name.
The output name will contain the input names' common prefix, the infix
'/Packed', and the input names' suffixes in square brackets.
For example, if the input names are 'a/b/c1/d1' and 'a/b/c2/d2, then
the output name is 'a/b/Packed[c1_d1, c2_d2]'.
"""
assert names
tokens_in_names = [name.split('/') for name in names]
common_prefix_tokens = []
# Find the longest common prefix of tokens.
while True:
first_token_in_names = set()
for tokens in tokens_in_names:
if not tokens:
break
first_token_in_names.add(tokens[0])
if len(first_token_in_names) != 1:
break
common_prefix_tokens.append(next(iter(first_token_in_names)))
for tokens in tokens_in_names:
tokens.pop(0)
common_prefix_tokens.append('Packed')
common_prefix = '/'.join(common_prefix_tokens)
suffixes = ['_'.join(tokens) for tokens in tokens_in_names]
return '%s[%s]' % (common_prefix, ', '.join(suffixes))
def _eliminate_common_key_with_none(stages, context, can_pack=lambda s: True):
# type: (Iterable[Stage], TransformContext, Callable[[str], Union[bool, int]]) -> Iterable[Stage]
"""Runs common subexpression elimination for sibling KeyWithNone stages.
If multiple KeyWithNone stages share a common input, then all but one stages
will be eliminated along with their output PCollections. Transforms that
originally read input from the output PCollection of the eliminated
KeyWithNone stages will be remapped to read input from the output PCollection
of the remaining KeyWithNone stage.
"""
# Partition stages by whether they are eligible for common KeyWithNone
# elimination, and group eligible KeyWithNone stages by parent and
# environment.
def get_stage_key(stage):
if len(stage.transforms) == 1 and can_pack(stage.name):
transform = only_transform(stage.transforms)
if (transform.spec.urn == common_urns.primitives.PAR_DO.urn and
len(transform.inputs) == 1 and len(transform.outputs) == 1):
pardo_payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
if pardo_payload.do_fn.urn == python_urns.KEY_WITH_NONE_DOFN:
return (only_element(transform.inputs.values()), stage.environment)
return None
grouped_eligible_stages, ineligible_stages = _group_stages_by_key(
stages, get_stage_key)
# Eliminate stages and build the PCollection remapping dictionary.
pcoll_id_remap = {}
remaining_stages = []
for sibling_stages in grouped_eligible_stages.values():
if len(sibling_stages) > 1:
output_pcoll_ids = [
only_element(stage.transforms[0].outputs.values())
for stage in sibling_stages
]
parent = _parent_for_fused_stages(sibling_stages, context)
for to_delete_pcoll_id in output_pcoll_ids[1:]:
pcoll_id_remap[to_delete_pcoll_id] = output_pcoll_ids[0]
del context.components.pcollections[to_delete_pcoll_id]
sibling_stages[0].parent = parent
sibling_stages[0].name = _make_pack_name(
stage.name for stage in sibling_stages)
only_transform(
sibling_stages[0].transforms).unique_name = _make_pack_name(
only_transform(stage.transforms).unique_name
for stage in sibling_stages)
remaining_stages.append(sibling_stages[0])
# Remap all transforms in components.
for transform in context.components.transforms.values():
_remap_input_pcolls(transform, pcoll_id_remap)
# Yield stages while remapping input PCollections if needed.
stages_to_yield = itertools.chain(ineligible_stages, remaining_stages)
for stage in stages_to_yield:
transform = only_transform(stage.transforms)
_remap_input_pcolls(transform, pcoll_id_remap)
yield stage
_DEFAULT_PACK_COMBINERS_LIMIT = 128
def pack_per_key_combiners(stages, context, can_pack=lambda s: True):
# type: (Iterable[Stage], TransformContext, Callable[[str], Union[bool, int]]) -> Iterator[Stage]
"""Packs sibling CombinePerKey stages into a single CombinePerKey.
If CombinePerKey stages have a common input, one input each, and one output
each, pack the stages into a single stage that runs all CombinePerKeys and
outputs resulting tuples to a new PCollection. A subsequent stage unpacks
tuples from this PCollection and sends them to the original output
PCollections.
"""
class _UnpackFn(core.DoFn):
"""A DoFn that unpacks a packed to multiple tagged outputs.
Example:
tags = (T1, T2, ...)
input = (K, (V1, V2, ...))
output = TaggedOutput(T1, (K, V1)), TaggedOutput(T2, (K, V1)), ...
"""
def __init__(self, tags):
self._tags = tags
def process(self, element):
key, values = element
return [
core.pvalue.TaggedOutput(tag, (key, value)) for tag,
value in zip(self._tags, values)
]
def _get_fallback_coder_id():
return context.add_or_get_coder_id(
# passing None works here because there are no component coders
coders.registry.get_coder(object).to_runner_api(None)) # type: ignore[arg-type]
def _get_component_coder_id_from_kv_coder(coder, index):
assert index < 2
if coder.spec.urn == common_urns.coders.KV.urn and len(
coder.component_coder_ids) == 2:
return coder.component_coder_ids[index]
return _get_fallback_coder_id()
def _get_key_coder_id_from_kv_coder(coder):
return _get_component_coder_id_from_kv_coder(coder, 0)
def _get_value_coder_id_from_kv_coder(coder):
return _get_component_coder_id_from_kv_coder(coder, 1)
def _try_fuse_stages(a, b):