From 975dec257364d68b5ada3bced7f139e88853722a Mon Sep 17 00:00:00 2001 From: Sela Date: Sun, 18 Dec 2016 14:36:53 +0200 Subject: [PATCH 1/2] SparkUnboundedSource mapWithStateDStream input data shuold be in serialized form for shuffle and checkpointing. Emit read count and watermark per microbatch. --- .../runners/spark/io/MicrobatchSource.java | 4 +- .../spark/io/SparkUnboundedSource.java | 120 +++++++++++++----- .../spark/stateful/StateSpecFunctions.java | 37 ++++-- 3 files changed, 115 insertions(+), 46 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/MicrobatchSource.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/MicrobatchSource.java index 565637597073..d5ed8dddce4a 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/MicrobatchSource.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/MicrobatchSource.java @@ -263,8 +263,8 @@ public CheckpointMarkT getCheckpointMark() { return (CheckpointMarkT) reader.getCheckpointMark(); } - public long getNumRecordsRead() { - return recordsRead; + public Instant getWatermark() { + return reader.getWatermark(); } } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java index 394b02373436..9a4550419803 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java @@ -18,16 +18,20 @@ package org.apache.beam.runners.spark.io; +import java.io.Serializable; import java.util.Collections; -import java.util.Iterator; import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.stateful.StateSpecFunctions; import org.apache.beam.runners.spark.translation.SparkRuntimeContext; -import org.apache.beam.runners.spark.translation.TranslationUtils; import org.apache.beam.sdk.io.Source; import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext$; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function; import org.apache.spark.rdd.RDD; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.StateSpec; @@ -39,6 +43,10 @@ import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.apache.spark.streaming.dstream.DStream; import org.apache.spark.streaming.scheduler.StreamInputInfo; +import org.joda.time.Instant; + +import scala.Tuple2; +import scala.runtime.BoxedUnit; /** @@ -75,20 +83,39 @@ JavaDStream> read(JavaStreamingContext jssc, // call mapWithState to read from a checkpointable sources. JavaMapWithStateDStream, CheckpointMarkT, byte[], - Iterator>> mapWithStateDStream = inputDStream.mapWithState( + Tuple2, Metadata>> mapWithStateDStream = inputDStream.mapWithState( StateSpec.function(StateSpecFunctions.mapSourceFunction(rc))); // set checkpoint duration for read stream, if set. checkpointStream(mapWithStateDStream, options); - // flatmap and report read elements. Use the inputDStream's id to tie between the reported - // info and the inputDStream it originated from. - int id = inputDStream.inputDStream().id(); - ReportingFlatMappedDStream> reportingFlatMappedDStream = - new ReportingFlatMappedDStream<>(mapWithStateDStream.dstream(), id, - getSourceName(source, id)); + // cache since checkpointing is less frequent. + mapWithStateDStream.cache(); - return JavaDStream.fromDStream(reportingFlatMappedDStream, - JavaSparkContext$.MODULE$.>fakeClassTag()); + // report the number of input elements for this InputDStream to the InputInfoTracker. + int id = inputDStream.inputDStream().id(); + JavaDStream metadataDStream = mapWithStateDStream.map( + new Function, Metadata>, Metadata>() { + @Override + public Metadata call(Tuple2, Metadata> t2) throws Exception { + return t2._2(); + } + }); + + // register the ReportingDStream op. + new ReportingDStream(metadataDStream.dstream(), id, getSourceName(source, id)).register(); + + // output the actual (deserialized) stream. + WindowedValue.FullWindowedValueCoder coder = + WindowedValue.FullWindowedValueCoder.of( + source.getDefaultOutputCoder(), + GlobalWindow.Coder.INSTANCE); + return mapWithStateDStream.flatMap( + new FlatMapFunction, Metadata>, byte[]>() { + @Override + public Iterable call(Tuple2, Metadata> t2) throws Exception { + return t2._1(); + } + }).map(CoderHelpers.fromByteFunction(coder)); } private static String getSourceName(Source source, int id) { @@ -111,20 +138,20 @@ private static void checkpointStream(JavaDStream dStream, } /** - * A flatMap DStream function that "flattens" the Iterators read by the - * {@link MicrobatchSource.Reader}s, while reporting the properties of the read to the - * {@link org.apache.spark.streaming.scheduler.InputInfoTracker} for RateControl purposes - * and visibility. + * A DStream function that reports the properties of the read to the + * {@link org.apache.spark.streaming.scheduler.InputInfoTracker} + * for RateControl purposes and visibility. */ - private static class ReportingFlatMappedDStream extends DStream { - private final DStream> parent; + private static class ReportingDStream extends DStream { + private final DStream parent; private final int inputDStreamId; private final String sourceName; - ReportingFlatMappedDStream(DStream> parent, - int inputDStreamId, - String sourceName) { - super(parent.ssc(), JavaSparkContext$.MODULE$.fakeClassTag()); + ReportingDStream( + DStream parent, + int inputDStreamId, + String sourceName) { + super(parent.ssc(), JavaSparkContext$.MODULE$.fakeClassTag()); this.parent = parent; this.inputDStreamId = inputDStreamId; this.sourceName = sourceName; @@ -142,20 +169,19 @@ public scala.collection.immutable.List> dependencies() { } @Override - public scala.Option> compute(Time validTime) { + public scala.Option> compute(Time validTime) { // compute parent. - scala.Option>> computedParentRDD = parent.getOrCompute(validTime); - // compute this DStream - take single-iterator partitions an flatMap them. - if (computedParentRDD.isDefined()) { - RDD computedRDD = computedParentRDD.get().toJavaRDD() - .flatMap(TranslationUtils.flattenIter()).rdd().cache(); - // report - for RateEstimator and visibility. - report(validTime, computedRDD.count()); - return scala.Option.apply(computedRDD); - } else { - report(validTime, 0); - return scala.Option.empty(); + scala.Option> parentRDDOpt = parent.getOrCompute(validTime); + long count = 0; + if (parentRDDOpt.isDefined()) { + JavaRDD parentRDD = parentRDDOpt.get().toJavaRDD(); + for (Metadata metadata: parentRDD.collect()) { + count += metadata.getNumRecords(); + } } + // report - for RateEstimator and visibility. + report(validTime, count); + return scala.Option.empty(); } private void report(Time batchTime, long count) { @@ -163,10 +189,34 @@ private void report(Time batchTime, long count) { scala.collection.immutable.Map metadata = new scala.collection.immutable.Map.Map1( StreamInputInfo.METADATA_KEY_DESCRIPTION(), - String.format("Read %d records from %s for batch time: %s", count, sourceName, - batchTime)); + String.format( + "Read %d records from %s for batch time: %s", + count, + sourceName, + batchTime)); StreamInputInfo streamInputInfo = new StreamInputInfo(inputDStreamId, count, metadata); ssc().scheduler().inputInfoTracker().reportInfo(batchTime, streamInputInfo); } } + + /** + * A metadata holder for an input stream partition. + */ + public static class Metadata implements Serializable { + private final long numRecords; + private final Instant watermark; + + public Metadata(long numRecords, Instant watermark) { + this.numRecords = numRecords; + this.watermark = watermark; + } + + public long getNumRecords() { + return numRecords; + } + + public Instant getWatermark() { + return watermark; + } + } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java index 053f4ac76fa8..ffe0ddd077af 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java @@ -29,6 +29,7 @@ import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.io.EmptyCheckpointMark; import org.apache.beam.runners.spark.io.MicrobatchSource; +import org.apache.beam.runners.spark.io.SparkUnboundedSource.Metadata; import org.apache.beam.runners.spark.translation.SparkRuntimeContext; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.BoundedSource; @@ -39,10 +40,12 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.spark.streaming.State; import org.apache.spark.streaming.StateSpec; +import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Option; +import scala.Tuple2; import scala.runtime.AbstractFunction3; /** @@ -92,14 +95,17 @@ private abstract static class SerializableFunction3 */ public static scala.Function3, scala.Option, /* CheckpointMarkT */State, - Iterator>> mapSourceFunction(final SparkRuntimeContext runtimeContext) { + Tuple2, Metadata>> mapSourceFunction( + final SparkRuntimeContext runtimeContext) { return new SerializableFunction3, Option, State, - Iterator>>() { + Tuple2, Metadata>>() { @Override - public Iterator> apply(Source source, scala.Option - startCheckpointMark, State state) { + public Tuple2, Metadata> apply( + Source source, + scala.Option startCheckpointMark, + State state) { // source as MicrobatchSource MicrobatchSource microbatchSource = (MicrobatchSource) source; @@ -130,18 +136,25 @@ public Iterator> apply(Source source, scala.Option> readValues = new ArrayList<>(); + // read microbatch as a serialized collection. + final List readValues = new ArrayList<>(); + final Instant watermark; + WindowedValue.FullWindowedValueCoder coder = + WindowedValue.FullWindowedValueCoder.of( + source.getDefaultOutputCoder(), + GlobalWindow.Coder.INSTANCE); try { // measure how long a read takes per-partition. Stopwatch stopwatch = Stopwatch.createStarted(); boolean finished = !reader.start(); while (!finished) { - readValues.add(WindowedValue.of(reader.getCurrent(), reader.getCurrentTimestamp(), - GlobalWindow.INSTANCE, PaneInfo.NO_FIRING)); + WindowedValue wv = WindowedValue.of(reader.getCurrent(), + reader.getCurrentTimestamp(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING); + readValues.add(CoderHelpers.toByteArray(wv, coder)); finished = !reader.advance(); } + watermark = ((MicrobatchSource.Reader) reader).getWatermark(); // close and checkpoint reader. reader.close(); LOG.info("Source id {} spent {} msec on reading.", microbatchSource.getId(), @@ -160,7 +173,13 @@ public Iterator> apply(Source source, scala.Option iterable = new Iterable() { + @Override + public Iterator iterator() { + return Iterators.unmodifiableIterator(readValues.iterator()); + } + }; + return new Tuple2<>(iterable, new Metadata(readValues.size(), watermark)); } }; } From 566663bd915b8ccacf18b71da16a0a434013ef41 Mon Sep 17 00:00:00 2001 From: Sela Date: Sun, 18 Dec 2016 15:16:23 +0200 Subject: [PATCH 2/2] Report the input global watermark for batch to the UI. --- .../beam/runners/spark/io/SparkUnboundedSource.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java index 9a4550419803..f03dc8c52054 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java @@ -173,25 +173,30 @@ public scala.Option> compute(Time validTime) { // compute parent. scala.Option> parentRDDOpt = parent.getOrCompute(validTime); long count = 0; + Instant globalWatermark = new Instant(Long.MIN_VALUE); if (parentRDDOpt.isDefined()) { JavaRDD parentRDD = parentRDDOpt.get().toJavaRDD(); for (Metadata metadata: parentRDD.collect()) { count += metadata.getNumRecords(); + // a monotonically increasing watermark. + globalWatermark = globalWatermark.isBefore(metadata.getWatermark()) + ? metadata.getWatermark() : globalWatermark; } } // report - for RateEstimator and visibility. - report(validTime, count); + report(validTime, count, globalWatermark); return scala.Option.empty(); } - private void report(Time batchTime, long count) { + private void report(Time batchTime, long count, Instant watermark) { // metadata - #records read and a description. scala.collection.immutable.Map metadata = new scala.collection.immutable.Map.Map1( StreamInputInfo.METADATA_KEY_DESCRIPTION(), String.format( - "Read %d records from %s for batch time: %s", + "Read %d records with observed watermark %s, from %s for batch time: %s", count, + watermark, sourceName, batchTime)); StreamInputInfo streamInputInfo = new StreamInputInfo(inputDStreamId, count, metadata);