From b1920956881ea63eef7f76467f745b250e85c435 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Fri, 22 Mar 2024 16:24:36 +0800 Subject: [PATCH] [#1579] fix(spark): clear out previous stage attempt data synchronously for stage resubmit --- .../spark/shuffle/RssShuffleManager.java | 68 +++++++++---------- .../client/api/ShuffleWriteClient.java | 3 +- .../client/impl/ShuffleWriteClientImpl.java | 6 +- .../impl/grpc/ShuffleServerGrpcClient.java | 9 ++- .../request/RssRegisterShuffleRequest.java | 15 +++- proto/src/main/proto/Rss.proto | 1 + .../server/ShuffleServerGrpcService.java | 24 +++++++ .../uniffle/server/ShuffleTaskManager.java | 2 +- 8 files changed, 82 insertions(+), 46 deletions(-) diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index b360807b40..b65ce2635e 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -438,35 +438,18 @@ public ShuffleHandle registerShuffle( int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); - - // retryInterval must bigger than `rss.server.heartbeat.interval`, or maybe it will return the - // same result - long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); - int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); - Map> partitionToServers; - try { - partitionToServers = - RetryUtils.retry( - () -> { - ShuffleAssignmentsInfo response = - shuffleWriteClient.getShuffleAssignments( - id.get(), - shuffleId, - dependency.partitioner().numPartitions(), - 1, - assignmentTags, - requiredShuffleServerNumber, - estimateTaskConcurrency); - registerShuffleServers( - id.get(), shuffleId, response.getServerToPartitionRanges(), remoteStorage); - return response.getPartitionToServers(); - }, - retryInterval, - retryTimes); - } catch (Throwable throwable) { - throw new RssException("registerShuffle failed!", throwable); - } + + Map> partitionToServers = + requestShuffleAssignment( + shuffleId, + dependency.partitioner().numPartitions(), + 1, + requiredShuffleServerNumber, + estimateTaskConcurrency, + failuresShuffleServerIds, + false); + startHeartbeat(); shuffleIdToPartitionNum.putIfAbsent(shuffleId, dependency.partitioner().numPartitions()); @@ -896,7 +879,8 @@ protected void registerShuffleServers( String appId, int shuffleId, Map> serverToPartitionRanges, - RemoteStorageInfo remoteStorage) { + RemoteStorageInfo remoteStorage, + boolean isStageRetry) { if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) { return; } @@ -914,7 +898,8 @@ protected void registerShuffleServers( entry.getValue(), remoteStorage, dataDistributionType, - maxConcurrencyPerPartitionToWrite); + maxConcurrencyPerPartitionToWrite, + isStageRetry); }); LOG.info( "Finish register shuffleId[" @@ -1146,8 +1131,11 @@ public synchronized boolean reassignAllShuffleServersForWholeStage( int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); - /** Before reassigning ShuffleServer, clear the ShuffleServer list in ShuffleWriteClient. */ - shuffleWriteClient.unregisterShuffle(id.get(), shuffleId); + + /** + * this will clear up the previous stage attempt all data when registering the same shuffleId + * at the second time + */ Map> partitionToServers = requestShuffleAssignment( shuffleId, @@ -1155,7 +1143,8 @@ public synchronized boolean reassignAllShuffleServersForWholeStage( 1, requiredShuffleServerNumber, estimateTaskConcurrency, - failuresShuffleServerIds); + failuresShuffleServerIds, + true); /** * we need to clear the metadata of the completed task, otherwise some of the stage's data * will be lost @@ -1219,7 +1208,7 @@ private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffl Set faultyServerIds = Sets.newHashSet(faultyShuffleServerId); faultyServerIds.addAll(failuresShuffleServerIds); Map> partitionToServers = - requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds); + requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds, false); if (partitionToServers.get(0) != null && partitionToServers.get(0).size() == 1) { return partitionToServers.get(0).get(0); } @@ -1232,10 +1221,13 @@ private Map> requestShuffleAssignment( int partitionNumPerRange, int assignmentShuffleServerNumber, int estimateTaskConcurrency, - Set faultyServerIds) { + Set faultyServerIds, + boolean isStageRetry) { Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); ClientUtils.validateClientType(clientType); assignmentTags.add(clientType); + // retryInterval must bigger than `rss.server.heartbeat.interval`, or maybe it will return the + // same result long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); faultyServerIds.addAll(failuresShuffleServerIds); @@ -1253,7 +1245,11 @@ private Map> requestShuffleAssignment( estimateTaskConcurrency, faultyServerIds); registerShuffleServers( - id.get(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo()); + id.get(), + shuffleId, + response.getServerToPartitionRanges(), + getRemoteStorageInfo(), + isStageRetry); return response.getPartitionToServers(); }, retryInterval, diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java index 88d97c336a..77390d72fb 100644 --- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java +++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java @@ -50,7 +50,8 @@ void registerShuffle( List partitionRanges, RemoteStorageInfo remoteStorage, ShuffleDataDistributionType dataDistributionType, - int maxConcurrencyPerPartitionToWrite); + int maxConcurrencyPerPartitionToWrite, + boolean isStageRetry); boolean sendCommit( Set shuffleServerInfoSet, String appId, int shuffleId, int numMaps); diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index 129dadc173..6531372908 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -529,7 +529,8 @@ public void registerShuffle( List partitionRanges, RemoteStorageInfo remoteStorage, ShuffleDataDistributionType dataDistributionType, - int maxConcurrencyPerPartitionToWrite) { + int maxConcurrencyPerPartitionToWrite, + boolean isStageRetry) { String user = null; try { user = UserGroupInformation.getCurrentUser().getShortUserName(); @@ -545,7 +546,8 @@ public void registerShuffle( remoteStorage, user, dataDistributionType, - maxConcurrencyPerPartitionToWrite); + maxConcurrencyPerPartitionToWrite, + isStageRetry); RssRegisterShuffleResponse response = getShuffleServerClient(shuffleServerInfo).registerShuffle(request); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index 5a6919e44a..21c2c4c127 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -159,7 +159,8 @@ private ShuffleRegisterResponse doRegisterShuffle( RemoteStorageInfo remoteStorageInfo, String user, ShuffleDataDistributionType dataDistributionType, - int maxConcurrencyPerPartitionToWrite) { + int maxConcurrencyPerPartitionToWrite, + boolean isStageRetry) { ShuffleRegisterRequest.Builder reqBuilder = ShuffleRegisterRequest.newBuilder(); reqBuilder .setAppId(appId) @@ -167,7 +168,8 @@ private ShuffleRegisterResponse doRegisterShuffle( .setUser(user) .setShuffleDataDistribution(RssProtos.DataDistribution.valueOf(dataDistributionType.name())) .setMaxConcurrencyPerPartitionToWrite(maxConcurrencyPerPartitionToWrite) - .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges)); + .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges)) + .setIsStageRetry(isStageRetry); RemoteStorage.Builder rsBuilder = RemoteStorage.newBuilder(); rsBuilder.setPath(remoteStorageInfo.getPath()); Map remoteStorageConf = remoteStorageInfo.getConfItems(); @@ -400,7 +402,8 @@ public RssRegisterShuffleResponse registerShuffle(RssRegisterShuffleRequest requ request.getRemoteStorageInfo(), request.getUser(), request.getDataDistributionType(), - request.getMaxConcurrencyPerPartitionToWrite()); + request.getMaxConcurrencyPerPartitionToWrite(), + request.isStageRetry()); RssRegisterShuffleResponse response; RssProtos.StatusCode statusCode = rpcResponse.getStatus(); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java index 2cd49bb6d3..f39ba412ec 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java @@ -35,6 +35,7 @@ public class RssRegisterShuffleRequest { private String user; private ShuffleDataDistributionType dataDistributionType; private int maxConcurrencyPerPartitionToWrite; + private boolean isStageRetry; public RssRegisterShuffleRequest( String appId, @@ -43,7 +44,8 @@ public RssRegisterShuffleRequest( RemoteStorageInfo remoteStorageInfo, String user, ShuffleDataDistributionType dataDistributionType, - int maxConcurrencyPerPartitionToWrite) { + int maxConcurrencyPerPartitionToWrite, + boolean isStageRetry) { this.appId = appId; this.shuffleId = shuffleId; this.partitionRanges = partitionRanges; @@ -51,6 +53,7 @@ public RssRegisterShuffleRequest( this.user = user; this.dataDistributionType = dataDistributionType; this.maxConcurrencyPerPartitionToWrite = maxConcurrencyPerPartitionToWrite; + this.isStageRetry = isStageRetry; } public RssRegisterShuffleRequest( @@ -67,7 +70,8 @@ public RssRegisterShuffleRequest( remoteStorageInfo, user, dataDistributionType, - RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue()); + RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), + false); } public RssRegisterShuffleRequest( @@ -79,7 +83,8 @@ public RssRegisterShuffleRequest( new RemoteStorageInfo(remoteStoragePath), StringUtils.EMPTY, ShuffleDataDistributionType.NORMAL, - RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue()); + RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), + false); } public String getAppId() { @@ -109,4 +114,8 @@ public ShuffleDataDistributionType getDataDistributionType() { public int getMaxConcurrencyPerPartitionToWrite() { return maxConcurrencyPerPartitionToWrite; } + + public boolean isStageRetry() { + return isStageRetry; + } } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index ac9c762881..a4b7b576e6 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -184,6 +184,7 @@ message ShuffleRegisterRequest { string user = 5; DataDistribution shuffleDataDistribution = 6; int32 maxConcurrencyPerPartitionToWrite = 7; + bool isStageRetry = 8; } enum DataDistribution { diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index 695c2afc60..2b23ae9206 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -158,6 +158,30 @@ public void registerShuffle( String remoteStoragePath = req.getRemoteStorage().getPath(); String user = req.getUser(); + if (req.getIsStageRetry()) { + try { + long start = System.currentTimeMillis(); + shuffleServer.getShuffleTaskManager().removeShuffleDataSync(appId, shuffleId); + LOG.info( + "Deleted the previous stage attempt data due to stage recomputing for app: {}, " + + "shuffleId: {}. It costs {} ms", + appId, + shuffleId, + System.currentTimeMillis() - start); + } catch (Exception e) { + LOG.error( + "Errors on clearing previous stage attempt data for app: {}, shuffleId: {}", + appId, + shuffleId, + e); + StatusCode code = StatusCode.INTERNAL_ERROR; + reply = ShuffleRegisterResponse.newBuilder().setStatus(code.toProto()).build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } + } + ShuffleDataDistributionType shuffleDataDistributionType = ShuffleDataDistributionType.valueOf( Optional.ofNullable(req.getShuffleDataDistribution()) diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index f9ed125b07..c5c8ac971c 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -856,7 +856,7 @@ public void removeShuffleDataAsync(String appId) { } @VisibleForTesting - void removeShuffleDataSync(String appId, int shuffleId) { + public void removeShuffleDataSync(String appId, int shuffleId) { removeResourcesByShuffleIds(appId, Arrays.asList(shuffleId)); }