Skip to content

Commit

Permalink
Merge 581233f into 2d9bf27
Browse files Browse the repository at this point in the history
  • Loading branch information
aviemzur committed Mar 22, 2017
2 parents 2d9bf27 + 581233f commit 6ae48e4
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ public boolean apply(NativeTransform debugTransform) {
@Override
<TransformT extends PTransform<? super PInput, POutput>> void
doVisitTransform(TransformHierarchy.Node node) {
super.doVisitTransform(node);
@SuppressWarnings("unchecked")
TransformT transform = (TransformT) node.getTransform();
@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import scala.Tuple2;


Expand All @@ -60,47 +59,64 @@ class SourceDStream<T, CheckpointMarkT extends UnboundedSource.CheckpointMark>
private final UnboundedSource<T, CheckpointMarkT> unboundedSource;
private final SparkRuntimeContext runtimeContext;
private final Duration boundReadDuration;
private final int numPartitions;
// the initial parallelism, set by Spark's backend, will be determined once when the job starts.
// in case of resuming/recovering from checkpoint, the DStream will be reconstructed and this
// property should not be reset.
private final int initialParallelism;
// the bound on max records is optional.
// in case it is set explicitly via PipelineOptions, it takes precedence
// otherwise it could be activated via RateController.
private Long boundMaxRecords = null;
private final long boundMaxRecords;

SourceDStream(
StreamingContext ssc,
UnboundedSource<T, CheckpointMarkT> unboundedSource,
SparkRuntimeContext runtimeContext) {

SparkRuntimeContext runtimeContext,
Long boundMaxRecords,
int defaultParallelism) {
super(ssc, JavaSparkContext$.MODULE$.<scala.Tuple2<Source<T>, CheckpointMarkT>>fakeClassTag());
this.unboundedSource = unboundedSource;
this.runtimeContext = runtimeContext;

SparkPipelineOptions options = runtimeContext.getPipelineOptions().as(
SparkPipelineOptions.class);

this.boundReadDuration = boundReadDuration(options.getReadTimePercentage(),
options.getMinReadTimeMillis());
// set initial parallelism once.
this.initialParallelism = ssc().sc().defaultParallelism();
checkArgument(this.initialParallelism > 0, "Number of partitions must be greater than zero.");
}

public void setMaxRecordsPerBatch(long maxRecordsPerBatch) {
boundMaxRecords = maxRecordsPerBatch;
checkArgument(defaultParallelism > 0, "Number of partitions must be greater than zero.");
this.initialParallelism = defaultParallelism;

this.boundMaxRecords = boundMaxRecords > 0 ? boundMaxRecords : rateControlledMaxRecords();

try {
this.numPartitions =
createMicrobatchSource()
.splitIntoBundles(initialParallelism, options)
.size();
} catch (Exception e) {
throw new RuntimeException(e);
}
}

@Override
public scala.Option<RDD<Tuple2<Source<T>, CheckpointMarkT>>> compute(Time validTime) {
long maxNumRecords = boundMaxRecords != null ? boundMaxRecords : rateControlledMaxRecords();
MicrobatchSource<T, CheckpointMarkT> microbatchSource = new MicrobatchSource<>(
unboundedSource, boundReadDuration, initialParallelism, maxNumRecords, -1,
id());
RDD<scala.Tuple2<Source<T>, CheckpointMarkT>> rdd = new SourceRDD.Unbounded<>(
ssc().sc(), runtimeContext, microbatchSource);
RDD<scala.Tuple2<Source<T>, CheckpointMarkT>> rdd =
new SourceRDD.Unbounded<>(
ssc().sc(),
runtimeContext,
createMicrobatchSource(),
numPartitions);
return scala.Option.apply(rdd);
}


private MicrobatchSource<T, CheckpointMarkT> createMicrobatchSource() {
return new MicrobatchSource<>(unboundedSource, boundReadDuration, initialParallelism,
boundMaxRecords, -1, id());
}

@Override
public void start() { }

Expand All @@ -112,6 +128,10 @@ public String name() {
return "Beam UnboundedSource [" + id() + "]";
}

public int getNumPartitions() {
return numPartitions;
}

//---- Bound by time.

// return the largest between the proportional read time (%batchDuration dedicated for read)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,17 @@
import org.apache.beam.sdk.io.UnboundedSource;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.spark.Dependency;
import org.apache.spark.HashPartitioner;
import org.apache.spark.InterruptibleIterator;
import org.apache.spark.Partition;
import org.apache.spark.Partitioner;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.api.java.JavaSparkContext$;
import org.apache.spark.rdd.RDD;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import scala.Option;


/**
Expand Down Expand Up @@ -213,21 +215,25 @@ public Source<T> getSource() {
*/
public static class Unbounded<T, CheckpointMarkT extends
UnboundedSource.CheckpointMark> extends RDD<scala.Tuple2<Source<T>, CheckpointMarkT>> {

private final MicrobatchSource<T, CheckpointMarkT> microbatchSource;
private final SparkRuntimeContext runtimeContext;
private final Partitioner partitioner;

// to satisfy Scala API.
private static final scala.collection.immutable.List<Dependency<?>> NIL =
scala.collection.JavaConversions
.asScalaBuffer(Collections.<Dependency<?>>emptyList()).toList();

public Unbounded(SparkContext sc,
SparkRuntimeContext runtimeContext,
MicrobatchSource<T, CheckpointMarkT> microbatchSource) {
SparkRuntimeContext runtimeContext,
MicrobatchSource<T, CheckpointMarkT> microbatchSource,
int initialNumPartitions) {
super(sc, NIL,
JavaSparkContext$.MODULE$.<scala.Tuple2<Source<T>, CheckpointMarkT>>fakeClassTag());
this.runtimeContext = runtimeContext;
this.microbatchSource = microbatchSource;
this.partitioner = new HashPartitioner(initialNumPartitions);
}

@Override
Expand All @@ -246,6 +252,13 @@ public Partition[] getPartitions() {
}
}

@Override
public Option<Partitioner> partitioner() {
// setting the partitioner helps to "keep" the same partitioner in the following
// mapWithState read for Read.Unbounded, preventing a post-mapWithState shuffle.
return scala.Some.apply(partitioner);
}

@Override
public scala.collection.Iterator<scala.Tuple2<Source<T>, CheckpointMarkT>>
compute(Partition split, TaskContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
import org.apache.spark.streaming.dstream.DStream;
import org.apache.spark.streaming.scheduler.StreamInputInfo;
import org.joda.time.Instant;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import scala.runtime.BoxedUnit;

Expand All @@ -70,27 +71,36 @@
*/
public class SparkUnboundedSource {

private static final Logger LOG = LoggerFactory.getLogger(SparkUnboundedSource.class);

public static <T, CheckpointMarkT extends CheckpointMark> UnboundedDataset<T> read(
JavaStreamingContext jssc,
SparkRuntimeContext rc,
UnboundedSource<T, CheckpointMarkT> source) {

SparkPipelineOptions options = rc.getPipelineOptions().as(SparkPipelineOptions.class);
Long maxRecordsPerBatch = options.getMaxRecordsPerBatch();
SourceDStream<T, CheckpointMarkT> sourceDStream = new SourceDStream<>(jssc.ssc(), source, rc);
// if max records per batch was set by the user.
if (maxRecordsPerBatch > 0) {
sourceDStream.setMaxRecordsPerBatch(maxRecordsPerBatch);
}

Integer defaultParallelism = jssc.sc().defaultParallelism();

SourceDStream<T, CheckpointMarkT> sourceDStream =
new SourceDStream<>(
jssc.ssc(),
source,
rc,
options.getMaxRecordsPerBatch(), defaultParallelism);

JavaPairInputDStream<Source<T>, CheckpointMarkT> inputDStream =
JavaPairInputDStream$.MODULE$.fromInputDStream(sourceDStream,
JavaSparkContext$.MODULE$.<Source<T>>fakeClassTag(),
JavaSparkContext$.MODULE$.<CheckpointMarkT>fakeClassTag());
JavaSparkContext$.MODULE$.<CheckpointMarkT>fakeClassTag());

// call mapWithState to read from a checkpointable sources.
JavaMapWithStateDStream<Source<T>, CheckpointMarkT, Tuple2<byte[], Instant>,
Tuple2<Iterable<byte[]>, Metadata>> mapWithStateDStream = inputDStream.mapWithState(
StateSpec.function(StateSpecFunctions.<T, CheckpointMarkT>mapSourceFunction(rc)));
Tuple2<Iterable<byte[]>, Metadata>> mapWithStateDStream =
inputDStream.mapWithState(
StateSpec
.function(StateSpecFunctions.<T, CheckpointMarkT>mapSourceFunction(rc))
.numPartitions(sourceDStream.getNumPartitions()));

// set checkpoint duration for read stream, if set.
checkpointStream(mapWithStateDStream, options);
Expand All @@ -113,13 +123,27 @@ public Metadata call(Tuple2<Iterable<byte[]>, Metadata> t2) throws Exception {
WindowedValue.FullWindowedValueCoder.of(
source.getDefaultOutputCoder(),
GlobalWindow.Coder.INSTANCE);

JavaDStream<WindowedValue<T>> readUnboundedStream = mapWithStateDStream.flatMap(
new FlatMapFunction<Tuple2<Iterable<byte[]>, Metadata>, byte[]>() {
@Override
public Iterable<byte[]> call(Tuple2<Iterable<byte[]>, Metadata> t2) throws Exception {
return t2._1();
}
}).map(CoderHelpers.fromByteFunction(coder));

if (sourceDStream.getNumPartitions() < defaultParallelism) {
// Repartition up to default parallelism if there are too few partitions.
LOG.info(
"Less partitions than default parallelism for source {} "
+ "(partitions={} < parallelism={}). "
+ "Repartitioning up to default parallelism.",
source,
sourceDStream.getNumPartitions(),
defaultParallelism);
readUnboundedStream = readUnboundedStream.repartition(defaultParallelism);
}

return new UnboundedDataset<>(readUnboundedStream, Collections.singletonList(id));
}

Expand Down

0 comments on commit 6ae48e4

Please sign in to comment.