diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java index b1b9507b39..ed1f90f94d 100644 --- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java +++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java @@ -55,6 +55,11 @@ public class RssMRConfig { MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_DATA_TRANSFER_POOL_SIZE; public static final int RSS_DATA_TRANSFER_POOL_SIZE_DEFAULT_VALUE = RssClientConfig.RSS_DATA_TRANFER_POOL_SIZE_DEFAULT_VALUE; + public static final String RSS_DATA_COMMIT_POOL_SIZE = + MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_DATA_COMMIT_POOL_SIZE; + public static final int RSS_DATA_COMMIT_POOL_SIZE_DEFAULT_VALUE = + RssClientConfig.RSS_DATA_COMMIT_POOL_SIZE_DEFAULT_VALUE; + public static final String RSS_CLIENT_SEND_THREAD_NUM = MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_SEND_THREAD_NUM; public static final int RSS_CLIENT_DEFAULT_SEND_THREAD_NUM = diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java index 684a5ec883..740de51eed 100644 --- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java +++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRUtils.java @@ -91,11 +91,13 @@ public static ShuffleWriteClient createShuffleClient(JobConf jobConf) { RssMRConfig.RSS_DATA_REPLICA_SKIP_ENABLED_DEFAULT_VALUE); int dataTransferPoolSize = jobConf.getInt(RssMRConfig.RSS_DATA_TRANSFER_POOL_SIZE, RssMRConfig.RSS_DATA_TRANSFER_POOL_SIZE_DEFAULT_VALUE); + int dataCommitPoolSize = jobConf.getInt(RssMRConfig.RSS_DATA_COMMIT_POOL_SIZE, + RssMRConfig.RSS_DATA_COMMIT_POOL_SIZE_DEFAULT_VALUE); ShuffleWriteClient client = ShuffleClientFactory .getInstance() .createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, replica, replicaWrite, replicaRead, replicaSkipEnabled, - dataTransferPoolSize); + dataTransferPoolSize, dataCommitPoolSize); return client; } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java index 041e21f82a..a9e845a8bc 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java @@ -109,6 +109,11 @@ public class RssSparkConfig { SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_DATA_TRANSFER_POOL_SIZE; public static final int RSS_DATA_TRANSFER_POOL_SIZE_DEFAULT_VALUE = RssClientConfig.RSS_DATA_TRANFER_POOL_SIZE_DEFAULT_VALUE; + public static final String RSS_DATA_COMMIT_POOL_SIZE = + SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_DATA_COMMIT_POOL_SIZE; + public static final int RSS_DATA_COMMIT_POOL_SIZE_DEFAULT_VALUE = + RssClientConfig.RSS_DATA_COMMIT_POOL_SIZE_DEFAULT_VALUE; + public static final boolean RSS_DATA_REPLICA_SKIP_ENABLED_DEFAULT_VALUE = RssClientConfig.RSS_DATA_REPLICA_SKIP_ENABLED_DEFAULT_VALUE; public static final String RSS_OZONE_DFS_NAMENODE_ODFS_ENABLE = diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 705d66b39f..4198460073 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -81,6 +81,7 @@ public class RssShuffleManager implements ShuffleManager { private final int dataReplicaRead; private final boolean dataReplicaSkipEnabled; private final int dataTransferPoolSize; + private final int dataCommitPoolSize; private boolean heartbeatStarted = false; private boolean dynamicConfEnabled = false; private RemoteStorageInfo remoteStorage; @@ -168,10 +169,13 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE); int heartBeatThreadNum = sparkConf.getInt(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM, RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM_DEFAULT_VALUE); - shuffleWriteClient = ShuffleClientFactory + this.dataCommitPoolSize = sparkConf.getInt(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE, + RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE_DEFAULT_VALUE); + this.shuffleWriteClient = ShuffleClientFactory .getInstance() .createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, - dataReplica, dataReplicaWrite, dataReplicaRead, dataReplicaSkipEnabled, dataTransferPoolSize); + dataReplica, dataReplicaWrite, dataReplicaRead, dataReplicaSkipEnabled, dataTransferPoolSize, + dataCommitPoolSize); registerCoordinator(); // fetch client conf and apply them if necessary and disable ESS if (isDriver && dynamicConfEnabled) { diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index b71f1b3b30..124ae6134b 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -81,6 +81,7 @@ public class RssShuffleManager implements ShuffleManager { private final int dataReplicaRead; private final boolean dataReplicaSkipEnabled; private final int dataTransferPoolSize; + private final int dataCommitPoolSize; private ShuffleWriteClient shuffleWriteClient; private final Map> taskToSuccessBlockIds; private final Map> taskToFailedBlockIds; @@ -169,11 +170,14 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { this.dataTransferPoolSize = sparkConf.getInt(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE, RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE_DEFAULT_VALUE); + this.dataCommitPoolSize = sparkConf.getInt(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE, + RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE_DEFAULT_VALUE); shuffleWriteClient = ShuffleClientFactory .getInstance() .createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, - dataReplica, dataReplicaWrite, dataReplicaRead, dataReplicaSkipEnabled, dataTransferPoolSize); + dataReplica, dataReplicaWrite, dataReplicaRead, dataReplicaSkipEnabled, dataTransferPoolSize, + dataCommitPoolSize); registerCoordinator(); // fetch client conf and apply them if necessary and disable ESS if (isDriver && dynamicConfEnabled) { @@ -238,11 +242,14 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM_DEFAULT_VALUE); this.dataTransferPoolSize = sparkConf.getInt(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE, RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE_DEFAULT_VALUE); + this.dataCommitPoolSize = sparkConf.getInt(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE, + RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE_DEFAULT_VALUE); shuffleWriteClient = ShuffleClientFactory .getInstance() .createShuffleWriteClient(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, - dataReplica, dataReplicaWrite, dataReplicaRead, dataReplicaSkipEnabled, dataTransferPoolSize); + dataReplica, dataReplicaWrite, dataReplicaRead, dataReplicaSkipEnabled, dataTransferPoolSize, + dataCommitPoolSize); this.taskToSuccessBlockIds = taskToSuccessBlockIds; this.taskToFailedBlockIds = taskToFailedBlockIds; if (loop != null) { diff --git a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java index 5afa43de12..d11a07f178 100644 --- a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java +++ b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java @@ -36,9 +36,10 @@ public static ShuffleClientFactory getInstance() { public ShuffleWriteClient createShuffleWriteClient( String clientType, int retryMax, long retryIntervalMax, int heartBeatThreadNum, - int replica, int replicaWrite, int replicaRead, boolean replicaSkipEnabled, int dataTransferPoolSize) { + int replica, int replicaWrite, int replicaRead, boolean replicaSkipEnabled, int dataTransferPoolSize, + int dataCommitPoolSize) { return new ShuffleWriteClientImpl(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, - replica, replicaWrite, replicaRead, replicaSkipEnabled, dataTransferPoolSize); + replica, replicaWrite, replicaRead, replicaSkipEnabled, dataTransferPoolSize, dataCommitPoolSize); } public ShuffleReadClient createShuffleReadClient(CreateShuffleReadClientRequest request) { diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index ce4c24731a..b078d42f80 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -88,11 +88,20 @@ public class ShuffleWriteClientImpl implements ShuffleWriteClient { private int replicaRead; private boolean replicaSkipEnabled; private int dataTranferPoolSize; + private int dataCommitPoolSize = -1; private final ForkJoinPool dataTransferPool; - public ShuffleWriteClientImpl(String clientType, int retryMax, long retryIntervalMax, int heartBeatThreadNum, - int replica, int replicaWrite, int replicaRead, boolean replicaSkipEnabled, - int dataTranferPoolSize) { + public ShuffleWriteClientImpl( + String clientType, + int retryMax, + long retryIntervalMax, + int heartBeatThreadNum, + int replica, + int replicaWrite, + int replicaRead, + boolean replicaSkipEnabled, + int dataTranferPoolSize, + int dataCommitPoolSize) { this.clientType = clientType; this.retryMax = retryMax; this.retryIntervalMax = retryIntervalMax; @@ -105,6 +114,7 @@ public ShuffleWriteClientImpl(String clientType, int retryMax, long retryInterva this.replicaSkipEnabled = replicaSkipEnabled; this.dataTranferPoolSize = dataTranferPoolSize; this.dataTransferPool = new ForkJoinPool(dataTranferPoolSize); + this.dataCommitPoolSize = dataCommitPoolSize; } private boolean sendShuffleDataAsync( @@ -247,43 +257,62 @@ public SendShuffleDataResult sendShuffleData(String appId, List shuffleServerInfoSet, String appId, int shuffleId, int numMaps) { + ForkJoinPool forkJoinPool = new ForkJoinPool( + dataCommitPoolSize == -1 ? shuffleServerInfoSet.size() : dataCommitPoolSize + ); AtomicInteger successfulCommit = new AtomicInteger(0); - shuffleServerInfoSet.stream().forEach(ssi -> { - RssSendCommitRequest request = new RssSendCommitRequest(appId, shuffleId); - String errorMsg = "Failed to commit shuffle data to " + ssi + " for shuffleId[" + shuffleId + "]"; - long startTime = System.currentTimeMillis(); - try { - RssSendCommitResponse response = getShuffleServerClient(ssi).sendCommit(request); - if (response.getStatusCode() == ResponseStatusCode.SUCCESS) { - int commitCount = response.getCommitCount(); - LOG.info("Successfully sendCommit for appId[" + appId + "], shuffleId[" + shuffleId - + "] to ShuffleServer[" + ssi.getId() + "], cost " - + (System.currentTimeMillis() - startTime) + " ms, got committed maps[" - + commitCount + "], map number of stage is " + numMaps); - if (commitCount >= numMaps) { - RssFinishShuffleResponse rfsResponse = - getShuffleServerClient(ssi).finishShuffle(new RssFinishShuffleRequest(appId, shuffleId)); - if (rfsResponse.getStatusCode() != ResponseStatusCode.SUCCESS) { - String msg = "Failed to finish shuffle to " + ssi + " for shuffleId[" + shuffleId - + "] with statusCode " + rfsResponse.getStatusCode(); + try { + forkJoinPool.submit(() -> { + shuffleServerInfoSet.parallelStream().forEach(ssi -> { + RssSendCommitRequest request = new RssSendCommitRequest(appId, shuffleId); + String errorMsg = "Failed to commit shuffle data to " + ssi + " for shuffleId[" + shuffleId + "]"; + long startTime = System.currentTimeMillis(); + try { + RssSendCommitResponse response = getShuffleServerClient(ssi).sendCommit(request); + if (response.getStatusCode() == ResponseStatusCode.SUCCESS) { + int commitCount = response.getCommitCount(); + LOG.info("Successfully sendCommit for appId[" + appId + "], shuffleId[" + shuffleId + + "] to ShuffleServer[" + ssi.getId() + "], cost " + + (System.currentTimeMillis() - startTime) + " ms, got committed maps[" + + commitCount + "], map number of stage is " + numMaps); + if (commitCount >= numMaps) { + RssFinishShuffleResponse rfsResponse = + getShuffleServerClient(ssi).finishShuffle(new RssFinishShuffleRequest(appId, shuffleId)); + if (rfsResponse.getStatusCode() != ResponseStatusCode.SUCCESS) { + String msg = "Failed to finish shuffle to " + ssi + " for shuffleId[" + shuffleId + + "] with statusCode " + rfsResponse.getStatusCode(); + LOG.error(msg); + throw new Exception(msg); + } else { + LOG.info("Successfully finish shuffle to " + ssi + " for shuffleId[" + shuffleId + "]"); + } + } + } else { + String msg = errorMsg + " with statusCode " + response.getStatusCode(); LOG.error(msg); throw new Exception(msg); - } else { - LOG.info("Successfully finish shuffle to " + ssi + " for shuffleId[" + shuffleId + "]"); } + successfulCommit.incrementAndGet(); + } catch (Exception e) { + LOG.error(errorMsg, e); } - } else { - String msg = errorMsg + " with statusCode " + response.getStatusCode(); - LOG.error(msg); - throw new Exception(msg); - } - successfulCommit.incrementAndGet(); - } catch (Exception e) { - LOG.error(errorMsg, e); - } - }); + }); + }).join(); + } finally { + forkJoinPool.shutdownNow(); + } + // check if every commit/finish call is successful return successfulCommit.get() == shuffleServerInfoSet.size(); } @@ -508,6 +537,7 @@ public void sendAppHeartbeat(String appId, long timeoutMs) { public void close() { heartBeatExecutorService.shutdownNow(); coordinatorClients.forEach(CoordinatorClient::close); + dataTransferPool.shutdownNow(); } private void throwExceptionIfNecessary(ClientResponse response, String errorMsg) { diff --git a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java index 0c519a8ed6..22f662b67b 100644 --- a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java +++ b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java @@ -36,6 +36,8 @@ public class RssClientConfig { public static final boolean RSS_DATA_REPLICA_SKIP_ENABLED_DEFAULT_VALUE = true; public static final String RSS_DATA_TRANSFER_POOL_SIZE = "rss.client.data.transfer.pool.size"; public static final int RSS_DATA_TRANFER_POOL_SIZE_DEFAULT_VALUE = Runtime.getRuntime().availableProcessors(); + public static final String RSS_DATA_COMMIT_POOL_SIZE = "rss.client.data.commit.pool.size"; + public static final int RSS_DATA_COMMIT_POOL_SIZE_DEFAULT_VALUE = -1; public static final String RSS_HEARTBEAT_INTERVAL = "rss.heartbeat.interval"; public static final long RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE = 10 * 1000L; public static final String RSS_HEARTBEAT_TIMEOUT = "rss.heartbeat.timeout"; diff --git a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java index 67efea7a9e..d7f539edc6 100644 --- a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java +++ b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java @@ -42,7 +42,7 @@ public class ShuffleWriteClientImplTest { @Test public void testSendData() { ShuffleWriteClientImpl shuffleWriteClient = - new ShuffleWriteClientImpl("GRPC", 3, 2000, 4, 1, 1, 1, true, 1); + new ShuffleWriteClientImpl("GRPC", 3, 2000, 4, 1, 1, 1, true, 1, 1); ShuffleServerClient mockShuffleServerClient = mock(ShuffleServerClient.class); ShuffleWriteClientImpl spyClient = Mockito.spy(shuffleWriteClient); doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any()); diff --git a/docs/client_guide.md b/docs/client_guide.md index 14b914d125..216b76a542 100644 --- a/docs/client_guide.md +++ b/docs/client_guide.md @@ -87,6 +87,7 @@ These configurations are shared by all types of clients. |.rss.client.read.buffer.size|14m|The max data size read from storage| |.rss.client.send.threadPool.size|5|The thread size for send shuffle data to shuffle server| |.rss.client.assignment.tags|-|The comma-separated list of tags for deciding assignment shuffle servers. Notice that the SHUFFLE_SERVER_VERSION will always as the assignment tag whether this conf is set or not| +|.rss.client.data.commit.pool.size|The number of assigned shuffle servers|The thread size for sending commit to shuffle servers| Notice: 1. `` should be `spark` or `mapreduce` diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java index 416af72a99..9ab84d4604 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java @@ -147,7 +147,7 @@ public static void setupServers() throws Exception { @Test public void testTags() throws Exception { ShuffleWriteClientImpl shuffleWriteClient = new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1, - 1, 1, 1, true, 1); + 1, 1, 1, true, 1, 1); shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM); // Case1 : only set the single default shuffle version tag diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java index 0148f69869..3dc71f4160 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java @@ -258,7 +258,7 @@ private void registerShuffleServer(String testAppId, int replica, int replicaWrite, int replicaRead, boolean replicaSkip) { shuffleWriteClientImpl = new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1, - replica, replicaWrite, replicaRead, replicaSkip, 1); + replica, replicaWrite, replicaRead, replicaSkip, 1, 1); List allServers = Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1, shuffleServerInfo2, shuffleServerInfo3, shuffleServerInfo4); diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java index 4339239405..a1cc6a1cf3 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java @@ -101,7 +101,7 @@ public void createClient() { public void clearResourceTest() throws Exception { final ShuffleWriteClient shuffleWriteClient = ShuffleClientFactory.getInstance().createShuffleWriteClient( - "GRPC", 2, 10000L, 4, 1, 1, 1, true, 1); + "GRPC", 2, 10000L, 4, 1, 1, 1, true, 1, 1); shuffleWriteClient.registerCoordinators("127.0.0.1:19999"); shuffleWriteClient.registerShuffle( new ShuffleServerInfo("127.0.0.1-20001", "127.0.0.1", 20001), diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java index 53f28697ca..f2e35c1280 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java @@ -88,7 +88,7 @@ public static void setupServers() throws Exception { @BeforeEach public void createClient() { shuffleWriteClientImpl = new ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1, - 1, 1, 1, true, 1); + 1, 1, 1, true, 1, 1); } @AfterEach