Skip to content

Commit

Permalink
[#719] feat(netty): Optimize allocation strategy (#739)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Users can choose to use netty's transmission method or grpc's through client configuration.

### Why are the changes needed?
Fix: #719 

### Does this PR introduce _any_ user-facing change?
No. However, if users want to use `netty` as a data transfer method, they need to enable `spark.rss.client.type=GRPC_ NETTY` or `mapreduce.rss.client.type=GRPC_ NETTY`

### How was this patch tested?
New uts.
  • Loading branch information
smallzhongfeng committed Mar 20, 2023
1 parent d60d675 commit 2cb22ff
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ public static void main(String[] args) {
assignmentTags.addAll(Arrays.asList(rawTags.split(",")));
}
assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION);
String clientType = conf.get(RssMRConfig.RSS_CLIENT_TYPE);
ClientUtils.validateClientType(clientType);
assignmentTags.add(clientType);

final ScheduledExecutorService scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(
new ThreadFactory() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,15 @@ public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff

// get all register info according to coordinator's response
Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
ClientUtils.validateClientType(clientType);
assignmentTags.add(clientType);

int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);

// retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result
long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);

Map<Integer, List<ShuffleServerInfo>> partitionToServers;
try {
partitionToServers = RetryUtils.retry(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency<
id.get(), defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient);

Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
ClientUtils.validateClientType(clientType);
assignmentTags.add(clientType);

int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
package org.apache.uniffle.client.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.storage.util.StorageType;
Expand Down Expand Up @@ -122,4 +126,11 @@ public static void validateTestModeConf(boolean testMode, String storageType) {
+ "because of the poor performance of these two types.");
}
}

