From 1152b48e0488d0024589e1ccf9cbf17e0e7a2a82 Mon Sep 17 00:00:00 2001 From: Vikas Kedigehalli Date: Fri, 7 Apr 2017 19:22:27 -0700 Subject: [PATCH 1/2] Fix GroupByKeyInputVisitor for Direct Runner --- .../runners/direct/direct_runner.py | 2 + sdks/python/apache_beam/runners/runner.py | 70 +++++++++++-------- .../python/apache_beam/runners/runner_test.py | 43 +++++++++++- 3 files changed, 84 insertions(+), 31 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 1a5775f7f49a..83752942f93e 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -65,10 +65,12 @@ def run(self, pipeline): from apache_beam.runners.direct.executor import Executor from apache_beam.runners.direct.transform_evaluator import \ TransformEvaluatorRegistry + from apache_beam.runners.runner import group_by_key_input_visitor MetricsEnvironment.set_metrics_supported(True) logging.info('Running pipeline with DirectRunner.') self.visitor = ConsumerTrackingPipelineVisitor() + pipeline.visit(group_by_key_input_visitor()) pipeline.visit(self.visitor) evaluation_context = EvaluationContext( diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index 528b03f5fcb6..aa9b22484209 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -86,6 +86,46 @@ def create_runner(runner_name): runner_name, ', '.join(_ALL_KNOWN_RUNNERS))) +def group_by_key_input_visitor(): + from apache_beam.pipeline import PipelineVisitor + + class GroupByKeyInputVisitor(PipelineVisitor): + """A visitor that replaces `Any` element type for input `PCollection` of + a `GroupByKey` or `GroupByKeyOnly` with a `KV` type. + + TODO(BEAM-115): Once Python SDk is compatible with the new Runner API, + we could directly replace the coder instead of mutating the element type. + """ + + def visit_transform(self, transform_node): + # Imported here to avoid circular dependencies. + # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam import GroupByKey, GroupByKeyOnly + from apache_beam import typehints + if (isinstance(transform_node.transform, GroupByKey) or + isinstance(transform_node.transform, GroupByKeyOnly)): + pcoll = transform_node.inputs[0] + input_type = pcoll.element_type + # If input_type is not specified, then treat it as `Any`. + if not input_type: + input_type = typehints.Any + + if not isinstance(input_type, typehints.TupleHint.TupleConstraint): + if isinstance(input_type, typehints.AnyTypeConstraint): + # `Any` type needs to be replaced with a KV[Any, Any] to + # force a KV coder as the main output coder for the pcollection + # preceding a GroupByKey. + pcoll.element_type = typehints.KV[typehints.Any, typehints.Any] + else: + # TODO: Handle other valid types, + # e.g. Union[KV[str, int], KV[str, float]] + raise ValueError( + "Input to GroupByKey must be of Tuple or Any type. " + "Found %s for %s" % (input_type, pcoll)) + + return GroupByKeyInputVisitor() + + class PipelineRunner(object): """A runner of a pipeline object. @@ -119,35 +159,7 @@ def visit_transform(self, transform_node): logging.error('Error while visiting %s', transform_node.full_label) raise - class GroupByKeyInputVisitor(PipelineVisitor): - """A visitor that replaces `Any` element type for input `PCollection` of - a `GroupByKey` with a `KV` type. - - TODO(BEAM-115): Once Python SDk is compatible with the new Runner API, - we could directly replace the coder instead of mutating the element type. - """ - def visit_transform(self, transform_node): - # Imported here to avoid circular dependencies. - # pylint: disable=wrong-import-order, wrong-import-position - from apache_beam import GroupByKey - from apache_beam import typehints - if isinstance(transform_node.transform, GroupByKey): - pcoll = transform_node.inputs[0] - input_type = pcoll.element_type - if not isinstance(input_type, typehints.TupleHint.TupleConstraint): - if isinstance(input_type, typehints.AnyTypeConstraint): - # `Any` type needs to be replaced with a KV[Any, Any] to - # force a KV coder as the main output coder for the pcollection - # preceding a GroupByKey. - pcoll.element_type = typehints.KV[typehints.Any, typehints.Any] - else: - # TODO: Handle other valid types, - # e.g. Union[KV[str, int], KV[str, float]] - raise ValueError( - "Input to GroupByKey must be of Tuple or Any type. " - "Found %s for %s" % (input_type, pcoll)) - - pipeline.visit(GroupByKeyInputVisitor()) + pipeline.visit(group_by_key_input_visitor()) pipeline.visit(RunVisitor(self)) def clear(self, pipeline, node=None): diff --git a/sdks/python/apache_beam/runners/runner_test.py b/sdks/python/apache_beam/runners/runner_test.py index b161cbbd49b1..e032a4bb95f3 100644 --- a/sdks/python/apache_beam/runners/runner_test.py +++ b/sdks/python/apache_beam/runners/runner_test.py @@ -28,14 +28,17 @@ import apache_beam as beam import apache_beam.transforms as ptransform +from apache_beam import typehints from apache_beam.metrics.cells import DistributionData from apache_beam.metrics.cells import DistributionResult from apache_beam.metrics.execution import MetricKey from apache_beam.metrics.execution import MetricResult from apache_beam.metrics.metricbase import MetricName -from apache_beam.pipeline import Pipeline -from apache_beam.runners import DirectRunner +from apache_beam.pipeline import Pipeline, AppliedPTransform +from apache_beam.pvalue import PCollection +from apache_beam.runners import DirectRunner, runner from apache_beam.runners import create_runner +from apache_beam.test_pipeline import TestPipeline from apache_beam.transforms.util import assert_that from apache_beam.transforms.util import equal_to from apache_beam.utils.pipeline_options import PipelineOptions @@ -118,6 +121,42 @@ def process(self, element): DistributionResult(DistributionData(15, 5, 1, 5)), DistributionResult(DistributionData(15, 5, 1, 5))))) + def test_group_by_key_input_visitor_with_valid_inputs(self): + p = TestPipeline() + pcoll1 = PCollection(p) + pcoll2 = PCollection(p) + pcoll3 = PCollection(p) + for transform in [beam.GroupByKeyOnly(), beam.GroupByKey()]: + pcoll1.element_type = None + pcoll2.element_type = typehints.Any + pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any] + for pcoll in [pcoll1, pcoll2, pcoll3]: + runner.group_by_key_input_visitor().visit_transform( + AppliedPTransform(None, transform, "label", [pcoll])) + self.assertEqual(pcoll.element_type, + typehints.KV[typehints.Any, typehints.Any]) + + def test_group_by_key_input_visitor_with_invalid_inputs(self): + p = TestPipeline() + pcoll1 = PCollection(p) + pcoll2 = PCollection(p) + for transform in [beam.GroupByKeyOnly(), beam.GroupByKey()]: + pcoll1.element_type = typehints.TupleSequenceConstraint + pcoll2.element_type = typehints.Set + err_msg = "Input to GroupByKey must be of Tuple or Any type" + for pcoll in [pcoll1, pcoll2]: + with self.assertRaisesRegexp(ValueError, err_msg): + runner.group_by_key_input_visitor().visit_transform( + AppliedPTransform(None, transform, "label", [pcoll])) + + def test_group_by_key_input_visitor_for_non_gbk_transforms(self): + p = TestPipeline() + pcoll = PCollection(p) + for transform in [beam.Flatten(), beam.Map(lambda x: x)]: + pcoll.element_type = typehints.Any + runner.group_by_key_input_visitor().visit_transform( + AppliedPTransform(None, transform, "label", [pcoll])) + self.assertEqual(pcoll.element_type, typehints.Any) if __name__ == '__main__': unittest.main() From a0e5a07a9cbe5b1cbe0d0eef8b5700019b09421f Mon Sep 17 00:00:00 2001 From: Vikas Kedigehalli Date: Mon, 10 Apr 2017 14:05:31 -0700 Subject: [PATCH 2/2] Address comments --- .../runners/direct/direct_runner.py | 18 +++++++++--------- sdks/python/apache_beam/runners/runner.py | 1 + sdks/python/apache_beam/runners/runner_test.py | 6 ++++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 83752942f93e..9b4e1acd237b 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -32,6 +32,7 @@ from apache_beam.runners.runner import PipelineRunner from apache_beam.runners.runner import PipelineState from apache_beam.runners.runner import PValueCache +from apache_beam.runners.runner import group_by_key_input_visitor from apache_beam.utils.pipeline_options import DirectOptions from apache_beam.utils.value_provider import RuntimeValueProvider @@ -65,26 +66,25 @@ def run(self, pipeline): from apache_beam.runners.direct.executor import Executor from apache_beam.runners.direct.transform_evaluator import \ TransformEvaluatorRegistry - from apache_beam.runners.runner import group_by_key_input_visitor MetricsEnvironment.set_metrics_supported(True) logging.info('Running pipeline with DirectRunner.') - self.visitor = ConsumerTrackingPipelineVisitor() + self.consumer_tracking_visitor = ConsumerTrackingPipelineVisitor() pipeline.visit(group_by_key_input_visitor()) - pipeline.visit(self.visitor) + pipeline.visit(self.consumer_tracking_visitor) evaluation_context = EvaluationContext( pipeline.options, BundleFactory(stacked=pipeline.options.view_as(DirectOptions) .direct_runner_use_stacked_bundle), - self.visitor.root_transforms, - self.visitor.value_to_consumers, - self.visitor.step_names, - self.visitor.views) + self.consumer_tracking_visitor.root_transforms, + self.consumer_tracking_visitor.value_to_consumers, + self.consumer_tracking_visitor.step_names, + self.consumer_tracking_visitor.views) evaluation_context.use_pvalue_cache(self._cache) - executor = Executor(self.visitor.value_to_consumers, + executor = Executor(self.consumer_tracking_visitor.value_to_consumers, TransformEvaluatorRegistry(evaluation_context), evaluation_context) # Start the executor. This is a non-blocking call, it will start the @@ -92,7 +92,7 @@ def run(self, pipeline): if pipeline.options: RuntimeValueProvider.set_runtime_options(pipeline.options._options_id, {}) - executor.start(self.visitor.root_transforms) + executor.start(self.consumer_tracking_visitor.root_transforms) result = DirectPipelineResult(executor, evaluation_context) if self._cache: diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index aa9b22484209..de9c8928d291 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -87,6 +87,7 @@ def create_runner(runner_name): def group_by_key_input_visitor(): + # Imported here to avoid circular dependencies. from apache_beam.pipeline import PipelineVisitor class GroupByKeyInputVisitor(PipelineVisitor): diff --git a/sdks/python/apache_beam/runners/runner_test.py b/sdks/python/apache_beam/runners/runner_test.py index e032a4bb95f3..0bebd665a2e4 100644 --- a/sdks/python/apache_beam/runners/runner_test.py +++ b/sdks/python/apache_beam/runners/runner_test.py @@ -34,9 +34,11 @@ from apache_beam.metrics.execution import MetricKey from apache_beam.metrics.execution import MetricResult from apache_beam.metrics.metricbase import MetricName -from apache_beam.pipeline import Pipeline, AppliedPTransform +from apache_beam.pipeline import AppliedPTransform +from apache_beam.pipeline import Pipeline from apache_beam.pvalue import PCollection -from apache_beam.runners import DirectRunner, runner +from apache_beam.runners import DirectRunner +from apache_beam.runners import runner from apache_beam.runners import create_runner from apache_beam.test_pipeline import TestPipeline from apache_beam.transforms.util import assert_that