Skip to content

Commit

Permalink
[FLINK-21110] Optimize the topology building in ExecutionGraph
Browse files Browse the repository at this point in the history
* Replace ExecutionEdge with ConsumerVertexGroup and ConsumedPartitionGroup
* Fix the related tests
  • Loading branch information
Thesharing committed Feb 4, 2021
1 parent a6ba6f1 commit 85af2bb
Show file tree
Hide file tree
Showing 18 changed files with 421 additions and 348 deletions.
Expand Up @@ -27,7 +27,6 @@
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.executiongraph.ExecutionEdge;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
Expand All @@ -37,6 +36,8 @@
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor;
import org.apache.flink.types.Either;
Expand All @@ -48,6 +49,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -62,7 +64,9 @@ public class TaskDeploymentDescriptorFactory {
private final JobID jobID;
private final boolean allowUnknownPartitions;
private final int subtaskIndex;
private final ExecutionEdge[][] inputEdges;
private final List<ConsumedPartitionGroup> consumedPartitions;
private final Map<IntermediateResultPartitionID, IntermediateResultPartition>
resultPartitionsById;

private TaskDeploymentDescriptorFactory(
ExecutionAttemptID executionId,
Expand All @@ -72,15 +76,17 @@ private TaskDeploymentDescriptorFactory(
JobID jobID,
boolean allowUnknownPartitions,
int subtaskIndex,
ExecutionEdge[][] inputEdges) {
List<ConsumedPartitionGroup> consumedPartitions,
Map<IntermediateResultPartitionID, IntermediateResultPartition> resultPartitionsById) {
this.executionId = executionId;
this.attemptNumber = attemptNumber;
this.serializedJobInformation = serializedJobInformation;
this.taskInfo = taskInfo;
this.jobID = jobID;
this.allowUnknownPartitions = allowUnknownPartitions;
this.subtaskIndex = subtaskIndex;
this.inputEdges = inputEdges;
this.consumedPartitions = consumedPartitions;
this.resultPartitionsById = resultPartitionsById;
}

public TaskDeploymentDescriptor createDeploymentDescriptor(
Expand All @@ -103,18 +109,19 @@ public TaskDeploymentDescriptor createDeploymentDescriptor(
}

private List<InputGateDeploymentDescriptor> createInputGateDeploymentDescriptors() {
List<InputGateDeploymentDescriptor> inputGates = new ArrayList<>(inputEdges.length);
List<InputGateDeploymentDescriptor> inputGates = new ArrayList<>(consumedPartitions.size());

for (ExecutionEdge[] edges : inputEdges) {
for (ConsumedPartitionGroup partitions : consumedPartitions) {
// If the produced partition has multiple consumers registered, we
// need to request the one matching our sub task index.
// TODO Refactor after removing the consumers from the intermediate result partitions
int numConsumerEdges = edges[0].getSource().getConsumers().get(0).size();
IntermediateResultPartition resultPartition =
resultPartitionsById.get(partitions.getResultPartitions().get(0));

int queueToRequest = subtaskIndex % numConsumerEdges;
int numConsumer = resultPartition.getConsumers().get(0).getVertices().size();

IntermediateResult consumedIntermediateResult =
edges[0].getSource().getIntermediateResult();
int queueToRequest = subtaskIndex % numConsumer;
IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
IntermediateDataSetID resultId = consumedIntermediateResult.getId();
ResultPartitionType partitionType = consumedIntermediateResult.getResultType();

Expand All @@ -123,18 +130,22 @@ private List<InputGateDeploymentDescriptor> createInputGateDeploymentDescriptors
resultId,
partitionType,
queueToRequest,
getConsumedPartitionShuffleDescriptors(edges)));
getConsumedPartitionShuffleDescriptors(
partitions.getResultPartitions())));
}

return inputGates;
}

private ShuffleDescriptor[] getConsumedPartitionShuffleDescriptors(ExecutionEdge[] edges) {
ShuffleDescriptor[] shuffleDescriptors = new ShuffleDescriptor[edges.length];
private ShuffleDescriptor[] getConsumedPartitionShuffleDescriptors(
List<IntermediateResultPartitionID> partitions) {

ShuffleDescriptor[] shuffleDescriptors = new ShuffleDescriptor[partitions.size()];
// Each edge is connected to a different result partition
for (int i = 0; i < edges.length; i++) {
for (int i = 0; i < partitions.size(); i++) {
shuffleDescriptors[i] =
getConsumedPartitionShuffleDescriptor(edges[i], allowUnknownPartitions);
getConsumedPartitionShuffleDescriptor(
resultPartitionsById.get(partitions.get(i)), allowUnknownPartitions);
}
return shuffleDescriptors;
}
Expand All @@ -151,7 +162,8 @@ public static TaskDeploymentDescriptorFactory fromExecutionVertex(
executionGraph.getJobID(),
executionGraph.getScheduleMode().allowLazyDeployment(),
executionVertex.getParallelSubtaskIndex(),
executionVertex.getAllInputEdges());
executionVertex.getAllConsumedPartitions(),
executionGraph.getIntermediateResultPartitionMapping());
}

