Skip to content

Commit

Permalink
[#1579] fix(spark): clear out previous stage attempt data synchronous…
Browse files Browse the repository at this point in the history
…ly for stage resubmit
  • Loading branch information
zuston committed Mar 22, 2024
1 parent c3c0c37 commit 4d3a892
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -438,35 +438,18 @@ public <K, V, C> ShuffleHandle registerShuffle(

int requiredShuffleServerNumber =
RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);

// retryInterval must bigger than `rss.server.heartbeat.interval`, or maybe it will return the
// same result
long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
Map<Integer, List<ShuffleServerInfo>> partitionToServers;
try {
partitionToServers =
RetryUtils.retry(
() -> {
ShuffleAssignmentsInfo response =
shuffleWriteClient.getShuffleAssignments(
id.get(),
shuffleId,
dependency.partitioner().numPartitions(),
1,
assignmentTags,
requiredShuffleServerNumber,
estimateTaskConcurrency);
registerShuffleServers(
id.get(), shuffleId, response.getServerToPartitionRanges(), remoteStorage);
return response.getPartitionToServers();
},
retryInterval,
retryTimes);
} catch (Throwable throwable) {
throw new RssException("registerShuffle failed!", throwable);
}

Map<Integer, List<ShuffleServerInfo>> partitionToServers =
requestShuffleAssignment(
shuffleId,
dependency.partitioner().numPartitions(),
1,
requiredShuffleServerNumber,
estimateTaskConcurrency,
failuresShuffleServerIds,
false);

startHeartbeat();

shuffleIdToPartitionNum.putIfAbsent(shuffleId, dependency.partitioner().numPartitions());
Expand Down Expand Up @@ -896,7 +879,8 @@ protected void registerShuffleServers(
String appId,
int shuffleId,
Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges,
RemoteStorageInfo remoteStorage) {
RemoteStorageInfo remoteStorage,
boolean isStageRetry) {
if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
return;
}
Expand All @@ -914,7 +898,8 @@ protected void registerShuffleServers(
entry.getValue(),
remoteStorage,
dataDistributionType,
maxConcurrencyPerPartitionToWrite);
maxConcurrencyPerPartitionToWrite,
isStageRetry);
});
LOG.info(
"Finish register shuffleId["
Expand Down Expand Up @@ -1146,16 +1131,20 @@ public synchronized boolean reassignAllShuffleServersForWholeStage(
int requiredShuffleServerNumber =
RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
/** Before reassigning ShuffleServer, clear the ShuffleServer list in ShuffleWriteClient. */
shuffleWriteClient.unregisterShuffle(id.get(), shuffleId);

/**
* this will clear up the previous stage attempt all data when registering the same shuffleId
* at the second time
*/
Map<Integer, List<ShuffleServerInfo>> partitionToServers =
requestShuffleAssignment(
shuffleId,
numPartitions,
1,
requiredShuffleServerNumber,
estimateTaskConcurrency,
failuresShuffleServerIds);
failuresShuffleServerIds,
true);
/**
* we need to clear the metadata of the completed task, otherwise some of the stage's data
* will be lost
Expand Down Expand Up @@ -1219,7 +1208,7 @@ private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffl
Set<String> faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
faultyServerIds.addAll(failuresShuffleServerIds);
Map<Integer, List<ShuffleServerInfo>> partitionToServers =
requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds);
requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds, false);
if (partitionToServers.get(0) != null && partitionToServers.get(0).size() == 1) {
return partitionToServers.get(0).get(0);
}
Expand All @@ -1232,10 +1221,13 @@ private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
int partitionNumPerRange,
int assignmentShuffleServerNumber,
int estimateTaskConcurrency,
Set<String> faultyServerIds) {
Set<String> faultyServerIds,
boolean isStageRetry) {
Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
ClientUtils.validateClientType(clientType);
assignmentTags.add(clientType);
// retryInterval must bigger than `rss.server.heartbeat.interval`, or maybe it will return the
// same result
long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
faultyServerIds.addAll(failuresShuffleServerIds);
Expand All @@ -1253,7 +1245,11 @@ private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
estimateTaskConcurrency,
faultyServerIds);
registerShuffleServers(
id.get(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo());
id.get(),
shuffleId,
response.getServerToPartitionRanges(),
getRemoteStorageInfo(),
isStageRetry);
return response.getPartitionToServers();
},
retryInterval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,35 @@ SendShuffleDataResult sendShuffleData(

void registerApplicationInfo(String appId, long timeoutMs, String user);

default void registerShuffle(
ShuffleServerInfo shuffleServerInfo,
String appId,
int shuffleId,
List<PartitionRange> partitionRanges,
RemoteStorageInfo remoteStorage,
ShuffleDataDistributionType dataDistributionType,
int maxConcurrencyPerPartitionToWrite) {
registerShuffle(
shuffleServerInfo,
appId,
shuffleId,
partitionRanges,
remoteStorage,
dataDistributionType,
maxConcurrencyPerPartitionToWrite,
false
);
}

void registerShuffle(
ShuffleServerInfo shuffleServerInfo,
String appId,
int shuffleId,
List<PartitionRange> partitionRanges,
RemoteStorageInfo remoteStorage,
ShuffleDataDistributionType dataDistributionType,
int maxConcurrencyPerPartitionToWrite);
int maxConcurrencyPerPartitionToWrite,
boolean isStageRetry);

boolean sendCommit(
Set<ShuffleServerInfo> shuffleServerInfoSet, String appId, int shuffleId, int numMaps);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,8 @@ public void registerShuffle(
List<PartitionRange> partitionRanges,
RemoteStorageInfo remoteStorage,
ShuffleDataDistributionType dataDistributionType,
int maxConcurrencyPerPartitionToWrite) {
int maxConcurrencyPerPartitionToWrite,
boolean isStageRetry) {
String user = null;
try {
user = UserGroupInformation.getCurrentUser().getShortUserName();
Expand All @@ -545,7 +546,8 @@ public void registerShuffle(
remoteStorage,
user,
dataDistributionType,
maxConcurrencyPerPartitionToWrite);
maxConcurrencyPerPartitionToWrite,
isStageRetry);
RssRegisterShuffleResponse response =
getShuffleServerClient(shuffleServerInfo).registerShuffle(request);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,17 @@ private ShuffleRegisterResponse doRegisterShuffle(
RemoteStorageInfo remoteStorageInfo,
String user,
ShuffleDataDistributionType dataDistributionType,
int maxConcurrencyPerPartitionToWrite) {
int maxConcurrencyPerPartitionToWrite,
boolean isStageRetry) {
ShuffleRegisterRequest.Builder reqBuilder = ShuffleRegisterRequest.newBuilder();
reqBuilder
.setAppId(appId)
.setShuffleId(shuffleId)
.setUser(user)
.setShuffleDataDistribution(RssProtos.DataDistribution.valueOf(dataDistributionType.name()))
.setMaxConcurrencyPerPartitionToWrite(maxConcurrencyPerPartitionToWrite)
.addAllPartitionRanges(toShufflePartitionRanges(partitionRanges));
.addAllPartitionRanges(toShufflePartitionRanges(partitionRanges))
.setIsStageRetry(isStageRetry);
RemoteStorage.Builder rsBuilder = RemoteStorage.newBuilder();
rsBuilder.setPath(remoteStorageInfo.getPath());
Map<String, String> remoteStorageConf = remoteStorageInfo.getConfItems();
Expand Down Expand Up @@ -400,7 +402,8 @@ public RssRegisterShuffleResponse registerShuffle(RssRegisterShuffleRequest requ
request.getRemoteStorageInfo(),
request.getUser(),
request.getDataDistributionType(),
request.getMaxConcurrencyPerPartitionToWrite());
request.getMaxConcurrencyPerPartitionToWrite(),
request.isStageRetry());

RssRegisterShuffleResponse response;
RssProtos.StatusCode statusCode = rpcResponse.getStatus();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class RssRegisterShuffleRequest {
private String user;
private ShuffleDataDistributionType dataDistributionType;
private int maxConcurrencyPerPartitionToWrite;
private boolean isStageRetry;

public RssRegisterShuffleRequest(
String appId,
Expand All @@ -43,14 +44,16 @@ public RssRegisterShuffleRequest(
RemoteStorageInfo remoteStorageInfo,
String user,
ShuffleDataDistributionType dataDistributionType,
int maxConcurrencyPerPartitionToWrite) {
int maxConcurrencyPerPartitionToWrite,
boolean isStageRetry) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionRanges = partitionRanges;
this.remoteStorageInfo = remoteStorageInfo;
this.user = user;
this.dataDistributionType = dataDistributionType;
this.maxConcurrencyPerPartitionToWrite = maxConcurrencyPerPartitionToWrite;
this.isStageRetry = isStageRetry;
}

public RssRegisterShuffleRequest(
Expand All @@ -67,7 +70,8 @@ public RssRegisterShuffleRequest(
remoteStorageInfo,
user,
dataDistributionType,
RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue());
RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(),
false);
}

public RssRegisterShuffleRequest(
Expand All @@ -79,7 +83,8 @@ public RssRegisterShuffleRequest(
new RemoteStorageInfo(remoteStoragePath),
StringUtils.EMPTY,
ShuffleDataDistributionType.NORMAL,
RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue());
RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(),
false);
}

public String getAppId() {
Expand Down Expand Up @@ -109,4 +114,8 @@ public ShuffleDataDistributionType getDataDistributionType() {
public int getMaxConcurrencyPerPartitionToWrite() {
return maxConcurrencyPerPartitionToWrite;
}

public boolean isStageRetry() {
return isStageRetry;
}
}
1 change: 1 addition & 0 deletions proto/src/main/proto/Rss.proto
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ message ShuffleRegisterRequest {
string user = 5;
DataDistribution shuffleDataDistribution = 6;
int32 maxConcurrencyPerPartitionToWrite = 7;
bool isStageRetry = 8;
}

enum DataDistribution {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,30 @@ public void registerShuffle(
String remoteStoragePath = req.getRemoteStorage().getPath();
String user = req.getUser();

if (req.getIsStageRetry()) {
try {
long start = System.currentTimeMillis();
shuffleServer.getShuffleTaskManager().removeShuffleDataSync(appId, shuffleId);
LOG.info(
"Deleted the previous stage attempt data due to stage recomputing for app: {}, "
+ "shuffleId: {}. It costs {} ms",
appId,
shuffleId,
System.currentTimeMillis() - start);
} catch (Exception e) {
LOG.error(
"Errors on clearing previous stage attempt data for app: {}, shuffleId: {}",
appId,
shuffleId,
e);
StatusCode code = StatusCode.INTERNAL_ERROR;
reply = ShuffleRegisterResponse.newBuilder().setStatus(code.toProto()).build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
return;
}
}

ShuffleDataDistributionType shuffleDataDistributionType =
ShuffleDataDistributionType.valueOf(
Optional.ofNullable(req.getShuffleDataDistribution())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ public void removeShuffleDataAsync(String appId) {
}

@VisibleForTesting
void removeShuffleDataSync(String appId, int shuffleId) {
public void removeShuffleDataSync(String appId, int shuffleId) {
removeResourcesByShuffleIds(appId, Arrays.asList(shuffleId));
}

Expand Down

0 comments on commit 4d3a892

Please sign in to comment.