diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 000eada815022..26f0ade7ec9bc 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -385,55 +385,53 @@ public void evaluate( JavaDStream> dStream = unboundedDataset.getDStream(); final String stepName = context.getCurrentTransform().getFullName(); - if (transform.getAdditionalOutputTags().size() == 0) { - JavaPairDStream, WindowedValue> all = - dStream.transformToPair( - new Function< - JavaRDD>, - JavaPairRDD, WindowedValue>>() { - @Override - public JavaPairRDD, WindowedValue> call( - JavaRDD> rdd) throws Exception { - final Accumulator aggAccum = - AggregatorsAccumulator.getInstance(); - final Accumulator metricsAccum = - MetricsAccumulator.getInstance(); - final Map, KV, SideInputBroadcast>> - sideInputs = - TranslationUtils.getSideInputs( - transform.getSideInputs(), - JavaSparkContext.fromSparkContext(rdd.context()), - pviews); - return rdd.mapPartitionsToPair( - new MultiDoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - runtimeContext, - transform.getMainOutputTag(), - sideInputs, - windowingStrategy)); - } - }); - Map, PValue> outputs = context.getOutputs(transform); - if (outputs.size() > 1) { - // cache the DStream if we're going to filter it more than once. - all.cache(); - } - for (Map.Entry, PValue> output : outputs.entrySet()) { - @SuppressWarnings("unchecked") - JavaPairDStream, WindowedValue> filtered = - all.filter(new TranslationUtils.TupleTagFilter(output.getKey())); - @SuppressWarnings("unchecked") - // Object is the best we can do since different outputs can have different tags - JavaDStream> values = - (JavaDStream>) - (JavaDStream) TranslationUtils.dStreamValues(filtered); - context.putDataset( - output.getValue(), - new UnboundedDataset<>(values, unboundedDataset.getStreamSources())); - } + JavaPairDStream, WindowedValue> all = + dStream.transformToPair( + new Function< + JavaRDD>, + JavaPairRDD, WindowedValue>>() { + @Override + public JavaPairRDD, WindowedValue> call( + JavaRDD> rdd) throws Exception { + final Accumulator aggAccum = + AggregatorsAccumulator.getInstance(); + final Accumulator metricsAccum = + MetricsAccumulator.getInstance(); + final Map, KV, SideInputBroadcast>> + sideInputs = + TranslationUtils.getSideInputs( + transform.getSideInputs(), + JavaSparkContext.fromSparkContext(rdd.context()), + pviews); + return rdd.mapPartitionsToPair( + new MultiDoFnFunction<>( + aggAccum, + metricsAccum, + stepName, + doFn, + runtimeContext, + transform.getMainOutputTag(), + sideInputs, + windowingStrategy)); + } + }); + Map, PValue> outputs = context.getOutputs(transform); + if (outputs.size() > 1) { + // cache the DStream if we're going to filter it more than once. + all.cache(); + } + for (Map.Entry, PValue> output : outputs.entrySet()) { + @SuppressWarnings("unchecked") + JavaPairDStream, WindowedValue> filtered = + all.filter(new TranslationUtils.TupleTagFilter(output.getKey())); + @SuppressWarnings("unchecked") + // Object is the best we can do since different outputs can have different tags + JavaDStream> values = + (JavaDStream>) + (JavaDStream) TranslationUtils.dStreamValues(filtered); + context.putDataset( + output.getValue(), + new UnboundedDataset<>(values, unboundedDataset.getStreamSources())); } } diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java index 78b8039231e95..dd52c05985eb4 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java @@ -33,8 +33,10 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.Values; @@ -51,7 +53,10 @@ import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Rule; @@ -344,6 +349,51 @@ public void testFlattenedWithWatermarkHold() throws IOException { p.run(); } + /** + * Test multiple output {@link ParDo} in streaming pipelines. + * This is currently needed as a test for https://issues.apache.org/jira/browse/BEAM-2029 since + * {@link org.apache.beam.sdk.testing.ValidatesRunner} tests do not currently run for Spark runner + * in streaming mode. + */ + @Test + public void testMultiOutputParDo() throws IOException { + Pipeline p = pipelineRule.createPipeline(); + Instant instant = new Instant(0); + CreateStream source1 = + CreateStream.of(VarIntCoder.of(), pipelineRule.batchDuration()) + .emptyBatch() + .advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(5))) + .nextBatch( + TimestampedValue.of(1, instant), + TimestampedValue.of(2, instant), + TimestampedValue.of(3, instant)) + .advanceNextBatchWatermarkToInfinity(); + + PCollection inputs = p.apply(source1); + + final TupleTag mainTag = new TupleTag<>(); + final TupleTag additionalTag = new TupleTag<>(); + + PCollectionTuple outputs = inputs.apply(ParDo.of(new DoFn() { + + @SuppressWarnings("unused") + @ProcessElement + public void process(ProcessContext context) { + Integer element = context.element(); + context.output(element); + context.output(additionalTag, element + 1); + } + }).withOutputTags(mainTag, TupleTagList.of(additionalTag))); + + PCollection output1 = outputs.get(mainTag).setCoder(VarIntCoder.of()); + PCollection output2 = outputs.get(additionalTag).setCoder(VarIntCoder.of()); + + PAssert.that(output1).containsInAnyOrder(1, 2, 3); + PAssert.that(output2).containsInAnyOrder(2, 3, 4); + + p.run(); + } + @Test public void testElementAtPositiveInfinityThrows() { CreateStream source =