From 8bb45552dd276a9ef002bf257f415bcf6e971a2f Mon Sep 17 00:00:00 2001 From: Aviem Zur Date: Tue, 29 Nov 2016 09:51:12 +0200 Subject: [PATCH] Add InputDStream id to MicrobatchSource hashcode. Done to avoid collisions between splits of different sources. --- .../runners/spark/io/MicrobatchSource.java | 20 +++++++++++++------ .../beam/runners/spark/io/SourceDStream.java | 3 ++- .../spark/stateful/StateSpecFunctions.java | 2 +- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/MicrobatchSource.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/MicrobatchSource.java index 4a174aaf9b2e..565637597073 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/MicrobatchSource.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/MicrobatchSource.java @@ -54,6 +54,7 @@ public class MicrobatchSource(splits.get(i), maxReadTime, 1, numRecords[i], i)); + result.add(new MicrobatchSource<>(splits.get(i), maxReadTime, 1, numRecords[i], i, sourceId)); } return result; } @@ -137,8 +140,8 @@ public Coder getCheckpointMarkCoder() { return source.getCheckpointMarkCoder(); } - public int getSplitId() { - return splitId; + public String getId() { + return sourceId + "_" + splitId; } @Override @@ -150,13 +153,18 @@ public boolean equals(Object o) { return false; } MicrobatchSource that = (MicrobatchSource) o; - + if (sourceId != that.sourceId) { + return false; + } return splitId == that.splitId; + } @Override public int hashCode() { - return splitId; + int result = sourceId; + result = 31 * result + splitId; + return result; } /** diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java index 4e47757dc0d3..84b247b265da 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java @@ -83,7 +83,8 @@ public SourceDStream(StreamingContext ssc, @Override public scala.Option, CheckpointMarkT>>> compute(Time validTime) { MicrobatchSource microbatchSource = new MicrobatchSource<>( - unboundedSource, boundReadDuration, initialParallelism, rateControlledMaxRecords(), -1); + unboundedSource, boundReadDuration, initialParallelism, rateControlledMaxRecords(), -1, + id()); RDD, CheckpointMarkT>> rdd = new SourceRDD.Unbounded<>( ssc().sc(), runtimeContext, microbatchSource); return scala.Option.apply(rdd); diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java index 48849c2d8feb..053f4ac76fa8 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java @@ -144,7 +144,7 @@ public Iterator> apply(Source source, scala.Option