diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java index 5a93c2b117..9751ba0b89 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java @@ -34,14 +34,6 @@ public AddBlockEvent(String taskId, List shuffleDataInfoList) this.processedCallbackChain = new ArrayList<>(); } - public AddBlockEvent( - String taskId, List shuffleBlockInfoList, Runnable callback) { - this.taskId = taskId; - this.shuffleDataInfoList = shuffleBlockInfoList; - this.processedCallbackChain = new ArrayList<>(); - addCallback(callback); - } - /** @param callback, should not throw any exception and execute fast. */ public void addCallback(Runnable callback) { processedCallbackChain.add(callback); diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockFailureCallback.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockFailureCallback.java new file mode 100644 index 0000000000..116d1945de --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockFailureCallback.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.writer; + +import org.apache.uniffle.common.ShuffleBlockInfo; + +public interface BlockFailureCallback { + void onBlockFailure(ShuffleBlockInfo block); +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockSuccessCallback.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockSuccessCallback.java new file mode 100644 index 0000000000..2b5dc0d09f --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BlockSuccessCallback.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.writer; + +import org.apache.uniffle.common.ShuffleBlockInfo; + +public interface BlockSuccessCallback { + void onBlockSuccess(ShuffleBlockInfo block); +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java index 30f649f688..1517b7173c 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java @@ -88,14 +88,23 @@ public CompletableFuture send(AddBlockEvent event) { () -> { String taskId = event.getTaskId(); List shuffleBlockInfoList = event.getShuffleDataInfoList(); + SendShuffleDataResult result = null; try { - SendShuffleDataResult result = + result = shuffleWriteClient.sendShuffleData( rssAppId, shuffleBlockInfoList, () -> !isValidTask(taskId)); putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds()); putFailedBlockSendTracker( taskToFailedBlockSendTracker, taskId, result.getFailedBlockSendTracker()); } finally { + Set succeedBlockIds = + result.getSuccessBlockIds() == null + ? Collections.emptySet() + : result.getSuccessBlockIds(); + for (ShuffleBlockInfo block : shuffleBlockInfoList) { + block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId())); + } + List callbackChain = Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST); for (Runnable runnable : callbackChain) { diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index d8261047fc..efe376a344 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -408,14 +408,18 @@ private void requestExecutorMemory(long leastMem) { } } + public void releaseBlockResource(ShuffleBlockInfo block) { + this.freeAllocatedMemory(block.getFreeMemory()); + block.getData().release(); + } + public List buildBlockEvents(List shuffleBlockInfoList) { long totalSize = 0; - long memoryUsed = 0; List events = new ArrayList<>(); List shuffleBlockInfosPerEvent = Lists.newArrayList(); for (ShuffleBlockInfo sbi : shuffleBlockInfoList) { + sbi.withCompletionCallback((block, isSuccessful) -> this.releaseBlockResource(block)); totalSize += sbi.getSize(); - memoryUsed += sbi.getFreeMemory(); shuffleBlockInfosPerEvent.add(sbi); // split shuffle data according to the size if (totalSize > sendSizeLimit) { @@ -427,20 +431,9 @@ public List buildBlockEvents(List shuffleBlockI + totalSize + " bytes"); } - // Use final temporary variables for closures - final long memoryUsedTemp = memoryUsed; - final List shuffleBlocksTemp = shuffleBlockInfosPerEvent; - events.add( - new AddBlockEvent( - taskId, - shuffleBlockInfosPerEvent, - () -> { - freeAllocatedMemory(memoryUsedTemp); - shuffleBlocksTemp.stream().forEach(x -> x.getData().release()); - })); + events.add(new AddBlockEvent(taskId, shuffleBlockInfosPerEvent)); shuffleBlockInfosPerEvent = Lists.newArrayList(); totalSize = 0; - memoryUsed = 0; } } if (!shuffleBlockInfosPerEvent.isEmpty()) { @@ -453,16 +446,7 @@ public List buildBlockEvents(List shuffleBlockI + " bytes"); } // Use final temporary variables for closures - final long memoryUsedTemp = memoryUsed; - final List shuffleBlocksTemp = shuffleBlockInfosPerEvent; - events.add( - new AddBlockEvent( - taskId, - shuffleBlockInfosPerEvent, - () -> { - freeAllocatedMemory(memoryUsedTemp); - shuffleBlocksTemp.stream().forEach(x -> x.getData().release()); - })); + events.add(new AddBlockEvent(taskId, shuffleBlockInfosPerEvent)); } return events; } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java index 38ebbbd372..22143bc0e9 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java @@ -371,6 +371,9 @@ public void spillByOwnTest() { long sum = 0L; List events = wbm.buildBlockEvents(blocks); for (AddBlockEvent event : events) { + for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) { + block.executeCompletionCallback(true); + } event.getProcessedCallbackChain().stream().forEach(x -> x.run()); sum += event.getShuffleDataInfoList().stream().mapToLong(x -> x.getFreeMemory()).sum(); } @@ -413,6 +416,9 @@ public void spillByOwnTest() { // ignore. } } + for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) { + block.executeCompletionCallback(true); + } event.getProcessedCallbackChain().stream().forEach(x -> x.run()); sum += event.getShuffleDataInfoList().stream() diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 0b4faef827..1b4df17478 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -1264,4 +1264,9 @@ private RemoteStorageInfo getRemoteStorageInfo() { public boolean isRssResubmitStage() { return rssResubmitStage; } + + @VisibleForTesting + public void setDataPusher(DataPusher dataPusher) { + this.dataPusher = dataPusher; + } } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 635b3593a2..8a22b73ba5 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -21,10 +21,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; @@ -46,6 +46,7 @@ import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; +import org.apache.commons.collections.CollectionUtils; import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; @@ -83,6 +84,7 @@ import org.apache.uniffle.common.exception.RssSendFailedException; import org.apache.uniffle.common.exception.RssWaitFailedException; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.storage.util.StorageType; public class RssShuffleWriter extends ShuffleWriter { @@ -94,7 +96,7 @@ public class RssShuffleWriter extends ShuffleWriter { private final String appId; private final int shuffleId; private WriteBufferManager bufferManager; - private final String taskId; + private String taskId; private final int numMaps; private final ShuffleDependency shuffleDependency; private final Partitioner partitioner; @@ -113,7 +115,8 @@ public class RssShuffleWriter extends ShuffleWriter { private final Set blockIds = Sets.newConcurrentHashSet(); private TaskContext taskContext; private SparkConf sparkConf; - private boolean blockSendFailureRetryEnabled; + private boolean blockFailSentRetryEnabled; + private int blockFailSentRetryMaxTimes = 1; /** used by columnar rss shuffle writer implementation */ protected final long taskAttemptId; @@ -122,7 +125,9 @@ public class RssShuffleWriter extends ShuffleWriter { private final BlockingQueue finishEventQueue = new LinkedBlockingQueue<>(); - private final Map faultyServers = new HashMap<>(); + // shuffleServerId -> failoverShuffleServer + private final Map replacementShuffleServers = + JavaUtils.newConcurrentMap(); // Only for tests @VisibleForTesting @@ -192,7 +197,7 @@ private RssShuffleWriter( this.taskFailureCallback = taskFailureCallback; this.taskContext = context; this.sparkConf = sparkConf; - this.blockSendFailureRetryEnabled = + this.blockFailSentRetryEnabled = sparkConf.getBoolean( RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.key(), @@ -269,8 +274,8 @@ private void writeImpl(Iterator> records) { long recordCount = 0; while (records.hasNext()) { recordCount++; - // Task should fast fail when sending data failed - checkIfBlocksFailed(); + + checkDataIfAnyFailure(); Product2 record = records.next(); K key = record._1(); @@ -363,6 +368,17 @@ protected List> postBlockEvent( List shuffleBlockInfoList) { List> futures = new ArrayList<>(); for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) { + if (blockFailSentRetryEnabled) { + // do nothing if failed. + for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) { + block.withCompletionCallback( + (completionBlock, isSuccessful) -> { + if (isSuccessful) { + bufferManager.releaseBlockResource(completionBlock); + } + }); + } + } event.addCallback( () -> { boolean ret = finishEventQueue.add(new Object()); @@ -386,7 +402,7 @@ protected void checkBlockSendResult(Set blockIds) { while (true) { try { finishEventQueue.clear(); - checkIfBlocksFailed(); + checkDataIfAnyFailure(); Set successBlockIds = shuffleManager.getSuccessBlockIds(taskId); blockIds.removeAll(successBlockIds); if (blockIds.isEmpty()) { @@ -422,105 +438,128 @@ protected void checkBlockSendResult(Set blockIds) { } } - private void checkIfBlocksFailed() { - Set failedBlockIds = shuffleManager.getFailedBlockIds(taskId); - if (blockSendFailureRetryEnabled && !failedBlockIds.isEmpty()) { - Set shouldResendBlockSet = shouldResendBlockStatusSet(failedBlockIds); - try { - reSendFailedBlockIds(shouldResendBlockSet); - } catch (Exception e) { - LOG.error("resend failed blocks failed.", e); + private void checkDataIfAnyFailure() { + if (blockFailSentRetryEnabled) { + collectFailedBlocksToResend(); + } else { + if (hasAnyBlockFailure()) { + throw new RssSendFailedException("Fail to send the block"); } - failedBlockIds = shuffleManager.getFailedBlockIds(taskId); } + } + + private boolean hasAnyBlockFailure() { + Set failedBlockIds = shuffleManager.getFailedBlockIds(taskId); if (!failedBlockIds.isEmpty()) { - String errorMsg = - "Send failed: Task[" - + taskId - + "]" - + " failed because " - + failedBlockIds.size() - + " blocks can't be sent to shuffle server: " - + shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers(); - LOG.error(errorMsg); - throw new RssSendFailedException(errorMsg); + LOG.error( + "Errors on sending blocks for task[{}]. {} blocks can't be sent to remote servers: {}", + taskId, + failedBlockIds.size(), + shuffleManager.getBlockIdsFailedSendTracker(taskId).getFaultyShuffleServers()); + return true; } + return false; } - private Set shouldResendBlockStatusSet(Set failedBlockIds) { - FailedBlockSendTracker failedBlockTracker = shuffleManager.getBlockIdsFailedSendTracker(taskId); - Set resendBlockStatusSet = Sets.newHashSet(); - for (Long failedBlockId : failedBlockIds) { - failedBlockTracker.getFailedBlockStatus(failedBlockId).stream() - // todo: more status need reassign - .filter( - trackingBlockStatus -> trackingBlockStatus.getStatusCode() == StatusCode.NO_BUFFER) - .forEach(trackingBlockStatus -> resendBlockStatusSet.add(trackingBlockStatus)); + private void collectFailedBlocksToResend() { + if (!blockFailSentRetryEnabled) { + return; + } + + FailedBlockSendTracker failedTracker = shuffleManager.getBlockIdsFailedSendTracker(taskId); + Set failedBlockIds = failedTracker.getFailedBlockIds(); + if (CollectionUtils.isEmpty(failedBlockIds)) { + return; + } + + boolean isFastFail = false; + Set resendCandidates = new HashSet<>(); + // to check whether the blocks resent exceed the max resend count. + for (Long blockId : failedBlockIds) { + List failedBlockStatus = failedTracker.getFailedBlockStatus(blockId); + int retryIndex = failedBlockStatus.get(0).getShuffleBlockInfo().getRetryCnt(); + // todo: support retry times by config + if (retryIndex >= blockFailSentRetryMaxTimes) { + LOG.error( + "Partial blocks for taskId: [{}] retry exceeding the max retry times: [{}]. Fast fail! faulty server list: {}", + taskId, + blockFailSentRetryMaxTimes, + failedBlockStatus.stream() + .map(x -> x.getShuffleServerInfo()) + .collect(Collectors.toSet())); + isFastFail = true; + break; + } + + // todo: if setting multi replica and another replica is succeed to send, no need to resend + resendCandidates.addAll(failedBlockStatus); } - return resendBlockStatusSet; + + if (isFastFail) { + // release data and allocated memory + for (Long blockId : failedBlockIds) { + List failedBlockStatus = failedTracker.getFailedBlockStatus(blockId); + Optional blockStatus = failedBlockStatus.stream().findFirst(); + if (blockStatus.isPresent()) { + blockStatus.get().getShuffleBlockInfo().executeCompletionCallback(true); + } + } + + throw new RssSendFailedException( + "Errors on resending the blocks data to the remote shuffle-server."); + } + + resendFailedBlocks(resendCandidates); } - private void reSendFailedBlockIds(Set failedBlockStatusSet) { - List reAssignSeverBlockInfoList = Lists.newArrayList(); - List failedBlockInfoList = Lists.newArrayList(); + private void resendFailedBlocks(Set failedBlockStatusSet) { + List reassignBlocks = Lists.newArrayList(); Map> faultyServerToPartitions = failedBlockStatusSet.stream().collect(Collectors.groupingBy(d -> d.getShuffleServerInfo())); - faultyServerToPartitions.entrySet().stream() - .forEach( - t -> { - Set partitionIds = - t.getValue().stream() - .map(x -> x.getShuffleBlockInfo().getPartitionId()) - .collect(Collectors.toSet()); - ShuffleServerInfo dynamicShuffleServer = faultyServers.get(t.getKey().getId()); - if (dynamicShuffleServer == null) { - dynamicShuffleServer = - reAssignFaultyShuffleServer(partitionIds, t.getKey().getId()); - faultyServers.put(t.getKey().getId(), dynamicShuffleServer); - } - - ShuffleServerInfo finalDynamicShuffleServer = dynamicShuffleServer; - failedBlockStatusSet.forEach( - trackingBlockStatus -> { - ShuffleBlockInfo failedBlockInfo = trackingBlockStatus.getShuffleBlockInfo(); - failedBlockInfoList.add(failedBlockInfo); - reAssignSeverBlockInfoList.add( - new ShuffleBlockInfo( - failedBlockInfo.getShuffleId(), - failedBlockInfo.getPartitionId(), - failedBlockInfo.getBlockId(), - failedBlockInfo.getLength(), - failedBlockInfo.getCrc(), - failedBlockInfo.getData(), - Lists.newArrayList(finalDynamicShuffleServer), - failedBlockInfo.getUncompressLength(), - failedBlockInfo.getFreeMemory(), - taskAttemptId)); - }); - }); - clearFailedBlockIdsStates(failedBlockInfoList, faultyServers); - processShuffleBlockInfos(reAssignSeverBlockInfoList); - checkIfBlocksFailed(); + + for (Map.Entry> entry : + faultyServerToPartitions.entrySet()) { + Set partitionIds = + entry.getValue().stream() + .map(x -> x.getShuffleBlockInfo().getPartitionId()) + .collect(Collectors.toSet()); + ShuffleServerInfo replacement = replacementShuffleServers.get(entry.getKey().getId()); + if (replacement == null) { + // todo: merge multiple requests into one. + replacement = reassignFaultyShuffleServer(partitionIds, entry.getKey().getId()); + replacementShuffleServers.put(entry.getKey().getId(), replacement); + } + + for (TrackingBlockStatus blockStatus : failedBlockStatusSet) { + // clear the previous retry state of block + ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo(); + clearFailedBlockState(block); + + final ShuffleBlockInfo newBlock = block; + newBlock.incrRetryCnt(); + newBlock.reassignShuffleServers(Arrays.asList(replacement)); + + reassignBlocks.add(newBlock); + } + } + + processShuffleBlockInfos(reassignBlocks); } - private void clearFailedBlockIdsStates( - List failedBlockInfoList, Map faultyServers) { - failedBlockInfoList.forEach( - shuffleBlockInfo -> { - shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(shuffleBlockInfo.getBlockId()); - shuffleBlockInfo.getShuffleServerInfos().stream() - .filter(s -> faultyServers.containsKey(s.getId())) - .forEach( - s -> - serverToPartitionToBlockIds - .get(s) - .get(shuffleBlockInfo.getPartitionId()) - .remove(shuffleBlockInfo.getBlockId())); - partitionLengths[shuffleBlockInfo.getPartitionId()] -= shuffleBlockInfo.getLength(); - }); + private void clearFailedBlockState(ShuffleBlockInfo block) { + shuffleManager.getBlockIdsFailedSendTracker(taskId).remove(block.getBlockId()); + block.getShuffleServerInfos().stream() + .filter(s -> replacementShuffleServers.containsKey(s.getId())) + .forEach( + s -> + serverToPartitionToBlockIds + .get(s) + .get(block.getPartitionId()) + .remove(block.getBlockId())); + partitionLengths[block.getPartitionId()] -= block.getLength(); } - private ShuffleServerInfo reAssignFaultyShuffleServer( + private ShuffleServerInfo reassignFaultyShuffleServer( Set partitionIds, String faultyServerId) { RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); String driver = rssConf.getString("driver.host", ""); @@ -611,6 +650,17 @@ public Option stop(boolean success) { return Option.empty(); } } finally { + if (blockFailSentRetryEnabled) { + if (success) { + if (CollectionUtils.isNotEmpty(shuffleManager.getFailedBlockIds(taskId))) { + LOG.error( + "Errors on stopping writer due to the remaining failed blockIds. This should not happen."); + return Option.empty(); + } + } else { + shuffleManager.getBlockIdsFailedSendTracker(taskId).clearAndReleaseBlockResources(); + } + } // free all memory & metadata, or memory leak happen in executor if (bufferManager != null) { bufferManager.freeAllMemory(); @@ -694,4 +744,29 @@ private void throwFetchFailedIfNecessary(Exception e) { } throw new RssException(e); } + + @VisibleForTesting + protected void enableBlockFailSentRetry() { + this.blockFailSentRetryEnabled = true; + } + + @VisibleForTesting + protected void setBlockFailSentRetryMaxTimes(int blockFailSentRetryMaxTimes) { + this.blockFailSentRetryMaxTimes = blockFailSentRetryMaxTimes; + } + + @VisibleForTesting + protected void addReassignmentShuffleServer(String shuffleId, ShuffleServerInfo replacement) { + replacementShuffleServers.put(shuffleId, replacement); + } + + @VisibleForTesting + protected void setTaskId(String taskId) { + this.taskId = taskId; + } + + @VisibleForTesting + protected Map>> getServerToPartitionToBlockIds() { + return serverToPartitionToBlockIds; + } } diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index b68d4b74e1..5ca85eced6 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -26,6 +26,7 @@ import java.util.Set; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.stream.Collectors; @@ -64,6 +65,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; @@ -73,6 +75,198 @@ public class RssShuffleWriterTest { + private MutableList> createMockRecords() { + MutableList> data = new MutableList<>(); + data.appendElem(new Tuple2<>("testKey2", "testValue2")); + data.appendElem(new Tuple2<>("testKey3", "testValue3")); + data.appendElem(new Tuple2<>("testKey4", "testValue4")); + data.appendElem(new Tuple2<>("testKey6", "testValue6")); + data.appendElem(new Tuple2<>("testKey1", "testValue1")); + data.appendElem(new Tuple2<>("testKey5", "testValue5")); + return data; + } + + @Test + public void blockFailureResendTest() throws Exception { + SparkConf conf = new SparkConf(); + conf.setAppName("testApp") + .setMaster("local[2]") + .set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32") + .set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "32") + .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true") + .set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true") + .set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "64") + .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000") + .set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "128") + .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name()); + + List shuffleBlockInfos = Lists.newArrayList(); + Map> successBlockIds = JavaUtils.newConcurrentMap(); + Map taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap(); + taskToFailedBlockSendTracker.put("taskId", new FailedBlockSendTracker()); + + AtomicInteger sentFailureCnt = new AtomicInteger(); + FakedDataPusher dataPusher = + new FakedDataPusher( + event -> { + assertEquals("taskId", event.getTaskId()); + FailedBlockSendTracker tracker = taskToFailedBlockSendTracker.get(event.getTaskId()); + for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) { + boolean isSuccessful = true; + ShuffleServerInfo shuffleServer = block.getShuffleServerInfos().get(0); + if (shuffleServer.getId().equals("id1") && block.getRetryCnt() == 0) { + tracker.add(block, shuffleServer, StatusCode.NO_BUFFER); + sentFailureCnt.addAndGet(1); + isSuccessful = false; + } else { + successBlockIds.putIfAbsent(event.getTaskId(), Sets.newConcurrentHashSet()); + successBlockIds.get(event.getTaskId()).add(block.getBlockId()); + shuffleBlockInfos.add(block); + } + block.executeCompletionCallback(isSuccessful); + } + return new CompletableFuture<>(); + }); + + final RssShuffleManager manager = + TestUtils.createShuffleManager( + conf, false, dataPusher, successBlockIds, taskToFailedBlockSendTracker); + Serializer kryoSerializer = new KryoSerializer(conf); + Partitioner mockPartitioner = mock(Partitioner.class); + final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + ShuffleDependency mockDependency = mock(ShuffleDependency.class); + RssShuffleHandle mockHandle = mock(RssShuffleHandle.class); + when(mockHandle.getDependency()).thenReturn(mockDependency); + when(mockDependency.serializer()).thenReturn(kryoSerializer); + when(mockDependency.partitioner()).thenReturn(mockPartitioner); + when(mockPartitioner.numPartitions()).thenReturn(3); + + Map> partitionToServers = Maps.newHashMap(); + List ssi12 = + Arrays.asList( + new ShuffleServerInfo("id1", "0.0.0.1", 100), + new ShuffleServerInfo("id2", "0.0.0.2", 100)); + partitionToServers.put(0, ssi12); + List ssi34 = + Arrays.asList( + new ShuffleServerInfo("id3", "0.0.0.3", 100), + new ShuffleServerInfo("id4", "0.0.0.4", 100)); + partitionToServers.put(1, ssi34); + List ssi56 = + Arrays.asList( + new ShuffleServerInfo("id5", "0.0.0.5", 100), + new ShuffleServerInfo("id6", "0.0.0.6", 100)); + partitionToServers.put(2, ssi56); + when(mockPartitioner.getPartition("testKey1")).thenReturn(0); + when(mockPartitioner.getPartition("testKey2")).thenReturn(1); + when(mockPartitioner.getPartition("testKey4")).thenReturn(0); + when(mockPartitioner.getPartition("testKey5")).thenReturn(1); + when(mockPartitioner.getPartition("testKey3")).thenReturn(2); + when(mockPartitioner.getPartition("testKey7")).thenReturn(0); + when(mockPartitioner.getPartition("testKey8")).thenReturn(1); + when(mockPartitioner.getPartition("testKey9")).thenReturn(2); + when(mockPartitioner.getPartition("testKey6")).thenReturn(2); + + TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class); + + BufferManagerOptions bufferOptions = new BufferManagerOptions(conf); + ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics(); + WriteBufferManager bufferManager = + new WriteBufferManager( + 0, + 0, + bufferOptions, + kryoSerializer, + partitionToServers, + mockTaskMemoryManager, + shuffleWriteMetrics, + RssSparkConfig.toRssConf(conf)); + bufferManager.setTaskId("taskId"); + + WriteBufferManager bufferManagerSpy = spy(bufferManager); + TaskContext contextMock = mock(TaskContext.class); + ShuffleHandleInfo mockShuffleHandleInfo = mock(ShuffleHandleInfo.class); + RssShuffleWriter rssShuffleWriter = + new RssShuffleWriter<>( + "appId", + 0, + "taskId", + 1L, + bufferManagerSpy, + shuffleWriteMetrics, + manager, + conf, + mockShuffleWriteClient, + mockHandle, + mockShuffleHandleInfo, + contextMock); + rssShuffleWriter.enableBlockFailSentRetry(); + doReturn(100000L).when(bufferManagerSpy).acquireMemory(anyLong()); + ShuffleServerInfo replacement = new ShuffleServerInfo("id10", "0.0.0.10", 100); + rssShuffleWriter.addReassignmentShuffleServer("id1", replacement); + + RssShuffleWriter rssShuffleWriterSpy = spy(rssShuffleWriter); + doNothing().when(rssShuffleWriterSpy).sendCommit(); + + // case 1. failed blocks will be resent + MutableList> data = createMockRecords(); + rssShuffleWriterSpy.write(data.iterator()); + + Awaitility.await() + .timeout(Duration.ofSeconds(5)) + .until(() -> successBlockIds.get("taskId").size() == data.size()); + assertEquals(2, sentFailureCnt.get()); + assertEquals(0, taskToFailedBlockSendTracker.get("taskId").getFailedBlockIds().size()); + assertEquals(6, shuffleWriteMetrics.recordsWritten()); + assertEquals( + shuffleBlockInfos.stream().mapToInt(ShuffleBlockInfo::getLength).sum(), + shuffleWriteMetrics.bytesWritten()); + assertEquals(6, shuffleBlockInfos.size()); + + assertEquals(0, bufferManagerSpy.getUsedBytes()); + assertEquals(0, bufferManagerSpy.getInSendListBytes()); + + // check the blockId -> servers mapping. + // server -> partitionId -> blockIds + Map>> serverToPartitionToBlockIds = + rssShuffleWriterSpy.getServerToPartitionToBlockIds(); + assertEquals(2, serverToPartitionToBlockIds.get(replacement).get(0).size()); + + // case2. If exceeding the max retry times, it will fast fail. + rssShuffleWriterSpy.setBlockFailSentRetryMaxTimes(1); + rssShuffleWriterSpy.setTaskId("taskId2"); + FakedDataPusher alwaysFailedDataPusher = + new FakedDataPusher( + event -> { + assertEquals("taskId2", event.getTaskId()); + FailedBlockSendTracker tracker = taskToFailedBlockSendTracker.get(event.getTaskId()); + for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) { + boolean isSuccessful = true; + ShuffleServerInfo shuffleServer = block.getShuffleServerInfos().get(0); + if (shuffleServer.getId().equals("id1")) { + tracker.add(block, shuffleServer, StatusCode.NO_BUFFER); + isSuccessful = false; + } else { + successBlockIds.putIfAbsent(event.getTaskId(), Sets.newConcurrentHashSet()); + successBlockIds.get(event.getTaskId()).add(block.getBlockId()); + } + block.executeCompletionCallback(isSuccessful); + } + return new CompletableFuture<>(); + }); + manager.setDataPusher(alwaysFailedDataPusher); + + MutableList> mockedData = createMockRecords(); + try { + rssShuffleWriterSpy.write(mockedData.iterator()); + fail(); + } catch (Exception e) { + // ignore + } + assertEquals(0, bufferManagerSpy.getUsedBytes()); + assertEquals(0, bufferManagerSpy.getInSendListBytes()); + } + @Test public void checkBlockSendResultTest() { SparkConf conf = new SparkConf(); @@ -161,8 +355,7 @@ public void checkBlockSendResultTest() { assertThrows( RuntimeException.class, () -> rssShuffleWriter.checkBlockSendResult(Sets.newHashSet(1L, 2L, 3L))); - System.out.println(e2.getMessage()); - assertTrue(e3.getMessage().startsWith("Send failed:")); + assertTrue(e3.getMessage().startsWith("Fail to send the block")); successBlocks.clear(); taskToFailedBlockSendTracker.clear(); } diff --git a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java index 0c239c7e10..93e20dd02e 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java @@ -32,6 +32,12 @@ public class FailedBlockSendTracker { + /** + * blockId -> list(trackingStatus) + * + *

This indicates the blockId latest sending status, and it will not store the resending + * history. The list data structure is to describe the multiple servers for the multiple replica + */ private Map> trackingBlockStatusMap; public FailedBlockSendTracker() { @@ -55,7 +61,10 @@ public void remove(long blockId) { trackingBlockStatusMap.remove(blockId); } - public void clear() { + public void clearAndReleaseBlockResources() { + trackingBlockStatusMap.values().stream() + .flatMap(x -> x.stream()) + .forEach(x -> x.getShuffleBlockInfo().executeCompletionCallback(true)); trackingBlockStatusMap.clear(); } diff --git a/common/src/main/java/org/apache/uniffle/common/BlockCompletionCallback.java b/common/src/main/java/org/apache/uniffle/common/BlockCompletionCallback.java new file mode 100644 index 0000000000..01ba694c3e --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/BlockCompletionCallback.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common; + +public interface BlockCompletionCallback { + void onBlockCompletion(ShuffleBlockInfo block, boolean isSuccessful); +} diff --git a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java index 8de75d90d4..36dec5e257 100644 --- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java +++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java @@ -36,6 +36,9 @@ public class ShuffleBlockInfo { private List shuffleServerInfos; private int uncompressLength; private long freeMemory; + private int retryCnt = 0; + + private transient BlockCompletionCallback completionCallback; public ShuffleBlockInfo( int shuffleId, @@ -153,7 +156,30 @@ public String toString() { return sb.toString(); } + public void incrRetryCnt() { + this.retryCnt += 1; + } + + public int getRetryCnt() { + return retryCnt; + } + + public void reassignShuffleServers(List replacements) { + this.shuffleServerInfos = replacements; + } + public synchronized void copyDataTo(ByteBuf to) { ByteBufUtils.copyByteBuf(data, to); } + + public void withCompletionCallback(BlockCompletionCallback callback) { + this.completionCallback = callback; + } + + public void executeCompletionCallback(boolean isSuccessful) { + if (completionCallback == null) { + return; + } + completionCallback.onBlockCompletion(this, isSuccessful); + } } diff --git a/common/src/main/java/org/apache/uniffle/common/function/TupleConsumer.java b/common/src/main/java/org/apache/uniffle/common/function/TupleConsumer.java new file mode 100644 index 0000000000..2a46387022 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/function/TupleConsumer.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.function; + +@FunctionalInterface +public interface TupleConsumer { + void accept(T t, F f); +}