From 8766b03eb31b4f16de8fbf5a6902378a2c1151e0 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 6 Mar 2017 08:55:13 -0800 Subject: [PATCH] Revert "Implement Single-Output ParDo as a composite" This reverts commit 6253abaac62979e8496a828c18c7d1aa7214be6a. The reverted commit breaks Dataflow DisplayData. The actual fix will include a Dataflow override for single-output ParDos. --- .../translation/ApexPipelineTranslator.java | 3 +- ...or.java => ParDoBoundMultiTranslator.java} | 4 +- .../translation/ParDoBoundTranslator.java | 95 +++++++++++++++ .../FlattenPCollectionTranslatorTest.java | 3 +- ...est.java => ParDoBoundTranslatorTest.java} | 8 +- .../beam/runners/direct/DirectRunner.java | 18 ++- .../ParDoSingleViaMultiOverrideFactory.java | 70 +++++++++++ ...arDoSingleViaMultiOverrideFactoryTest.java | 46 +++++++ .../flink/FlinkBatchTransformTranslators.java | 78 +++++++++++- .../FlinkStreamingTransformTranslators.java | 113 ++++++++++++++++- .../dataflow/DataflowPipelineTranslator.java | 29 +++++ .../DataflowPipelineTranslatorTest.java | 7 +- .../translation/TransformTranslator.java | 100 +++++++-------- .../StreamingTransformTranslator.java | 115 ++++++++---------- .../streaming/TrackStreamingSourcesTest.java | 4 +- .../org/apache/beam/sdk/transforms/ParDo.java | 8 +- 16 files changed, 556 insertions(+), 145 deletions(-) rename runners/apex/src/main/java/org/apache/beam/runners/apex/translation/{ParDoTranslator.java => ParDoBoundMultiTranslator.java} (99%) create mode 100644 runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java rename runners/apex/src/test/java/org/apache/beam/runners/apex/translation/{ParDoTranslatorTest.java => ParDoBoundTranslatorTest.java} (98%) create mode 100644 runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java create mode 100644 runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java index 7eb955123cbe..951a286fb3af 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ApexPipelineTranslator.java @@ -59,7 +59,8 @@ public class ApexPipelineTranslator implements Pipeline.PipelineVisitor { static { // register TransformTranslators - registerTransformTranslator(ParDo.BoundMulti.class, new ParDoTranslator<>()); + registerTransformTranslator(ParDo.Bound.class, new ParDoBoundTranslator()); + registerTransformTranslator(ParDo.BoundMulti.class, new ParDoBoundMultiTranslator<>()); registerTransformTranslator(Read.Unbounded.class, new ReadUnboundedTranslator()); registerTransformTranslator(Read.Bounded.class, new ReadBoundedTranslator()); registerTransformTranslator(GroupByKey.class, new GroupByKeyTranslator()); diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundMultiTranslator.java similarity index 99% rename from runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java rename to runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundMultiTranslator.java index 5ffc3c389a68..f55b48cd7253 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundMultiTranslator.java @@ -46,10 +46,10 @@ /** * {@link ParDo.BoundMulti} is translated to {@link ApexParDoOperator} that wraps the {@link DoFn}. */ -class ParDoTranslator +class ParDoBoundMultiTranslator implements TransformTranslator> { private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(ParDoTranslator.class); + private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundMultiTranslator.class); @Override public void translate(ParDo.BoundMulti transform, TranslationContext context) { diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java new file mode 100644 index 000000000000..5195809bdbbf --- /dev/null +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.apex.translation; + +import java.util.List; +import org.apache.beam.runners.apex.ApexRunner; +import org.apache.beam.runners.apex.translation.operators.ApexParDoOperator; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; +import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; + +/** {@link ParDo.Bound} is translated to {link ApexParDoOperator} that wraps the {@link DoFn}. */ +class ParDoBoundTranslator + implements TransformTranslator> { + private static final long serialVersionUID = 1L; + + @Override + public void translate(ParDo.Bound transform, TranslationContext context) { + DoFn doFn = transform.getFn(); + DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + + if (signature.processElement().isSplittable()) { + throw new UnsupportedOperationException( + String.format( + "%s does not support splittable DoFn: %s", ApexRunner.class.getSimpleName(), doFn)); + } + if (signature.stateDeclarations().size() > 0) { + throw new UnsupportedOperationException( + String.format( + "Found %s annotations on %s, but %s cannot yet be used with state in the %s.", + DoFn.StateId.class.getSimpleName(), + doFn.getClass().getName(), + DoFn.class.getSimpleName(), + ApexRunner.class.getSimpleName())); + } + + if (signature.timerDeclarations().size() > 0) { + throw new UnsupportedOperationException( + String.format( + "Found %s annotations on %s, but %s cannot yet be used with timers in the %s.", + DoFn.TimerId.class.getSimpleName(), + doFn.getClass().getName(), + DoFn.class.getSimpleName(), + ApexRunner.class.getSimpleName())); + } + + PCollection output = (PCollection) context.getOutput(); + PCollection input = (PCollection) context.getInput(); + List> sideInputs = transform.getSideInputs(); + Coder inputCoder = input.getCoder(); + WindowedValueCoder wvInputCoder = + FullWindowedValueCoder.of( + inputCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); + + ApexParDoOperator operator = + new ApexParDoOperator<>( + context.getPipelineOptions(), + doFn, + new TupleTag(), + TupleTagList.empty().getAll() /*sideOutputTags*/, + output.getWindowingStrategy(), + sideInputs, + wvInputCoder, + context.stateInternalsFactory()); + context.addOperator(operator, operator.output); + context.addStream(context.getInput(), operator.input); + if (!sideInputs.isEmpty()) { + ParDoBoundMultiTranslator.addSideInputs(operator, sideInputs, context); + } + } +} diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java index 64ca0ee4fd07..b2e29b6de085 100644 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/FlattenPCollectionTranslatorTest.java @@ -110,8 +110,7 @@ public void testFlattenSingleCollection() { PCollectionList.of(single).apply(Flatten.pCollections()) .apply(ParDo.of(new EmbeddedCollector())); translator.translate(p, dag); - Assert.assertNotNull( - dag.getOperatorMeta("ParDo(EmbeddedCollector)/ParMultiDo(EmbeddedCollector)")); + Assert.assertNotNull(dag.getOperatorMeta("ParDo(EmbeddedCollector)")); } } diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslatorTest.java similarity index 98% rename from runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java rename to runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslatorTest.java index 83e68f7822d4..2aa07208cb09 100644 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslatorTest.java @@ -68,11 +68,11 @@ import org.slf4j.LoggerFactory; /** - * integration test for {@link ParDoTranslator}. + * integration test for {@link ParDoBoundTranslator}. */ @RunWith(JUnit4.class) -public class ParDoTranslatorTest { - private static final Logger LOG = LoggerFactory.getLogger(ParDoTranslatorTest.class); +public class ParDoBoundTranslatorTest { + private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundTranslatorTest.class); private static final long SLEEP_MILLIS = 500; private static final long TIMEOUT_MILLIS = 30000; @@ -98,7 +98,7 @@ public void test() throws Exception { Assert.assertNotNull(om); Assert.assertEquals(om.getOperator().getClass(), ApexReadUnboundedInputOperator.class); - om = dag.getOperatorMeta("ParDo(Add)/ParMultiDo(Add)"); + om = dag.getOperatorMeta("ParDo(Add)"); Assert.assertNotNull(om); Assert.assertEquals(om.getOperator().getClass(), ApexParDoOperator.class); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index 4601262ef261..f56d225f1f1b 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -89,10 +89,24 @@ public class DirectRunner extends PipelineRunner { .put( PTransformMatchers.classEqualTo(TestStream.class), new DirectTestStreamFactory()) /* primitive */ + /* Single-output ParDos are implemented in terms of Multi-output ParDos. Any override + that is applied to a multi-output ParDo must first have all matching Single-output ParDos + converted to match. + */ + .put(PTransformMatchers.splittableParDoSingle(), new ParDoSingleViaMultiOverrideFactory()) + .put( + PTransformMatchers.stateOrTimerParDoSingle(), + new ParDoSingleViaMultiOverrideFactory()) + // SplittableParMultiDo is implemented in terms of nonsplittable single ParDos + .put(PTransformMatchers.splittableParDoMulti(), new ParDoMultiOverrideFactory()) + // state and timer pardos are implemented in terms of nonsplittable single ParDos + .put(PTransformMatchers.stateOrTimerParDoMulti(), new ParDoMultiOverrideFactory()) + .put( + PTransformMatchers.classEqualTo(ParDo.Bound.class), + new ParDoSingleViaMultiOverrideFactory()) /* returns a BoundMulti */ .put( PTransformMatchers.classEqualTo(BoundMulti.class), - /* returns one of two primitives; SplittableParDos and ParDos with state and timers - are replaced appropriately by the override factory. */ + /* returns one of two primitives; SplittableParDos are replaced above. */ new ParDoMultiOverrideFactory()) .put( PTransformMatchers.classEqualTo(GBKIntoKeyedWorkItems.class), diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java new file mode 100644 index 000000000000..f8597299217b --- /dev/null +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; +import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.ParDo.Bound; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; + +/** + * A {@link PTransformOverrideFactory} that overrides single-output {@link ParDo} to implement + * it in terms of multi-output {@link ParDo}. + */ +class ParDoSingleViaMultiOverrideFactory + extends SingleInputOutputOverrideFactory< + PCollection, PCollection, Bound> { + @Override + public PTransform, PCollection> getReplacementTransform( + Bound transform) { + return new ParDoSingleViaMulti<>(transform); + } + + static class ParDoSingleViaMulti + extends PTransform, PCollection> { + private static final String MAIN_OUTPUT_TAG = "main"; + + private final ParDo.Bound underlyingParDo; + + public ParDoSingleViaMulti(ParDo.Bound underlyingParDo) { + this.underlyingParDo = underlyingParDo; + } + + @Override + public PCollection expand(PCollection input) { + + // Output tags for ParDo need only be unique up to applied transform + TupleTag mainOutputTag = new TupleTag(MAIN_OUTPUT_TAG); + + PCollectionTuple outputs = + input.apply( + ParDo.of(underlyingParDo.getFn()) + .withSideInputs(underlyingParDo.getSideInputs()) + .withOutputTags(mainOutputTag, TupleTagList.empty())); + PCollection output = outputs.get(mainOutputTag); + + output.setTypeDescriptor(underlyingParDo.getFn().getOutputTypeDescriptor()); + return output; + } + } +} diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java new file mode 100644 index 000000000000..59577a82b3b9 --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactoryTest.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.direct; + +import static org.junit.Assert.assertThat; + +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.hamcrest.Matchers; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link ParDoSingleViaMultiOverrideFactory}. + */ +@RunWith(JUnit4.class) +public class ParDoSingleViaMultiOverrideFactoryTest { + private ParDoSingleViaMultiOverrideFactory factory = + new ParDoSingleViaMultiOverrideFactory<>(); + + @Test + public void getInputSucceeds() { + TestPipeline p = TestPipeline.create(); + PCollection input = p.apply(Create.of(1, 2, 3)); + PCollection reconstructed = factory.getInput(input.expand(), p); + assertThat(reconstructed, Matchers.>equalTo(input)); + } +} diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java index 31a6bdace118..f043c901391f 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java @@ -112,7 +112,8 @@ class FlinkBatchTransformTranslators { TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslatorBatch()); - TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoTranslatorBatch()); + TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundTranslatorBatch()); + TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiTranslatorBatch()); TRANSLATORS.put(Read.Bounded.class, new ReadSourceTranslatorBatch()); } @@ -497,7 +498,80 @@ private static void rejectSplittable(DoFn doFn) { } } - private static class ParDoTranslatorBatch + private static class ParDoBoundTranslatorBatch + implements FlinkBatchPipelineTranslator.BatchTransformTranslator< + ParDo.Bound> { + + @Override + @SuppressWarnings("unchecked") + public void translateNode( + ParDo.Bound transform, + + FlinkBatchTranslationContext context) { + DoFn doFn = transform.getFn(); + rejectSplittable(doFn); + + DataSet> inputDataSet = + context.getInputDataSet(context.getInput(transform)); + + TypeInformation> typeInformation = + context.getTypeInfo(context.getOutput(transform)); + + List> sideInputs = transform.getSideInputs(); + + // construct a map from side input to WindowingStrategy so that + // the DoFn runner can map main-input windows to side input windows + Map, WindowingStrategy> sideInputStrategies = new HashMap<>(); + for (PCollectionView sideInput: sideInputs) { + sideInputStrategies.put(sideInput, sideInput.getWindowingStrategyInternal()); + } + + WindowingStrategy windowingStrategy = + context.getOutput(transform).getWindowingStrategy(); + + SingleInputUdfOperator, WindowedValue, ?> outputDataSet; + DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass()); + if (signature.stateDeclarations().size() > 0 + || signature.timerDeclarations().size() > 0) { + + // Based on the fact that the signature is stateful, DoFnSignatures ensures + // that it is also keyed + KvCoder inputCoder = + (KvCoder) context.getInput(transform).getCoder(); + + FlinkStatefulDoFnFunction doFnWrapper = new FlinkStatefulDoFnFunction<>( + (DoFn) doFn, windowingStrategy, sideInputStrategies, context.getPipelineOptions(), + null, new TupleTag() + ); + + Grouping> grouping = + inputDataSet.groupBy(new KvKeySelector(inputCoder.getKeyCoder())); + + outputDataSet = new GroupReduceOperator( + grouping, typeInformation, doFnWrapper, transform.getName()); + + } else { + FlinkDoFnFunction doFnWrapper = + new FlinkDoFnFunction<>( + doFn, + windowingStrategy, + sideInputStrategies, + context.getPipelineOptions(), + null, new TupleTag()); + + outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, + transform.getName()); + + } + + transformSideInputs(sideInputs, outputDataSet, context); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + + } + } + + private static class ParDoBoundMultiTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator< ParDo.BoundMulti> { diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 7227dceddbc1..c7df91dc1c1b 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -121,7 +121,8 @@ class FlinkStreamingTransformTranslators { TRANSLATORS.put(Write.class, new WriteSinkStreamingTranslator()); TRANSLATORS.put(TextIO.Write.Bound.class, new TextIOWriteBoundStreamingTranslator()); - TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoStreamingTranslator()); + TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundStreamingTranslator()); + TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiStreamingTranslator()); TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslator()); TRANSLATORS.put(Flatten.PCollections.class, new FlattenPCollectionTranslator()); @@ -319,6 +320,114 @@ private static void rejectSplittable(DoFn doFn) { } } + private static class ParDoBoundStreamingTranslator + extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< + ParDo.Bound> { + + @Override + public void translateNode( + ParDo.Bound transform, + FlinkStreamingTranslationContext context) { + + DoFn doFn = transform.getFn(); + rejectSplittable(doFn); + + WindowingStrategy windowingStrategy = + context.getOutput(transform).getWindowingStrategy(); + + TypeInformation> typeInfo = + context.getTypeInfo(context.getOutput(transform)); + + List> sideInputs = transform.getSideInputs(); + + @SuppressWarnings("unchecked") + PCollection inputPCollection = (PCollection) context.getInput(transform); + + Coder> inputCoder = context.getCoder(inputPCollection); + + DataStream> inputDataStream = + context.getInputDataStream(context.getInput(transform)); + Coder keyCoder = null; + boolean stateful = false; + DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass()); + if (signature.stateDeclarations().size() > 0 + || signature.timerDeclarations().size() > 0) { + // Based on the fact that the signature is stateful, DoFnSignatures ensures + // that it is also keyed + keyCoder = ((KvCoder) inputPCollection.getCoder()).getKeyCoder(); + inputDataStream = inputDataStream.keyBy(new KvToByteBufferKeySelector(keyCoder)); + stateful = true; + } + + if (sideInputs.isEmpty()) { + DoFnOperator> doFnOperator = + new DoFnOperator<>( + transform.getFn(), + inputCoder, + new TupleTag("main output"), + Collections.>emptyList(), + new DoFnOperator.DefaultOutputManagerFactory>(), + windowingStrategy, + new HashMap>(), /* side-input mapping */ + Collections.>emptyList(), /* side inputs */ + context.getPipelineOptions(), + keyCoder); + + SingleOutputStreamOperator> outDataStream = inputDataStream + .transform(transform.getName(), typeInfo, doFnOperator); + + context.setOutputDataStream(context.getOutput(transform), outDataStream); + } else { + Tuple2>, DataStream> transformedSideInputs = + transformSideInputs(sideInputs, context); + + DoFnOperator> doFnOperator = + new DoFnOperator<>( + transform.getFn(), + inputCoder, + new TupleTag("main output"), + Collections.>emptyList(), + new DoFnOperator.DefaultOutputManagerFactory>(), + windowingStrategy, + transformedSideInputs.f0, + sideInputs, + context.getPipelineOptions(), + keyCoder); + + SingleOutputStreamOperator> outDataStream; + if (stateful) { + // we have to manually contruct the two-input transform because we're not + // allowed to have only one input keyed, normally. + KeyedStream keyedStream = (KeyedStream) inputDataStream; + TwoInputTransformation< + WindowedValue>, + RawUnionValue, + WindowedValue> rawFlinkTransform = new TwoInputTransformation<>( + keyedStream.getTransformation(), + transformedSideInputs.f1.broadcast().getTransformation(), + transform.getName(), + (TwoInputStreamOperator) doFnOperator, + typeInfo, + keyedStream.getParallelism()); + + rawFlinkTransform.setStateKeyType(keyedStream.getKeyType()); + rawFlinkTransform.setStateKeySelectors(keyedStream.getKeySelector(), null); + + outDataStream = new SingleOutputStreamOperator( + keyedStream.getExecutionEnvironment(), + rawFlinkTransform) {}; // we have to cheat around the ctor being protected + + keyedStream.getExecutionEnvironment().addOperator(rawFlinkTransform); + } else { + outDataStream = inputDataStream + .connect(transformedSideInputs.f1.broadcast()) + .transform(transform.getName(), typeInfo, doFnOperator); + } + context.setOutputDataStream(context.getOutput(transform), outDataStream); + } + } + } + /** * Wraps each element in a {@link RawUnionValue} with the given tag id. */ @@ -396,7 +505,7 @@ public RawUnionValue map(T o) throws Exception { } - private static class ParDoStreamingTranslator + private static class ParDoBoundMultiStreamingTranslator extends FlinkStreamingPipelineTranslator.StreamTransformTranslator< ParDo.BoundMulti> { diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index ab4cb9c67b04..06e50483440d 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -45,6 +45,7 @@ import com.google.common.base.Supplier; import com.google.common.collect.BiMap; import com.google.common.collect.ImmutableBiMap; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import java.io.IOException; import java.util.ArrayList; @@ -846,6 +847,34 @@ private void translateMultiHelper( } }); + registerTransformTranslator( + ParDo.Bound.class, + new TransformTranslator() { + @Override + public void translate(ParDo.Bound transform, TranslationContext context) { + translateSingleHelper(transform, context); + } + + private void translateSingleHelper( + ParDo.Bound transform, TranslationContext context) { + + StepTranslationContext stepContext = context.addStep(transform, "ParallelDo"); + translateInputs( + stepContext, context.getInput(transform), transform.getSideInputs(), context); + long mainOutput = stepContext.addOutput(context.getOutput(transform)); + translateFn( + stepContext, + transform.getFn(), + context.getInput(transform).getWindowingStrategy(), + transform.getSideInputs(), + context.getInput(transform).getCoder(), + context, + mainOutput, + ImmutableMap.>of( + mainOutput, new TupleTag<>(PropertyNames.OUTPUT))); + } + }); + registerTransformTranslator( Window.Assign.class, new TransformTranslator() { diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java index ccb185cbea93..d4271e52f7c7 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java @@ -525,8 +525,7 @@ private static Step createPredefinedStep() throws Exception { assertEquals(13, job.getSteps().size()); Step step = job.getSteps().get(1); - assertEquals( - stepName + "/ParMultiDo(NoOp)", getString(step.getProperties(), PropertyNames.USER_NAME)); + assertEquals(stepName, getString(step.getProperties(), PropertyNames.USER_NAME)); assertAllStepOutputsHaveUniqueIds(job); return step; } @@ -972,7 +971,7 @@ public void populateDisplayData(DisplayData.Builder builder) { .put("type", "JAVA_CLASS") .put("value", fn1.getClass().getName()) .put("shortValue", fn1.getClass().getSimpleName()) - .put("namespace", ParDo.BoundMulti.class.getName()) + .put("namespace", parDo1.getClass().getName()) .build(), ImmutableMap.builder() .put("key", "foo2") @@ -992,7 +991,7 @@ public void populateDisplayData(DisplayData.Builder builder) { .put("type", "JAVA_CLASS") .put("value", fn2.getClass().getName()) .put("shortValue", fn2.getClass().getSimpleName()) - .put("namespace", ParDo.BoundMulti.class.getName()) + .put("namespace", parDo2.getClass().getName()) .build(), ImmutableMap.builder() .put("key", "foo3") diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 0ae731328913..725d157b4659 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -27,7 +27,6 @@ import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable; import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers; -import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.io.IOException; @@ -75,7 +74,6 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.hadoop.conf.Configuration; @@ -326,19 +324,38 @@ public Iterable> call( }; } - private static TransformEvaluator> parDo() { - return new TransformEvaluator>() { + private static TransformEvaluator> parDo() { + return new TransformEvaluator>() { @Override - public void evaluate(ParDo.BoundMulti transform, EvaluationContext context) { - if (transform.getSideOutputTags().size() == 0) { - evaluateSingle(transform, context); - } else { - evaluateMulti(transform, context); - } + public void evaluate(ParDo.Bound transform, EvaluationContext context) { + String stepName = context.getCurrentTransform().getFullName(); + DoFn doFn = transform.getFn(); + rejectSplittable(doFn); + rejectStateAndTimers(doFn); + @SuppressWarnings("unchecked") + JavaRDD> inRDD = + ((BoundedDataset) context.borrowDataset(transform)).getRDD(); + WindowingStrategy windowingStrategy = + context.getInput(transform).getWindowingStrategy(); + JavaSparkContext jsc = context.getSparkContext(); + Accumulator aggAccum = + SparkAggregators.getNamedAggregators(jsc); + Accumulator metricsAccum = + MetricsAccumulator.getInstance(); + Map, KV, SideInputBroadcast>> sideInputs = + TranslationUtils.getSideInputs(transform.getSideInputs(), context); + context.putDataset(transform, + new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(aggAccum, metricsAccum, + stepName, doFn, context.getRuntimeContext(), sideInputs, windowingStrategy)))); } + }; + } - private void evaluateMulti( - ParDo.BoundMulti transform, EvaluationContext context) { + private static TransformEvaluator> + multiDo() { + return new TransformEvaluator>() { + @Override + public void evaluate(ParDo.BoundMulti transform, EvaluationContext context) { String stepName = context.getCurrentTransform().getFullName(); DoFn doFn = transform.getFn(); rejectSplittable(doFn); @@ -349,21 +366,16 @@ private void evaluateMulti( WindowingStrategy windowingStrategy = context.getInput(transform).getWindowingStrategy(); JavaSparkContext jsc = context.getSparkContext(); - Accumulator aggAccum = SparkAggregators.getNamedAggregators(jsc); - Accumulator metricsAccum = MetricsAccumulator.getInstance(); - JavaPairRDD, WindowedValue> all = - inRDD - .mapPartitionsToPair( - new MultiDoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - context.getRuntimeContext(), - transform.getMainOutputTag(), - TranslationUtils.getSideInputs(transform.getSideInputs(), context), - windowingStrategy)) - .cache(); + Accumulator aggAccum = + SparkAggregators.getNamedAggregators(jsc); + Accumulator metricsAccum = + MetricsAccumulator.getInstance(); + JavaPairRDD, WindowedValue> all = inRDD + .mapPartitionsToPair( + new MultiDoFnFunction<>(aggAccum, metricsAccum, stepName, doFn, + context.getRuntimeContext(), transform.getMainOutputTag(), + TranslationUtils.getSideInputs(transform.getSideInputs(), context), + windowingStrategy)).cache(); List pct = context.getOutputs(transform); for (TaggedPValue e : pct) { @SuppressWarnings("unchecked") @@ -376,37 +388,6 @@ private void evaluateMulti( context.putDataset(e.getValue(), new BoundedDataset<>(values)); } } - - private void evaluateSingle( - ParDo.BoundMulti transform, EvaluationContext context) { - String stepName = context.getCurrentTransform().getFullName(); - DoFn doFn = transform.getFn(); - rejectSplittable(doFn); - rejectStateAndTimers(doFn); - @SuppressWarnings("unchecked") - JavaRDD> inRDD = - ((BoundedDataset) context.borrowDataset(transform)).getRDD(); - WindowingStrategy windowingStrategy = - context.getInput(transform).getWindowingStrategy(); - JavaSparkContext jsc = context.getSparkContext(); - Accumulator aggAccum = SparkAggregators.getNamedAggregators(jsc); - Accumulator metricsAccum = MetricsAccumulator.getInstance(); - Map, KV, SideInputBroadcast>> sideInputs = - TranslationUtils.getSideInputs(transform.getSideInputs(), context); - PValue onlyOutput = Iterables.getOnlyElement(context.getOutputs(transform)).getValue(); - context.putDataset( - onlyOutput, - new BoundedDataset<>( - inRDD.mapPartitions( - new DoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - context.getRuntimeContext(), - sideInputs, - windowingStrategy)))); - } }; } @@ -761,7 +742,8 @@ private static TransformEvaluator transform, }; } + private static TransformEvaluator> parDo() { + return new TransformEvaluator>() { + @Override + public void evaluate(final ParDo.Bound transform, + final EvaluationContext context) { + final DoFn doFn = transform.getFn(); + rejectSplittable(doFn); + rejectStateAndTimers(doFn); + final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); + final WindowingStrategy windowingStrategy = + context.getInput(transform).getWindowingStrategy(); + final SparkPCollectionView pviews = context.getPViews(); + + @SuppressWarnings("unchecked") + UnboundedDataset unboundedDataset = + ((UnboundedDataset) context.borrowDataset(transform)); + JavaDStream> dStream = unboundedDataset.getDStream(); + + final String stepName = context.getCurrentTransform().getFullName(); + + JavaDStream> outStream = + dStream.transform(new Function>, + JavaRDD>>() { + @Override + public JavaRDD> call(JavaRDD> rdd) throws + Exception { + final JavaSparkContext jsc = new JavaSparkContext(rdd.context()); + final Accumulator aggAccum = + SparkAggregators.getNamedAggregators(jsc); + final Accumulator metricsAccum = + MetricsAccumulator.getInstance(); + final Map, KV, SideInputBroadcast>> sideInputs = + TranslationUtils.getSideInputs(transform.getSideInputs(), + jsc, pviews); + return rdd.mapPartitions( + new DoFnFunction<>(aggAccum, metricsAccum, stepName, doFn, runtimeContext, + sideInputs, windowingStrategy)); + } + }); + + context.putDataset(transform, + new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); + } + }; + } + private static TransformEvaluator> multiDo() { return new TransformEvaluator>() { - public void evaluate( - final ParDo.BoundMulti transform, final EvaluationContext context) { - if (transform.getSideOutputTags().size() == 0) { - evaluateSingle(transform, context); - } else { - evaluateMulti(transform, context); - } - } - - private void evaluateMulti( - final ParDo.BoundMulti transform, final EvaluationContext context) { + @Override + public void evaluate(final ParDo.BoundMulti transform, + final EvaluationContext context) { final DoFn doFn = transform.getFn(); rejectSplittable(doFn); rejectStateAndTimers(doFn); @@ -389,60 +427,10 @@ public JavaPairRDD, WindowedValue> call( JavaDStream> values = (JavaDStream>) (JavaDStream) TranslationUtils.dStreamValues(filtered); - context.putDataset( - e.getValue(), new UnboundedDataset<>(values, unboundedDataset.getStreamSources())); + context.putDataset(e.getValue(), + new UnboundedDataset<>(values, unboundedDataset.getStreamSources())); } } - - private void evaluateSingle( - final ParDo.BoundMulti transform, final EvaluationContext context) { - final DoFn doFn = transform.getFn(); - rejectSplittable(doFn); - rejectStateAndTimers(doFn); - final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); - final WindowingStrategy windowingStrategy = - context.getInput(transform).getWindowingStrategy(); - final SparkPCollectionView pviews = context.getPViews(); - - @SuppressWarnings("unchecked") - UnboundedDataset unboundedDataset = - ((UnboundedDataset) context.borrowDataset(transform)); - JavaDStream> dStream = unboundedDataset.getDStream(); - - final String stepName = context.getCurrentTransform().getFullName(); - - JavaDStream> outStream = - dStream.transform( - new Function>, JavaRDD>>() { - @Override - public JavaRDD> call(JavaRDD> rdd) - throws Exception { - final JavaSparkContext jsc = new JavaSparkContext(rdd.context()); - final Accumulator aggAccum = - SparkAggregators.getNamedAggregators(jsc); - final Accumulator metricsAccum = - MetricsAccumulator.getInstance(); - final Map, KV, SideInputBroadcast>> - sideInputs = - TranslationUtils.getSideInputs(transform.getSideInputs(), jsc, pviews); - return rdd.mapPartitions( - new DoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - runtimeContext, - sideInputs, - windowingStrategy)); - } - }); - - PCollection output = - (PCollection) - Iterables.getOnlyElement(context.getOutputs(transform)).getValue(); - context.putDataset( - output, new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); - } }; } @@ -487,6 +475,7 @@ public JavaRDD>> call( EVALUATORS.put(Read.Unbounded.class, readUnbounded()); EVALUATORS.put(GroupByKey.class, groupByKey()); EVALUATORS.put(Combine.GroupedValues.class, combineGrouped()); + EVALUATORS.put(ParDo.Bound.class, parDo()); EVALUATORS.put(ParDo.BoundMulti.class, multiDo()); EVALUATORS.put(ConsoleIO.Write.Unbound.class, print()); EVALUATORS.put(CreateStream.class, createFromQueue()); diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java index d66633b4c49b..b181a042820c 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java @@ -83,7 +83,7 @@ public void testTrackSingle() { p.apply(emptyStream).apply(ParDo.of(new PassthroughFn<>())); - p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.BoundMulti.class, 0)); + p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.Bound.class, 0)); assertThat(StreamingSourceTracker.numAssertions, equalTo(1)); } @@ -111,7 +111,7 @@ public void testTrackFlattened() { PCollectionList.of(pcol1).and(pcol2).apply(Flatten.pCollections()); flattened.apply(ParDo.of(new PassthroughFn<>())); - p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.BoundMulti.class, 0, 1)); + p.traverseTopologically(new StreamingSourceTracker(jssc, p, ParDo.Bound.class, 0, 1)); assertThat(StreamingSourceTracker.numAssertions, equalTo(1)); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 92252310f961..19c5a2d5b511 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -738,8 +738,12 @@ public BoundMulti withOutputTags( @Override public PCollection expand(PCollection input) { - TupleTag mainOutput = new TupleTag<>(); - return input.apply(withOutputTags(mainOutput, TupleTagList.empty())).get(mainOutput); + validateWindowType(input, fn); + return PCollection.createPrimitiveOutputInternal( + input.getPipeline(), + input.getWindowingStrategy(), + input.isBounded()) + .setTypeDescriptor(getFn().getOutputTypeDescriptor()); } @Override