diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 39e7dd0d9e88..009a6f760d48 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -133,14 +133,8 @@ public void evaluate(GroupByKey transform, EvaluationContext context) { WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), windowFn.windowCoder()); // --- group by key only. - Long bundleSize = - context.getSerializableOptions().get().as(SparkPipelineOptions.class).getBundleSize(); - Partitioner partitioner = - (bundleSize > 0) - ? new HashPartitioner(context.getSparkContext().defaultParallelism()) - : null; JavaRDD>>>> groupedByKey = - GroupCombineFunctions.groupByKeyOnly(inRDD, keyCoder, wvCoder, partitioner); + GroupCombineFunctions.groupByKeyOnly(inRDD, keyCoder, wvCoder, getPartitioner(context)); // --- now group also by window. // for batch, GroupAlsoByWindow uses an in-memory StateInternals. @@ -377,6 +371,7 @@ public void evaluate( (KvCoder) context.getInput(transform).getCoder(), windowingStrategy.getWindowFn().windowCoder(), (JavaRDD) inRDD, + getPartitioner(context), (MultiDoFnFunction) multiDoFnFunction); } else { all = inRDD.mapPartitionsToPair(multiDoFnFunction); @@ -420,6 +415,7 @@ private static JavaPairRDD, WindowedValue> statef KvCoder kvCoder, Coder windowCoder, JavaRDD>> kvInRDD, + Partitioner partitioner, MultiDoFnFunction, OutputT> doFnFunction) { Coder keyCoder = kvCoder.getKeyCoder(); @@ -427,7 +423,7 @@ private static JavaPairRDD, WindowedValue> statef WindowedValue.FullWindowedValueCoder.of(kvCoder.getValueCoder(), windowCoder); JavaRDD>>>> groupRDD = - GroupCombineFunctions.groupByKeyOnly(kvInRDD, keyCoder, wvCoder, null); + GroupCombineFunctions.groupByKeyOnly(kvInRDD, keyCoder, wvCoder, partitioner); return groupRDD .map( @@ -550,6 +546,14 @@ public String toNativeString() { }; } + private static Partitioner getPartitioner(EvaluationContext context) { + Long bundleSize = + context.getSerializableOptions().get().as(SparkPipelineOptions.class).getBundleSize(); + return (bundleSize > 0) + ? null + : new HashPartitioner(context.getSparkContext().defaultParallelism()); + } + private static final Map> EVALUATORS = new HashMap<>(); static { diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index a03aa178a224..7fe5bdefbb56 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -83,6 +83,7 @@ import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap; import org.apache.spark.Accumulator; +import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext$; @@ -305,7 +306,11 @@ public void evaluate(GroupByKey transform, EvaluationContext context) { JavaDStream>>>> groupedByKeyStream = dStream.transform( rdd -> - GroupCombineFunctions.groupByKeyOnly(rdd, coder.getKeyCoder(), wvCoder, null)); + GroupCombineFunctions.groupByKeyOnly( + rdd, + coder.getKeyCoder(), + wvCoder, + new HashPartitioner(rdd.rdd().sparkContext().defaultParallelism()))); // --- now group also by window. JavaDStream>>> outStream =