diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index c5b3318ae2dd7..dabab00acdb9a 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -48,14 +48,17 @@ from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.transforms import WindowInto from apache_beam.transforms import window from apache_beam.transforms.core import _GroupByKeyOnly from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display import DisplayDataItem from apache_beam.transforms.ptransform import PTransform +from apache_beam.transforms.window import TimestampedValue from apache_beam.typehints import with_input_types from apache_beam.typehints import with_output_types from apache_beam.typehints.typehints_test import TypeHintTestCase +from apache_beam.utils.timestamp import Timestamp from apache_beam.utils.windowed_value import WindowedValue # Disable frequent lint warning due to pipe operator for chaining transforms. @@ -310,6 +313,7 @@ def incorrect_par_do_fn(x): expected_error_prefix = 'FlatMap and ParDo must return an iterable.' self.assertStartswith(cm.exception.args[0], expected_error_prefix) + @attr('ValidatesRunner') def test_do_fn_with_finish(self): class MyDoFn(beam.DoFn): def process(self, element): @@ -332,6 +336,33 @@ def match(actual): assert_that(result, matcher()) pipeline.run() + @attr('ValidatesRunner') + def test_do_fn_with_windowing_in_finish_bundle(self): + windowfn = window.FixedWindows(2) + + class MyDoFn(beam.DoFn): + def process(self, element): + yield TimestampedValue('process'+ str(element), 5) + + def finish_bundle(self): + yield WindowedValue('finish', 1, [windowfn]) + + pipeline = TestPipeline() + result = (pipeline + | 'Start' >> beam.Create([x for x in range(3)]) + | beam.ParDo(MyDoFn()) + | WindowInto(windowfn) + | 'create tuple' >> beam.Map( + lambda v, t=beam.DoFn.TimestampParam, w=beam.DoFn.WindowParam: + (v, t, w.start, w.end))) + expected_process = [('process'+ str(x), Timestamp(5), Timestamp(4), + Timestamp(6)) for x in range(3)] + expected_finish = [('finish', Timestamp(1), Timestamp(0), Timestamp(2))] + + assert_that(result, equal_to(expected_process + expected_finish)) + pipeline.run() + + @attr('ValidatesRunner') def test_do_fn_with_start(self): class MyDoFn(beam.DoFn): def __init__(self):