Skip to content

Commit

Permalink
[FLINK-1345] [streaming] Chaining refactor + ChainingStrategy exposed…
Browse files Browse the repository at this point in the history
… through the API for operators
  • Loading branch information
gyfora authored and mbalassi committed Jan 21, 2015
1 parent 26535c4 commit 3e30c6f
Show file tree
Hide file tree
Showing 16 changed files with 159 additions and 157 deletions.
Expand Up @@ -36,6 +36,7 @@
import org.apache.flink.streaming.api.collector.OutputSelector; import org.apache.flink.streaming.api.collector.OutputSelector;
import org.apache.flink.streaming.api.invokable.ChainableInvokable; import org.apache.flink.streaming.api.invokable.ChainableInvokable;
import org.apache.flink.streaming.api.invokable.StreamInvokable; import org.apache.flink.streaming.api.invokable.StreamInvokable;
import org.apache.flink.streaming.api.invokable.StreamInvokable.ChainingStrategy;
import org.apache.flink.streaming.api.invokable.operator.co.CoInvokable; import org.apache.flink.streaming.api.invokable.operator.co.CoInvokable;
import org.apache.flink.streaming.api.streamrecord.StreamRecordSerializer; import org.apache.flink.streaming.api.streamrecord.StreamRecordSerializer;
import org.apache.flink.streaming.api.streamvertex.CoStreamVertex; import org.apache.flink.streaming.api.streamvertex.CoStreamVertex;
Expand Down Expand Up @@ -64,7 +65,7 @@ public class JobGraphBuilder {
private Map<String, Integer> vertexParallelism; private Map<String, Integer> vertexParallelism;
private Map<String, Long> bufferTimeout; private Map<String, Long> bufferTimeout;
private Map<String, List<String>> outEdgeList; private Map<String, List<String>> outEdgeList;
private Map<String, List<Integer>> outEdgeType; private Map<String, List<Integer>> outEdgeIndex;
private Map<String, List<List<String>>> outEdgeNames; private Map<String, List<List<String>>> outEdgeNames;
private Map<String, List<Boolean>> outEdgeSelectAll; private Map<String, List<Boolean>> outEdgeSelectAll;
private Map<String, List<String>> inEdgeList; private Map<String, List<String>> inEdgeList;
Expand All @@ -85,8 +86,7 @@ public class JobGraphBuilder {
private Map<String, Map<String, OperatorState<?>>> operatorStates; private Map<String, Map<String, OperatorState<?>>> operatorStates;
private Map<String, InputFormat<String, ?>> inputFormatList; private Map<String, InputFormat<String, ?>> inputFormatList;
private Map<String, List<String>> chainedVertices; private Map<String, List<String>> chainedVertices;
private Map<String, List<ChainableInvokable<?, ?>>> chainedInvokable; private Map<String, String> lastInChains;
private Map<String, List<StreamRecordSerializer<?>>> chainedSerializer;


private Set<String> sources; private Set<String> sources;
private Set<String> builtVertices; private Set<String> builtVertices;
Expand All @@ -97,11 +97,19 @@ public class JobGraphBuilder {
*/ */
public JobGraphBuilder() { public JobGraphBuilder() {


initGraph();

if (LOG.isDebugEnabled()) {
LOG.debug("JobGraph created");
}
}

public void initGraph() {
streamVertices = new HashMap<String, AbstractJobVertex>(); streamVertices = new HashMap<String, AbstractJobVertex>();
vertexParallelism = new HashMap<String, Integer>(); vertexParallelism = new HashMap<String, Integer>();
bufferTimeout = new HashMap<String, Long>(); bufferTimeout = new HashMap<String, Long>();
outEdgeList = new HashMap<String, List<String>>(); outEdgeList = new HashMap<String, List<String>>();
outEdgeType = new HashMap<String, List<Integer>>(); outEdgeIndex = new HashMap<String, List<Integer>>();
outEdgeNames = new HashMap<String, List<List<String>>>(); outEdgeNames = new HashMap<String, List<List<String>>>();
outEdgeSelectAll = new HashMap<String, List<Boolean>>(); outEdgeSelectAll = new HashMap<String, List<Boolean>>();
inEdgeList = new HashMap<String, List<String>>(); inEdgeList = new HashMap<String, List<String>>();
Expand All @@ -121,16 +129,11 @@ public JobGraphBuilder() {
iterationWaitTime = new HashMap<String, Long>(); iterationWaitTime = new HashMap<String, Long>();
operatorStates = new HashMap<String, Map<String, OperatorState<?>>>(); operatorStates = new HashMap<String, Map<String, OperatorState<?>>>();
inputFormatList = new HashMap<String, InputFormat<String, ?>>(); inputFormatList = new HashMap<String, InputFormat<String, ?>>();
chainedInvokable = new HashMap<String, List<ChainableInvokable<?, ?>>>();
chainedSerializer = new HashMap<String, List<StreamRecordSerializer<?>>>();
chainedVertices = new HashMap<String, List<String>>(); chainedVertices = new HashMap<String, List<String>>();
lastInChains = new HashMap<String, String>();


sources = new HashSet<String>(); sources = new HashSet<String>();
builtVertices = new HashSet<String>(); builtVertices = new HashSet<String>();

if (LOG.isDebugEnabled()) {
LOG.debug("JobGraph created");
}
} }


/** /**
Expand Down Expand Up @@ -198,7 +201,7 @@ public void addIterationHead(String vertexName, String iterationHead, Integer it
iterationIds.put(vertexName, iterationID); iterationIds.put(vertexName, iterationID);
iterationIDtoHeadName.put(iterationID, vertexName); iterationIDtoHeadName.put(iterationID, vertexName);


setBytesFrom(iterationHead, vertexName); setSerializersFrom(iterationHead, vertexName);


setEdge(vertexName, iterationHead, outPartitioning setEdge(vertexName, iterationHead, outPartitioning
.get(inEdgeList.get(iterationHead).get(0)).get(0), 0, new ArrayList<String>(), .get(inEdgeList.get(iterationHead).get(0)).get(0), 0, new ArrayList<String>(),
Expand Down Expand Up @@ -241,7 +244,7 @@ public void addIterationTail(String vertexName, String iterationTail, Integer it
iterationIds.put(vertexName, iterationID); iterationIds.put(vertexName, iterationID);
iterationIDtoTailName.put(iterationID, vertexName); iterationIDtoTailName.put(iterationID, vertexName);


setBytesFrom(iterationTail, vertexName); setSerializersFrom(iterationTail, vertexName);
iterationWaitTime.put(iterationIDtoTailName.get(iterationID), waitTime); iterationWaitTime.put(iterationIDtoTailName.get(iterationID), waitTime);


if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
Expand Down Expand Up @@ -288,12 +291,13 @@ private void addVertex(String vertexName, Class<? extends AbstractInvokable> ver
invokableObjects.put(vertexName, invokableObject); invokableObjects.put(vertexName, invokableObject);
operatorNames.put(vertexName, operatorName); operatorNames.put(vertexName, operatorName);
outEdgeList.put(vertexName, new ArrayList<String>()); outEdgeList.put(vertexName, new ArrayList<String>());
outEdgeType.put(vertexName, new ArrayList<Integer>()); outEdgeIndex.put(vertexName, new ArrayList<Integer>());
outEdgeNames.put(vertexName, new ArrayList<List<String>>()); outEdgeNames.put(vertexName, new ArrayList<List<String>>());
outEdgeSelectAll.put(vertexName, new ArrayList<Boolean>()); outEdgeSelectAll.put(vertexName, new ArrayList<Boolean>());
inEdgeList.put(vertexName, new ArrayList<String>()); inEdgeList.put(vertexName, new ArrayList<String>());
outPartitioning.put(vertexName, new ArrayList<StreamPartitioner<?>>()); outPartitioning.put(vertexName, new ArrayList<StreamPartitioner<?>>());
iterationTailCount.put(vertexName, 0); iterationTailCount.put(vertexName, 0);
lastInChains.put(vertexName, vertexName);
} }


private void addTypeSerializers(String vertexName, StreamRecordSerializer<?> in1, private void addTypeSerializers(String vertexName, StreamRecordSerializer<?> in1,
Expand Down Expand Up @@ -327,11 +331,16 @@ private void createVertex(String vertexName) {
} }
} }


List<String> chainedNames = chainedVertices.get(vertexName);
boolean isChained = chainedNames != null;
int numChained = isChained ? chainedNames.size() : 0;
String lastInChain = lastInChains.get(vertexName);

// Get vertex attributes // Get vertex attributes
Class<? extends AbstractInvokable> vertexClass = vertexClasses.get(vertexName); Class<? extends AbstractInvokable> vertexClass = vertexClasses.get(vertexName);
StreamInvokable<?, ?> invokableObject = invokableObjects.get(vertexName); StreamInvokable<?, ?> invokableObject = invokableObjects.get(vertexName);
int parallelism = vertexParallelism.get(vertexName); int parallelism = vertexParallelism.get(vertexName);
byte[] outputSelector = outputSelectors.get(vertexName); byte[] outputSelector = outputSelectors.get(lastInChain);
Map<String, OperatorState<?>> state = operatorStates.get(vertexName); Map<String, OperatorState<?>> state = operatorStates.get(vertexName);


// Create vertex object // Create vertex object
Expand All @@ -347,16 +356,17 @@ private void createVertex(String vertexName) {
LOG.debug("Parallelism set: {} for {}", parallelism, vertexName); LOG.debug("Parallelism set: {} for {}", parallelism, vertexName);
} }


// Set vertex config

StreamConfig config = new StreamConfig(vertex.getConfiguration()); StreamConfig config = new StreamConfig(vertex.getConfiguration());


config.setBufferTimeout(bufferTimeout.get(vertexName)); config.setBufferTimeout(bufferTimeout.get(lastInChain));


config.setTypeSerializerIn1(typeSerializersIn1.get(vertexName)); config.setTypeSerializerIn1(typeSerializersIn1.get(vertexName));
config.setTypeSerializerIn2(typeSerializersIn2.get(vertexName)); config.setTypeSerializerIn2(typeSerializersIn2.get(vertexName));
config.setTypeSerializerOut1(typeSerializersOut1.get(vertexName)); config.setTypeSerializerOut1(typeSerializersOut1.get(vertexName));
config.setTypeSerializerOut2(typeSerializersOut2.get(vertexName)); config.setTypeSerializerOut2(typeSerializersOut2.get(vertexName));


// Set vertex config
config.setUserInvokable(invokableObject); config.setUserInvokable(invokableObject);
config.setOutputSelector(outputSelector); config.setOutputSelector(outputSelector);
config.setOperatorStates(state); config.setOperatorStates(state);
Expand All @@ -371,15 +381,12 @@ private void createVertex(String vertexName) {
vertex.setInputSplitSource(inputFormatList.get(vertexName)); vertex.setInputSplitSource(inputFormatList.get(vertexName));
} }


List<ChainableInvokable<?, ?>> chainedInvokables = chainedInvokable.get(vertexName);
List<StreamRecordSerializer<?>> chainedSerializers = chainedSerializer.get(vertexName);

int numChained = chainedInvokables == null ? 0 : chainedInvokables.size();
config.setNumberofChainedTasks(numChained); config.setNumberofChainedTasks(numChained);


for (int i = 0; i < numChained; i++) { for (int i = 0; i < numChained; i++) {
config.setChainedInvokable(chainedInvokables.get(i), i); config.setChainedInvokable(
config.setChainedSerializer(chainedSerializers.get(i), i); (ChainableInvokable<?, ?>) invokableObjects.get(chainedNames.get(i)), i);
config.setChainedSerializer(typeSerializersIn1.get(chainedNames.get(i)), i);
} }


streamVertices.put(vertexName, vertex); streamVertices.put(vertexName, vertex);
Expand All @@ -390,12 +397,15 @@ private void createVertex(String vertexName) {
} }


private void chainRecursively(String chainStart, String current, String next) { private void chainRecursively(String chainStart, String current, String next) {
// We chain the next operator to the start of this chain
chainTasks(chainStart, next); chainTasks(chainStart, next);
// Add multiple chaining here // Now recursively chain the outputs of next (depth first)
for (String output : outEdgeList.get(next)) { for (String output : outEdgeList.get(next)) {
if (isChainable(next, output)) { if (isChainable(next, output)) {
// Recursive call
chainRecursively(chainStart, next, output); chainRecursively(chainStart, next, output);
} else { } else {
// If not chainable we continue building the jobgraph from there
createVertex(output); createVertex(output);
} }
} }
Expand All @@ -405,42 +415,22 @@ private boolean isChainable(String vertexName, String outName) {
return outEdgeList.get(vertexName).size() == 1 return outEdgeList.get(vertexName).size() == 1
&& inEdgeList.get(outName).size() == 1 && inEdgeList.get(outName).size() == 1
&& outputSelectors.get(vertexName) == null && outputSelectors.get(vertexName) == null
&& invokableObjects.get(outName).isChainable() && invokableObjects.get(outName).getChainingStrategy() == ChainingStrategy.ALWAYS
&& (invokableObjects.get(vertexName).getChainingStrategy() == ChainingStrategy.HEAD || invokableObjects
.get(vertexName).getChainingStrategy() == ChainingStrategy.ALWAYS)
&& outPartitioning.get(vertexName).get(0).getStrategy() == PartitioningStrategy.FORWARD && outPartitioning.get(vertexName).get(0).getStrategy() == PartitioningStrategy.FORWARD
&& vertexParallelism.get(vertexName) == vertexParallelism.get(outName) && chaining; && vertexParallelism.get(vertexName) == vertexParallelism.get(outName) && chaining;
} }


private void chainTasks(String first, String second) { private void chainTasks(String first, String second) {
List<ChainableInvokable<?, ?>> chainedInvokables = chainedInvokable.get(first);
if (chainedInvokables == null) {
chainedInvokables = new ArrayList<ChainableInvokable<?, ?>>();
}
chainedInvokables.add((ChainableInvokable<?, ?>) invokableObjects.get(second));
chainedInvokable.put(first, chainedInvokables);

List<StreamRecordSerializer<?>> chainedSerializers = chainedSerializer.get(first);
if (chainedSerializers == null) {
chainedSerializers = new ArrayList<StreamRecordSerializer<?>>();
}
chainedSerializers.add(typeSerializersIn1.get(second));
chainedSerializer.put(first, chainedSerializers);


List<String> chained = chainedVertices.get(first); List<String> chained = chainedVertices.get(first);
if (chained == null) { if (chained == null) {
chained = new ArrayList<String>(); chained = new ArrayList<String>();
} }
chained.add(second); chained.add(second);
chainedVertices.put(first, chained); chainedVertices.put(first, chained);

lastInChains.put(first, second);
outEdgeList.put(first, outEdgeList.get(second));
typeSerializersOut1.put(first, typeSerializersOut1.get(second));
outPartitioning.put(first, outPartitioning.get(second));
outEdgeType.put(first, outEdgeType.get(second));
outEdgeNames.put(first, outEdgeNames.get(second));
outEdgeSelectAll.put(first, outEdgeSelectAll.get(second));
outPartitioning.put(first, outPartitioning.get(second));
bufferTimeout.put(first, bufferTimeout.get(second));
outputSelectors.put(first, outputSelectors.get(second));


} }


Expand Down Expand Up @@ -477,8 +467,10 @@ private <T> void connect(String upStreamVertexName, String downStreamVertexName,


int outputIndex = upStreamVertex.getNumberOfProducedIntermediateDataSets() - 1; int outputIndex = upStreamVertex.getNumberOfProducedIntermediateDataSets() - 1;


config.setOutputName(outputIndex, outEdgeNames.get(upStreamVertexName).get(outputIndex)); config.setOutputName(outputIndex, outEdgeNames.get(lastInChains.get(upStreamVertexName))
config.setSelectAll(outputIndex, outEdgeSelectAll.get(upStreamVertexName).get(outputIndex)); .get(outputIndex));
config.setSelectAll(outputIndex, outEdgeSelectAll.get(lastInChains.get(upStreamVertexName))
.get(outputIndex));
config.setPartitioner(outputIndex, partitionerObject); config.setPartitioner(outputIndex, partitionerObject);
config.setNumberOfOutputChannels(outputIndex, vertexParallelism.get(downStreamVertexName)); config.setNumberOfOutputChannels(outputIndex, vertexParallelism.get(downStreamVertexName));
} }
Expand Down Expand Up @@ -547,7 +539,7 @@ public void setEdge(String upStreamVertexName, String downStreamVertexName,
StreamPartitioner<?> partitionerObject, int typeNumber, List<String> outputNames, StreamPartitioner<?> partitionerObject, int typeNumber, List<String> outputNames,
boolean selectAll) { boolean selectAll) {
outEdgeList.get(upStreamVertexName).add(downStreamVertexName); outEdgeList.get(upStreamVertexName).add(downStreamVertexName);
outEdgeType.get(upStreamVertexName).add(typeNumber); outEdgeIndex.get(upStreamVertexName).add(typeNumber);
inEdgeList.get(downStreamVertexName).add(upStreamVertexName); inEdgeList.get(downStreamVertexName).add(upStreamVertexName);
outPartitioning.get(upStreamVertexName).add(partitionerObject); outPartitioning.get(upStreamVertexName).add(partitionerObject);
outEdgeNames.get(upStreamVertexName).add(outputNames); outEdgeNames.get(upStreamVertexName).add(outputNames);
Expand Down Expand Up @@ -608,7 +600,7 @@ public <OUT> void setOutType(String id, TypeInformation<OUT> outType) {
* @param to * @param to
* to * to
*/ */
public void setBytesFrom(String from, String to) { public void setSerializersFrom(String from, String to) {
operatorNames.put(to, operatorNames.get(from)); operatorNames.put(to, operatorNames.get(from));


typeSerializersIn1.put(to, typeSerializersOut1.get(from)); typeSerializersIn1.put(to, typeSerializersOut1.get(from));
Expand Down Expand Up @@ -691,19 +683,20 @@ private void buildJobGraph() {
for (String upStreamVertexName : builtVertices) { for (String upStreamVertexName : builtVertices) {
int i = 0; int i = 0;


List<Integer> outEdgeTypeList = outEdgeType.get(upStreamVertexName); List<Integer> outEdgeTypeList = outEdgeIndex.get(lastInChains.get(upStreamVertexName));


for (String downStreamVertexName : outEdgeList.get(upStreamVertexName)) { for (String downStreamVertexName : outEdgeList
.get(lastInChains.get(upStreamVertexName))) {
StreamConfig downStreamVertexConfig = new StreamConfig(streamVertices.get( StreamConfig downStreamVertexConfig = new StreamConfig(streamVertices.get(
downStreamVertexName).getConfiguration()); downStreamVertexName).getConfiguration());


int inputNumber = downStreamVertexConfig.getNumberOfInputs(); int inputNumber = downStreamVertexConfig.getNumberOfInputs();


downStreamVertexConfig.setInputType(inputNumber++, outEdgeTypeList.get(i)); downStreamVertexConfig.setInputIndex(inputNumber++, outEdgeTypeList.get(i));
downStreamVertexConfig.setNumberOfInputs(inputNumber); downStreamVertexConfig.setNumberOfInputs(inputNumber);


connect(upStreamVertexName, downStreamVertexName, connect(upStreamVertexName, downStreamVertexName,
outPartitioning.get(upStreamVertexName).get(i)); outPartitioning.get(lastInChains.get(upStreamVertexName)).get(i));
i++; i++;
} }
} }
Expand Down
Expand Up @@ -24,7 +24,6 @@


import org.apache.commons.lang3.SerializationException; import org.apache.commons.lang3.SerializationException;
import org.apache.commons.lang3.SerializationUtils; import org.apache.commons.lang3.SerializationUtils;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.collector.OutputSelector; import org.apache.flink.streaming.api.collector.OutputSelector;
import org.apache.flink.streaming.api.invokable.ChainableInvokable; import org.apache.flink.streaming.api.invokable.ChainableInvokable;
Expand Down Expand Up @@ -170,7 +169,7 @@ public void setDirectedEmit(boolean directedEmit) {
config.setBoolean(DIRECTED_EMIT, directedEmit); config.setBoolean(DIRECTED_EMIT, directedEmit);
} }


public boolean getDirectedEmit() { public boolean isDirectedEmit() {
return config.getBoolean(DIRECTED_EMIT, false); return config.getBoolean(DIRECTED_EMIT, false);
} }


Expand Down Expand Up @@ -239,7 +238,7 @@ public void setSelectAll(int outputIndex, Boolean selectAll) {
} }
} }


public boolean getSelectAll(int outputIndex) { public boolean isSelectAll(int outputIndex) {
return config.getBoolean(OUTPUT_SELECT_ALL + outputIndex, false); return config.getBoolean(OUTPUT_SELECT_ALL + outputIndex, false);
} }


Expand Down Expand Up @@ -272,26 +271,14 @@ public int getNumberOfOutputs() {
return config.getInteger(NUMBER_OF_OUTPUTS, 0); return config.getInteger(NUMBER_OF_OUTPUTS, 0);
} }


public void setInputType(int inputNumber, Integer inputTypeNumber) { public void setInputIndex(int inputNumber, Integer inputTypeNumber) {
config.setInteger(INPUT_TYPE + inputNumber++, inputTypeNumber); config.setInteger(INPUT_TYPE + inputNumber++, inputTypeNumber);
} }


public int getInputType(int inputNumber) { public int getInputIndex(int inputNumber) {
return config.getInteger(INPUT_TYPE + inputNumber, 0); return config.getInteger(INPUT_TYPE + inputNumber, 0);
} }


public void setFunctionClass(Class<? extends AbstractRichFunction> functionClass) {
config.setClass("functionClass", functionClass);
}

public Class<? extends AbstractRichFunction> getFunctionClass(ClassLoader cl) {
try {
return config.getClass("functionClass", null, cl);
} catch (ClassNotFoundException e) {
throw new RuntimeException("Could not load function class", e);
}
}

public void setOperatorStates(Map<String, OperatorState<?>> states) { public void setOperatorStates(Map<String, OperatorState<?>> states) {
config.setBytes(OPERATOR_STATES, SerializationUtils.serialize((Serializable) states)); config.setBytes(OPERATOR_STATES, SerializationUtils.serialize((Serializable) states));
} }
Expand Down
Expand Up @@ -400,7 +400,7 @@ protected <OUT> CoInvokable<IN1, IN2, OUT> getReduceInvokable(


@SuppressWarnings({ "unchecked", "rawtypes" }) @SuppressWarnings({ "unchecked", "rawtypes" })
SingleOutputStreamOperator<OUT, ?> returnStream = new SingleOutputStreamOperator( SingleOutputStreamOperator<OUT, ?> returnStream = new SingleOutputStreamOperator(
environment, functionName, outTypeInfo); environment, functionName, outTypeInfo, functionInvokable);


dataStream1.jobGraphBuilder.addCoTask(returnStream.getId(), functionInvokable, dataStream1.jobGraphBuilder.addCoTask(returnStream.getId(), functionInvokable,
getInputType1(), getInputType2(), outTypeInfo, functionName, getInputType1(), getInputType2(), outTypeInfo, functionName,
Expand Down
Expand Up @@ -1083,7 +1083,7 @@ private DataStreamSink<OUT> writeToFile(OutputFormat<OUT> format, long millis) {
protected <R> DataStream<OUT> addIterationSource(Integer iterationID, long waitTime) { protected <R> DataStream<OUT> addIterationSource(Integer iterationID, long waitTime) {


DataStream<R> returnStream = new DataStreamSource<R>(environment, "iterationSource", null, DataStream<R> returnStream = new DataStreamSource<R>(environment, "iterationSource", null,
true); null, true);


jobGraphBuilder.addIterationHead(returnStream.getId(), this.getId(), iterationID, jobGraphBuilder.addIterationHead(returnStream.getId(), this.getId(), iterationID,
degreeOfParallelism, waitTime); degreeOfParallelism, waitTime);
Expand All @@ -1110,7 +1110,7 @@ protected <R> DataStream<OUT> addIterationSource(Integer iterationID, long waitT
DataStream<OUT> inputStream = this.copy(); DataStream<OUT> inputStream = this.copy();
@SuppressWarnings({ "unchecked", "rawtypes" }) @SuppressWarnings({ "unchecked", "rawtypes" })
SingleOutputStreamOperator<R, ?> returnStream = new SingleOutputStreamOperator(environment, SingleOutputStreamOperator<R, ?> returnStream = new SingleOutputStreamOperator(environment,
operatorName, outTypeInfo); operatorName, outTypeInfo, invokable);


jobGraphBuilder.addStreamVertex(returnStream.getId(), invokable, getType(), outTypeInfo, jobGraphBuilder.addStreamVertex(returnStream.getId(), invokable, getType(), outTypeInfo,
operatorName, degreeOfParallelism); operatorName, degreeOfParallelism);
Expand Down Expand Up @@ -1174,10 +1174,13 @@ protected <X> void connectGraph(DataStream<X> inputStream, String outputID, int
*/ */
public DataStreamSink<OUT> addSink(SinkFunction<OUT> sinkFunction) { public DataStreamSink<OUT> addSink(SinkFunction<OUT> sinkFunction) {


DataStreamSink<OUT> returnStream = new DataStreamSink<OUT>(environment, "sink", getType()); StreamInvokable<OUT, OUT> sinkInvokable = new SinkInvokable<OUT>(clean(sinkFunction));


jobGraphBuilder.addStreamVertex(returnStream.getId(), new SinkInvokable<OUT>( DataStreamSink<OUT> returnStream = new DataStreamSink<OUT>(environment, "sink", getType(),
clean(sinkFunction)), getType(), null, "sink", degreeOfParallelism); sinkInvokable);

jobGraphBuilder.addStreamVertex(returnStream.getId(), sinkInvokable, getType(), null,
"sink", degreeOfParallelism);


this.connectGraph(this.copy(), returnStream.getId(), 0); this.connectGraph(this.copy(), returnStream.getId(), 0);


Expand Down
Expand Up @@ -19,6 +19,7 @@


import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.invokable.StreamInvokable;


/** /**
* Represents the end of a DataStream. * Represents the end of a DataStream.
Expand All @@ -29,8 +30,8 @@
public class DataStreamSink<IN> extends SingleOutputStreamOperator<IN, DataStreamSink<IN>> { public class DataStreamSink<IN> extends SingleOutputStreamOperator<IN, DataStreamSink<IN>> {


protected DataStreamSink(StreamExecutionEnvironment environment, String operatorType, protected DataStreamSink(StreamExecutionEnvironment environment, String operatorType,
TypeInformation<IN> outTypeInfo) { TypeInformation<IN> outTypeInfo, StreamInvokable<?,?> invokable) {
super(environment, operatorType, outTypeInfo); super(environment, operatorType, outTypeInfo, invokable);
} }


protected DataStreamSink(DataStream<IN> dataStream) { protected DataStreamSink(DataStream<IN> dataStream) {
Expand Down

0 comments on commit 3e30c6f

Please sign in to comment.