From 6befb9caf5a219cd4ddb5b7dbee631c438de7ae6 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Wed, 27 Jul 2016 14:23:15 -0700 Subject: [PATCH] Replace ParDo with simpler transforms where possible There are a number of places in the Java SDK where we use ParDo.of(DoFn) when MapElements or other higher-level composites are applicable and readable. This change alters a number of those. --- .../apache/beam/sdk/transforms/Combine.java | 28 +++++++++---------- .../org/apache/beam/sdk/transforms/Count.java | 8 +++--- .../beam/sdk/transforms/FlatMapElements.java | 4 +-- .../apache/beam/sdk/transforms/Flatten.java | 12 ++++---- .../org/apache/beam/sdk/transforms/Keys.java | 8 +++--- .../apache/beam/sdk/transforms/KvSwap.java | 9 +++--- .../beam/sdk/transforms/MapElements.java | 16 ++++++++--- .../beam/sdk/transforms/RemoveDuplicates.java | 8 +++--- .../apache/beam/sdk/transforms/Values.java | 8 +++--- .../apache/beam/sdk/transforms/WithKeys.java | 9 +++--- .../beam/sdk/transforms/windowing/Window.java | 11 ++++---- .../org/apache/beam/sdk/PipelineTest.java | 12 ++++---- .../org/apache/beam/sdk/io/WriteTest.java | 4 ++- .../beam/sdk/transforms/MapElementsTest.java | 8 +++--- .../org/apache/beam/sdk/io/kafka/KafkaIO.java | 10 ++++--- 15 files changed, 81 insertions(+), 74 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java index a8258000b614..ebe9bed255c8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java @@ -2134,14 +2134,14 @@ public void processElement(ProcessContext c) { inputCoder.getValueCoder())) .setWindowingStrategyInternal(preCombineStrategy) .apply("PreCombineHot", Combine.perKey(hotPreCombine)) - .apply("StripNonce", ParDo.of( - new DoFn, AccumT>, - KV>>() { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(KV.of( - c.element().getKey().getKey(), - InputOrAccum.accum(c.element().getValue()))); + .apply("StripNonce", MapElements.via( + new SimpleFunction, AccumT>, + KV>>() { + @Override + public KV> apply(KV, AccumT> elem) { + return KV.of( + elem.getKey().getKey(), + InputOrAccum.accum(elem.getValue())); } })) .setCoder(KvCoder.of(inputCoder.getKeyCoder(), inputOrAccumCoder)) @@ -2150,12 +2150,12 @@ public void processElement(ProcessContext c) { PCollection>> preprocessedCold = split .get(cold) .setCoder(inputCoder) - .apply("PrepareCold", ParDo.of( - new DoFn, KV>>() { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(KV.of(c.element().getKey(), - InputOrAccum.input(c.element().getValue()))); + .apply("PrepareCold", MapElements.via( + new SimpleFunction, KV>>() { + @Override + public KV> apply(KV element) { + return KV.of(element.getKey(), + InputOrAccum.input(element.getValue())); } })) .setCoder(KvCoder.of(inputCoder.getKeyCoder(), inputOrAccumCoder)); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java index ac59c767504e..195c5d17ed88 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Count.java @@ -107,10 +107,10 @@ public PerElement() { } public PCollection> apply(PCollection input) { return input - .apply("Init", ParDo.of(new DoFn>() { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(KV.of(c.element(), (Void) null)); + .apply("Init", MapElements.via(new SimpleFunction>() { + @Override + public KV apply(T element) { + return KV.of(element, (Void) null); } })) .apply(Count.perKey()); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java index 6f9e3d8ac078..2837c40cc3a6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java @@ -29,7 +29,7 @@ * {@link PCollection} and merging the results. */ public class FlatMapElements -extends PTransform, PCollection> { +extends PTransform, PCollection> { /** * For a {@code SerializableFunction>} {@code fn}, * returns a {@link PTransform} that applies {@code fn} to every element of the input @@ -130,7 +130,7 @@ private FlatMapElements( } @Override - public PCollection apply(PCollection input) { + public PCollection apply(PCollection input) { return input.apply( "FlatMap", ParDo.of( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java index 7e09d7e4dd3b..f3f4f887078d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java @@ -173,13 +173,11 @@ public PCollection apply(PCollection> in) { @SuppressWarnings("unchecked") Coder elemCoder = ((IterableLikeCoder) inCoder).getElemCoder(); - return in.apply("FlattenIterables", ParDo.of( - new DoFn, T>() { - @ProcessElement - public void processElement(ProcessContext c) { - for (T i : c.element()) { - c.output(i); - } + return in.apply("FlattenIterables", FlatMapElements.via( + new SimpleFunction, Iterable>() { + @Override + public Iterable apply(Iterable element) { + return element; } })) .setCoder(elemCoder); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Keys.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Keys.java index 5ac1866a3590..2405adf41e4b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Keys.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Keys.java @@ -58,10 +58,10 @@ private Keys() { } @Override public PCollection apply(PCollection> in) { return - in.apply("Keys", ParDo.of(new DoFn, K>() { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(c.element().getKey()); + in.apply("Keys", MapElements.via(new SimpleFunction, K>() { + @Override + public K apply(KV kv) { + return kv.getKey(); } })); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/KvSwap.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/KvSwap.java index d4386d2a8107..2b81ebfdf1de 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/KvSwap.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/KvSwap.java @@ -62,11 +62,10 @@ private KvSwap() { } @Override public PCollection> apply(PCollection> in) { return - in.apply("KvSwap", ParDo.of(new DoFn, KV>() { - @ProcessElement - public void processElement(ProcessContext c) { - KV e = c.element(); - c.output(KV.of(e.getValue(), e.getKey())); + in.apply("KvSwap", MapElements.via(new SimpleFunction, KV>() { + @Override + public KV apply(KV kv) { + return KV.of(kv.getValue(), kv.getKey()); } })); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java index 17ad6e74a13f..73e4359831db 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java @@ -25,7 +25,7 @@ * {@code PTransform}s for mapping a simple function over the elements of a {@link PCollection}. */ public class MapElements -extends PTransform, PCollection> { +extends PTransform, PCollection> { /** * For a {@code SerializableFunction} {@code fn} and output type descriptor, @@ -44,8 +44,16 @@ public class MapElements * descriptor need not be provided. */ public static MissingOutputTypeDescriptor - via(SerializableFunction fn) { - return new MissingOutputTypeDescriptor<>(fn); + via(SerializableFunction fn) { + + // TypeDescriptor interacts poorly with the wildcards needed to correctly express + // covariance and contravariance in Java, so instead we cast it to an invariant + // function here. + @SuppressWarnings("unchecked") // safe covariant cast + SerializableFunction simplerFn = + (SerializableFunction) fn; + + return new MissingOutputTypeDescriptor<>(simplerFn); } /** @@ -103,7 +111,7 @@ private MapElements(SimpleFunction fn, Class fnClass) { } @Override - public PCollection apply(PCollection input) { + public PCollection apply(PCollection input) { return input.apply( "Map", ParDo.of( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/RemoveDuplicates.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/RemoveDuplicates.java index bba4b5130957..2744b14c8913 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/RemoveDuplicates.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/RemoveDuplicates.java @@ -85,10 +85,10 @@ public static WithRepresentativeValues withRepresentativeValueF @Override public PCollection apply(PCollection in) { return in - .apply("CreateIndex", ParDo.of(new DoFn>() { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(KV.of(c.element(), (Void) null)); + .apply("CreateIndex", MapElements.via(new SimpleFunction>() { + @Override + public KV apply(T element) { + return KV.of(element, (Void) null); } })) .apply(Combine.perKey( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Values.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Values.java index 34342db53c6b..d21d100764a0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Values.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Values.java @@ -58,10 +58,10 @@ private Values() { } @Override public PCollection apply(PCollection> in) { return - in.apply("Values", ParDo.of(new DoFn, V>() { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(c.element().getValue()); + in.apply("Values", MapElements.via(new SimpleFunction, V>() { + @Override + public V apply(KV kv) { + return kv.getValue(); } })); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java index 2a44963e6fc2..8b061f6bdc57 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java @@ -113,11 +113,10 @@ public WithKeys withKeyType(TypeDescriptor keyType) { @Override public PCollection> apply(PCollection in) { PCollection> result = - in.apply("AddKeys", ParDo.of(new DoFn>() { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(KV.of(fn.apply(c.element()), - c.element())); + in.apply("AddKeys", MapElements.via(new SimpleFunction>() { + @Override + public KV apply(V element) { + return KV.of(fn.apply(element), element); } })); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Window.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Window.java index c1b0237e7a69..9dd069cf9952 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Window.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/Window.java @@ -21,10 +21,10 @@ import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.NonDeterministicException; -import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.util.WindowingStrategy.AccumulationMode; @@ -645,10 +645,9 @@ public PCollection apply(PCollection input) { // We first apply a (trivial) transform to the input PCollection to produce a new // PCollection. This ensures that we don't modify the windowing strategy of the input // which may be used elsewhere. - .apply("Identity", ParDo.of(new DoFn() { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(c.element()); + .apply("Identity", MapElements.via(new SimpleFunction() { + @Override public T apply(T element) { + return element; } })) // Then we modify the windowing strategy. diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java index 8b8649994a06..d7b3ac54de9d 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java @@ -36,10 +36,10 @@ import org.apache.beam.sdk.testing.RunnableOnService; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; @@ -146,10 +146,10 @@ public void testMultipleApply() { private static PTransform, PCollection> addSuffix( final String suffix) { - return ParDo.of(new DoFn() { - @ProcessElement - public void processElement(DoFn.ProcessContext c) { - c.output(c.element() + suffix); + return MapElements.via(new SimpleFunction() { + @Override + public String apply(String input) { + return input + suffix; } }); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java index 705b77cb3986..56877148ec28 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java @@ -90,7 +90,9 @@ public class WriteTest { // Static counts of the number of records per shard. private static List recordsPerShard = new ArrayList<>(); - private static final MapElements IDENTITY_MAP = + @SuppressWarnings("unchecked") // covariant cast + private static final PTransform, PCollection> IDENTITY_MAP = + (PTransform) MapElements.via(new SimpleFunction() { @Override public String apply(String input) { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java index e86a1289f82b..7217bca663fa 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java @@ -233,7 +233,7 @@ public Integer apply(Integer input) { } @Test public void testSimpleFunctionDisplayData() { - SimpleFunction simpleFn = new SimpleFunction() { + SimpleFunction simpleFn = new SimpleFunction() { @Override public Integer apply(Integer input) { return input; @@ -255,17 +255,17 @@ public void populateDisplayData(DisplayData.Builder builder) { @Test @Category(RunnableOnService.class) public void testPrimitiveDisplayData() { - SimpleFunction mapFn = new SimpleFunction() { + SimpleFunction mapFn = new SimpleFunction() { @Override public Integer apply(Integer input) { return input; } }; - MapElements map = MapElements.via(mapFn); + MapElements map = MapElements.via(mapFn); DisplayDataEvaluator evaluator = DisplayDataEvaluator.create(); - Set displayData = evaluator.displayDataForPrimitiveTransforms(map); + Set displayData = evaluator.displayDataForPrimitiveTransforms(map); assertThat("MapElements should include the mapFn in its primitive display data", displayData, hasItem(hasDisplayItem("mapFn", mapFn.getClass()))); } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index 2383105d2e61..8a0c7880e97d 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -34,9 +34,11 @@ import org.apache.beam.sdk.io.kafka.KafkaCheckpointMark.PartitionMark; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.ExposedByteArrayInputStream; import org.apache.beam.sdk.values.KV; @@ -1314,10 +1316,10 @@ private KafkaValueWrite(TypedWrite kvWriteTransform) { public PDone apply(PCollection input) { return input .apply("Kafka values with default key", - ParDo.of(new DoFn>() { - @ProcessElement - public void processElement(ProcessContext ctx) throws Exception { - ctx.output(KV.of(null, ctx.element())); + MapElements.via(new SimpleFunction>() { + @Override + public KV apply(V element) { + return KV.of(null, element); } })) .setCoder(KvCoder.of(VoidCoder.of(), kvWriteTransform.valueCoder))