Skip to content

Commit

Permalink
This closes #2746
Browse files Browse the repository at this point in the history
  • Loading branch information
dhalperi committed Apr 27, 2017
2 parents c493695 + 1925a50 commit a46eb1a
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 49 deletions.
Expand Up @@ -385,55 +385,53 @@ public void evaluate(
JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream();

final String stepName = context.getCurrentTransform().getFullName();
if (transform.getAdditionalOutputTags().size() == 0) {
JavaPairDStream<TupleTag<?>, WindowedValue<?>> all =
dStream.transformToPair(
new Function<
JavaRDD<WindowedValue<InputT>>,
JavaPairRDD<TupleTag<?>, WindowedValue<?>>>() {
@Override
public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call(
JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
final Accumulator<NamedAggregators> aggAccum =
AggregatorsAccumulator.getInstance();
final Accumulator<SparkMetricsContainer> metricsAccum =
MetricsAccumulator.getInstance();
final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>>
sideInputs =
TranslationUtils.getSideInputs(
transform.getSideInputs(),
JavaSparkContext.fromSparkContext(rdd.context()),
pviews);
return rdd.mapPartitionsToPair(
new MultiDoFnFunction<>(
aggAccum,
metricsAccum,
stepName,
doFn,
runtimeContext,
transform.getMainOutputTag(),
sideInputs,
windowingStrategy));
}
});
Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform);
if (outputs.size() > 1) {
// cache the DStream if we're going to filter it more than once.
all.cache();
}
for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
@SuppressWarnings("unchecked")
JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered =
all.filter(new TranslationUtils.TupleTagFilter(output.getKey()));
@SuppressWarnings("unchecked")
// Object is the best we can do since different outputs can have different tags
JavaDStream<WindowedValue<Object>> values =
(JavaDStream<WindowedValue<Object>>)
(JavaDStream<?>) TranslationUtils.dStreamValues(filtered);
context.putDataset(
output.getValue(),
new UnboundedDataset<>(values, unboundedDataset.getStreamSources()));
}
JavaPairDStream<TupleTag<?>, WindowedValue<?>> all =
dStream.transformToPair(
new Function<
JavaRDD<WindowedValue<InputT>>,
JavaPairRDD<TupleTag<?>, WindowedValue<?>>>() {
@Override
public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call(
JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
final Accumulator<NamedAggregators> aggAccum =
AggregatorsAccumulator.getInstance();
final Accumulator<SparkMetricsContainer> metricsAccum =
MetricsAccumulator.getInstance();
final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>>
sideInputs =
TranslationUtils.getSideInputs(
transform.getSideInputs(),
JavaSparkContext.fromSparkContext(rdd.context()),
pviews);
return rdd.mapPartitionsToPair(
new MultiDoFnFunction<>(
aggAccum,
metricsAccum,
stepName,
doFn,
runtimeContext,
transform.getMainOutputTag(),
sideInputs,
windowingStrategy));
}
});
Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform);
if (outputs.size() > 1) {
// cache the DStream if we're going to filter it more than once.
all.cache();
}
for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
@SuppressWarnings("unchecked")
JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered =
all.filter(new TranslationUtils.TupleTagFilter(output.getKey()));
@SuppressWarnings("unchecked")
// Object is the best we can do since different outputs can have different tags
JavaDStream<WindowedValue<Object>> values =
(JavaDStream<WindowedValue<Object>>)
(JavaDStream<?>) TranslationUtils.dStreamValues(filtered);
context.putDataset(
output.getValue(),
new UnboundedDataset<>(values, unboundedDataset.getStreamSources()));
}
}

Expand Down
Expand Up @@ -33,8 +33,10 @@
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.Values;
Expand All @@ -51,7 +53,10 @@
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Rule;
Expand Down Expand Up @@ -344,6 +349,51 @@ public void testFlattenedWithWatermarkHold() throws IOException {
p.run();
}

/**
* Test multiple output {@link ParDo} in streaming pipelines.
* This is currently needed as a test for https://issues.apache.org/jira/browse/BEAM-2029 since
* {@link org.apache.beam.sdk.testing.ValidatesRunner} tests do not currently run for Spark runner
* in streaming mode.
*/
@Test
public void testMultiOutputParDo() throws IOException {
Pipeline p = pipelineRule.createPipeline();
Instant instant = new Instant(0);
CreateStream<Integer> source1 =
CreateStream.of(VarIntCoder.of(), pipelineRule.batchDuration())
.emptyBatch()
.advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(5)))
.nextBatch(
TimestampedValue.of(1, instant),
TimestampedValue.of(2, instant),
TimestampedValue.of(3, instant))
.advanceNextBatchWatermarkToInfinity();

PCollection<Integer> inputs = p.apply(source1);

final TupleTag<Integer> mainTag = new TupleTag<>();
final TupleTag<Integer> additionalTag = new TupleTag<>();

PCollectionTuple outputs = inputs.apply(ParDo.of(new DoFn<Integer, Integer>() {

@SuppressWarnings("unused")
@ProcessElement
public void process(ProcessContext context) {
Integer element = context.element();
context.output(element);
context.output(additionalTag, element + 1);
}
}).withOutputTags(mainTag, TupleTagList.of(additionalTag)));

PCollection<Integer> output1 = outputs.get(mainTag).setCoder(VarIntCoder.of());
PCollection<Integer> output2 = outputs.get(additionalTag).setCoder(VarIntCoder.of());

PAssert.that(output1).containsInAnyOrder(1, 2, 3);
PAssert.that(output2).containsInAnyOrder(2, 3, 4);

p.run();
}

@Test
public void testElementAtPositiveInfinityThrows() {
CreateStream<Integer> source =
Expand Down

0 comments on commit a46eb1a

Please sign in to comment.