Skip to content

Commit

Permalink
Count input partition size while reading, so "count reporting" is muc…
Browse files Browse the repository at this point in the history
…h lighter and we don't count the entire stream.
  • Loading branch information
Sela committed Dec 6, 2016
1 parent dcad795 commit deb998f
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@

package org.apache.beam.runners.spark.io;

import java.io.Serializable;
import java.util.Collections;
import java.util.Iterator;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.stateful.StateSpecFunctions;
import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
import org.apache.beam.runners.spark.translation.TranslationUtils;
import org.apache.beam.sdk.io.Source;
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext$;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.rdd.RDD;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.StateSpec;
Expand All @@ -40,6 +43,9 @@
import org.apache.spark.streaming.dstream.DStream;
import org.apache.spark.streaming.scheduler.StreamInputInfo;

import scala.Tuple2;
import scala.runtime.BoxedUnit;


/**
* A "composite" InputDStream implementation for {@link UnboundedSource}s.
Expand All @@ -58,31 +64,59 @@
public class SparkUnboundedSource {

public static <T, CheckpointMarkT extends UnboundedSource.CheckpointMark>
JavaDStream<WindowedValue<T>> read(JavaStreamingContext jssc,
SparkRuntimeContext rc,
UnboundedSource<T, CheckpointMarkT> source) {
JavaDStream<WindowedValue<T>> read(
JavaStreamingContext jssc,
SparkRuntimeContext rc,
UnboundedSource<T, CheckpointMarkT> source) {

JavaPairInputDStream<Source<T>, CheckpointMarkT> inputDStream =
JavaPairInputDStream$.MODULE$.fromInputDStream(new SourceDStream<>(jssc.ssc(), source, rc),
JavaPairInputDStream$.MODULE$.fromInputDStream(
new SourceDStream<>(jssc.ssc(), source, rc),
JavaSparkContext$.MODULE$.<Source<T>>fakeClassTag(),
JavaSparkContext$.MODULE$.<CheckpointMarkT>fakeClassTag());
JavaSparkContext$.MODULE$.<CheckpointMarkT>fakeClassTag());

// call mapWithState to read from a checkpointable sources.
//TODO: consider broadcasting the rc instead of re-sending every batch.
JavaMapWithStateDStream<Source<T>, CheckpointMarkT, byte[],
Iterator<WindowedValue<T>>> mapWithStateDStream = inputDStream.mapWithState(
StateSpec.function(StateSpecFunctions.<T, CheckpointMarkT>mapSourceFunction(rc)));
Tuple2<Iterator<WindowedValue<T>>, InputPartitionMetadata>> mapWithStateDStream =
inputDStream.mapWithState(
StateSpec.function(StateSpecFunctions.<T, CheckpointMarkT>mapSourceFunction(rc)));

// set checkpoint duration for read stream, if set.
checkpointStream(mapWithStateDStream, rc);
// flatmap and report read elements. Use the inputDStream's id to tie between the reported
// info and the inputDStream it originated from.
int id = inputDStream.inputDStream().id();
ReportingFlatMappedDStream<WindowedValue<T>> reportingFlatMappedDStream =
new ReportingFlatMappedDStream<>(mapWithStateDStream.dstream(), id,
getSourceName(source, id));
// cache because checkpoint intervals are greater then batch intervals (default: X10).
mapWithStateDStream.cache();

return JavaDStream.fromDStream(reportingFlatMappedDStream,
JavaSparkContext$.MODULE$.<WindowedValue<T>>fakeClassTag());
// report count.
// use the inputDStream's id to tie between the reported info
// and the inputDStream it originated from.
int id = inputDStream.inputDStream().id();
JavaDStream<InputPartitionMetadata> metadataStream = mapWithStateDStream.map(
new Function<Tuple2<Iterator<WindowedValue<T>>, InputPartitionMetadata>,
InputPartitionMetadata>() {
@Override
public InputPartitionMetadata call(
Tuple2<Iterator<WindowedValue<T>>, InputPartitionMetadata> t2) throws Exception {
return t2._2();
}
});
// register the ReportingDStream op.
new ReportingDStream(metadataStream.dstream(), id, getSourceName(source, id)).register();

// the actual stream of data to output.
return mapWithStateDStream.flatMap(new FlatMapFunction<Tuple2<Iterator<WindowedValue<T>>,
InputPartitionMetadata>, WindowedValue<T>>() {
@Override
public Iterable<WindowedValue<T>> call(
Tuple2<Iterator<WindowedValue<T>>, InputPartitionMetadata> t2) throws Exception {
final Iterator<WindowedValue<T>> itr = t2._1();
return new Iterable<WindowedValue<T>>() {
@Override
public Iterator<WindowedValue<T>> iterator() {
return itr;
}
};
}
});
}

private static <T> String getSourceName(Source<T> source, int id) {
Expand All @@ -106,20 +140,20 @@ private static void checkpointStream(JavaDStream<?> dStream,
}

/**
* A flatMap DStream function that "flattens" the Iterators read by the
* {@link MicrobatchSource.Reader}s, while reporting the properties of the read to the
* A DStream function that reports the properties of the read to the
* {@link org.apache.spark.streaming.scheduler.InputInfoTracker} for RateControl purposes
* and visibility.
*/
private static class ReportingFlatMappedDStream<T> extends DStream<T> {
private final DStream<Iterator<T>> parent;
private static class ReportingDStream extends DStream<BoxedUnit> {
private final DStream<InputPartitionMetadata> parent;
private final int inputDStreamId;
private final String sourceName;

ReportingFlatMappedDStream(DStream<Iterator<T>> parent,
int inputDStreamId,
String sourceName) {
super(parent.ssc(), JavaSparkContext$.MODULE$.<T>fakeClassTag());
ReportingDStream(
DStream<InputPartitionMetadata> parent,
int inputDStreamId,
String sourceName) {
super(parent.ssc(), JavaSparkContext$.MODULE$.<BoxedUnit>fakeClassTag());
this.parent = parent;
this.inputDStreamId = inputDStreamId;
this.sourceName = sourceName;
Expand All @@ -137,31 +171,48 @@ public scala.collection.immutable.List<DStream<?>> dependencies() {
}

@Override
public scala.Option<RDD<T>> compute(Time validTime) {
public scala.Option<RDD<BoxedUnit>> compute(Time validTime) {
// compute parent.
scala.Option<RDD<Iterator<T>>> computedParentRDD = parent.getOrCompute(validTime);
scala.Option<RDD<InputPartitionMetadata>> parentRDDOpt = parent.getOrCompute(validTime);
// compute this DStream - take single-iterator partitions an flatMap them.
if (computedParentRDD.isDefined()) {
RDD<T> computedRDD = computedParentRDD.get().toJavaRDD()
.flatMap(TranslationUtils.<T>flattenIter()).rdd().cache();
if (parentRDDOpt.isDefined()) {
JavaRDD<InputPartitionMetadata> parentRDD = parentRDDOpt.get().toJavaRDD();
long count = 0;
// number of elements to collect is the number of Source splits, a very limited number.
for (InputPartitionMetadata metadata: parentRDD.collect()) {
count += metadata.numRecords;
}
// report - for RateEstimator and visibility.
report(validTime, computedRDD.count());
return scala.Option.apply(computedRDD);
report(validTime, count);
} else {
report(validTime, 0);
return scala.Option.empty();
}
return scala.Option.empty();
}

private void report(Time batchTime, long count) {
// metadata - #records read and a description.
scala.collection.immutable.Map<String, Object> metadata =
new scala.collection.immutable.Map.Map1<String, Object>(
StreamInputInfo.METADATA_KEY_DESCRIPTION(),
String.format("Read %d records from %s for batch time: %s", count, sourceName,
batchTime));
String.format(
"Read %d records from %s for batch time: %s",
count,
sourceName,
batchTime));
StreamInputInfo streamInputInfo = new StreamInputInfo(inputDStreamId, count, metadata);
ssc().scheduler().inputInfoTracker().reportInfo(batchTime, streamInputInfo);
}
}

/**
* Metadata for an input partition of a microbatch.
*/
public static class InputPartitionMetadata implements Serializable {
private final Long numRecords;

public InputPartitionMetadata(Long numRecords) {
this.numRecords = numRecords;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.io.EmptyCheckpointMark;
import org.apache.beam.runners.spark.io.MicrobatchSource;
import org.apache.beam.runners.spark.io.SparkUnboundedSource;
import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.io.BoundedSource;
Expand All @@ -43,6 +44,7 @@
import org.slf4j.LoggerFactory;

import scala.Option;
import scala.Tuple2;
import scala.runtime.AbstractFunction3;

/**
Expand Down Expand Up @@ -91,15 +93,19 @@ private abstract static class SerializableFunction3<T1, T2, T3, T4>
* @return The appropriate {@link org.apache.spark.streaming.StateSpec} function.
*/
public static <T, CheckpointMarkT extends UnboundedSource.CheckpointMark>
scala.Function3<Source<T>, scala.Option<CheckpointMarkT>, /* CheckpointMarkT */State<byte[]>,
Iterator<WindowedValue<T>>> mapSourceFunction(final SparkRuntimeContext runtimeContext) {
scala.Function3<Source<T>, scala.Option<CheckpointMarkT>, /* CheckpointMarkT */State<byte[]>,
Tuple2<Iterator<WindowedValue<T>>, SparkUnboundedSource.InputPartitionMetadata>>
mapSourceFunction(final SparkRuntimeContext runtimeContext) {

return new SerializableFunction3<Source<T>, Option<CheckpointMarkT>, State<byte[]>,
Iterator<WindowedValue<T>>>() {
Tuple2<Iterator<WindowedValue<T>>, SparkUnboundedSource.InputPartitionMetadata>>() {

@Override
public Iterator<WindowedValue<T>> apply(Source<T> source, scala.Option<CheckpointMarkT>
startCheckpointMark, State<byte[]> state) {
public Tuple2<Iterator<WindowedValue<T>>, SparkUnboundedSource.InputPartitionMetadata> apply(
Source<T> source,
scala.Option<CheckpointMarkT> startCheckpointMark,
State<byte[]> state) {

// source as MicrobatchSource
MicrobatchSource<T, CheckpointMarkT> microbatchSource =
(MicrobatchSource<T, CheckpointMarkT>) source;
Expand Down Expand Up @@ -159,8 +165,10 @@ public Iterator<WindowedValue<T>> apply(Source<T> source, scala.Option<Checkpoin
} catch (IOException e) {
throw new RuntimeException("Failed to read from reader.", e);
}

return Iterators.unmodifiableIterator(readValues.iterator());
SparkUnboundedSource.InputPartitionMetadata metadata =
new SparkUnboundedSource.InputPartitionMetadata((long) readValues.size());
Iterator<WindowedValue<T>> unmodItr = Iterators.unmodifiableIterator(readValues.iterator());
return new Tuple2<>(unmodItr, metadata);
}
};
}
Expand Down

0 comments on commit deb998f

Please sign in to comment.