public static void validateClientType(String clientType) {
Set<String> types = Arrays.stream(ClientType.values()).map(Enum::name).collect(Collectors.toSet());
if (!types.contains(clientType)) {
throw new IllegalArgumentException(String.format("The value of %s should be one of %s", clientType, types));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,17 @@ public void testWaitUntilDoneOrFail() {
List<CompletableFuture<Boolean>> futures3 = getFutures(false);
Awaitility.await().timeout(4, TimeUnit.SECONDS).until(() -> waitUntilDoneOrFail(futures3, true));
}

@Test
public void testValidateClientType() {
String clientType = "GRPC_NETTY";
ClientUtils.validateClientType(clientType);
clientType = "test";
try {
ClientUtils.validateClientType(clientType);
fail();
} catch (Exception e) {
// Ignore
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.hadoop.conf.Configuration;
import org.junit.jupiter.api.AfterEach;
Expand All @@ -36,6 +37,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ServerStatus;
import org.apache.uniffle.coordinator.metric.CoordinatorMetrics;

import static org.awaitility.Awaitility.await;
Expand All @@ -48,6 +51,8 @@ public class SimpleClusterManagerTest {
private static final Logger LOG = LoggerFactory.getLogger(SimpleClusterManagerTest.class);

private final Set<String> testTags = Sets.newHashSet("test");
private final Set<String> nettyTags = Sets.newHashSet("test", ClientType.GRPC_NETTY.name());
private final Set<String> grpcTags = Sets.newHashSet("test", ClientType.GRPC.name());

@BeforeEach
public void setUp() {
Expand Down Expand Up @@ -79,15 +84,15 @@ public void getServerListTest() throws Exception {
try (SimpleClusterManager clusterManager = new SimpleClusterManager(ssc, new Configuration())) {

ServerNode sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20,
10, testTags, true);
10, grpcTags, true);
ServerNode sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
10, testTags, true);
10, grpcTags, true);
ServerNode sn3 = new ServerNode("sn3", "ip", 0, 100L, 50L, 20,
11, testTags, true);
11, grpcTags, true);
clusterManager.add(sn1);
clusterManager.add(sn2);
clusterManager.add(sn3);
List<ServerNode> serverNodes = clusterManager.getServerList(testTags);
List<ServerNode> serverNodes = clusterManager.getServerList(grpcTags);
assertEquals(3, serverNodes.size());
Set<String> expectedIds = Sets.newHashSet("sn1", "sn2", "sn3");
assertEquals(expectedIds, serverNodes.stream().map(ServerNode::getId).collect(Collectors.toSet()));
Expand All @@ -98,7 +103,7 @@ public void getServerListTest() throws Exception {
sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
10, Sets.newHashSet("test", "new_tag"), true);
ServerNode sn4 = new ServerNode("sn4", "ip", 0, 100L, 51L, 20,
10, testTags, true);
10, grpcTags, true);
clusterManager.add(sn1);
clusterManager.add(sn2);
clusterManager.add(sn4);
Expand All @@ -109,7 +114,7 @@ public void getServerListTest() throws Exception {
assertTrue(serverNodes.contains(sn4));

Map<String, Set<ServerNode>> tagToNodes = clusterManager.getTagToNodes();
assertEquals(2, tagToNodes.size());
assertEquals(3, tagToNodes.size());

Set<ServerNode> newTagNodes = tagToNodes.get("new_tag");
assertEquals(2, newTagNodes.size());
Expand All @@ -124,6 +129,67 @@ public void getServerListTest() throws Exception {
}
}

@Test
public void getServerListForNettyTest() throws Exception {
CoordinatorConf ssc = new CoordinatorConf();
ssc.setLong(CoordinatorConf.COORDINATOR_HEARTBEAT_TIMEOUT, 30 * 1000L);
try (SimpleClusterManager clusterManager = new SimpleClusterManager(ssc, new Configuration())) {

ServerNode sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20,
10, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1);
ServerNode sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
10, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1);
ServerNode sn3 = new ServerNode("sn3", "ip", 0, 100L, 50L, 20,
11, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1);
ServerNode sn4 = new ServerNode("sn4", "ip", 0, 100L, 50L, 20,
11, grpcTags, true);
clusterManager.add(sn1);
clusterManager.add(sn2);
clusterManager.add(sn3);
clusterManager.add(sn4);

List<ServerNode> serverNodes2 = clusterManager.getServerList(nettyTags);
assertEquals(3, serverNodes2.size());

List<ServerNode> serverNodes3 = clusterManager.getServerList(grpcTags);
assertEquals(1, serverNodes3.size());

List<ServerNode> serverNodes4 = clusterManager.getServerList(testTags);
assertEquals(4, serverNodes4.size());

Map<String, Set<ServerNode>> tagToNodes = clusterManager.getTagToNodes();
assertEquals(3, tagToNodes.size());

// tag changes
sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20,
10, Sets.newHashSet("new_tag"), true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1);
sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
10, Sets.newHashSet("test", "new_tag"),
true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1);
sn4 = new ServerNode("sn4", "ip", 0, 100L, 51L, 20,
10, grpcTags, true);
clusterManager.add(sn1);
clusterManager.add(sn2);
clusterManager.add(sn4);
Set<ServerNode> testTagNodesForNetty = tagToNodes.get(ClientType.GRPC_NETTY.name());
assertEquals(1, testTagNodesForNetty.size());

List<ServerNode> serverNodes = clusterManager.getServerList(grpcTags);
assertEquals(1, serverNodes.size());
assertTrue(serverNodes.contains(sn4));

Set<ServerNode> newTagNodes = tagToNodes.get("new_tag");
assertEquals(2, newTagNodes.size());
assertTrue(newTagNodes.contains(sn1));
assertTrue(newTagNodes.contains(sn2));
Set<ServerNode> testTagNodes = tagToNodes.get("test");
assertEquals(3, testTagNodes.size());
assertTrue(testTagNodes.contains(sn2));
assertTrue(testTagNodes.contains(sn3));
assertTrue(testTagNodes.contains(sn4));
}
}

