Skip to content

Commit

Permalink
This closes #1579
Browse files Browse the repository at this point in the history
  • Loading branch information
kennknowles committed Dec 12, 2016
2 parents 321547f + 74b0bef commit bfd21d7
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 31 deletions.
Expand Up @@ -54,6 +54,11 @@ public interface SparkPipelineOptions
Long getMinReadTimeMillis();
void setMinReadTimeMillis(Long minReadTimeMillis);

@Description("Max records per micro-batch. For streaming sources only.")
@Default.Long(-1)
Long getMaxRecordsPerBatch();
void setMaxRecordsPerBatch(Long maxRecordsPerBatch);

@Description("A value between 0-1 to describe the percentage of a micro-batch dedicated "
+ "to reading from UnboundedSource.")
@Default.Double(0.1)
Expand Down
Expand Up @@ -53,7 +53,7 @@
* {@link SparkPipelineOptions#getMinReadTimeMillis()}.
* Records bound is controlled by the {@link RateController} mechanism.
*/
public class SourceDStream<T, CheckpointMarkT extends UnboundedSource.CheckpointMark>
class SourceDStream<T, CheckpointMarkT extends UnboundedSource.CheckpointMark>
extends InputDStream<Tuple2<Source<T>, CheckpointMarkT>> {
private static final Logger LOG = LoggerFactory.getLogger(SourceDStream.class);

Expand All @@ -64,10 +64,16 @@ public class SourceDStream<T, CheckpointMarkT extends UnboundedSource.Checkpoint
// in case of resuming/recovering from checkpoint, the DStream will be reconstructed and this
// property should not be reset.
private final int initialParallelism;
// the bound on max records is optional.
// in case it is set explicitly via PipelineOptions, it takes precedence
// otherwise it could be activated via RateController.
private Long boundMaxRecords = null;

SourceDStream(
StreamingContext ssc,
UnboundedSource<T, CheckpointMarkT> unboundedSource,
SparkRuntimeContext runtimeContext) {

public SourceDStream(StreamingContext ssc,
UnboundedSource<T, CheckpointMarkT> unboundedSource,
SparkRuntimeContext runtimeContext) {
super(ssc, JavaSparkContext$.MODULE$.<scala.Tuple2<Source<T>, CheckpointMarkT>>fakeClassTag());
this.unboundedSource = unboundedSource;
this.runtimeContext = runtimeContext;
Expand All @@ -80,10 +86,15 @@ public SourceDStream(StreamingContext ssc,
checkArgument(this.initialParallelism > 0, "Number of partitions must be greater than zero.");
}

public void setMaxRecordsPerBatch(long maxRecordsPerBatch) {
boundMaxRecords = maxRecordsPerBatch;
}

@Override
public scala.Option<RDD<Tuple2<Source<T>, CheckpointMarkT>>> compute(Time validTime) {
long maxNumRecords = boundMaxRecords != null ? boundMaxRecords : rateControlledMaxRecords();
MicrobatchSource<T, CheckpointMarkT> microbatchSource = new MicrobatchSource<>(
unboundedSource, boundReadDuration, initialParallelism, rateControlledMaxRecords(), -1,
unboundedSource, boundReadDuration, initialParallelism, maxNumRecords, -1,
id());
RDD<scala.Tuple2<Source<T>, CheckpointMarkT>> rdd = new SourceRDD.Unbounded<>(
ssc().sc(), runtimeContext, microbatchSource);
Expand Down
Expand Up @@ -61,19 +61,25 @@ public class SparkUnboundedSource {
JavaDStream<WindowedValue<T>> read(JavaStreamingContext jssc,
SparkRuntimeContext rc,
UnboundedSource<T, CheckpointMarkT> source) {
SparkPipelineOptions options = rc.getPipelineOptions().as(SparkPipelineOptions.class);
Long maxRecordsPerBatch = options.getMaxRecordsPerBatch();
SourceDStream<T, CheckpointMarkT> sourceDStream = new SourceDStream<>(jssc.ssc(), source, rc);
// if max records per batch was set by the user.
if (maxRecordsPerBatch > 0) {
sourceDStream.setMaxRecordsPerBatch(maxRecordsPerBatch);
}
JavaPairInputDStream<Source<T>, CheckpointMarkT> inputDStream =
JavaPairInputDStream$.MODULE$.fromInputDStream(new SourceDStream<>(jssc.ssc(), source, rc),
JavaPairInputDStream$.MODULE$.fromInputDStream(sourceDStream,
JavaSparkContext$.MODULE$.<Source<T>>fakeClassTag(),
JavaSparkContext$.MODULE$.<CheckpointMarkT>fakeClassTag());

// call mapWithState to read from a checkpointable sources.
//TODO: consider broadcasting the rc instead of re-sending every batch.
JavaMapWithStateDStream<Source<T>, CheckpointMarkT, byte[],
Iterator<WindowedValue<T>>> mapWithStateDStream = inputDStream.mapWithState(
StateSpec.function(StateSpecFunctions.<T, CheckpointMarkT>mapSourceFunction(rc)));

// set checkpoint duration for read stream, if set.
checkpointStream(mapWithStateDStream, rc);
checkpointStream(mapWithStateDStream, options);
// flatmap and report read elements. Use the inputDStream's id to tie between the reported
// info and the inputDStream it originated from.
int id = inputDStream.inputDStream().id();
Expand All @@ -97,9 +103,8 @@ private static <T> String getSourceName(Source<T> source, int id) {
}

private static void checkpointStream(JavaDStream<?> dStream,
SparkRuntimeContext rc) {
long checkpointDurationMillis = rc.getPipelineOptions().as(SparkPipelineOptions.class)
.getCheckpointDurationMillis();
SparkPipelineOptions options) {
long checkpointDurationMillis = options.getCheckpointDurationMillis();
if (checkpointDurationMillis > 0) {
dStream.checkpoint(new Duration(checkpointDurationMillis));
}
Expand Down
Expand Up @@ -76,10 +76,15 @@ public static void init() throws IOException {

@Test
public void testEarliest2Topics() throws Exception {
Duration batchIntervalDuration = Duration.standardSeconds(5);
SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(checkpointParentDir);
// It seems that the consumer's first "position" lookup (in unit test) takes +200 msec,
// so to be on the safe side we'll set to 750 msec.
options.setMinReadTimeMillis(750L);
// provide a generous enough batch-interval to have everything fit in one micro-batch.
options.setBatchIntervalMillis(batchIntervalDuration.getMillis());
// provide a very generous read time bound, we rely on num records bound here.
options.setMinReadTimeMillis(batchIntervalDuration.minus(1).getMillis());
// bound the read on the number of messages - 2 topics of 4 messages each.
options.setMaxRecordsPerBatch(8L);

//--- setup
// two topics.
final String topic1 = "topic1";
Expand All @@ -90,8 +95,6 @@ public void testEarliest2Topics() throws Exception {
);
// expected.
final String[] expected = {"k1,v1", "k2,v2", "k3,v3", "k4,v4"};
// batch and window duration.
final Duration batchAndWindowDuration = Duration.standardSeconds(1);

// write to both topics ahead.
produce(topic1, messages);
Expand All @@ -114,17 +117,27 @@ public void testEarliest2Topics() throws Exception {
PCollection<String> deduped =
p.apply(read.withoutMetadata()).setCoder(
KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
.apply(Window.<KV<String, String>>into(FixedWindows.of(batchAndWindowDuration)))
.apply(Window.<KV<String, String>>into(FixedWindows.of(batchIntervalDuration)))
.apply(ParDo.of(new FormatKVFn()))
.apply(Distinct.<String>create());

PAssertStreaming.runAndAssertContents(p, deduped, expected, Duration.standardSeconds(1L));
// graceful shutdown will make sure first batch (at least) will finish.
Duration timeout = Duration.standardSeconds(1L);
PAssertStreaming.runAndAssertContents(p, deduped, expected, timeout);
}

@Test
public void testLatest() throws Exception {
Duration batchIntervalDuration = Duration.standardSeconds(5);
SparkContextOptions options =
commonOptions.withTmpCheckpointDir(checkpointParentDir).as(SparkContextOptions.class);
// provide a generous enough batch-interval to have everything fit in one micro-batch.
options.setBatchIntervalMillis(batchIntervalDuration.getMillis());
// provide a very generous read time bound, we rely on num records bound here.
options.setMinReadTimeMillis(batchIntervalDuration.minus(1).getMillis());
// bound the read on the number of messages - 1 topics of 4 messages.
options.setMaxRecordsPerBatch(4L);

//--- setup
final String topic = "topic";
// messages.
Expand All @@ -133,16 +146,11 @@ public void testLatest() throws Exception {
);
// expected.
final String[] expected = {"k1,v1", "k2,v2", "k3,v3", "k4,v4"};
// batch and window duration.
final Duration batchAndWindowDuration = Duration.standardSeconds(1);

// write once first batch completes, this will guarantee latest-like behaviour.
options.setListeners(Collections.<JavaStreamingListener>singletonList(
KafkaWriteOnBatchCompleted.once(messages, Collections.singletonList(topic),
EMBEDDED_KAFKA_CLUSTER.getProps(), EMBEDDED_KAFKA_CLUSTER.getBrokerList())));
// It seems that the consumer's first "position" lookup (in unit test) takes +200 msec,
// so to be on the safe side we'll set to 750 msec.
options.setMinReadTimeMillis(750L);

//------- test: read and format.
Pipeline p = Pipeline.create(options);
Expand All @@ -161,7 +169,7 @@ public void testLatest() throws Exception {
PCollection<String> formatted =
p.apply(read.withoutMetadata()).setCoder(
KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
.apply(Window.<KV<String, String>>into(FixedWindows.of(batchAndWindowDuration)))
.apply(Window.<KV<String, String>>into(FixedWindows.of(batchIntervalDuration)))
.apply(ParDo.of(new FormatKVFn()));

// run for more than 1 batch interval, so that reading of latest is attempted in the
Expand Down
Expand Up @@ -112,10 +112,14 @@ private static void produce() {

@Test
public void testRun() throws Exception {
Duration batchIntervalDuration = Duration.standardSeconds(5);
SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(checkpointParentDir);
// It seems that the consumer's first "position" lookup (in unit test) takes +200 msec,
// so to be on the safe side we'll set to 750 msec.
options.setMinReadTimeMillis(750L);
// provide a generous enough batch-interval to have everything fit in one micro-batch.
options.setBatchIntervalMillis(batchIntervalDuration.getMillis());
// provide a very generous read time bound, we rely on num records bound here.
options.setMinReadTimeMillis(batchIntervalDuration.minus(1).getMillis());
// bound the read on the number of messages - 1 topic of 4 messages.
options.setMaxRecordsPerBatch(4L);

// checkpoint after first (and only) interval.
options.setCheckpointDurationMillis(options.getBatchIntervalMillis());
Expand Down Expand Up @@ -164,10 +168,9 @@ private static SparkPipelineResult run(SparkPipelineOptions options) {
.apply(Window.<KV<String, String>>into(FixedWindows.of(windowDuration)))
.apply(ParDo.of(new FormatAsText()));

// requires a graceful stop so that checkpointing of the first run would finish successfully
// before stopping and attempting to resume.
return PAssertStreaming.runAndAssertContents(p, formattedKV, EXPECTED,
Duration.standardSeconds(1L));
// graceful shutdown will make sure first batch (at least) will finish.
Duration timeout = Duration.standardSeconds(1L);
return PAssertStreaming.runAndAssertContents(p, formattedKV, EXPECTED, timeout);
}

@AfterClass
Expand Down

0 comments on commit bfd21d7

Please sign in to comment.