diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index 4b446bafed523..dce8b6066f710 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.operators.ResourceSpec; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.IllegalConfigurationException; @@ -66,7 +67,6 @@ import org.apache.flink.streaming.runtime.tasks.StreamIterationHead; import org.apache.flink.streaming.runtime.tasks.StreamIterationTail; import org.apache.flink.util.FlinkRuntimeException; -import org.apache.flink.util.Preconditions; import org.apache.flink.util.SerializedValue; import org.apache.commons.lang3.StringUtils; @@ -139,8 +139,6 @@ public static JobGraph createJobGraph(StreamGraph streamGraph, @Nullable JobID j private final StreamGraphHasher defaultStreamGraphHasher; private final List legacyStreamGraphHashers; - private final Map chainInfos; - private StreamingJobGraphGenerator(StreamGraph streamGraph, @Nullable JobID jobID) { this.streamGraph = streamGraph; this.defaultStreamGraphHasher = new StreamGraphHasherV2(); @@ -155,7 +153,6 @@ private StreamingJobGraphGenerator(StreamGraph streamGraph, @Nullable JobID jobI this.chainedPreferredResources = new HashMap<>(); this.chainedInputOutputFormats = new HashMap<>(); this.physicalEdgesInOrder = new ArrayList<>(); - this.chainInfos = new HashMap<>(); jobGraph = new JobGraph(jobID, streamGraph.getJobName()); } @@ -260,7 +257,7 @@ private void setPhysicalEdges() { } } - private Collection buildChainedInputsAndGetHeadInputs( + private Map buildChainedInputsAndGetHeadInputs( final Map hashes, final List> legacyHashes) { @@ -274,8 +271,9 @@ private Collection buildChainedInputsAndGetHeadInputs( // as long as only NAry ops support this chaining, we need to skip the other parts final StreamEdge sourceOutEdge = sourceNode.getOutEdges().get(0); final StreamNode target = streamGraph.getStreamNode(sourceOutEdge.getTargetId()); + final ChainingStrategy targetChainingStrategy = target.getOperatorFactory().getChainingStrategy(); - if (isMultipleInput(target) && isChainableIgnoringInEdgesSize(sourceOutEdge, streamGraph)) { + if (targetChainingStrategy == ChainingStrategy.HEAD_WITH_SOURCES && isChainableInput(sourceOutEdge, streamGraph)) { final OperatorID opId = new OperatorID(hashes.get(sourceNodeId)); final StreamConfig.SourceInputConfig inputConfig = new StreamConfig.SourceInputConfig(sourceOutEdge); final StreamConfig operatorConfig = new StreamConfig(new Configuration()); @@ -296,15 +294,12 @@ private Collection buildChainedInputsAndGetHeadInputs( } } - chainEntryPoints.computeIfAbsent( + chainEntryPoints.put( sourceNodeId, - (k) -> new OperatorChainInfo(sourceNodeId, hashes, legacyHashes, chainedSources, streamGraph)); + new OperatorChainInfo(sourceNodeId, hashes, legacyHashes, chainedSources, streamGraph)); } - for (Map.Entry entry : chainEntryPoints.entrySet()) { - chainInfos.put(entry.getKey(), entry.getValue()); - } - return chainEntryPoints.values(); + return chainEntryPoints; } /** @@ -315,17 +310,25 @@ private Collection buildChainedInputsAndGetHeadInputs( private void setChaining(Map hashes, List> legacyHashes) { // we separate out the sources that run as inputs to another operator (chained inputs) // from the sources that needs to run as the main (head) operator. - final Collection entryPoints = buildChainedInputsAndGetHeadInputs(hashes, legacyHashes); + final Map chainEntryPoints = buildChainedInputsAndGetHeadInputs(hashes, legacyHashes); + final Collection initialEntryPoints = new ArrayList<>(chainEntryPoints.values()); - for (OperatorChainInfo info : entryPoints) { + // iterate over a copy of the values, because this map get + for (OperatorChainInfo info : initialEntryPoints) { createChain( info.getStartNodeId(), 1, // operators start at position 1 because 0 is for chained source inputs - info); + info, + chainEntryPoints); } } - private List createChain(Integer currentNodeId, int chainIndex, OperatorChainInfo chainInfo) { + private List createChain( + final Integer currentNodeId, + final int chainIndex, + final OperatorChainInfo chainInfo, + final Map chainEntryPoints) { + Integer startNodeId = chainInfo.getStartNodeId(); if (!builtVertices.contains(startNodeId)) { @@ -346,7 +349,7 @@ private List createChain(Integer currentNodeId, int chainIndex, Oper for (StreamEdge chainable : chainableOutputs) { transitiveOutEdges.addAll( - createChain(chainable.getTargetId(), chainIndex + 1, chainInfo)); + createChain(chainable.getTargetId(), chainIndex + 1, chainInfo, chainEntryPoints)); } for (StreamEdge nonChainable : nonChainableOutputs) { @@ -354,9 +357,10 @@ private List createChain(Integer currentNodeId, int chainIndex, Oper createChain( nonChainable.getTargetId(), 1, // operators start at position 1 because 0 is for chained source inputs - chainInfos.computeIfAbsent( + chainEntryPoints.computeIfAbsent( nonChainable.getTargetId(), - (k) -> chainInfo.newChain(nonChainable.getTargetId()))); + (k) -> chainInfo.newChain(nonChainable.getTargetId())), + chainEntryPoints); } chainedNames.put(currentNodeId, createChainedName(currentNodeId, chainableOutputs)); @@ -540,20 +544,34 @@ private void setVertexConfig( // build the inputs as a combination of source and network inputs final List inEdges = vertex.getInEdges(); - final StreamConfig.InputConfig[] inputConfigs = new StreamConfig.InputConfig[inEdges.size()]; + final TypeSerializer[] inputSerializers = vertex.getTypeSerializersIn(); + + final StreamConfig.InputConfig[] inputConfigs = new StreamConfig.InputConfig[inputSerializers.length]; int inputGateCount = 0; - for (int i = 0; i < inEdges.size(); i++) { - final StreamEdge inEdge = inEdges.get(i); + for (final StreamEdge inEdge : inEdges) { final ChainedSourceInfo chainedSource = chainedSources.get(inEdge.getSourceId()); + + final int inputIndex = inEdge.getTypeNumber() == 0 + ? 0 // single input operator + : inEdge.getTypeNumber() - 1; // in case of 2 or more inputs + if (chainedSource != null) { - inputConfigs[i] = chainedSource.getInputConfig(); + // chained source is the input + if (inputConfigs[inputIndex] != null) { + throw new IllegalStateException("Trying to union a chained source with another input."); + } + inputConfigs[inputIndex] = chainedSource.getInputConfig(); chainedConfigs - .computeIfAbsent(vertexID, (key) -> new HashMap<>()) - .put(inEdge.getSourceId(), chainedSource.getOperatorConfig()); + .computeIfAbsent(vertexID, (key) -> new HashMap<>()) + .put(inEdge.getSourceId(), chainedSource.getOperatorConfig()); } else { - inputConfigs[i] = new StreamConfig.NetworkInputConfig( - vertex.getTypeSerializerIn(Math.max(0, inEdge.getTypeNumber() - 1)), inputGateCount++); + // network input. null if we move to a new input, non-null if this is a further edge + // that is union-ed into the same input + if (inputConfigs[inputIndex] == null) { + inputConfigs[inputIndex] = new StreamConfig.NetworkInputConfig( + inputSerializers[inputIndex], inputGateCount++); + } } } config.setInputs(inputConfigs); @@ -722,10 +740,10 @@ public static boolean isChainable(StreamEdge edge, StreamGraph streamGraph) { StreamNode downStreamVertex = streamGraph.getTargetVertex(edge); return downStreamVertex.getInEdges().size() == 1 - && isChainableIgnoringInEdgesSize(edge, streamGraph); + && isChainableInput(edge, streamGraph); } - private static boolean isChainableIgnoringInEdgesSize(StreamEdge edge, StreamGraph streamGraph) { + private static boolean isChainableInput(StreamEdge edge, StreamGraph streamGraph) { StreamNode upStreamVertex = streamGraph.getSourceVertex(edge); StreamNode downStreamVertex = streamGraph.getTargetVertex(edge); @@ -748,17 +766,47 @@ static boolean areOperatorsChainable( return false; } - if (upStreamOperator.getChainingStrategy() == ChainingStrategy.NEVER || - downStreamOperator.getChainingStrategy() != ChainingStrategy.ALWAYS) { + // yielding operators cannot be chained to legacy sources + // unfortunately the information that vertices have been chained is not preserved at this point + if (downStreamOperator instanceof YieldingOperatorFactory && + getHeadOperator(upStreamVertex, streamGraph).isStreamSource()) { return false; } - // yielding operators cannot be chained to legacy sources - if (downStreamOperator instanceof YieldingOperatorFactory) { - // unfortunately the information that vertices have been chained is not preserved at this point - return !getHeadOperator(upStreamVertex, streamGraph).isStreamSource(); + // we use switch/case here to make sure this is exhaustive if ever values are added to the + // ChainingStrategy enum + boolean isChainable; + + switch (upStreamOperator.getChainingStrategy()) { + case NEVER: + isChainable = false; + break; + case ALWAYS: + case HEAD: + case HEAD_WITH_SOURCES: + isChainable = true; + break; + default: + throw new RuntimeException("Unknown chaining strategy: " + upStreamOperator.getChainingStrategy()); } - return true; + + switch (downStreamOperator.getChainingStrategy()) { + case NEVER: + case HEAD: + isChainable = false; + break; + case ALWAYS: + // keep the value from upstream + break; + case HEAD_WITH_SOURCES: + // only if upstream is a source + isChainable &= (upStreamOperator instanceof SourceOperatorFactory); + break; + default: + throw new RuntimeException("Unknown chaining strategy: " + upStreamOperator.getChainingStrategy()); + } + + return isChainable; } /** @@ -1093,18 +1141,9 @@ private void configureCheckpointing() { jobGraph.setSnapshotSettings(settings); } - private static boolean isMultipleInput(StreamNode vertex) { - StreamOperatorFactory factory = vertex.getOperatorFactory(); - // FIXME super hack! - return factory.getClass().getName().equals( - "org.apache.flink.table.runtime.operators.multipleinput.BatchMultipleInputStreamOperatorFactory") || - factory.getClass().getName().equals( - "org.apache.flink.table.runtime.operators.multipleinput.StreamMultipleInputStreamOperatorFactory"); - } - /** * A private class to help maintain the information of an operator chain during the recursive call in - * {@link #createChain(Integer, int, OperatorChainInfo)}. + * {@link #createChain(Integer, int, OperatorChainInfo, Map)}. */ private static class OperatorChainInfo { private final Integer startNodeId;