diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java index 11eeb33089fa7..efeaf15849d3b 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java @@ -1096,12 +1096,7 @@ public DataStreamSource fromCollection( // must not have null elements and mixed elements FromElementsFunction.checkCollection(data, typeInfo.getTypeClass()); - SourceFunction function; - try { - function = new FromElementsFunction<>(typeInfo.createSerializer(getConfig()), data); - } catch (IOException e) { - throw new RuntimeException(e.getMessage(), e); - } + SourceFunction function = new FromElementsFunction<>(data); return addSource(function, "Collection Source", typeInfo, Boundedness.BOUNDED) .setParallelism(1); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/FromElementsFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/FromElementsFunction.java index 574ffd715b155..d740f6cd85e3f 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/FromElementsFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/FromElementsFunction.java @@ -18,8 +18,11 @@ package org.apache.flink.streaming.api.functions.source; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.core.memory.DataInputView; @@ -28,37 +31,44 @@ import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.operators.OutputTypeConfigurable; +import org.apache.flink.util.IterableUtils; import org.apache.flink.util.Preconditions; +import javax.annotation.Nullable; + import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.Objects; /** * A stream source function that returns a sequence of elements. * - *

Upon construction, this source function serializes the elements using Flink's type - * information. That way, any object transport using Java serialization will not be affected by the - * serializability of the elements. + *

This source function serializes the elements using Flink's type information. That way, any + * object transport using Java serialization will not be affected by the serializability of the + * elements. * *

NOTE: This source has a parallelism of 1. * * @param The type of elements returned by this function. */ @PublicEvolving -public class FromElementsFunction implements SourceFunction, CheckpointedFunction { +public class FromElementsFunction + implements SourceFunction, CheckpointedFunction, OutputTypeConfigurable { private static final long serialVersionUID = 1L; /** The (de)serializer to be used for the data elements. */ - private final TypeSerializer serializer; + private @Nullable TypeSerializer serializer; /** The actual data elements, in serialized form. */ - private final byte[] elementsSerialized; + private byte[] elementsSerialized; /** The number of serialized elements. */ private final int numElements; @@ -72,30 +82,81 @@ public class FromElementsFunction implements SourceFunction, CheckpointedF /** Flag to make the source cancelable. */ private volatile boolean isRunning = true; + private final transient Iterable elements; + private transient ListState checkpointedState; + @SafeVarargs public FromElementsFunction(TypeSerializer serializer, T... elements) throws IOException { this(serializer, Arrays.asList(elements)); } public FromElementsFunction(TypeSerializer serializer, Iterable elements) throws IOException { + this.serializer = Preconditions.checkNotNull(serializer); + this.elements = elements; + this.numElements = + elements instanceof Collection + ? ((Collection) elements).size() + : (int) IterableUtils.toStream(elements).count(); + serializeElements(); + } + + @SafeVarargs + public FromElementsFunction(T... elements) { + this(Arrays.asList(elements)); + } + + public FromElementsFunction(Iterable elements) { + this.serializer = null; + this.elements = elements; + this.numElements = + elements instanceof Collection + ? ((Collection) elements).size() + : (int) IterableUtils.toStream(elements).count(); + checkIterable(elements, Object.class); + } + + @VisibleForTesting + @Nullable + public TypeSerializer getSerializer() { + return serializer; + } + + private void serializeElements() throws IOException { + Preconditions.checkState(serializer != null, "serializer not set"); ByteArrayOutputStream baos = new ByteArrayOutputStream(); DataOutputViewStreamWrapper wrapper = new DataOutputViewStreamWrapper(baos); - - int count = 0; try { for (T element : elements) { serializer.serialize(element, wrapper); - count++; } } catch (Exception e) { throw new IOException("Serializing the source elements failed: " + e.getMessage(), e); } - - this.serializer = serializer; this.elementsSerialized = baos.toByteArray(); - this.numElements = count; + } + + /** + * Set element type and re-serialize element if required. Should only be called before + * serialization/deserialization of this function. + */ + @Override + public void setOutputType(TypeInformation outTypeInfo, ExecutionConfig executionConfig) { + Preconditions.checkState( + elements != null, + "The output type should've been specified before shipping the graph to the cluster"); + checkIterable(elements, outTypeInfo.getTypeClass()); + TypeSerializer newSerializer = outTypeInfo.createSerializer(executionConfig); + if (Objects.equals(serializer, newSerializer)) { + return; + } + serializer = newSerializer; + try { + serializeElements(); + } catch (IOException ex) { + throw new UncheckedIOException(ex); + } } @Override @@ -127,6 +188,7 @@ public void initializeState(FunctionInitializationContext context) throws Except @Override public void run(SourceContext ctx) throws Exception { + Preconditions.checkState(serializer != null, "serializer not configured"); ByteArrayInputStream bais = new ByteArrayInputStream(elementsSerialized); final DataInputView input = new DataInputViewStreamWrapper(bais); @@ -222,6 +284,10 @@ public void snapshotState(FunctionSnapshotContext context) throws Exception { * @param The generic type of the collection to be checked. */ public static void checkCollection(Collection elements, Class viewedAs) { + checkIterable(elements, viewedAs); + } + + private static void checkIterable(Iterable elements, Class viewedAs) { for (OUT elem : elements) { if (elem == null) { throw new IllegalArgumentException("The collection contains a null element"); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/StreamExecutionEnvironmentTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/StreamExecutionEnvironmentTest.java index 8079e22863d07..f86dad2065e3a 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/StreamExecutionEnvironmentTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/StreamExecutionEnvironmentTest.java @@ -49,6 +49,7 @@ import java.util.NoSuchElementException; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -67,6 +68,35 @@ public void fromElementsWithBaseTypeTest2() { env.fromElements(SubClass.class, new SubClass(1, "Java"), new ParentClass(1, "hello")); } + @Test + public void testFromElementsDeducedType() { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + DataStreamSource source = env.fromElements("a", "b"); + + FromElementsFunction elementsFunction = + (FromElementsFunction) getFunctionFromDataSource(source); + assertEquals( + BasicTypeInfo.STRING_TYPE_INFO.createSerializer(env.getConfig()), + elementsFunction.getSerializer()); + } + + @Test + public void testFromElementsPostConstructionType() { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + DataStreamSource source = env.fromElements("a", "b"); + TypeInformation customType = new GenericTypeInfo<>(String.class); + + source.returns(customType); + + FromElementsFunction elementsFunction = + (FromElementsFunction) getFunctionFromDataSource(source); + assertNotEquals( + BasicTypeInfo.STRING_TYPE_INFO.createSerializer(env.getConfig()), + elementsFunction.getSerializer()); + assertEquals( + customType.createSerializer(env.getConfig()), elementsFunction.getSerializer()); + } + @Test @SuppressWarnings("unchecked") public void testFromCollectionParallelism() { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/FromElementsFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/FromElementsFunctionTest.java index 819e32fbd5011..0be9550688f1c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/FromElementsFunctionTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/FromElementsFunctionTest.java @@ -21,7 +21,9 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.java.typeutils.GenericTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.api.java.typeutils.ValueTypeInfo; import org.apache.flink.core.memory.DataInputView; @@ -33,21 +35,42 @@ import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness; import org.apache.flink.types.Value; import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.InstantiationUtil; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** Tests for the {@link org.apache.flink.streaming.api.functions.source.FromElementsFunction}. */ public class FromElementsFunctionTest { + private static final String[] STRING_ARRAY_DATA = {"Oh", "boy", "what", "a", "show", "!"}; + private static final List STRING_LIST_DATA = Arrays.asList(STRING_ARRAY_DATA); + + @Rule public final ExpectedException thrown = ExpectedException.none(); + + private static List runSource(FromElementsFunction source) throws Exception { + List result = new ArrayList<>(); + FromElementsFunction clonedSource = InstantiationUtil.clone(source); + clonedSource.run(new ListSourceContext<>(result)); + return result; + } + @Test public void testStrings() { try { @@ -68,6 +91,106 @@ public void testStrings() { } } + @Test + public void testNullElement() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("contains a null element"); + + new FromElementsFunction<>("a", null, "b"); + } + + @Test + public void testSetOutputTypeWithNoSerializer() throws Exception { + FromElementsFunction source = new FromElementsFunction<>(STRING_ARRAY_DATA); + + assertNull(source.getSerializer()); + + source.setOutputType(BasicTypeInfo.STRING_TYPE_INFO, new ExecutionConfig()); + + assertNotNull(source.getSerializer()); + assertEquals( + BasicTypeInfo.STRING_TYPE_INFO.createSerializer(new ExecutionConfig()), + source.getSerializer()); + + List result = runSource(source); + + assertEquals(STRING_LIST_DATA, result); + } + + @Test + public void testSetOutputTypeWithSameSerializer() throws Exception { + FromElementsFunction source = + new FromElementsFunction<>( + BasicTypeInfo.STRING_TYPE_INFO.createSerializer(new ExecutionConfig()), + STRING_LIST_DATA); + + TypeSerializer existingSerializer = source.getSerializer(); + + source.setOutputType(BasicTypeInfo.STRING_TYPE_INFO, new ExecutionConfig()); + + TypeSerializer newSerializer = source.getSerializer(); + + assertEquals(existingSerializer, newSerializer); + + List result = runSource(source); + + assertEquals(STRING_LIST_DATA, result); + } + + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) + public void testSetOutputTypeWithIncompatibleType() throws Exception { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("not all subclasses of java.lang.Integer"); + + FromElementsFunction source = new FromElementsFunction<>(STRING_LIST_DATA); + source.setOutputType((TypeInformation) BasicTypeInfo.INT_TYPE_INFO, new ExecutionConfig()); + } + + @Test + public void testSetOutputTypeWithExistingBrokenSerializer() throws Exception { + TypeInformation info = + new ValueTypeInfo<>(DeserializeTooMuchType.class); + + FromElementsFunction source = + new FromElementsFunction<>( + info.createSerializer(new ExecutionConfig()), new DeserializeTooMuchType()); + + TypeSerializer existingSerializer = source.getSerializer(); + + source.setOutputType( + new GenericTypeInfo<>(DeserializeTooMuchType.class), new ExecutionConfig()); + + TypeSerializer newSerializer = source.getSerializer(); + + assertNotEquals(existingSerializer, newSerializer); + + List result = runSource(source); + + assertThat(result, hasSize(1)); + assertThat(result.get(0), instanceOf(DeserializeTooMuchType.class)); + } + + @Test + public void testSetOutputTypeAfterTransferred() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage( + "The output type should've been specified before shipping the graph to the cluster"); + + FromElementsFunction source = + InstantiationUtil.clone(new FromElementsFunction<>(STRING_LIST_DATA)); + source.setOutputType(BasicTypeInfo.STRING_TYPE_INFO, new ExecutionConfig()); + } + + @Test + public void testNoSerializer() throws Exception { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("serializer not configured"); + + FromElementsFunction source = new FromElementsFunction<>(STRING_LIST_DATA); + runSource(source); + } + @Test public void testNonJavaSerializableType() { try { @@ -79,8 +202,7 @@ public void testNonJavaSerializableType() { .createSerializer(new ExecutionConfig()), data); - List result = new ArrayList(); - source.run(new ListSourceContext(result)); + List result = runSource(source); assertEquals(Arrays.asList(data), result); } catch (Exception e) { @@ -89,6 +211,19 @@ public void testNonJavaSerializableType() { } } + @Test + public void testNonJavaSerializableTypeWithSetOutputType() throws Exception { + MyPojo[] data = {new MyPojo(1, 2), new MyPojo(3, 4), new MyPojo(5, 6)}; + + FromElementsFunction source = new FromElementsFunction<>(data); + + source.setOutputType(TypeExtractor.getForClass(MyPojo.class), new ExecutionConfig()); + + List result = runSource(source); + + assertEquals(Arrays.asList(data), result); + } + @Test public void testSerializationError() { try { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java index 8308b24c905cd..dc01b6043e8fa 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java @@ -19,6 +19,7 @@ package org.apache.flink.streaming.api.graph; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.operators.ResourceSpec; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; @@ -40,6 +41,7 @@ import org.apache.flink.streaming.api.functions.sink.DiscardingSink; import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.Output; @@ -209,6 +211,25 @@ public void testVirtualTransformations() throws Exception { instanceof ShufflePartitioner); } + @Test + public void testOutputTypeConfigurationWithUdfStreamOperator() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + + OutputTypeConfigurableFunction function = new OutputTypeConfigurableFunction<>(); + + DataStream source = env.fromElements(1, 10); + + NoOpUdfOperator udfOperator = new NoOpUdfOperator<>(function); + + source.transform("no-op udf operator", BasicTypeInfo.INT_TYPE_INFO, udfOperator) + .addSink(new DiscardingSink<>()); + + env.getStreamGraph(); + + assertTrue(udfOperator instanceof AbstractUdfStreamOperator); + assertEquals(BasicTypeInfo.INT_TYPE_INFO, function.getTypeInformation()); + } + /** * Test whether an {@link OutputTypeConfigurable} implementation gets called with the correct * output type. In this test case the output type must be BasicTypeInfo.INT_TYPE_INFO. @@ -634,6 +655,32 @@ public void testSetSlotSharingResource() { equalTo(resourceProfile3)); } + private static class OutputTypeConfigurableFunction + implements OutputTypeConfigurable, Function { + private TypeInformation typeInformation; + + public TypeInformation getTypeInformation() { + return typeInformation; + } + + @Override + public void setOutputType(TypeInformation outTypeInfo, ExecutionConfig executionConfig) { + typeInformation = outTypeInfo; + } + } + + static class NoOpUdfOperator extends AbstractUdfStreamOperator + implements OneInputStreamOperator { + NoOpUdfOperator(Function function) { + super(function); + } + + @Override + public void processElement(StreamRecord element) throws Exception { + output.collect(element); + } + } + static class OutputTypeConfigurableOperationWithTwoInputs extends AbstractStreamOperator implements TwoInputStreamOperator,