diff --git a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java index 1f7aae942d..df3f667be9 100644 --- a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java +++ b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java @@ -216,6 +216,13 @@ public RawKeyValueIterator run() throws IOException, InterruptedException { LOG.info("In reduce: " + reduceId + ", Rss MR client starts to fetch blocks from RSS server"); JobConf readerJobConf = getRemoteConf(); boolean expectedTaskIdsBitmapFilterEnable = serverInfoList.size() > 1; + int retryMax = + rssJobConf.getInt( + RssMRConfig.RSS_CLIENT_RETRY_MAX, RssMRConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE); + long retryIntervalMax = + rssJobConf.getLong( + RssMRConfig.RSS_CLIENT_RETRY_INTERVAL_MAX, + RssMRConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE); ShuffleReadClient shuffleReadClient = ShuffleClientFactory.getInstance() .createShuffleReadClient( @@ -232,6 +239,8 @@ public RawKeyValueIterator run() throws IOException, InterruptedException { .hadoopConf(readerJobConf) .idHelper(new MRIdHelper()) .expectedTaskIdsBitmapFilterEnable(expectedTaskIdsBitmapFilterEnable) + .retryMax(retryMax) + .retryIntervalMax(retryIntervalMax) .rssConf(RssMRConfig.toRssConf(rssJobConf))); RssFetcher fetcher = new RssFetcher( diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java index 75855ba7a4..76bfed6084 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java @@ -108,6 +108,14 @@ public RssShuffleReader( @Override public Iterator> read() { LOG.info("Shuffle read started:" + getReadInfo()); + int retryMax = + rssConf.getInteger( + RssClientConfig.RSS_CLIENT_RETRY_MAX, + RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE); + long retryIntervalMax = + rssConf.getLong( + RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX, + RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE); ShuffleReadClient shuffleReadClient = ShuffleClientFactory.getInstance() .createShuffleReadClient( @@ -123,6 +131,8 @@ public Iterator> read() { .shuffleServerInfoList(shuffleServerInfoList) .hadoopConf(hadoopConf) .expectedTaskIdsBitmapFilterEnable(expectedTaskIdsBitmapFilterEnable) + .retryMax(retryMax) + .retryIntervalMax(retryIntervalMax) .rssConf(rssConf)); RssShuffleDataIterator rssShuffleDataIterator = new RssShuffleDataIterator( diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java index 3d7b58bbbc..0c7f3be9e3 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java @@ -246,6 +246,14 @@ class MultiPartitionIterator extends AbstractIterator> { boolean expectedTaskIdsBitmapFilterEnable = !(mapStartIndex == 0 && mapEndIndex == Integer.MAX_VALUE) || shuffleServerInfoList.size() > 1; + int retryMax = + rssConf.getInteger( + RssClientConfig.RSS_CLIENT_RETRY_MAX, + RssClientConfig.RSS_CLIENT_RETRY_MAX_DEFAULT_VALUE); + long retryIntervalMax = + rssConf.getLong( + RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX, + RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE); ShuffleReadClient shuffleReadClient = ShuffleClientFactory.getInstance() .createShuffleReadClient( @@ -262,6 +270,8 @@ class MultiPartitionIterator extends AbstractIterator> { .hadoopConf(hadoopConf) .shuffleDataDistributionType(dataDistributionType) .expectedTaskIdsBitmapFilterEnable(expectedTaskIdsBitmapFilterEnable) + .retryMax(retryMax) + .retryIntervalMax(retryIntervalMax) .rssConf(rssConf)); RssShuffleDataIterator iterator = new RssShuffleDataIterator<>( diff --git a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java index ce7d90041b..8efdd44a3d 100644 --- a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java +++ b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java @@ -212,6 +212,8 @@ public static class ReadClientBuilder { private int indexReadLimit; private long readBufferSize; private ClientType clientType; + private int retryMax; + private long retryIntervalMax; public ReadClientBuilder appId(String appId) { this.appId = appId; @@ -310,6 +312,16 @@ public ReadClientBuilder clientType(ClientType clientType) { return this; } + public ReadClientBuilder retryMax(int retryMax) { + this.retryMax = retryMax; + return this; + } + + public ReadClientBuilder retryIntervalMax(long retryIntervalMax) { + this.retryIntervalMax = retryIntervalMax; + return this; + } + public ReadClientBuilder() {} public String getAppId() { @@ -388,6 +400,14 @@ public ClientType getClientType() { return clientType; } + public int getRetryMax() { + return retryMax; + } + + public long getRetryIntervalMax() { + return retryIntervalMax; + } + public ShuffleReadClientImpl build() { return new ShuffleReadClientImpl(this); } diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java index 49bb2de3e4..4a789bfa2c 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java @@ -168,6 +168,8 @@ private void init(ShuffleClientFactory.ReadClientBuilder builder) { request.setExpectTaskIds(taskIdBitmap); request.setClientConf(builder.getRssConf()); request.setClientType(builder.getClientType()); + request.setRetryMax(builder.getRetryMax()); + request.setRetryIntervalMax(builder.getRetryIntervalMax()); if (builder.isExpectedTaskIdsBitmapFilterEnable()) { request.useExpectedTaskIdsBitmapFilter(); } diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java new file mode 100644 index 0000000000..abefb1b0dd --- /dev/null +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/RpcClientRetryTest.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.test; + +import java.io.File; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.roaringbitmap.longlong.Roaring64NavigableMap; + +import org.apache.uniffle.client.factory.ShuffleClientFactory; +import org.apache.uniffle.client.impl.ShuffleReadClientImpl; +import org.apache.uniffle.client.impl.ShuffleWriteClientImpl; +import org.apache.uniffle.client.response.SendShuffleDataResult; +import org.apache.uniffle.common.ClientType; +import org.apache.uniffle.common.PartitionRange; +import org.apache.uniffle.common.RemoteStorageInfo; +import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.ShuffleDataDistributionType; +import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.rpc.ServerType; +import org.apache.uniffle.coordinator.CoordinatorConf; +import org.apache.uniffle.coordinator.CoordinatorServer; +import org.apache.uniffle.server.MockedGrpcServer; +import org.apache.uniffle.server.MockedShuffleServer; +import org.apache.uniffle.server.ShuffleServer; +import org.apache.uniffle.server.ShuffleServerConf; +import org.apache.uniffle.storage.util.StorageType; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +public class RpcClientRetryTest extends ShuffleReadWriteBase { + + private static ShuffleServerInfo shuffleServerInfo0; + private static ShuffleServerInfo shuffleServerInfo1; + private static ShuffleServerInfo shuffleServerInfo2; + private static MockedShuffleWriteClientImpl shuffleWriteClientImpl; + + private ShuffleClientFactory.ReadClientBuilder baseReadBuilder(StorageType storageType) { + return ShuffleClientFactory.newReadBuilder() + .storageType(storageType.name()) + .shuffleId(0) + .partitionId(0) + .indexReadLimit(100) + .partitionNumPerRange(1) + .partitionNum(10) + .readBufferSize(1000); + } + + public static MockedShuffleServer createMockedShuffleServer(int id, File tmpDir) + throws Exception { + ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC); + File dataDir1 = new File(tmpDir, id + "_1"); + File dataDir2 = new File(tmpDir, id + "_2"); + String basePath = dataDir1.getAbsolutePath() + "," + dataDir2.getAbsolutePath(); + shuffleServerConf.setString("rss.storage.type", StorageType.MEMORY_LOCALFILE.name()); + shuffleServerConf.setString("rss.storage.basePath", basePath); + shuffleServerConf.set(ShuffleServerConf.SERVER_MEMORY_SHUFFLE_LOWWATERMARK_PERCENTAGE, 5.0); + shuffleServerConf.set(ShuffleServerConf.SERVER_MEMORY_SHUFFLE_HIGHWATERMARK_PERCENTAGE, 15.0); + shuffleServerConf.set(ShuffleServerConf.SERVER_BUFFER_CAPACITY, 600L); + return new MockedShuffleServer(shuffleServerConf); + } + + @BeforeAll + public static void initCluster(@TempDir File tmpDir) throws Exception { + CoordinatorConf coordinatorConf = getCoordinatorConf(); + createCoordinatorServer(coordinatorConf); + + grpcShuffleServers.add(createMockedShuffleServer(0, tmpDir)); + grpcShuffleServers.add(createMockedShuffleServer(1, tmpDir)); + grpcShuffleServers.add(createMockedShuffleServer(2, tmpDir)); + + shuffleServerInfo0 = + new ShuffleServerInfo( + String.format("127.0.0.1-%s", grpcShuffleServers.get(0).getGrpcPort()), + grpcShuffleServers.get(0).getIp(), + grpcShuffleServers.get(0).getGrpcPort()); + shuffleServerInfo1 = + new ShuffleServerInfo( + String.format("127.0.0.1-%s", grpcShuffleServers.get(1).getGrpcPort()), + grpcShuffleServers.get(1).getIp(), + grpcShuffleServers.get(1).getGrpcPort()); + shuffleServerInfo2 = + new ShuffleServerInfo( + String.format("127.0.0.1-%s", grpcShuffleServers.get(2).getGrpcPort()), + grpcShuffleServers.get(2).getIp(), + grpcShuffleServers.get(2).getGrpcPort()); + for (CoordinatorServer coordinator : coordinators) { + coordinator.start(); + } + for (ShuffleServer shuffleServer : grpcShuffleServers) { + shuffleServer.start(); + } + } + + public static void cleanCluster() throws Exception { + for (CoordinatorServer coordinator : coordinators) { + coordinator.stopServer(); + } + for (ShuffleServer shuffleServer : grpcShuffleServers) { + shuffleServer.stopServer(); + } + grpcShuffleServers = Lists.newArrayList(); + coordinators = Lists.newArrayList(); + } + + @AfterAll + public static void cleanEnv() throws Exception { + if (shuffleWriteClientImpl != null) { + shuffleWriteClientImpl.close(); + } + cleanCluster(); + } + + private static Stream testRpcRetryLogicProvider() { + return Stream.of( + Arguments.of(StorageType.MEMORY_LOCALFILE), + // According to SERVER_BUFFER_CAPACITY & SERVER_MEMORY_SHUFFLE_HIGHWATERMARK_PERCENTAGE, + // data will be flushed to disk, so read from disk only + Arguments.of(StorageType.LOCALFILE)); + } + + @ParameterizedTest + @MethodSource("testRpcRetryLogicProvider") + public void testRpcRetryLogic(StorageType storageType) { + String testAppId = "testRpcRetryLogic"; + registerShuffleServer(testAppId, 3, 2, 2, true); + Map expectedData = Maps.newHashMap(); + Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); + + List blocks = + createShuffleBlockList( + 0, + 0, + 0, + 3, + 25, + blockIdBitmap, + expectedData, + Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2)); + + SendShuffleDataResult result = shuffleWriteClientImpl.sendShuffleData(testAppId, blocks); + Roaring64NavigableMap failedBlockIdBitmap = Roaring64NavigableMap.bitmapOf(); + Roaring64NavigableMap successfulBlockIdBitmap = Roaring64NavigableMap.bitmapOf(); + for (Long blockId : result.getSuccessBlockIds()) { + successfulBlockIdBitmap.addLong(blockId); + } + for (Long blockId : result.getFailedBlockIds()) { + failedBlockIdBitmap.addLong(blockId); + } + assertEquals(0, failedBlockIdBitmap.getLongCardinality()); + assertEquals(blockIdBitmap, successfulBlockIdBitmap); + + Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0); + + ShuffleReadClientImpl readClient1 = + baseReadBuilder(storageType) + .appId(testAppId) + .blockIdBitmap(blockIdBitmap) + .taskIdBitmap(taskIdBitmap) + .shuffleServerInfoList( + Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2)) + .retryMax(3) + .retryIntervalMax(1) + .build(); + + // The data cannot be read because the maximum number of retries is 3 + enableFirstNReadRequestsToFail(4); + try { + validateResult(readClient1, expectedData); + fail(); + } catch (Exception e) { + // do nothing + } + disableFirstNReadRequestsToFail(); + + ShuffleReadClientImpl readClient2 = + baseReadBuilder(storageType) + .appId(testAppId) + .blockIdBitmap(blockIdBitmap) + .taskIdBitmap(taskIdBitmap) + .shuffleServerInfoList( + Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2)) + .retryMax(3) + .retryIntervalMax(1) + .build(); + + // The data can be read because the reader will retry + enableFirstNReadRequestsToFail(1); + validateResult(readClient2, expectedData); + disableFirstNReadRequestsToFail(); + } + + private static void enableFirstNReadRequestsToFail(int failedCount) { + for (ShuffleServer server : grpcShuffleServers) { + ((MockedGrpcServer) server.getServer()) + .getService() + .enableFirstNReadRequestToFail(failedCount); + } + } + + private static void disableFirstNReadRequestsToFail() { + for (ShuffleServer server : grpcShuffleServers) { + ((MockedGrpcServer) server.getServer()).getService().resetFirstNReadRequestToFail(); + } + } + + static class MockedShuffleWriteClientImpl extends ShuffleWriteClientImpl { + MockedShuffleWriteClientImpl(ShuffleClientFactory.WriteClientBuilder builder) { + super(builder); + } + + public SendShuffleDataResult sendShuffleData( + String appId, List shuffleBlockInfoList) { + return super.sendShuffleData(appId, shuffleBlockInfoList, () -> false); + } + } + + private void registerShuffleServer( + String testAppId, int replica, int replicaWrite, int replicaRead, boolean replicaSkip) { + + shuffleWriteClientImpl = + new MockedShuffleWriteClientImpl( + ShuffleClientFactory.newWriteBuilder() + .clientType(ClientType.GRPC.name()) + .retryMax(3) + .retryIntervalMax(1000) + .heartBeatThreadNum(1) + .replica(replica) + .replicaWrite(replicaWrite) + .replicaRead(replicaRead) + .replicaSkipEnabled(replicaSkip) + .dataTransferPoolSize(1) + .dataCommitPoolSize(1) + .unregisterThreadPoolSize(10) + .unregisterRequestTimeSec(10)); + + List allServers = + Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2); + + for (int i = 0; i < replica; i++) { + shuffleWriteClientImpl.registerShuffle( + allServers.get(i), + testAppId, + 0, + Lists.newArrayList(new PartitionRange(0, 0)), + new RemoteStorageInfo(""), + ShuffleDataDistributionType.NORMAL, + 1); + } + } +} diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index 5a6919e44a..3ab81a0c5f 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory; import org.apache.uniffle.client.api.ShuffleServerClient; +import org.apache.uniffle.client.request.RetryableRequest; import org.apache.uniffle.client.request.RssAppHeartBeatRequest; import org.apache.uniffle.client.request.RssFinishShuffleRequest; import org.apache.uniffle.client.request.RssGetInMemoryShuffleDataRequest; @@ -109,12 +110,26 @@ import org.apache.uniffle.proto.ShuffleServerGrpc; import org.apache.uniffle.proto.ShuffleServerGrpc.ShuffleServerBlockingStub; +import static org.apache.uniffle.proto.RssProtos.StatusCode.NO_BUFFER; + public class ShuffleServerGrpcClient extends GrpcClient implements ShuffleServerClient { private static final Logger LOG = LoggerFactory.getLogger(ShuffleServerGrpcClient.class); protected static final long FAILED_REQUIRE_ID = -1; protected long rpcTimeout; private ShuffleServerBlockingStub blockingStub; + /** + * A single instance of the Random class is created as a member variable to be reused throughout + * `ShuffleServerGrpcClient`. This approach has the following benefits: 1. Performance + * optimization: It avoids the overhead of creating and destroying objects frequently, reducing + * memory allocation and garbage collection costs. 2. Randomness: Reusing the same Random object + * helps maintain the randomness of the generated numbers. If multiple Random objects are created + * in a short period of time, their seeds may be the same or very close, leading to less random + * numbers. + */ + protected Random random = new Random(); + + protected static final int BACK_OFF_BASE = 2000; @VisibleForTesting public ShuffleServerGrpcClient(String host, int port) { @@ -237,8 +252,6 @@ public long requirePreAllocation( long start = System.currentTimeMillis(); int retry = 0; long result = FAILED_REQUIRE_ID; - Random random = new Random(); - final int backOffBase = 2000; if (LOG.isDebugEnabled()) { LOG.debug( "Requiring buffer for appId: {}, shuffleId: {}, partitionIds: {} with {} bytes from {}:{}", @@ -258,7 +271,7 @@ public long requirePreAllocation( "Exception happened when requiring pre-allocated buffer from {}:{}", host, port, e); return result; } - if (rpcResponse.getStatus() != RssProtos.StatusCode.NO_BUFFER + if (rpcResponse.getStatus() != NO_BUFFER && rpcResponse.getStatus() != RssProtos.StatusCode.NO_BUFFER_FOR_HUGE_PARTITION) { break; } @@ -291,7 +304,7 @@ public long requirePreAllocation( long backoffTime = Math.min( retryIntervalMax, - backOffBase * (1L << Math.min(retry, 16)) + random.nextInt(backOffBase)); + BACK_OFF_BASE * (1L << Math.min(retry, 16)) + random.nextInt(BACK_OFF_BASE)); Thread.sleep(backoffTime); } catch (Exception e) { LOG.warn( @@ -822,7 +835,6 @@ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request .setLength(request.getLength()) .setTimestamp(start) .build(); - GetLocalShuffleDataResponse rpcResponse = getBlockingStub().getLocalShuffleData(rpcRequest); String requestInfo = "appId[" + request.getAppId() @@ -831,22 +843,29 @@ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request + "], partitionId[" + request.getPartitionId() + "]"; - LOG.info( - "GetShuffleData from {}:{} for {} cost {} ms", - host, - port, - requestInfo, - System.currentTimeMillis() - start); - - RssProtos.StatusCode statusCode = rpcResponse.getStatus(); - + int retry = 0; + GetLocalShuffleDataResponse rpcResponse; + while (true) { + rpcResponse = getBlockingStub().getLocalShuffleData(rpcRequest); + if (rpcResponse.getStatus() != NO_BUFFER) { + break; + } + waitOrThrow( + request, retry, requestInfo, StatusCode.fromProto(rpcResponse.getStatus()), start); + retry++; + } RssGetShuffleDataResponse response; - switch (statusCode) { + switch (rpcResponse.getStatus()) { case SUCCESS: + LOG.info( + "GetShuffleData from {}:{} for {} cost {} ms", + host, + port, + requestInfo, + System.currentTimeMillis() - start); response = new RssGetShuffleDataResponse( StatusCode.SUCCESS, ByteBuffer.wrap(rpcResponse.getData().toByteArray())); - break; default: String msg = @@ -874,8 +893,6 @@ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest requ .setPartitionNumPerRange(request.getPartitionNumPerRange()) .setPartitionNum(request.getPartitionNum()) .build(); - long start = System.currentTimeMillis(); - GetLocalShuffleIndexResponse rpcResponse = getBlockingStub().getLocalShuffleIndex(rpcRequest); String requestInfo = "appId[" + request.getAppId() @@ -884,18 +901,27 @@ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest requ + "], partitionId[" + request.getPartitionId() + "]"; - LOG.info( - "GetShuffleIndex from {}:{} for {} cost {} ms", - host, - port, - requestInfo, - System.currentTimeMillis() - start); - - RssProtos.StatusCode statusCode = rpcResponse.getStatus(); - + long start = System.currentTimeMillis(); + int retry = 0; + GetLocalShuffleIndexResponse rpcResponse; + while (true) { + rpcResponse = getBlockingStub().getLocalShuffleIndex(rpcRequest); + if (rpcResponse.getStatus() != NO_BUFFER) { + break; + } + waitOrThrow( + request, retry, requestInfo, StatusCode.fromProto(rpcResponse.getStatus()), start); + retry++; + } RssGetShuffleIndexResponse response; - switch (statusCode) { + switch (rpcResponse.getStatus()) { case SUCCESS: + LOG.info( + "GetShuffleIndex from {}:{} for {} cost {} ms", + host, + port, + requestInfo, + System.currentTimeMillis() - start); response = new RssGetShuffleIndexResponse( StatusCode.SUCCESS, @@ -944,8 +970,6 @@ public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData( .setSerializedExpectedTaskIdsBitmap(serializedTaskIdsBytes) .setTimestamp(start) .build(); - - GetMemoryShuffleDataResponse rpcResponse = getBlockingStub().getMemoryShuffleData(rpcRequest); String requestInfo = "appId[" + request.getAppId() @@ -954,20 +978,28 @@ public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData( + "], partitionId[" + request.getPartitionId() + "]"; - LOG.info( - "GetInMemoryShuffleData from {}:{} for " - + requestInfo - + " cost " - + (System.currentTimeMillis() - start) - + " ms", - host, - port); - - RssProtos.StatusCode statusCode = rpcResponse.getStatus(); - + int retry = 0; + GetMemoryShuffleDataResponse rpcResponse; + while (true) { + rpcResponse = getBlockingStub().getMemoryShuffleData(rpcRequest); + if (rpcResponse.getStatus() != NO_BUFFER) { + break; + } + waitOrThrow( + request, retry, requestInfo, StatusCode.fromProto(rpcResponse.getStatus()), start); + retry++; + } RssGetInMemoryShuffleDataResponse response; - switch (statusCode) { + switch (rpcResponse.getStatus()) { case SUCCESS: + LOG.info( + "GetInMemoryShuffleData from {}:{} for " + + requestInfo + + " cost " + + (System.currentTimeMillis() - start) + + " ms", + host, + port); response = new RssGetInMemoryShuffleDataResponse( StatusCode.SUCCESS, @@ -995,6 +1027,47 @@ public String getClientInfo() { return "ShuffleServerGrpcClient for host[" + host + "], port[" + port + "]"; } + protected void waitOrThrow( + RetryableRequest request, int retry, String requestInfo, StatusCode statusCode, long start) { + if (retry >= request.getRetryMax()) { + String msg = + String.format( + "ShuffleServer %s:%s is full when %s due to %s, after %d retries, cost %d ms", + host, + port, + request.operationType(), + statusCode, + request.getRetryMax(), + System.currentTimeMillis() - start); + LOG.error(msg); + throw new RssFetchFailedException(msg); + } + try { + long backoffTime = + Math.min( + request.getRetryIntervalMax(), + BACK_OFF_BASE * (1L << Math.min(retry, 16)) + random.nextInt(BACK_OFF_BASE)); + LOG.warn( + "Can't acquire buffer for {} from {}:{} when executing {}, due to {}. " + + "Will retry {} more time(s) after waiting {} milliseconds.", + requestInfo, + host, + port, + request.operationType(), + statusCode, + request.getRetryMax() - retry, + backoffTime); + Thread.sleep(backoffTime); + } catch (InterruptedException e) { + LOG.warn( + "Exception happened when executing {} from {}:{}", + request.operationType(), + host, + port, + e); + } + } + private List toShufflePartitionRanges( List partitionRanges) { List ret = Lists.newArrayList(); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java index fc8aa0272b..f677b63856 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java @@ -221,12 +221,27 @@ public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData( + "], lastBlockId[" + request.getLastBlockId() + "]"; - RpcResponse rpcResponse = transportClient.sendRpcSync(getMemoryShuffleDataRequest, rpcTimeout); - GetMemoryShuffleDataResponse getMemoryShuffleDataResponse = - (GetMemoryShuffleDataResponse) rpcResponse; - StatusCode statusCode = rpcResponse.getStatusCode(); - switch (statusCode) { + long start = System.currentTimeMillis(); + int retry = 0; + RpcResponse rpcResponse; + GetMemoryShuffleDataResponse getMemoryShuffleDataResponse; + while (true) { + rpcResponse = transportClient.sendRpcSync(getMemoryShuffleDataRequest, rpcTimeout); + getMemoryShuffleDataResponse = (GetMemoryShuffleDataResponse) rpcResponse; + if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) { + break; + } + waitOrThrow(request, retry, requestInfo, rpcResponse.getStatusCode(), start); + retry++; + } + switch (rpcResponse.getStatusCode()) { case SUCCESS: + LOG.info( + "GetInMemoryShuffleData from {}:{} for {} cost {} ms", + host, + nettyPort, + requestInfo, + System.currentTimeMillis() - start); return new RssGetInMemoryShuffleDataResponse( StatusCode.SUCCESS, getMemoryShuffleDataResponse.body(), @@ -236,7 +251,7 @@ public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData( "Can't get shuffle in memory data from " + host + ":" - + port + + nettyPort + " for " + requestInfo + ", errorMsg:" @@ -257,8 +272,6 @@ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest requ request.getPartitionId(), request.getPartitionNumPerRange(), request.getPartitionNum()); - long start = System.currentTimeMillis(); - RpcResponse rpcResponse = transportClient.sendRpcSync(getLocalShuffleIndexRequest, rpcTimeout); String requestInfo = "appId[" + request.getAppId() @@ -266,17 +279,27 @@ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest requ + request.getShuffleId() + "], partitionId[" + request.getPartitionId(); - LOG.info( - "GetShuffleIndex from {}:{} for {} cost {} ms", - host, - port, - requestInfo, - System.currentTimeMillis() - start); - GetLocalShuffleIndexResponse getLocalShuffleIndexResponse = - (GetLocalShuffleIndexResponse) rpcResponse; - StatusCode statusCode = rpcResponse.getStatusCode(); - switch (statusCode) { + long start = System.currentTimeMillis(); + int retry = 0; + RpcResponse rpcResponse; + GetLocalShuffleIndexResponse getLocalShuffleIndexResponse; + while (true) { + rpcResponse = transportClient.sendRpcSync(getLocalShuffleIndexRequest, rpcTimeout); + getLocalShuffleIndexResponse = (GetLocalShuffleIndexResponse) rpcResponse; + if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) { + break; + } + waitOrThrow(request, retry, requestInfo, rpcResponse.getStatusCode(), start); + retry++; + } + switch (rpcResponse.getStatusCode()) { case SUCCESS: + LOG.info( + "GetShuffleIndex from {}:{} for {} cost {} ms", + host, + nettyPort, + requestInfo, + System.currentTimeMillis() - start); return new RssGetShuffleIndexResponse( StatusCode.SUCCESS, getLocalShuffleIndexResponse.body(), @@ -286,7 +309,7 @@ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest requ "Can't get shuffle index from " + host + ":" - + port + + nettyPort + " for " + requestInfo + ", errorMsg:" @@ -310,8 +333,6 @@ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request request.getOffset(), request.getLength(), System.currentTimeMillis()); - long start = System.currentTimeMillis(); - RpcResponse rpcResponse = transportClient.sendRpcSync(getLocalShuffleIndexRequest, rpcTimeout); String requestInfo = "appId[" + request.getAppId() @@ -320,17 +341,27 @@ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request + "], partitionId[" + request.getPartitionId() + "]"; - LOG.info( - "GetShuffleData from {}:{} for {} cost {} ms", - host, - port, - requestInfo, - System.currentTimeMillis() - start); - GetLocalShuffleDataResponse getLocalShuffleDataResponse = - (GetLocalShuffleDataResponse) rpcResponse; - StatusCode statusCode = rpcResponse.getStatusCode(); - switch (statusCode) { + long start = System.currentTimeMillis(); + int retry = 0; + RpcResponse rpcResponse; + GetLocalShuffleDataResponse getLocalShuffleDataResponse; + while (true) { + rpcResponse = transportClient.sendRpcSync(getLocalShuffleIndexRequest, rpcTimeout); + getLocalShuffleDataResponse = (GetLocalShuffleDataResponse) rpcResponse; + if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) { + break; + } + waitOrThrow(request, retry, requestInfo, rpcResponse.getStatusCode(), start); + retry++; + } + switch (rpcResponse.getStatusCode()) { case SUCCESS: + LOG.info( + "GetShuffleData from {}:{} for {} cost {} ms", + host, + nettyPort, + requestInfo, + System.currentTimeMillis() - start); return new RssGetShuffleDataResponse( StatusCode.SUCCESS, getLocalShuffleDataResponse.body()); default: @@ -338,7 +369,7 @@ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request "Can't get shuffle data from " + host + ":" - + port + + nettyPort + " for " + requestInfo + ", errorMsg:" diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RetryableRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RetryableRequest.java new file mode 100644 index 0000000000..2abe4b2fca --- /dev/null +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RetryableRequest.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.request; + +public abstract class RetryableRequest { + protected int retryMax; + protected long retryIntervalMax; + + public int getRetryMax() { + return retryMax; + } + + public long getRetryIntervalMax() { + return retryIntervalMax; + } + + public abstract String operationType(); +} diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetInMemoryShuffleDataRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetInMemoryShuffleDataRequest.java index bf3534e1a5..64c41104f7 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetInMemoryShuffleDataRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetInMemoryShuffleDataRequest.java @@ -17,9 +17,10 @@ package org.apache.uniffle.client.request; +import com.google.common.annotations.VisibleForTesting; import org.roaringbitmap.longlong.Roaring64NavigableMap; -public class RssGetInMemoryShuffleDataRequest { +public class RssGetInMemoryShuffleDataRequest extends RetryableRequest { private final String appId; private final int shuffleId; private final int partitionId; @@ -33,13 +34,28 @@ public RssGetInMemoryShuffleDataRequest( int partitionId, long lastBlockId, int readBufferSize, - Roaring64NavigableMap expectedTaskIds) { + Roaring64NavigableMap expectedTaskIds, + int retryMax, + long retryIntervalMax) { this.appId = appId; this.shuffleId = shuffleId; this.partitionId = partitionId; this.lastBlockId = lastBlockId; this.readBufferSize = readBufferSize; this.expectedTaskIds = expectedTaskIds; + this.retryMax = retryMax; + this.retryIntervalMax = retryIntervalMax; + } + + @VisibleForTesting + public RssGetInMemoryShuffleDataRequest( + String appId, + int shuffleId, + int partitionId, + long lastBlockId, + int readBufferSize, + Roaring64NavigableMap expectedTaskIds) { + this(appId, shuffleId, partitionId, lastBlockId, readBufferSize, expectedTaskIds, 1, 0); } public String getAppId() { @@ -65,4 +81,9 @@ public int getReadBufferSize() { public Roaring64NavigableMap getExpectedTaskIds() { return expectedTaskIds; } + + @Override + public String operationType() { + return "GetInMemoryShuffleData"; + } } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java index 0b9997a786..5801922171 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java @@ -17,7 +17,9 @@ package org.apache.uniffle.client.request; -public class RssGetShuffleDataRequest { +import com.google.common.annotations.VisibleForTesting; + +public class RssGetShuffleDataRequest extends RetryableRequest { private final String appId; private final int shuffleId; @@ -34,7 +36,9 @@ public RssGetShuffleDataRequest( int partitionNumPerRange, int partitionNum, long offset, - int length) { + int length, + int retryMax, + long retryIntervalMax) { this.appId = appId; this.shuffleId = shuffleId; this.partitionId = partitionId; @@ -42,6 +46,20 @@ public RssGetShuffleDataRequest( this.partitionNum = partitionNum; this.offset = offset; this.length = length; + this.retryMax = retryMax; + this.retryIntervalMax = retryIntervalMax; + } + + @VisibleForTesting + public RssGetShuffleDataRequest( + String appId, + int shuffleId, + int partitionId, + int partitionNumPerRange, + int partitionNum, + long offset, + int length) { + this(appId, shuffleId, partitionId, partitionNumPerRange, partitionNum, offset, length, 1, 0); } public String getAppId() { @@ -71,4 +89,9 @@ public long getOffset() { public int getLength() { return length; } + + @Override + public String operationType() { + return "GetShuffleData"; + } } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleIndexRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleIndexRequest.java index 8ae5da712c..0e61206a22 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleIndexRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleIndexRequest.java @@ -17,7 +17,9 @@ package org.apache.uniffle.client.request; -public class RssGetShuffleIndexRequest { +import com.google.common.annotations.VisibleForTesting; + +public class RssGetShuffleIndexRequest extends RetryableRequest { private final String appId; private final int shuffleId; @@ -26,12 +28,26 @@ public class RssGetShuffleIndexRequest { private final int partitionNum; public RssGetShuffleIndexRequest( - String appId, int shuffleId, int partitionId, int partitionNumPerRange, int partitionNum) { + String appId, + int shuffleId, + int partitionId, + int partitionNumPerRange, + int partitionNum, + int retryMax, + long retryIntervalMax) { this.appId = appId; this.shuffleId = shuffleId; this.partitionId = partitionId; this.partitionNumPerRange = partitionNumPerRange; this.partitionNum = partitionNum; + this.retryMax = retryMax; + this.retryIntervalMax = retryIntervalMax; + } + + @VisibleForTesting + public RssGetShuffleIndexRequest( + String appId, int shuffleId, int partitionId, int partitionNumPerRange, int partitionNum) { + this(appId, shuffleId, partitionId, partitionNumPerRange, partitionNum, 1, 0); } public String getAppId() { @@ -53,4 +69,9 @@ public int getPartitionNumPerRange() { public int getPartitionNum() { return partitionNum; } + + @Override + public String operationType() { + return "GetShuffleIndex"; + } } diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java index 9ea2e84f21..bd71c3bc4d 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java @@ -132,12 +132,6 @@ public class ShuffleServerConf extends RssBaseConf { .withDescription( "Expired time (ms) for application which has no heartbeat with coordinator"); - public static final ConfigOption SERVER_MEMORY_REQUEST_RETRY_MAX = - ConfigOptions.key("rss.server.memory.request.retry.max") - .intType() - .defaultValue(50) - .withDescription("Max times to retry for memory request"); - public static final ConfigOption SERVER_PRE_ALLOCATION_EXPIRED = ConfigOptions.key("rss.server.preAllocation.expired") .longType() diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index dcc2717c58..9f8f79eb56 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -674,7 +674,7 @@ public void getLocalShuffleData( storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId)); } - if (shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(length)) { + if (shuffleServer.getShuffleBufferManager().requireReadMemory(length)) { try { long start = System.currentTimeMillis(); sdr = @@ -722,7 +722,7 @@ public void getLocalShuffleData( shuffleServer.getShuffleBufferManager().releaseReadMemory(length); } } else { - status = StatusCode.INTERNAL_ERROR; + status = StatusCode.NO_BUFFER; msg = "Can't require memory to get shuffle data"; LOG.error(msg + " for " + requestInfo); reply = @@ -766,7 +766,7 @@ public void getLocalShuffleIndex( shuffleServer .getShuffleServerConf() .getLong(ShuffleServerConf.SERVER_SHUFFLE_INDEX_SIZE_HINT); - if (shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(assumedFileSize)) { + if (shuffleServer.getShuffleBufferManager().requireReadMemory(assumedFileSize)) { ShuffleIndexResult shuffleIndexResult = null; try { long start = System.currentTimeMillis(); @@ -812,7 +812,7 @@ public void getLocalShuffleIndex( shuffleServer.getShuffleBufferManager().releaseReadMemory(assumedFileSize); } } else { - status = StatusCode.INTERNAL_ERROR; + status = StatusCode.NO_BUFFER; msg = "Can't require memory to get shuffle index"; LOG.error(msg + " for " + requestInfo); reply = @@ -853,7 +853,7 @@ public void getMemoryShuffleData( "appId[" + appId + "], shuffleId[" + shuffleId + "], partitionId[" + partitionId + "]"; // todo: if can get the exact memory size? - if (shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(readBufferSize)) { + if (shuffleServer.getShuffleBufferManager().requireReadMemory(readBufferSize)) { ShuffleDataResult shuffleDataResult = null; try { Roaring64NavigableMap expectedTaskIds = null; @@ -915,7 +915,7 @@ public void getMemoryShuffleData( shuffleServer.getShuffleBufferManager().releaseReadMemory(readBufferSize); } } else { - status = StatusCode.INTERNAL_ERROR; + status = StatusCode.NO_BUFFER; msg = "Can't require memory to get in memory shuffle data"; LOG.error(msg + " for " + requestInfo); reply = diff --git a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java index 4d42b0576f..8f41a07956 100644 --- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java +++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java @@ -60,7 +60,6 @@ public class ShuffleBufferManager { private final ShuffleFlushManager shuffleFlushManager; private long capacity; private long readCapacity; - private int retryNum; private long highWaterMark; private long lowWaterMark; private boolean bufferFlushEnabled; @@ -111,7 +110,6 @@ public ShuffleBufferManager( readCapacity); this.shuffleFlushManager = shuffleFlushManager; this.bufferPool = new ConcurrentHashMap<>(); - this.retryNum = conf.getInteger(ShuffleServerConf.SERVER_MEMORY_REQUEST_RETRY_MAX); this.highWaterMark = (long) (capacity @@ -424,35 +422,37 @@ private void releaseFlushMemory(long size) { ShuffleServerMetrics.gaugeInFlushBufferSize.set(inFlushSize.get()); } - public boolean requireReadMemoryWithRetry(long size) { + public boolean requireReadMemory(long size) { ShuffleServerMetrics.counterTotalRequireReadMemoryNum.inc(); - for (int i = 0; i < retryNum; i++) { - synchronized (this) { - if (readDataMemory.get() + size < readCapacity) { - readDataMemory.addAndGet(size); - ShuffleServerMetrics.gaugeReadBufferUsedSize.inc(size); - return true; - } + boolean isSuccessful = false; + + do { + long currentReadDataMemory = readDataMemory.get(); + long newReadDataMemory = currentReadDataMemory + size; + if (newReadDataMemory >= readCapacity) { + break; } - LOG.info( + if (readDataMemory.compareAndSet(currentReadDataMemory, newReadDataMemory)) { + ShuffleServerMetrics.gaugeReadBufferUsedSize.inc(size); + isSuccessful = true; + break; + } + } while (true); + + if (!isSuccessful) { + LOG.error( "Can't require[" + size + "] for read data, current[" + readDataMemory.get() + "], capacity[" + readCapacity - + "], re-try " - + i - + " times"); + + "]"); ShuffleServerMetrics.counterTotalRequireReadMemoryRetryNum.inc(); - try { - Thread.sleep(1000); - } catch (Exception e) { - LOG.warn("Error happened when require memory", e); - } + ShuffleServerMetrics.counterTotalRequireReadMemoryFailedNum.inc(); } - ShuffleServerMetrics.counterTotalRequireReadMemoryFailedNum.inc(); - return false; + + return isSuccessful; } public void releaseReadMemory(long size) { diff --git a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java index 2e0c070e9a..e87f9aa4e6 100644 --- a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java +++ b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java @@ -263,7 +263,7 @@ public void handleGetMemoryShuffleDataRequest( "appId[" + appId + "], shuffleId[" + shuffleId + "], partitionId[" + partitionId + "]"; // todo: if can get the exact memory size? - if (shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(readBufferSize)) { + if (shuffleServer.getShuffleBufferManager().requireReadMemory(readBufferSize)) { ShuffleDataResult shuffleDataResult = null; try { shuffleDataResult = @@ -308,7 +308,7 @@ public void handleGetMemoryShuffleDataRequest( req.getRequestId(), status, msg, Lists.newArrayList(), Unpooled.EMPTY_BUFFER); } } else { - status = StatusCode.INTERNAL_ERROR; + status = StatusCode.NO_BUFFER; msg = "Can't require memory to get in memory shuffle data"; LOG.error(msg + " for " + requestInfo); response = @@ -347,7 +347,7 @@ public void handleGetLocalShuffleIndexRequest( shuffleServer .getShuffleServerConf() .getLong(ShuffleServerConf.SERVER_SHUFFLE_INDEX_SIZE_HINT); - if (shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(assumedFileSize)) { + if (shuffleServer.getShuffleBufferManager().requireReadMemory(assumedFileSize)) { ShuffleIndexResult shuffleIndexResult = null; try { final long start = System.currentTimeMillis(); @@ -392,7 +392,7 @@ public void handleGetLocalShuffleIndexRequest( req.getRequestId(), status, msg, Unpooled.EMPTY_BUFFER, 0L); } } else { - status = StatusCode.INTERNAL_ERROR; + status = StatusCode.NO_BUFFER; msg = "Can't require memory to get shuffle index"; LOG.error(msg + " for " + requestInfo); response = @@ -447,7 +447,7 @@ public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDat storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId)); } - if (shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(length)) { + if (shuffleServer.getShuffleBufferManager().requireReadMemory(length)) { ShuffleDataResult sdr = null; try { final long start = System.currentTimeMillis(); @@ -486,7 +486,7 @@ public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDat req.getRequestId(), status, msg, new NettyManagedBuffer(Unpooled.EMPTY_BUFFER)); } } else { - status = StatusCode.INTERNAL_ERROR; + status = StatusCode.NO_BUFFER; msg = "Can't require memory to get shuffle data"; LOG.error(msg + " for " + requestInfo); response = diff --git a/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java b/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java index eafd832928..87b9abd51d 100644 --- a/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java +++ b/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java @@ -22,11 +22,14 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import com.google.common.collect.Lists; import com.google.common.util.concurrent.Uninterruptibles; +import com.google.protobuf.UnsafeByteOperations; import io.grpc.stub.StreamObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.proto.RssProtos; @@ -45,7 +48,11 @@ public class MockedShuffleServerGrpcService extends ShuffleServerGrpcService { private boolean recordGetShuffleResult = false; private long numOfFailedReadRequest = 0; - private AtomicInteger failedReadRequest = new AtomicInteger(0); + private AtomicInteger failedGetShuffleResultRequest = new AtomicInteger(0); + private AtomicInteger failedGetShuffleResultForMultiPartRequest = new AtomicInteger(0); + private AtomicInteger failedGetMemoryShuffleDataRequest = new AtomicInteger(0); + private AtomicInteger failedGetLocalShuffleDataRequest = new AtomicInteger(0); + private AtomicInteger failedGetLocalShuffleIndexRequest = new AtomicInteger(0); public void enableMockedTimeout(long timeout) { mockedTimeout = timeout; @@ -69,7 +76,11 @@ public void enableFirstNReadRequestToFail(int n) { public void resetFirstNReadRequestToFail() { numOfFailedReadRequest = 0; - failedReadRequest.set(0); + failedGetShuffleResultRequest.set(0); + failedGetShuffleResultForMultiPartRequest.set(0); + failedGetMemoryShuffleDataRequest.set(0); + failedGetLocalShuffleDataRequest.set(0); + failedGetLocalShuffleIndexRequest.set(0); } public MockedShuffleServerGrpcService(ShuffleServer shuffleServer) { @@ -111,7 +122,7 @@ public void getShuffleResult( Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS); } if (numOfFailedReadRequest > 0) { - int currentFailedReadRequest = failedReadRequest.getAndIncrement(); + int currentFailedReadRequest = failedGetShuffleResultRequest.getAndIncrement(); if (currentFailedReadRequest < numOfFailedReadRequest) { LOG.info( "This request is failed as mocked failure, current/firstN: {}/{}", @@ -128,11 +139,11 @@ public void getShuffleResultForMultiPart( RssProtos.GetShuffleResultForMultiPartRequest request, StreamObserver responseObserver) { if (mockedTimeout > 0) { - LOG.info("Add a mocked timeout on getShuffleResult"); + LOG.info("Add a mocked timeout on getShuffleResultForMultiPart"); Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS); } if (numOfFailedReadRequest > 0) { - int currentFailedReadRequest = failedReadRequest.getAndIncrement(); + int currentFailedReadRequest = failedGetShuffleResultForMultiPartRequest.getAndIncrement(); if (currentFailedReadRequest < numOfFailedReadRequest) { LOG.info( "This request is failed as mocked failure, current/firstN: {}/{}", @@ -163,15 +174,81 @@ public void getMemoryShuffleData( RssProtos.GetMemoryShuffleDataRequest request, StreamObserver responseObserver) { if (numOfFailedReadRequest > 0) { - int currentFailedReadRequest = failedReadRequest.getAndIncrement(); + int currentFailedReadRequest = failedGetMemoryShuffleDataRequest.getAndIncrement(); if (currentFailedReadRequest < numOfFailedReadRequest) { LOG.info( "This request is failed as mocked failure, current/firstN: {}/{}", currentFailedReadRequest, numOfFailedReadRequest); - throw new RuntimeException("This request is failed as mocked failure"); + StatusCode status = StatusCode.NO_BUFFER; + String msg = + "Can't require memory to get in memory shuffle data (This request is failed as mocked failure)"; + RssProtos.GetMemoryShuffleDataResponse reply = + RssProtos.GetMemoryShuffleDataResponse.newBuilder() + .setData(UnsafeByteOperations.unsafeWrap(new byte[] {})) + .addAllShuffleDataBlockSegments(Lists.newArrayList()) + .setStatus(status.toProto()) + .setRetMsg(msg) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; } } super.getMemoryShuffleData(request, responseObserver); } + + @Override + public void getLocalShuffleData( + RssProtos.GetLocalShuffleDataRequest request, + StreamObserver responseObserver) { + if (numOfFailedReadRequest > 0) { + int currentFailedReadRequest = failedGetLocalShuffleDataRequest.getAndIncrement(); + if (currentFailedReadRequest < numOfFailedReadRequest) { + LOG.info( + "This request is failed as mocked failure, current/firstN: {}/{}", + currentFailedReadRequest, + numOfFailedReadRequest); + StatusCode status = StatusCode.NO_BUFFER; + String msg = + "Can't require memory to get shuffle data (This request is failed as mocked failure)"; + RssProtos.GetLocalShuffleDataResponse reply = + RssProtos.GetLocalShuffleDataResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(msg) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } + } + super.getLocalShuffleData(request, responseObserver); + } + + @Override + public void getLocalShuffleIndex( + RssProtos.GetLocalShuffleIndexRequest request, + StreamObserver responseObserver) { + if (numOfFailedReadRequest > 0) { + int currentFailedReadRequest = failedGetLocalShuffleIndexRequest.getAndIncrement(); + if (currentFailedReadRequest < numOfFailedReadRequest) { + LOG.info( + "This request is failed as mocked failure, current/firstN: {}/{}", + currentFailedReadRequest, + numOfFailedReadRequest); + StatusCode status = StatusCode.NO_BUFFER; + String msg = + "Can't require memory to get shuffle index (This request is failed as mocked failure)"; + RssProtos.GetLocalShuffleIndexResponse reply = + RssProtos.GetLocalShuffleIndexResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(msg) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } + } + super.getLocalShuffleIndex(request, responseObserver); + } } diff --git a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java index 1db3326bea..819c26e040 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java +++ b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java @@ -134,7 +134,9 @@ private ClientReadHandler getMemoryClientReadHandler( request.getPartitionId(), request.getReadBufferSize(), shuffleServerClient, - expectTaskIds); + expectTaskIds, + request.getRetryMax(), + request.getRetryIntervalMax()); return memoryClientReadHandler; } @@ -155,7 +157,9 @@ private ClientReadHandler getLocalfileClientReaderHandler( request.getProcessBlockIds(), shuffleServerClient, request.getDistributionType(), - request.getExpectTaskIds()); + request.getExpectTaskIds(), + request.getRetryMax(), + request.getRetryIntervalMax()); } private ClientReadHandler getHadoopClientReadHandler( diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java index 9fc50884ad..2b5ea8f7a9 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java @@ -17,6 +17,7 @@ package org.apache.uniffle.storage.handler.impl; +import com.google.common.annotations.VisibleForTesting; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,6 +38,8 @@ public class LocalFileClientReadHandler extends DataSkippableReadHandler { private final int partitionNumPerRange; private final int partitionNum; private ShuffleServerClient shuffleServerClient; + private int retryMax; + private long retryIntervalMax; public LocalFileClientReadHandler( String appId, @@ -50,7 +53,9 @@ public LocalFileClientReadHandler( Roaring64NavigableMap processBlockIds, ShuffleServerClient shuffleServerClient, ShuffleDataDistributionType distributionType, - Roaring64NavigableMap expectTaskIds) { + Roaring64NavigableMap expectTaskIds, + int retryMax, + long retryIntervalMax) { super( appId, shuffleId, @@ -63,9 +68,11 @@ public LocalFileClientReadHandler( this.shuffleServerClient = shuffleServerClient; this.partitionNumPerRange = partitionNumPerRange; this.partitionNum = partitionNum; + this.retryMax = retryMax; + this.retryIntervalMax = retryIntervalMax; } - /** Only for test */ + @VisibleForTesting public LocalFileClientReadHandler( String appId, int shuffleId, @@ -89,7 +96,9 @@ public LocalFileClientReadHandler( processBlockIds, shuffleServerClient, ShuffleDataDistributionType.NORMAL, - Roaring64NavigableMap.bitmapOf()); + Roaring64NavigableMap.bitmapOf(), + 1, + 0); } @Override @@ -97,7 +106,13 @@ public ShuffleIndexResult readShuffleIndex() { ShuffleIndexResult shuffleIndexResult = null; RssGetShuffleIndexRequest request = new RssGetShuffleIndexRequest( - appId, shuffleId, partitionId, partitionNumPerRange, partitionNum); + appId, + shuffleId, + partitionId, + partitionNumPerRange, + partitionNum, + retryMax, + retryIntervalMax); try { shuffleIndexResult = shuffleServerClient.getShuffleIndex(request).getShuffleIndexResult(); } catch (RssFetchFailedException e) { @@ -141,17 +156,16 @@ public ShuffleDataResult readShuffleData(ShuffleDataSegment shuffleDataSegment) partitionNumPerRange, partitionNum, shuffleDataSegment.getOffset(), - expectedLength); + expectedLength, + retryMax, + retryIntervalMax); try { RssGetShuffleDataResponse response = shuffleServerClient.getShuffleData(request); result = new ShuffleDataResult(response.getShuffleData(), shuffleDataSegment.getBufferSegments()); } catch (Exception e) { throw new RssException( - "Failed to read shuffle data with " - + shuffleServerClient.getClientInfo() - + " due to " - + e.getMessage()); + "Failed to read shuffle data with " + shuffleServerClient.getClientInfo(), e); } if (result.getDataBuffer().remaining() != expectedLength) { throw new RssException( diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java index a3a7931921..f1fbe2361c 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java @@ -19,6 +19,7 @@ import java.util.List; +import com.google.common.annotations.VisibleForTesting; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,6 +38,8 @@ public class MemoryClientReadHandler extends AbstractClientReadHandler { private long lastBlockId = Constants.INVALID_BLOCK_ID; private ShuffleServerClient shuffleServerClient; private Roaring64NavigableMap expectTaskIds; + private int retryMax; + private long retryIntervalMax; public MemoryClientReadHandler( String appId, @@ -44,13 +47,28 @@ public MemoryClientReadHandler( int partitionId, int readBufferSize, ShuffleServerClient shuffleServerClient, - Roaring64NavigableMap expectTaskIds) { + Roaring64NavigableMap expectTaskIds, + int retryMax, + long retryIntervalMax) { this.appId = appId; this.shuffleId = shuffleId; this.partitionId = partitionId; this.readBufferSize = readBufferSize; this.shuffleServerClient = shuffleServerClient; this.expectTaskIds = expectTaskIds; + this.retryMax = retryMax; + this.retryIntervalMax = retryIntervalMax; + } + + @VisibleForTesting + public MemoryClientReadHandler( + String appId, + int shuffleId, + int partitionId, + int readBufferSize, + ShuffleServerClient shuffleServerClient, + Roaring64NavigableMap expectTaskIds) { + this(appId, shuffleId, partitionId, readBufferSize, shuffleServerClient, expectTaskIds, 1, 0); } @Override @@ -59,7 +77,14 @@ public ShuffleDataResult readShuffleData() { RssGetInMemoryShuffleDataRequest request = new RssGetInMemoryShuffleDataRequest( - appId, shuffleId, partitionId, lastBlockId, readBufferSize, expectTaskIds); + appId, + shuffleId, + partitionId, + lastBlockId, + readBufferSize, + expectTaskIds, + retryMax, + retryIntervalMax); try { RssGetInMemoryShuffleDataResponse response = @@ -70,10 +95,7 @@ public ShuffleDataResult readShuffleData() { } catch (Exception e) { // todo: fault tolerance solution should be added throw new RssFetchFailedException( - "Failed to read in memory shuffle data with " - + shuffleServerClient.getClientInfo() - + " due to " - + e); + "Failed to read in memory shuffle data with " + shuffleServerClient.getClientInfo(), e); } // update lastBlockId for next rpc call diff --git a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java index 38c7e9efbd..9b73dc85a4 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java +++ b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java @@ -39,6 +39,8 @@ public class CreateShuffleReadHandlerRequest { private int partitionNumPerRange; private int partitionNum; private int readBufferSize; + private int retryMax; + private long retryIntervalMax; private String storageBasePath; private RssBaseConf rssBaseConf; private Configuration hadoopConf; @@ -129,6 +131,22 @@ public void setReadBufferSize(int readBufferSize) { this.readBufferSize = readBufferSize; } + public int getRetryMax() { + return retryMax; + } + + public void setRetryMax(int retryMax) { + this.retryMax = retryMax; + } + + public long getRetryIntervalMax() { + return retryIntervalMax; + } + + public void setRetryIntervalMax(long retryIntervalMax) { + this.retryIntervalMax = retryIntervalMax; + } + public String getStorageBasePath() { return storageBasePath; } diff --git a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java index 884f2b969e..2a55ae4de2 100644 --- a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java +++ b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandlerTest.java @@ -34,7 +34,6 @@ import org.apache.uniffle.client.request.RssGetShuffleDataRequest; import org.apache.uniffle.client.response.RssGetShuffleDataResponse; import org.apache.uniffle.client.response.RssGetShuffleIndexResponse; -import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleDataResult; import org.apache.uniffle.common.ShufflePartitionedBlock; import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; @@ -135,9 +134,7 @@ public void testDataInconsistent() throws Exception { readBufferSize, expectBlockIds, processBlockIds, - mockShuffleServerClient, - ShuffleDataDistributionType.NORMAL, - Roaring64NavigableMap.bitmapOf()); + mockShuffleServerClient); int totalSegment = ((blockSize * actualWriteDataBlock) / bytesPerSegment) + 1; int readBlocks = 0; for (int i = 0; i < totalSegment; i++) {