Skip to content

Commit

Permalink
[BEAM-2937] Basic PGBK combiner lifting. (#4290)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb committed Dec 20, 2017
1 parent 2063781 commit e92f718
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 3 deletions.
161 changes: 158 additions & 3 deletions sdks/python/apache_beam/runners/portability/fn_api_runner.py
Expand Up @@ -171,6 +171,14 @@ def items(self):
class FnApiRunner(runner.PipelineRunner):

def __init__(self, use_grpc=False, sdk_harness_factory=None):
"""Creates a new Fn API Runner.
Args:
use_grpc: whether to use grpc or simply make in-process calls
defaults to False
sdk_harness_factory: callable used to instantiate customized sdk harnesses
typcially not set by users
"""
super(FnApiRunner, self).__init__()
self._last_uid = -1
self._use_grpc = use_grpc
Expand Down Expand Up @@ -277,6 +285,150 @@ def deduplicate_read(self):

safe_coders = {}

def lift_combiners(stages):
"""Expands CombinePerKey into pre- and post-grouping stages.
... -> CombinePerKey -> ...
becomes
... -> PreCombine -> GBK -> MergeAccumulators -> ExtractOutput -> ...
"""
def add_or_get_coder_id(coder_proto):
for coder_id, coder in pipeline_components.coders.items():
if coder == coder_proto:
return coder_id
new_coder_id = unique_name(pipeline_components.coders, 'coder')
pipeline_components.coders[new_coder_id].CopyFrom(coder_proto)
return new_coder_id

def windowed_coder_id(coder_id):
proto = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.SdkFunctionSpec(
spec=beam_runner_api_pb2.FunctionSpec(
urn=urns.WINDOWED_VALUE_CODER)),
component_coder_ids=[coder_id, window_coder_id])
return add_or_get_coder_id(proto)

for stage in stages:
assert len(stage.transforms) == 1
transform = stage.transforms[0]
if transform.spec.urn == urns.COMBINE_PER_KEY_TRANSFORM:
combine_payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.CombinePayload)

input_pcoll = pipeline_components.pcollections[only_element(
transform.inputs.values())]
output_pcoll = pipeline_components.pcollections[only_element(
transform.outputs.values())]

windowed_input_coder = pipeline_components.coders[
input_pcoll.coder_id]
element_coder_id, window_coder_id = (
windowed_input_coder.component_coder_ids)
element_coder = pipeline_components.coders[element_coder_id]
key_coder_id, _ = element_coder.component_coder_ids
accumulator_coder_id = combine_payload.accumulator_coder_id

key_accumulator_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.SdkFunctionSpec(
spec=beam_runner_api_pb2.FunctionSpec(
urn=urns.KV_CODER)),
component_coder_ids=[key_coder_id, accumulator_coder_id])
key_accumulator_coder_id = add_or_get_coder_id(key_accumulator_coder)

accumulator_iter_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.SdkFunctionSpec(
spec=beam_runner_api_pb2.FunctionSpec(
urn=urns.ITERABLE_CODER)),
component_coder_ids=[accumulator_coder_id])
accumulator_iter_coder_id = add_or_get_coder_id(
accumulator_iter_coder)

key_accumulator_iter_coder = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.SdkFunctionSpec(
spec=beam_runner_api_pb2.FunctionSpec(
urn=urns.KV_CODER)),
component_coder_ids=[key_coder_id, accumulator_iter_coder_id])
key_accumulator_iter_coder_id = add_or_get_coder_id(
key_accumulator_iter_coder)

precombined_pcoll_id = unique_name(
pipeline_components.pcollections, 'pcollection')
pipeline_components.pcollections[precombined_pcoll_id].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=transform.unique_name + '/Precombine.out',
coder_id=windowed_coder_id(key_accumulator_coder_id),
windowing_strategy_id=input_pcoll.windowing_strategy_id,
is_bounded=input_pcoll.is_bounded))

