diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java index 5c9f401b8a..430e2ff584 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java @@ -517,7 +517,8 @@ public void registerShuffle( List partitionRanges, RemoteStorageInfo remoteStorage, ShuffleDataDistributionType distributionType, - int maxConcurrencyPerPartitionToWrite) {} + int maxConcurrencyPerPartitionToWrite, + int stageAttemptNumber) {} @Override public boolean sendCommit( diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java index b858312b0f..7664b47d6d 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java @@ -504,7 +504,8 @@ public void registerShuffle( List partitionRanges, RemoteStorageInfo storageType, ShuffleDataDistributionType distributionType, - int maxConcurrencyPerPartitionToWrite) {} + int maxConcurrencyPerPartitionToWrite, + int stageAttemptNumber) {} @Override public boolean sendCommit( diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmit.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmit.java new file mode 100644 index 0000000000..a0550de4c5 --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmit.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle; + +public class RssStageResubmit { + private String stageIdAndNumber; + private boolean isReassigned; + + public RssStageResubmit(String stageIdAndNumber, boolean isReassigned) { + this.stageIdAndNumber = stageIdAndNumber; + this.isReassigned = isReassigned; + } + + public String getStageIdAndNumber() { + return stageIdAndNumber; + } + + public void setStageIdAndNumber(String stageIdAndNumber) { + this.stageIdAndNumber = stageIdAndNumber; + } + + public boolean isReassigned() { + return isReassigned; + } + + public void setReassigned(boolean reassigned) { + isReassigned = reassigned; + } +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java new file mode 100644 index 0000000000..058453caff --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssStageResubmitManager.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle; + +import java.util.Map; +import java.util.Set; + +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.common.util.JavaUtils; + +public class RssStageResubmitManager { + + private static final Logger LOG = LoggerFactory.getLogger(RssStageResubmitManager.class); + /** A list of shuffleServer for Write failures */ + private Set failuresShuffleServerIds; + /** + * Prevent multiple tasks from reporting FetchFailed, resulting in multiple ShuffleServer + * assignments, stageID, Attemptnumber Whether to reassign the combination flag; + */ + private Map serverAssignedInfos; + + public RssStageResubmitManager() { + this.failuresShuffleServerIds = Sets.newConcurrentHashSet(); + this.serverAssignedInfos = JavaUtils.newConcurrentMap(); + } + + public Set getFailuresShuffleServerIds() { + return failuresShuffleServerIds; + } + + public void setFailuresShuffleServerIds(Set failuresShuffleServerIds) { + this.failuresShuffleServerIds = failuresShuffleServerIds; + } + + public void recordFailuresShuffleServer(String shuffleServerId) { + failuresShuffleServerIds.add(shuffleServerId); + } + + public RssStageResubmit recordAndGetServerAssignedInfo(int shuffleId, String stageIdAndAttempt) { + + return serverAssignedInfos.computeIfAbsent( + shuffleId, id -> new RssStageResubmit(stageIdAndAttempt, false)); + } + + public void recordAndGetServerAssignedInfo( + int shuffleId, String stageIdAndAttempt, boolean isRetried) { + serverAssignedInfos + .computeIfAbsent(shuffleId, id -> new RssStageResubmit(stageIdAndAttempt, false)) + .setReassigned(isRetried); + } +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java new file mode 100644 index 0000000000..167227b88e --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.handle; + +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import com.google.common.collect.Lists; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking; +import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.proto.RssProtos; + +public class StageAttemptShuffleHandleInfo extends ShuffleHandleInfoBase { + private static final Logger LOGGER = LoggerFactory.getLogger(StageAttemptShuffleHandleInfo.class); + + private ShuffleHandleInfo current; + private LinkedList historyHandles; + + public StageAttemptShuffleHandleInfo(ShuffleHandleInfo shuffleServerInfo) { + super(0, null); + this.current = shuffleServerInfo; + this.historyHandles = Lists.newLinkedList(); + } + + public StageAttemptShuffleHandleInfo( + ShuffleHandleInfo currentShuffleServerInfo, LinkedList historyHandles) { + super(0, null); + this.current = currentShuffleServerInfo; + this.historyHandles = historyHandles; + } + + @Override + public Set getServers() { + return current.getServers(); + } + + @Override + public Map> getAvailablePartitionServersForWriter() { + return current.getAvailablePartitionServersForWriter(); + } + + @Override + public Map> getAllPartitionServersForReader() { + return current.getAllPartitionServersForReader(); + } + + @Override + public PartitionDataReplicaRequirementTracking createPartitionReplicaTracking() { + return current.createPartitionReplicaTracking(); + } + + /** + * When a Stage retry occurs, replace the current shuffleHandleInfo and record the historical + * shuffleHandleInfo. + */ + public void replaceCurrentShuffleHandleInfo(ShuffleHandleInfo shuffleHandleInfo) { + this.historyHandles.add(current); + this.current = shuffleHandleInfo; + } + + public ShuffleHandleInfo getCurrent() { + return current; + } + + public LinkedList getHistoryHandles() { + return historyHandles; + } + + public static RssProtos.StageAttemptShuffleHandleInfo toProto( + StageAttemptShuffleHandleInfo handleInfo) { + synchronized (handleInfo) { + LinkedList mutableShuffleHandleInfoLinkedList = + Lists.newLinkedList(); + RssProtos.MutableShuffleHandleInfo currentMutableShuffleHandleInfo = + MutableShuffleHandleInfo.toProto((MutableShuffleHandleInfo) handleInfo.getCurrent()); + for (ShuffleHandleInfo historyHandle : handleInfo.getHistoryHandles()) { + mutableShuffleHandleInfoLinkedList.add( + MutableShuffleHandleInfo.toProto((MutableShuffleHandleInfo) historyHandle)); + } + RssProtos.StageAttemptShuffleHandleInfo handleProto = + RssProtos.StageAttemptShuffleHandleInfo.newBuilder() + .setCurrentMutableShuffleHandleInfo(currentMutableShuffleHandleInfo) + .addAllHistoryMutableShuffleHandleInfo(mutableShuffleHandleInfoLinkedList) + .build(); + return handleProto; + } + } + + public static StageAttemptShuffleHandleInfo fromProto( + RssProtos.StageAttemptShuffleHandleInfo handleProto) { + if (handleProto == null) { + return null; + } + + MutableShuffleHandleInfo mutableShuffleHandleInfo = + MutableShuffleHandleInfo.fromProto(handleProto.getCurrentMutableShuffleHandleInfo()); + List historyMutableShuffleHandleInfoList = + handleProto.getHistoryMutableShuffleHandleInfoList(); + LinkedList historyHandles = Lists.newLinkedList(); + for (RssProtos.MutableShuffleHandleInfo shuffleHandleInfo : + historyMutableShuffleHandleInfoList) { + historyHandles.add(MutableShuffleHandleInfo.fromProto(shuffleHandleInfo)); + } + + StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo = + new StageAttemptShuffleHandleInfo(mutableShuffleHandleInfo, historyHandles); + return stageAttemptShuffleHandleInfo; + } +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java index 9751ba0b89..f989fdb0b1 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java @@ -25,11 +25,18 @@ public class AddBlockEvent { private String taskId; + private int stageAttemptNumber; private List shuffleDataInfoList; private List processedCallbackChain; public AddBlockEvent(String taskId, List shuffleDataInfoList) { + this(taskId, 0, shuffleDataInfoList); + } + + public AddBlockEvent( + String taskId, int stageAttemptNumber, List shuffleDataInfoList) { this.taskId = taskId; + this.stageAttemptNumber = stageAttemptNumber; this.shuffleDataInfoList = shuffleDataInfoList; this.processedCallbackChain = new ArrayList<>(); } @@ -43,6 +50,10 @@ public String getTaskId() { return taskId; } + public int getStageAttemptNumber() { + return stageAttemptNumber; + } + public List getShuffleDataInfoList() { return shuffleDataInfoList; } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java index e9ef2ba614..bdf0cf8496 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java @@ -92,7 +92,10 @@ public CompletableFuture send(AddBlockEvent event) { try { result = shuffleWriteClient.sendShuffleData( - rssAppId, shuffleBlockInfoList, () -> !isValidTask(taskId)); + rssAppId, + event.getStageAttemptNumber(), + shuffleBlockInfoList, + () -> !isValidTask(taskId)); putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds()); putFailedBlockSendTracker( taskToFailedBlockSendTracker, taskId, result.getFailedBlockSendTracker()); diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index 77cb173e37..2a26f401bf 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -19,16 +19,24 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import java.util.stream.Collectors; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.apache.commons.collections.CollectionUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.security.UserGroupInformation; import org.apache.spark.MapOutputTracker; @@ -38,22 +46,41 @@ import org.apache.spark.SparkException; import org.apache.spark.shuffle.RssSparkConfig; import org.apache.spark.shuffle.RssSparkShuffleUtils; +import org.apache.spark.shuffle.RssStageResubmit; +import org.apache.spark.shuffle.RssStageResubmitManager; +import org.apache.spark.shuffle.ShuffleHandleInfoManager; import org.apache.spark.shuffle.ShuffleManager; import org.apache.spark.shuffle.SparkVersionUtils; +import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo; +import org.apache.spark.shuffle.handle.ShuffleHandleInfo; +import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.uniffle.client.api.CoordinatorClient; +import org.apache.uniffle.client.api.ShuffleManagerClient; +import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.CoordinatorClientFactory; +import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; import org.apache.uniffle.client.request.RssFetchClientConfRequest; +import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest; import org.apache.uniffle.client.response.RssFetchClientConfResponse; +import org.apache.uniffle.client.response.RssPartitionToShuffleServerWithStageRetryResponse; +import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse; +import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.ClientType; +import org.apache.uniffle.common.PartitionRange; +import org.apache.uniffle.common.ReceivingFailureServer; import org.apache.uniffle.common.RemoteStorageInfo; +import org.apache.uniffle.common.ShuffleAssignmentsInfo; +import org.apache.uniffle.common.ShuffleDataDistributionType; +import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.ConfigOption; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.shuffle.BlockIdManager; import static org.apache.uniffle.common.config.RssClientConf.HADOOP_CONFIG_KEY_PREFIX; @@ -65,7 +92,34 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac private Method unregisterAllMapOutputMethod; private Method registerShuffleMethod; private volatile BlockIdManager blockIdManager; + protected ShuffleDataDistributionType dataDistributionType; private Object blockIdManagerLock = new Object(); + protected AtomicReference id = new AtomicReference<>(); + protected String appId = ""; + protected ShuffleWriteClient shuffleWriteClient; + protected boolean dynamicConfEnabled; + protected int maxConcurrencyPerPartitionToWrite; + protected String clientType; + + protected SparkConf sparkConf; + protected ShuffleManagerClient shuffleManagerClient; + /** Whether to enable the dynamic shuffleServer function rewrite and reread functions */ + protected boolean rssResubmitStage; + /** + * Mapping between ShuffleId and ShuffleServer list. ShuffleServer list is dynamically allocated. + * ShuffleServer is not obtained from RssShuffleHandle, but from this mapping. + */ + protected ShuffleHandleInfoManager shuffleHandleInfoManager; + + protected RssStageResubmitManager rssStageResubmitManager; + + protected int partitionReassignMaxServerNum; + + protected boolean blockIdSelfManagedEnabled; + + protected boolean taskBlockSendFailureRetryEnabled; + + protected boolean shuffleManagerRpcServiceEnabled; public BlockIdManager getBlockIdManager() { if (blockIdManager == null) { @@ -520,4 +574,417 @@ protected static RemoteStorageInfo getDefaultRemoteStorageInfo(SparkConf sparkCo return new RemoteStorageInfo( sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""), confItems); } + + /** + * In Stage Retry mode, obtain the Shuffle Server list from the Driver based on shuffleId. + * + * @param shuffleId shuffleId + * @return ShuffleHandleInfo + */ + protected synchronized StageAttemptShuffleHandleInfo getRemoteShuffleHandleInfoWithStageRetry( + int shuffleId) { + RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = + new RssPartitionToShuffleServerRequest(shuffleId); + RssPartitionToShuffleServerWithStageRetryResponse rpcPartitionToShufflerServer = + getOrCreateShuffleManagerClient() + .getPartitionToShufflerServerWithStageRetry(rssPartitionToShuffleServerRequest); + StageAttemptShuffleHandleInfo shuffleHandleInfo = + StageAttemptShuffleHandleInfo.fromProto( + rpcPartitionToShufflerServer.getShuffleHandleInfoProto()); + return shuffleHandleInfo; + } + + /** + * In Block Retry mode, obtain the Shuffle Server list from the Driver based on shuffleId. + * + * @param shuffleId shuffleId + * @return ShuffleHandleInfo + */ + protected synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfoWithBlockRetry( + int shuffleId) { + RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = + new RssPartitionToShuffleServerRequest(shuffleId); + RssReassignOnBlockSendFailureResponse rpcPartitionToShufflerServer = + getOrCreateShuffleManagerClient() + .getPartitionToShufflerServerWithBlockRetry(rssPartitionToShuffleServerRequest); + MutableShuffleHandleInfo shuffleHandleInfo = + MutableShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getHandle()); + return shuffleHandleInfo; + } + + // todo: automatic close client when the client is idle to avoid too much connections for spark + // driver. + protected ShuffleManagerClient getOrCreateShuffleManagerClient() { + if (shuffleManagerClient == null) { + RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); + String driver = rssConf.getString("driver.host", ""); + int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); + this.shuffleManagerClient = + ShuffleManagerClientFactory.getInstance() + .createShuffleManagerClient(ClientType.GRPC, driver, port); + } + return shuffleManagerClient; + } + + @Override + public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) { + return shuffleHandleInfoManager.get(shuffleId); + } + + /** + * @return the maximum number of fetch failures per shuffle partition before that shuffle stage + * should be recomputed + */ + @Override + public int getMaxFetchFailures() { + final String TASK_MAX_FAILURE = "spark.task.maxFailures"; + return Math.max(1, sparkConf.getInt(TASK_MAX_FAILURE, 4) - 1); + } + + /** + * Add the shuffleServer that failed to write to the failure list + * + * @param shuffleServerId + */ + @Override + public void addFailuresShuffleServerInfos(String shuffleServerId) { + rssStageResubmitManager.recordFailuresShuffleServer(shuffleServerId); + } + + /** + * Reassign the ShuffleServer list for ShuffleId + * + * @param shuffleId + * @param numPartitions + */ + @Override + public boolean reassignOnStageResubmit( + int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) { + String stageIdAndAttempt = stageId + "_" + stageAttemptNumber; + RssStageResubmit rssStageResubmit = + rssStageResubmitManager.recordAndGetServerAssignedInfo(shuffleId, stageIdAndAttempt); + synchronized (rssStageResubmit) { + Boolean needReassign = rssStageResubmit.isReassigned(); + if (!needReassign) { + int requiredShuffleServerNumber = + RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); + int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); + + /** + * this will clear up the previous stage attempt all data when registering the same + * shuffleId at the second time + */ + Map> partitionToServers = + requestShuffleAssignment( + shuffleId, + numPartitions, + 1, + requiredShuffleServerNumber, + estimateTaskConcurrency, + rssStageResubmitManager.getFailuresShuffleServerIds(), + stageAttemptNumber); + /** + * we need to clear the metadata of the completed task, otherwise some of the stage's data + * will be lost + */ + try { + unregisterAllMapOutput(shuffleId); + } catch (SparkException e) { + LOG.error("Clear MapoutTracker Meta failed!"); + throw new RssException("Clear MapoutTracker Meta failed!", e); + } + MutableShuffleHandleInfo shuffleHandleInfo = + new MutableShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo()); + StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo = + (StageAttemptShuffleHandleInfo) shuffleHandleInfoManager.get(shuffleId); + stageAttemptShuffleHandleInfo.replaceCurrentShuffleHandleInfo(shuffleHandleInfo); + rssStageResubmitManager.recordAndGetServerAssignedInfo(shuffleId, stageIdAndAttempt, true); + return true; + } else { + LOG.info( + "The Stage:{} has been reassigned in an Attempt{},Return without performing any operation", + stageId, + stageAttemptNumber); + return false; + } + } + } + + /** this is only valid on driver side that exposed to being invoked by grpc server */ + @Override + public MutableShuffleHandleInfo reassignOnBlockSendFailure( + int shuffleId, Map> partitionToFailureServers) { + long startTime = System.currentTimeMillis(); + MutableShuffleHandleInfo handleInfo = + (MutableShuffleHandleInfo) shuffleHandleInfoManager.get(shuffleId); + synchronized (handleInfo) { + // If the reassignment servers for one partition exceeds the max reassign server num, + // it should fast fail. + handleInfo.checkPartitionReassignServerNum( + partitionToFailureServers.keySet(), partitionReassignMaxServerNum); + + Map> newServerToPartitions = new HashMap<>(); + // receivingFailureServer -> partitionId -> replacementServerIds. For logging + Map>> reassignResult = new HashMap<>(); + + for (Map.Entry> entry : + partitionToFailureServers.entrySet()) { + int partitionId = entry.getKey(); + for (ReceivingFailureServer receivingFailureServer : entry.getValue()) { + StatusCode code = receivingFailureServer.getStatusCode(); + String serverId = receivingFailureServer.getServerId(); + + boolean serverHasReplaced = false; + Set replacements = handleInfo.getReplacements(serverId); + if (CollectionUtils.isEmpty(replacements)) { + final int requiredServerNum = 1; + Set excludedServers = new HashSet<>(handleInfo.listExcludedServers()); + excludedServers.add(serverId); + replacements = + reassignServerForTask( + shuffleId, Sets.newHashSet(partitionId), excludedServers, requiredServerNum); + } else { + serverHasReplaced = true; + } + + Set updatedReassignServers = + handleInfo.updateAssignment(partitionId, serverId, replacements); + + reassignResult + .computeIfAbsent(serverId, x -> new HashMap<>()) + .computeIfAbsent(partitionId, x -> new HashSet<>()) + .addAll( + updatedReassignServers.stream().map(x -> x.getId()).collect(Collectors.toSet())); + + if (serverHasReplaced) { + for (ShuffleServerInfo serverInfo : updatedReassignServers) { + newServerToPartitions + .computeIfAbsent(serverInfo, x -> new ArrayList<>()) + .add(new PartitionRange(partitionId, partitionId)); + } + } + } + } + if (!newServerToPartitions.isEmpty()) { + LOG.info( + "Register the new partition->servers assignment on reassign. {}", + newServerToPartitions); + registerShuffleServers(id.get(), shuffleId, newServerToPartitions, getRemoteStorageInfo()); + } + + LOG.info( + "Finished reassignOnBlockSendFailure request and cost {}(ms). Reassign result: {}", + System.currentTimeMillis() - startTime, + reassignResult); + + return handleInfo; + } + } + + /** + * Creating the shuffleAssignmentInfo from the servers and partitionIds + * + * @param servers + * @param partitionIds + * @return + */ + private ShuffleAssignmentsInfo createShuffleAssignmentsInfo( + Set servers, Set partitionIds) { + Map> newPartitionToServers = new HashMap<>(); + List partitionRanges = new ArrayList<>(); + for (Integer partitionId : partitionIds) { + newPartitionToServers.put(partitionId, new ArrayList<>(servers)); + partitionRanges.add(new PartitionRange(partitionId, partitionId)); + } + Map> serverToPartitionRanges = new HashMap<>(); + for (ShuffleServerInfo server : servers) { + serverToPartitionRanges.put(server, partitionRanges); + } + return new ShuffleAssignmentsInfo(newPartitionToServers, serverToPartitionRanges); + } + + /** Request the new shuffle-servers to replace faulty server. */ + private Set reassignServerForTask( + int shuffleId, + Set partitionIds, + Set excludedServers, + int requiredServerNum) { + AtomicReference> replacementsRef = + new AtomicReference<>(new HashSet<>()); + requestShuffleAssignment( + shuffleId, + requiredServerNum, + 1, + requiredServerNum, + 1, + excludedServers, + shuffleAssignmentsInfo -> { + if (shuffleAssignmentsInfo == null) { + return null; + } + Set replacements = + shuffleAssignmentsInfo.getPartitionToServers().values().stream() + .flatMap(x -> x.stream()) + .collect(Collectors.toSet()); + replacementsRef.set(replacements); + return createShuffleAssignmentsInfo(replacements, partitionIds); + }); + return replacementsRef.get(); + } + + private Map> requestShuffleAssignment( + int shuffleId, + int partitionNum, + int partitionNumPerRange, + int assignmentShuffleServerNumber, + int estimateTaskConcurrency, + Set faultyServerIds, + Function reassignmentHandler) { + Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); + ClientUtils.validateClientType(clientType); + assignmentTags.add(clientType); + long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); + int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); + faultyServerIds.addAll(rssStageResubmitManager.getFailuresShuffleServerIds()); + try { + return RetryUtils.retry( + () -> { + ShuffleAssignmentsInfo response = + shuffleWriteClient.getShuffleAssignments( + id.get(), + shuffleId, + partitionNum, + partitionNumPerRange, + assignmentTags, + assignmentShuffleServerNumber, + estimateTaskConcurrency, + faultyServerIds); + LOG.info("Finished reassign"); + if (reassignmentHandler != null) { + response = reassignmentHandler.apply(response); + } + registerShuffleServers( + id.get(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo()); + return response.getPartitionToServers(); + }, + retryInterval, + retryTimes); + } catch (Throwable throwable) { + throw new RssException("registerShuffle failed!", throwable); + } + } + + protected Map> requestShuffleAssignment( + int shuffleId, + int partitionNum, + int partitionNumPerRange, + int assignmentShuffleServerNumber, + int estimateTaskConcurrency, + Set faultyServerIds, + int stageAttemptNumber) { + Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); + ClientUtils.validateClientType(clientType); + assignmentTags.add(clientType); + + long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); + int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); + faultyServerIds.addAll(rssStageResubmitManager.getFailuresShuffleServerIds()); + try { + return RetryUtils.retry( + () -> { + ShuffleAssignmentsInfo response = + shuffleWriteClient.getShuffleAssignments( + appId, + shuffleId, + partitionNum, + partitionNumPerRange, + assignmentTags, + assignmentShuffleServerNumber, + estimateTaskConcurrency, + faultyServerIds); + registerShuffleServers( + appId, + shuffleId, + response.getServerToPartitionRanges(), + getRemoteStorageInfo(), + stageAttemptNumber); + return response.getPartitionToServers(); + }, + retryInterval, + retryTimes); + } catch (Throwable throwable) { + throw new RssException("registerShuffle failed!", throwable); + } + } + + protected void registerShuffleServers( + String appId, + int shuffleId, + Map> serverToPartitionRanges, + RemoteStorageInfo remoteStorage, + int stageAttemptNumber) { + if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) { + return; + } + LOG.info("Start to register shuffleId {}", shuffleId); + long start = System.currentTimeMillis(); + serverToPartitionRanges.entrySet().stream() + .forEach( + entry -> { + shuffleWriteClient.registerShuffle( + entry.getKey(), + appId, + shuffleId, + entry.getValue(), + remoteStorage, + ShuffleDataDistributionType.NORMAL, + maxConcurrencyPerPartitionToWrite, + stageAttemptNumber); + }); + LOG.info( + "Finish register shuffleId {} with {} ms", shuffleId, (System.currentTimeMillis() - start)); + } + + @VisibleForTesting + protected void registerShuffleServers( + String appId, + int shuffleId, + Map> serverToPartitionRanges, + RemoteStorageInfo remoteStorage) { + if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) { + return; + } + LOG.info("Start to register shuffleId[{}]", shuffleId); + long start = System.currentTimeMillis(); + Set>> entries = + serverToPartitionRanges.entrySet(); + entries.stream() + .forEach( + entry -> { + shuffleWriteClient.registerShuffle( + entry.getKey(), + appId, + shuffleId, + entry.getValue(), + remoteStorage, + dataDistributionType, + maxConcurrencyPerPartitionToWrite); + }); + LOG.info( + "Finish register shuffleId[{}] with {} ms", + shuffleId, + (System.currentTimeMillis() - start)); + } + + protected RemoteStorageInfo getRemoteStorageInfo() { + String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()); + RemoteStorageInfo defaultRemoteStorage = + new RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), "")); + return ClientUtils.fetchRemoteStorage( + appId, defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient); + } + + public boolean isRssResubmitStage() { + return rssResubmitStage; + } } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java index b213600410..52ded5db7d 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java @@ -37,12 +37,6 @@ public interface RssShuffleManagerInterface { /** @return the unique spark id for rss shuffle */ String getAppId(); - /** - * @return the maximum number of fetch failures per shuffle partition before that shuffle stage - * should be re-submitted - */ - int getMaxFetchFailures(); - /** * @param shuffleId the shuffle id to query * @return the num of partitions(a.k.a reduce tasks) for shuffle with shuffle id. @@ -63,6 +57,8 @@ public interface RssShuffleManagerInterface { */ void unregisterAllMapOutput(int shuffleId) throws SparkException; + BlockIdManager getBlockIdManager(); + /** * Get ShuffleHandleInfo with ShuffleId * @@ -71,6 +67,12 @@ public interface RssShuffleManagerInterface { */ ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId); + /** + * @return the maximum number of fetch failures per shuffle partition before that shuffle stage + * should be re-submitted + */ + int getMaxFetchFailures(); + /** * Add the shuffleServer that failed to write to the failure list * @@ -78,11 +80,8 @@ public interface RssShuffleManagerInterface { */ void addFailuresShuffleServerInfos(String shuffleServerId); - boolean reassignAllShuffleServersForWholeStage( - int stageId, int stageAttemptNumber, int shuffleId, int numMaps); + boolean reassignOnStageResubmit(int stageId, int stageAttemptNumber, int shuffleId, int numMaps); MutableShuffleHandleInfo reassignOnBlockSendFailure( int shuffleId, Map> partitionToFailureServers); - - BlockIdManager getBlockIdManager(); } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java index 5aaf23a718..846b00f506 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java @@ -30,6 +30,7 @@ import com.google.protobuf.UnsafeByteOperations; import io.grpc.stub.StreamObserver; import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo; +import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -188,10 +189,37 @@ public void reportShuffleFetchFailure( } @Override - public void getPartitionToShufflerServer( + public void getPartitionToShufflerServerWithStageRetry( RssProtos.PartitionToShuffleServerRequest request, - StreamObserver responseObserver) { - RssProtos.PartitionToShuffleServerResponse reply; + StreamObserver responseObserver) { + RssProtos.PartitionToShuffleServerWithStageRetryResponse reply; + RssProtos.StatusCode code; + int shuffleId = request.getShuffleId(); + StageAttemptShuffleHandleInfo shuffleHandle = + (StageAttemptShuffleHandleInfo) shuffleManager.getShuffleHandleInfoByShuffleId(shuffleId); + if (shuffleHandle != null) { + code = RssProtos.StatusCode.SUCCESS; + reply = + RssProtos.PartitionToShuffleServerWithStageRetryResponse.newBuilder() + .setStatus(code) + .setShuffleHandleInfo(StageAttemptShuffleHandleInfo.toProto(shuffleHandle)) + .build(); + } else { + code = RssProtos.StatusCode.INVALID_REQUEST; + reply = + RssProtos.PartitionToShuffleServerWithStageRetryResponse.newBuilder() + .setStatus(code) + .build(); + } + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + + @Override + public void getPartitionToShufflerServerWithBlockRetry( + RssProtos.PartitionToShuffleServerRequest request, + StreamObserver responseObserver) { + RssProtos.RssReassignOnBlockSendFailureResponse reply; RssProtos.StatusCode code; int shuffleId = request.getShuffleId(); MutableShuffleHandleInfo shuffleHandle = @@ -199,13 +227,13 @@ public void getPartitionToShufflerServer( if (shuffleHandle != null) { code = RssProtos.StatusCode.SUCCESS; reply = - RssProtos.PartitionToShuffleServerResponse.newBuilder() + RssProtos.RssReassignOnBlockSendFailureResponse.newBuilder() .setStatus(code) - .setShuffleHandleInfo(MutableShuffleHandleInfo.toProto(shuffleHandle)) + .setHandle(MutableShuffleHandleInfo.toProto(shuffleHandle)) .build(); } else { code = RssProtos.StatusCode.INVALID_REQUEST; - reply = RssProtos.PartitionToShuffleServerResponse.newBuilder().setStatus(code).build(); + reply = RssProtos.RssReassignOnBlockSendFailureResponse.newBuilder().setStatus(code).build(); } responseObserver.onNext(reply); responseObserver.onCompleted(); @@ -220,7 +248,7 @@ public void reassignShuffleServers( int shuffleId = request.getShuffleId(); int numPartitions = request.getNumPartitions(); boolean needReassign = - shuffleManager.reassignAllShuffleServersForWholeStage( + shuffleManager.reassignOnStageResubmit( stageId, stageAttemptNumber, shuffleId, numPartitions); RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS; RssProtos.ReassignServersReponse reply = diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java index 2a608bd4c9..e070456757 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java @@ -69,6 +69,15 @@ public SendShuffleDataResult sendShuffleData( String appId, List shuffleBlockInfoList, Supplier needCancelRequest) { + return sendShuffleData(appId, 0, shuffleBlockInfoList, needCancelRequest); + } + + @Override + public SendShuffleDataResult sendShuffleData( + String appId, + int stageAttemptNumber, + List shuffleBlockInfoList, + Supplier needCancelRequest) { return fakedShuffleDataResult; } diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java index 66bb26de81..df40e9d161 100644 --- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java +++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java @@ -23,13 +23,11 @@ import java.util.Set; import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo; -import org.apache.spark.shuffle.handle.ShuffleHandleInfoBase; +import org.apache.spark.shuffle.handle.ShuffleHandleInfo; import org.apache.uniffle.common.ReceivingFailureServer; import org.apache.uniffle.shuffle.BlockIdManager; -import static org.mockito.Mockito.mock; - public class DummyRssShuffleManager implements RssShuffleManagerInterface { public Set unregisteredShuffleIds = new LinkedHashSet<>(); @@ -38,11 +36,6 @@ public String getAppId() { return "testAppId"; } - @Override - public int getMaxFetchFailures() { - return 2; - } - @Override public int getPartitionNum(int shuffleId) { return 16; @@ -59,15 +52,25 @@ public void unregisterAllMapOutput(int shuffleId) { } @Override - public ShuffleHandleInfoBase getShuffleHandleInfoByShuffleId(int shuffleId) { + public BlockIdManager getBlockIdManager() { + return null; + } + + @Override + public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) { return null; } + @Override + public int getMaxFetchFailures() { + return 0; + } + @Override public void addFailuresShuffleServerInfos(String shuffleServerId) {} @Override - public boolean reassignAllShuffleServersForWholeStage( + public boolean reassignOnStageResubmit( int stageId, int stageAttemptNumber, int shuffleId, int numMaps) { return false; } @@ -75,11 +78,6 @@ public boolean reassignAllShuffleServersForWholeStage( @Override public MutableShuffleHandleInfo reassignOnBlockSendFailure( int shuffleId, Map> partitionToFailureServers) { - return mock(MutableShuffleHandleInfo.class); - } - - @Override - public BlockIdManager getBlockIdManager() { return null; } } diff --git a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java index 6dc2abbf69..ac3fbda7e3 100644 --- a/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java +++ b/client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcServiceTest.java @@ -35,7 +35,7 @@ public class ShuffleManagerGrpcServiceTest { // create mock of RssShuffleManagerInterface. - private static RssShuffleManagerInterface mockShuffleManager; + private static RssShuffleManagerBase mockShuffleManager; private static final String appId = "app-123"; private static final int maxFetchFailures = 2; private static final int shuffleId = 0; @@ -65,7 +65,7 @@ public void onCompleted() { @BeforeAll public static void setup() { - mockShuffleManager = mock(RssShuffleManagerInterface.class); + mockShuffleManager = mock(RssShuffleManagerBase.class); Mockito.when(mockShuffleManager.getAppId()).thenReturn(appId); Mockito.when(mockShuffleManager.getNumMaps(shuffleId)).thenReturn(numMaps); Mockito.when(mockShuffleManager.getPartitionNum(shuffleId)).thenReturn(numReduces); diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 6f5b255be9..59b6eb2539 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -37,13 +37,13 @@ import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; -import org.apache.spark.SparkException; import org.apache.spark.TaskContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo; import org.apache.spark.shuffle.handle.ShuffleHandleInfo; import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo; +import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo; import org.apache.spark.shuffle.reader.RssShuffleReader; import org.apache.spark.shuffle.writer.AddBlockEvent; import org.apache.spark.shuffle.writer.DataPusher; @@ -54,20 +54,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.uniffle.client.api.ShuffleManagerClient; -import org.apache.uniffle.client.api.ShuffleWriteClient; -import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; import org.apache.uniffle.client.impl.FailedBlockSendTracker; -import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest; -import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse; import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.client.util.RssClientConfig; -import org.apache.uniffle.common.ClientType; -import org.apache.uniffle.common.PartitionRange; -import org.apache.uniffle.common.ReceivingFailureServer; import org.apache.uniffle.common.RemoteStorageInfo; -import org.apache.uniffle.common.ShuffleAssignmentsInfo; -import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; @@ -76,7 +66,6 @@ import org.apache.uniffle.common.rpc.GrpcServer; import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.JavaUtils; -import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.common.util.RssUtils; import org.apache.uniffle.common.util.ThreadUtils; import org.apache.uniffle.shuffle.RssShuffleClientFactory; @@ -94,10 +83,6 @@ public class RssShuffleManager extends RssShuffleManagerBase { private final long heartbeatInterval; private final long heartbeatTimeout; private ScheduledExecutorService heartBeatScheduledExecutorService; - private SparkConf sparkConf; - private String appId = ""; - private String clientType; - private ShuffleWriteClient shuffleWriteClient; private Map> taskToSuccessBlockIds = JavaUtils.newConcurrentMap(); private Map taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap(); @@ -109,41 +94,16 @@ public class RssShuffleManager extends RssShuffleManagerBase { private final int dataCommitPoolSize; private Set failedTaskIds = Sets.newConcurrentHashSet(); private boolean heartbeatStarted = false; - private boolean dynamicConfEnabled; private final int maxFailures; private final boolean speculation; private final BlockIdLayout blockIdLayout; private final String user; private final String uuid; private DataPusher dataPusher; - private final int maxConcurrencyPerPartitionToWrite; - private final Map shuffleIdToPartitionNum = JavaUtils.newConcurrentMap(); private final Map shuffleIdToNumMapTasks = JavaUtils.newConcurrentMap(); private GrpcServer shuffleManagerServer; private ShuffleManagerGrpcService service; - private ShuffleManagerClient shuffleManagerClient; - /** - * Mapping between ShuffleId and ShuffleServer list. ShuffleServer list is dynamically allocated. - * ShuffleServer is not obtained from RssShuffleHandle, but from this mapping. - */ - private Map shuffleIdToShuffleHandleInfo = - JavaUtils.newConcurrentMap(); - /** Whether to enable the dynamic shuffleServer function rewrite and reread functions */ - private boolean rssResubmitStage; - - private boolean taskBlockSendFailureRetry; - - private boolean shuffleManagerRpcServiceEnabled; - /** A list of shuffleServer for Write failures */ - private Set failuresShuffleServerIds = Sets.newHashSet(); - /** - * Prevent multiple tasks from reporting FetchFailed, resulting in multiple ShuffleServer - * assignments, stageID, Attemptnumber Whether to reassign the combination flag; - */ - private Map serverAssignedInfos = JavaUtils.newConcurrentMap(); - - private boolean blockIdSelfManagedEnabled; public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) { @@ -178,15 +138,11 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { this.maxConcurrencyPerPartitionToWrite = RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE); LOG.info( - "Check quorum config [" - + dataReplica - + ":" - + dataReplicaWrite - + ":" - + dataReplicaRead - + ":" - + dataReplicaSkipEnabled - + "]"); + "Check quorum config [{}:{}:{}:{}]", + dataReplica, + dataReplicaWrite, + dataReplicaRead, + dataReplicaSkipEnabled); RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, dataReplicaRead); this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE); @@ -211,10 +167,11 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { this.rssResubmitStage = rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false) && RssSparkShuffleUtils.isStageResubmitSupported(); - this.taskBlockSendFailureRetry = rssConf.getBoolean(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED); + this.taskBlockSendFailureRetryEnabled = + rssConf.getBoolean(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED); this.blockIdSelfManagedEnabled = rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED); this.shuffleManagerRpcServiceEnabled = - taskBlockSendFailureRetry || rssResubmitStage || blockIdSelfManagedEnabled; + taskBlockSendFailureRetryEnabled || rssResubmitStage || blockIdSelfManagedEnabled; if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) { if (isDriver) { heartBeatScheduledExecutorService = @@ -276,6 +233,8 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { poolSize, keepAliveTime); } + this.shuffleHandleInfoManager = new ShuffleHandleInfoManager(); + this.rssStageResubmitManager = new RssStageResubmitManager(); } // This method is called in Spark driver side, @@ -351,8 +310,6 @@ public ShuffleHandle registerShuffle( ClientUtils.fetchRemoteStorage( appId, defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient); - int partitionNumPerRange = sparkConf.get(RssSparkConfig.RSS_PARTITION_NUM_PER_RANGE); - // get all register info according to coordinator's response Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); ClientUtils.validateClientType(clientType); @@ -360,44 +317,32 @@ public ShuffleHandle registerShuffle( int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); + int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); - // retryInterval must bigger than `rss.server.heartbeat.interval`, or maybe it will return the - // same result - long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); - int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); - - Map> partitionToServers; - try { - partitionToServers = - RetryUtils.retry( - () -> { - ShuffleAssignmentsInfo response = - shuffleWriteClient.getShuffleAssignments( - appId, - shuffleId, - dependency.partitioner().numPartitions(), - partitionNumPerRange, - assignmentTags, - requiredShuffleServerNumber, - -1); - registerShuffleServers( - appId, shuffleId, response.getServerToPartitionRanges(), remoteStorage); - return response.getPartitionToServers(); - }, - retryInterval, - retryTimes); - } catch (Throwable throwable) { - throw new RssException("registerShuffle failed!", throwable); - } + Map> partitionToServers = + requestShuffleAssignment( + shuffleId, + dependency.partitioner().numPartitions(), + 1, + requiredShuffleServerNumber, + estimateTaskConcurrency, + rssStageResubmitManager.getFailuresShuffleServerIds(), + 0); startHeartbeat(); shuffleIdToPartitionNum.putIfAbsent(shuffleId, dependency.partitioner().numPartitions()); shuffleIdToNumMapTasks.putIfAbsent(shuffleId, dependency.rdd().partitions().length); - if (shuffleManagerRpcServiceEnabled) { - MutableShuffleHandleInfo handleInfo = + if (shuffleManagerRpcServiceEnabled && rssResubmitStage) { + ShuffleHandleInfo handleInfo = new MutableShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage); - shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo); + StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo = + new StageAttemptShuffleHandleInfo(handleInfo); + shuffleHandleInfoManager.register(shuffleId, stageAttemptShuffleHandleInfo); + } else if (shuffleManagerRpcServiceEnabled && taskBlockSendFailureRetryEnabled) { + ShuffleHandleInfo shuffleHandleInfo = + new MutableShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage); + shuffleHandleInfoManager.register(shuffleId, shuffleHandleInfo); } Broadcast hdlInfoBd = RssSparkShuffleUtils.broadcastShuffleHdlInfo( @@ -433,37 +378,6 @@ private void startHeartbeat() { } } - @VisibleForTesting - protected void registerShuffleServers( - String appId, - int shuffleId, - Map> serverToPartitionRanges, - RemoteStorageInfo remoteStorage) { - if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) { - return; - } - LOG.info("Start to register shuffleId[" + shuffleId + "]"); - long start = System.currentTimeMillis(); - serverToPartitionRanges.entrySet().stream() - .forEach( - entry -> { - shuffleWriteClient.registerShuffle( - entry.getKey(), - appId, - shuffleId, - entry.getValue(), - remoteStorage, - ShuffleDataDistributionType.NORMAL, - maxConcurrencyPerPartitionToWrite); - }); - LOG.info( - "Finish register shuffleId[" - + shuffleId - + "] with " - + (System.currentTimeMillis() - start) - + " ms"); - } - @VisibleForTesting protected void registerCoordinator() { String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key()); @@ -491,9 +405,12 @@ public ShuffleWriter getWriter( int shuffleId = rssHandle.getShuffleId(); String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber(); ShuffleHandleInfo shuffleHandleInfo; - if (shuffleManagerRpcServiceEnabled) { - // Get the ShuffleServer list from the Driver based on the shuffleId - shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId); + if (shuffleManagerRpcServiceEnabled && rssResubmitStage) { + // In Stage Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId + shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId); + } else if (shuffleManagerRpcServiceEnabled && taskBlockSendFailureRetryEnabled) { + // In Block Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId + shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry(shuffleId); } else { shuffleHandleInfo = new SimpleShuffleHandleInfo( @@ -561,9 +478,12 @@ public ShuffleReader getReader( + "]"); start = System.currentTimeMillis(); ShuffleHandleInfo shuffleHandleInfo; - if (shuffleManagerRpcServiceEnabled) { - // Get the ShuffleServer list from the Driver based on the shuffleId - shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId); + if (shuffleManagerRpcServiceEnabled && rssResubmitStage) { + // In Stage Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId. + shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId); + } else if (shuffleManagerRpcServiceEnabled && taskBlockSendFailureRetryEnabled) { + // In Block Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId + shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry(shuffleId); } else { shuffleHandleInfo = new SimpleShuffleHandleInfo( @@ -762,16 +682,6 @@ public String getAppId() { return appId; } - /** - * @return the maximum number of fetch failures per shuffle partition before that shuffle stage - * should be recomputed - */ - @Override - public int getMaxFetchFailures() { - final String TASK_MAX_FAILURE = "spark.task.maxFailures"; - return Math.max(1, sparkConf.getInt(TASK_MAX_FAILURE, 4) - 1); - } - /** * @param shuffleId the shuffleId to query * @return the num of partitions(a.k.a reduce tasks) for shuffle with shuffle id. @@ -806,51 +716,6 @@ private Roaring64NavigableMap getShuffleResult( } } - public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) { - return taskToFailedBlockSendTracker.get(taskId); - } - - @Override - public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) { - return shuffleIdToShuffleHandleInfo.get(shuffleId); - } - - private ShuffleManagerClient createShuffleManagerClient(String host, int port) { - // Host can be inferred from `spark.driver.bindAddress`, which would be set when SparkContext is - // constructed. - return ShuffleManagerClientFactory.getInstance() - .createShuffleManagerClient(ClientType.GRPC, host, port); - } - - private ShuffleManagerClient getOrCreateShuffleManagerClient() { - if (shuffleManagerClient == null) { - RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); - String driver = rssConf.getString("driver.host", ""); - int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); - this.shuffleManagerClient = - ShuffleManagerClientFactory.getInstance() - .createShuffleManagerClient(ClientType.GRPC, driver, port); - } - return shuffleManagerClient; - } - - /** - * Get the ShuffleServer list from the Driver based on the shuffleId - * - * @param shuffleId shuffleId - * @return ShuffleHandleInfo - */ - private synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfo(int shuffleId) { - RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = - new RssPartitionToShuffleServerRequest(shuffleId); - RssPartitionToShuffleServerResponse handleInfoResponse = - getOrCreateShuffleManagerClient() - .getPartitionToShufflerServer(rssPartitionToShuffleServerRequest); - MutableShuffleHandleInfo shuffleHandleInfo = - MutableShuffleHandleInfo.fromProto(handleInfoResponse.getShuffleHandleInfoProto()); - return shuffleHandleInfo; - } - /** * Add the shuffleServer that failed to write to the failure list * @@ -858,122 +723,21 @@ private synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfo(int shu */ @Override public void addFailuresShuffleServerInfos(String shuffleServerId) { - failuresShuffleServerIds.add(shuffleServerId); - } - - /** - * Reassign the ShuffleServer list for ShuffleId - * - * @param shuffleId - * @param numPartitions - */ - @Override - public synchronized boolean reassignAllShuffleServersForWholeStage( - int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) { - String stageIdAndAttempt = stageId + "_" + stageAttemptNumber; - Boolean needReassign = serverAssignedInfos.computeIfAbsent(stageIdAndAttempt, id -> false); - if (!needReassign) { - int requiredShuffleServerNumber = - RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); - int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); - /** Before reassigning ShuffleServer, clear the ShuffleServer list in ShuffleWriteClient. */ - shuffleWriteClient.unregisterShuffle(appId, shuffleId); - Map> partitionToServers = - requestShuffleAssignment( - shuffleId, - numPartitions, - 1, - requiredShuffleServerNumber, - estimateTaskConcurrency, - failuresShuffleServerIds); - /** - * we need to clear the metadata of the completed task, otherwise some of the stage's data - * will be lost - */ - try { - unregisterAllMapOutput(shuffleId); - } catch (SparkException e) { - LOG.error("Clear MapoutTracker Meta failed!"); - throw new RssException("Clear MapoutTracker Meta failed!", e); - } - MutableShuffleHandleInfo handleInfo = - new MutableShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo()); - shuffleIdToShuffleHandleInfo.put(shuffleId, handleInfo); - serverAssignedInfos.put(stageIdAndAttempt, true); - return true; - } else { - LOG.info( - "The Stage:{} has been reassigned in an Attempt{},Return without performing any operation", - stageId, - stageAttemptNumber); - return false; - } + rssStageResubmitManager.recordFailuresShuffleServer(shuffleServerId); } - @Override - public MutableShuffleHandleInfo reassignOnBlockSendFailure( - int shuffleId, Map> partitionToFailureServers) { - throw new RssException("Illegal access for reassignOnBlockSendFailure that is not supported."); + public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) { + return taskToFailedBlockSendTracker.get(taskId); } private ShuffleServerInfo assignShuffleServer(int shuffleId, String faultyShuffleServerId) { Set faultyServerIds = Sets.newHashSet(faultyShuffleServerId); - faultyServerIds.addAll(failuresShuffleServerIds); + faultyServerIds.addAll(rssStageResubmitManager.getFailuresShuffleServerIds()); Map> partitionToServers = - requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds); + requestShuffleAssignment(shuffleId, 1, 1, 1, 1, faultyServerIds, 0); if (partitionToServers.get(0) != null && partitionToServers.get(0).size() == 1) { return partitionToServers.get(0).get(0); } return null; } - - private Map> requestShuffleAssignment( - int shuffleId, - int partitionNum, - int partitionNumPerRange, - int assignmentShuffleServerNumber, - int estimateTaskConcurrency, - Set faultyServerIds) { - Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); - ClientUtils.validateClientType(clientType); - assignmentTags.add(clientType); - - long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); - int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); - faultyServerIds.addAll(failuresShuffleServerIds); - try { - return RetryUtils.retry( - () -> { - ShuffleAssignmentsInfo response = - shuffleWriteClient.getShuffleAssignments( - appId, - shuffleId, - partitionNum, - partitionNumPerRange, - assignmentTags, - assignmentShuffleServerNumber, - estimateTaskConcurrency, - faultyServerIds); - registerShuffleServers( - appId, shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo()); - return response.getPartitionToServers(); - }, - retryInterval, - retryTimes); - } catch (Throwable throwable) { - throw new RssException("registerShuffle failed!", throwable); - } - } - - private RemoteStorageInfo getRemoteStorageInfo() { - String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()); - RemoteStorageInfo defaultRemoteStorage = - new RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), "")); - return ClientUtils.fetchRemoteStorage( - appId, defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient); - } - - public boolean isRssResubmitStage() { - return rssResubmitStage; - } } diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index e06896a531..0daceeb48d 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -542,6 +542,7 @@ private void throwFetchFailedIfNecessary(Exception e) { RssReportShuffleWriteFailureResponse response = shuffleManagerClient.reportShuffleWriteFailure(req); if (response.getReSubmitWholeStage()) { + // The shuffle server is reassigned. RssReassignServersRequest rssReassignServersRequest = new RssReassignServersRequest( taskContext.stageId(), diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index f70e38a7d0..2be1c7708c 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -18,10 +18,7 @@ package org.apache.spark.shuffle; import java.io.IOException; -import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -30,7 +27,6 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; import java.util.stream.Collectors; import scala.Tuple2; @@ -41,13 +37,11 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Sets; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -import org.apache.commons.collections4.CollectionUtils; import org.apache.hadoop.conf.Configuration; import org.apache.spark.MapOutputTracker; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; -import org.apache.spark.SparkException; import org.apache.spark.TaskContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.executor.ShuffleReadMetrics; @@ -55,6 +49,7 @@ import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo; import org.apache.spark.shuffle.handle.ShuffleHandleInfo; import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo; +import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo; import org.apache.spark.shuffle.reader.RssShuffleReader; import org.apache.spark.shuffle.writer.AddBlockEvent; import org.apache.spark.shuffle.writer.DataPusher; @@ -67,19 +62,10 @@ import org.slf4j.LoggerFactory; import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking; -import org.apache.uniffle.client.api.ShuffleManagerClient; -import org.apache.uniffle.client.api.ShuffleWriteClient; -import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; import org.apache.uniffle.client.impl.FailedBlockSendTracker; -import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest; -import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse; import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.client.util.RssClientConfig; -import org.apache.uniffle.common.ClientType; -import org.apache.uniffle.common.PartitionRange; -import org.apache.uniffle.common.ReceivingFailureServer; import org.apache.uniffle.common.RemoteStorageInfo; -import org.apache.uniffle.common.ShuffleAssignmentsInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; @@ -87,10 +73,8 @@ import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssFetchFailedException; import org.apache.uniffle.common.rpc.GrpcServer; -import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.JavaUtils; -import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.common.util.RssUtils; import org.apache.uniffle.common.util.ThreadUtils; import org.apache.uniffle.shuffle.RssShuffleClientFactory; @@ -105,10 +89,8 @@ public class RssShuffleManager extends RssShuffleManagerBase { private static final Logger LOG = LoggerFactory.getLogger(RssShuffleManager.class); - private final String clientType; private final long heartbeatInterval; private final long heartbeatTimeout; - private AtomicReference id = new AtomicReference<>(); private final int dataReplica; private final int dataReplicaWrite; private final int dataReplicaRead; @@ -119,10 +101,7 @@ public class RssShuffleManager extends RssShuffleManagerBase { private final Map taskToFailedBlockSendTracker; private ScheduledExecutorService heartBeatScheduledExecutorService; private boolean heartbeatStarted = false; - private boolean dynamicConfEnabled; - private final ShuffleDataDistributionType dataDistributionType; private final BlockIdLayout blockIdLayout; - private final int maxConcurrencyPerPartitionToWrite; private final int maxFailures; private final boolean speculation; private String user; @@ -135,31 +114,6 @@ public class RssShuffleManager extends RssShuffleManagerBase { private ShuffleManagerGrpcService service; private GrpcServer shuffleManagerServer; - /** used by columnar rss shuffle writer implementation */ - protected SparkConf sparkConf; - - protected ShuffleWriteClient shuffleWriteClient; - - private ShuffleManagerClient shuffleManagerClient; - /** Whether to enable the dynamic shuffleServer function rewrite and reread functions */ - private boolean rssResubmitStage; - - private boolean taskBlockSendFailureRetryEnabled; - - private boolean shuffleManagerRpcServiceEnabled; - /** A list of shuffleServer for Write failures */ - private Set failuresShuffleServerIds; - /** - * Prevent multiple tasks from reporting FetchFailed, resulting in multiple ShuffleServer - * assignments, stageID, Attemptnumber Whether to reassign the combination flag; - */ - private Map serverAssignedInfos; - - private final int partitionReassignMaxServerNum; - - private final ShuffleHandleInfoManager shuffleHandleInfoManager = new ShuffleHandleInfoManager(); - private boolean blockIdSelfManagedEnabled; - public RssShuffleManager(SparkConf conf, boolean isDriver) { this.sparkConf = conf; boolean supportsRelocation = @@ -305,10 +259,10 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { failedTaskIds, poolSize, keepAliveTime); - this.failuresShuffleServerIds = Sets.newHashSet(); - this.serverAssignedInfos = JavaUtils.newConcurrentMap(); this.partitionReassignMaxServerNum = rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM); + this.shuffleHandleInfoManager = new ShuffleHandleInfoManager(); + this.rssStageResubmitManager = new RssStageResubmitManager(); } public CompletableFuture sendData(AddBlockEvent event) { @@ -399,6 +353,8 @@ protected static ShuffleDataDistributionType getDataDistributionType(SparkConf s this.dataPusher = dataPusher; this.partitionReassignMaxServerNum = rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM); + this.shuffleHandleInfoManager = new ShuffleHandleInfoManager(); + this.rssStageResubmitManager = new RssStageResubmitManager(); } // This method is called in Spark driver side, @@ -440,6 +396,7 @@ public ShuffleHandle registerShuffle( if (id.get() == null) { id.compareAndSet(null, SparkEnv.get().conf().getAppId() + "_" + uuid); + appId = id.get(); dataPusher.setRssAppId(id.get()); } LOG.info("Generate application id used in rss: " + id.get()); @@ -474,43 +431,30 @@ public ShuffleHandle registerShuffle( int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); - - // retryInterval must bigger than `rss.server.heartbeat.interval`, or maybe it will return the - // same result - long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); - int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); - Map> partitionToServers; - try { - partitionToServers = - RetryUtils.retry( - () -> { - ShuffleAssignmentsInfo response = - shuffleWriteClient.getShuffleAssignments( - id.get(), - shuffleId, - dependency.partitioner().numPartitions(), - 1, - assignmentTags, - requiredShuffleServerNumber, - estimateTaskConcurrency); - registerShuffleServers( - id.get(), shuffleId, response.getServerToPartitionRanges(), remoteStorage); - return response.getPartitionToServers(); - }, - retryInterval, - retryTimes); - } catch (Throwable throwable) { - throw new RssException("registerShuffle failed!", throwable); - } - startHeartbeat(); + Map> partitionToServers = + requestShuffleAssignment( + shuffleId, + dependency.partitioner().numPartitions(), + 1, + requiredShuffleServerNumber, + estimateTaskConcurrency, + rssStageResubmitManager.getFailuresShuffleServerIds(), + 0); + startHeartbeat(); shuffleIdToPartitionNum.putIfAbsent(shuffleId, dependency.partitioner().numPartitions()); shuffleIdToNumMapTasks.putIfAbsent(shuffleId, dependency.rdd().partitions().length); - if (shuffleManagerRpcServiceEnabled) { - MutableShuffleHandleInfo handleInfo = + if (shuffleManagerRpcServiceEnabled && rssResubmitStage) { + ShuffleHandleInfo shuffleHandleInfo = new MutableShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage); + StageAttemptShuffleHandleInfo handleInfo = + new StageAttemptShuffleHandleInfo(shuffleHandleInfo); shuffleHandleInfoManager.register(shuffleId, handleInfo); + } else if (shuffleManagerRpcServiceEnabled && taskBlockSendFailureRetryEnabled) { + ShuffleHandleInfo shuffleHandleInfo = + new MutableShuffleHandleInfo(shuffleId, partitionToServers, remoteStorage); + shuffleHandleInfoManager.register(shuffleId, shuffleHandleInfo); } Broadcast hdlInfoBd = RssSparkShuffleUtils.broadcastShuffleHdlInfo( @@ -545,9 +489,12 @@ public ShuffleWriter getWriter( writeMetrics = context.taskMetrics().shuffleWriteMetrics(); } ShuffleHandleInfo shuffleHandleInfo; - if (shuffleManagerRpcServiceEnabled) { - // Get the ShuffleServer list from the Driver based on the shuffleId - shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId); + if (shuffleManagerRpcServiceEnabled && rssResubmitStage) { + // In Stage Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId. + shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId); + } else if (shuffleManagerRpcServiceEnabled && taskBlockSendFailureRetryEnabled) { + // In Stage Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId. + shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry(shuffleId); } else { shuffleHandleInfo = new SimpleShuffleHandleInfo( @@ -687,9 +634,12 @@ public ShuffleReader getReaderImpl( final int partitionNum = rssShuffleHandle.getDependency().partitioner().numPartitions(); int shuffleId = rssShuffleHandle.getShuffleId(); ShuffleHandleInfo shuffleHandleInfo; - if (shuffleManagerRpcServiceEnabled) { - // Get the ShuffleServer list from the Driver based on the shuffleId - shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId); + if (shuffleManagerRpcServiceEnabled && rssResubmitStage) { + // In Stage Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId. + shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId); + } else if (shuffleManagerRpcServiceEnabled && taskBlockSendFailureRetryEnabled) { + // In Block Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId. + shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry(shuffleId); } else { shuffleHandleInfo = new SimpleShuffleHandleInfo( @@ -936,39 +886,6 @@ public void clearTaskMeta(String taskId) { taskToFailedBlockSendTracker.remove(taskId); } - @VisibleForTesting - protected void registerShuffleServers( - String appId, - int shuffleId, - Map> serverToPartitionRanges, - RemoteStorageInfo remoteStorage) { - if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) { - return; - } - LOG.info("Start to register shuffleId[" + shuffleId + "]"); - long start = System.currentTimeMillis(); - Set>> entries = - serverToPartitionRanges.entrySet(); - entries.stream() - .forEach( - entry -> { - shuffleWriteClient.registerShuffle( - entry.getKey(), - appId, - shuffleId, - entry.getValue(), - remoteStorage, - dataDistributionType, - maxConcurrencyPerPartitionToWrite); - }); - LOG.info( - "Finish register shuffleId[" - + shuffleId - + "] with " - + (System.currentTimeMillis() - start) - + " ms"); - } - @VisibleForTesting protected void registerCoordinator() { String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key()); @@ -1023,16 +940,6 @@ public String getAppId() { return id.get(); } - /** - * @return the maximum number of fetch failures per shuffle partition before that shuffle stage - * should be recomputed - */ - @Override - public int getMaxFetchFailures() { - final String TASK_MAX_FAILURE = "spark.task.maxFailures"; - return Math.max(1, sparkConf.getInt(TASK_MAX_FAILURE, 4) - 1); - } - @Override public int getPartitionNum(int shuffleId) { return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0); @@ -1134,279 +1041,6 @@ public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) { return taskToFailedBlockSendTracker.get(taskId); } - @Override - public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) { - return shuffleHandleInfoManager.get(shuffleId); - } - - // todo: automatic close client when the client is idle to avoid too much connections for spark - // driver. - private ShuffleManagerClient getOrCreateShuffleManagerClient() { - if (shuffleManagerClient == null) { - RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); - String driver = rssConf.getString("driver.host", ""); - int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); - this.shuffleManagerClient = - ShuffleManagerClientFactory.getInstance() - .createShuffleManagerClient(ClientType.GRPC, driver, port); - } - return shuffleManagerClient; - } - - /** - * Get the ShuffleServer list from the Driver based on the shuffleId - * - * @param shuffleId shuffleId - * @return ShuffleHandleInfo - */ - private synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfo(int shuffleId) { - RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = - new RssPartitionToShuffleServerRequest(shuffleId); - RssPartitionToShuffleServerResponse rpcPartitionToShufflerServer = - getOrCreateShuffleManagerClient() - .getPartitionToShufflerServer(rssPartitionToShuffleServerRequest); - MutableShuffleHandleInfo shuffleHandleInfo = - MutableShuffleHandleInfo.fromProto( - rpcPartitionToShufflerServer.getShuffleHandleInfoProto()); - return shuffleHandleInfo; - } - - /** - * Add the shuffleServer that failed to write to the failure list - * - * @param shuffleServerId - */ - @Override - public void addFailuresShuffleServerInfos(String shuffleServerId) { - failuresShuffleServerIds.add(shuffleServerId); - } - - /** - * Reassign the ShuffleServer list for ShuffleId - * - * @param shuffleId - * @param numPartitions - */ - @Override - public synchronized boolean reassignAllShuffleServersForWholeStage( - int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) { - String stageIdAndAttempt = stageId + "_" + stageAttemptNumber; - Boolean needReassign = serverAssignedInfos.computeIfAbsent(stageIdAndAttempt, id -> false); - if (!needReassign) { - int requiredShuffleServerNumber = - RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); - int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf); - /** Before reassigning ShuffleServer, clear the ShuffleServer list in ShuffleWriteClient. */ - shuffleWriteClient.unregisterShuffle(id.get(), shuffleId); - Map> partitionToServers = - requestShuffleAssignment( - shuffleId, - numPartitions, - 1, - requiredShuffleServerNumber, - estimateTaskConcurrency, - failuresShuffleServerIds, - null); - /** - * we need to clear the metadata of the completed task, otherwise some of the stage's data - * will be lost - */ - try { - unregisterAllMapOutput(shuffleId); - } catch (SparkException e) { - LOG.error("Clear MapoutTracker Meta failed!"); - throw new RssException("Clear MapoutTracker Meta failed!", e); - } - MutableShuffleHandleInfo handleInfo = - new MutableShuffleHandleInfo(shuffleId, partitionToServers, getRemoteStorageInfo()); - shuffleHandleInfoManager.register(shuffleId, handleInfo); - serverAssignedInfos.put(stageIdAndAttempt, true); - return true; - } else { - LOG.info( - "The Stage:{} has been reassigned in an Attempt{},Return without performing any operation", - stageId, - stageAttemptNumber); - return false; - } - } - - /** this is only valid on driver side that exposed to being invoked by grpc server */ - @Override - public MutableShuffleHandleInfo reassignOnBlockSendFailure( - int shuffleId, Map> partitionToFailureServers) { - long startTime = System.currentTimeMillis(); - MutableShuffleHandleInfo handleInfo = - (MutableShuffleHandleInfo) shuffleHandleInfoManager.get(shuffleId); - synchronized (handleInfo) { - // If the reassignment servers for one partition exceeds the max reassign server num, - // it should fast fail. - handleInfo.checkPartitionReassignServerNum( - partitionToFailureServers.keySet(), partitionReassignMaxServerNum); - - Map> newServerToPartitions = new HashMap<>(); - // receivingFailureServer -> partitionId -> replacementServerIds. For logging - Map>> reassignResult = new HashMap<>(); - - for (Map.Entry> entry : - partitionToFailureServers.entrySet()) { - int partitionId = entry.getKey(); - for (ReceivingFailureServer receivingFailureServer : entry.getValue()) { - StatusCode code = receivingFailureServer.getStatusCode(); - String serverId = receivingFailureServer.getServerId(); - - boolean serverHasReplaced = false; - Set replacements = handleInfo.getReplacements(serverId); - if (CollectionUtils.isEmpty(replacements)) { - final int requiredServerNum = 1; - Set excludedServers = new HashSet<>(handleInfo.listExcludedServers()); - excludedServers.add(serverId); - replacements = - reassignServerForTask( - shuffleId, Sets.newHashSet(partitionId), excludedServers, requiredServerNum); - } else { - serverHasReplaced = true; - } - - Set updatedReassignServers = - handleInfo.updateAssignment(partitionId, serverId, replacements); - - reassignResult - .computeIfAbsent(serverId, x -> new HashMap<>()) - .computeIfAbsent(partitionId, x -> new HashSet<>()) - .addAll( - updatedReassignServers.stream().map(x -> x.getId()).collect(Collectors.toSet())); - - if (serverHasReplaced) { - for (ShuffleServerInfo serverInfo : updatedReassignServers) { - newServerToPartitions - .computeIfAbsent(serverInfo, x -> new ArrayList<>()) - .add(new PartitionRange(partitionId, partitionId)); - } - } - } - } - if (!newServerToPartitions.isEmpty()) { - LOG.info( - "Register the new partition->servers assignment on reassign. {}", - newServerToPartitions); - registerShuffleServers(id.get(), shuffleId, newServerToPartitions, getRemoteStorageInfo()); - } - - LOG.info( - "Finished reassignOnBlockSendFailure request and cost {}(ms). Reassign result: {}", - System.currentTimeMillis() - startTime, - reassignResult); - - return handleInfo; - } - } - - /** - * Creating the shuffleAssignmentInfo from the servers and partitionIds - * - * @param servers - * @param partitionIds - * @return - */ - private ShuffleAssignmentsInfo createShuffleAssignmentsInfo( - Set servers, Set partitionIds) { - Map> newPartitionToServers = new HashMap<>(); - List partitionRanges = new ArrayList<>(); - for (Integer partitionId : partitionIds) { - newPartitionToServers.put(partitionId, new ArrayList<>(servers)); - partitionRanges.add(new PartitionRange(partitionId, partitionId)); - } - Map> serverToPartitionRanges = new HashMap<>(); - for (ShuffleServerInfo server : servers) { - serverToPartitionRanges.put(server, partitionRanges); - } - return new ShuffleAssignmentsInfo(newPartitionToServers, serverToPartitionRanges); - } - - /** Request the new shuffle-servers to replace faulty server. */ - private Set reassignServerForTask( - int shuffleId, - Set partitionIds, - Set excludedServers, - int requiredServerNum) { - AtomicReference> replacementsRef = - new AtomicReference<>(new HashSet<>()); - requestShuffleAssignment( - shuffleId, - requiredServerNum, - 1, - requiredServerNum, - 1, - excludedServers, - shuffleAssignmentsInfo -> { - if (shuffleAssignmentsInfo == null) { - return null; - } - Set replacements = - shuffleAssignmentsInfo.getPartitionToServers().values().stream() - .flatMap(x -> x.stream()) - .collect(Collectors.toSet()); - replacementsRef.set(replacements); - return createShuffleAssignmentsInfo(replacements, partitionIds); - }); - return replacementsRef.get(); - } - - private Map> requestShuffleAssignment( - int shuffleId, - int partitionNum, - int partitionNumPerRange, - int assignmentShuffleServerNumber, - int estimateTaskConcurrency, - Set faultyServerIds, - Function reassignmentHandler) { - Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); - ClientUtils.validateClientType(clientType); - assignmentTags.add(clientType); - long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); - int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); - faultyServerIds.addAll(failuresShuffleServerIds); - try { - return RetryUtils.retry( - () -> { - ShuffleAssignmentsInfo response = - shuffleWriteClient.getShuffleAssignments( - id.get(), - shuffleId, - partitionNum, - partitionNumPerRange, - assignmentTags, - assignmentShuffleServerNumber, - estimateTaskConcurrency, - faultyServerIds); - LOG.info("Finished reassign"); - if (reassignmentHandler != null) { - response = reassignmentHandler.apply(response); - } - registerShuffleServers( - id.get(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo()); - return response.getPartitionToServers(); - }, - retryInterval, - retryTimes); - } catch (Throwable throwable) { - throw new RssException("registerShuffle failed!", throwable); - } - } - - private RemoteStorageInfo getRemoteStorageInfo() { - String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()); - RemoteStorageInfo defaultRemoteStorage = - new RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), "")); - return ClientUtils.fetchRemoteStorage( - id.get(), defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient); - } - - public boolean isRssResubmitStage() { - return rssResubmitStage; - } - @VisibleForTesting public void setDataPusher(DataPusher dataPusher) { this.dataPusher = dataPusher; diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java index 0805dfe350..d0eab5ae1a 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java @@ -588,7 +588,8 @@ public void registerShuffle( List partitionRanges, RemoteStorageInfo remoteStorage, ShuffleDataDistributionType dataDistributionType, - int maxConcurrencyPerPartitionToWrite) {} + int maxConcurrencyPerPartitionToWrite, + int stageAttemptNumber) {} @Override public boolean sendCommit( diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java index 7d8f533924..efd39e35af 100644 --- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java +++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java @@ -35,6 +35,16 @@ public interface ShuffleWriteClient { + default SendShuffleDataResult sendShuffleData( + String appId, + int stageAttemptNumber, + List shuffleBlockInfoList, + Supplier needCancelRequest) { + throw new UnsupportedOperationException( + this.getClass().getName() + + " doesn't implement getShuffleAssignments with faultyServerIds"); + } + SendShuffleDataResult sendShuffleData( String appId, List shuffleBlockInfoList, @@ -44,6 +54,25 @@ SendShuffleDataResult sendShuffleData( void registerApplicationInfo(String appId, long timeoutMs, String user); + default void registerShuffle( + ShuffleServerInfo shuffleServerInfo, + String appId, + int shuffleId, + List partitionRanges, + RemoteStorageInfo remoteStorage, + ShuffleDataDistributionType dataDistributionType, + int maxConcurrencyPerPartitionToWrite) { + registerShuffle( + shuffleServerInfo, + appId, + shuffleId, + partitionRanges, + remoteStorage, + dataDistributionType, + maxConcurrencyPerPartitionToWrite, + 0); + } + void registerShuffle( ShuffleServerInfo shuffleServerInfo, String appId, @@ -51,7 +80,8 @@ void registerShuffle( List partitionRanges, RemoteStorageInfo remoteStorage, ShuffleDataDistributionType dataDistributionType, - int maxConcurrencyPerPartitionToWrite); + int maxConcurrencyPerPartitionToWrite, + int stageAttemptNumber); boolean sendCommit( Set shuffleServerInfoSet, String appId, int shuffleId, int numMaps); diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index ed240c8877..27e0d76832 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -156,6 +156,7 @@ public ShuffleWriteClientImpl(ShuffleClientFactory.WriteClientBuilder builder) { private boolean sendShuffleDataAsync( String appId, + int stageAttemptNumber, Map>>> serverToBlocks, Map> serverToBlockIds, Map blockIdsSendSuccessTracker, @@ -185,7 +186,11 @@ private boolean sendShuffleDataAsync( // todo: compact unnecessary blocks that reach replicaWrite RssSendShuffleDataRequest request = new RssSendShuffleDataRequest( - appId, retryMax, retryIntervalMax, shuffleIdToBlocks); + appId, + stageAttemptNumber, + retryMax, + retryIntervalMax, + shuffleIdToBlocks); long s = System.currentTimeMillis(); RssSendShuffleDataResponse response = getShuffleServerClient(ssi).sendShuffleData(request); @@ -307,10 +312,19 @@ void genServerToBlocks( }); } + @Override + public SendShuffleDataResult sendShuffleData( + String appId, + List shuffleBlockInfoList, + Supplier needCancelRequest) { + return sendShuffleData(appId, 0, shuffleBlockInfoList, needCancelRequest); + } + /** The batch of sending belongs to the same task */ @Override public SendShuffleDataResult sendShuffleData( String appId, + int stageAttemptNumber, List shuffleBlockInfoList, Supplier needCancelRequest) { @@ -393,6 +407,7 @@ public SendShuffleDataResult sendShuffleData( boolean isAllSuccess = sendShuffleDataAsync( appId, + stageAttemptNumber, primaryServerToBlocks, primaryServerToBlockIds, blockIdsSendSuccessTracker, @@ -409,6 +424,7 @@ public SendShuffleDataResult sendShuffleData( LOG.info("The sending of primary round is failed partially, so start the secondary round"); sendShuffleDataAsync( appId, + stageAttemptNumber, secondaryServerToBlocks, secondaryServerToBlockIds, blockIdsSendSuccessTracker, @@ -538,7 +554,8 @@ public void registerShuffle( List partitionRanges, RemoteStorageInfo remoteStorage, ShuffleDataDistributionType dataDistributionType, - int maxConcurrencyPerPartitionToWrite) { + int maxConcurrencyPerPartitionToWrite, + int stageAttemptNumber) { String user = null; try { user = UserGroupInformation.getCurrentUser().getShortUserName(); @@ -554,7 +571,8 @@ public void registerShuffle( remoteStorage, user, dataDistributionType, - maxConcurrencyPerPartitionToWrite); + maxConcurrencyPerPartitionToWrite, + stageAttemptNumber); RssRegisterShuffleResponse response = getShuffleServerClient(shuffleServerInfo).registerShuffle(request); diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java index a77b0d3c7a..9fefb98f6b 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/SendShuffleDataRequest.java @@ -30,6 +30,8 @@ public class SendShuffleDataRequest extends RequestMessage { private String appId; private int shuffleId; + + private int stageAttemptNumber; private long requireId; private Map> partitionToBlocks; private long timestamp; @@ -41,12 +43,24 @@ public SendShuffleDataRequest( long requireId, Map> partitionToBlocks, long timestamp) { + this(requestId, appId, shuffleId, 0, requireId, partitionToBlocks, timestamp); + } + + public SendShuffleDataRequest( + long requestId, + String appId, + int shuffleId, + int stageAttemptNumber, + long requireId, + Map> partitionToBlocks, + long timestamp) { super(requestId); this.appId = appId; this.shuffleId = shuffleId; this.requireId = requireId; this.partitionToBlocks = partitionToBlocks; this.timestamp = timestamp; + this.stageAttemptNumber = stageAttemptNumber; } @Override @@ -146,6 +160,10 @@ public void setTimestamp(long timestamp) { this.timestamp = timestamp; } + public int getStageAttemptNumber() { + return stageAttemptNumber; + } + @Override public String getOperationType() { return "sendShuffleData"; diff --git a/common/src/main/java/org/apache/uniffle/common/rpc/StatusCode.java b/common/src/main/java/org/apache/uniffle/common/rpc/StatusCode.java index 79e35ecabe..ff8ac231c0 100644 --- a/common/src/main/java/org/apache/uniffle/common/rpc/StatusCode.java +++ b/common/src/main/java/org/apache/uniffle/common/rpc/StatusCode.java @@ -35,6 +35,7 @@ public enum StatusCode { ACCESS_DENIED(8), INVALID_REQUEST(9), NO_BUFFER_FOR_HUGE_PARTITION(10), + STAGE_RETRY_IGNORE(11), UNKNOWN(-1); static final Map VALUE_MAP = diff --git a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java index 45d570e779..f34d0a6cce 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java @@ -28,7 +28,7 @@ import org.apache.uniffle.client.request.RssReportShuffleResultRequest; import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest; import org.apache.uniffle.client.response.RssGetShuffleResultResponse; -import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse; +import org.apache.uniffle.client.response.RssPartitionToShuffleServerWithStageRetryResponse; import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse; import org.apache.uniffle.client.response.RssReassignServersReponse; import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; @@ -40,12 +40,23 @@ RssReportShuffleFetchFailureResponse reportShuffleFetchFailure( RssReportShuffleFetchFailureRequest request); /** - * Gets the mapping between partitions and ShuffleServer from the ShuffleManager server + * In Stage Retry mode,Gets the mapping between partitions and ShuffleServer from the + * ShuffleManager server. * * @param req request * @return RssPartitionToShuffleServerResponse */ - RssPartitionToShuffleServerResponse getPartitionToShufflerServer( + RssPartitionToShuffleServerWithStageRetryResponse getPartitionToShufflerServerWithStageRetry( + RssPartitionToShuffleServerRequest req); + + /** + * In Block Retry mode,Gets the mapping between partitions and ShuffleServer from the + * ShuffleManager server. + * + * @param req request + * @return RssPartitionToShuffleServerResponse + */ + RssReassignOnBlockSendFailureResponse getPartitionToShufflerServerWithBlockRetry( RssPartitionToShuffleServerRequest req); RssReportShuffleWriteFailureResponse reportShuffleWriteFailure( diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java index bebee89112..17a73162c3 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java @@ -32,7 +32,7 @@ import org.apache.uniffle.client.request.RssReportShuffleResultRequest; import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest; import org.apache.uniffle.client.response.RssGetShuffleResultResponse; -import org.apache.uniffle.client.response.RssPartitionToShuffleServerResponse; +import org.apache.uniffle.client.response.RssPartitionToShuffleServerWithStageRetryResponse; import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse; import org.apache.uniffle.client.response.RssReassignServersReponse; import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; @@ -90,14 +90,26 @@ public RssReportShuffleFetchFailureResponse reportShuffleFetchFailure( } @Override - public RssPartitionToShuffleServerResponse getPartitionToShufflerServer( + public RssPartitionToShuffleServerWithStageRetryResponse + getPartitionToShufflerServerWithStageRetry(RssPartitionToShuffleServerRequest req) { + RssProtos.PartitionToShuffleServerRequest protoRequest = req.toProto(); + RssProtos.PartitionToShuffleServerWithStageRetryResponse partitionToShufflerServer = + getBlockingStub().getPartitionToShufflerServerWithStageRetry(protoRequest); + RssPartitionToShuffleServerWithStageRetryResponse + rssPartitionToShuffleServerWithStageRetryResponse = + RssPartitionToShuffleServerWithStageRetryResponse.fromProto(partitionToShufflerServer); + return rssPartitionToShuffleServerWithStageRetryResponse; + } + + @Override + public RssReassignOnBlockSendFailureResponse getPartitionToShufflerServerWithBlockRetry( RssPartitionToShuffleServerRequest req) { RssProtos.PartitionToShuffleServerRequest protoRequest = req.toProto(); - RssProtos.PartitionToShuffleServerResponse partitionToShufflerServer = - getBlockingStub().getPartitionToShufflerServer(protoRequest); - RssPartitionToShuffleServerResponse rssPartitionToShuffleServerResponse = - RssPartitionToShuffleServerResponse.fromProto(partitionToShufflerServer); - return rssPartitionToShuffleServerResponse; + RssProtos.RssReassignOnBlockSendFailureResponse partitionToShufflerServer = + getBlockingStub().getPartitionToShufflerServerWithBlockRetry(protoRequest); + RssReassignOnBlockSendFailureResponse rssReassignOnBlockSendFailureResponse = + RssReassignOnBlockSendFailureResponse.fromProto(partitionToShufflerServer); + return rssReassignOnBlockSendFailureResponse; } @Override diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index f20cd85f51..55fbfadbd8 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -174,7 +174,8 @@ private ShuffleRegisterResponse doRegisterShuffle( RemoteStorageInfo remoteStorageInfo, String user, ShuffleDataDistributionType dataDistributionType, - int maxConcurrencyPerPartitionToWrite) { + int maxConcurrencyPerPartitionToWrite, + int stageAttemptNumber) { ShuffleRegisterRequest.Builder reqBuilder = ShuffleRegisterRequest.newBuilder(); reqBuilder .setAppId(appId) @@ -182,7 +183,8 @@ private ShuffleRegisterResponse doRegisterShuffle( .setUser(user) .setShuffleDataDistribution(RssProtos.DataDistribution.valueOf(dataDistributionType.name())) .setMaxConcurrencyPerPartitionToWrite(maxConcurrencyPerPartitionToWrite) - .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges)); + .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges)) + .setStageAttemptNumber(stageAttemptNumber); RemoteStorage.Builder rsBuilder = RemoteStorage.newBuilder(); rsBuilder.setPath(remoteStorageInfo.getPath()); Map remoteStorageConf = remoteStorageInfo.getConfItems(); @@ -433,7 +435,8 @@ public RssRegisterShuffleResponse registerShuffle(RssRegisterShuffleRequest requ request.getRemoteStorageInfo(), request.getUser(), request.getDataDistributionType(), - request.getMaxConcurrencyPerPartitionToWrite()); + request.getMaxConcurrencyPerPartitionToWrite(), + request.getStageAttemptNumber()); RssRegisterShuffleResponse response; RssProtos.StatusCode statusCode = rpcResponse.getStatus(); @@ -464,6 +467,7 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ String appId = request.getAppId(); Map>> shuffleIdToBlocks = request.getShuffleIdToBlocks(); + int stageAttemptNumber = request.getStageAttemptNumber(); boolean isSuccessful = true; AtomicReference failedStatusCode = new AtomicReference<>(StatusCode.INTERNAL_ERROR); @@ -528,6 +532,7 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ .setRequireBufferId(requireId) .addAllShuffleData(shuffleData) .setTimestamp(start) + .setStageAttemptNumber(stageAttemptNumber) .build(); SendShuffleDataResponse response = getBlockingStub().sendShuffleData(rpcRequest); if (LOG.isDebugEnabled()) { diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java index ca06d555ed..fcb464ce15 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java @@ -108,6 +108,7 @@ public String getClientInfo() { public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest request) { Map>> shuffleIdToBlocks = request.getShuffleIdToBlocks(); + int stageAttemptNumber = request.getStageAttemptNumber(); boolean isSuccessful = true; AtomicReference failedStatusCode = new AtomicReference<>(StatusCode.INTERNAL_ERROR); @@ -128,6 +129,7 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ requestId(), request.getAppId(), shuffleId, + stageAttemptNumber, 0L, stb.getValue(), System.currentTimeMillis()); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java index 98fd012416..4cbdc4448d 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java @@ -64,6 +64,30 @@ public RssGetShuffleAssignmentsRequest( int assignmentShuffleServerNumber, int estimateTaskConcurrency, Set faultyServerIds) { + this( + appId, + shuffleId, + partitionNum, + partitionNumPerRange, + dataReplica, + requiredTags, + assignmentShuffleServerNumber, + estimateTaskConcurrency, + faultyServerIds, + 0); + } + + public RssGetShuffleAssignmentsRequest( + String appId, + int shuffleId, + int partitionNum, + int partitionNumPerRange, + int dataReplica, + Set requiredTags, + int assignmentShuffleServerNumber, + int estimateTaskConcurrency, + Set faultyServerIds, + int stageAttemptNumber) { this.appId = appId; this.shuffleId = shuffleId; this.partitionNum = partitionNum; diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java index 2cd49bb6d3..7e42be653e 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java @@ -35,6 +35,7 @@ public class RssRegisterShuffleRequest { private String user; private ShuffleDataDistributionType dataDistributionType; private int maxConcurrencyPerPartitionToWrite; + private int stageAttemptNumber; public RssRegisterShuffleRequest( String appId, @@ -44,6 +45,26 @@ public RssRegisterShuffleRequest( String user, ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite) { + this( + appId, + shuffleId, + partitionRanges, + remoteStorageInfo, + user, + dataDistributionType, + maxConcurrencyPerPartitionToWrite, + 0); + } + + public RssRegisterShuffleRequest( + String appId, + int shuffleId, + List partitionRanges, + RemoteStorageInfo remoteStorageInfo, + String user, + ShuffleDataDistributionType dataDistributionType, + int maxConcurrencyPerPartitionToWrite, + int stageAttemptNumber) { this.appId = appId; this.shuffleId = shuffleId; this.partitionRanges = partitionRanges; @@ -51,6 +72,7 @@ public RssRegisterShuffleRequest( this.user = user; this.dataDistributionType = dataDistributionType; this.maxConcurrencyPerPartitionToWrite = maxConcurrencyPerPartitionToWrite; + this.stageAttemptNumber = stageAttemptNumber; } public RssRegisterShuffleRequest( @@ -67,7 +89,8 @@ public RssRegisterShuffleRequest( remoteStorageInfo, user, dataDistributionType, - RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue()); + RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), + 0); } public RssRegisterShuffleRequest( @@ -79,7 +102,8 @@ public RssRegisterShuffleRequest( new RemoteStorageInfo(remoteStoragePath), StringUtils.EMPTY, ShuffleDataDistributionType.NORMAL, - RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue()); + RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), + 0); } public String getAppId() { @@ -109,4 +133,8 @@ public ShuffleDataDistributionType getDataDistributionType() { public int getMaxConcurrencyPerPartitionToWrite() { return maxConcurrencyPerPartitionToWrite; } + + public int getStageAttemptNumber() { + return stageAttemptNumber; + } } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java index 8fbf18f29c..1b5fdcff80 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssSendShuffleDataRequest.java @@ -25,6 +25,7 @@ public class RssSendShuffleDataRequest { private String appId; + private int stageAttemptNumber; private int retryMax; private long retryIntervalMax; private Map>> shuffleIdToBlocks; @@ -34,10 +35,20 @@ public RssSendShuffleDataRequest( int retryMax, long retryIntervalMax, Map>> shuffleIdToBlocks) { + this(appId, 0, retryMax, retryIntervalMax, shuffleIdToBlocks); + } + + public RssSendShuffleDataRequest( + String appId, + int stageAttemptNumber, + int retryMax, + long retryIntervalMax, + Map>> shuffleIdToBlocks) { this.appId = appId; this.retryMax = retryMax; this.retryIntervalMax = retryIntervalMax; this.shuffleIdToBlocks = shuffleIdToBlocks; + this.stageAttemptNumber = stageAttemptNumber; } public String getAppId() { @@ -52,6 +63,10 @@ public long getRetryIntervalMax() { return retryIntervalMax; } + public int getStageAttemptNumber() { + return stageAttemptNumber; + } + public Map>> getShuffleIdToBlocks() { return shuffleIdToBlocks; } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerWithStageRetryResponse.java similarity index 68% rename from internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java rename to internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerWithStageRetryResponse.java index 9daa002ed2..0f150eb8e1 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerResponse.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssPartitionToShuffleServerWithStageRetryResponse.java @@ -20,24 +20,24 @@ import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.proto.RssProtos; -public class RssPartitionToShuffleServerResponse extends ClientResponse { - private RssProtos.MutableShuffleHandleInfo shuffleHandleInfoProto; +public class RssPartitionToShuffleServerWithStageRetryResponse extends ClientResponse { + private RssProtos.StageAttemptShuffleHandleInfo shuffleHandleInfoProto; - public RssPartitionToShuffleServerResponse( + public RssPartitionToShuffleServerWithStageRetryResponse( StatusCode statusCode, String message, - RssProtos.MutableShuffleHandleInfo shuffleHandleInfoProto) { + RssProtos.StageAttemptShuffleHandleInfo shuffleHandleInfoProto) { super(statusCode, message); this.shuffleHandleInfoProto = shuffleHandleInfoProto; } - public RssProtos.MutableShuffleHandleInfo getShuffleHandleInfoProto() { + public RssProtos.StageAttemptShuffleHandleInfo getShuffleHandleInfoProto() { return shuffleHandleInfoProto; } - public static RssPartitionToShuffleServerResponse fromProto( - RssProtos.PartitionToShuffleServerResponse response) { - return new RssPartitionToShuffleServerResponse( + public static RssPartitionToShuffleServerWithStageRetryResponse fromProto( + RssProtos.PartitionToShuffleServerWithStageRetryResponse response) { + return new RssPartitionToShuffleServerWithStageRetryResponse( StatusCode.valueOf(response.getStatus().name()), response.getMsg(), response.getShuffleHandleInfo()); diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index 5b3d9e1331..df6bb30562 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -184,6 +184,7 @@ message ShuffleRegisterRequest { string user = 5; DataDistribution shuffleDataDistribution = 6; int32 maxConcurrencyPerPartitionToWrite = 7; + int32 stageAttemptNumber = 8; } enum DataDistribution { @@ -221,6 +222,7 @@ message SendShuffleDataRequest { int64 requireBufferId = 3; repeated ShuffleData shuffleData = 4; int64 timestamp = 5; + int32 stageAttemptNumber = 6; } message SendShuffleDataResponse { @@ -305,6 +307,7 @@ enum StatusCode { ACCESS_DENIED = 8; INVALID_REQUEST = 9; NO_BUFFER_FOR_HUGE_PARTITION = 10; + STAGE_RETRY_IGNORE = 11; // add more status } @@ -528,8 +531,10 @@ message CancelDecommissionResponse { // per application. service ShuffleManager { rpc reportShuffleFetchFailure (ReportShuffleFetchFailureRequest) returns (ReportShuffleFetchFailureResponse); - // Gets the mapping between partitions and ShuffleServer from the ShuffleManager server - rpc getPartitionToShufflerServer(PartitionToShuffleServerRequest) returns (PartitionToShuffleServerResponse); + // Gets the mapping between partitions and ShuffleServer from the ShuffleManager server on Stage Retry. + rpc getPartitionToShufflerServerWithStageRetry(PartitionToShuffleServerRequest) returns (PartitionToShuffleServerWithStageRetryResponse); + // Gets the mapping between partitions and ShuffleServer from the ShuffleManager server on Block Retry. + rpc getPartitionToShufflerServerWithBlockRetry(PartitionToShuffleServerRequest) returns (RssReassignOnBlockSendFailureResponse); // Report write failures to ShuffleManager rpc reportShuffleWriteFailure (ReportShuffleWriteFailureRequest) returns (ReportShuffleWriteFailureResponse); // Reassign the RPC interface of the ShuffleServer list @@ -563,10 +568,15 @@ message PartitionToShuffleServerRequest { int32 shuffleId = 2; } -message PartitionToShuffleServerResponse { +message PartitionToShuffleServerWithStageRetryResponse { StatusCode status = 1; string msg = 2; - MutableShuffleHandleInfo shuffleHandleInfo = 3; + StageAttemptShuffleHandleInfo shuffleHandleInfo = 3; +} + +message StageAttemptShuffleHandleInfo { + repeated MutableShuffleHandleInfo historyMutableShuffleHandleInfo= 1; + MutableShuffleHandleInfo currentMutableShuffleHandleInfo = 2; } message MutableShuffleHandleInfo { diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index 7e1eb88cfc..8ed8d21093 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -157,6 +157,36 @@ public void registerShuffle( int shuffleId = req.getShuffleId(); String remoteStoragePath = req.getRemoteStorage().getPath(); String user = req.getUser(); + int stageAttemptNumber = req.getStageAttemptNumber(); + + if (stageAttemptNumber > 0) { + ShuffleTaskInfo taskInfo = shuffleServer.getShuffleTaskManager().getShuffleTaskInfo(appId); + int attemptNumber = taskInfo.getLatestStageAttemptNumber(shuffleId); + if (stageAttemptNumber > attemptNumber) { + taskInfo.refreshLatestStageAttemptNumber(shuffleId, stageAttemptNumber); + try { + long start = System.currentTimeMillis(); + shuffleServer.getShuffleTaskManager().removeShuffleDataSync(appId, shuffleId); + LOG.info( + "Deleted the previous stage attempt data due to stage recomputing for app: {}, " + + "shuffleId: {}. It costs {} ms", + appId, + shuffleId, + System.currentTimeMillis() - start); + } catch (Exception e) { + LOG.error( + "Errors on clearing previous stage attempt data for app: {}, shuffleId: {}", + appId, + shuffleId, + e); + StatusCode code = StatusCode.INTERNAL_ERROR; + reply = ShuffleRegisterResponse.newBuilder().setStatus(code.toProto()).build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } + } + } ShuffleDataDistributionType shuffleDataDistributionType = ShuffleDataDistributionType.valueOf( @@ -210,6 +240,22 @@ public void sendShuffleData( int shuffleId = req.getShuffleId(); long requireBufferId = req.getRequireBufferId(); long timestamp = req.getTimestamp(); + int stageAttemptNumber = req.getStageAttemptNumber(); + ShuffleTaskInfo taskInfo = shuffleServer.getShuffleTaskManager().getShuffleTaskInfo(appId); + Integer latestStageAttemptNumber = taskInfo.getLatestStageAttemptNumber(shuffleId); + // The Stage retry occurred, and the task before StageNumber was simply ignored and not + // processed if the task was being sent. + if (stageAttemptNumber < latestStageAttemptNumber) { + String responseMessage = "A retry has occurred at the Stage, sending data is invalid."; + reply = + SendShuffleDataResponse.newBuilder() + .setStatus(StatusCode.STAGE_RETRY_IGNORE.toProto()) + .setRetMsg(responseMessage) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } if (timestamp > 0) { /* * Here we record the transport time, but we don't consider the impact of data size on transport time. diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java index b6806f6343..bbbfded017 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java @@ -66,6 +66,8 @@ public class ShuffleTaskInfo { private final Map> partitionBlockCounters; + private final Map latestStageAttemptNumbers; + public ShuffleTaskInfo(String appId) { this.appId = appId; this.currentTimes = System.currentTimeMillis(); @@ -78,6 +80,7 @@ public ShuffleTaskInfo(String appId) { this.existHugePartition = new AtomicBoolean(false); this.specification = new AtomicReference<>(); this.partitionBlockCounters = JavaUtils.newConcurrentMap(); + this.latestStageAttemptNumbers = JavaUtils.newConcurrentMap(); } public Long getCurrentTimes() { @@ -220,6 +223,14 @@ public long getBlockNumber(int shuffleId, int partitionId) { return counter.get(); } + public Integer getLatestStageAttemptNumber(int shuffleId) { + return latestStageAttemptNumbers.computeIfAbsent(shuffleId, key -> 0); + } + + public void refreshLatestStageAttemptNumber(int shuffleId, int stageAttemptNumber) { + latestStageAttemptNumbers.put(shuffleId, stageAttemptNumber); + } + @Override public String toString() { return "ShuffleTaskInfo{" diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index a98eac26c2..8fe597d03a 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -910,7 +910,7 @@ public void removeShuffleDataAsync(String appId) { } @VisibleForTesting - void removeShuffleDataSync(String appId, int shuffleId) { + public void removeShuffleDataSync(String appId, int shuffleId) { removeResourcesByShuffleIds(appId, Arrays.asList(shuffleId)); } diff --git a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java index 668b53cba0..448c12a237 100644 --- a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java +++ b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java @@ -57,6 +57,7 @@ import org.apache.uniffle.server.ShuffleServer; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.server.ShuffleServerMetrics; +import org.apache.uniffle.server.ShuffleTaskInfo; import org.apache.uniffle.server.ShuffleTaskManager; import org.apache.uniffle.server.buffer.PreAllocatedBufferInfo; import org.apache.uniffle.server.buffer.ShuffleBufferManager; @@ -102,6 +103,18 @@ public void handleSendShuffleDataRequest(TransportClient client, SendShuffleData int shuffleId = req.getShuffleId(); long requireBufferId = req.getRequireId(); long timestamp = req.getTimestamp(); + int stageAttemptNumber = req.getStageAttemptNumber(); + ShuffleTaskInfo taskInfo = shuffleServer.getShuffleTaskManager().getShuffleTaskInfo(appId); + Integer latestStageAttemptNumber = taskInfo.getLatestStageAttemptNumber(shuffleId); + // The Stage retry occurred, and the task before StageNumber was simply ignored and not + // processed if the task was being sent. + if (stageAttemptNumber < latestStageAttemptNumber) { + String responseMessage = "A retry has occurred at the Stage, sending data is invalid."; + rpcResponse = + new RpcResponse(req.getRequestId(), StatusCode.STAGE_RETRY_IGNORE, responseMessage); + client.getChannel().writeAndFlush(rpcResponse); + return; + } if (timestamp > 0) { /* * Here we record the transport time, but we don't consider the impact of data size on transport time.