private static MaybeOffloaded<JobInformation> getSerializedJobInformation(
Expand All @@ -173,8 +185,8 @@ private static MaybeOffloaded<TaskInformation> getSerializedTaskInformation(
}

public static ShuffleDescriptor getConsumedPartitionShuffleDescriptor(
ExecutionEdge edge, boolean allowUnknownPartitions) {
IntermediateResultPartition consumedPartition = edge.getSource();
IntermediateResultPartition consumedPartition, boolean allowUnknownPartitions) {

Execution producer = consumedPartition.getProducer().getCurrentExecutionAttempt();

ExecutionState producerState = producer.getState();
Expand Down
@@ -0,0 +1,174 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/

package org.apache.flink.runtime.executiongraph;

import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

/** Utilities for building {@link EdgeManager}. */
public class EdgeManagerBuildUtil {

public static void registerToExecutionEdgeManager(
ExecutionVertex[] taskVertices,
IntermediateResult ires,
int inputNumber,
DistributionPattern distributionPattern) {

switch (distributionPattern) {
case POINTWISE:
connectPointwise(taskVertices, ires, inputNumber);
break;
case ALL_TO_ALL:
connectAllToAll(taskVertices, ires, inputNumber);
break;
default:
throw new RuntimeException("Unrecognized distribution pattern.");
}
}

private static void connectAllToAll(
ExecutionVertex[] taskVertices, IntermediateResult ires, int inputNumber) {

ConsumedPartitionGroup consumedPartitions =
new ConsumedPartitionGroup(
Arrays.stream(ires.getPartitions())
.map(IntermediateResultPartition::getPartitionId)
.collect(Collectors.toList()));
for (ExecutionVertex ev : taskVertices) {
ev.setConsumedPartitions(consumedPartitions, inputNumber);
}

ConsumerVertexGroup vertices =
new ConsumerVertexGroup(
Arrays.stream(taskVertices)
.map(ExecutionVertex::getID)
.collect(Collectors.toList()));
for (IntermediateResultPartition partition : ires.getPartitions()) {
partition.setConsumers(vertices);
}
}

private static void connectPointwise(
ExecutionVertex[] taskVertices, IntermediateResult ires, int inputNumber) {

final int sourceCount = ires.getPartitions().length;
final int targetCount = taskVertices.length;

if (sourceCount == targetCount) {
for (int i = 0; i < sourceCount; i++) {
ExecutionVertex executionVertex = taskVertices[i];
IntermediateResultPartition partition = ires.getPartitions()[i];

ConsumerVertexGroup consumerVertexGroup =
new ConsumerVertexGroup(executionVertex.getID());
partition.setConsumers(consumerVertexGroup);

ConsumedPartitionGroup consumedPartitionGroup =
new ConsumedPartitionGroup(partition.getPartitionId());
executionVertex.setConsumedPartitions(consumedPartitionGroup, inputNumber);
}
} else if (sourceCount > targetCount) {
for (int index = 0; index < targetCount; index++) {

ExecutionVertex executionVertex = taskVertices[index];
ConsumerVertexGroup consumerVertexGroup =
new ConsumerVertexGroup(executionVertex.getID());

List<IntermediateResultPartitionID> consumedPartitions =
new ArrayList<>(sourceCount / targetCount + 1);

if (sourceCount % targetCount == 0) {
int factor = sourceCount / targetCount;
int start = index * factor;
for (int i = 0; i < factor; i++) {
IntermediateResultPartition partition = ires.getPartitions()[start + i];
partition.setConsumers(consumerVertexGroup);

consumedPartitions.add(partition.getPartitionId());
}
} else {
float factor = ((float) sourceCount) / targetCount;
int start = (int) (index * factor);
int end =
(index == targetCount - 1) ? sourceCount : (int) ((index + 1) * factor);

for (int i = 0; i < end - start; i++) {
IntermediateResultPartition partition = ires.getPartitions()[start + i];
partition.setConsumers(consumerVertexGroup);

consumedPartitions.add(partition.getPartitionId());
}
}

ConsumedPartitionGroup consumedPartitionGroup =
new ConsumedPartitionGroup(consumedPartitions);
executionVertex.setConsumedPartitions(consumedPartitionGroup, inputNumber);
}
} else {
for (int partitionNum = 0; partitionNum < sourceCount; partitionNum++) {

IntermediateResultPartition partition = ires.getPartitions()[partitionNum];
ConsumedPartitionGroup consumerPartitionGroup =
new ConsumedPartitionGroup(partition.getPartitionId());

List<ExecutionVertexID> consumers = new ArrayList<>(targetCount / sourceCount + 1);

if (targetCount % sourceCount == 0) {
int factor = targetCount / sourceCount;
int start = partitionNum * factor;
for (int i = 0; i < factor; i++) {
ExecutionVertex executionVertex = taskVertices[start + i];
executionVertex.setConsumedPartitions(consumerPartitionGroup, inputNumber);

consumers.add(executionVertex.getID());
}
} else {
float factor = ((float) targetCount) / sourceCount;
int mirrorPartitionNumber = sourceCount - 1 - partitionNum;
int start = (int) (mirrorPartitionNumber * factor);
int end =
(mirrorPartitionNumber == sourceCount - 1)
? targetCount
: (int) ((mirrorPartitionNumber + 1) * factor);

for (int i = 0; i < end - start; i++) {
int mirrorVertexSubTaskIndex = start + i;
int vertexSubtaskIndex = targetCount - 1 - mirrorVertexSubTaskIndex;

ExecutionVertex executionVertex = taskVertices[vertexSubtaskIndex];
executionVertex.setConsumedPartitions(consumerPartitionGroup, inputNumber);

consumers.add(executionVertex.getID());
}
}

ConsumerVertexGroup consumerVertexGroup = new ConsumerVertexGroup(consumers);
partition.setConsumers(consumerVertexGroup);
}
}
}
}

0 comments on commit 85af2bb

Please sign in to comment.