Skip to content

Commit

Permalink
[fix-up] Optimize the topology building in ExecutionGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
Thesharing committed Feb 25, 2021
1 parent d49a6a6 commit e19dc30
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 38 deletions.
Expand Up @@ -33,17 +33,17 @@
public class EdgeManagerBuildUtil {

public static void connectVertexToResult(
ExecutionVertex[] taskVertices,
ExecutionJobVertex vertex,
IntermediateResult ires,
int inputNumber,
DistributionPattern distributionPattern) {

switch (distributionPattern) {
case POINTWISE:
connectPointwise(taskVertices, ires, inputNumber);
connectPointwise(vertex.getTaskVertices(), ires, inputNumber);
break;
case ALL_TO_ALL:
connectAllToAll(taskVertices, ires, inputNumber);
connectAllToAll(vertex.getTaskVertices(), ires, inputNumber);
break;
default:
throw new RuntimeException("Unrecognized distribution pattern.");
Expand Down Expand Up @@ -122,35 +122,17 @@ private static void connectPointwise(
ConsumedPartitionGroup consumerPartitionGroup =
new ConsumedPartitionGroup(partition.getPartitionId());

List<ExecutionVertexID> consumers = new ArrayList<>(targetCount / sourceCount + 1);

if (targetCount % sourceCount == 0) {
int factor = targetCount / sourceCount;
int start = partitionNum * factor;
for (int i = 0; i < factor; i++) {
ExecutionVertex executionVertex = taskVertices[start + i];
executionVertex.addConsumedPartitions(consumerPartitionGroup, inputNumber);

consumers.add(executionVertex.getID());
}
} else {
float factor = ((float) targetCount) / sourceCount;
int mirrorPartitionNumber = sourceCount - 1 - partitionNum;
int start = (int) (mirrorPartitionNumber * factor);
int end =
(mirrorPartitionNumber == sourceCount - 1)
? targetCount
: (int) ((mirrorPartitionNumber + 1) * factor);

for (int i = 0; i < end - start; i++) {
int mirrorVertexSubTaskIndex = start + i;
int vertexSubtaskIndex = targetCount - 1 - mirrorVertexSubTaskIndex;

ExecutionVertex executionVertex = taskVertices[vertexSubtaskIndex];
executionVertex.addConsumedPartitions(consumerPartitionGroup, inputNumber);

consumers.add(executionVertex.getID());
}
float factor = ((float) targetCount) / sourceCount;
int start = (int) (Math.ceil(partitionNum * factor));
int end = (int) (Math.ceil((partitionNum + 1) * factor));

List<ExecutionVertexID> consumers = new ArrayList<>(end - start);

for (int i = start; i < end; i++) {
ExecutionVertex executionVertex = taskVertices[i];
executionVertex.addConsumedPartitions(consumerPartitionGroup, inputNumber);

consumers.add(executionVertex.getID());
}

ConsumerVertexGroup consumerVertexGroup = new ConsumerVertexGroup(consumers);
Expand Down
Expand Up @@ -439,7 +439,7 @@ public CompletableFuture<Execution> registerProducedPartitions(
PartitionDescriptor partitionDescriptor = PartitionDescriptor.from(partition);
int maxParallelism =
getPartitionMaxParallelism(
partition, vertex.getExecutionGraph()::getVertexOrThrow);
partition, vertex.getExecutionGraph()::getExecutionVertexOrThrow);
CompletableFuture<? extends ShuffleDescriptor> shuffleDescriptorFuture =
vertex.getExecutionGraph()
.getShuffleMaster()
Expand Down Expand Up @@ -726,7 +726,7 @@ private void updatePartitionConsumers(final IntermediateResultPartition partitio

for (ExecutionVertexID consumerVertexId : allConsumers.get(0).getVertices()) {
final ExecutionVertex consumerVertex =
vertex.getExecutionGraph().getVertexOrThrow(consumerVertexId);
vertex.getExecutionGraph().getExecutionVertexOrThrow(consumerVertexId);
final Execution consumer = consumerVertex.getCurrentExecutionAttempt();
final ExecutionState consumerState = consumer.getState();

Expand Down
Expand Up @@ -692,7 +692,7 @@ public void registerResultPartition(
resultPartitionsById.put(id, partition);
}

public ExecutionVertex getVertexOrThrow(ExecutionVertexID id) {
public ExecutionVertex getExecutionVertexOrThrow(ExecutionVertexID id) {
return checkNotNull(executionVerticesById.get(id));
}

Expand Down
Expand Up @@ -455,7 +455,7 @@ public void connectToPredecessors(

this.inputs.add(ires);

connectVertexToResult(taskVertices, ires, num, edge.getDistributionPattern());
connectVertexToResult(this, ires, num, edge.getDistributionPattern());
}
}

Expand Down
Expand Up @@ -26,8 +26,8 @@
* A distribution pattern determines, which sub tasks of a producing task are connected to which
* consuming sub tasks.
*
* <p>It affects how {@link ExecutionVertex} and {@link IntermediateResultPartition} is connected in
* {@link EdgeManagerBuildUtil}
* <p>It affects how {@link ExecutionVertex} and {@link IntermediateResultPartition} are connected
* in {@link EdgeManagerBuildUtil}
*/
public enum DistributionPattern {

Expand Down

0 comments on commit e19dc30

Please sign in to comment.