Skip to content

Commit

Permalink
This closes #2328
Browse files Browse the repository at this point in the history
  • Loading branch information
aviemzur committed Mar 26, 2017
2 parents 348d335 + b32f048 commit c9e55a4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@

package org.apache.beam.runners.spark.translation;

import static com.google.common.base.Preconditions.checkArgument;

import com.google.common.base.Optional;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.util.ByteArray;
import org.apache.beam.sdk.coders.Coder;
Expand Down Expand Up @@ -67,14 +66,12 @@ public static <K, V> JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> g
/**
* Apply a composite {@link org.apache.beam.sdk.transforms.Combine.Globally} transformation.
*/
public static <InputT, AccumT> Iterable<WindowedValue<AccumT>> combineGlobally(
public static <InputT, AccumT> Optional<Iterable<WindowedValue<AccumT>>> combineGlobally(
JavaRDD<WindowedValue<InputT>> rdd,
final SparkGlobalCombineFn<InputT, AccumT, ?> sparkCombineFn,
final Coder<InputT> iCoder,
final Coder<AccumT> aCoder,
final WindowingStrategy<?, ?> windowingStrategy) {
checkArgument(!rdd.isEmpty(), "CombineGlobally computation should be skipped for empty RDDs.");

// coders.
final WindowedValue.FullWindowedValueCoder<InputT> wviCoder =
WindowedValue.FullWindowedValueCoder.of(iCoder,
Expand All @@ -93,6 +90,11 @@ public static <InputT, AccumT> Iterable<WindowedValue<AccumT>> combineGlobally(
//---- AccumT: A
//---- InputT: I
JavaRDD<byte[]> inputRDDBytes = rdd.map(CoderHelpers.toByteFunction(wviCoder));

if (inputRDDBytes.isEmpty()) {
return Optional.absent();
}

/*Itr<WV<A>>*/ byte[] accumulatedBytes = inputRDDBytes.aggregate(
CoderHelpers.toByteArray(sparkCombineFn.zeroValue(), iterAccumCoder),
new Function2</*A*/ byte[], /*I*/ byte[], /*A*/ byte[]>() {
Expand All @@ -115,7 +117,8 @@ public static <InputT, AccumT> Iterable<WindowedValue<AccumT>> combineGlobally(
}
}
);
return CoderHelpers.fromByteArray(accumulatedBytes, iterAccumCoder);

return Optional.of(CoderHelpers.fromByteArray(accumulatedBytes, iterAccumCoder));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable;
import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers;

import com.google.common.base.Optional;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
Expand Down Expand Up @@ -259,9 +260,20 @@ public void evaluate(
((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();

JavaRDD<WindowedValue<OutputT>> outRdd;
// handle empty input RDD, which will naturally skip the entire execution
// as Spark will not run on empty RDDs.
if (inRdd.isEmpty()) {

Optional<Iterable<WindowedValue<AccumT>>> maybeAccumulated =
GroupCombineFunctions.combineGlobally(inRdd, sparkCombineFn, iCoder, aCoder,
windowingStrategy);

if (maybeAccumulated.isPresent()) {
Iterable<WindowedValue<OutputT>> output =
sparkCombineFn.extractOutput(maybeAccumulated.get());
outRdd = context.getSparkContext()
.parallelize(CoderHelpers.toByteArrays(output, wvoCoder))
.map(CoderHelpers.fromByteFunction(wvoCoder));
} else {
// handle empty input RDD, which will naturally skip the entire execution
// as Spark will not run on empty RDDs.
JavaSparkContext jsc = new JavaSparkContext(inRdd.context());
if (hasDefault) {
OutputT defaultValue = combineFn.defaultValue();
Expand All @@ -272,14 +284,8 @@ public void evaluate(
} else {
outRdd = jsc.emptyRDD();
}
} else {
Iterable<WindowedValue<AccumT>> accumulated = GroupCombineFunctions.combineGlobally(
inRdd, sparkCombineFn, iCoder, aCoder, windowingStrategy);
Iterable<WindowedValue<OutputT>> output = sparkCombineFn.extractOutput(accumulated);
outRdd = context.getSparkContext()
.parallelize(CoderHelpers.toByteArrays(output, wvoCoder))
.map(CoderHelpers.fromByteFunction(wvoCoder));
}

context.putDataset(transform, new BoundedDataset<>(outRdd));
}

Expand Down

0 comments on commit c9e55a4

Please sign in to comment.