diff --git a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java index 15981aace6..044eae0bd5 100644 --- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java +++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java @@ -130,6 +130,9 @@ public static void main(String[] args) { assignmentTags.addAll(Arrays.asList(rawTags.split(","))); } assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION); + String clientType = conf.get(RssMRConfig.RSS_CLIENT_TYPE); + ClientUtils.validateClientType(clientType); + assignmentTags.add(clientType); final ScheduledExecutorService scheduledExecutorService = Executors.newSingleThreadScheduledExecutor( new ThreadFactory() { diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 2ecb6f8f09..61e201cbb2 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -263,12 +263,15 @@ public ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff // get all register info according to coordinator's response Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); + ClientUtils.validateClientType(clientType); + assignmentTags.add(clientType); int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL); int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES); + Map> partitionToServers; try { partitionToServers = RetryUtils.retry(() -> { diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index e70026ac1a..4a574bc5ec 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -339,6 +339,8 @@ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency< id.get(), defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient); Set assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf); + ClientUtils.validateClientType(clientType); + assignmentTags.add(clientType); int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf); diff --git a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java index 0bdf7cf00b..d9b51883dc 100644 --- a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java +++ b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java @@ -18,12 +18,16 @@ package org.apache.uniffle.client.util; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import org.apache.uniffle.client.api.ShuffleWriteClient; +import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.util.Constants; import org.apache.uniffle.storage.util.StorageType; @@ -122,4 +126,11 @@ public static void validateTestModeConf(boolean testMode, String storageType) { + "because of the poor performance of these two types."); } } + + public static void validateClientType(String clientType) { + Set types = Arrays.stream(ClientType.values()).map(Enum::name).collect(Collectors.toSet()); + if (!types.contains(clientType)) { + throw new IllegalArgumentException(String.format("The value of %s should be one of %s", clientType, types)); + } + } } diff --git a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java index 77f9cba5c3..577162a465 100644 --- a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java +++ b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java @@ -134,4 +134,17 @@ public void testWaitUntilDoneOrFail() { List> futures3 = getFutures(false); Awaitility.await().timeout(4, TimeUnit.SECONDS).until(() -> waitUntilDoneOrFail(futures3, true)); } + + @Test + public void testValidateClientType() { + String clientType = "GRPC_NETTY"; + ClientUtils.validateClientType(clientType); + clientType = "test"; + try { + ClientUtils.validateClientType(clientType); + fail(); + } catch (Exception e) { + // Ignore + } + } } diff --git a/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java b/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java index c8dccf0831..123dca4d51 100644 --- a/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java +++ b/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java @@ -27,6 +27,7 @@ import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import com.google.common.collect.Maps; import com.google.common.collect.Sets; import org.apache.hadoop.conf.Configuration; import org.junit.jupiter.api.AfterEach; @@ -36,6 +37,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.uniffle.common.ClientType; +import org.apache.uniffle.common.ServerStatus; import org.apache.uniffle.coordinator.metric.CoordinatorMetrics; import static org.awaitility.Awaitility.await; @@ -48,6 +51,8 @@ public class SimpleClusterManagerTest { private static final Logger LOG = LoggerFactory.getLogger(SimpleClusterManagerTest.class); private final Set testTags = Sets.newHashSet("test"); + private final Set nettyTags = Sets.newHashSet("test", ClientType.GRPC_NETTY.name()); + private final Set grpcTags = Sets.newHashSet("test", ClientType.GRPC.name()); @BeforeEach public void setUp() { @@ -79,15 +84,15 @@ public void getServerListTest() throws Exception { try (SimpleClusterManager clusterManager = new SimpleClusterManager(ssc, new Configuration())) { ServerNode sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20, - 10, testTags, true); + 10, grpcTags, true); ServerNode sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21, - 10, testTags, true); + 10, grpcTags, true); ServerNode sn3 = new ServerNode("sn3", "ip", 0, 100L, 50L, 20, - 11, testTags, true); + 11, grpcTags, true); clusterManager.add(sn1); clusterManager.add(sn2); clusterManager.add(sn3); - List serverNodes = clusterManager.getServerList(testTags); + List serverNodes = clusterManager.getServerList(grpcTags); assertEquals(3, serverNodes.size()); Set expectedIds = Sets.newHashSet("sn1", "sn2", "sn3"); assertEquals(expectedIds, serverNodes.stream().map(ServerNode::getId).collect(Collectors.toSet())); @@ -98,7 +103,7 @@ public void getServerListTest() throws Exception { sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21, 10, Sets.newHashSet("test", "new_tag"), true); ServerNode sn4 = new ServerNode("sn4", "ip", 0, 100L, 51L, 20, - 10, testTags, true); + 10, grpcTags, true); clusterManager.add(sn1); clusterManager.add(sn2); clusterManager.add(sn4); @@ -109,7 +114,7 @@ public void getServerListTest() throws Exception { assertTrue(serverNodes.contains(sn4)); Map> tagToNodes = clusterManager.getTagToNodes(); - assertEquals(2, tagToNodes.size()); + assertEquals(3, tagToNodes.size()); Set newTagNodes = tagToNodes.get("new_tag"); assertEquals(2, newTagNodes.size()); @@ -124,6 +129,67 @@ public void getServerListTest() throws Exception { } } + @Test + public void getServerListForNettyTest() throws Exception { + CoordinatorConf ssc = new CoordinatorConf(); + ssc.setLong(CoordinatorConf.COORDINATOR_HEARTBEAT_TIMEOUT, 30 * 1000L); + try (SimpleClusterManager clusterManager = new SimpleClusterManager(ssc, new Configuration())) { + + ServerNode sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20, + 10, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1); + ServerNode sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21, + 10, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1); + ServerNode sn3 = new ServerNode("sn3", "ip", 0, 100L, 50L, 20, + 11, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1); + ServerNode sn4 = new ServerNode("sn4", "ip", 0, 100L, 50L, 20, + 11, grpcTags, true); + clusterManager.add(sn1); + clusterManager.add(sn2); + clusterManager.add(sn3); + clusterManager.add(sn4); + + List serverNodes2 = clusterManager.getServerList(nettyTags); + assertEquals(3, serverNodes2.size()); + + List serverNodes3 = clusterManager.getServerList(grpcTags); + assertEquals(1, serverNodes3.size()); + + List serverNodes4 = clusterManager.getServerList(testTags); + assertEquals(4, serverNodes4.size()); + + Map> tagToNodes = clusterManager.getTagToNodes(); + assertEquals(3, tagToNodes.size()); + + // tag changes + sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20, + 10, Sets.newHashSet("new_tag"), true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1); + sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21, + 10, Sets.newHashSet("test", "new_tag"), + true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1); + sn4 = new ServerNode("sn4", "ip", 0, 100L, 51L, 20, + 10, grpcTags, true); + clusterManager.add(sn1); + clusterManager.add(sn2); + clusterManager.add(sn4); + Set testTagNodesForNetty = tagToNodes.get(ClientType.GRPC_NETTY.name()); + assertEquals(1, testTagNodesForNetty.size()); + + List serverNodes = clusterManager.getServerList(grpcTags); + assertEquals(1, serverNodes.size()); + assertTrue(serverNodes.contains(sn4)); + + Set newTagNodes = tagToNodes.get("new_tag"); + assertEquals(2, newTagNodes.size()); + assertTrue(newTagNodes.contains(sn1)); + assertTrue(newTagNodes.contains(sn2)); + Set testTagNodes = tagToNodes.get("test"); + assertEquals(3, testTagNodes.size()); + assertTrue(testTagNodes.contains(sn2)); + assertTrue(testTagNodes.contains(sn3)); + assertTrue(testTagNodes.contains(sn4)); + } + } + @Test public void testGetCorrectServerNodesWhenOneNodeRemovedAndUnhealthyNodeFound() throws Exception { CoordinatorConf ssc = new CoordinatorConf(); diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java index e3091a4309..b6f4e3ad2e 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java @@ -17,10 +17,12 @@ package org.apache.uniffle.test; +import java.lang.reflect.Field; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import com.google.common.collect.Lists; import com.google.common.collect.Sets; @@ -31,10 +33,12 @@ import org.apache.uniffle.client.request.RssGetShuffleAssignmentsRequest; import org.apache.uniffle.client.response.RssApplicationInfoResponse; import org.apache.uniffle.client.response.RssGetShuffleAssignmentsResponse; +import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.PartitionRange; import org.apache.uniffle.common.ShuffleRegisterInfo; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssBaseConf; +import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.storage.StorageInfo; import org.apache.uniffle.common.storage.StorageMedia; @@ -121,11 +125,28 @@ public void getShuffleRegisterInfoTest() { @Test public void getShuffleAssignmentsTest() throws Exception { - String appId = "getShuffleAssignmentsTest"; + final String appId = "getShuffleAssignmentsTest"; CoordinatorTestUtils.waitForRegister(coordinatorClient,2); + // When the shuffleServerHeartbeat Test is completed before the current test, + // the server's tags will be [ss_v4, GRPC_NETTY] and [ss_v4, GRPC], respectively. + // We need to remove the first machine's tag from GRPC_NETTY to GRPC + shuffleServers.get(0).stopServer(); + RssConf shuffleServerConf = shuffleServers.get(0).getShuffleServerConf(); + Class clazz = RssConf.class; + Field field = clazz.getDeclaredField("settings"); + field.setAccessible(true); + ((ConcurrentHashMap) field.get(shuffleServerConf)).remove(ShuffleServerConf.NETTY_SERVER_PORT.key()); + String storageTypeJsonSource = String.format("{\"%s\": \"ssd\"}", baseDir); + withEnvironmentVariables("RSS_ENV_KEY", storageTypeJsonSource).execute(() -> { + ShuffleServer ss = new ShuffleServer((ShuffleServerConf) shuffleServerConf); + ss.start(); + shuffleServers.set(0, ss); + }); + Thread.sleep(5000); + // add tag when ClientType is `GRPC` RssGetShuffleAssignmentsRequest request = new RssGetShuffleAssignmentsRequest( appId, 1, 10, 4, 1, - Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION)); + Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION, ClientType.GRPC.name())); RssGetShuffleAssignmentsResponse response = coordinatorClient.getShuffleAssignments(request); Set expectedStart = Sets.newHashSet(0, 4, 8); @@ -157,7 +178,7 @@ public void getShuffleAssignmentsTest() throws Exception { request = new RssGetShuffleAssignmentsRequest( appId, 1, 10, 4, 2, - Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION)); + Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION, ClientType.GRPC.name())); response = coordinatorClient.getShuffleAssignments(request); serverToPartitionRanges = response.getServerToPartitionRanges(); assertEquals(2, serverToPartitionRanges.size()); diff --git a/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java b/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java index 1ceab307f2..ed22f95375 100644 --- a/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java +++ b/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java @@ -44,6 +44,8 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; +import org.apache.uniffle.common.ClientType; + import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -165,6 +167,7 @@ private void runRssApp(Configuration jobConf) throws Exception { jobConf.set(MRJobConfig.MAPREDUCE_APPLICATION_CLASSPATH, "$PWD/rss.jar/" + localFile.getName() + "," + MRJobConfig.DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH); jobConf.set(RssMRConfig.RSS_COORDINATOR_QUORUM, COORDINATOR_QUORUM); + jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name()); updateRssConfiguration(jobConf); runMRApp(jobConf, getTestTool(), getTestArgs()); diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java index 3a47a81d1f..263884c1ea 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java @@ -35,6 +35,7 @@ import picocli.CommandLine; import org.apache.uniffle.common.Arguments; +import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.ServerStatus; import org.apache.uniffle.common.exception.InvalidRequestException; import org.apache.uniffle.common.metrics.GRPCMetrics; @@ -248,9 +249,18 @@ private void initServerTags() { if (CollectionUtils.isNotEmpty(configuredTags)) { tags.addAll(configuredTags); } + tagServer(); LOG.info("Server tags: {}", tags); } + private void tagServer() { + if (nettyServerEnabled) { + tags.add(ClientType.GRPC_NETTY.name()); + } else { + tags.add(ClientType.GRPC.name()); + } + } + private void registerMetrics() throws Exception { LOG.info("Register metrics"); CollectorRegistry shuffleServerCollectorRegistry = new CollectorRegistry(true);