From 1e4804b944bd91887748fc0ad2771825752f3a28 Mon Sep 17 00:00:00 2001 From: Fantasy-Jay <13631435453@163.com> Date: Thu, 3 Aug 2023 14:28:04 +0800 Subject: [PATCH] [#1068] feat(tez): Fail fast in client when failed to send data to server. (#1069) ### What changes were proposed in this pull request? Currently, it only checks for blocks that failed to send after all buffer data has been sent. This check also needs to be moved forward into the addRecord method, allowing it to fail fast. ### Why are the changes needed? Fix: # ([1068](https://github.com/apache/incubator-uniffle/issues/1068)) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Add more test case in WriteBufferManager. --- .../sort/buffer/WriteBufferManager.java | 24 ++-- .../sort/buffer/WriteBufferManagerTest.java | 126 ++++++++++++++++-- 2 files changed, 129 insertions(+), 21 deletions(-) diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java index 06e4194b9e..7da3208cbb 100644 --- a/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java +++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManager.java @@ -172,6 +172,9 @@ public void addRecord(int partitionId, K key, V value) throws InterruptedExcepti memoryLock.unlock(); } + // Fail fast if there are some failed blocks. + checkFailedBlocks(); + if (!buffers.containsKey(partitionId)) { WriteBuffer sortWriterBuffer = new WriteBuffer( @@ -282,14 +285,7 @@ public void waitSendFinished() { } long start = System.currentTimeMillis(); while (true) { - if (failedBlockIds.size() > 0) { - String errorMsg = - "Send failed: failed because " - + failedBlockIds.size() - + " blocks can't be sent to shuffle server."; - LOG.error(errorMsg); - throw new RssException(errorMsg); - } + checkFailedBlocks(); allBlockIds.removeAll(successBlockIds); if (allBlockIds.isEmpty()) { break; @@ -335,6 +331,18 @@ public void waitSendFinished() { sortTime); } + // Check if there are some failed blocks, if true then throw Exception. + private void checkFailedBlocks() { + if (failedBlockIds.size() > 0) { + String errorMsg = + "Send failed: failed because " + + failedBlockIds.size() + + " blocks can't be sent to shuffle server."; + LOG.error(errorMsg); + throw new RssException(errorMsg); + } + } + ShuffleBlockInfo createShuffleBlock(WriteBuffer wb) { byte[] data = wb.getData(); copyTime += wb.getCopyTime(); diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java index 8f2af65b9f..1449649c3f 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java @@ -44,7 +44,6 @@ import org.apache.tez.common.counters.TaskCounter; import org.apache.tez.common.counters.TezCounter; import org.apache.tez.dag.records.TezTaskAttemptID; -import org.apache.tez.dag.records.TezVertexID; import org.apache.tez.runtime.api.OutputContext; import org.apache.tez.runtime.library.api.TezRuntimeConfiguration; import org.apache.tez.runtime.library.output.OutputTestHelpers; @@ -67,6 +66,7 @@ import org.apache.uniffle.storage.util.StorageType; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class WriteBufferManagerTest { @@ -99,7 +99,9 @@ public void testWriteException(@TempDir File tmpDir) throws IOException, Interru long sendCheckInterval = 500L; long sendCheckTimeout = 5; int bitmapSplitNum = 1; - int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2); + int shuffleId = + RssTezUtils.computeShuffleId( + tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 2); Configuration conf = new Configuration(); FileSystem localFs = FileSystem.getLocal(conf); @@ -139,7 +141,7 @@ public void testWriteException(@TempDir File tmpDir) throws IOException, Interru rssConf, partitionToServers, numMaps, - isMemoryShuffleEnabled(storageType), + StorageType.withMemory(StorageType.valueOf(storageType)), sendCheckInterval, sendCheckTimeout, bitmapSplitNum, @@ -197,7 +199,9 @@ public void testWriteNormal(@TempDir File tmpDir) throws IOException, Interrupte long sendCheckInterval = 500L; long sendCheckTimeout = 60 * 1000 * 10L; int bitmapSplitNum = 1; - int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2); + int shuffleId = + RssTezUtils.computeShuffleId( + tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 2); Configuration conf = new Configuration(); FileSystem localFs = FileSystem.getLocal(conf); @@ -237,7 +241,7 @@ public void testWriteNormal(@TempDir File tmpDir) throws IOException, Interrupte rssConf, partitionToServers, numMaps, - isMemoryShuffleEnabled(storageType), + StorageType.withMemory(StorageType.valueOf(storageType)), sendCheckInterval, sendCheckTimeout, bitmapSplitNum, @@ -305,7 +309,9 @@ public void testCommitBlocksWhenMemoryShuffleDisabled(@TempDir File tmpDir) long sendCheckInterval = 500L; long sendCheckTimeout = 60 * 1000 * 10L; int bitmapSplitNum = 1; - int shuffleId = getShuffleId(tezTaskAttemptID, 1, 2); + int shuffleId = + RssTezUtils.computeShuffleId( + tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 2); Configuration conf = new Configuration(); FileSystem localFs = FileSystem.getLocal(conf); @@ -371,15 +377,109 @@ public void testCommitBlocksWhenMemoryShuffleDisabled(@TempDir File tmpDir) writeClient.mockedShuffleServer.getFlushBlockSize()); } - private int getShuffleId(TezTaskAttemptID tezTaskAttemptID, int upVertexId, int downVertexId) { - TezVertexID tezVertexID = tezTaskAttemptID.getTaskID().getVertexID(); + @Test + public void testFastFailWhenSendBlocksFailed(@TempDir File tmpDir) + throws IOException, InterruptedException { + TezTaskAttemptID tezTaskAttemptID = + TezTaskAttemptID.fromString("attempt_1681717153064_3770270_1_00_000000_0"); + final long maxMemSize = 10240; + final String appId = "application_1681717153064_3770270"; + final long taskAttemptId = 0; + final Set successBlockIds = Sets.newConcurrentHashSet(); + final Set failedBlockIds = Sets.newConcurrentHashSet(); + MockShuffleWriteClient writeClient = new MockShuffleWriteClient(); + // set mode = 1 to fake sending shuffle data failed. + writeClient.setMode(1); + RawComparator comparator = WritableComparator.get(BytesWritable.class); + long maxSegmentSize = 3 * 1024; + SerializationFactory serializationFactory = new SerializationFactory(new JobConf()); + Serializer keySerializer = + serializationFactory.getSerializer(BytesWritable.class); + Serializer valSerializer = + serializationFactory.getSerializer(BytesWritable.class); + // note: max buffer size is tiny. + long maxBufferSize = 14 * 1024; + double memoryThreshold = 0.8f; + int sendThreadNum = 1; + double sendThreshold = 0.2f; + int batch = 50; + int numMaps = 1; + RssConf rssConf = new RssConf(); + Map> partitionToServers = new HashMap<>(); + long sendCheckInterval = 500L; + long sendCheckTimeout = 60 * 1000 * 10L; + int bitmapSplitNum = 1; int shuffleId = - RssTezUtils.computeShuffleId(tezVertexID.getDAGId().getId(), upVertexId, downVertexId); - return shuffleId; - } + RssTezUtils.computeShuffleId( + tezTaskAttemptID.getTaskID().getVertexID().getDAGId().getId(), 1, 2); + + Configuration conf = new Configuration(); + FileSystem localFs = FileSystem.getLocal(conf); + Path workingDir = + new Path( + System.getProperty( + "test.build.data", System.getProperty("java.io.tmpdir", tmpDir.toString())), + RssOrderedPartitionedKVOutputTest.class.getName()) + .makeQualified(localFs.getUri(), localFs.getWorkingDirectory()); + conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS, Text.class.getName()); + conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS, Text.class.getName()); + conf.set( + TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS, HashPartitioner.class.getName()); + conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, workingDir.toString()); + OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, workingDir); + TezCounter mapOutputByteCounter = + outputContext.getCounters().findCounter(TaskCounter.OUTPUT_BYTES); - private boolean isMemoryShuffleEnabled(String storageType) { - return StorageType.withMemory(StorageType.valueOf(storageType)); + WriteBufferManager bufferManager = + new WriteBufferManager( + tezTaskAttemptID, + maxMemSize, + appId, + taskAttemptId, + successBlockIds, + failedBlockIds, + writeClient, + comparator, + maxSegmentSize, + keySerializer, + valSerializer, + maxBufferSize, + memoryThreshold, + sendThreadNum, + sendThreshold, + batch, + rssConf, + partitionToServers, + numMaps, + false, + sendCheckInterval, + sendCheckTimeout, + bitmapSplitNum, + shuffleId, + true, + mapOutputByteCounter); + + Random random = new Random(); + RssException rssException = + assertThrows( + RssException.class, + () -> { + for (int i = 0; i < 10000; i++) { + byte[] key = new byte[20]; + byte[] value = new byte[1024]; + random.nextBytes(key); + random.nextBytes(value); + int partitionId = random.nextInt(50); + bufferManager.addRecord( + partitionId, new BytesWritable(key), new BytesWritable(value)); + } + }); + assertTrue(rssException.getMessage().contains("Send failed")); + + rssException = assertThrows(RssException.class, bufferManager::waitSendFinished); + assertTrue(rssException.getMessage().contains("Send failed")); + + assertTrue(mapOutputByteCounter.getValue() < 10520000); } class MockShuffleServer {