Skip to content

Commit

Permalink
Fix integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
rickyma committed Apr 17, 2024
1 parent 84dda59 commit cfb02b9
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,25 @@ public static void setupServers(@TempDir File tmpDir) throws Exception {
dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name());
addDynamicConf(coordinatorConf, dynamicConf);
createCoordinatorServer(coordinatorConf);
ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC);

ShuffleServerConf grpcShuffleServerConf = getShuffleServerConf(ServerType.GRPC);
File dataDir1 = new File(tmpDir, "data1");
File dataDir2 = new File(tmpDir, "data2");
String basePath = dataDir1.getAbsolutePath() + "," + dataDir2.getAbsolutePath();
shuffleServerConf.setString("rss.storage.type", StorageType.LOCALFILE.name());
shuffleServerConf.setBoolean(ShuffleServerConf.RSS_TEST_MODE_ENABLE, true);
shuffleServerConf.setString("rss.storage.basePath", basePath);
createShuffleServer(shuffleServerConf);
String grpcBasePath = dataDir1.getAbsolutePath() + "," + dataDir2.getAbsolutePath();
grpcShuffleServerConf.setString("rss.storage.type", StorageType.LOCALFILE.name());
grpcShuffleServerConf.setBoolean(ShuffleServerConf.RSS_TEST_MODE_ENABLE, true);
grpcShuffleServerConf.setString("rss.storage.basePath", grpcBasePath);
createShuffleServer(grpcShuffleServerConf);

ShuffleServerConf nettyShuffleServerConf = getShuffleServerConf(ServerType.GRPC_NETTY);
File dataDir3 = new File(tmpDir, "data3");
File dataDir4 = new File(tmpDir, "data4");
String nettyBasePath = dataDir3.getAbsolutePath() + "," + dataDir4.getAbsolutePath();
nettyShuffleServerConf.setString("rss.storage.type", StorageType.LOCALFILE.name());
nettyShuffleServerConf.setBoolean(ShuffleServerConf.RSS_TEST_MODE_ENABLE, true);
nettyShuffleServerConf.setString("rss.storage.basePath", nettyBasePath);
createShuffleServer(nettyShuffleServerConf);

startServers();
}

