Skip to content

Commit

Permalink
[BEAM-7044] portable Spark: support stateful dofns
Browse files Browse the repository at this point in the history
  • Loading branch information
ibzib committed Jun 10, 2019
1 parent 9c5f768 commit 8c07b77
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 34 deletions.
3 changes: 1 addition & 2 deletions runners/spark/job-server/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ def portableValidatesRunnerTask(String name) {
excludeCategories 'org.apache.beam.sdk.testing.UsesMapState'
excludeCategories 'org.apache.beam.sdk.testing.UsesSetState'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
// TODO re-enable when state is supported
excludeCategories 'org.apache.beam.sdk.testing.UsesStatefulParDo'
// TODO(BEAM-7221)
excludeCategories 'org.apache.beam.sdk.testing.UsesTimersInParDo'
//SplitableDoFnTests
excludeCategories 'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,19 +142,37 @@ public static <K, V> PairFunction<Tuple2<K, V>, ByteArray, byte[]> toByteFunctio
}

/**
* A function wrapper for converting a byte array pair to a key-value pair.
* A function for converting a byte array pair to a key-value pair.
*
* @param keyCoder Coder to deserialize keys.
* @param valueCoder Coder to deserialize values.
* @param <K> The type of the key being deserialized.
* @param <V> The type of the value being deserialized.
* @return A function that accepts a pair of byte arrays and returns a key-value pair.
*/
public static <K, V> PairFunction<Tuple2<ByteArray, byte[]>, K, V> fromByteFunction(
final Coder<K> keyCoder, final Coder<V> valueCoder) {
return tuple ->
new Tuple2<>(
fromByteArray(tuple._1().getValue(), keyCoder), fromByteArray(tuple._2(), valueCoder));
public static class FromByteFunction<K, V>
implements PairFunction<Tuple2<ByteArray, byte[]>, K, V>,
org.apache.beam.vendor.guava.v20_0.com.google.common.base.Function<
Tuple2<ByteArray, byte[]>, Tuple2<K, V>> {
private final Coder<K> keyCoder;
private final Coder<V> valueCoder;

/**
* @param keyCoder Coder to deserialize keys.
* @param valueCoder Coder to deserialize values.
*/
public FromByteFunction(final Coder<K> keyCoder, final Coder<V> valueCoder) {
this.keyCoder = keyCoder;
this.valueCoder = valueCoder;
}

@Override
public Tuple2<K, V> call(Tuple2<ByteArray, byte[]> tuple) {
return new Tuple2<>(
fromByteArray(tuple._1().getValue(), keyCoder), fromByteArray(tuple._2(), valueCoder));
}

@Override
public Tuple2<K, V> apply(Tuple2<ByteArray, byte[]> tuple) {
return call(tuple);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,33 @@ public static <K, V> JavaRDD<KV<K, Iterable<WindowedValue<V>>>> groupByKeyOnly(
.mapPartitions(TranslationUtils.fromPairFlatMapFunction(), true);
}

/**
* Spark-level group by key operation that keeps original Beam {@link KV} pairs unchanged.
*
* @returns {@link JavaPairRDD} where the first value in the pair is the serialized key, and the
* second is an iterable of the {@link KV} pairs with that key.
*/
static <K, V> JavaPairRDD<ByteArray, Iterable<WindowedValue<KV<K, V>>>> groupByKeyPair(
JavaRDD<WindowedValue<KV<K, V>>> rdd, Coder<K> keyCoder, WindowedValueCoder<V> wvCoder) {
// we use coders to convert objects in the PCollection to byte arrays, so they
// can be transferred over the network for the shuffle.
JavaPairRDD<ByteArray, byte[]> pairRDD =
rdd.map(new ReifyTimestampsAndWindowsFunction<>())
.map(WindowedValue::getValue)
.mapToPair(TranslationUtils.toPairFunction())
.mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder));

JavaPairRDD<ByteArray, Iterable<Tuple2<ByteArray, byte[]>>> groupedRDD =
pairRDD.groupBy((value) -> value._1);

return groupedRDD
.mapValues(
it -> Iterables.transform(it, new CoderHelpers.FromByteFunction<>(keyCoder, wvCoder)))
.mapValues(it -> Iterables.transform(it, new TranslationUtils.FromPairFunction()))
.mapValues(
it -> Iterables.transform(it, new TranslationUtils.ToKVByWindowInValueFunction<>()));
}

/** Apply a composite {@link org.apache.beam.sdk.transforms.Combine.Globally} transformation. */
public static <InputT, AccumT> Optional<Iterable<WindowedValue<AccumT>>> combineGlobally(
JavaRDD<WindowedValue<InputT>> rdd,
Expand Down Expand Up @@ -169,8 +196,8 @@ public static <K, V> JavaRDD<WindowedValue<KV<K, V>>> reshuffle(
.mapToPair(TranslationUtils.toPairFunction())
.mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder))
.repartition(rdd.getNumPartitions())
.mapToPair(CoderHelpers.fromByteFunction(keyCoder, wvCoder))
.map(TranslationUtils.fromPairFunction())
.map(TranslationUtils.toKVByWindowInValue());
.mapToPair(new CoderHelpers.FromByteFunction(keyCoder, wvCoder))
.map(new TranslationUtils.FromPairFunction())
.map(new TranslationUtils.ToKVByWindowInValueFunction<>());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@

import static org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.createOutputMap;
import static org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowingStrategy;
import static org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.instantiateCoder;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.runners.core.SystemReduceFn;
Expand All @@ -41,6 +44,7 @@
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.io.SourceRDD;
import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.util.ByteArray;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
Expand All @@ -60,6 +64,7 @@
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Sets;
import org.apache.spark.HashPartitioner;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
Expand Down Expand Up @@ -206,7 +211,6 @@ private static <InputT, OutputT, SideInputT> void translateExecutableStage(
}
String inputPCollectionId = stagePayload.getInput();
Dataset inputDataset = context.popDataset(inputPCollectionId);
JavaRDD<WindowedValue<InputT>> inputRdd = ((BoundedDataset<InputT>) inputDataset).getRDD();
Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
BiMap<String, Integer> outputExtractionMap = createOutputMap(outputs.values());

Expand All @@ -223,14 +227,55 @@ private static <InputT, OutputT, SideInputT> void translateExecutableStage(
broadcastVariablesBuilder.put(collectionId, tuple2);
}

SparkExecutableStageFunction<InputT, SideInputT> function =
new SparkExecutableStageFunction<>(
stagePayload,
context.jobInfo,
outputExtractionMap,
broadcastVariablesBuilder.build(),
MetricsAccumulator.getInstance());
JavaRDD<RawUnionValue> staged = inputRdd.mapPartitions(function);
JavaRDD<RawUnionValue> staged;
if (stagePayload.getTimersCount() > 0) {
throw new UnsupportedOperationException(
"Timers are not yet supported in Spark portable runner (BEAM-7221)");
}
if (stagePayload.getUserStatesCount() > 0) {
Components components = pipeline.getComponents();
Coder<WindowedValue<InputT>> windowedInputCoder =
instantiateCoder(inputPCollectionId, components);
Coder valueCoder =
((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
// Stateful stages are only allowed of KV input to be able to group on the key
if (!(valueCoder instanceof KvCoder)) {
throw new IllegalStateException(
String.format(
Locale.ENGLISH,
"The element coder for stateful DoFn '%s' must be KvCoder but is: %s",
inputPCollectionId,
valueCoder.getClass().getSimpleName()));
}
Coder keyCoder = ((KvCoder) valueCoder).getKeyCoder();
Coder innerValueCoder = ((KvCoder) valueCoder).getValueCoder();
WindowingStrategy windowingStrategy = getWindowingStrategy(inputPCollectionId, components);
WindowFn<Object, BoundedWindow> windowFn = windowingStrategy.getWindowFn();
WindowedValue.WindowedValueCoder wvCoder =
WindowedValue.FullWindowedValueCoder.of(innerValueCoder, windowFn.windowCoder());

JavaPairRDD<ByteArray, Iterable<WindowedValue<KV>>> groupedByKey =
groupByKeyPair(inputDataset, keyCoder, wvCoder);
SparkExecutableStageFunction<KV, SideInputT> function =
new SparkExecutableStageFunction<>(
stagePayload,
context.jobInfo,
outputExtractionMap,
broadcastVariablesBuilder.build(),
MetricsAccumulator.getInstance());
staged = groupedByKey.flatMap(function.forPair());
} else {
JavaRDD<WindowedValue<InputT>> inputRdd2 = ((BoundedDataset<InputT>) inputDataset).getRDD();
SparkExecutableStageFunction<InputT, SideInputT> function2 =
new SparkExecutableStageFunction<>(
stagePayload,
context.jobInfo,
outputExtractionMap,
broadcastVariablesBuilder.build(),
MetricsAccumulator.getInstance());
staged = inputRdd2.mapPartitions(function2);
}

String intermediateId = getExecutableStageIntermediateId(transformNode);
context.pushDataset(
intermediateId,
Expand Down Expand Up @@ -274,6 +319,13 @@ public void setName(String name) {
}
}

/** Wrapper to help with type inference for {@link GroupCombineFunctions#groupByKeyPair} */
private static <K, V> JavaPairRDD<ByteArray, Iterable<WindowedValue<KV<K, V>>>> groupByKeyPair(
Dataset dataset, Coder<K> keyCoder, WindowedValueCoder<V> wvCoder) {
JavaRDD<WindowedValue<KV<K, V>>> inputRdd = ((BoundedDataset<KV<K, V>>) dataset).getRDD();
return GroupCombineFunctions.groupByKeyPair(inputRdd, keyCoder, wvCoder);
}

/**
* Collect and serialize the data and then broadcast the result. *This can be expensive.*
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@
import org.apache.beam.runners.fnexecution.control.RemoteBundle;
import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.state.InMemoryBagUserStateFactory;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
import org.apache.beam.runners.fnexecution.translation.BatchSideInputHandlerFactory;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator;
import org.apache.beam.runners.spark.util.ByteArray;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.util.WindowedValue;
Expand Down Expand Up @@ -77,6 +79,7 @@ class SparkExecutableStageFunction<InputT, SideInputT>
private final Map<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
sideInputs;
private final MetricsContainerStepMapAccumulator metricsAccumulator;
private transient InMemoryBagUserStateFactory bagUserStateHandlerFactory;

SparkExecutableStageFunction(
RunnerApi.ExecutableStagePayload stagePayload,
Expand Down Expand Up @@ -105,6 +108,11 @@ class SparkExecutableStageFunction<InputT, SideInputT>
this.metricsAccumulator = metricsAccumulator;
}

/** Call the executable stage function on the values of a PairRDD, ignoring the key. */
FlatMapFunction<Tuple2<ByteArray, Iterable<WindowedValue<InputT>>>, RawUnionValue> forPair() {
return (input) -> call(input._2.iterator());
}

@Override
public Iterator<RawUnionValue> call(Iterator<WindowedValue<InputT>> inputs) throws Exception {
try (JobBundleFactory jobBundleFactory = jobBundleFactoryCreator.create()) {
Expand Down Expand Up @@ -176,7 +184,23 @@ public <T> List<T> getSideInput(String pCollectionId) {
} catch (IOException e) {
throw new RuntimeException("Failed to setup state handler", e);
}

// Need to discard the old key's state
if (bagUserStateHandlerFactory != null) {
bagUserStateHandlerFactory.resetForNewKey();
}
final StateRequestHandler userStateHandler;
if (executableStage.getUserStates().size() > 0) {
bagUserStateHandlerFactory = new InMemoryBagUserStateFactory();
userStateHandler =
StateRequestHandlers.forBagUserStateHandlerFactory(
processBundleDescriptor, bagUserStateHandlerFactory);
} else {
userStateHandler = StateRequestHandler.unsupported();
}

handlerMap.put(StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputHandler);
handlerMap.put(StateKey.TypeCase.BAG_USER_STATE, userStateHandler);
return StateRequestHandlers.delegateBasedUponType(handlerMap);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ public void evaluate(
JavaRDD<WindowedValue<KV<K, OutputT>>> outRdd =
accumulatePerKey
.flatMapValues(sparkCombineFn::extractOutput)
.map(TranslationUtils.fromPairFunction())
.map(TranslationUtils.toKVByWindowInValue());
.map(new TranslationUtils.FromPairFunction())
.map(new TranslationUtils.ToKVByWindowInValueFunction<>());

context.putDataset(transform, new BoundedDataset<>(outRdd));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,19 @@ public static <K, V> PairFlatMapFunction<Iterator<KV<K, V>>, K, V> toPairFlatMap
}

/** A pair to {@link KV} function . */
static <K, V> Function<Tuple2<K, V>, KV<K, V>> fromPairFunction() {
return t2 -> KV.of(t2._1(), t2._2());
static class FromPairFunction<K, V>
implements Function<Tuple2<K, V>, KV<K, V>>,
org.apache.beam.vendor.guava.v20_0.com.google.common.base.Function<
Tuple2<K, V>, KV<K, V>> {
@Override
public KV<K, V> call(Tuple2<K, V> t2) {
return KV.of(t2._1(), t2._2());
}

@Override
public KV<K, V> apply(Tuple2<K, V> t2) {
return call(t2);
}
}

/** A pair to {@link KV} flatmap function . */
Expand All @@ -160,11 +171,21 @@ static <K, V> FlatMapFunction<Iterator<Tuple2<K, V>>, KV<K, V>> fromPairFlatMapF
}

/** Extract window from a {@link KV} with {@link WindowedValue} value. */
static <K, V> Function<KV<K, WindowedValue<V>>, WindowedValue<KV<K, V>>> toKVByWindowInValue() {
return kv -> {
static class ToKVByWindowInValueFunction<K, V>
implements Function<KV<K, WindowedValue<V>>, WindowedValue<KV<K, V>>>,
org.apache.beam.vendor.guava.v20_0.com.google.common.base.Function<
KV<K, WindowedValue<V>>, WindowedValue<KV<K, V>>> {

@Override
public WindowedValue<KV<K, V>> call(KV<K, WindowedValue<V>> kv) {
WindowedValue<V> wv = kv.getValue();
return wv.withValue(KV.of(kv.getKey(), wv.getValue()));
};
}

@Override
public WindowedValue<KV<K, V>> apply(KV<K, WindowedValue<V>> kv) {
return call(kv);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,6 @@ def test_metrics(self):
# Skip until Spark runner supports metrics.
raise unittest.SkipTest("BEAM-7219")

def test_pardo_state_only(self):
# Skip until Spark runner supports user state.
raise unittest.SkipTest("BEAM-7044")

def test_pardo_timers(self):
# Skip until Spark runner supports timers.
raise unittest.SkipTest("BEAM-7221")
Expand Down

0 comments on commit 8c07b77

Please sign in to comment.