Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#719] feat(netty): Optimize allocation strategy #739

Merged
merged 6 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ public static void main(String[] args) {
assignmentTags.addAll(Arrays.asList(rawTags.split(",")));
}
assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION);
String clientType = conf.get(RssMRConfig.RSS_CLIENT_TYPE);
ClientUtils.validateClientType(clientType);
assignmentTags.add(clientType);

final ScheduledExecutorService scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(
new ThreadFactory() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,15 @@ public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff

// get all register info according to coordinator's response
Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
ClientUtils.validateClientType(clientType);
assignmentTags.add(clientType);

int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);

// retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe it will return the same result
long retryInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);

Map<Integer, List<ShuffleServerInfo>> partitionToServers;
try {
partitionToServers = RetryUtils.retry(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency<
id.get(), defaultRemoteStorage, dynamicConfEnabled, storageType, shuffleWriteClient);

Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
ClientUtils.validateClientType(clientType);
assignmentTags.add(clientType);

int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
package org.apache.uniffle.client.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.storage.util.StorageType;
Expand Down Expand Up @@ -122,4 +126,11 @@ public static void validateTestModeConf(boolean testMode, String storageType) {
+ "because of the poor performance of these two types.");
}
}

public static void validateClientType(String clientType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we convert this method to CheckValue#checkValueFunc like ConfigUtil#POSITIVE_LONG_VALIDATOR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, i will try it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem to work because there is no checkValue on mr's client side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it. My mistake.

Set<String> types = Arrays.stream(ClientType.values()).map(Enum::name).collect(Collectors.toSet());
if (!types.contains(clientType)) {
throw new IllegalArgumentException(String.format("The value of %s should be one of %s", clientType, types));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,17 @@ public void testWaitUntilDoneOrFail() {
List<CompletableFuture<Boolean>> futures3 = getFutures(false);
Awaitility.await().timeout(4, TimeUnit.SECONDS).until(() -> waitUntilDoneOrFail(futures3, true));
}

@Test
public void testValidateClientType() {
String clientType = "GRPC_NETTY";
ClientUtils.validateClientType(clientType);
clientType = "test";
try {
ClientUtils.validateClientType(clientType);
fail();
} catch (Exception e) {
// Ignore
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Sets;

import org.apache.uniffle.common.util.UnitConverter;
Expand Down Expand Up @@ -645,4 +646,8 @@ public String getEnv(String key) {
return System.getenv(key);
}

@VisibleForTesting
public Map<String, Object> getSettings() {
smallzhongfeng marked this conversation as resolved.
Show resolved Hide resolved
return settings;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.hadoop.conf.Configuration;
import org.junit.jupiter.api.AfterEach;
Expand All @@ -36,6 +37,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ServerStatus;
import org.apache.uniffle.coordinator.metric.CoordinatorMetrics;

import static org.awaitility.Awaitility.await;
Expand All @@ -48,6 +51,8 @@ public class SimpleClusterManagerTest {
private static final Logger LOG = LoggerFactory.getLogger(SimpleClusterManagerTest.class);

private final Set<String> testTags = Sets.newHashSet("test");
private final Set<String> nettyTags = Sets.newHashSet("test", ClientType.GRPC_NETTY.name());
private final Set<String> grpcTags = Sets.newHashSet("test", ClientType.GRPC.name());

@BeforeEach
public void setUp() {
Expand Down Expand Up @@ -79,15 +84,15 @@ public void getServerListTest() throws Exception {
try (SimpleClusterManager clusterManager = new SimpleClusterManager(ssc, new Configuration())) {

ServerNode sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20,
10, testTags, true);
10, grpcTags, true);
ServerNode sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
10, testTags, true);
10, grpcTags, true);
ServerNode sn3 = new ServerNode("sn3", "ip", 0, 100L, 50L, 20,
11, testTags, true);
11, grpcTags, true);
clusterManager.add(sn1);
clusterManager.add(sn2);
clusterManager.add(sn3);
List<ServerNode> serverNodes = clusterManager.getServerList(testTags);
List<ServerNode> serverNodes = clusterManager.getServerList(grpcTags);
assertEquals(3, serverNodes.size());
Set<String> expectedIds = Sets.newHashSet("sn1", "sn2", "sn3");
assertEquals(expectedIds, serverNodes.stream().map(ServerNode::getId).collect(Collectors.toSet()));
Expand All @@ -98,7 +103,7 @@ public void getServerListTest() throws Exception {
sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
10, Sets.newHashSet("test", "new_tag"), true);
ServerNode sn4 = new ServerNode("sn4", "ip", 0, 100L, 51L, 20,
10, testTags, true);
10, grpcTags, true);
clusterManager.add(sn1);
clusterManager.add(sn2);
clusterManager.add(sn4);
Expand All @@ -109,7 +114,7 @@ public void getServerListTest() throws Exception {
assertTrue(serverNodes.contains(sn4));

Map<String, Set<ServerNode>> tagToNodes = clusterManager.getTagToNodes();
assertEquals(2, tagToNodes.size());
assertEquals(3, tagToNodes.size());

Set<ServerNode> newTagNodes = tagToNodes.get("new_tag");
assertEquals(2, newTagNodes.size());
Expand All @@ -124,6 +129,67 @@ public void getServerListTest() throws Exception {
}
}

@Test
public void getServerListForNettyTest() throws Exception {
CoordinatorConf ssc = new CoordinatorConf();
ssc.setLong(CoordinatorConf.COORDINATOR_HEARTBEAT_TIMEOUT, 30 * 1000L);
try (SimpleClusterManager clusterManager = new SimpleClusterManager(ssc, new Configuration())) {

ServerNode sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20,
10, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1);
ServerNode sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
10, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1);
ServerNode sn3 = new ServerNode("sn3", "ip", 0, 100L, 50L, 20,
11, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1);
ServerNode sn4 = new ServerNode("sn4", "ip", 0, 100L, 50L, 20,
11, grpcTags, true);
clusterManager.add(sn1);
clusterManager.add(sn2);
clusterManager.add(sn3);
clusterManager.add(sn4);

List<ServerNode> serverNodes2 = clusterManager.getServerList(nettyTags);
assertEquals(3, serverNodes2.size());

List<ServerNode> serverNodes3 = clusterManager.getServerList(grpcTags);
assertEquals(1, serverNodes3.size());

List<ServerNode> serverNodes4 = clusterManager.getServerList(testTags);
assertEquals(4, serverNodes4.size());

Map<String, Set<ServerNode>> tagToNodes = clusterManager.getTagToNodes();
assertEquals(3, tagToNodes.size());

// tag changes
sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20,
10, Sets.newHashSet("new_tag"), true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1);
sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
10, Sets.newHashSet("test", "new_tag"),
true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1);
sn4 = new ServerNode("sn4", "ip", 0, 100L, 51L, 20,
10, grpcTags, true);
clusterManager.add(sn1);
clusterManager.add(sn2);
clusterManager.add(sn4);
Set<ServerNode> testTagNodesForNetty = tagToNodes.get(ClientType.GRPC_NETTY.name());
assertEquals(1, testTagNodesForNetty.size());

List<ServerNode> serverNodes = clusterManager.getServerList(grpcTags);
assertEquals(1, serverNodes.size());
assertTrue(serverNodes.contains(sn4));

Set<ServerNode> newTagNodes = tagToNodes.get("new_tag");
assertEquals(2, newTagNodes.size());
assertTrue(newTagNodes.contains(sn1));
assertTrue(newTagNodes.contains(sn2));
Set<ServerNode> testTagNodes = tagToNodes.get("test");
assertEquals(3, testTagNodes.size());
assertTrue(testTagNodes.contains(sn2));
assertTrue(testTagNodes.contains(sn3));
assertTrue(testTagNodes.contains(sn4));
}
}

