From 4ac5cafe90a371cf616f97cb202d5016b68616d1 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Fri, 29 Jul 2016 10:35:01 -0700 Subject: [PATCH] Use input type in coder inference for MapElements and FlatMapElements Previously, the input TypeDescriptor was unknown, so we would fail to infer a coder for things like MapElements.of(SimpleFunction) even if the input PCollection provided a coder for T. Now, the input type is plumbed appropriately and the coder is inferred. --- .../beam/sdk/transforms/FlatMapElements.java | 126 ++++++++++++------ .../beam/sdk/transforms/MapElements.java | 60 +++++---- .../beam/sdk/transforms/SimpleFunction.java | 34 +++++ .../sdk/transforms/FlatMapElementsTest.java | 48 +++++++ .../beam/sdk/transforms/MapElementsTest.java | 84 ++++++++++++ 5 files changed, 288 insertions(+), 64 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java index 694592ed86b0..04d993cb0d62 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java @@ -17,8 +17,10 @@ */ package org.apache.beam.sdk.transforms; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; import java.lang.reflect.ParameterizedType; @@ -45,8 +47,16 @@ public class FlatMapElements * descriptor need not be provided. */ public static MissingOutputTypeDescriptor - via(SerializableFunction> fn) { - return new MissingOutputTypeDescriptor<>(fn); + via(SerializableFunction> fn) { + + // TypeDescriptor interacts poorly with the wildcards needed to correctly express + // covariance and contravariance in Java, so instead we cast it to an invariant + // function here. + @SuppressWarnings("unchecked") // safe covariant cast + SerializableFunction> simplerFn = + (SerializableFunction>) fn; + + return new MissingOutputTypeDescriptor<>(simplerFn); } /** @@ -72,16 +82,15 @@ public class FlatMapElements *

To use a Java 8 lambda, see {@link #via(SerializableFunction)}. */ public static FlatMapElements - via(SimpleFunction> fn) { - - @SuppressWarnings({"rawtypes", "unchecked"}) // safe by static typing - TypeDescriptor> iterableType = (TypeDescriptor) fn.getOutputTypeDescriptor(); - - @SuppressWarnings("unchecked") // safe by correctness of getIterableElementType - TypeDescriptor outputType = - (TypeDescriptor) getIterableElementType(iterableType); - - return new FlatMapElements<>(fn, outputType); + via(SimpleFunction> fn) { + // TypeDescriptor interacts poorly with the wildcards needed to correctly express + // covariance and contravariance in Java, so instead we cast it to an invariant + // function here. + @SuppressWarnings("unchecked") // safe covariant cast + SimpleFunction> simplerFn = + (SimpleFunction>) fn; + + return new FlatMapElements<>(simplerFn, fn.getClass()); } /** @@ -91,18 +100,80 @@ public class FlatMapElements */ public static final class MissingOutputTypeDescriptor { - private final SerializableFunction> fn; + private final SerializableFunction> fn; private MissingOutputTypeDescriptor( - SerializableFunction> fn) { + SerializableFunction> fn) { this.fn = fn; } public FlatMapElements withOutputType(TypeDescriptor outputType) { - return new FlatMapElements<>(fn, outputType); + TypeDescriptor> iterableOutputType = TypeDescriptors.iterables(outputType); + + return new FlatMapElements<>( + SimpleFunction.fromSerializableFunctionWithOutputType(fn, + iterableOutputType), + fn.getClass()); } } + ////////////////////////////////////////////////////////////////////////////////////////////////// + + private final SimpleFunction> fn; + private final DisplayData.Item fnClassDisplayData; + + private FlatMapElements( + SimpleFunction> fn, + Class fnClass) { + this.fn = fn; + this.fnClassDisplayData = DisplayData.item("flatMapFn", fnClass).withLabel("FlatMap Function"); + } + + @Override + public PCollection apply(PCollection input) { + return input.apply( + "FlatMap", + ParDo.of( + new DoFn() { + private static final long serialVersionUID = 0L; + + @ProcessElement + public void processElement(ProcessContext c) { + for (OutputT element : fn.apply(c.element())) { + c.output(element); + } + } + + @Override + public TypeDescriptor getInputTypeDescriptor() { + return fn.getInputTypeDescriptor(); + } + + @Override + public TypeDescriptor getOutputTypeDescriptor() { + @SuppressWarnings({"rawtypes", "unchecked"}) // safe by static typing + TypeDescriptor> iterableType = + (TypeDescriptor) fn.getOutputTypeDescriptor(); + + @SuppressWarnings("unchecked") // safe by correctness of getIterableElementType + TypeDescriptor outputType = + (TypeDescriptor) getIterableElementType(iterableType); + + return outputType; + } + })); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder.add(fnClassDisplayData); + } + + /** + * Does a best-effort job of getting the best {@link TypeDescriptor} for the type of the + * elements contained in the iterable described by the given {@link TypeDescriptor}. + */ private static TypeDescriptor getIterableElementType( TypeDescriptor> iterableTypeDescriptor) { @@ -118,29 +189,4 @@ private static TypeDescriptor getIterableElementType( (ParameterizedType) iterableTypeDescriptor.getSupertype(Iterable.class).getType(); return TypeDescriptor.of(iterableType.getActualTypeArguments()[0]); } - - ////////////////////////////////////////////////////////////////////////////////////////////////// - - private final SerializableFunction> fn; - private final transient TypeDescriptor outputType; - - private FlatMapElements( - SerializableFunction> fn, - TypeDescriptor outputType) { - this.fn = fn; - this.outputType = outputType; - } - - @Override - public PCollection apply(PCollection input) { - return input.apply("Map", ParDo.of(new DoFn() { - private static final long serialVersionUID = 0L; - @ProcessElement - public void processElement(ProcessContext c) { - for (OutputT element : fn.apply(c.element())) { - c.output(element); - } - } - })).setTypeDescriptorInternal(outputType); - } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java index b7b9a5fa3d1f..429d3fca8be3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java @@ -67,9 +67,9 @@ public class MapElements * })); * } */ - public static MapElements - via(final SimpleFunction fn) { - return new MapElements<>(fn, fn.getOutputTypeDescriptor()); + public static MapElements via( + final SimpleFunction fn) { + return new MapElements<>(fn, fn.getClass()); } /** @@ -85,42 +85,54 @@ private MissingOutputTypeDescriptor(SerializableFunction fn) { this.fn = fn; } - public MapElements withOutputType(TypeDescriptor outputType) { - return new MapElements<>(fn, outputType); + public MapElements withOutputType(final TypeDescriptor outputType) { + return new MapElements<>( + SimpleFunction.fromSerializableFunctionWithOutputType(fn, outputType), fn.getClass()); } + } /////////////////////////////////////////////////////////////////// - private final SerializableFunction fn; - private final transient TypeDescriptor outputType; + private final SimpleFunction fn; + private final DisplayData.Item fnClassDisplayData; - private MapElements( - SerializableFunction fn, - TypeDescriptor outputType) { + private MapElements(SimpleFunction fn, Class fnClass) { this.fn = fn; - this.outputType = outputType; + this.fnClassDisplayData = DisplayData.item("mapFn", fnClass).withLabel("Map Function"); } @Override public PCollection apply(PCollection input) { - return input.apply("Map", ParDo.of(new DoFn() { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(fn.apply(c.element())); - } - - @Override - public void populateDisplayData(DisplayData.Builder builder) { - MapElements.this.populateDisplayData(builder); - } - })).setTypeDescriptorInternal(outputType); + return input.apply( + "Map", + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(fn.apply(c.element())); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + MapElements.this.populateDisplayData(builder); + } + + @Override + public TypeDescriptor getInputTypeDescriptor() { + return fn.getInputTypeDescriptor(); + } + + @Override + public TypeDescriptor getOutputTypeDescriptor() { + return fn.getOutputTypeDescriptor(); + } + })); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); - builder.add(DisplayData.item("mapFn", fn.getClass()) - .withLabel("Map Function")); + builder.add(fnClassDisplayData); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java index 889435262c0b..6c540cc034dc 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java @@ -27,6 +27,12 @@ public abstract class SimpleFunction implements SerializableFunction { + public static + SimpleFunction fromSerializableFunctionWithOutputType( + SerializableFunction fn, TypeDescriptor outputType) { + return new SimpleFunctionWithOutputType<>(fn, outputType); + } + /** * Returns a {@link TypeDescriptor} capturing what is known statically * about the input type of this {@code OldDoFn} instance's most-derived @@ -52,4 +58,32 @@ public TypeDescriptor getInputTypeDescriptor() { public TypeDescriptor getOutputTypeDescriptor() { return new TypeDescriptor(this) {}; } + + /** + * A {@link SimpleFunction} built from a {@link SerializableFunction}, having + * a known output type that is explicitly set. + */ + private static class SimpleFunctionWithOutputType + extends SimpleFunction { + + private final SerializableFunction fn; + private final TypeDescriptor outputType; + + public SimpleFunctionWithOutputType( + SerializableFunction fn, + TypeDescriptor outputType) { + this.fn = fn; + this.outputType = outputType; + } + + @Override + public OutputT apply(InputT input) { + return fn.apply(input); + } + + @Override + public TypeDescriptor getOutputTypeDescriptor() { + return outputType; + } + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java index 057fd19ee941..781e14398960 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.transforms; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; + import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertThat; @@ -24,6 +26,7 @@ import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; @@ -102,6 +105,51 @@ public Set apply(String input) { pipeline.run(); } + /** + * A {@link SimpleFunction} to test that the coder registry can propagate coders + * that are bound to type variables. + */ + private static class PolymorphicSimpleFunction extends SimpleFunction> { + @Override + public Iterable apply(T input) { + return Collections.emptyList(); + } + } + + /** + * Basic test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction}. + */ + @Test + public void testPolymorphicSimpleFunction() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + + // This is the function that needs to propagate the input T to output T + .apply("Polymorphic Identity", MapElements.via(new PolymorphicSimpleFunction())) + + // This is a consumer to ensure that all coder inference logic is executed. + .apply("Test Consumer", MapElements.via(new SimpleFunction, Integer>() { + @Override + public Integer apply(Iterable input) { + return 42; + } + })); + } + + @Test + public void testSimpleFunctionClassDisplayData() { + SimpleFunction> simpleFn = new SimpleFunction>() { + @Override + public List apply(Integer input) { + return Collections.emptyList(); + } + }; + + FlatMapElements simpleMap = FlatMapElements.via(simpleFn); + assertThat(DisplayData.from(simpleMap), hasDisplayItem("flatMapFn", simpleFn.getClass())); + } + @Test @Category(NeedsRunner.class) public void testVoidValues() throws Exception { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java index b4751d2e8e08..dbf884424085 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java @@ -53,6 +53,29 @@ public class MapElementsTest implements Serializable { @Rule public transient ExpectedException thrown = ExpectedException.none(); + /** + * A {@link SimpleFunction} to test that the coder registry can propagate coders + * that are bound to type variables. + */ + private static class PolymorphicSimpleFunction extends SimpleFunction { + @Override + public T apply(T input) { + return input; + } + } + + /** + * A {@link SimpleFunction} to test that the coder registry can propagate coders + * that are bound to type variables, when the variable appears nested in the + * output. + */ + private static class NestedPolymorphicSimpleFunction extends SimpleFunction> { + @Override + public KV apply(T input) { + return KV.of(input, "hello"); + } + } + /** * Basic test of {@link MapElements} with a {@link SimpleFunction}. */ @@ -73,6 +96,55 @@ public Integer apply(Integer input) { pipeline.run(); } + /** + * Basic test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction}. + */ + @Test + public void testPolymorphicSimpleFunction() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection output = pipeline + .apply(Create.of(1, 2, 3)) + + // This is the function that needs to propagate the input T to output T + .apply("Polymorphic Identity", MapElements.via(new PolymorphicSimpleFunction())) + + // This is a consumer to ensure that all coder inference logic is executed. + .apply("Test Consumer", MapElements.via(new SimpleFunction() { + @Override + public Integer apply(Integer input) { + return input; + } + })); + } + + /** + * Test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction} + * where the type variable occurs nested within other concrete type constructors. + */ + @Test + public void testNestedPolymorphicSimpleFunction() throws Exception { + Pipeline pipeline = TestPipeline.create(); + PCollection output = + pipeline + .apply(Create.of(1, 2, 3)) + + // This is the function that needs to propagate the input T to output T + .apply( + "Polymorphic Identity", + MapElements.via(new NestedPolymorphicSimpleFunction())) + + // This is a consumer to ensure that all coder inference logic is executed. + .apply( + "Test Consumer", + MapElements.via( + new SimpleFunction, Integer>() { + @Override + public Integer apply(KV input) { + return 42; + } + })); + } + /** * Basic test of {@link MapElements} with a {@link SerializableFunction}. This style is * generally discouraged in Java 7, in favor of {@link SimpleFunction}. @@ -147,6 +219,18 @@ public Integer apply(Integer input) { hasDisplayItem("mapFn", serializableFn.getClass())); } + @Test + public void testSimpleFunctionClassDisplayData() { + SimpleFunction simpleFn = new SimpleFunction() { + @Override + public Integer apply(Integer input) { + return input; + } + }; + + MapElements simpleMap = MapElements.via(simpleFn); + assertThat(DisplayData.from(simpleMap), hasDisplayItem("mapFn", simpleFn.getClass())); + } @Test public void testSimpleFunctionDisplayData() { SimpleFunction simpleFn = new SimpleFunction() {