Skip to content

Commit

Permalink
[FLINK-19434][DataStream API] Add source input chaining to StreamingJ…
Browse files Browse the repository at this point in the history
…obGraphGenerator (part 3)

Some cleanups in StreamingJobGraphGenerator:

  - checking for chainability via ChainingStrategy
  - passing chain entry points as parameters, not as fields
  - fix input gate index for network unions
  • Loading branch information
StephanEwen committed Oct 2, 2020
1 parent ce4f67b commit 7e1dda8
Showing 1 changed file with 85 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.operators.ResourceSpec;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.IllegalConfigurationException;
Expand Down Expand Up @@ -66,7 +67,6 @@
import org.apache.flink.streaming.runtime.tasks.StreamIterationHead;
import org.apache.flink.streaming.runtime.tasks.StreamIterationTail;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.SerializedValue;

import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -139,8 +139,6 @@ public static JobGraph createJobGraph(StreamGraph streamGraph, @Nullable JobID j
private final StreamGraphHasher defaultStreamGraphHasher;
private final List<StreamGraphHasher> legacyStreamGraphHashers;

private final Map<Integer, OperatorChainInfo> chainInfos;

private StreamingJobGraphGenerator(StreamGraph streamGraph, @Nullable JobID jobID) {
this.streamGraph = streamGraph;
this.defaultStreamGraphHasher = new StreamGraphHasherV2();
Expand All @@ -155,7 +153,6 @@ private StreamingJobGraphGenerator(StreamGraph streamGraph, @Nullable JobID jobI
this.chainedPreferredResources = new HashMap<>();
this.chainedInputOutputFormats = new HashMap<>();
this.physicalEdgesInOrder = new ArrayList<>();
this.chainInfos = new HashMap<>();

jobGraph = new JobGraph(jobID, streamGraph.getJobName());
}
Expand Down Expand Up @@ -260,7 +257,7 @@ private void setPhysicalEdges() {
}
}

private Collection<OperatorChainInfo> buildChainedInputsAndGetHeadInputs(
private Map<Integer, OperatorChainInfo> buildChainedInputsAndGetHeadInputs(
final Map<Integer, byte[]> hashes,
final List<Map<Integer, byte[]>> legacyHashes) {

Expand All @@ -274,8 +271,9 @@ private Collection<OperatorChainInfo> buildChainedInputsAndGetHeadInputs(
// as long as only NAry ops support this chaining, we need to skip the other parts
final StreamEdge sourceOutEdge = sourceNode.getOutEdges().get(0);
final StreamNode target = streamGraph.getStreamNode(sourceOutEdge.getTargetId());
final ChainingStrategy targetChainingStrategy = target.getOperatorFactory().getChainingStrategy();

if (isMultipleInput(target) && isChainableIgnoringInEdgesSize(sourceOutEdge, streamGraph)) {
if (targetChainingStrategy == ChainingStrategy.HEAD_WITH_SOURCES && isChainableInput(sourceOutEdge, streamGraph)) {
final OperatorID opId = new OperatorID(hashes.get(sourceNodeId));
final StreamConfig.SourceInputConfig inputConfig = new StreamConfig.SourceInputConfig(sourceOutEdge);
final StreamConfig operatorConfig = new StreamConfig(new Configuration());
Expand All @@ -296,15 +294,12 @@ private Collection<OperatorChainInfo> buildChainedInputsAndGetHeadInputs(
}
}

chainEntryPoints.computeIfAbsent(
chainEntryPoints.put(
sourceNodeId,
(k) -> new OperatorChainInfo(sourceNodeId, hashes, legacyHashes, chainedSources, streamGraph));
new OperatorChainInfo(sourceNodeId, hashes, legacyHashes, chainedSources, streamGraph));
}

for (Map.Entry<Integer, OperatorChainInfo> entry : chainEntryPoints.entrySet()) {
chainInfos.put(entry.getKey(), entry.getValue());
}
return chainEntryPoints.values();
return chainEntryPoints;
}

/**
Expand All @@ -315,17 +310,25 @@ private Collection<OperatorChainInfo> buildChainedInputsAndGetHeadInputs(
private void setChaining(Map<Integer, byte[]> hashes, List<Map<Integer, byte[]>> legacyHashes) {
// we separate out the sources that run as inputs to another operator (chained inputs)
// from the sources that needs to run as the main (head) operator.
final Collection<OperatorChainInfo> entryPoints = buildChainedInputsAndGetHeadInputs(hashes, legacyHashes);
final Map<Integer, OperatorChainInfo> chainEntryPoints = buildChainedInputsAndGetHeadInputs(hashes, legacyHashes);
final Collection<OperatorChainInfo> initialEntryPoints = new ArrayList<>(chainEntryPoints.values());

for (OperatorChainInfo info : entryPoints) {
// iterate over a copy of the values, because this map get
for (OperatorChainInfo info : initialEntryPoints) {
createChain(
info.getStartNodeId(),
1, // operators start at position 1 because 0 is for chained source inputs
info);
info,
chainEntryPoints);
}
}

private List<StreamEdge> createChain(Integer currentNodeId, int chainIndex, OperatorChainInfo chainInfo) {
private List<StreamEdge> createChain(
final Integer currentNodeId,
final int chainIndex,
final OperatorChainInfo chainInfo,
final Map<Integer, OperatorChainInfo> chainEntryPoints) {

Integer startNodeId = chainInfo.getStartNodeId();
if (!builtVertices.contains(startNodeId)) {

Expand All @@ -346,17 +349,18 @@ private List<StreamEdge> createChain(Integer currentNodeId, int chainIndex, Oper

for (StreamEdge chainable : chainableOutputs) {
transitiveOutEdges.addAll(
createChain(chainable.getTargetId(), chainIndex + 1, chainInfo));
createChain(chainable.getTargetId(), chainIndex + 1, chainInfo, chainEntryPoints));
}

for (StreamEdge nonChainable : nonChainableOutputs) {
transitiveOutEdges.add(nonChainable);
createChain(
nonChainable.getTargetId(),
1, // operators start at position 1 because 0 is for chained source inputs
chainInfos.computeIfAbsent(
chainEntryPoints.computeIfAbsent(
nonChainable.getTargetId(),
(k) -> chainInfo.newChain(nonChainable.getTargetId())));
(k) -> chainInfo.newChain(nonChainable.getTargetId())),
chainEntryPoints);
}

chainedNames.put(currentNodeId, createChainedName(currentNodeId, chainableOutputs));
Expand Down Expand Up @@ -540,20 +544,34 @@ private void setVertexConfig(

// build the inputs as a combination of source and network inputs
final List<StreamEdge> inEdges = vertex.getInEdges();
final StreamConfig.InputConfig[] inputConfigs = new StreamConfig.InputConfig[inEdges.size()];
final TypeSerializer<?>[] inputSerializers = vertex.getTypeSerializersIn();

final StreamConfig.InputConfig[] inputConfigs = new StreamConfig.InputConfig[inputSerializers.length];

int inputGateCount = 0;
for (int i = 0; i < inEdges.size(); i++) {
final StreamEdge inEdge = inEdges.get(i);
for (final StreamEdge inEdge : inEdges) {
final ChainedSourceInfo chainedSource = chainedSources.get(inEdge.getSourceId());

final int inputIndex = inEdge.getTypeNumber() == 0
? 0 // single input operator
: inEdge.getTypeNumber() - 1; // in case of 2 or more inputs

if (chainedSource != null) {
inputConfigs[i] = chainedSource.getInputConfig();
// chained source is the input
if (inputConfigs[inputIndex] != null) {
throw new IllegalStateException("Trying to union a chained source with another input.");
}
inputConfigs[inputIndex] = chainedSource.getInputConfig();
chainedConfigs
.computeIfAbsent(vertexID, (key) -> new HashMap<>())
.put(inEdge.getSourceId(), chainedSource.getOperatorConfig());
.computeIfAbsent(vertexID, (key) -> new HashMap<>())
.put(inEdge.getSourceId(), chainedSource.getOperatorConfig());
} else {
inputConfigs[i] = new StreamConfig.NetworkInputConfig(
vertex.getTypeSerializerIn(Math.max(0, inEdge.getTypeNumber() - 1)), inputGateCount++);
// network input. null if we move to a new input, non-null if this is a further edge
// that is union-ed into the same input
if (inputConfigs[inputIndex] == null) {
inputConfigs[inputIndex] = new StreamConfig.NetworkInputConfig(
inputSerializers[inputIndex], inputGateCount++);
}
}
}
config.setInputs(inputConfigs);
Expand Down Expand Up @@ -722,10 +740,10 @@ public static boolean isChainable(StreamEdge edge, StreamGraph streamGraph) {
StreamNode downStreamVertex = streamGraph.getTargetVertex(edge);

return downStreamVertex.getInEdges().size() == 1
&& isChainableIgnoringInEdgesSize(edge, streamGraph);
&& isChainableInput(edge, streamGraph);
}

private static boolean isChainableIgnoringInEdgesSize(StreamEdge edge, StreamGraph streamGraph) {
private static boolean isChainableInput(StreamEdge edge, StreamGraph streamGraph) {
StreamNode upStreamVertex = streamGraph.getSourceVertex(edge);
StreamNode downStreamVertex = streamGraph.getTargetVertex(edge);

Expand All @@ -748,17 +766,47 @@ static boolean areOperatorsChainable(
return false;
}

if (upStreamOperator.getChainingStrategy() == ChainingStrategy.NEVER ||
downStreamOperator.getChainingStrategy() != ChainingStrategy.ALWAYS) {
// yielding operators cannot be chained to legacy sources
// unfortunately the information that vertices have been chained is not preserved at this point
if (downStreamOperator instanceof YieldingOperatorFactory &&
getHeadOperator(upStreamVertex, streamGraph).isStreamSource()) {
return false;
}

// yielding operators cannot be chained to legacy sources
if (downStreamOperator instanceof YieldingOperatorFactory) {
// unfortunately the information that vertices have been chained is not preserved at this point
return !getHeadOperator(upStreamVertex, streamGraph).isStreamSource();
// we use switch/case here to make sure this is exhaustive if ever values are added to the
// ChainingStrategy enum
boolean isChainable;

switch (upStreamOperator.getChainingStrategy()) {
case NEVER:
isChainable = false;
break;
case ALWAYS:
case HEAD:
case HEAD_WITH_SOURCES:
isChainable = true;
break;
default:
throw new RuntimeException("Unknown chaining strategy: " + upStreamOperator.getChainingStrategy());
}
return true;

switch (downStreamOperator.getChainingStrategy()) {
case NEVER:
case HEAD:
isChainable = false;
break;
case ALWAYS:
// keep the value from upstream
break;
case HEAD_WITH_SOURCES:
// only if upstream is a source
isChainable &= (upStreamOperator instanceof SourceOperatorFactory);
break;
default:
throw new RuntimeException("Unknown chaining strategy: " + upStreamOperator.getChainingStrategy());
}

return isChainable;
}

/**
Expand Down Expand Up @@ -1093,18 +1141,9 @@ private void configureCheckpointing() {
jobGraph.setSnapshotSettings(settings);
}

private static boolean isMultipleInput(StreamNode vertex) {
StreamOperatorFactory<?> factory = vertex.getOperatorFactory();
// FIXME super hack!
return factory.getClass().getName().equals(
"org.apache.flink.table.runtime.operators.multipleinput.BatchMultipleInputStreamOperatorFactory") ||
factory.getClass().getName().equals(
"org.apache.flink.table.runtime.operators.multipleinput.StreamMultipleInputStreamOperatorFactory");
}

/**
* A private class to help maintain the information of an operator chain during the recursive call in
* {@link #createChain(Integer, int, OperatorChainInfo)}.
* {@link #createChain(Integer, int, OperatorChainInfo, Map)}.
*/
private static class OperatorChainInfo {
private final Integer startNodeId;
Expand Down

0 comments on commit 7e1dda8

Please sign in to comment.