From 155bbbd887225720e0956522901827dd18f30608 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Mon, 8 May 2017 13:33:15 -0700 Subject: [PATCH 1/2] [BEAM-1925] validate DoFn at pipeline creation time --- sdks/python/apache_beam/runners/common.py | 25 ++++---- .../python/apache_beam/runners/common_test.py | 58 +++++++++++++++++++ sdks/python/apache_beam/transforms/core.py | 9 ++- .../apache_beam/transforms/ptransform_test.py | 1 - 4 files changed, 78 insertions(+), 15 deletions(-) create mode 100644 sdks/python/apache_beam/runners/common_test.py diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 045c1093836f..74c61abc34bc 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -95,17 +95,20 @@ def __init__(self, do_fn): self._validate() def _validate(self): + self._validate_process() self._validate_bundle_method(self.start_bundle_method) self._validate_bundle_method(self.finish_bundle_method) - def _validate_bundle_method(self, method_wrapper): - # Here we use the fact that every DoFn parameter defined in core.DoFn has - # the value that is the same as the name of the parameter and ends with - # string 'Param'. - unsupported_dofn_params = [i for i in core.DoFn.__dict__ if - i.endswith('Param')] + def _validate_process(self): + """Validate that none of the DoFnParameters are repeated in the function + """ + for param in core.DoFn.DoFnParams: + assert self.process_method.defaults.count(param) <= 1 - for param in unsupported_dofn_params: + def _validate_bundle_method(self, method_wrapper): + """Validate that none of the DoFnParameters are used in the function + """ + for param in core.DoFn.DoFnParams: assert param not in method_wrapper.defaults @@ -156,18 +159,14 @@ def invoke_process(self, windowed_value): def invoke_start_bundle(self): """Invokes the DoFn.start_bundle() method. """ - args_for_start_bundle = self.signature.start_bundle_method.defaults self.output_processor.start_bundle_outputs( - self.signature.start_bundle_method.method_value( - *args_for_start_bundle)) + self.signature.start_bundle_method.method_value()) def invoke_finish_bundle(self): """Invokes the DoFn.finish_bundle() method. """ - args_for_finish_bundle = self.signature.finish_bundle_method.defaults self.output_processor.finish_bundle_outputs( - self.signature.finish_bundle_method.method_value( - *args_for_finish_bundle)) + self.signature.finish_bundle_method.method_value()) class SimpleInvoker(DoFnInvoker): diff --git a/sdks/python/apache_beam/runners/common_test.py b/sdks/python/apache_beam/runners/common_test.py new file mode 100644 index 000000000000..62a6955f6ce4 --- /dev/null +++ b/sdks/python/apache_beam/runners/common_test.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from apache_beam.transforms.core import DoFn +from apache_beam.runners.common import DoFnSignature + + +class DoFnSignatureTest(unittest.TestCase): + + def test_dofn_validate_process_error(self): + class MyDoFn(DoFn): + def process(self, element, w1=DoFn.WindowParam, w2=DoFn.WindowParam): + pass + + with self.assertRaises(AssertionError): + DoFnSignature(MyDoFn()) + + def test_dofn_validate_start_bundle_error(self): + class MyDoFn(DoFn): + def process(self, element): + pass + + def start_bundle(self, w1=DoFn.WindowParam): + pass + + with self.assertRaises(AssertionError): + DoFnSignature(MyDoFn()) + + def test_dofn_validate_finish_bundle_error(self): + class MyDoFn(DoFn): + def process(self, element): + pass + + def finish_bundle(self, w1=DoFn.WindowParam): + pass + + with self.assertRaises(AssertionError): + DoFnSignature(MyDoFn()) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 7ca1632c08e7..e37a387e6d18 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -29,7 +29,8 @@ from apache_beam.internal import util from apache_beam.runners.api import beam_runner_api_pb2 from apache_beam.transforms import ptransform -from apache_beam.transforms.display import HasDisplayData, DisplayDataItem +from apache_beam.transforms.display import DisplayDataItem +from apache_beam.transforms.display import HasDisplayData from apache_beam.transforms.ptransform import PTransform from apache_beam.transforms.ptransform import PTransformWithSideInputs from apache_beam.transforms.window import MIN_TIMESTAMP @@ -131,6 +132,8 @@ class DoFn(WithTypeHints, HasDisplayData): TimestampParam = 'TimestampParam' WindowParam = 'WindowParam' + DoFnParams = [ElementParam, SideInputParam, TimestampParam, WindowParam] + @staticmethod def from_callable(fn): return CallableWrapperDoFn(fn) @@ -596,6 +599,10 @@ def __init__(self, fn, *args, **kwargs): if not isinstance(self.fn, DoFn): raise TypeError('ParDo must be called with a DoFn instance.') + # Validate the DoFn by creating a DoFnSignature + from apache_beam.runners.common import DoFnSignature + DoFnSignature(self.fn) + def default_type_hints(self): return self.fn.get_type_hints() diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index b8b0733622e4..e7126610935b 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -303,7 +303,6 @@ def __init__(self): def start_bundle(self): self.state = 'started' - return None def process(self, element): if self.state == 'started': From dd6a12db7c11b38329354b181f825815fa3a9b6d Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Mon, 1 May 2017 15:00:24 -0700 Subject: [PATCH 2/2] [BEAM-1283] Finish bundle should only emit windowed values --- .../apache_beam/examples/snippets/snippets_test.py | 10 ++++++++-- sdks/python/apache_beam/io/iobase.py | 4 +++- sdks/python/apache_beam/runners/common.py | 13 ++----------- .../apache_beam/transforms/ptransform_test.py | 8 +++++--- 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py index 014809628ee1..da0a9625cfda 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets_test.py +++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py @@ -33,6 +33,7 @@ from apache_beam.transforms.util import equal_to from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.examples.snippets import snippets +from apache_beam.utils.windowed_value import WindowedValue # pylint: disable=expression-not-assigned from apache_beam.test_pipeline import TestPipeline @@ -366,6 +367,7 @@ def parse_player_and_score(csv): class SnippetsTest(unittest.TestCase): # Replacing text read/write transforms with dummy transforms for testing. + class DummyReadTransform(beam.PTransform): """A transform that will replace iobase.ReadFromText. @@ -387,16 +389,20 @@ def process(self, element): pass def finish_bundle(self): + from apache_beam.transforms import window + assert self.file_to_read for file_name in glob.glob(self.file_to_read): if self.compression_type is None: with open(file_name) as file: for record in file: - yield self.coder.decode(record.rstrip('\n')) + value = self.coder.decode(record.rstrip('\n')) + yield WindowedValue(value, -1, [window.GlobalWindow()]) else: with gzip.open(file_name, 'r') as file: for record in file: - yield self.coder.decode(record.rstrip('\n')) + value = self.coder.decode(record.rstrip('\n')) + yield WindowedValue(value, -1, [window.GlobalWindow()]) def expand(self, pcoll): return pcoll | beam.Create([None]) | 'DummyReadForTesting' >> beam.ParDo( diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index 312542ad5d31..d47ef5b19cb3 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -44,6 +44,7 @@ from apache_beam.transforms import window from apache_beam.transforms.display import HasDisplayData from apache_beam.transforms.display import DisplayDataItem +from apache_beam.utils.windowed_value import WindowedValue # Encapsulates information about a bundle of a source generated when method @@ -931,7 +932,8 @@ def process(self, element, init_result): def finish_bundle(self): if self.writer is not None: - yield window.TimestampedValue(self.writer.close(), window.MAX_TIMESTAMP) + yield WindowedValue(self.writer.close(), window.MAX_TIMESTAMP, + [window.GlobalWindow()]) class _WriteKeyedBundleDoFn(core.DoFn): diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 74c61abc34bc..ec1f5dc47c77 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -487,18 +487,9 @@ def finish_bundle_outputs(self, results): if isinstance(result, WindowedValue): windowed_value = result - elif isinstance(result, TimestampedValue): - value = result.value - timestamp = result.timestamp - assign_context = NoContext(value, timestamp) - windowed_value = WindowedValue( - value, timestamp, self.window_fn.assign(assign_context)) else: - value = result - timestamp = -1 - assign_context = NoContext(value) - windowed_value = WindowedValue( - value, timestamp, self.window_fn.assign(assign_context)) + raise RuntimeError('Finish Bundle should only output WindowedValue ' +\ + 'type but got %s' % type(result)) if tag is None: self.main_receivers.receive(windowed_value) diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index e7126610935b..5948460579de 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -30,8 +30,10 @@ from apache_beam.metrics import Metrics from apache_beam.metrics.metric import MetricsFilter from apache_beam.io.iobase import Read -from apache_beam.test_pipeline import TestPipeline +from apache_beam.options.pipeline_options import TypeOptions import apache_beam.pvalue as pvalue +from apache_beam.test_pipeline import TestPipeline +from apache_beam.transforms import window import apache_beam.transforms.combiners as combine from apache_beam.transforms.display import DisplayData, DisplayDataItem from apache_beam.transforms.ptransform import PTransform @@ -40,7 +42,7 @@ from apache_beam.typehints import with_input_types from apache_beam.typehints import with_output_types from apache_beam.typehints.typehints_test import TypeHintTestCase -from apache_beam.options.pipeline_options import TypeOptions +from apache_beam.utils.windowed_value import WindowedValue # Disable frequent lint warning due to pipe operator for chaining transforms. @@ -280,7 +282,7 @@ def process(self, element): pass def finish_bundle(self): - yield 'finish' + yield WindowedValue('finish', -1, [window.GlobalWindow()]) pipeline = TestPipeline() pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])