grouped_pcoll_id = unique_name(
pipeline_components.pcollections, 'pcollection')
pipeline_components.pcollections[grouped_pcoll_id].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=transform.unique_name + '/Group.out',
coder_id=windowed_coder_id(key_accumulator_iter_coder_id),
windowing_strategy_id=output_pcoll.windowing_strategy_id,
is_bounded=output_pcoll.is_bounded))

merged_pcoll_id = unique_name(
pipeline_components.pcollections, 'pcollection')
pipeline_components.pcollections[merged_pcoll_id].CopyFrom(
beam_runner_api_pb2.PCollection(
unique_name=transform.unique_name + '/Merge.out',
coder_id=windowed_coder_id(key_accumulator_coder_id),
windowing_strategy_id=output_pcoll.windowing_strategy_id,
is_bounded=output_pcoll.is_bounded))

def make_stage(base_stage, transform):
return Stage(
transform.unique_name,
[transform],
downstream_side_inputs=base_stage.downstream_side_inputs,
must_follow=base_stage.must_follow)

yield make_stage(
stage,
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Precombine',
spec=beam_runner_api_pb2.FunctionSpec(
urn=urns.PRECOMBINE_TRANSFORM,
payload=transform.spec.payload),
inputs=transform.inputs,
outputs={'out': precombined_pcoll_id}))

yield make_stage(
stage,
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Group',
spec=beam_runner_api_pb2.FunctionSpec(
urn=urns.GROUP_BY_KEY_TRANSFORM),
inputs={'in': precombined_pcoll_id},
outputs={'out': grouped_pcoll_id}))

yield make_stage(
stage,
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/Merge',
spec=beam_runner_api_pb2.FunctionSpec(
urn=urns.MERGE_ACCUMULATORS_TRANSFORM,
payload=transform.spec.payload),
inputs={'in': grouped_pcoll_id},
outputs={'out': merged_pcoll_id}))

yield make_stage(
stage,
beam_runner_api_pb2.PTransform(
unique_name=transform.unique_name + '/ExtractOutputs',
spec=beam_runner_api_pb2.FunctionSpec(
urn=urns.EXTRACT_OUTPUTS_TRANSFORM,
payload=transform.spec.payload),
inputs={'in': merged_pcoll_id},
outputs=transform.outputs))

else:
yield stage

def expand_gbk(stages):
"""Transforms each GBK into a write followed by a read.
"""
Expand Down Expand Up @@ -351,6 +503,8 @@ def fix_pcoll_coder(pcoll):

# This is used later to correlate the read and write.
param = str("group:%s" % stage.name)
if stage.name not in pipeline_components.transforms:
pipeline_components.transforms[stage.name].CopyFrom(transform)
gbk_write = Stage(
transform.unique_name + '/Write',
[beam_runner_api_pb2.PTransform(
Expand Down Expand Up @@ -613,7 +767,8 @@ def process(stage):
pcoll.coder_id = coders.get_id(coder)
coders.populate_map(pipeline_components.coders)

known_composites = set([urns.GROUP_BY_KEY_TRANSFORM])
known_composites = set(
[urns.GROUP_BY_KEY_TRANSFORM, urns.COMBINE_PER_KEY_TRANSFORM])

def leaf_transforms(root_ids):
for root_id in root_ids:
Expand All @@ -631,8 +786,8 @@ def leaf_transforms(root_ids):

# Apply each phase in order.
for phase in [
annotate_downstream_side_inputs, expand_gbk, sink_flattens,
greedily_fuse, sort_stages]:
annotate_downstream_side_inputs, lift_combiners, expand_gbk,
sink_flattens, greedily_fuse, sort_stages]:
logging.info('%s %s %s', '=' * 20, phase, '=' * 20)
stages = list(phase(stages))
logging.debug('Stages: %s', [str(s) for s in stages])
Expand Down
54 changes: 54 additions & 0 deletions sdks/python/apache_beam/runners/worker/bundle_processor.py
Expand Up @@ -553,3 +553,57 @@ def create(factory, transform_id, transform_proto, unused_parameter, consumers):
factory.state_sampler),
transform_proto.unique_name,
consumers)


