Skip to content

Commit

Permalink
[FLINK-13248] [runtime] Enabling custom factories for one input strea…
Browse files Browse the repository at this point in the history
…m operators to be passed in DataStream
  • Loading branch information
Arvid Heise committed Aug 27, 2019
1 parent bdf014c commit ddcfc5d
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 37 deletions.
Expand Up @@ -60,10 +60,13 @@
import org.apache.flink.streaming.api.functions.timestamps.AscendingTimestampExtractor;
import org.apache.flink.streaming.api.functions.timestamps.BoundedOutOfOrdernessTimestampExtractor;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory;
import org.apache.flink.streaming.api.operators.ProcessOperator;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamFilter;
import org.apache.flink.streaming.api.operators.StreamFlatMap;
import org.apache.flink.streaming.api.operators.StreamMap;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamSink;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.streaming.api.transformations.PartitionTransformation;
Expand Down Expand Up @@ -1172,21 +1175,53 @@ public DataStreamSink<T> writeUsingOutputFormat(OutputFormat<T> format) {
* @param <R>
* type of the return stream
* @return the data stream constructed
* @see #transform(String, TypeInformation, OneInputStreamOperatorFactory)
*/
@PublicEvolving
public <R> SingleOutputStreamOperator<R> transform(String operatorName, TypeInformation<R> outTypeInfo, OneInputStreamOperator<T, R> operator) {
public <R> SingleOutputStreamOperator<R> transform(
String operatorName,
TypeInformation<R> outTypeInfo,
OneInputStreamOperator<T, R> operator) {

return transformImpl(operatorName, outTypeInfo, SimpleOperatorFactory.of(operator));
}

/**
* Method for passing user defined operators created by the given factory along with the type information that will
* transform the DataStream.
*
* <p>This method uses the rather new operator factories and should only be used when custom factories are needed.
*
* @param operatorName name of the operator, for logging purposes
* @param outTypeInfo the output type of the operator
* @param operatorFactory the factory for the operator.
* @param <R> type of the return stream
* @return the data stream constructed.
*/
@PublicEvolving
public <R> SingleOutputStreamOperator<R> transform(
String operatorName,
TypeInformation<R> outTypeInfo,
OneInputStreamOperatorFactory<T, R> operatorFactory) {
return transformImpl(operatorName, outTypeInfo, operatorFactory);
}

private <R> SingleOutputStreamOperator<R> transformImpl(
String operatorName,
TypeInformation<R> outTypeInfo,
StreamOperatorFactory<R> operatorFactory) {

// read the output type of the input Transform to coax out errors about MissingTypeInfo
transformation.getOutputType();

OneInputTransformation<T, R> resultTransform = new OneInputTransformation<>(
this.transformation,
operatorName,
operator,
operatorFactory,
outTypeInfo,
environment.getParallelism());

@SuppressWarnings({ "unchecked", "rawtypes" })
@SuppressWarnings({"unchecked", "rawtypes"})
SingleOutputStreamOperator<R> returnStream = new SingleOutputStreamOperator(environment, resultTransform);

getExecutionEnvironment().addOperator(resultTransform);
Expand Down
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.api.operators;

/**
* A factory to create {@link OneInputStreamOperator}.
*
* @param <IN> The input type of the operator.
* @param <OUT> The output type of the operator.
*/
public interface OneInputStreamOperatorFactory<IN, OUT> extends StreamOperatorFactory<OUT> {
}
Expand Up @@ -55,6 +55,7 @@
import org.apache.flink.streaming.api.operators.SetupableStreamOperator;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
import org.apache.flink.streaming.api.operators.StreamTaskStateInitializerImpl;
import org.apache.flink.streaming.api.operators.YieldingOperatorFactory;
Expand Down Expand Up @@ -89,7 +90,9 @@
*/
public class AbstractStreamOperatorTestHarness<OUT> implements AutoCloseable {

protected final StreamOperator<OUT> operator;
protected StreamOperator<OUT> operator;

protected final StreamOperatorFactory<OUT> factory;

protected final ConcurrentLinkedQueue<Object> outputList;

Expand Down Expand Up @@ -150,32 +153,76 @@ public AbstractStreamOperatorTestHarness(
int subtaskIndex,
OperatorID operatorID) throws Exception {
this(
operator,
new MockEnvironmentBuilder()
.setTaskName("MockTask")
.setMemorySize(3 * 1024 * 1024)
.setInputSplitProvider(new MockInputSplitProvider())
.setBufferSize(1024)
.setMaxParallelism(maxParallelism)
.setParallelism(parallelism)
.setSubtaskIndex(subtaskIndex)
.build(),
true,
operatorID);
operator,
SimpleOperatorFactory.of(operator),
new MockEnvironmentBuilder()
.setTaskName("MockTask")
.setMemorySize(3 * 1024 * 1024)
.setInputSplitProvider(new MockInputSplitProvider())
.setBufferSize(1024)
.setMaxParallelism(maxParallelism)
.setParallelism(parallelism)
.setSubtaskIndex(subtaskIndex)
.build(),
true,
operatorID);
}

public AbstractStreamOperatorTestHarness(
StreamOperatorFactory<OUT> factory,
MockEnvironment env) throws Exception {
this(null, factory, env, false, new OperatorID());
}

public AbstractStreamOperatorTestHarness(
StreamOperatorFactory<OUT> factory,
int maxParallelism,
int parallelism,
int subtaskIndex) throws Exception {
this(
factory,
maxParallelism,
parallelism,
subtaskIndex,
new OperatorID());
}

public AbstractStreamOperatorTestHarness(
StreamOperatorFactory<OUT> factory,
int maxParallelism,
int parallelism,
int subtaskIndex,
OperatorID operatorID) throws Exception {
this(
null,
factory,
new MockEnvironmentBuilder()
.setTaskName("MockTask")
.setMemorySize(3 * 1024 * 1024)
.setInputSplitProvider(new MockInputSplitProvider())
.setBufferSize(1024)
.setMaxParallelism(maxParallelism)
.setParallelism(parallelism)
.setSubtaskIndex(subtaskIndex)
.build(),
true,
operatorID);
}

public AbstractStreamOperatorTestHarness(
StreamOperator<OUT> operator,
MockEnvironment env) throws Exception {
this(operator, env, false, new OperatorID());
this(operator, SimpleOperatorFactory.of(operator), env, false, new OperatorID());
}

private AbstractStreamOperatorTestHarness(
StreamOperator<OUT> operator,
StreamOperatorFactory<OUT> factory,
MockEnvironment env,
boolean environmentIsInternal,
OperatorID operatorID) throws Exception {
this.operator = operator;
this.factory = factory;
this.outputList = new ConcurrentLinkedQueue<>();
this.sideOutputLists = new HashMap<>();

Expand Down Expand Up @@ -287,13 +334,14 @@ public void setup(TypeSerializer<OUT> outputSerializer) {
createStreamTaskStateManager(environment, stateBackend, processingTimeService);
mockTask.setStreamTaskStateInitializer(streamTaskStateInitializer);

if (operator instanceof SetupableStreamOperator) {
((SetupableStreamOperator) operator).setup(mockTask, config, new MockOutput(outputSerializer));
}
SimpleOperatorFactory<OUT> factory = SimpleOperatorFactory.of(operator);
if (factory instanceof YieldingOperatorFactory) {
((YieldingOperatorFactory) factory).setMailboxExecutor(mockTask.getMailboxExecutorFactory().apply(config.getChainIndex()));
}
if (operator == null) {
this.operator = factory.createStreamOperator(mockTask, config, new MockOutput(outputSerializer));
} else if (operator instanceof SetupableStreamOperator) {
((SetupableStreamOperator) operator).setup(mockTask, config, new MockOutput(outputSerializer));
}
setupCalled = true;
this.mockTask.init();
}
Expand Down
Expand Up @@ -23,6 +23,7 @@
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.Preconditions;
Expand All @@ -39,8 +40,6 @@
public class OneInputStreamOperatorTestHarness<IN, OUT>
extends AbstractStreamOperatorTestHarness<OUT> {

private final OneInputStreamOperator<IN, OUT> oneInputOperator;

private long currentWatermark;

public OneInputStreamOperatorTestHarness(
Expand Down Expand Up @@ -91,16 +90,56 @@ public OneInputStreamOperatorTestHarness(
int subtaskIndex,
OperatorID operatorID) throws Exception {
super(operator, maxParallelism, parallelism, subtaskIndex, operatorID);

this.oneInputOperator = operator;
}

public OneInputStreamOperatorTestHarness(
OneInputStreamOperator<IN, OUT> operator,
MockEnvironment environment) throws Exception {
super(operator, environment);
}

public OneInputStreamOperatorTestHarness(
OneInputStreamOperatorFactory<IN, OUT> factory,
TypeSerializer<IN> typeSerializerIn,
MockEnvironment environment) throws Exception {
this(factory, environment);

config.setTypeSerializerIn1(Preconditions.checkNotNull(typeSerializerIn));
}

public OneInputStreamOperatorTestHarness(
OneInputStreamOperatorFactory<IN, OUT> factory,
MockEnvironment environment) throws Exception {
super(factory, environment);
}

public OneInputStreamOperatorTestHarness(
OneInputStreamOperatorFactory<IN, OUT> factory,
TypeSerializer<IN> typeSerializerIn) throws Exception {
this(factory, 1, 1, 0);

config.setTypeSerializerIn1(Preconditions.checkNotNull(typeSerializerIn));
}

public OneInputStreamOperatorTestHarness(
OneInputStreamOperatorFactory<IN, OUT> factory,
int maxParallelism,
int parallelism,
int subtaskIndex) throws Exception {
this(factory, maxParallelism, parallelism, subtaskIndex, new OperatorID());
}

public OneInputStreamOperatorTestHarness(
OneInputStreamOperatorFactory<IN, OUT> factory,
int maxParallelism,
int parallelism,
int subtaskIndex,
OperatorID operatorID) throws Exception {
super(factory, maxParallelism, parallelism, subtaskIndex, operatorID);
}

this.oneInputOperator = operator;
public OneInputStreamOperator<IN, OUT> getOneInputOperator() {
return (OneInputStreamOperator<IN, OUT>) this.operator;
}

public void processElement(IN value, long timestamp) throws Exception {
Expand All @@ -109,13 +148,13 @@ public void processElement(IN value, long timestamp) throws Exception {

public void processElement(StreamRecord<IN> element) throws Exception {
operator.setKeyContextElement1(element);
oneInputOperator.processElement(element);
getOneInputOperator().processElement(element);
}

public void processElements(Collection<StreamRecord<IN>> elements) throws Exception {
for (StreamRecord<IN> element: elements) {
operator.setKeyContextElement1(element);
oneInputOperator.processElement(element);
getOneInputOperator().processElement(element);
}
}

Expand All @@ -125,16 +164,16 @@ public void processWatermark(long watermark) throws Exception {

public void processWatermark(Watermark mark) throws Exception {
currentWatermark = mark.getTimestamp();
oneInputOperator.processWatermark(mark);
getOneInputOperator().processWatermark(mark);
}

public long getCurrentWatermark() {
return currentWatermark;
}

public void endInput() throws Exception {
if (oneInputOperator instanceof BoundedOneInput) {
((BoundedOneInput) oneInputOperator).endInput();
if (getOneInputOperator() instanceof BoundedOneInput) {
((BoundedOneInput) getOneInputOperator()).endInput();
} else {
throw new UnsupportedOperationException("The operator is not BoundedOneInput");
}
Expand Down
Expand Up @@ -408,12 +408,8 @@ class HarnessTestBase extends StreamingWithStateTestBase {
}

def getOperator(testHarness: OneInputStreamOperatorTestHarness[_, _])
: AbstractUdfStreamOperator[_, _] = {
val operatorField = classOf[OneInputStreamOperatorTestHarness[_, _]]
.getDeclaredField("oneInputOperator")
operatorField.setAccessible(true)
operatorField.get(testHarness).asInstanceOf[AbstractUdfStreamOperator[_, _]]
}
: AbstractUdfStreamOperator[_, _] =
testHarness.getOneInputOperator.asInstanceOf[AbstractUdfStreamOperator[_, _]]

def verify(expected: JQueue[Object], actual: JQueue[Object]): Unit = {
verify(expected, actual, new RowResultSortComparator)
Expand Down

0 comments on commit ddcfc5d

Please sign in to comment.