Skip to content

Commit

Permalink
[FLINK-21326] Optimize the topology building in ExecutionGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
Thesharing committed Feb 26, 2021
1 parent 8a8a260 commit cb77b97
Show file tree
Hide file tree
Showing 18 changed files with 396 additions and 350 deletions.
Expand Up @@ -27,7 +27,6 @@
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.executiongraph.ExecutionEdge;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
Expand All @@ -37,6 +36,8 @@
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor;
import org.apache.flink.types.Either;
Expand All @@ -62,7 +63,7 @@ public class TaskDeploymentDescriptorFactory {
private final JobID jobID;
private final boolean allowUnknownPartitions;
private final int subtaskIndex;
private final ExecutionEdge[][] inputEdges;
private final List<List<IntermediateResultPartition>> consumedPartitions;

private TaskDeploymentDescriptorFactory(
ExecutionAttemptID executionId,
Expand All @@ -72,15 +73,15 @@ private TaskDeploymentDescriptorFactory(
JobID jobID,
boolean allowUnknownPartitions,
int subtaskIndex,
ExecutionEdge[][] inputEdges) {
List<List<IntermediateResultPartition>> consumedPartitions) {
this.executionId = executionId;
this.attemptNumber = attemptNumber;
this.serializedJobInformation = serializedJobInformation;
this.taskInfo = taskInfo;
this.jobID = jobID;
this.allowUnknownPartitions = allowUnknownPartitions;
this.subtaskIndex = subtaskIndex;
this.inputEdges = inputEdges;
this.consumedPartitions = consumedPartitions;
}

public TaskDeploymentDescriptor createDeploymentDescriptor(
Expand All @@ -101,18 +102,18 @@ public TaskDeploymentDescriptor createDeploymentDescriptor(
}

private List<InputGateDeploymentDescriptor> createInputGateDeploymentDescriptors() {
List<InputGateDeploymentDescriptor> inputGates = new ArrayList<>(inputEdges.length);
List<InputGateDeploymentDescriptor> inputGates = new ArrayList<>(consumedPartitions.size());

for (ExecutionEdge[] edges : inputEdges) {
for (List<IntermediateResultPartition> partitions : consumedPartitions) {
// If the produced partition has multiple consumers registered, we
// need to request the one matching our sub task index.
// TODO Refactor after removing the consumers from the intermediate result partitions
int numConsumerEdges = edges[0].getSource().getConsumers().get(0).size();
IntermediateResultPartition resultPartition = partitions.get(0);

int queueToRequest = subtaskIndex % numConsumerEdges;
int numConsumer = resultPartition.getConsumers().get(0).getVertices().size();

IntermediateResult consumedIntermediateResult =
edges[0].getSource().getIntermediateResult();
int queueToRequest = subtaskIndex % numConsumer;
IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
IntermediateDataSetID resultId = consumedIntermediateResult.getId();
ResultPartitionType partitionType = consumedIntermediateResult.getResultType();

Expand All @@ -121,25 +122,39 @@ private List<InputGateDeploymentDescriptor> createInputGateDeploymentDescriptors
resultId,
partitionType,
queueToRequest,
getConsumedPartitionShuffleDescriptors(edges)));
getConsumedPartitionShuffleDescriptors(partitions)));
}

return inputGates;
}