@Test
public void testGetCorrectServerNodesWhenOneNodeRemovedAndUnhealthyNodeFound() throws Exception {
CoordinatorConf ssc = new CoordinatorConf();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.uniffle.test;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
Expand All @@ -31,10 +33,12 @@
import org.apache.uniffle.client.request.RssGetShuffleAssignmentsRequest;
import org.apache.uniffle.client.response.RssApplicationInfoResponse;
import org.apache.uniffle.client.response.RssGetShuffleAssignmentsResponse;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.ShuffleRegisterInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssBaseConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.storage.StorageInfo;
import org.apache.uniffle.common.storage.StorageMedia;
Expand Down Expand Up @@ -121,11 +125,28 @@ public void getShuffleRegisterInfoTest() {

@Test
public void getShuffleAssignmentsTest() throws Exception {
String appId = "getShuffleAssignmentsTest";
final String appId = "getShuffleAssignmentsTest";
CoordinatorTestUtils.waitForRegister(coordinatorClient,2);
// When the shuffleServerHeartbeat Test is completed before the current test,
// the server's tags will be [ss_v4, GRPC_NETTY] and [ss_v4, GRPC], respectively.
// We need to remove the first machine's tag from GRPC_NETTY to GRPC
shuffleServers.get(0).stopServer();
RssConf shuffleServerConf = shuffleServers.get(0).getShuffleServerConf();
Class<RssConf> clazz = RssConf.class;
Field field = clazz.getDeclaredField("settings");
field.setAccessible(true);
((ConcurrentHashMap) field.get(shuffleServerConf)).remove(ShuffleServerConf.NETTY_SERVER_PORT.key());
String storageTypeJsonSource = String.format("{\"%s\": \"ssd\"}", baseDir);
withEnvironmentVariables("RSS_ENV_KEY", storageTypeJsonSource).execute(() -> {
ShuffleServer ss = new ShuffleServer((ShuffleServerConf) shuffleServerConf);
ss.start();
shuffleServers.set(0, ss);
});
Thread.sleep(5000);
// add tag when ClientType is `GRPC`
RssGetShuffleAssignmentsRequest request = new RssGetShuffleAssignmentsRequest(
appId, 1, 10, 4, 1,
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION, ClientType.GRPC.name()));
RssGetShuffleAssignmentsResponse response = coordinatorClient.getShuffleAssignments(request);
Set<Integer> expectedStart = Sets.newHashSet(0, 4, 8);

Expand Down Expand Up @@ -157,7 +178,7 @@ public void getShuffleAssignmentsTest() throws Exception {

request = new RssGetShuffleAssignmentsRequest(
appId, 1, 10, 4, 2,
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION, ClientType.GRPC.name()));
response = coordinatorClient.getShuffleAssignments(request);
serverToPartitionRanges = response.getServerToPartitionRanges();
assertEquals(2, serverToPartitionRanges.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;

import org.apache.uniffle.common.ClientType;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

Expand Down Expand Up @@ -165,6 +167,7 @@ private void runRssApp(Configuration jobConf) throws Exception {
jobConf.set(MRJobConfig.MAPREDUCE_APPLICATION_CLASSPATH,
"$PWD/rss.jar/" + localFile.getName() + "," + MRJobConfig.DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH);
jobConf.set(RssMRConfig.RSS_COORDINATOR_QUORUM, COORDINATOR_QUORUM);
jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name());
updateRssConfiguration(jobConf);
runMRApp(jobConf, getTestTool(), getTestArgs());

Expand Down
10 changes: 10 additions & 0 deletions server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import picocli.CommandLine;

import org.apache.uniffle.common.Arguments;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ServerStatus;
import org.apache.uniffle.common.exception.InvalidRequestException;
import org.apache.uniffle.common.metrics.GRPCMetrics;
Expand Down Expand Up @@ -248,9 +249,18 @@ private void initServerTags() {
if (CollectionUtils.isNotEmpty(configuredTags)) {
tags.addAll(configuredTags);
}
tagServer();
LOG.info("Server tags: {}", tags);
}

private void tagServer() {
if (nettyServerEnabled) {
tags.add(ClientType.GRPC_NETTY.name());
} else {
tags.add(ClientType.GRPC.name());
}
}

private void registerMetrics() throws Exception {
LOG.info("Register metrics");
CollectorRegistry shuffleServerCollectorRegistry = new CollectorRegistry(true);
Expand Down

0 comments on commit 2cb22ff

Please sign in to comment.