diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 8506b8563dc0..fdb9a9dd865e 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -546,4 +546,5 @@ def from_runner_api(proto, context): if pc not in result.inputs: pc.producer = result pc.tag = tag + result.update_input_refcounts() return result diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index 37ff2a8420b8..5889ab543bad 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -435,6 +435,12 @@ def test_flatten_no_pcollections(self): assert_that(result, equal_to([])) pipeline.run() + def test_flatten_same_pcollections(self): + pipeline = TestPipeline() + pc = pipeline | beam.Create(['a', 'b']) + assert_that((pc, pc, pc) | beam.Flatten(), equal_to(['a', 'b'] * 3)) + pipeline.run() + def test_flatten_pcollections_in_iterable(self): pipeline = TestPipeline() pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3])