diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index aa381f6cfcb8e..683f52db4055d 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -171,6 +171,14 @@ def items(self): class FnApiRunner(runner.PipelineRunner): def __init__(self, use_grpc=False, sdk_harness_factory=None): + """Creates a new Fn API Runner. + + Args: + use_grpc: whether to use grpc or simply make in-process calls + defaults to False + sdk_harness_factory: callable used to instantiate customized sdk harnesses + typcially not set by users + """ super(FnApiRunner, self).__init__() self._last_uid = -1 self._use_grpc = use_grpc @@ -277,6 +285,150 @@ def deduplicate_read(self): safe_coders = {} + def lift_combiners(stages): + """Expands CombinePerKey into pre- and post-grouping stages. + + ... -> CombinePerKey -> ... + + becomes + + ... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ... + """ + def add_or_get_coder_id(coder_proto): + for coder_id, coder in pipeline_components.coders.items(): + if coder == coder_proto: + return coder_id + new_coder_id = unique_name(pipeline_components.coders, 'coder') + pipeline_components.coders[new_coder_id].CopyFrom(coder_proto) + return new_coder_id + + def windowed_coder_id(coder_id): + proto = beam_runner_api_pb2.Coder( + spec=beam_runner_api_pb2.SdkFunctionSpec( + spec=beam_runner_api_pb2.FunctionSpec( + urn=urns.WINDOWED_VALUE_CODER)), + component_coder_ids=[coder_id, window_coder_id]) + return add_or_get_coder_id(proto) + + for stage in stages: + assert len(stage.transforms) == 1 + transform = stage.transforms[0] + if transform.spec.urn == urns.COMBINE_PER_KEY_TRANSFORM: + combine_payload = proto_utils.parse_Bytes( + transform.spec.payload, beam_runner_api_pb2.CombinePayload) + + input_pcoll = pipeline_components.pcollections[only_element( + transform.inputs.values())] + output_pcoll = pipeline_components.pcollections[only_element( + transform.outputs.values())] + + windowed_input_coder = pipeline_components.coders[ + input_pcoll.coder_id] + element_coder_id, window_coder_id = ( + windowed_input_coder.component_coder_ids) + element_coder = pipeline_components.coders[element_coder_id] + key_coder_id, _ = element_coder.component_coder_ids + accumulator_coder_id = combine_payload.accumulator_coder_id + + key_accumulator_coder = beam_runner_api_pb2.Coder( + spec=beam_runner_api_pb2.SdkFunctionSpec( + spec=beam_runner_api_pb2.FunctionSpec( + urn=urns.KV_CODER)), + component_coder_ids=[key_coder_id, accumulator_coder_id]) + key_accumulator_coder_id = add_or_get_coder_id(key_accumulator_coder) + + accumulator_iter_coder = beam_runner_api_pb2.Coder( + spec=beam_runner_api_pb2.SdkFunctionSpec( + spec=beam_runner_api_pb2.FunctionSpec( + urn=urns.ITERABLE_CODER)), + component_coder_ids=[accumulator_coder_id]) + accumulator_iter_coder_id = add_or_get_coder_id( + accumulator_iter_coder) + + key_accumulator_iter_coder = beam_runner_api_pb2.Coder( + spec=beam_runner_api_pb2.SdkFunctionSpec( + spec=beam_runner_api_pb2.FunctionSpec( + urn=urns.KV_CODER)), + component_coder_ids=[key_coder_id, accumulator_iter_coder_id]) + key_accumulator_iter_coder_id = add_or_get_coder_id( + key_accumulator_iter_coder) + + precombined_pcoll_id = unique_name( + pipeline_components.pcollections, 'pcollection') + pipeline_components.pcollections[precombined_pcoll_id].CopyFrom( + beam_runner_api_pb2.PCollection( + unique_name=transform.unique_name + '/Precombine.out', + coder_id=windowed_coder_id(key_accumulator_coder_id), + windowing_strategy_id=input_pcoll.windowing_strategy_id, + is_bounded=input_pcoll.is_bounded)) + + grouped_pcoll_id = unique_name( + pipeline_components.pcollections, 'pcollection') + pipeline_components.pcollections[grouped_pcoll_id].CopyFrom( + beam_runner_api_pb2.PCollection( + unique_name=transform.unique_name + '/Group.out', + coder_id=windowed_coder_id(key_accumulator_iter_coder_id), + windowing_strategy_id=output_pcoll.windowing_strategy_id, + is_bounded=output_pcoll.is_bounded)) + + merged_pcoll_id = unique_name( + pipeline_components.pcollections, 'pcollection') + pipeline_components.pcollections[merged_pcoll_id].CopyFrom( + beam_runner_api_pb2.PCollection( + unique_name=transform.unique_name + '/Merge.out', + coder_id=windowed_coder_id(key_accumulator_coder_id), + windowing_strategy_id=output_pcoll.windowing_strategy_id, + is_bounded=output_pcoll.is_bounded)) + + def make_stage(base_stage, transform): + return Stage( + transform.unique_name, + [transform], + downstream_side_inputs=base_stage.downstream_side_inputs, + must_follow=base_stage.must_follow) + + yield make_stage( + stage, + beam_runner_api_pb2.PTransform( + unique_name=transform.unique_name + '/Precombine', + spec=beam_runner_api_pb2.FunctionSpec( + urn=urns.PRECOMBINE_TRANSFORM, + payload=transform.spec.payload), + inputs=transform.inputs, + outputs={'out': precombined_pcoll_id})) + + yield make_stage( + stage, + beam_runner_api_pb2.PTransform( + unique_name=transform.unique_name + '/Group', + spec=beam_runner_api_pb2.FunctionSpec( + urn=urns.GROUP_BY_KEY_TRANSFORM), + inputs={'in': precombined_pcoll_id}, + outputs={'out': grouped_pcoll_id})) + + yield make_stage( + stage, + beam_runner_api_pb2.PTransform( + unique_name=transform.unique_name + '/Merge', + spec=beam_runner_api_pb2.FunctionSpec( + urn=urns.MERGE_ACCUMULATORS_TRANSFORM, + payload=transform.spec.payload), + inputs={'in': grouped_pcoll_id}, + outputs={'out': merged_pcoll_id})) + + yield make_stage( + stage, + beam_runner_api_pb2.PTransform( + unique_name=transform.unique_name + '/ExtractOutputs', + spec=beam_runner_api_pb2.FunctionSpec( + urn=urns.EXTRACT_OUTPUTS_TRANSFORM, + payload=transform.spec.payload), + inputs={'in': merged_pcoll_id}, + outputs=transform.outputs)) + + else: + yield stage + def expand_gbk(stages): """Transforms each GBK into a write followed by a read. """ @@ -351,6 +503,8 @@ def fix_pcoll_coder(pcoll): # This is used later to correlate the read and write. param = str("group:%s" % stage.name) + if stage.name not in pipeline_components.transforms: + pipeline_components.transforms[stage.name].CopyFrom(transform) gbk_write = Stage( transform.unique_name + '/Write', [beam_runner_api_pb2.PTransform( @@ -613,7 +767,8 @@ def process(stage): pcoll.coder_id = coders.get_id(coder) coders.populate_map(pipeline_components.coders) - known_composites = set([urns.GROUP_BY_KEY_TRANSFORM]) + known_composites = set( + [urns.GROUP_BY_KEY_TRANSFORM, urns.COMBINE_PER_KEY_TRANSFORM]) def leaf_transforms(root_ids): for root_id in root_ids: @@ -631,8 +786,8 @@ def leaf_transforms(root_ids): # Apply each phase in order. for phase in [ - annotate_downstream_side_inputs, expand_gbk, sink_flattens, - greedily_fuse, sort_stages]: + annotate_downstream_side_inputs, lift_combiners, expand_gbk, + sink_flattens, greedily_fuse, sort_stages]: logging.info('%s %s %s', '=' * 20, phase, '=' * 20) stages = list(phase(stages)) logging.debug('Stages: %s', [str(s) for s in stages]) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 94dca8b242a63..136f22d0903c1 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -553,3 +553,57 @@ def create(factory, transform_id, transform_proto, unused_parameter, consumers): factory.state_sampler), transform_proto.unique_name, consumers) + + +@BeamTransformFactory.register_urn( + urns.PRECOMBINE_TRANSFORM, beam_runner_api_pb2.CombinePayload) +def create(factory, transform_id, transform_proto, payload, consumers): + # TODO: Combine side inputs. + serialized_combine_fn = pickler.dumps( + (beam.CombineFn.from_runner_api(payload.combine_fn, factory.context), + [], {})) + return factory.augment_oldstyle_op( + operations.PGBKCVOperation( + transform_proto.unique_name, + operation_specs.WorkerPartialGroupByKey( + serialized_combine_fn, + None, + [factory.get_only_output_coder(transform_proto)]), + factory.counter_factory, + factory.state_sampler), + transform_proto.unique_name, + consumers) + + +@BeamTransformFactory.register_urn( + urns.MERGE_ACCUMULATORS_TRANSFORM, beam_runner_api_pb2.CombinePayload) +def create(factory, transform_id, transform_proto, payload, consumers): + return _create_combine_phase_operation( + factory, transform_proto, payload, consumers, 'merge') + + +@BeamTransformFactory.register_urn( + urns.EXTRACT_OUTPUTS_TRANSFORM, beam_runner_api_pb2.CombinePayload) +def create(factory, transform_id, transform_proto, payload, consumers): + return _create_combine_phase_operation( + factory, transform_proto, payload, consumers, 'extract') + + +def _create_combine_phase_operation( + factory, transform_proto, payload, consumers, phase): + # This is where support for combine fn side inputs would go. + serialized_combine_fn = pickler.dumps( + (beam.CombineFn.from_runner_api(payload.combine_fn, factory.context), + [], {})) + return factory.augment_oldstyle_op( + operations.CombineOperation( + transform_proto.unique_name, + operation_specs.WorkerCombineFn( + serialized_combine_fn, + phase, + None, + [factory.get_only_output_coder(transform_proto)]), + factory.counter_factory, + factory.state_sampler), + transform_proto.unique_name, + consumers) diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index c245655b84788..8098a63f3c783 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -398,6 +398,7 @@ def __init__(self, operation_name, spec, counter_factory, state_sampler): fn, args, kwargs = pickler.loads(self.spec.serialized_fn)[:3] self.phased_combine_fn = ( PhasedCombineFnExecutor(self.spec.phase, fn, args, kwargs)) + self.scoped_metrics_container = ScopedMetricsContainer() def finish(self): logging.debug('Finishing %s', self) diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py index 6bdc29fbc6fb7..bd02fe1dfb784 100644 --- a/sdks/python/apache_beam/utils/urns.py +++ b/sdks/python/apache_beam/utils/urns.py @@ -42,6 +42,9 @@ GROUP_ALSO_BY_WINDOW_TRANSFORM = "beam:ptransform:group_also_by_window:v0.1" COMBINE_PER_KEY_TRANSFORM = "beam:ptransform:combine_per_key:v0.1" COMBINE_GROUPED_VALUES_TRANSFORM = "beam:ptransform:combine_grouped_values:v0.1" +PRECOMBINE_TRANSFORM = "beam:ptransform:combine_pre:v0.1" +MERGE_ACCUMULATORS_TRANSFORM = "beam:ptransform:combine_merge_accumulators:v0.1" +EXTRACT_OUTPUTS_TRANSFORM = "beam:ptransform:combine_extract_outputs:v0.1" FLATTEN_TRANSFORM = "beam:ptransform:flatten:v0.1" READ_TRANSFORM = "beam:ptransform:read:v0.1" RESHUFFLE_TRANSFORM = "beam:ptransform:reshuffle:v0.1"