Expand All @@ -76,13 +87,19 @@ public void run() throws Exception {
Map resultWithoutRss = runSparkApp(sparkConf, fileName);
results.add(resultWithoutRss);

updateSparkConfWithRss(sparkConf);
updateSparkConfWithRssGrpc(sparkConf);
updateSparkConfCustomer(sparkConf);
for (Codec.Type type : new Codec.Type[] {Codec.Type.NOOP, Codec.Type.ZSTD, Codec.Type.LZ4}) {
sparkConf.set("spark." + COMPRESSION_TYPE.key().toLowerCase(), type.name());
Map resultWithRss = runSparkApp(sparkConf, fileName);
results.add(resultWithRss);
}
updateSparkConfWithRssNetty(sparkConf);
for (Codec.Type type : new Codec.Type[] {Codec.Type.NOOP, Codec.Type.ZSTD, Codec.Type.LZ4}) {
sparkConf.set("spark." + COMPRESSION_TYPE.key().toLowerCase(), type.name());
Map resultWithRss = runSparkApp(sparkConf, fileName);
results.add(resultWithRss);
}

for (int i = 1; i < results.size(); i++) {
verifyTestResult(results.get(0), results.get(i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public static void setupServers(@TempDir File tmpDir) throws Exception {
public void testMemoryRelease() throws Exception {
final String fileName = generateTextFile(10000, 10000);
SparkConf sparkConf = createSparkConf();
updateSparkConfWithRss(sparkConf);
updateSparkConfWithRssGrpc(sparkConf);
sparkConf.set("spark.executor.memory", "500m");
sparkConf.set("spark.unsafe.exceptionOnMemoryLeak", "true");
updateRssStorage(sparkConf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,16 @@ private void doTestRssShuffleManager(
BlockIdLayout clientConfLayout,
BlockIdLayout dynamicConfLayout,
BlockIdLayout expectedLayout,
boolean enableDynamicCLientConf)
boolean enableDynamicClientConf)
throws Exception {
Map<String, String> dynamicConf = startServers(dynamicConfLayout);

SparkConf conf = createSparkConf();
updateSparkConfWithRss(conf);
updateSparkConfWithRssGrpc(conf);
// enable stage recompute
conf.set("spark." + RssClientConfig.RSS_RESUBMIT_STAGE, "true");
// enable dynamic client conf
conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED, enableDynamicCLientConf);
conf.set(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED, enableDynamicClientConf);
// configure storage type
conf.set("spark." + RssClientConfig.RSS_STORAGE_TYPE, StorageType.MEMORY_LOCALFILE.name());
// restarting the coordinator may cause RssException: There isn't enough shuffle servers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.common.ClientType;

import static org.junit.jupiter.api.Assertions.assertEquals;

public abstract class SparkIntegrationTestBase extends IntegrationTestBase {
Expand All @@ -54,24 +56,24 @@ public void run() throws Exception {
final long durationWithoutRss = System.currentTimeMillis() - start;

Uninterruptibles.sleepUninterruptibly(2, TimeUnit.SECONDS);
updateSparkConfWithRss(sparkConf);
updateSparkConfWithRssGrpc(sparkConf);
updateSparkConfCustomer(sparkConf);
start = System.currentTimeMillis();
Map resultWithRss = runSparkApp(sparkConf, fileName);
final long durationWithRss = System.currentTimeMillis() - start;
Map resultWithRssGrpc = runSparkApp(sparkConf, fileName);
final long durationWithRssGrpc = System.currentTimeMillis() - start;

updateSparkConfWithRssNetty(sparkConf);
start = System.currentTimeMillis();
Map resultWithRssNetty = runSparkApp(sparkConf, fileName);
final long durationWithRssNetty = System.currentTimeMillis() - start;
verifyTestResult(resultWithoutRss, resultWithRss);
verifyTestResult(resultWithoutRss, resultWithRssGrpc);
verifyTestResult(resultWithoutRss, resultWithRssNetty);

LOG.info(
"Test: durationWithoutRss["
+ durationWithoutRss
+ "], durationWithRss["
+ durationWithRss
+ "], durationWithRssGrpc["
+ durationWithRssGrpc
+ "]"
+ "], durationWithRssNetty["
+ durationWithRssNetty
Expand All @@ -90,16 +92,16 @@ protected Map runSparkApp(SparkConf sparkConf, String testFileName) throws Excep
spark.close();
}
spark = SparkSession.builder().config(sparkConf).getOrCreate();
Map resultWithRss = runTest(spark, testFileName);
Map result = runTest(spark, testFileName);
spark.stop();
return resultWithRss;
return result;
}

protected SparkConf createSparkConf() {
return new SparkConf().setAppName(this.getClass().getSimpleName()).setMaster("local[4]");
}

public void updateSparkConfWithRss(SparkConf sparkConf) {
public void updateSparkConfWithRssGrpc(SparkConf sparkConf) {
sparkConf.set("spark.shuffle.manager", "org.apache.spark.shuffle.RssShuffleManager");
sparkConf.set(
"spark.shuffle.sort.io.plugin.class", "org.apache.spark.shuffle.RssShuffleDataIo");
Expand All @@ -118,10 +120,11 @@ public void updateSparkConfWithRss(SparkConf sparkConf) {
sparkConf.set(RssSparkConfig.RSS_CLIENT_READ_BUFFER_SIZE.key(), "1m");
sparkConf.set(RssSparkConfig.RSS_HEARTBEAT_INTERVAL.key(), "2000");
sparkConf.set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true");
sparkConf.set(RssSparkConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name());
}

public void updateSparkConfWithRssNetty(SparkConf sparkConf) {
sparkConf.set(RssSparkConfig.RSS_CLIENT_TYPE, "GRPC_NETTY");
sparkConf.set(RssSparkConfig.RSS_CLIENT_TYPE, ClientType.GRPC_NETTY.name());
}

protected void verifyTestResult(Map expected, Map actual) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,19 @@ public void updateSparkConfCustomer(SparkConf sparkConf) {
}

@Override
public void updateSparkConfWithRss(SparkConf sparkConf) {
super.updateSparkConfWithRss(sparkConf);
public void updateSparkConfWithRssGrpc(SparkConf sparkConf) {
super.updateSparkConfWithRssGrpc(sparkConf);
addMultiReplicaConf(sparkConf);
}

@Override
public void updateSparkConfWithRssNetty(SparkConf sparkConf) {
super.updateSparkConfWithRssNetty(sparkConf);
// Add multi replica conf
addMultiReplicaConf(sparkConf);
}

private static void addMultiReplicaConf(SparkConf sparkConf) {
// Add multi replica conf
sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA.key(), String.valueOf(replicateWrite));
sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA_WRITE.key(), String.valueOf(replicateWrite));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ public void updateSparkConfCustomer(SparkConf sparkConf) {
}

@Override
public void updateSparkConfWithRss(SparkConf sparkConf) {
super.updateSparkConfWithRss(sparkConf);
public void updateSparkConfWithRssGrpc(SparkConf sparkConf) {
super.updateSparkConfWithRssGrpc(sparkConf);
// Add multi replica conf
sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA.key(), String.valueOf(replicateWrite));
sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA_WRITE.key(), String.valueOf(replicateWrite));
Expand Down

0 comments on commit cfb02b9

Please sign in to comment.