From cb77b97495f74a6eaab935c9279b1cf9f5780fa7 Mon Sep 17 00:00:00 2001 From: Thesharing Date: Wed, 24 Feb 2021 10:48:59 +0800 Subject: [PATCH] [FLINK-21326] Optimize the topology building in ExecutionGraph --- .../TaskDeploymentDescriptorFactory.java | 53 ++++--- .../executiongraph/EdgeManagerBuildUtil.java | 143 +++++++++++++++++ .../runtime/executiongraph/Execution.java | 41 +++-- .../runtime/executiongraph/ExecutionEdge.java | 51 ------ .../executiongraph/ExecutionGraph.java | 33 ++++ .../executiongraph/ExecutionJobVertex.java | 8 +- .../executiongraph/ExecutionVertex.java | 145 ++++-------------- .../executiongraph/IntermediateResult.java | 15 -- .../IntermediateResultPartition.java | 28 ++-- .../runtime/jobgraph/DistributionPattern.java | 21 +-- ...raphToInputsLocationsRetrieverAdapter.java | 15 +- .../adapter/DefaultExecutionTopology.java | 11 +- .../runtime/shuffle/PartitionDescriptor.java | 8 +- .../ExecutionGraphTestUtils.java | 13 +- .../ExecutionVertexDeploymentTest.java | 20 +-- .../executiongraph/PointwisePatternTest.java | 83 ++++++---- .../adapter/DefaultExecutionTopologyTest.java | 18 +-- .../partitioner/RescalePartitionerTest.java | 40 +++-- 18 files changed, 396 insertions(+), 350 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java delete mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionEdge.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java index 2f69f9eb3c87c4..3c03704784d515 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java @@ -27,7 +27,6 @@ import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.executiongraph.Execution; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; -import org.apache.flink.runtime.executiongraph.ExecutionEdge; import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.executiongraph.IntermediateResult; @@ -37,6 +36,8 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.shuffle.ShuffleDescriptor; import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor; import org.apache.flink.types.Either; @@ -62,7 +63,7 @@ public class TaskDeploymentDescriptorFactory { private final JobID jobID; private final boolean allowUnknownPartitions; private final int subtaskIndex; - private final ExecutionEdge[][] inputEdges; + private final List> consumedPartitions; private TaskDeploymentDescriptorFactory( ExecutionAttemptID executionId, @@ -72,7 +73,7 @@ private TaskDeploymentDescriptorFactory( JobID jobID, boolean allowUnknownPartitions, int subtaskIndex, - ExecutionEdge[][] inputEdges) { + List> consumedPartitions) { this.executionId = executionId; this.attemptNumber = attemptNumber; this.serializedJobInformation = serializedJobInformation; @@ -80,7 +81,7 @@ private TaskDeploymentDescriptorFactory( this.jobID = jobID; this.allowUnknownPartitions = allowUnknownPartitions; this.subtaskIndex = subtaskIndex; - this.inputEdges = inputEdges; + this.consumedPartitions = consumedPartitions; } public TaskDeploymentDescriptor createDeploymentDescriptor( @@ -101,18 +102,18 @@ public TaskDeploymentDescriptor createDeploymentDescriptor( } private List createInputGateDeploymentDescriptors() { - List inputGates = new ArrayList<>(inputEdges.length); + List inputGates = new ArrayList<>(consumedPartitions.size()); - for (ExecutionEdge[] edges : inputEdges) { + for (List partitions : consumedPartitions) { // If the produced partition has multiple consumers registered, we // need to request the one matching our sub task index. // TODO Refactor after removing the consumers from the intermediate result partitions - int numConsumerEdges = edges[0].getSource().getConsumers().get(0).size(); + IntermediateResultPartition resultPartition = partitions.get(0); - int queueToRequest = subtaskIndex % numConsumerEdges; + int numConsumer = resultPartition.getConsumers().get(0).getVertices().size(); - IntermediateResult consumedIntermediateResult = - edges[0].getSource().getIntermediateResult(); + int queueToRequest = subtaskIndex % numConsumer; + IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult(); IntermediateDataSetID resultId = consumedIntermediateResult.getId(); ResultPartitionType partitionType = consumedIntermediateResult.getResultType(); @@ -121,25 +122,39 @@ private List createInputGateDeploymentDescriptors resultId, partitionType, queueToRequest, - getConsumedPartitionShuffleDescriptors(edges))); + getConsumedPartitionShuffleDescriptors(partitions))); } return inputGates; } - private ShuffleDescriptor[] getConsumedPartitionShuffleDescriptors(ExecutionEdge[] edges) { - ShuffleDescriptor[] shuffleDescriptors = new ShuffleDescriptor[edges.length]; + private ShuffleDescriptor[] getConsumedPartitionShuffleDescriptors( + List partitions) { + + ShuffleDescriptor[] shuffleDescriptors = new ShuffleDescriptor[partitions.size()]; // Each edge is connected to a different result partition - for (int i = 0; i < edges.length; i++) { + for (int i = 0; i < partitions.size(); i++) { shuffleDescriptors[i] = - getConsumedPartitionShuffleDescriptor(edges[i], allowUnknownPartitions); + getConsumedPartitionShuffleDescriptor( + partitions.get(i), allowUnknownPartitions); } return shuffleDescriptors; } public static TaskDeploymentDescriptorFactory fromExecutionVertex( ExecutionVertex executionVertex, int attemptNumber) throws IOException { - ExecutionGraph executionGraph = executionVertex.getExecutionGraph(); + final ExecutionGraph executionGraph = executionVertex.getExecutionGraph(); + + final List> consumedPartitions = new ArrayList<>(); + + for (ConsumedPartitionGroup partitionGroup : executionVertex.getAllConsumedPartitions()) { + List partitions = new ArrayList<>(); + for (IntermediateResultPartitionID partitionId : partitionGroup.getResultPartitions()) { + partitions.add(executionGraph.getResultPartitionOrThrow(partitionId)); + } + consumedPartitions.add(partitions); + } + return new TaskDeploymentDescriptorFactory( executionVertex.getCurrentExecutionAttempt().getAttemptId(), attemptNumber, @@ -149,7 +164,7 @@ public static TaskDeploymentDescriptorFactory fromExecutionVertex( executionGraph.getJobID(), executionGraph.getScheduleMode().allowLazyDeployment(), executionVertex.getParallelSubtaskIndex(), - executionVertex.getAllInputEdges()); + consumedPartitions); } private static MaybeOffloaded getSerializedJobInformation( @@ -171,8 +186,8 @@ private static MaybeOffloaded getSerializedTaskInformation( } public static ShuffleDescriptor getConsumedPartitionShuffleDescriptor( - ExecutionEdge edge, boolean allowUnknownPartitions) { - IntermediateResultPartition consumedPartition = edge.getSource(); + IntermediateResultPartition consumedPartition, boolean allowUnknownPartitions) { + Execution producer = consumedPartition.getProducer().getCurrentExecutionAttempt(); ExecutionState producerState = producer.getState(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java new file mode 100644 index 00000000000000..13b85629bd5b25 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +package org.apache.flink.runtime.executiongraph; + +import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; +import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +/** Utilities for building {@link EdgeManager}. */ +public class EdgeManagerBuildUtil { + + public static void connectVertexToResult( + ExecutionJobVertex vertex, + IntermediateResult ires, + int inputNumber, + DistributionPattern distributionPattern) { + + switch (distributionPattern) { + case POINTWISE: + connectPointwise(vertex.getTaskVertices(), ires, inputNumber); + break; + case ALL_TO_ALL: + connectAllToAll(vertex.getTaskVertices(), ires, inputNumber); + break; + default: + throw new RuntimeException("Unrecognized distribution pattern."); + } + } + + private static void connectAllToAll( + ExecutionVertex[] taskVertices, IntermediateResult ires, int inputNumber) { + + ConsumedPartitionGroup consumedPartitions = + new ConsumedPartitionGroup( + Arrays.stream(ires.getPartitions()) + .map(IntermediateResultPartition::getPartitionId) + .collect(Collectors.toList())); + for (ExecutionVertex ev : taskVertices) { + ev.addConsumedPartitions(consumedPartitions, inputNumber); + } + + ConsumerVertexGroup vertices = + new ConsumerVertexGroup( + Arrays.stream(taskVertices) + .map(ExecutionVertex::getID) + .collect(Collectors.toList())); + for (IntermediateResultPartition partition : ires.getPartitions()) { + partition.addConsumers(vertices); + } + } + + private static void connectPointwise( + ExecutionVertex[] taskVertices, IntermediateResult ires, int inputNumber) { + + final int sourceCount = ires.getPartitions().length; + final int targetCount = taskVertices.length; + + if (sourceCount == targetCount) { + for (int i = 0; i < sourceCount; i++) { + ExecutionVertex executionVertex = taskVertices[i]; + IntermediateResultPartition partition = ires.getPartitions()[i]; + + ConsumerVertexGroup consumerVertexGroup = + new ConsumerVertexGroup(executionVertex.getID()); + partition.addConsumers(consumerVertexGroup); + + ConsumedPartitionGroup consumedPartitionGroup = + new ConsumedPartitionGroup(partition.getPartitionId()); + executionVertex.addConsumedPartitions(consumedPartitionGroup, inputNumber); + } + } else if (sourceCount > targetCount) { + for (int index = 0; index < targetCount; index++) { + + ExecutionVertex executionVertex = taskVertices[index]; + ConsumerVertexGroup consumerVertexGroup = + new ConsumerVertexGroup(executionVertex.getID()); + + int start = index * sourceCount / targetCount; + int end = (index + 1) * sourceCount / targetCount; + + List consumedPartitions = + new ArrayList<>(end - start); + + for (int i = start; i < end; i++) { + IntermediateResultPartition partition = ires.getPartitions()[i]; + partition.addConsumers(consumerVertexGroup); + + consumedPartitions.add(partition.getPartitionId()); + } + + ConsumedPartitionGroup consumedPartitionGroup = + new ConsumedPartitionGroup(consumedPartitions); + executionVertex.addConsumedPartitions(consumedPartitionGroup, inputNumber); + } + } else { + for (int partitionNum = 0; partitionNum < sourceCount; partitionNum++) { + + IntermediateResultPartition partition = ires.getPartitions()[partitionNum]; + ConsumedPartitionGroup consumerPartitionGroup = + new ConsumedPartitionGroup(partition.getPartitionId()); + + float factor = ((float) targetCount) / sourceCount; + int start = (int) (Math.ceil(partitionNum * factor)); + int end = (int) (Math.ceil((partitionNum + 1) * factor)); + + List consumers = new ArrayList<>(end - start); + + for (int i = start; i < end; i++) { + ExecutionVertex executionVertex = taskVertices[i]; + executionVertex.addConsumedPartitions(consumerPartitionGroup, inputNumber); + + consumers.add(executionVertex.getID()); + } + + ConsumerVertexGroup consumerVertexGroup = new ConsumerVertexGroup(consumers); + partition.addConsumers(consumerVertexGroup); + } + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java index 4db17fd377a9e9..5b3593d956c04b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java @@ -49,6 +49,8 @@ import org.apache.flink.runtime.messages.Acknowledge; import org.apache.flink.runtime.operators.coordination.OperatorEvent; import org.apache.flink.runtime.operators.coordination.TaskNotRunningException; +import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.shuffle.NettyShuffleMaster; import org.apache.flink.runtime.shuffle.PartitionDescriptor; import org.apache.flink.runtime.shuffle.ProducerDescriptor; @@ -436,7 +438,9 @@ public CompletableFuture registerProducedPartitions( for (IntermediateResultPartition partition : partitions) { PartitionDescriptor partitionDescriptor = PartitionDescriptor.from(partition); - int maxParallelism = getPartitionMaxParallelism(partition); + int maxParallelism = + getPartitionMaxParallelism( + partition, vertex.getExecutionGraph()::getExecutionVertexOrThrow); CompletableFuture shuffleDescriptorFuture = vertex.getExecutionGraph() .getShuffleMaster() @@ -468,15 +472,15 @@ public CompletableFuture registerProducedPartitions( }); } - private static int getPartitionMaxParallelism(IntermediateResultPartition partition) { - final List> consumers = partition.getConsumers(); + private static int getPartitionMaxParallelism( + IntermediateResultPartition partition, + Function getVertexById) { + final List consumers = partition.getConsumers(); Preconditions.checkArgument( - !consumers.isEmpty(), + consumers.size() == 1, "Currently there has to be exactly one consumer in real jobs"); - List consumer = consumers.get(0); - ExecutionJobVertex consumerVertex = consumer.get(0).getTarget().getJobVertex(); - int maxParallelism = consumerVertex.getMaxParallelism(); - return maxParallelism; + final List consumerIds = consumers.get(0).getVertices(); + return getVertexById.apply(consumerIds.get(0)).getJobVertex().getMaxParallelism(); } /** @@ -707,7 +711,10 @@ public CompletableFuture suspend() { return releaseFuture; } - private void updatePartitionConsumers(final List> allConsumers) { + private void updatePartitionConsumers(final IntermediateResultPartition partition) { + + final List allConsumers = partition.getConsumers(); + if (allConsumers.size() == 0) { return; } @@ -718,8 +725,9 @@ private void updatePartitionConsumers(final List> allConsume return; } - for (ExecutionEdge edge : allConsumers.get(0)) { - final ExecutionVertex consumerVertex = edge.getTarget(); + for (ExecutionVertexID consumerVertexId : allConsumers.get(0).getVertices()) { + final ExecutionVertex consumerVertex = + vertex.getExecutionGraph().getExecutionVertexOrThrow(consumerVertexId); final Execution consumer = consumerVertex.getCurrentExecutionAttempt(); final ExecutionState consumerState = consumer.getState(); @@ -729,7 +737,7 @@ private void updatePartitionConsumers(final List> allConsume // sent after switching to running // ---------------------------------------------------------------- if (consumerState == DEPLOYING || consumerState == RUNNING) { - final PartitionInfo partitionInfo = createPartitionInfo(edge); + final PartitionInfo partitionInfo = createPartitionInfo(partition); if (consumerState == DEPLOYING) { consumerVertex.cachePartitionInfo(partitionInfo); @@ -740,11 +748,12 @@ private void updatePartitionConsumers(final List> allConsume } } - private static PartitionInfo createPartitionInfo(ExecutionEdge executionEdge) { + private static PartitionInfo createPartitionInfo( + IntermediateResultPartition consumedPartition) { IntermediateDataSetID intermediateDataSetID = - executionEdge.getSource().getIntermediateResult().getId(); + consumedPartition.getIntermediateResult().getId(); ShuffleDescriptor shuffleDescriptor = - getConsumedPartitionShuffleDescriptor(executionEdge, false); + getConsumedPartitionShuffleDescriptor(consumedPartition, false); return new PartitionInfo(intermediateDataSetID, shuffleDescriptor); } @@ -954,7 +963,7 @@ private void finishPartitionsAndUpdateConsumers() { finishedPartition.getIntermediateResult().getPartitions(); for (IntermediateResultPartition partition : allPartitionsOfNewlyFinishedResults) { - updatePartitionConsumers(partition.getConsumers()); + updatePartitionConsumers(partition); } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionEdge.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionEdge.java deleted file mode 100644 index ab949f29947b00..00000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionEdge.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.executiongraph; - -public class ExecutionEdge { - - private final IntermediateResultPartition source; - - private final ExecutionVertex target; - - private final int inputNum; - - public ExecutionEdge(IntermediateResultPartition source, ExecutionVertex target, int inputNum) { - this.source = source; - this.target = target; - this.inputNum = inputNum; - } - - public IntermediateResultPartition getSource() { - return source; - } - - public ExecutionVertex getTarget() { - return target; - } - - public int getInputNum() { - return inputNum; - } - - @Override - public String toString() { - return "ExecutionEdge [" + source + " <=> " + target + "]"; - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java index 9e0f6d591ccf3f..ed2a5f841ad130 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java @@ -279,6 +279,12 @@ public class ExecutionGraph implements AccessExecutionGraph { private final ExecutionDeploymentListener executionDeploymentListener; private final ExecutionStateUpdateListener executionStateUpdateListener; + private final EdgeManager edgeManager; + + private final Map executionVerticesById; + private final Map + resultPartitionsById; + // -------------------------------------------------------------------------------------------- // Constructors // -------------------------------------------------------------------------------------------- @@ -351,6 +357,10 @@ public ExecutionGraph( this.executionDeploymentListener = executionDeploymentListener; this.executionStateUpdateListener = executionStateUpdateListener; + + this.edgeManager = new EdgeManager(); + this.executionVerticesById = new HashMap<>(); + this.resultPartitionsById = new HashMap<>(); } public void start(@Nonnull ComponentMainThreadExecutor jobMasterMainThreadExecutor) { @@ -668,6 +678,29 @@ public Iterator iterator() { }; } + public EdgeManager getEdgeManager() { + return edgeManager; + } + + public void registerExecutionVertex(ExecutionVertexID id, ExecutionVertex vertex) { + executionVerticesById.put(id, vertex); + } + + public void registerResultPartition( + IntermediateResultPartitionID id, IntermediateResultPartition partition) { + + resultPartitionsById.put(id, partition); + } + + public ExecutionVertex getExecutionVertexOrThrow(ExecutionVertexID id) { + return checkNotNull(executionVerticesById.get(id)); + } + + public IntermediateResultPartition getResultPartitionOrThrow( + final IntermediateResultPartitionID id) { + return checkNotNull(resultPartitionsById.get(id)); + } + @Override public long getStatusTimestamp(JobStatus status) { return this.stateTimestamps[status.ordinal()]; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java index 8e844134b4310e..1d5a0cec80af4e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java @@ -70,6 +70,7 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil.connectVertexToResult; import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -454,12 +455,7 @@ public void connectToPredecessors( this.inputs.add(ires); - int consumerIndex = ires.registerConsumer(); - - for (int i = 0; i < parallelism; i++) { - ExecutionVertex ev = taskVertices[i]; - ev.connectSource(num, ires, edge, consumerIndex); - } + connectVertexToResult(this, ires, num, edge.getDistributionPattern()); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java index 1bc0fda348acf3..2e6918ef80116a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java @@ -29,11 +29,10 @@ import org.apache.flink.runtime.clusterframework.types.ResourceProfile; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; -import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; -import org.apache.flink.runtime.jobgraph.JobEdge; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobmaster.LogicalSlot; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.util.EvictingBoundedList; @@ -75,8 +74,6 @@ public class ExecutionVertex private final Map resultPartitions; - private final ExecutionEdge[][] inputEdges; - private final int subTaskIndex; private final ExecutionVertexID executionVertexId; @@ -133,7 +130,7 @@ public ExecutionVertex( resultPartitions.put(irp.getPartitionId(), irp); } - this.inputEdges = new ExecutionEdge[jobVertex.getJobVertex().getInputs().size()][]; + getExecutionGraph().registerExecutionVertex(executionVertexId, this); this.priorExecutions = new EvictingBoundedList<>(maxPriorExecutionHistoryLength); @@ -202,19 +199,24 @@ public ExecutionVertexID getID() { } public int getNumberOfInputs() { - return this.inputEdges.length; + return getAllConsumedPartitions().size(); + } + + public List getAllConsumedPartitions() { + return getExecutionGraph().getEdgeManager().getVertexConsumedPartitions(executionVertexId); } - public ExecutionEdge[] getInputEdges(int input) { - if (input < 0 || input >= inputEdges.length) { + public ConsumedPartitionGroup getConsumedPartitions(int input) { + final List allConsumedPartitions = getAllConsumedPartitions(); + + if (input < 0 || input >= allConsumedPartitions.size()) { throw new IllegalArgumentException( - String.format("Input %d is out of range [0..%d)", input, inputEdges.length)); + String.format( + "Input %d is out of range [0..%d)", + input, allConsumedPartitions.size())); } - return inputEdges[input]; - } - public ExecutionEdge[][] getAllInputEdges() { - return inputEdges; + return allConsumedPartitions.get(input); } public InputSplit getNextInputSplit(String host) { @@ -321,107 +323,11 @@ public Map getProduc // Graph building // -------------------------------------------------------------------------------------------- - public void connectSource( - int inputNumber, IntermediateResult source, JobEdge edge, int consumerNumber) { - - final DistributionPattern pattern = edge.getDistributionPattern(); - final IntermediateResultPartition[] sourcePartitions = source.getPartitions(); - - ExecutionEdge[] edges; - - switch (pattern) { - case POINTWISE: - edges = connectPointwise(sourcePartitions, inputNumber); - break; - - case ALL_TO_ALL: - edges = connectAllToAll(sourcePartitions, inputNumber); - break; - - default: - throw new RuntimeException("Unrecognized distribution pattern."); - } - - inputEdges[inputNumber] = edges; - - // add the consumers to the source - // for now (until the receiver initiated handshake is in place), we need to register the - // edges as the execution graph - for (ExecutionEdge ee : edges) { - ee.getSource().addConsumer(ee, consumerNumber); - } - } - - private ExecutionEdge[] connectAllToAll( - IntermediateResultPartition[] sourcePartitions, int inputNumber) { - ExecutionEdge[] edges = new ExecutionEdge[sourcePartitions.length]; - - for (int i = 0; i < sourcePartitions.length; i++) { - IntermediateResultPartition irp = sourcePartitions[i]; - edges[i] = new ExecutionEdge(irp, this, inputNumber); - } - - return edges; - } - - private ExecutionEdge[] connectPointwise( - IntermediateResultPartition[] sourcePartitions, int inputNumber) { - final int numSources = sourcePartitions.length; - final int parallelism = getTotalNumberOfParallelSubtasks(); - - // simple case same number of sources as targets - if (numSources == parallelism) { - return new ExecutionEdge[] { - new ExecutionEdge(sourcePartitions[subTaskIndex], this, inputNumber) - }; - } else if (numSources < parallelism) { - - int sourcePartition; - - // check if the pattern is regular or irregular - // we use int arithmetics for regular, and floating point with rounding for irregular - if (parallelism % numSources == 0) { - // same number of targets per source - int factor = parallelism / numSources; - sourcePartition = subTaskIndex / factor; - } else { - // different number of targets per source - float factor = ((float) parallelism) / numSources; - sourcePartition = (int) (subTaskIndex / factor); - } - - return new ExecutionEdge[] { - new ExecutionEdge(sourcePartitions[sourcePartition], this, inputNumber) - }; - } else { - if (numSources % parallelism == 0) { - // same number of targets per source - int factor = numSources / parallelism; - int startIndex = subTaskIndex * factor; - - ExecutionEdge[] edges = new ExecutionEdge[factor]; - for (int i = 0; i < factor; i++) { - edges[i] = - new ExecutionEdge(sourcePartitions[startIndex + i], this, inputNumber); - } - return edges; - } else { - float factor = ((float) numSources) / parallelism; + public void addConsumedPartitions(ConsumedPartitionGroup consumedPartitions, int inputNum) { - int start = (int) (subTaskIndex * factor); - int end = - (subTaskIndex == getTotalNumberOfParallelSubtasks() - 1) - ? sourcePartitions.length - : (int) ((subTaskIndex + 1) * factor); - - ExecutionEdge[] edges = new ExecutionEdge[end - start]; - for (int i = 0; i < edges.length; i++) { - edges[i] = new ExecutionEdge(sourcePartitions[start + i], this, inputNumber); - } - - return edges; - } - } + getExecutionGraph() + .getEdgeManager() + .addVertexConsumedPartitions(executionVertexId, consumedPartitions, inputNum); } /** @@ -493,8 +399,10 @@ public Optional getPreferredLocationBasedOnState() { * input-based preference. */ public Collection> getPreferredLocationsBasedOnInputs() { + final List allConsumedPartitions = getAllConsumedPartitions(); + // otherwise, base the preferred locations on the input connections - if (inputEdges == null) { + if (allConsumedPartitions == null) { return Collections.emptySet(); } else { Set> locations = @@ -503,16 +411,15 @@ public Collection> getPreferredLocationsB new HashSet<>(getTotalNumberOfParallelSubtasks()); // go over all inputs - for (int i = 0; i < inputEdges.length; i++) { + for (ConsumedPartitionGroup sources : allConsumedPartitions) { inputLocations.clear(); - ExecutionEdge[] sources = inputEdges[i]; if (sources != null) { // go over all input sources - for (int k = 0; k < sources.length; k++) { + for (IntermediateResultPartitionID sourceId : sources.getResultPartitions()) { // look-up assigned slot of input source CompletableFuture locationFuture = - sources[k] - .getSource() + getExecutionGraph() + .getResultPartitionOrThrow(sourceId) .getProducer() .getCurrentTaskManagerLocationFuture(); // add input location diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java index 0db89206454567..988dfd87e1d7b9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java @@ -51,8 +51,6 @@ public class IntermediateResult { private int partitionsAssigned; - private int numConsumers; - private final int connectionIndex; private final ResultPartitionType resultType; @@ -145,19 +143,6 @@ public ResultPartitionType getResultType() { return resultType; } - public int registerConsumer() { - final int index = numConsumers; - numConsumers++; - - for (IntermediateResultPartition p : partitions) { - if (p.addConsumerGroup() != index) { - throw new RuntimeException( - "Inconsistent consumer mapping between intermediate result partitions."); - } - } - return index; - } - public int getConnectionIndex() { return connectionIndex; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java index e61a362def6cbc..0493d161839453 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java @@ -20,8 +20,8 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup; -import java.util.ArrayList; import java.util.List; public class IntermediateResultPartition { @@ -32,8 +32,6 @@ public class IntermediateResultPartition { private final IntermediateResultPartitionID partitionId; - private List> consumers; - /** Whether this partition has produced some data. */ private boolean hasDataProduced = false; @@ -41,8 +39,9 @@ public IntermediateResultPartition( IntermediateResult totalResult, ExecutionVertex producer, int partitionNumber) { this.totalResult = totalResult; this.producer = producer; - this.consumers = new ArrayList>(0); this.partitionId = new IntermediateResultPartitionID(totalResult.getId(), partitionNumber); + + producer.getExecutionGraph().registerResultPartition(partitionId, this); } public ExecutionVertex getProducer() { @@ -65,8 +64,8 @@ public ResultPartitionType getResultType() { return totalResult.getResultType(); } - public List> getConsumers() { - return consumers; + public List getConsumers() { + return getEdgeManager().getPartitionConsumers(partitionId); } public void markDataProduced() { @@ -90,21 +89,12 @@ void resetForNewExecution() { hasDataProduced = false; } - int addConsumerGroup() { - int pos = consumers.size(); - - // NOTE: currently we support only one consumer per result!!! - if (pos != 0) { - throw new RuntimeException( - "Currently, each intermediate result can only have one consumer."); - } - - consumers.add(new ArrayList()); - return pos; + public void addConsumers(ConsumerVertexGroup consumers) { + getEdgeManager().addPartitionConsumers(partitionId, consumers); } - void addConsumer(ExecutionEdge edge, int consumerNumber) { - consumers.get(consumerNumber).add(edge); + private EdgeManager getEdgeManager() { + return producer.getExecutionGraph().getEdgeManager(); } boolean markFinished() { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/DistributionPattern.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/DistributionPattern.java index f054311b53f093..e4a4796f4c6c1f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/DistributionPattern.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/DistributionPattern.java @@ -18,29 +18,22 @@ package org.apache.flink.runtime.jobgraph; +import org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil; import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; /** * A distribution pattern determines, which sub tasks of a producing task are connected to which * consuming sub tasks. + * + *

It affects how {@link ExecutionVertex} and {@link IntermediateResultPartition} are connected + * in {@link EdgeManagerBuildUtil} */ public enum DistributionPattern { - /** - * Each producing sub task is connected to each sub task of the consuming task. - * - *

{@link - * ExecutionVertex#connectAllToAll(org.apache.flink.runtime.executiongraph.IntermediateResultPartition[], - * int)} - */ + /** Each producing sub task is connected to each sub task of the consuming task. */ ALL_TO_ALL, - /** - * Each producing sub task is connected to one or more subtask(s) of the consuming task. - * - *

{@link - * ExecutionVertex#connectPointwise(org.apache.flink.runtime.executiongraph.IntermediateResultPartition[], - * int)} - */ + /** Each producing sub task is connected to one or more subtask(s) of the consuming task. */ POINTWISE } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapter.java index 23a1380f3a553c..b46ee91a2f3b5f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapter.java @@ -19,10 +19,11 @@ package org.apache.flink.runtime.scheduler; import org.apache.flink.runtime.execution.ExecutionState; -import org.apache.flink.runtime.executiongraph.ExecutionEdge; import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; @@ -51,11 +52,13 @@ public Collection> getConsumedResultPartitionsProd List> resultPartitionProducers = new ArrayList<>(ev.getNumberOfInputs()); - for (int i = 0; i < ev.getNumberOfInputs(); i++) { - ExecutionEdge[] inputEdges = ev.getInputEdges(i); - List producers = new ArrayList<>(inputEdges.length); - for (ExecutionEdge inputEdge : inputEdges) { - ExecutionVertex producer = inputEdge.getSource().getProducer(); + for (ConsumedPartitionGroup consumedPartitions : ev.getAllConsumedPartitions()) { + List producers = + new ArrayList<>(consumedPartitions.getResultPartitions().size()); + for (IntermediateResultPartitionID consumedPartitionId : + consumedPartitions.getResultPartitions()) { + ExecutionVertex producer = + executionGraph.getResultPartitionOrThrow(consumedPartitionId).getProducer(); producers.add(producer.getID()); } resultPartitionProducers.add(producers); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java index c10233c4292f2f..50112920523f9e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.scheduler.adapter; -import org.apache.flink.runtime.executiongraph.ExecutionEdge; import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; @@ -26,6 +25,7 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobmanager.scheduler.CoLocationConstraint; import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState; import org.apache.flink.runtime.scheduler.strategy.SchedulingExecutionVertex; @@ -211,10 +211,11 @@ private static void connectVerticesToConsumedPartitions( final DefaultExecutionVertex schedulingVertex = mapEntry.getValue(); final ExecutionVertex executionVertex = mapEntry.getKey(); - for (int index = 0; index < executionVertex.getNumberOfInputs(); index++) { - for (ExecutionEdge edge : executionVertex.getInputEdges(index)) { - DefaultResultPartition partition = - resultPartitions.get(edge.getSource().getPartitionId()); + for (ConsumedPartitionGroup consumedPartitionGroup : + executionVertex.getAllConsumedPartitions()) { + for (IntermediateResultPartitionID consumedPartition : + consumedPartitionGroup.getResultPartitions()) { + DefaultResultPartition partition = resultPartitions.get(consumedPartition); schedulingVertex.addConsumedResult(partition); partition.addConsumer(schedulingVertex); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/PartitionDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/PartitionDescriptor.java index 4fd2e1c0a3ba30..7a127793c79eb9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/PartitionDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/PartitionDescriptor.java @@ -19,12 +19,12 @@ package org.apache.flink.runtime.shuffle; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.runtime.executiongraph.ExecutionEdge; import org.apache.flink.runtime.executiongraph.IntermediateResult; import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup; import java.io.Serializable; import java.util.List; @@ -113,13 +113,13 @@ public static PartitionDescriptor from(IntermediateResultPartition partition) { // If no consumers are known at this point, we use a single subpartition, otherwise we have // one for each consuming sub task. int numberOfSubpartitions = 1; - List> consumers = partition.getConsumers(); - if (!consumers.isEmpty() && !consumers.get(0).isEmpty()) { + List consumers = partition.getConsumers(); + if (!consumers.isEmpty() && !consumers.get(0).getVertices().isEmpty()) { if (consumers.size() > 1) { throw new IllegalStateException( "Currently, only a single consumer group per partition is supported."); } - numberOfSubpartitions = consumers.get(0).size(); + numberOfSubpartitions = consumers.get(0).getVertices().size(); } IntermediateResult result = partition.getIntermediateResult(); return new PartitionDescriptor( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java index 31e7f360f0043c..bff132db9425bd 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter; import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -509,13 +510,15 @@ public static void verifyGeneratedExecutionJobVertex( assertEquals(inputJobVertices.size(), ev.getNumberOfInputs()); for (int i = 0; i < inputJobVertices.size(); i++) { - ExecutionEdge[] inputEdges = ev.getInputEdges(i); - assertEquals(inputJobVertices.get(i).getParallelism(), inputEdges.length); + List consumedPartitions = + ev.getConsumedPartitions(i).getResultPartitions(); + assertEquals( + inputJobVertices.get(i).getParallelism(), consumedPartitions.size()); int expectedPartitionNum = 0; - for (ExecutionEdge inEdge : inputEdges) { - assertEquals(i, inEdge.getInputNum()); - assertEquals(expectedPartitionNum, inEdge.getSource().getPartitionNumber()); + for (IntermediateResultPartitionID consumedPartitionId : consumedPartitions) { + assertEquals( + expectedPartitionNum, consumedPartitionId.getPartitionNumber()); expectedPartitionNum++; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexDeploymentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexDeploymentTest.java index 9f60818189a145..42552b4454d9e8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexDeploymentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexDeploymentTest.java @@ -33,6 +33,7 @@ import org.apache.flink.runtime.jobmaster.TestingLogicalSlot; import org.apache.flink.runtime.jobmaster.TestingLogicalSlotBuilder; import org.apache.flink.runtime.messages.Acknowledge; +import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.testtasks.NoOpInvokable; import org.apache.flink.runtime.testutils.DirectScheduledExecutorService; @@ -52,8 +53,6 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class ExecutionVertexDeploymentTest extends TestLogger { @@ -304,10 +303,7 @@ public void testTddProducedPartitionsLazyScheduling() throws Exception { TaskDeploymentDescriptorFactory tddFactory = TaskDeploymentDescriptorFactory.fromExecutionVertex(vertex, 1); - ExecutionEdge mockEdge = createMockExecutionEdge(1); - - result.getPartitions()[0].addConsumerGroup(); - result.getPartitions()[0].addConsumer(mockEdge, 0); + result.getPartitions()[0].addConsumers(new ConsumerVertexGroup(vertex.getID())); TaskManagerLocation location = new TaskManagerLocation( @@ -333,16 +329,4 @@ public void testTddProducedPartitionsLazyScheduling() throws Exception { assertEquals(scheduleMode.allowLazyDeployment(), desc.notifyPartitionDataAvailable()); } } - - private ExecutionEdge createMockExecutionEdge(int maxParallelism) { - ExecutionVertex targetVertex = mock(ExecutionVertex.class); - ExecutionJobVertex targetJobVertex = mock(ExecutionJobVertex.class); - - when(targetVertex.getJobVertex()).thenReturn(targetJobVertex); - when(targetJobVertex.getMaxParallelism()).thenReturn(maxParallelism); - - ExecutionEdge edge = mock(ExecutionEdge.class); - when(edge.getTarget()).thenReturn(targetVertex); - return edge; - } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/PointwisePatternTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/PointwisePatternTest.java index 3c05cd955da872..2d4f91f4fbeaf5 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/PointwisePatternTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/PointwisePatternTest.java @@ -21,14 +21,18 @@ import org.apache.flink.runtime.JobException; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.junit.Test; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.List; +import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -45,10 +49,12 @@ public void testNToN() throws Exception { for (ExecutionVertex ev : target.getTaskVertices()) { assertEquals(1, ev.getNumberOfInputs()); - ExecutionEdge[] inEdges = ev.getInputEdges(0); - assertEquals(1, inEdges.length); + List consumedPartitions = + ev.getConsumedPartitions(0).getResultPartitions(); + assertEquals(1, consumedPartitions.size()); - assertEquals(ev.getParallelSubtaskIndex(), inEdges[0].getSource().getPartitionNumber()); + assertEquals( + ev.getParallelSubtaskIndex(), consumedPartitions.get(0).getPartitionNumber()); } } @@ -61,14 +67,16 @@ public void test2NToN() throws Exception { for (ExecutionVertex ev : target.getTaskVertices()) { assertEquals(1, ev.getNumberOfInputs()); - ExecutionEdge[] inEdges = ev.getInputEdges(0); - assertEquals(2, inEdges.length); + List consumedPartitions = + ev.getConsumedPartitions(0).getResultPartitions(); + assertEquals(2, consumedPartitions.size()); assertEquals( - ev.getParallelSubtaskIndex() * 2, inEdges[0].getSource().getPartitionNumber()); + ev.getParallelSubtaskIndex() * 2, + consumedPartitions.get(0).getPartitionNumber()); assertEquals( ev.getParallelSubtaskIndex() * 2 + 1, - inEdges[1].getSource().getPartitionNumber()); + consumedPartitions.get(1).getPartitionNumber()); } } @@ -81,17 +89,19 @@ public void test3NToN() throws Exception { for (ExecutionVertex ev : target.getTaskVertices()) { assertEquals(1, ev.getNumberOfInputs()); - ExecutionEdge[] inEdges = ev.getInputEdges(0); - assertEquals(3, inEdges.length); + List consumedPartitions = + ev.getConsumedPartitions(0).getResultPartitions(); + assertEquals(3, consumedPartitions.size()); assertEquals( - ev.getParallelSubtaskIndex() * 3, inEdges[0].getSource().getPartitionNumber()); + ev.getParallelSubtaskIndex() * 3, + consumedPartitions.get(0).getPartitionNumber()); assertEquals( ev.getParallelSubtaskIndex() * 3 + 1, - inEdges[1].getSource().getPartitionNumber()); + consumedPartitions.get(1).getPartitionNumber()); assertEquals( ev.getParallelSubtaskIndex() * 3 + 2, - inEdges[2].getSource().getPartitionNumber()); + consumedPartitions.get(2).getPartitionNumber()); } } @@ -104,11 +114,13 @@ public void testNTo2N() throws Exception { for (ExecutionVertex ev : target.getTaskVertices()) { assertEquals(1, ev.getNumberOfInputs()); - ExecutionEdge[] inEdges = ev.getInputEdges(0); - assertEquals(1, inEdges.length); + List consumedPartitions = + ev.getConsumedPartitions(0).getResultPartitions(); + assertEquals(1, consumedPartitions.size()); assertEquals( - ev.getParallelSubtaskIndex() / 2, inEdges[0].getSource().getPartitionNumber()); + ev.getParallelSubtaskIndex() / 2, + consumedPartitions.get(0).getPartitionNumber()); } } @@ -121,11 +133,13 @@ public void testNTo7N() throws Exception { for (ExecutionVertex ev : target.getTaskVertices()) { assertEquals(1, ev.getNumberOfInputs()); - ExecutionEdge[] inEdges = ev.getInputEdges(0); - assertEquals(1, inEdges.length); + List consumedPartitions = + ev.getConsumedPartitions(0).getResultPartitions(); + assertEquals(1, consumedPartitions.size()); assertEquals( - ev.getParallelSubtaskIndex() / 7, inEdges[0].getSource().getPartitionNumber()); + ev.getParallelSubtaskIndex() / 7, + consumedPartitions.get(0).getPartitionNumber()); } } @@ -183,10 +197,11 @@ private void testLowToHigh(int lowDop, int highDop) throws Exception { for (ExecutionVertex ev : target.getTaskVertices()) { assertEquals(1, ev.getNumberOfInputs()); - ExecutionEdge[] inEdges = ev.getInputEdges(0); - assertEquals(1, inEdges.length); + List consumedPartitions = + ev.getConsumedPartitions(0).getResultPartitions(); + assertEquals(1, consumedPartitions.size()); - timesUsed[inEdges[0].getSource().getPartitionNumber()]++; + timesUsed[consumedPartitions.get(0).getPartitionNumber()]++; } for (int used : timesUsed) { @@ -209,14 +224,19 @@ private void testHighToLow(int highDop, int lowDop) throws Exception { for (ExecutionVertex ev : target.getTaskVertices()) { assertEquals(1, ev.getNumberOfInputs()); - ExecutionEdge[] inEdges = ev.getInputEdges(0); - assertTrue(inEdges.length >= factor && inEdges.length <= factor + delta); - - for (ExecutionEdge ee : inEdges) { - timesUsed[ee.getSource().getPartitionNumber()]++; + List consumedPartitions = + ev.getAllConsumedPartitions().stream() + .map(ConsumedPartitionGroup::getResultPartitions) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + assertTrue( + consumedPartitions.size() >= factor + && consumedPartitions.size() <= factor + delta); + + for (IntermediateResultPartitionID consumedPartition : consumedPartitions) { + timesUsed[consumedPartition.getPartitionNumber()]++; } } - for (int used : timesUsed) { assertEquals(1, used); } @@ -260,14 +280,15 @@ private void testConnections( for (int vertexIndex = 0; vertexIndex < target.getTaskVertices().length; vertexIndex++) { ExecutionVertex ev = target.getTaskVertices()[vertexIndex]; - ExecutionEdge[] edges = ev.getInputEdges(0); + List partitionIds = + ev.getConsumedPartitions(0).getResultPartitions(); - assertEquals(expectedConsumedPartitionNumber[vertexIndex].length, edges.length); + assertEquals(expectedConsumedPartitionNumber[vertexIndex].length, partitionIds.size()); - for (int partitionIndex = 0; partitionIndex < edges.length; partitionIndex++) { + for (int partitionIndex = 0; partitionIndex < partitionIds.size(); partitionIndex++) { assertEquals( expectedConsumedPartitionNumber[vertexIndex][partitionIndex], - edges[partitionIndex].getSource().getPartitionNumber()); + partitionIds.get(partitionIndex).getPartitionNumber()); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java index a7d6bbbc5a5f24..a1dc6b5b96286c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.scheduler.adapter; -import org.apache.flink.runtime.executiongraph.ExecutionEdge; import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; @@ -26,6 +25,8 @@ import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; +import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup; import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState; import org.apache.flink.util.IterableUtils; @@ -37,7 +38,6 @@ import org.junit.Before; import org.junit.Test; -import java.util.Arrays; import java.util.Collection; import java.util.Iterator; import java.util.List; @@ -182,9 +182,10 @@ private static void assertGraphEquals( List originalConsumedPartitions = IntStream.range(0, originalVertex.getNumberOfInputs()) - .mapToObj(originalVertex::getInputEdges) - .flatMap(Arrays::stream) - .map(ExecutionEdge::getSource) + .mapToObj(originalVertex::getConsumedPartitions) + .map(ConsumedPartitionGroup::getResultPartitions) + .flatMap(Collection::stream) + .map(originalGraph::getResultPartitionOrThrow) .collect(Collectors.toList()); Iterable adaptedConsumedPartitions = adaptedVertex.getConsumedResults(); @@ -227,20 +228,19 @@ private static void assertPartitionsEquals( assertPartitionEquals(originalPartition, adaptedPartition); - List originalConsumers = + List originalConsumerIds = originalPartition.getConsumers().stream() + .map(ConsumerVertexGroup::getVertices) .flatMap(Collection::stream) - .map(ExecutionEdge::getTarget) .collect(Collectors.toList()); Iterable adaptedConsumers = adaptedPartition.getConsumers(); - for (ExecutionVertex originalConsumer : originalConsumers) { + for (ExecutionVertexID originalId : originalConsumerIds) { // it is sufficient to verify that some vertex exists with the correct ID here, // since deep equality is verified later in the main loop // this DOES rely on an implicit assumption that the vertices objects returned by // the topology are // identical to those stored in the partition - ExecutionVertexID originalId = originalConsumer.getID(); assertTrue( IterableUtils.toStream(adaptedConsumers) .anyMatch( diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java index b920f64bab672a..e0bd6087822d42 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java @@ -21,11 +21,12 @@ import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.runtime.JobException; -import org.apache.flink.runtime.executiongraph.ExecutionEdge; import org.apache.flink.runtime.executiongraph.ExecutionGraph; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; import org.apache.flink.runtime.executiongraph.TestingExecutionGraphBuilder; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.streaming.api.datastream.DataStream; @@ -138,11 +139,17 @@ public void flatMap( Map mapInputPartitionCounts = new HashMap<>(); for (ExecutionVertex mapTaskVertex : mapTaskVertices) { assertEquals(1, mapTaskVertex.getNumberOfInputs()); - assertEquals(1, mapTaskVertex.getInputEdges(0).length); - ExecutionEdge inputEdge = mapTaskVertex.getInputEdges(0)[0]; + assertEquals(1, mapTaskVertex.getConsumedPartitions(0).getResultPartitions().size()); + IntermediateResultPartitionID consumedPartitionId = + mapTaskVertex.getConsumedPartitions(0).getResultPartitions().get(0); assertEquals( - sourceVertex.getID(), inputEdge.getSource().getProducer().getJobvertexId()); - int inputPartition = inputEdge.getSource().getPartitionNumber(); + sourceVertex.getID(), + mapTaskVertex + .getExecutionGraph() + .getResultPartitionOrThrow(consumedPartitionId) + .getProducer() + .getJobvertexId()); + int inputPartition = consumedPartitionId.getPartitionNumber(); if (!mapInputPartitionCounts.containsKey(inputPartition)) { mapInputPartitionCounts.put(inputPartition, 1); } else { @@ -159,22 +166,29 @@ public void flatMap( assertEquals(1, execSinkVertex.getInputs().size()); assertEquals(2, execSinkVertex.getParallelism()); ExecutionVertex[] sinkTaskVertices = execSinkVertex.getTaskVertices(); + ExecutionGraph executionGraph = execSinkVertex.getGraph(); // verify each sink instance has two inputs from the map and that each map subpartition // only occurs in one unique input edge Set mapSubpartitions = new HashSet<>(); for (ExecutionVertex sinkTaskVertex : sinkTaskVertices) { assertEquals(1, sinkTaskVertex.getNumberOfInputs()); - assertEquals(2, sinkTaskVertex.getInputEdges(0).length); - ExecutionEdge inputEdge1 = sinkTaskVertex.getInputEdges(0)[0]; - ExecutionEdge inputEdge2 = sinkTaskVertex.getInputEdges(0)[1]; - assertEquals(mapVertex.getID(), inputEdge1.getSource().getProducer().getJobvertexId()); - assertEquals(mapVertex.getID(), inputEdge2.getSource().getProducer().getJobvertexId()); - - int inputPartition1 = inputEdge1.getSource().getPartitionNumber(); + assertEquals(2, sinkTaskVertex.getConsumedPartitions(0).getResultPartitions().size()); + IntermediateResultPartitionID consumedPartitionId1 = + sinkTaskVertex.getConsumedPartitions(0).getResultPartitions().get(0); + IntermediateResultPartition consumedPartition1 = + executionGraph.getResultPartitionOrThrow(consumedPartitionId1); + IntermediateResultPartitionID consumedPartitionId2 = + sinkTaskVertex.getConsumedPartitions(0).getResultPartitions().get(1); + IntermediateResultPartition consumedPartition2 = + executionGraph.getResultPartitionOrThrow(consumedPartitionId2); + assertEquals(mapVertex.getID(), consumedPartition1.getProducer().getJobvertexId()); + assertEquals(mapVertex.getID(), consumedPartition2.getProducer().getJobvertexId()); + + int inputPartition1 = consumedPartition1.getPartitionNumber(); assertFalse(mapSubpartitions.contains(inputPartition1)); mapSubpartitions.add(inputPartition1); - int inputPartition2 = inputEdge2.getSource().getPartitionNumber(); + int inputPartition2 = consumedPartition2.getPartitionNumber(); assertFalse(mapSubpartitions.contains(inputPartition2)); mapSubpartitions.add(inputPartition2); }