@Test
public void testGetCorrectServerNodesWhenOneNodeRemovedAndUnhealthyNodeFound() throws Exception {
CoordinatorConf ssc = new CoordinatorConf();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.uniffle.test;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -31,10 +32,12 @@
import org.apache.uniffle.client.request.RssGetShuffleAssignmentsRequest;
import org.apache.uniffle.client.response.RssApplicationInfoResponse;
import org.apache.uniffle.client.response.RssGetShuffleAssignmentsResponse;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.ShuffleRegisterInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssBaseConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.storage.StorageInfo;
import org.apache.uniffle.common.storage.StorageMedia;
Expand Down Expand Up @@ -121,11 +124,30 @@ public void getShuffleRegisterInfoTest() {

@Test
public void getShuffleAssignmentsTest() throws Exception {
String appId = "getShuffleAssignmentsTest";
final String appId = "getShuffleAssignmentsTest";
CoordinatorTestUtils.waitForRegister(coordinatorClient,2);
// When the shuffleServerHeartbeat Test is completed before the current test,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shuffleServerConf.setInteger(ShuffleServerConf.NETTY_SERVER_PORT, SHUFFLE_SERVER_PORT + 5);

Because this shuffleServerConf modifies the netty port number of a server, there will be one less machine allocated. Therefore, in order to maintain the original test, we need to set the port number back to a negative number, which is equivalent to setting the label to GRPC.

// the server's tags will be [ss_v4, GRPC_NETTY] and [ss_v4, GRPC], respectively.
// We need to remove the first machine's tag from GRPC_NETTY to GRPC
shuffleServers.get(0).stopServer();
RssConf shuffleServerConf = shuffleServers.get(0).getShuffleServerConf();
Map<String, Object> originConf = shuffleServerConf.getSettings();
Class<RssConf> clazz = RssConf.class;
Field field = clazz.getDeclaredField("settings");
field.setAccessible(true);
originConf.remove(ShuffleServerConf.NETTY_SERVER_PORT.key());
field.set(shuffleServerConf, originConf);
Copy link
Contributor

@jerqi jerqi Mar 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need set the value again. They are the same object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it.

String storageTypeJsonSource = String.format("{\"%s\": \"ssd\"}", baseDir);
withEnvironmentVariables("RSS_ENV_KEY", storageTypeJsonSource).execute(() -> {
ShuffleServer ss = new ShuffleServer((ShuffleServerConf) shuffleServerConf);
ss.start();
shuffleServers.set(0, ss);
});
Thread.sleep(5000);
// add tag when ClientType is `GRPC`
RssGetShuffleAssignmentsRequest request = new RssGetShuffleAssignmentsRequest(
appId, 1, 10, 4, 1,
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION, ClientType.GRPC.name()));
RssGetShuffleAssignmentsResponse response = coordinatorClient.getShuffleAssignments(request);
Set<Integer> expectedStart = Sets.newHashSet(0, 4, 8);

Expand Down Expand Up @@ -157,7 +179,7 @@ public void getShuffleAssignmentsTest() throws Exception {

request = new RssGetShuffleAssignmentsRequest(
appId, 1, 10, 4, 2,
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION, ClientType.GRPC.name()));
response = coordinatorClient.getShuffleAssignments(request);
serverToPartitionRanges = response.getServerToPartitionRanges();
assertEquals(2, serverToPartitionRanges.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;

import org.apache.uniffle.common.ClientType;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

Expand Down Expand Up @@ -165,6 +167,7 @@ private void runRssApp(Configuration jobConf) throws Exception {
jobConf.set(MRJobConfig.MAPREDUCE_APPLICATION_CLASSPATH,
"$PWD/rss.jar/" + localFile.getName() + "," + MRJobConfig.DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH);
jobConf.set(RssMRConfig.RSS_COORDINATOR_QUORUM, COORDINATOR_QUORUM);
jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name());
updateRssConfiguration(jobConf);
runMRApp(jobConf, getTestTool(), getTestArgs());

Expand Down
10 changes: 10 additions & 0 deletions server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import picocli.CommandLine;

import org.apache.uniffle.common.Arguments;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ServerStatus;
import org.apache.uniffle.common.exception.InvalidRequestException;
import org.apache.uniffle.common.metrics.GRPCMetrics;
Expand Down Expand Up @@ -248,9 +249,18 @@ private void initServerTags() {
if (CollectionUtils.isNotEmpty(configuredTags)) {
tags.addAll(configuredTags);
}
tagServer();
LOG.info("Server tags: {}", tags);
}

private void tagServer() {
if (nettyServerEnabled) {
tags.add(ClientType.GRPC_NETTY.name());
} else {
tags.add(ClientType.GRPC.name());
}
}

private void registerMetrics() throws Exception {
LOG.info("Register metrics");
CollectorRegistry shuffleServerCollectorRegistry = new CollectorRegistry(true);
Expand Down