private ShuffleDescriptor[] getConsumedPartitionShuffleDescriptors(ExecutionEdge[] edges) {
ShuffleDescriptor[] shuffleDescriptors = new ShuffleDescriptor[edges.length];
private ShuffleDescriptor[] getConsumedPartitionShuffleDescriptors(
List<IntermediateResultPartition> partitions) {

ShuffleDescriptor[] shuffleDescriptors = new ShuffleDescriptor[partitions.size()];
// Each edge is connected to a different result partition
for (int i = 0; i < edges.length; i++) {
for (int i = 0; i < partitions.size(); i++) {
shuffleDescriptors[i] =
getConsumedPartitionShuffleDescriptor(edges[i], allowUnknownPartitions);
getConsumedPartitionShuffleDescriptor(
partitions.get(i), allowUnknownPartitions);
}
return shuffleDescriptors;
}

public static TaskDeploymentDescriptorFactory fromExecutionVertex(
ExecutionVertex executionVertex, int attemptNumber) throws IOException {
ExecutionGraph executionGraph = executionVertex.getExecutionGraph();
final ExecutionGraph executionGraph = executionVertex.getExecutionGraph();

final List<List<IntermediateResultPartition>> consumedPartitions = new ArrayList<>();

for (ConsumedPartitionGroup partitionGroup : executionVertex.getAllConsumedPartitions()) {
List<IntermediateResultPartition> partitions = new ArrayList<>();
for (IntermediateResultPartitionID partitionId : partitionGroup.getResultPartitions()) {
partitions.add(executionGraph.getResultPartitionOrThrow(partitionId));
}
consumedPartitions.add(partitions);
}

return new TaskDeploymentDescriptorFactory(
executionVertex.getCurrentExecutionAttempt().getAttemptId(),
attemptNumber,
Expand All @@ -149,7 +164,7 @@ public static TaskDeploymentDescriptorFactory fromExecutionVertex(
executionGraph.getJobID(),
executionGraph.getScheduleMode().allowLazyDeployment(),
executionVertex.getParallelSubtaskIndex(),
executionVertex.getAllInputEdges());
consumedPartitions);
}

private static MaybeOffloaded<JobInformation> getSerializedJobInformation(
Expand All @@ -171,8 +186,8 @@ private static MaybeOffloaded<TaskInformation> getSerializedTaskInformation(
}

public static ShuffleDescriptor getConsumedPartitionShuffleDescriptor(
ExecutionEdge edge, boolean allowUnknownPartitions) {
IntermediateResultPartition consumedPartition = edge.getSource();
IntermediateResultPartition consumedPartition, boolean allowUnknownPartitions) {

Execution producer = consumedPartition.getProducer().getCurrentExecutionAttempt();

ExecutionState producerState = producer.getState();
Expand Down
@@ -0,0 +1,143 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/

package org.apache.flink.runtime.executiongraph;

import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

/** Utilities for building {@link EdgeManager}. */
public class EdgeManagerBuildUtil {

public static void connectVertexToResult(
ExecutionJobVertex vertex,
IntermediateResult ires,
int inputNumber,
DistributionPattern distributionPattern) {

switch (distributionPattern) {
case POINTWISE:
connectPointwise(vertex.getTaskVertices(), ires, inputNumber);
break;
case ALL_TO_ALL:
connectAllToAll(vertex.getTaskVertices(), ires, inputNumber);
break;
default:
throw new RuntimeException("Unrecognized distribution pattern.");
}
}

private static void connectAllToAll(
ExecutionVertex[] taskVertices, IntermediateResult ires, int inputNumber) {

ConsumedPartitionGroup consumedPartitions =
new ConsumedPartitionGroup(
Arrays.stream(ires.getPartitions())
.map(IntermediateResultPartition::getPartitionId)
.collect(Collectors.toList()));
for (ExecutionVertex ev : taskVertices) {
ev.addConsumedPartitions(consumedPartitions, inputNumber);
}

ConsumerVertexGroup vertices =
new ConsumerVertexGroup(
Arrays.stream(taskVertices)
.map(ExecutionVertex::getID)
.collect(Collectors.toList()));
for (IntermediateResultPartition partition : ires.getPartitions()) {
partition.addConsumers(vertices);
}
}

private static void connectPointwise(
ExecutionVertex[] taskVertices, IntermediateResult ires, int inputNumber) {

final int sourceCount = ires.getPartitions().length;
final int targetCount = taskVertices.length;

if (sourceCount == targetCount) {
for (int i = 0; i < sourceCount; i++) {
ExecutionVertex executionVertex = taskVertices[i];
IntermediateResultPartition partition = ires.getPartitions()[i];

ConsumerVertexGroup consumerVertexGroup =
new ConsumerVertexGroup(executionVertex.getID());
partition.addConsumers(consumerVertexGroup);

ConsumedPartitionGroup consumedPartitionGroup =
new ConsumedPartitionGroup(partition.getPartitionId());
executionVertex.addConsumedPartitions(consumedPartitionGroup, inputNumber);
}
} else if (sourceCount > targetCount) {
for (int index = 0; index < targetCount; index++) {

ExecutionVertex executionVertex = taskVertices[index];
ConsumerVertexGroup consumerVertexGroup =
new ConsumerVertexGroup(executionVertex.getID());

int start = index * sourceCount / targetCount;
int end = (index + 1) * sourceCount / targetCount;

List<IntermediateResultPartitionID> consumedPartitions =
new ArrayList<>(end - start);

for (int i = start; i < end; i++) {
IntermediateResultPartition partition = ires.getPartitions()[i];
partition.addConsumers(consumerVertexGroup);

consumedPartitions.add(partition.getPartitionId());
}

ConsumedPartitionGroup consumedPartitionGroup =
new ConsumedPartitionGroup(consumedPartitions);
executionVertex.addConsumedPartitions(consumedPartitionGroup, inputNumber);
}
} else {
for (int partitionNum = 0; partitionNum < sourceCount; partitionNum++) {

IntermediateResultPartition partition = ires.getPartitions()[partitionNum];
ConsumedPartitionGroup consumerPartitionGroup =
new ConsumedPartitionGroup(partition.getPartitionId());

float factor = ((float) targetCount) / sourceCount;
int start = (int) (Math.ceil(partitionNum * factor));
int end = (int) (Math.ceil((partitionNum + 1) * factor));

List<ExecutionVertexID> consumers = new ArrayList<>(end - start);

for (int i = start; i < end; i++) {
ExecutionVertex executionVertex = taskVertices[i];
executionVertex.addConsumedPartitions(consumerPartitionGroup, inputNumber);

consumers.add(executionVertex.getID());
}

ConsumerVertexGroup consumerVertexGroup = new ConsumerVertexGroup(consumers);
partition.addConsumers(consumerVertexGroup);
}
}
}
}
Expand Up @@ -49,6 +49,8 @@
import org.apache.flink.runtime.messages.Acknowledge;
import org.apache.flink.runtime.operators.coordination.OperatorEvent;
import org.apache.flink.runtime.operators.coordination.TaskNotRunningException;
import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.shuffle.NettyShuffleMaster;
import org.apache.flink.runtime.shuffle.PartitionDescriptor;
import org.apache.flink.runtime.shuffle.ProducerDescriptor;
Expand Down Expand Up @@ -436,7 +438,9 @@ public CompletableFuture<Execution> registerProducedPartitions(

for (IntermediateResultPartition partition : partitions) {
PartitionDescriptor partitionDescriptor = PartitionDescriptor.from(partition);
int maxParallelism = getPartitionMaxParallelism(partition);
int maxParallelism =
getPartitionMaxParallelism(
partition, vertex.getExecutionGraph()::getExecutionVertexOrThrow);
CompletableFuture<? extends ShuffleDescriptor> shuffleDescriptorFuture =
vertex.getExecutionGraph()
.getShuffleMaster()
Expand Down Expand Up @@ -468,15 +472,15 @@ public CompletableFuture<Execution> registerProducedPartitions(
});
}

private static int getPartitionMaxParallelism(IntermediateResultPartition partition) {
final List<List<ExecutionEdge>> consumers = partition.getConsumers();
private static int getPartitionMaxParallelism(
IntermediateResultPartition partition,
Function<ExecutionVertexID, ExecutionVertex> getVertexById) {
final List<ConsumerVertexGroup> consumers = partition.getConsumers();
Preconditions.checkArgument(
!consumers.isEmpty(),
consumers.size() == 1,
"Currently there has to be exactly one consumer in real jobs");
List<ExecutionEdge> consumer = consumers.get(0);
ExecutionJobVertex consumerVertex = consumer.get(0).getTarget().getJobVertex();
int maxParallelism = consumerVertex.getMaxParallelism();
return maxParallelism;
final List<ExecutionVertexID> consumerIds = consumers.get(0).getVertices();
return getVertexById.apply(consumerIds.get(0)).getJobVertex().getMaxParallelism();
}

/**
Expand Down Expand Up @@ -707,7 +711,10 @@ public CompletableFuture<?> suspend() {
return releaseFuture;
}

private void updatePartitionConsumers(final List<List<ExecutionEdge>> allConsumers) {
private void updatePartitionConsumers(final IntermediateResultPartition partition) {

final List<ConsumerVertexGroup> allConsumers = partition.getConsumers();

if (allConsumers.size() == 0) {
return;
}
Expand All @@ -718,8 +725,9 @@ private void updatePartitionConsumers(final List<List<ExecutionEdge>> allConsume
return;
}

for (ExecutionEdge edge : allConsumers.get(0)) {
final ExecutionVertex consumerVertex = edge.getTarget();
for (ExecutionVertexID consumerVertexId : allConsumers.get(0).getVertices()) {
final ExecutionVertex consumerVertex =
vertex.getExecutionGraph().getExecutionVertexOrThrow(consumerVertexId);
final Execution consumer = consumerVertex.getCurrentExecutionAttempt();
final ExecutionState consumerState = consumer.getState();

Expand All @@ -729,7 +737,7 @@ private void updatePartitionConsumers(final List<List<ExecutionEdge>> allConsume
// sent after switching to running
// ----------------------------------------------------------------
if (consumerState == DEPLOYING || consumerState == RUNNING) {
final PartitionInfo partitionInfo = createPartitionInfo(edge);
final PartitionInfo partitionInfo = createPartitionInfo(partition);

if (consumerState == DEPLOYING) {
consumerVertex.cachePartitionInfo(partitionInfo);
Expand All @@ -740,11 +748,12 @@ private void updatePartitionConsumers(final List<List<ExecutionEdge>> allConsume
}
}

private static PartitionInfo createPartitionInfo(ExecutionEdge executionEdge) {
private static PartitionInfo createPartitionInfo(
IntermediateResultPartition consumedPartition) {
IntermediateDataSetID intermediateDataSetID =
executionEdge.getSource().getIntermediateResult().getId();
consumedPartition.getIntermediateResult().getId();
ShuffleDescriptor shuffleDescriptor =
getConsumedPartitionShuffleDescriptor(executionEdge, false);
getConsumedPartitionShuffleDescriptor(consumedPartition, false);
return new PartitionInfo(intermediateDataSetID, shuffleDescriptor);
}

Expand Down Expand Up @@ -954,7 +963,7 @@ private void finishPartitionsAndUpdateConsumers() {
finishedPartition.getIntermediateResult().getPartitions();

for (IntermediateResultPartition partition : allPartitionsOfNewlyFinishedResults) {
updatePartitionConsumers(partition.getConsumers());
updatePartitionConsumers(partition);
}
}
}
Expand Down

0 comments on commit cb77b97

Please sign in to comment.