Skip to content

Commit

Permalink
[fix-up] Optimize the topology building in ExecutionGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
Thesharing committed Feb 25, 2021
1 parent a744454 commit d49a6a6
Show file tree
Hide file tree
Showing 12 changed files with 71 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -64,9 +63,7 @@ public class TaskDeploymentDescriptorFactory {
private final JobID jobID;
private final boolean allowUnknownPartitions;
private final int subtaskIndex;
private final List<ConsumedPartitionGroup> consumedPartitions;
private final Map<IntermediateResultPartitionID, IntermediateResultPartition>
resultPartitionsById;
private final List<List<IntermediateResultPartition>> consumedPartitions;

private TaskDeploymentDescriptorFactory(
ExecutionAttemptID executionId,
Expand All @@ -76,8 +73,7 @@ private TaskDeploymentDescriptorFactory(
JobID jobID,
boolean allowUnknownPartitions,
int subtaskIndex,
List<ConsumedPartitionGroup> consumedPartitions,
Map<IntermediateResultPartitionID, IntermediateResultPartition> resultPartitionsById) {
List<List<IntermediateResultPartition>> consumedPartitions) {
this.executionId = executionId;
this.attemptNumber = attemptNumber;
this.serializedJobInformation = serializedJobInformation;
Expand All @@ -86,7 +82,6 @@ private TaskDeploymentDescriptorFactory(
this.allowUnknownPartitions = allowUnknownPartitions;
this.subtaskIndex = subtaskIndex;
this.consumedPartitions = consumedPartitions;
this.resultPartitionsById = resultPartitionsById;
}

public TaskDeploymentDescriptor createDeploymentDescriptor(
Expand All @@ -109,12 +104,11 @@ public TaskDeploymentDescriptor createDeploymentDescriptor(
private List<InputGateDeploymentDescriptor> createInputGateDeploymentDescriptors() {
List<InputGateDeploymentDescriptor> inputGates = new ArrayList<>(consumedPartitions.size());

for (ConsumedPartitionGroup partitions : consumedPartitions) {
for (List<IntermediateResultPartition> partitions : consumedPartitions) {
// If the produced partition has multiple consumers registered, we
// need to request the one matching our sub task index.
// TODO Refactor after removing the consumers from the intermediate result partitions
IntermediateResultPartition resultPartition =
resultPartitionsById.get(partitions.getResultPartitions().get(0));
IntermediateResultPartition resultPartition = partitions.get(0);

int numConsumer = resultPartition.getConsumers().get(0).getVertices().size();

Expand All @@ -128,29 +122,39 @@ private List<InputGateDeploymentDescriptor> createInputGateDeploymentDescriptors
resultId,
partitionType,
queueToRequest,
getConsumedPartitionShuffleDescriptors(
partitions.getResultPartitions())));
getConsumedPartitionShuffleDescriptors(partitions)));
}

return inputGates;
}

private ShuffleDescriptor[] getConsumedPartitionShuffleDescriptors(
List<IntermediateResultPartitionID> partitions) {
List<IntermediateResultPartition> partitions) {

ShuffleDescriptor[] shuffleDescriptors = new ShuffleDescriptor[partitions.size()];
// Each edge is connected to a different result partition
for (int i = 0; i < partitions.size(); i++) {
shuffleDescriptors[i] =
getConsumedPartitionShuffleDescriptor(
resultPartitionsById.get(partitions.get(i)), allowUnknownPartitions);
partitions.get(i), allowUnknownPartitions);
}
return shuffleDescriptors;
}

public static TaskDeploymentDescriptorFactory fromExecutionVertex(
ExecutionVertex executionVertex, int attemptNumber) throws IOException {
ExecutionGraph executionGraph = executionVertex.getExecutionGraph();
final ExecutionGraph executionGraph = executionVertex.getExecutionGraph();

final List<List<IntermediateResultPartition>> consumedPartitions = new ArrayList<>();

for (ConsumedPartitionGroup partitionGroup : executionVertex.getAllConsumedPartitions()) {
List<IntermediateResultPartition> partitions = new ArrayList<>();
for (IntermediateResultPartitionID partitionId : partitionGroup.getResultPartitions()) {
partitions.add(executionGraph.getResultPartitionOrThrow(partitionId));
}
consumedPartitions.add(partitions);
}

return new TaskDeploymentDescriptorFactory(
executionVertex.getCurrentExecutionAttempt().getAttemptId(),
attemptNumber,
Expand All @@ -160,8 +164,7 @@ public static TaskDeploymentDescriptorFactory fromExecutionVertex(
executionGraph.getJobID(),
executionGraph.getScheduleMode().allowLazyDeployment(),
executionVertex.getParallelSubtaskIndex(),
executionVertex.getAllConsumedPartitions(),
executionGraph.getIntermediateResultPartitionMapping());
consumedPartitions);
}

private static MaybeOffloaded<JobInformation> getSerializedJobInformation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
/** Utilities for building {@link EdgeManager}. */
public class EdgeManagerBuildUtil {

public static void registerToExecutionEdgeManager(
public static void connectVertexToResult(
ExecutionVertex[] taskVertices,
IntermediateResult ires,
int inputNumber,
Expand All @@ -59,7 +59,7 @@ private static void connectAllToAll(
.map(IntermediateResultPartition::getPartitionId)
.collect(Collectors.toList()));
for (ExecutionVertex ev : taskVertices) {
ev.setConsumedPartitions(consumedPartitions, inputNumber);
ev.addConsumedPartitions(consumedPartitions, inputNumber);
}

ConsumerVertexGroup vertices =
Expand All @@ -68,7 +68,7 @@ private static void connectAllToAll(
.map(ExecutionVertex::getID)
.collect(Collectors.toList()));
for (IntermediateResultPartition partition : ires.getPartitions()) {
partition.setConsumers(vertices);
partition.addConsumers(vertices);
}
}

Expand All @@ -85,11 +85,11 @@ private static void connectPointwise(

ConsumerVertexGroup consumerVertexGroup =
new ConsumerVertexGroup(executionVertex.getID());
partition.setConsumers(consumerVertexGroup);
partition.addConsumers(consumerVertexGroup);

ConsumedPartitionGroup consumedPartitionGroup =
new ConsumedPartitionGroup(partition.getPartitionId());
executionVertex.setConsumedPartitions(consumedPartitionGroup, inputNumber);
executionVertex.addConsumedPartitions(consumedPartitionGroup, inputNumber);
}
} else if (sourceCount > targetCount) {
for (int index = 0; index < targetCount; index++) {
Expand All @@ -98,35 +98,22 @@ private static void connectPointwise(
ConsumerVertexGroup consumerVertexGroup =
new ConsumerVertexGroup(executionVertex.getID());

List<IntermediateResultPartitionID> consumedPartitions =
new ArrayList<>(sourceCount / targetCount + 1);

if (sourceCount % targetCount == 0) {
int factor = sourceCount / targetCount;
int start = index * factor;
for (int i = 0; i < factor; i++) {
IntermediateResultPartition partition = ires.getPartitions()[start + i];
partition.setConsumers(consumerVertexGroup);
int start = index * sourceCount / targetCount;
int end = (index + 1) * sourceCount / targetCount;

consumedPartitions.add(partition.getPartitionId());
}
} else {
float factor = ((float) sourceCount) / targetCount;
int start = (int) (index * factor);
int end =
(index == targetCount - 1) ? sourceCount : (int) ((index + 1) * factor);
List<IntermediateResultPartitionID> consumedPartitions =
new ArrayList<>(end - start);

for (int i = 0; i < end - start; i++) {
IntermediateResultPartition partition = ires.getPartitions()[start + i];
partition.setConsumers(consumerVertexGroup);
for (int i = start; i < end; i++) {
IntermediateResultPartition partition = ires.getPartitions()[i];
partition.addConsumers(consumerVertexGroup);

consumedPartitions.add(partition.getPartitionId());
}
consumedPartitions.add(partition.getPartitionId());
}

ConsumedPartitionGroup consumedPartitionGroup =
new ConsumedPartitionGroup(consumedPartitions);
executionVertex.setConsumedPartitions(consumedPartitionGroup, inputNumber);
executionVertex.addConsumedPartitions(consumedPartitionGroup, inputNumber);
}
} else {
for (int partitionNum = 0; partitionNum < sourceCount; partitionNum++) {
Expand All @@ -142,7 +129,7 @@ private static void connectPointwise(
int start = partitionNum * factor;
for (int i = 0; i < factor; i++) {
ExecutionVertex executionVertex = taskVertices[start + i];
executionVertex.setConsumedPartitions(consumerPartitionGroup, inputNumber);
executionVertex.addConsumedPartitions(consumerPartitionGroup, inputNumber);

consumers.add(executionVertex.getID());
}
Expand All @@ -160,14 +147,14 @@ private static void connectPointwise(
int vertexSubtaskIndex = targetCount - 1 - mirrorVertexSubTaskIndex;

ExecutionVertex executionVertex = taskVertices[vertexSubtaskIndex];
executionVertex.setConsumedPartitions(consumerPartitionGroup, inputNumber);
executionVertex.addConsumedPartitions(consumerPartitionGroup, inputNumber);

consumers.add(executionVertex.getID());
}
}

ConsumerVertexGroup consumerVertexGroup = new ConsumerVertexGroup(consumers);
partition.setConsumers(consumerVertexGroup);
partition.addConsumers(consumerVertexGroup);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ public CompletableFuture<Execution> registerProducedPartitions(
PartitionDescriptor partitionDescriptor = PartitionDescriptor.from(partition);
int maxParallelism =
getPartitionMaxParallelism(
partition, vertex.getExecutionGraph().getExecutionVertexMapping());
partition, vertex.getExecutionGraph()::getVertexOrThrow);
CompletableFuture<? extends ShuffleDescriptor> shuffleDescriptorFuture =
vertex.getExecutionGraph()
.getShuffleMaster()
Expand Down Expand Up @@ -473,13 +473,13 @@ public CompletableFuture<Execution> registerProducedPartitions(

private static int getPartitionMaxParallelism(
IntermediateResultPartition partition,
Map<ExecutionVertexID, ExecutionVertex> verticesById) {
Function<ExecutionVertexID, ExecutionVertex> getVertexById) {
final List<ConsumerVertexGroup> consumers = partition.getConsumers();
Preconditions.checkArgument(
consumers.size() == 1,
"Currently there has to be exactly one consumer in real jobs");
final List<ExecutionVertexID> consumerIds = consumers.get(0).getVertices();
return verticesById.get(consumerIds.get(0)).getJobVertex().getMaxParallelism();
return getVertexById.apply(consumerIds.get(0)).getJobVertex().getMaxParallelism();
}

/**
Expand Down Expand Up @@ -710,9 +710,10 @@ public CompletableFuture<?> suspend() {
return releaseFuture;
}

private void updatePartitionConsumers(
final IntermediateResultPartition partition,
final List<ConsumerVertexGroup> allConsumers) {
private void updatePartitionConsumers(final IntermediateResultPartition partition) {

final List<ConsumerVertexGroup> allConsumers = partition.getConsumers();

if (allConsumers.size() == 0) {
return;
}
Expand All @@ -725,7 +726,7 @@ private void updatePartitionConsumers(

for (ExecutionVertexID consumerVertexId : allConsumers.get(0).getVertices()) {
final ExecutionVertex consumerVertex =
vertex.getExecutionGraph().getVertex(consumerVertexId);
vertex.getExecutionGraph().getVertexOrThrow(consumerVertexId);
final Execution consumer = consumerVertex.getCurrentExecutionAttempt();
final ExecutionState consumerState = consumer.getState();

Expand Down Expand Up @@ -975,7 +976,7 @@ private void finishPartitionsAndUpdateConsumers() {
finishedPartition.getIntermediateResult().getPartitions();

for (IntermediateResultPartition partition : allPartitionsOfNewlyFinishedResults) {
updatePartitionConsumers(partition, partition.getConsumers());
updatePartitionConsumers(partition);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,21 +692,13 @@ public void registerResultPartition(
resultPartitionsById.put(id, partition);
}

public ExecutionVertex getVertex(ExecutionVertexID id) {
return executionVerticesById.get(id);
public ExecutionVertex getVertexOrThrow(ExecutionVertexID id) {
return checkNotNull(executionVerticesById.get(id));
}

public IntermediateResultPartition getResultPartition(final IntermediateResultPartitionID id) {
return resultPartitionsById.get(id);
}

public Map<IntermediateResultPartitionID, IntermediateResultPartition>
getIntermediateResultPartitionMapping() {
return Collections.unmodifiableMap(resultPartitionsById);
}

public Map<ExecutionVertexID, ExecutionVertex> getExecutionVertexMapping() {
return Collections.unmodifiableMap(executionVerticesById);
public IntermediateResultPartition getResultPartitionOrThrow(
final IntermediateResultPartitionID id) {
return checkNotNull(resultPartitionsById.get(id));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil.registerToExecutionEdgeManager;
import static org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil.connectVertexToResult;
import static org.apache.flink.util.Preconditions.checkNotNull;

/**
Expand Down Expand Up @@ -455,7 +455,7 @@ public void connectToPredecessors(

this.inputs.add(ires);

registerToExecutionEdgeManager(taskVertices, ires, num, edge.getDistributionPattern());
connectVertexToResult(taskVertices, ires, num, edge.getDistributionPattern());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ public Map<IntermediateResultPartitionID, IntermediateResultPartition> getProduc
// Graph building
// --------------------------------------------------------------------------------------------

public void setConsumedPartitions(ConsumedPartitionGroup consumedPartitions, int inputNum) {
public void addConsumedPartitions(ConsumedPartitionGroup consumedPartitions, int inputNum) {

getExecutionGraph()
.getEdgeManager()
Expand Down Expand Up @@ -419,7 +419,7 @@ public Collection<CompletableFuture<TaskManagerLocation>> getPreferredLocationsB
// look-up assigned slot of input source
CompletableFuture<TaskManagerLocation> locationFuture =
getExecutionGraph()
.getResultPartition(sourceId)
.getResultPartitionOrThrow(sourceId)
.getProducer()
.getCurrentTaskManagerLocationFuture();
// add input location
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ void resetForNewExecution() {
hasDataProduced = false;
}

public void setConsumers(ConsumerVertexGroup consumers) {
producer.getExecutionGraph().getEdgeManager().addPartitionConsumers(partitionId, consumers);
public void addConsumers(ConsumerVertexGroup consumers) {
getEdgeManager().addPartitionConsumers(partitionId, consumers);
}

EdgeManager getEdgeManager() {
private EdgeManager getEdgeManager() {
return producer.getExecutionGraph().getEdgeManager();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@

package org.apache.flink.runtime.jobgraph;

import org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;

/**
* A distribution pattern determines, which sub tasks of a producing task are connected to which
* consuming sub tasks.
*
* <p>It affects how {@link ExecutionVertex} and {@link IntermediateResultPartition} is connected in
* {@link EdgeManagerBuildUtil}
*/
public enum DistributionPattern {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public Collection<Collection<ExecutionVertexID>> getConsumedResultPartitionsProd
for (IntermediateResultPartitionID consumedPartitionId :
consumedPartitions.getResultPartitions()) {
ExecutionVertex producer =
executionGraph.getResultPartition(consumedPartitionId).getProducer();
executionGraph.getResultPartitionOrThrow(consumedPartitionId).getProducer();
producers.add(producer.getID());
}
resultPartitionProducers.add(producers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ public void testTddProducedPartitionsLazyScheduling() throws Exception {
TaskDeploymentDescriptorFactory tddFactory =
TaskDeploymentDescriptorFactory.fromExecutionVertex(vertex, 1);

result.getPartitions()[0].setConsumers(new ConsumerVertexGroup(vertex.getID()));
result.getPartitions()[0].addConsumers(new ConsumerVertexGroup(vertex.getID()));

TaskManagerLocation location =
new TaskManagerLocation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ private static void assertGraphEquals(
.mapToObj(originalVertex::getConsumedPartitions)
.map(ConsumedPartitionGroup::getResultPartitions)
.flatMap(Collection::stream)
.map(originalGraph::getResultPartition)
.map(originalGraph::getResultPartitionOrThrow)
.collect(Collectors.toList());
Iterable<DefaultResultPartition> adaptedConsumedPartitions =
adaptedVertex.getConsumedResults();
Expand Down

0 comments on commit d49a6a6

Please sign in to comment.