From 9ab77be61e8803806cc36280255386b8f7ffcb4b Mon Sep 17 00:00:00 2001 From: Amit Sela Date: Mon, 27 Mar 2017 15:30:03 +0300 Subject: [PATCH] Force a "default" partitioner based on Spark default parallelism to avoid unnecessary shuffles in the composite GBK implementation. Add Javadoc. --- .../SparkGroupAlsoByWindowViaWindowSet.java | 35 ++- .../translation/GroupCombineFunctions.java | 38 ++- .../spark/translation/TranslationUtils.java | 245 ++++++++++++++---- 3 files changed, 248 insertions(+), 70 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java index 2f1713a18177..1f2fcb69c996 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java @@ -54,6 +54,8 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TupleTag; import org.apache.spark.Partitioner; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext$; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; @@ -135,12 +137,35 @@ JavaDStream>>> groupAlsoByWindow( //---- InputT: I DStream>*/ byte[]>> pairDStream = inputDStream - .map(WindowingHelpers.>>>unwindowFunction()) - .mapToPair(TranslationUtils.>>toPairFunction()) - // move to bytes and use coders for deserialization because there's a shuffle - // and checkpointing involved. - .mapToPair(CoderHelpers.toByteFunction(keyCoder, itrWvCoder)) + .transformToPair( + new Function< + JavaRDD>>>>, + JavaPairRDD>() { + // we use mapPartitions with the RDD API because its the only available API + // that allows to preserve partitioning. + @Override + public JavaPairRDD call( + JavaRDD>>>> rdd) + throws Exception { + return rdd.mapPartitions( + TranslationUtils.functionToFlatMapFunction( + WindowingHelpers + .>>>unwindowFunction()), + true) + .mapPartitionsToPair( + TranslationUtils + .>>toPairFlatMapFunction(), + true) + // move to bytes representation and use coders for deserialization + // because of checkpointing. + .mapPartitionsToPair( + TranslationUtils.pairFunctionToPairFlatMapFunction( + CoderHelpers.toByteFunction(keyCoder, itrWvCoder)), + true); + } + }) .dstream(); + PairDStreamFunctions pairDStreamFunctions = DStream.toPairDStreamFunctions( pairDStream, diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java index 917a9eef6c61..6a67ccee6e9d 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java @@ -28,13 +28,14 @@ import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; +import org.apache.spark.HashPartitioner; +import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; - /** * A set of group/combine functions to apply to Spark {@link org.apache.spark.rdd.RDD}s. */ @@ -49,18 +50,31 @@ public static JavaRDD>>>> g JavaRDD>> rdd, Coder keyCoder, WindowedValueCoder wvCoder) { - - // Use coders to convert objects in the PCollection to byte arrays, so they + // we use coders to convert objects in the PCollection to byte arrays, so they // can be transferred over the network for the shuffle. - return rdd - .map(new ReifyTimestampsAndWindowsFunction()) - .map(WindowingHelpers.>>unwindowFunction()) - .mapToPair(TranslationUtils.>toPairFunction()) - .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder)) - .groupByKey() - .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder)) - .map(TranslationUtils.>>fromPairFunction()) - .map(WindowingHelpers.>>>windowFunction()); + JavaPairRDD pairRDD = + rdd + .map(new ReifyTimestampsAndWindowsFunction()) + .map(WindowingHelpers.>>unwindowFunction()) + .mapToPair(TranslationUtils.>toPairFunction()) + .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder)); + // use a default parallelism HashPartitioner. + Partitioner partitioner = new HashPartitioner(rdd.rdd().sparkContext().defaultParallelism()); + + // using mapPartitions allows to preserve the partitioner + // and avoid unnecessary shuffle downstream. + return pairRDD + .groupByKey(partitioner) + .mapPartitionsToPair( + TranslationUtils.pairFunctionToPairFlatMapFunction( + CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder)), + true) + .mapPartitions( + TranslationUtils.>>fromPairFlatMapFunction(), true) + .mapPartitions( + TranslationUtils.functionToFlatMapFunction( + WindowingHelpers.>>>windowFunction()), + true); } /** diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java index 8545b360b31b..ef1ff9f26eb3 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java @@ -19,8 +19,10 @@ package org.apache.beam.runners.spark.translation; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; import com.google.common.collect.Maps; import java.io.Serializable; +import java.util.Iterator; import java.util.List; import java.util.Map; import org.apache.beam.runners.core.InMemoryStateInternals; @@ -41,7 +43,9 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.api.java.JavaDStream; @@ -49,21 +53,17 @@ import scala.Tuple2; -/** - * A set of utilities to help translating Beam transformations into Spark transformations. - */ +/** A set of utilities to help translating Beam transformations into Spark transformations. */ public final class TranslationUtils { - private TranslationUtils() { - } + private TranslationUtils() {} /** * In-memory state internals factory. * * @param State key type. */ - static class InMemoryStateInternalsFactory implements StateInternalsFactory, - Serializable { + static class InMemoryStateInternalsFactory implements StateInternalsFactory, Serializable { @Override public StateInternals stateInternalsForKey(K key) { return InMemoryStateInternals.forKey(key); @@ -73,12 +73,12 @@ public StateInternals stateInternalsForKey(K key) { /** * A SparkKeyedCombineFn function applied to grouped KVs. * - * @param Grouped key type. - * @param Grouped values type. + * @param Grouped key type. + * @param Grouped values type. * @param Output type. */ - public static class CombineGroupedValues implements - Function>>, WindowedValue>> { + public static class CombineGroupedValues + implements Function>>, WindowedValue>> { private final SparkKeyedCombineFn fn; public CombineGroupedValues(SparkKeyedCombineFn fn) { @@ -88,44 +88,46 @@ public CombineGroupedValues(SparkKeyedCombineFn fn) { @Override public WindowedValue> call(WindowedValue>> windowedKv) throws Exception { - return WindowedValue.of(KV.of(windowedKv.getValue().getKey(), fn.apply(windowedKv)), - windowedKv.getTimestamp(), windowedKv.getWindows(), windowedKv.getPane()); + return WindowedValue.of( + KV.of(windowedKv.getValue().getKey(), fn.apply(windowedKv)), + windowedKv.getTimestamp(), + windowedKv.getWindows(), + windowedKv.getPane()); } } /** * Checks if the window transformation should be applied or skipped. * - *

- * Avoid running assign windows if both source and destination are global window - * or if the user has not specified the WindowFn (meaning they are just messing - * with triggering or allowed lateness). - *

+ *

Avoid running assign windows if both source and destination are global window or if the user + * has not specified the WindowFn (meaning they are just messing with triggering or allowed + * lateness). * * @param transform The {@link Window.Assign} transformation. - * @param context The {@link EvaluationContext}. - * @param PCollection type. - * @param {@link BoundedWindow} type. + * @param context The {@link EvaluationContext}. + * @param PCollection type. + * @param {@link BoundedWindow} type. * @return if to apply the transformation. */ - public static boolean - skipAssignWindows(Window.Assign transform, EvaluationContext context) { + public static boolean skipAssignWindows( + Window.Assign transform, EvaluationContext context) { @SuppressWarnings("unchecked") WindowFn windowFn = (WindowFn) transform.getWindowFn(); return windowFn == null || (context.getInput(transform).getWindowingStrategy().getWindowFn() - instanceof GlobalWindows - && windowFn instanceof GlobalWindows); + instanceof GlobalWindows + && windowFn instanceof GlobalWindows); } /** Transform a pair stream into a value stream. */ public static JavaDStream dStreamValues(JavaPairDStream pairDStream) { - return pairDStream.map(new Function, T2>() { - @Override - public T2 call(Tuple2 v1) throws Exception { - return v1._2(); - } - }); + return pairDStream.map( + new Function, T2>() { + @Override + public T2 call(Tuple2 v1) throws Exception { + return v1._2(); + } + }); } /** {@link KV} to pair function. */ @@ -138,7 +140,33 @@ public Tuple2 call(KV kv) { }; } - /** A pair to {@link KV} function . */ + /** {@link KV} to pair flatmap function. */ + public static PairFlatMapFunction>, K, V> toPairFlatMapFunction() { + return new PairFlatMapFunction>, K, V>() { + @Override + public Iterable> call(final Iterator> itr) { + final Iterator> outputItr = + Iterators.transform( + itr, + new com.google.common.base.Function, Tuple2>() { + + @Override + public Tuple2 apply(KV kv) { + return new Tuple2<>(kv.getKey(), kv.getValue()); + } + }); + return new Iterable>() { + + @Override + public Iterator> iterator() { + return outputItr; + } + }; + } + }; + } + + /** A pair to {@link KV} function . */ static Function, KV> fromPairFunction() { return new Function, KV>() { @Override @@ -148,22 +176,48 @@ public KV call(Tuple2 t2) { }; } - /** Extract key from a {@link WindowedValue} {@link KV} into a pair. */ - public static PairFunction>, K, WindowedValue>> - toPairByKeyInWindowedValue() { - return new PairFunction>, K, WindowedValue>>() { + /** A pair to {@link KV} flatmap function . */ + static FlatMapFunction>, KV> fromPairFlatMapFunction() { + return new FlatMapFunction>, KV>() { + @Override + public Iterable> call(Iterator> itr) { + final Iterator> outputItr = + Iterators.transform( + itr, + new com.google.common.base.Function, KV>() { + @Override + public KV apply(Tuple2 t2) { + return KV.of(t2._1(), t2._2()); + } + }); + return new Iterable>() { @Override - public Tuple2>> call( - WindowedValue> windowedKv) throws Exception { - return new Tuple2<>(windowedKv.getValue().getKey(), windowedKv); - } + public Iterator> iterator() { + return outputItr; + } }; } + }; + } + + /** Extract key from a {@link WindowedValue} {@link KV} into a pair. */ + public static + PairFunction>, K, WindowedValue>> + toPairByKeyInWindowedValue() { + return new PairFunction>, K, WindowedValue>>() { + @Override + public Tuple2>> call(WindowedValue> windowedKv) + throws Exception { + return new Tuple2<>(windowedKv.getValue().getKey(), windowedKv); + } + }; + } /** Extract window from a {@link KV} with {@link WindowedValue} value. */ static Function>, WindowedValue>> toKVByWindowInValue() { return new Function>, WindowedValue>>() { - @Override public WindowedValue> call(KV> kv) throws Exception { + @Override + public WindowedValue> call(KV> kv) throws Exception { WindowedValue wv = kv.getValue(); return wv.withValue(KV.of(kv.getKey(), wv.getValue())); } @@ -193,28 +247,25 @@ public Boolean call(Tuple2, WindowedValue> input) { /** * Create SideInputs as Broadcast variables. * - * @param views The {@link PCollectionView}s. + * @param views The {@link PCollectionView}s. * @param context The {@link EvaluationContext}. * @return a map of tagged {@link SideInputBroadcast}s and their {@link WindowingStrategy}. */ - static Map, KV, SideInputBroadcast>> - getSideInputs(List> views, EvaluationContext context) { + static Map, KV, SideInputBroadcast>> getSideInputs( + List> views, EvaluationContext context) { return getSideInputs(views, context.getSparkContext(), context.getPViews()); } /** * Create SideInputs as Broadcast variables. * - * @param views The {@link PCollectionView}s. + * @param views The {@link PCollectionView}s. * @param context The {@link JavaSparkContext}. - * @param pviews The {@link SparkPCollectionView}. + * @param pviews The {@link SparkPCollectionView}. * @return a map of tagged {@link SideInputBroadcast}s and their {@link WindowingStrategy}. */ - public static Map, KV, SideInputBroadcast>> - getSideInputs( - List> views, - JavaSparkContext context, - SparkPCollectionView pviews) { + public static Map, KV, SideInputBroadcast>> getSideInputs( + List> views, JavaSparkContext context, SparkPCollectionView pviews) { if (views == null) { return ImmutableMap.of(); } else { @@ -223,7 +274,8 @@ public Boolean call(Tuple2, WindowedValue> input) { for (PCollectionView view : views) { SideInputBroadcast helper = pviews.getPCollectionView(view, context); WindowingStrategy windowingStrategy = view.getWindowingStrategyInternal(); - sideInputs.put(view.getTagInternal(), + sideInputs.put( + view.getTagInternal(), KV., SideInputBroadcast>of(windowingStrategy, helper)); } return sideInputs; @@ -270,9 +322,96 @@ public static void rejectStateAndTimers(DoFn doFn) { public static VoidFunction emptyVoidFunction() { return new VoidFunction() { - @Override public void call(T t) throws Exception { + @Override + public void call(T t) throws Exception { // Empty implementation. } }; } + + /** + * A utility method that adapts {@link PairFunction} to a {@link PairFlatMapFunction} with an + * {@link Iterator} input. This is particularly useful because it allows to use functions written + * for mapToPair functions in flatmapToPair functions. + * + * @param pairFunction the {@link PairFunction} to adapt. + * @param the input type. + * @param the output key type. + * @param the output value type. + * @return a {@link PairFlatMapFunction} that accepts an {@link Iterator} as an input and applies + * the {@link PairFunction} on every element. + */ + public static PairFlatMapFunction, K, V> pairFunctionToPairFlatMapFunction( + final PairFunction pairFunction) { + return new PairFlatMapFunction, K, V>() { + + @Override + public Iterable> call(Iterator itr) throws Exception { + final Iterator> outputItr = + Iterators.transform( + itr, + new com.google.common.base.Function>() { + + @Override + public Tuple2 apply(T t) { + try { + return pairFunction.call(t); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }); + return new Iterable>() { + + @Override + public Iterator> iterator() { + return outputItr; + } + }; + } + }; + } + + /** + * A utility method that adapts {@link Function} to a {@link FlatMapFunction} with an {@link + * Iterator} input. This is particularly useful because it allows to use functions written for map + * functions in flatmap functions. + * + * @param func the {@link Function} to adapt. + * @param the input type. + * @param the output type. + * @return a {@link FlatMapFunction} that accepts an {@link Iterator} as an input and applies the + * {@link Function} on every element. + */ + public static + FlatMapFunction, OutputT> functionToFlatMapFunction( + final Function func) { + return new FlatMapFunction, OutputT>() { + + @Override + public Iterable call(Iterator itr) throws Exception { + final Iterator outputItr = + Iterators.transform( + itr, + new com.google.common.base.Function() { + + @Override + public OutputT apply(InputT t) { + try { + return func.call(t); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }); + return new Iterable() { + + @Override + public Iterator iterator() { + return outputItr; + } + }; + } + }; + } }