@BeamTransformFactory.register_urn(
urns.PRECOMBINE_TRANSFORM, beam_runner_api_pb2.CombinePayload)
def create(factory, transform_id, transform_proto, payload, consumers):
# TODO: Combine side inputs.
serialized_combine_fn = pickler.dumps(
(beam.CombineFn.from_runner_api(payload.combine_fn, factory.context),
[], {}))
return factory.augment_oldstyle_op(
operations.PGBKCVOperation(
transform_proto.unique_name,
operation_specs.WorkerPartialGroupByKey(
serialized_combine_fn,
None,
[factory.get_only_output_coder(transform_proto)]),
factory.counter_factory,
factory.state_sampler),
transform_proto.unique_name,
consumers)


@BeamTransformFactory.register_urn(
urns.MERGE_ACCUMULATORS_TRANSFORM, beam_runner_api_pb2.CombinePayload)
def create(factory, transform_id, transform_proto, payload, consumers):
return _create_combine_phase_operation(
factory, transform_proto, payload, consumers, 'merge')


@BeamTransformFactory.register_urn(
urns.EXTRACT_OUTPUTS_TRANSFORM, beam_runner_api_pb2.CombinePayload)
def create(factory, transform_id, transform_proto, payload, consumers):
return _create_combine_phase_operation(
factory, transform_proto, payload, consumers, 'extract')


def _create_combine_phase_operation(
factory, transform_proto, payload, consumers, phase):
# This is where support for combine fn side inputs would go.
serialized_combine_fn = pickler.dumps(
(beam.CombineFn.from_runner_api(payload.combine_fn, factory.context),
[], {}))
return factory.augment_oldstyle_op(
operations.CombineOperation(
transform_proto.unique_name,
operation_specs.WorkerCombineFn(
serialized_combine_fn,
phase,
None,
[factory.get_only_output_coder(transform_proto)]),
factory.counter_factory,
factory.state_sampler),
transform_proto.unique_name,
consumers)
1 change: 1 addition & 0 deletions sdks/python/apache_beam/runners/worker/operations.py
Expand Up @@ -398,6 +398,7 @@ def __init__(self, operation_name, spec, counter_factory, state_sampler):
fn, args, kwargs = pickler.loads(self.spec.serialized_fn)[:3]
self.phased_combine_fn = (
PhasedCombineFnExecutor(self.spec.phase, fn, args, kwargs))
self.scoped_metrics_container = ScopedMetricsContainer()

def finish(self):
logging.debug('Finishing %s', self)
Expand Down
3 changes: 3 additions & 0 deletions sdks/python/apache_beam/utils/urns.py
Expand Up @@ -42,6 +42,9 @@
GROUP_ALSO_BY_WINDOW_TRANSFORM = "beam:ptransform:group_also_by_window:v0.1"
COMBINE_PER_KEY_TRANSFORM = "beam:ptransform:combine_per_key:v0.1"
COMBINE_GROUPED_VALUES_TRANSFORM = "beam:ptransform:combine_grouped_values:v0.1"
PRECOMBINE_TRANSFORM = "beam:ptransform:combine_pre:v0.1"
MERGE_ACCUMULATORS_TRANSFORM = "beam:ptransform:combine_merge_accumulators:v0.1"
EXTRACT_OUTPUTS_TRANSFORM = "beam:ptransform:combine_extract_outputs:v0.1"
FLATTEN_TRANSFORM = "beam:ptransform:flatten:v0.1"
READ_TRANSFORM = "beam:ptransform:read:v0.1"
RESHUFFLE_TRANSFORM = "beam:ptransform:reshuffle:v0.1"
Expand Down

0 comments on commit e92f718

Please sign in to comment.