From f1cdb558e57c27e745a2f6a5f8440edee3412154 Mon Sep 17 00:00:00 2001 From: Aviem Zur Date: Sun, 13 Nov 2016 13:57:07 +0200 Subject: [PATCH 1/2] Unify spark-runner EvaluationContext and StreamingEvaluationContext --- .../beam/runners/spark/SparkRunner.java | 4 +- .../spark/translation/BoundedDataset.java | 114 ++++++++ .../runners/spark/translation/Dataset.java | 34 +++ .../spark/translation/EvaluationContext.java | 228 ++++++--------- .../translation/TransformTranslator.java | 93 +++--- .../SparkRunnerStreamingContextFactory.java | 7 +- .../streaming/StreamingEvaluationContext.java | 272 ------------------ .../StreamingTransformTranslator.java | 135 ++++----- .../streaming/UnboundedDataset.java | 103 +++++++ 9 files changed, 459 insertions(+), 531 deletions(-) create mode 100644 runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java create mode 100644 runners/spark/src/main/java/org/apache/beam/runners/spark/translation/Dataset.java delete mode 100644 runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java create mode 100644 runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java index 45c7f55982ac..6bbef39e63c0 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java @@ -26,7 +26,6 @@ import org.apache.beam.runners.spark.translation.TransformEvaluator; import org.apache.beam.runners.spark.translation.TransformTranslator; import org.apache.beam.runners.spark.translation.streaming.SparkRunnerStreamingContextFactory; -import org.apache.beam.runners.spark.translation.streaming.StreamingEvaluationContext; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; @@ -49,6 +48,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; + /** * The SparkRunner translate operations defined on a pipeline to a representation * executable by Spark, and then submitting the job to Spark to be executed. If we wanted to run @@ -136,7 +136,7 @@ public EvaluationResult run(Pipeline pipeline) { jssc.start(); // if recovering from checkpoint, we have to reconstruct the EvaluationResult instance. - return contextFactory.getCtxt() == null ? new StreamingEvaluationContext(jssc.sc(), + return contextFactory.getCtxt() == null ? new EvaluationContext(jssc.sc(), pipeline, jssc, mOptions.getTimeout()) : contextFactory.getCtxt(); } else { if (mOptions.getTimeout() > 0) { diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java new file mode 100644 index 000000000000..774efb9916cd --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.common.base.Function; +import com.google.common.collect.Iterables; +import java.util.List; +import javax.annotation.Nullable; +import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaRDDLike; +import org.apache.spark.api.java.JavaSparkContext; + +/** + * Holds an RDD or values for deferred conversion to an RDD if needed. PCollections are sometimes + * created from a collection of objects (using RDD parallelize) and then only used to create View + * objects; in which case they do not need to be converted to bytes since they are not transferred + * across the network until they are broadcast. + */ +public class BoundedDataset implements Dataset { + // only set if creating an RDD from a static collection + @Nullable private transient JavaSparkContext jsc; + + private Iterable> windowedValues; + private Coder coder; + private JavaRDD> rdd; + + BoundedDataset(JavaRDD> rdd) { + this.rdd = rdd; + } + + BoundedDataset(Iterable values, JavaSparkContext jsc, Coder coder) { + this.windowedValues = + Iterables.transform(values, WindowingHelpers.windowValueFunction()); + this.jsc = jsc; + this.coder = coder; + } + + @SuppressWarnings("ConstantConditions") + public JavaRDD> getRDD() { + if (rdd == null) { + WindowedValue.ValueOnlyWindowedValueCoder windowCoder = + WindowedValue.getValueOnlyCoder(coder); + rdd = jsc.parallelize(CoderHelpers.toByteArrays(windowedValues, windowCoder)) + .map(CoderHelpers.fromByteFunction(windowCoder)); + } + return rdd; + } + + Iterable> getValues(PCollection pcollection) { + if (windowedValues == null) { + WindowFn windowFn = + pcollection.getWindowingStrategy().getWindowFn(); + Coder windowCoder = windowFn.windowCoder(); + final WindowedValue.WindowedValueCoder windowedValueCoder; + if (windowFn instanceof GlobalWindows) { + windowedValueCoder = + WindowedValue.ValueOnlyWindowedValueCoder.of(pcollection.getCoder()); + } else { + windowedValueCoder = + WindowedValue.FullWindowedValueCoder.of(pcollection.getCoder(), windowCoder); + } + JavaRDDLike bytesRDD = + rdd.map(CoderHelpers.toByteFunction(windowedValueCoder)); + List clientBytes = bytesRDD.collect(); + windowedValues = Iterables.transform(clientBytes, + new Function>() { + @Override + public WindowedValue apply(byte[] bytes) { + return CoderHelpers.fromByteArray(bytes, windowedValueCoder); + } + }); + } + return windowedValues; + } + + @Override + public void cache() { + rdd.cache(); + } + + @Override + public void action() { + rdd.count(); + } + + @Override + public void setName(String name) { + rdd.setName(name); + } + +} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/Dataset.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/Dataset.java new file mode 100644 index 000000000000..36b03feb77c7 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/Dataset.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import java.io.Serializable; + + +/** + * Holder for Spark RDD/DStream. + */ +public interface Dataset extends Serializable { + + void cache(); + + void action(); + + void setName(String name); +} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index 6ccec8569e1f..3e09b3994c7d 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -20,17 +20,15 @@ import static com.google.common.base.Preconditions.checkArgument; -import com.google.common.base.Function; import com.google.common.collect.Iterables; import java.io.IOException; import java.util.LinkedHashMap; import java.util.LinkedHashSet; -import java.util.List; import java.util.Map; import java.util.Set; import org.apache.beam.runners.spark.EvaluationResult; import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton; -import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.beam.runners.spark.translation.streaming.UnboundedDataset; import org.apache.beam.sdk.AggregatorRetrievalException; import org.apache.beam.sdk.AggregatorValues; import org.apache.beam.sdk.Pipeline; @@ -39,17 +37,15 @@ import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.GlobalWindows; -import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; -import org.apache.spark.api.java.JavaRDDLike; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.joda.time.Duration; @@ -58,91 +54,47 @@ */ public class EvaluationContext implements EvaluationResult { private final JavaSparkContext jsc; - private final Pipeline pipeline; + private JavaStreamingContext jssc; private final SparkRuntimeContext runtime; - private final Map> pcollections = new LinkedHashMap<>(); - private final Set> leafRdds = new LinkedHashSet<>(); - private final Set multireads = new LinkedHashSet<>(); + private final Pipeline pipeline; + private long timeout; + private final Map datasets = new LinkedHashMap<>(); + private final Map pcollections = new LinkedHashMap<>(); + private final Set leaves = new LinkedHashSet<>(); + private final Set multiReads = new LinkedHashSet<>(); private final Map pobjects = new LinkedHashMap<>(); private final Map>> pview = new LinkedHashMap<>(); - protected AppliedPTransform currentTransform; + private AppliedPTransform currentTransform; + private State state; public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) { this.jsc = jsc; this.pipeline = pipeline; this.runtime = new SparkRuntimeContext(pipeline, jsc); + this.state = State.DONE; } - /** - * Holds an RDD or values for deferred conversion to an RDD if needed. PCollections are - * sometimes created from a collection of objects (using RDD parallelize) and then - * only used to create View objects; in which case they do not need to be - * converted to bytes since they are not transferred across the network until they are - * broadcast. - */ - private class RDDHolder { - - private Iterable> windowedValues; - private Coder coder; - private JavaRDDLike, ?> rdd; - - RDDHolder(Iterable values, Coder coder) { - this.windowedValues = - Iterables.transform(values, WindowingHelpers.windowValueFunction()); - this.coder = coder; - } - - RDDHolder(JavaRDDLike, ?> rdd) { - this.rdd = rdd; - } - - JavaRDDLike, ?> getRDD() { - if (rdd == null) { - WindowedValue.ValueOnlyWindowedValueCoder windowCoder = - WindowedValue.getValueOnlyCoder(coder); - rdd = jsc.parallelize(CoderHelpers.toByteArrays(windowedValues, windowCoder)) - .map(CoderHelpers.fromByteFunction(windowCoder)); - } - return rdd; - } - - Iterable> getValues(PCollection pcollection) { - if (windowedValues == null) { - WindowFn windowFn = - pcollection.getWindowingStrategy().getWindowFn(); - Coder windowCoder = windowFn.windowCoder(); - final WindowedValue.WindowedValueCoder windowedValueCoder; - if (windowFn instanceof GlobalWindows) { - windowedValueCoder = - WindowedValue.ValueOnlyWindowedValueCoder.of(pcollection.getCoder()); - } else { - windowedValueCoder = - WindowedValue.FullWindowedValueCoder.of(pcollection.getCoder(), windowCoder); - } - JavaRDDLike bytesRDD = - rdd.map(CoderHelpers.toByteFunction(windowedValueCoder)); - List clientBytes = bytesRDD.collect(); - windowedValues = Iterables.transform(clientBytes, - new Function>() { - @Override - public WindowedValue apply(byte[] bytes) { - return CoderHelpers.fromByteArray(bytes, windowedValueCoder); - } - }); - } - return windowedValues; - } + public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline, + JavaStreamingContext jssc, long timeout) { + this(jsc, pipeline); + this.jssc = jssc; + this.timeout = timeout; + this.state = State.RUNNING; } - protected JavaSparkContext getSparkContext() { + JavaSparkContext getSparkContext() { return jsc; } + public JavaStreamingContext getStreamingContext() { + return jssc; + } + public Pipeline getPipeline() { return pipeline; } - protected SparkRuntimeContext getRuntimeContext() { + public SparkRuntimeContext getRuntimeContext() { return runtime; } @@ -150,11 +102,7 @@ public void setCurrentTransform(AppliedPTransform transform) { this.currentTransform = transform; } - protected AppliedPTransform getCurrentTransform() { - return currentTransform; - } - - protected T getInput(PTransform transform) { + public T getInput(PTransform transform) { checkArgument(currentTransform != null && currentTransform.getTransform() == transform, "can only be called with current transform"); @SuppressWarnings("unchecked") @@ -162,7 +110,7 @@ protected T getInput(PTransform transform) { return input; } - protected T getOutput(PTransform transform) { + public T getOutput(PTransform transform) { checkArgument(currentTransform != null && currentTransform.getTransform() == transform, "can only be called with current transform"); @SuppressWarnings("unchecked") @@ -170,81 +118,73 @@ protected T getOutput(PTransform transform) { return output; } - protected void setOutputRDD(PTransform transform, - JavaRDDLike, ?> rdd) { - setRDD((PValue) getOutput(transform), rdd); + public void putDataset(PTransform transform, Dataset dataset) { + putDataset((PValue) getOutput(transform), dataset); } - protected void setOutputRDDFromValues(PTransform transform, Iterable values, - Coder coder) { - pcollections.put((PValue) getOutput(transform), new RDDHolder<>(values, coder)); + public void putDataset(PValue pvalue, Dataset dataset) { + try { + dataset.setName(pvalue.getName()); + } catch (IllegalStateException e) { + // name not set, ignore + } + datasets.put(pvalue, dataset); + leaves.add(dataset); } - public void setPView(PValue view, Iterable> value) { + void putBoundedDatasetFromValues(PTransform transform, Iterable values, + Coder coder) { + datasets.put((PValue) getOutput(transform), new BoundedDataset<>(values, jsc, coder)); + } + + public void setUnboundedDatasetFromQueue( + PTransform transform, Iterable> values, Coder coder) { + datasets.put((PValue) getOutput(transform), new UnboundedDataset<>(values, jssc, coder)); + } + void setPView(PValue view, Iterable> value) { pview.put(view, value); } - protected boolean hasOutputRDD(PTransform transform) { - PValue pvalue = (PValue) getOutput(transform); - return pcollections.containsKey(pvalue); + public Dataset borrowDataset(PTransform transform) { + return borrowDataset((PValue) getInput(transform)); } - public JavaRDDLike getRDD(PValue pvalue) { - RDDHolder rddHolder = pcollections.get(pvalue); - JavaRDDLike rdd = rddHolder.getRDD(); - leafRdds.remove(rddHolder); - if (multireads.contains(pvalue)) { + public Dataset borrowDataset(PValue pvalue) { + Dataset dataset = datasets.get(pvalue); + leaves.remove(dataset); + if (multiReads.contains(pvalue)) { // Ensure the RDD is marked as cached - rdd.rdd().cache(); + dataset.cache(); } else { - multireads.add(pvalue); - } - return rdd; - } - - protected void setRDD(PValue pvalue, JavaRDDLike, ?> rdd) { - try { - rdd.rdd().setName(pvalue.getName()); - } catch (IllegalStateException e) { - // name not set, ignore + multiReads.add(pvalue); } - RDDHolder rddHolder = new RDDHolder<>(rdd); - pcollections.put(pvalue, rddHolder); - leafRdds.add(rddHolder); - } - - protected JavaRDDLike getInputRDD(PTransform transform) { - return getRDD((PValue) getInput(transform)); + return dataset; } - Iterable> getPCollectionView(PCollectionView view) { return pview.get(view); } /** - * Computes the outputs for all RDDs that are leaves in the DAG and do not have any - * actions (like saving to a file) registered on them (i.e. they are performed for side - * effects). + * Computes the outputs for all RDDs that are leaves in the DAG and do not have any actions (like + * saving to a file) registered on them (i.e. they are performed for side effects). */ public void computeOutputs() { - for (RDDHolder rddHolder : leafRdds) { - JavaRDDLike rdd = rddHolder.getRDD(); - rdd.rdd().cache(); // cache so that any subsequent get() is cheap - rdd.count(); // force the RDD to be computed + for (Dataset dataset : leaves) { + dataset.cache(); // cache so that any subsequent get() is cheap. + dataset.action(); // force computation. } } + @SuppressWarnings("unchecked") @Override public T get(PValue value) { if (pobjects.containsKey(value)) { - @SuppressWarnings("unchecked") T result = (T) pobjects.get(value); return result; } if (pcollections.containsKey(value)) { - JavaRDDLike rdd = pcollections.get(value).getRDD(); - @SuppressWarnings("unchecked") + JavaRDD rdd = ((BoundedDataset) pcollections.get(value)).getRDD(); T res = (T) Iterables.getOnlyElement(rdd.collect()); pobjects.put(value, res); return res; @@ -271,27 +211,37 @@ public MetricResults metrics() { @Override public Iterable get(PCollection pcollection) { @SuppressWarnings("unchecked") - RDDHolder rddHolder = (RDDHolder) pcollections.get(pcollection); - Iterable> windowedValues = rddHolder.getValues(pcollection); + BoundedDataset boundedDataset = (BoundedDataset) datasets.get(pcollection); + Iterable> windowedValues = boundedDataset.getValues(pcollection); return Iterables.transform(windowedValues, WindowingHelpers.unwindowValueFunction()); } Iterable> getWindowedValues(PCollection pcollection) { @SuppressWarnings("unchecked") - RDDHolder rddHolder = (RDDHolder) pcollections.get(pcollection); - return rddHolder.getValues(pcollection); + BoundedDataset boundedDataset = (BoundedDataset) datasets.get(pcollection); + return boundedDataset.getValues(pcollection); } @Override public void close(boolean gracefully) { - // graceful stop is used for streaming. + if (isStreamingPipeline()) { + // stop streaming context + if (timeout > 0) { + jssc.awaitTerminationOrTimeout(timeout); + } else { + jssc.awaitTermination(); + } + // stop streaming context gracefully, so checkpointing (and other computations) get to + // finish before shutdown. + jssc.stop(false, gracefully); + } + state = State.DONE; SparkContextFactory.stopSparkContext(jsc); } - /** The runner is blocking. */ @Override public State getState() { - return State.DONE; + return state; } @Override @@ -307,9 +257,19 @@ public State waitUntilFinish() { @Override public State waitUntilFinish(Duration duration) { - // This is no-op, since Spark runner in batch is blocking. - // It needs to be updated once SparkRunner supports non-blocking execution: - // https://issues.apache.org/jira/browse/BEAM-595 - return State.DONE; + if (isStreamingPipeline()) { + throw new UnsupportedOperationException( + "Spark runner EvaluationContext does not support waitUntilFinish for streaming " + + "pipelines."); + } else { + // This is no-op, since Spark runner in batch is blocking. + // It needs to be updated once SparkRunner supports non-blocking execution: + // https://issues.apache.org/jira/browse/BEAM-595 + return State.DONE; + } + } + + private boolean isStreamingPipeline() { + return jssc != null; } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 2e682c437862..10e3376d16ae 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -16,7 +16,6 @@ * limitations under the License. */ - package org.apache.beam.runners.spark.translation; import static com.google.common.base.Preconditions.checkState; @@ -73,11 +72,9 @@ import org.apache.spark.Accumulator; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaRDDLike; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; - import scala.Tuple2; @@ -101,11 +98,11 @@ public void evaluate(Flatten.FlattenPCollectionList transform, EvaluationCont } else { JavaRDD>[] rdds = new JavaRDD[pcs.size()]; for (int i = 0; i < rdds.length; i++) { - rdds[i] = (JavaRDD>) context.getRDD(pcs.get(i)); + rdds[i] = ((BoundedDataset) context.borrowDataset(pcs.get(i))).getRDD(); } unionRDD = context.getSparkContext().union(rdds); } - context.setOutputRDD(transform, unionRDD); + context.putDataset(transform, new BoundedDataset<>(unionRDD)); } }; } @@ -116,7 +113,7 @@ private static TransformEvaluator> groupByKey() { public void evaluate(GroupByKey transform, EvaluationContext context) { @SuppressWarnings("unchecked") JavaRDD>> inRDD = - (JavaRDD>>) context.getInputRDD(transform); + ((BoundedDataset>) context.borrowDataset(transform)).getRDD(); @SuppressWarnings("unchecked") final KvCoder coder = (KvCoder) context.getInput(transform).getCoder(); @@ -124,8 +121,9 @@ public void evaluate(GroupByKey transform, EvaluationContext context) { final Accumulator accum = AccumulatorSingleton.getInstance(context.getSparkContext()); - context.setOutputRDD(transform, GroupCombineFunctions.groupByKey(inRDD, accum, coder, - context.getRuntimeContext(), context.getInput(transform).getWindowingStrategy())); + context.putDataset(transform, + new BoundedDataset<>(GroupCombineFunctions.groupByKey(inRDD, accum, coder, + context.getRuntimeContext(), context.getInput(transform).getWindowingStrategy()))); } }; } @@ -146,16 +144,17 @@ public void evaluate(Combine.GroupedValues transform, CombineFnUtil.toFnWithContext(transform.getFn()); @SuppressWarnings("unchecked") - JavaRDDLike>>, ?> inRDD = - (JavaRDDLike>>, ?>) - context.getInputRDD(transform); + JavaRDD>>> inRDD = + ((BoundedDataset>>) + context.borrowDataset(transform)).getRDD(); - SparkKeyedCombineFn combineFnWithContext = + SparkKeyedCombineFn combineFnWithContext = new SparkKeyedCombineFn<>(fn, context.getRuntimeContext(), TranslationUtils.getSideInputs(transform.getSideInputs(), context), - windowingStrategy); - context.setOutputRDD(transform, inRDD.map(new TranslationUtils.CombineGroupedValues<>( - combineFnWithContext))); + windowingStrategy); + context.putDataset(transform, new BoundedDataset<>(inRDD.map(new TranslationUtils + .CombineGroupedValues<>( + combineFnWithContext)))); } }; } @@ -182,10 +181,11 @@ public void evaluate(Combine.Globally transform, EvaluationCont @SuppressWarnings("unchecked") JavaRDD> inRdd = - (JavaRDD>) context.getInputRDD(transform); + ((BoundedDataset) context.borrowDataset(transform)).getRDD(); - context.setOutputRDD(transform, GroupCombineFunctions.combineGlobally(inRdd, combineFn, - iCoder, oCoder, runtimeContext, windowingStrategy, sideInputs, hasDefault)); + context.putDataset(transform, new BoundedDataset<>(GroupCombineFunctions + .combineGlobally(inRdd, combineFn, + iCoder, oCoder, runtimeContext, windowingStrategy, sideInputs, hasDefault))); } }; } @@ -212,10 +212,11 @@ public void evaluate(Combine.PerKey transform, @SuppressWarnings("unchecked") JavaRDD>> inRdd = - (JavaRDD>>) context.getInputRDD(transform); + ((BoundedDataset>) context.borrowDataset(transform)).getRDD(); - context.setOutputRDD(transform, GroupCombineFunctions.combinePerKey(inRdd, combineFn, - inputCoder, runtimeContext, windowingStrategy, sideInputs)); + context.putDataset(transform, new BoundedDataset<>(GroupCombineFunctions + .combinePerKey(inRdd, combineFn, + inputCoder, runtimeContext, windowingStrategy, sideInputs))); } }; } @@ -225,8 +226,8 @@ private static TransformEvaluator @Override public void evaluate(ParDo.Bound transform, EvaluationContext context) { @SuppressWarnings("unchecked") - JavaRDDLike, ?> inRDD = - (JavaRDDLike, ?>) context.getInputRDD(transform); + JavaRDD> inRDD = + ((BoundedDataset) context.borrowDataset(transform)).getRDD(); @SuppressWarnings("unchecked") final WindowFn windowFn = (WindowFn) context.getInput(transform).getWindowingStrategy().getWindowFn(); @@ -234,9 +235,9 @@ public void evaluate(ParDo.Bound transform, EvaluationContext c AccumulatorSingleton.getInstance(context.getSparkContext()); Map, KV, BroadcastHelper>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs(), context); - context.setOutputRDD(transform, - inRDD.mapPartitions(new DoFnFunction<>(accum, transform.getFn(), - context.getRuntimeContext(), sideInputs, windowFn))); + context.putDataset(transform, + new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, transform.getFn(), + context.getRuntimeContext(), sideInputs, windowFn)))); } }; } @@ -247,8 +248,8 @@ public void evaluate(ParDo.Bound transform, EvaluationContext c @Override public void evaluate(ParDo.BoundMulti transform, EvaluationContext context) { @SuppressWarnings("unchecked") - JavaRDDLike, ?> inRDD = - (JavaRDDLike, ?>) context.getInputRDD(transform); + JavaRDD> inRDD = + ((BoundedDataset) context.borrowDataset(transform)).getRDD(); @SuppressWarnings("unchecked") final WindowFn windowFn = (WindowFn) context.getInput(transform).getWindowingStrategy().getWindowFn(); @@ -268,7 +269,7 @@ public void evaluate(ParDo.BoundMulti transform, EvaluationCont // Object is the best we can do since different outputs can have different tags JavaRDD> values = (JavaRDD>) (JavaRDD) filtered.values(); - context.setRDD(e.getValue(), values); + context.putDataset(e.getValue(), new BoundedDataset<>(values)); } } }; @@ -281,8 +282,8 @@ private static TransformEvaluator> readText() { public void evaluate(TextIO.Read.Bound transform, EvaluationContext context) { String pattern = transform.getFilepattern(); JavaRDD> rdd = context.getSparkContext().textFile(pattern) - .map(WindowingHelpers.windowFunction()); - context.setOutputRDD(transform, rdd); + .map(WindowingHelpers.windowFunction()); + context.putDataset(transform, new BoundedDataset<>(rdd)); } }; } @@ -293,7 +294,7 @@ private static TransformEvaluator> writeText() { public void evaluate(TextIO.Write.Bound transform, EvaluationContext context) { @SuppressWarnings("unchecked") JavaPairRDD last = - ((JavaRDDLike, ?>) context.getInputRDD(transform)) + ((BoundedDataset) context.borrowDataset(transform)).getRDD() .map(WindowingHelpers.unwindowFunction()) .mapToPair(new PairFunction() { @@ -331,7 +332,7 @@ public T call(AvroKey key) { return key.datum(); } }).map(WindowingHelpers.windowFunction()); - context.setOutputRDD(transform, rdd); + context.putDataset(transform, new BoundedDataset<>(rdd)); } }; } @@ -349,7 +350,7 @@ public void evaluate(AvroIO.Write.Bound transform, EvaluationContext context) AvroJob.setOutputKeySchema(job, transform.getSchema()); @SuppressWarnings("unchecked") JavaPairRDD, NullWritable> last = - ((JavaRDDLike, ?>) context.getInputRDD(transform)) + ((BoundedDataset) context.borrowDataset(transform)).getRDD() .map(WindowingHelpers.unwindowFunction()) .mapToPair(new PairFunction, NullWritable>() { @Override @@ -377,7 +378,7 @@ public void evaluate(Read.Bounded transform, EvaluationContext context) { JavaRDD> input = new SourceRDD.Bounded<>( jsc.sc(), transform.getSource(), runtimeContext).toJavaRDD(); // cache to avoid re-evaluation of the source by Spark's lazy DAG evaluation. - context.setOutputRDD(transform, input.cache()); + context.putDataset(transform, new BoundedDataset<>(input.cache())); } }; } @@ -388,7 +389,7 @@ private static TransformEvaluator> readHadoop() public void evaluate(HadoopIO.Read.Bound transform, EvaluationContext context) { String pattern = transform.getFilepattern(); JavaSparkContext jsc = context.getSparkContext(); - @SuppressWarnings ("unchecked") + @SuppressWarnings("unchecked") JavaPairRDD file = jsc.newAPIHadoopFile(pattern, transform.getFormatClass(), transform.getKeyClass(), transform.getValueClass(), @@ -400,7 +401,7 @@ public KV call(Tuple2 t2) throws Exception { return KV.of(t2._1(), t2._2()); } }).map(WindowingHelpers.>windowFunction()); - context.setOutputRDD(transform, rdd); + context.putDataset(transform, new BoundedDataset<>(rdd)); } }; } @@ -410,8 +411,8 @@ private static TransformEvaluator> writeHadoop @Override public void evaluate(HadoopIO.Write.Bound transform, EvaluationContext context) { @SuppressWarnings("unchecked") - JavaPairRDD last = ((JavaRDDLike>, ?>) context - .getInputRDD(transform)) + JavaPairRDD last = ((BoundedDataset>) context.borrowDataset(transform)) + .getRDD() .map(WindowingHelpers.>unwindowFunction()) .mapToPair(new PairFunction, K, V>() { @Override @@ -492,20 +493,20 @@ private static TransformEvaluator> @Override public void evaluate(Window.Bound transform, EvaluationContext context) { @SuppressWarnings("unchecked") - JavaRDDLike, ?> inRDD = - (JavaRDDLike, ?>) context.getInputRDD(transform); + JavaRDD> inRDD = + ((BoundedDataset) context.borrowDataset(transform)).getRDD(); if (TranslationUtils.skipAssignWindows(transform, context)) { - context.setOutputRDD(transform, inRDD); + context.putDataset(transform, new BoundedDataset<>(inRDD)); } else { @SuppressWarnings("unchecked") WindowFn windowFn = (WindowFn) transform.getWindowFn(); OldDoFn addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); Accumulator accum = AccumulatorSingleton.getInstance(context.getSparkContext()); - context.setOutputRDD(transform, - inRDD.mapPartitions(new DoFnFunction<>(accum, addWindowsDoFn, - context.getRuntimeContext(), null, null))); + context.putDataset(transform, + new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, addWindowsDoFn, + context.getRuntimeContext(), null, null)))); } } }; @@ -519,7 +520,7 @@ public void evaluate(Create.Values transform, EvaluationContext context) { // Use a coder to convert the objects in the PCollection to byte arrays, so they // can be transferred over the network. Coder coder = context.getOutput(transform).getCoder(); - context.setOutputRDDFromValues(transform, elems, coder); + context.putBoundedDatasetFromValues(transform, elems, coder); } }; } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java index a670f61a3130..e5ad19f18aa8 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java @@ -22,6 +22,7 @@ import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.SparkRunner; +import org.apache.beam.runners.spark.translation.EvaluationContext; import org.apache.beam.runners.spark.translation.SparkContextFactory; import org.apache.beam.runners.spark.translation.SparkPipelineTranslator; import org.apache.beam.runners.spark.translation.TransformTranslator; @@ -53,7 +54,7 @@ public SparkRunnerStreamingContextFactory(Pipeline pipeline, SparkPipelineOption this.options = options; } - private StreamingEvaluationContext ctxt; + private EvaluationContext ctxt; @Override public JavaStreamingContext create() { @@ -71,7 +72,7 @@ public JavaStreamingContext create() { JavaSparkContext jsc = SparkContextFactory.getSparkContext(options); JavaStreamingContext jssc = new JavaStreamingContext(jsc, batchDuration); - ctxt = new StreamingEvaluationContext(jsc, pipeline, jssc, + ctxt = new EvaluationContext(jsc, pipeline, jssc, options.getTimeout()); pipeline.traverseTopologically(new SparkRunner.Evaluator(translator, ctxt)); ctxt.computeOutputs(); @@ -94,7 +95,7 @@ public JavaStreamingContext create() { return jssc; } - public StreamingEvaluationContext getCtxt() { + public EvaluationContext getCtxt() { return ctxt; } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java deleted file mode 100644 index bfba3163eaf6..000000000000 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.translation.streaming; - - -import com.google.common.collect.Iterables; - -import java.io.IOException; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.Map; -import java.util.Queue; -import java.util.Set; -import java.util.concurrent.LinkedBlockingQueue; -import org.apache.beam.runners.spark.coders.CoderHelpers; -import org.apache.beam.runners.spark.translation.EvaluationContext; -import org.apache.beam.runners.spark.translation.SparkRuntimeContext; -import org.apache.beam.runners.spark.translation.WindowingHelpers; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.transforms.AppliedPTransform; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.PInput; -import org.apache.beam.sdk.values.POutput; -import org.apache.beam.sdk.values.PValue; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaRDDLike; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.VoidFunction; -import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.streaming.api.java.JavaDStreamLike; -import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.joda.time.Duration; - - -/** - * Streaming evaluation context helps to handle streaming. - */ -public class StreamingEvaluationContext extends EvaluationContext { - - private final JavaStreamingContext jssc; - private final long timeout; - private final Map> pstreams = new LinkedHashMap<>(); - private final Set> leafStreams = new LinkedHashSet<>(); - - public StreamingEvaluationContext(JavaSparkContext jsc, Pipeline pipeline, - JavaStreamingContext jssc, long timeout) { - super(jsc, pipeline); - this.jssc = jssc; - this.timeout = timeout; - } - - /** - * DStream holder Can also crate a DStream from a supplied queue of values, but mainly for - * testing. - */ - private class DStreamHolder { - - private Iterable> values; - private Coder coder; - private JavaDStream> dStream; - - DStreamHolder(Iterable> values, Coder coder) { - this.values = values; - this.coder = coder; - } - - DStreamHolder(JavaDStream> dStream) { - this.dStream = dStream; - } - - @SuppressWarnings("unchecked") - JavaDStream> getDStream() { - if (dStream == null) { - WindowedValue.ValueOnlyWindowedValueCoder windowCoder = - WindowedValue.getValueOnlyCoder(coder); - // create the DStream from queue - Queue>> rddQueue = new LinkedBlockingQueue<>(); - JavaRDD> lastRDD = null; - for (Iterable v : values) { - Iterable> windowedValues = - Iterables.transform(v, WindowingHelpers.windowValueFunction()); - JavaRDD> rdd = getSparkContext().parallelize( - CoderHelpers.toByteArrays(windowedValues, windowCoder)).map( - CoderHelpers.fromByteFunction(windowCoder)); - rddQueue.offer(rdd); - lastRDD = rdd; - } - // create dstream from queue, one at a time, - // with last as default in case batches repeat (graceful stops for example). - // if the stream is empty, avoid creating a default empty RDD. - // mainly for unit test so no reason to have this configurable. - dStream = lastRDD != null ? jssc.queueStream(rddQueue, true, lastRDD) - : jssc.queueStream(rddQueue, true); - } - return dStream; - } - } - - void setDStreamFromQueue( - PTransform transform, Iterable> values, Coder coder) { - pstreams.put((PValue) getOutput(transform), new DStreamHolder<>(values, coder)); - } - - void setStream(PTransform transform, JavaDStream> dStream) { - setStream((PValue) getOutput(transform), dStream); - } - - void setStream(PValue pvalue, JavaDStream> dStream) { - DStreamHolder dStreamHolder = new DStreamHolder<>(dStream); - pstreams.put(pvalue, dStreamHolder); - leafStreams.add(dStreamHolder); - } - - boolean hasStream(PTransform transform) { - PValue pvalue = (PValue) getInput(transform); - return hasStream(pvalue); - } - - boolean hasStream(PValue pvalue) { - return pstreams.containsKey(pvalue); - } - - JavaDStreamLike getStream(PTransform transform) { - return getStream((PValue) getInput(transform)); - } - - JavaDStreamLike getStream(PValue pvalue) { - DStreamHolder dStreamHolder = pstreams.get(pvalue); - JavaDStreamLike dStream = dStreamHolder.getDStream(); - leafStreams.remove(dStreamHolder); - return dStream; - } - - // used to set the RDD from the DStream in the RDDHolder for transformation - void setInputRDD( - PTransform transform, JavaRDDLike, ?> rdd) { - setRDD((PValue) getInput(transform), rdd); - } - - // used to get the RDD transformation output and use it as the DStream transformation output - JavaRDDLike getOutputRDD(PTransform transform) { - return getRDD((PValue) getOutput(transform)); - } - - public JavaStreamingContext getStreamingContext() { - return jssc; - } - - @Override - public void computeOutputs() { - super.computeOutputs(); // in case the pipeline contains bounded branches as well. - for (DStreamHolder streamHolder : leafStreams) { - computeOutput(streamHolder); - } // force a DStream action - } - - private static void computeOutput(DStreamHolder streamHolder) { - JavaDStream> dStream = streamHolder.getDStream(); - // cache in DStream level not RDD - // because there could be a difference in StorageLevel if the DStream is windowed. - dStream.dstream().cache(); - dStream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) throws Exception { - rdd.count(); - } - }); - } - - @Override - public void close(boolean gracefully) { - if (timeout > 0) { - jssc.awaitTerminationOrTimeout(timeout); - } else { - jssc.awaitTermination(); - } - // stop streaming context gracefully, so checkpointing (and other computations) get to - // finish before shutdown. - jssc.stop(false, gracefully); - state = State.DONE; - super.close(false); - } - - private State state = State.RUNNING; - - @Override - public State getState() { - return state; - } - - @Override - public State cancel() throws IOException { - throw new UnsupportedOperationException( - "Spark runner StreamingEvaluationContext does not support cancel."); - } - - @Override - public State waitUntilFinish() { - throw new UnsupportedOperationException( - "Spark runner StreamingEvaluationContext does not support waitUntilFinish."); - } - - @Override - public State waitUntilFinish(Duration duration) { - throw new UnsupportedOperationException( - "Spark runner StreamingEvaluationContext does not support waitUntilFinish."); - } - - //---------------- override in order to expose in package - @Override - protected InputT getInput(PTransform transform) { - return super.getInput(transform); - } - @Override - protected OutputT getOutput(PTransform transform) { - return super.getOutput(transform); - } - - @Override - protected JavaSparkContext getSparkContext() { - return super.getSparkContext(); - } - - @Override - protected SparkRuntimeContext getRuntimeContext() { - return super.getRuntimeContext(); - } - - @Override - public void setCurrentTransform(AppliedPTransform transform) { - super.setCurrentTransform(transform); - } - - @Override - protected AppliedPTransform getCurrentTransform() { - return super.getCurrentTransform(); - } - - @Override - protected void setOutputRDD(PTransform transform, - JavaRDDLike, ?> rdd) { - super.setOutputRDD(transform, rdd); - } - - @Override - protected void setOutputRDDFromValues(PTransform transform, Iterable values, - Coder coder) { - super.setOutputRDDFromValues(transform, values, coder); - } - - @Override - protected boolean hasOutputRDD(PTransform transform) { - return super.hasOutputRDD(transform); - } -} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 71c27df6cd31..d536e9a31961 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -29,6 +29,8 @@ import org.apache.beam.runners.spark.io.ConsoleIO; import org.apache.beam.runners.spark.io.CreateStream; import org.apache.beam.runners.spark.io.SparkUnboundedSource; +import org.apache.beam.runners.spark.translation.BoundedDataset; +import org.apache.beam.runners.spark.translation.Dataset; import org.apache.beam.runners.spark.translation.DoFnFunction; import org.apache.beam.runners.spark.translation.EvaluationContext; import org.apache.beam.runners.spark.translation.GroupCombineFunctions; @@ -71,15 +73,13 @@ import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.streaming.api.java.JavaDStreamLike; import org.apache.spark.streaming.api.java.JavaPairDStream; - /** * Supports translation between a Beam transform, and Spark's operations on DStreams. */ -public final class StreamingTransformTranslator { +final class StreamingTransformTranslator { private StreamingTransformTranslator() { } @@ -89,9 +89,8 @@ private static TransformEvaluator> print() { @Override public void evaluate(ConsoleIO.Write.Unbound transform, EvaluationContext context) { @SuppressWarnings("unchecked") - JavaDStreamLike, ?, JavaRDD>> dstream = - (JavaDStreamLike, ?, JavaRDD>>) - ((StreamingEvaluationContext) context).getStream(transform); + JavaDStream> dstream = + ((UnboundedDataset) (context).borrowDataset(transform)).getDStream(); dstream.map(WindowingHelpers.unwindowFunction()).print(transform.getNum()); } }; @@ -101,9 +100,9 @@ private static TransformEvaluator> readUnbounded() { return new TransformEvaluator>() { @Override public void evaluate(Read.Unbounded transform, EvaluationContext context) { - StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - sec.setStream(transform, SparkUnboundedSource.read(sec.getStreamingContext(), - sec.getRuntimeContext(), transform.getSource())); + context.putDataset(transform, + new UnboundedDataset<>(SparkUnboundedSource.read(context.getStreamingContext(), + context.getRuntimeContext(), transform.getSource()))); } }; } @@ -112,10 +111,9 @@ private static TransformEvaluator> createFromQu return new TransformEvaluator>() { @Override public void evaluate(CreateStream.QueuedValues transform, EvaluationContext context) { - StreamingEvaluationContext sec = (StreamingEvaluationContext) context; Iterable> values = transform.getQueuedValues(); - Coder coder = sec.getOutput(transform).getCoder(); - sec.setDStreamFromQueue(transform, values, coder); + Coder coder = context.getOutput(transform).getCoder(); + context.setUnboundedDatasetFromQueue(transform, values, coder); } }; } @@ -125,23 +123,23 @@ private static TransformEvaluator> flatten @SuppressWarnings("unchecked") @Override public void evaluate(Flatten.FlattenPCollectionList transform, EvaluationContext context) { - StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - PCollectionList pcs = sec.getInput(transform); + PCollectionList pcs = context.getInput(transform); // since this is a streaming pipeline, at least one of the PCollections to "flatten" are // unbounded, meaning it represents a DStream. // So we could end up with an unbounded unified DStream. final List>> rdds = new ArrayList<>(); final List>> dStreams = new ArrayList<>(); - for (PCollection pcol: pcs.getAll()) { - if (sec.hasStream(pcol)) { - dStreams.add((JavaDStream>) sec.getStream(pcol)); + for (PCollection pcol : pcs.getAll()) { + Dataset dataset = context.borrowDataset(pcol); + if (dataset instanceof UnboundedDataset) { + dStreams.add(((UnboundedDataset) dataset).getDStream()); } else { - rdds.add((JavaRDD>) context.getRDD(pcol)); + rdds.add(((BoundedDataset) dataset).getRDD()); } } // start by unifying streams into a single stream. JavaDStream> unifiedStreams = - sec.getStreamingContext().union(dStreams.remove(0), dStreams); + context.getStreamingContext().union(dStreams.remove(0), dStreams); // now unify in RDDs. if (rdds.size() > 0) { JavaDStream> joined = unifiedStreams.transform( @@ -152,9 +150,9 @@ public JavaRDD> call(JavaRDD> streamRdd) return new JavaSparkContext(streamRdd.context()).union(streamRdd, rdds); } }); - sec.setStream(transform, joined); + context.putDataset(transform, new UnboundedDataset<>(joined)); } else { - sec.setStream(transform, unifiedStreams); + context.putDataset(transform, new UnboundedDataset<>(unifiedStreams)); } } }; @@ -164,12 +162,11 @@ private static TransformEvaluator> return new TransformEvaluator>() { @Override public void evaluate(Window.Bound transform, EvaluationContext context) { - StreamingEvaluationContext sec = (StreamingEvaluationContext) context; @SuppressWarnings("unchecked") WindowFn windowFn = (WindowFn) transform.getWindowFn(); @SuppressWarnings("unchecked") JavaDStream> dStream = - (JavaDStream>) sec.getStream(transform); + ((UnboundedDataset) context.borrowDataset(transform)).getDStream(); // get the right window durations. Duration windowDuration; Duration slideDuration; @@ -188,10 +185,10 @@ public void evaluate(Window.Bound transform, EvaluationContext context) { dStream.window(windowDuration, slideDuration); //--- then we apply windowing to the elements if (TranslationUtils.skipAssignWindows(transform, context)) { - sec.setStream(transform, windowedDStream); + context.putDataset(transform, new UnboundedDataset<>(windowedDStream)); } else { final OldDoFn addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); - final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); JavaDStream> outStream = windowedDStream.transform( new Function>, JavaRDD>>() { @Override @@ -202,7 +199,7 @@ public JavaRDD> call(JavaRDD> rdd) throws Exce new DoFnFunction<>(accum, addWindowsDoFn, runtimeContext, null, null)); } }); - sec.setStream(transform, outStream); + context.putDataset(transform, new UnboundedDataset<>(outStream)); } } }; @@ -212,18 +209,16 @@ private static TransformEvaluator> groupByKey() { return new TransformEvaluator>() { @Override public void evaluate(GroupByKey transform, EvaluationContext context) { - StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - @SuppressWarnings("unchecked") JavaDStream>> dStream = - (JavaDStream>>) sec.getStream(transform); + ((UnboundedDataset>) context.borrowDataset(transform)).getDStream(); @SuppressWarnings("unchecked") - final KvCoder coder = (KvCoder) sec.getInput(transform).getCoder(); + final KvCoder coder = (KvCoder) context.getInput(transform).getCoder(); - final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); final WindowingStrategy windowingStrategy = - sec.getInput(transform).getWindowingStrategy(); + context.getInput(transform).getWindowingStrategy(); JavaDStream>>> outStream = dStream.transform(new Function>>, @@ -237,7 +232,7 @@ public JavaRDD>>> call( windowingStrategy); } }); - sec.setStream(transform, outStream); + context.putDataset(transform, new UnboundedDataset<>(outStream)); } }; } @@ -245,29 +240,29 @@ public JavaRDD>>> call( private static TransformEvaluator> combineGrouped() { return new TransformEvaluator>() { + @SuppressWarnings("unchecked") @Override public void evaluate(Combine.GroupedValues transform, EvaluationContext context) { - StreamingEvaluationContext sec = (StreamingEvaluationContext) context; // get the applied combine function. PCollection>> input = - sec.getInput(transform); + context.getInput(transform); WindowingStrategy windowingStrategy = input.getWindowingStrategy(); - @SuppressWarnings("unchecked") final CombineWithContext.KeyedCombineFnWithContext fn = (CombineWithContext.KeyedCombineFnWithContext) CombineFnUtil.toFnWithContext(transform.getFn()); - @SuppressWarnings("unchecked") JavaDStream>>> dStream = - (JavaDStream>>>) sec.getStream(transform); + ((UnboundedDataset>>) context.borrowDataset(transform)) + .getDStream(); - SparkKeyedCombineFn combineFnWithContext = - new SparkKeyedCombineFn<>(fn, sec.getRuntimeContext(), + SparkKeyedCombineFn combineFnWithContext = + new SparkKeyedCombineFn<>(fn, context.getRuntimeContext(), TranslationUtils.getSideInputs(transform.getSideInputs(), context), - windowingStrategy); - sec.setStream(transform, dStream.map(new TranslationUtils.CombineGroupedValues<>( - combineFnWithContext))); + windowingStrategy); + context.putDataset(transform, new UnboundedDataset<>(dStream.map(new TranslationUtils + .CombineGroupedValues<>( + combineFnWithContext)))); } }; } @@ -276,26 +271,24 @@ public void evaluate(Combine.GroupedValues transform, combineGlobally() { return new TransformEvaluator>() { + @SuppressWarnings("unchecked") @Override public void evaluate(Combine.Globally transform, EvaluationContext context) { - StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - final PCollection input = sec.getInput(transform); + final PCollection input = context.getInput(transform); // serializable arguments to pass. - final Coder iCoder = sec.getInput(transform).getCoder(); - final Coder oCoder = sec.getOutput(transform).getCoder(); - @SuppressWarnings("unchecked") + final Coder iCoder = context.getInput(transform).getCoder(); + final Coder oCoder = context.getOutput(transform).getCoder(); final CombineWithContext.CombineFnWithContext combineFn = (CombineWithContext.CombineFnWithContext) CombineFnUtil.toFnWithContext(transform.getFn()); final WindowingStrategy windowingStrategy = input.getWindowingStrategy(); - final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); final Map, KV, BroadcastHelper>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs(), context); final boolean hasDefault = transform.isInsertDefault(); - @SuppressWarnings("unchecked") JavaDStream> dStream = - (JavaDStream>) sec.getStream(transform); + ((UnboundedDataset) context.borrowDataset(transform)).getDStream(); JavaDStream> outStream = dStream.transform( new Function>, JavaRDD>>() { @@ -307,7 +300,7 @@ public JavaRDD> call(JavaRDD> rdd) } }); - sec.setStream(transform, outStream); + context.putDataset(transform, new UnboundedDataset<>(outStream)); } }; } @@ -315,27 +308,24 @@ public JavaRDD> call(JavaRDD> rdd) private static TransformEvaluator> combinePerKey() { return new TransformEvaluator>() { + @SuppressWarnings("unchecked") @Override public void evaluate(final Combine.PerKey transform, final EvaluationContext context) { - StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - final PCollection> input = sec.getInput(transform); + final PCollection> input = context.getInput(transform); // serializable arguments to pass. - @SuppressWarnings("unchecked") final KvCoder inputCoder = - (KvCoder) sec.getInput(transform).getCoder(); - @SuppressWarnings("unchecked") + (KvCoder) context.getInput(transform).getCoder(); final CombineWithContext.KeyedCombineFnWithContext combineFn = (CombineWithContext.KeyedCombineFnWithContext) CombineFnUtil.toFnWithContext(transform.getFn()); final WindowingStrategy windowingStrategy = input.getWindowingStrategy(); - final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); final Map, KV, BroadcastHelper>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs(), context); - @SuppressWarnings("unchecked") JavaDStream>> dStream = - (JavaDStream>>) sec.getStream(transform); + ((UnboundedDataset>) context.borrowDataset(transform)).getDStream(); JavaDStream>> outStream = dStream.transform(new Function>>, @@ -347,26 +337,24 @@ public JavaRDD>> call( windowingStrategy, sideInputs); } }); - sec.setStream(transform, outStream); + context.putDataset(transform, new UnboundedDataset<>(outStream)); } }; } private static TransformEvaluator> parDo() { return new TransformEvaluator>() { + @SuppressWarnings("unchecked") @Override public void evaluate(final ParDo.Bound transform, final EvaluationContext context) { - final StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); final Map, KV, BroadcastHelper>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs(), context); - @SuppressWarnings("unchecked") final WindowFn windowFn = - (WindowFn) sec.getInput(transform).getWindowingStrategy().getWindowFn(); - @SuppressWarnings("unchecked") + (WindowFn) context.getInput(transform).getWindowingStrategy().getWindowFn(); JavaDStream> dStream = - (JavaDStream>) sec.getStream(transform); + ((UnboundedDataset) context.borrowDataset(transform)).getDStream(); JavaDStream> outStream = dStream.transform(new Function>, @@ -381,7 +369,7 @@ public JavaRDD> call(JavaRDD> rdd) } }); - sec.setStream(transform, outStream); + context.putDataset(transform, new UnboundedDataset<>(outStream)); } }; } @@ -392,16 +380,15 @@ public JavaRDD> call(JavaRDD> rdd) @Override public void evaluate(final ParDo.BoundMulti transform, final EvaluationContext context) { - final StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); final Map, KV, BroadcastHelper>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs(), context); @SuppressWarnings("unchecked") final WindowFn windowFn = - (WindowFn) sec.getInput(transform).getWindowingStrategy().getWindowFn(); + (WindowFn) context.getInput(transform).getWindowingStrategy().getWindowFn(); @SuppressWarnings("unchecked") JavaDStream> dStream = - (JavaDStream>) sec.getStream(transform); + ((UnboundedDataset) context.borrowDataset(transform)).getDStream(); JavaPairDStream, WindowedValue> all = dStream.transformToPair( new Function>, JavaPairRDD, WindowedValue>>() { @@ -414,7 +401,7 @@ public JavaPairRDD, WindowedValue> call( runtimeContext, transform.getMainOutputTag(), sideInputs, windowFn)); } }).cache(); - PCollectionTuple pct = sec.getOutput(transform); + PCollectionTuple pct = context.getOutput(transform); for (Map.Entry, PCollection> e : pct.getAll().entrySet()) { @SuppressWarnings("unchecked") JavaPairDStream, WindowedValue> filtered = @@ -424,7 +411,7 @@ public JavaPairRDD, WindowedValue> call( JavaDStream> values = (JavaDStream>) (JavaDStream) TranslationUtils.dStreamValues(filtered); - sec.setStream(e.getValue(), values); + context.putDataset(e.getValue(), new UnboundedDataset<>(values)); } } }; diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java new file mode 100644 index 000000000000..67adee2b18aa --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation.streaming; + +import com.google.common.collect.Iterables; +import java.util.Queue; +import java.util.concurrent.LinkedBlockingQueue; +import javax.annotation.Nullable; +import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.beam.runners.spark.translation.Dataset; +import org.apache.beam.runners.spark.translation.WindowingHelpers; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.VoidFunction; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + + +/** + * DStream holder Can also crate a DStream from a supplied queue of values, but mainly for testing. + */ +public class UnboundedDataset implements Dataset { + // only set if creating a DStream from a static collection + @Nullable private transient JavaStreamingContext jssc; + + private Iterable> values; + private Coder coder; + private JavaDStream> dStream; + + UnboundedDataset(JavaDStream> dStream) { + this.dStream = dStream; + } + + public UnboundedDataset(Iterable> values, JavaStreamingContext jssc, Coder coder) { + this.values = values; + this.jssc = jssc; + this.coder = coder; + } + + @SuppressWarnings("ConstantConditions") + JavaDStream> getDStream() { + if (dStream == null) { + WindowedValue.ValueOnlyWindowedValueCoder windowCoder = + WindowedValue.getValueOnlyCoder(coder); + // create the DStream from queue + Queue>> rddQueue = new LinkedBlockingQueue<>(); + JavaRDD> lastRDD = null; + for (Iterable v : values) { + Iterable> windowedValues = + Iterables.transform(v, WindowingHelpers.windowValueFunction()); + JavaRDD> rdd = jssc.sc().parallelize( + CoderHelpers.toByteArrays(windowedValues, windowCoder)).map( + CoderHelpers.fromByteFunction(windowCoder)); + rddQueue.offer(rdd); + lastRDD = rdd; + } + // create DStream from queue, one at a time, + // with last as default in case batches repeat (graceful stops for example). + // if the stream is empty, avoid creating a default empty RDD. + // mainly for unit test so no reason to have this configurable. + dStream = lastRDD != null ? jssc.queueStream(rddQueue, true, lastRDD) + : jssc.queueStream(rddQueue, true); + } + return dStream; + } + + @Override + public void cache() { + dStream.cache(); + } + + @Override + public void action() { + dStream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) throws Exception { + rdd.count(); + } + }); + } + + @Override + public void setName(String name) { + // ignore + } +} From 1f614776f83b8f9536332e0686d53923e29677fe Mon Sep 17 00:00:00 2001 From: Aviem Zur Date: Sun, 13 Nov 2016 14:16:51 +0200 Subject: [PATCH 2/2] PR 1291 review changes. --- .../runners/spark/translation/EvaluationContext.java | 10 ++++++---- .../runners/spark/translation/TransformTranslator.java | 6 +++--- .../streaming/StreamingTransformTranslator.java | 2 +- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index 3e09b3994c7d..aaf757318fff 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -71,6 +71,7 @@ public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) { this.jsc = jsc; this.pipeline = pipeline; this.runtime = new SparkRuntimeContext(pipeline, jsc); + // A batch pipeline is blocking by nature this.state = State.DONE; } @@ -137,11 +138,12 @@ void putBoundedDatasetFromValues(PTransform transform, Iterable val datasets.put((PValue) getOutput(transform), new BoundedDataset<>(values, jsc, coder)); } - public void setUnboundedDatasetFromQueue( + public void putUnboundedDatasetFromQueue( PTransform transform, Iterable> values, Coder coder) { datasets.put((PValue) getOutput(transform), new UnboundedDataset<>(values, jssc, coder)); } - void setPView(PValue view, Iterable> value) { + + void putPView(PValue view, Iterable> value) { pview.put(view, value); } @@ -259,8 +261,8 @@ public State waitUntilFinish() { public State waitUntilFinish(Duration duration) { if (isStreamingPipeline()) { throw new UnsupportedOperationException( - "Spark runner EvaluationContext does not support waitUntilFinish for streaming " + - "pipelines."); + "Spark runner EvaluationContext does not support waitUntilFinish for streaming " + + "pipelines."); } else { // This is no-op, since Spark runner in batch is blocking. // It needs to be updated once SparkRunner supports non-blocking execution: diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 10e3376d16ae..c902ee30a9a9 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -531,7 +531,7 @@ private static TransformEvaluator> viewAsSingleton() { public void evaluate(View.AsSingleton transform, EvaluationContext context) { Iterable> iter = context.getWindowedValues(context.getInput(transform)); - context.setPView(context.getOutput(transform), iter); + context.putPView(context.getOutput(transform), iter); } }; } @@ -542,7 +542,7 @@ private static TransformEvaluator> viewAsIter() { public void evaluate(View.AsIterable transform, EvaluationContext context) { Iterable> iter = context.getWindowedValues(context.getInput(transform)); - context.setPView(context.getOutput(transform), iter); + context.putPView(context.getOutput(transform), iter); } }; } @@ -555,7 +555,7 @@ public void evaluate(View.CreatePCollectionView transform, EvaluationContext context) { Iterable> iter = context.getWindowedValues(context.getInput(transform)); - context.setPView(context.getOutput(transform), iter); + context.putPView(context.getOutput(transform), iter); } }; } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index d536e9a31961..b30f0793135c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -113,7 +113,7 @@ private static TransformEvaluator> createFromQu public void evaluate(CreateStream.QueuedValues transform, EvaluationContext context) { Iterable> values = transform.getQueuedValues(); Coder coder = context.getOutput(transform).getCoder(); - context.setUnboundedDatasetFromQueue(transform, values, coder); + context.putUnboundedDatasetFromQueue(transform, values, coder); } }; }