diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java index 283122f0b9..f0f25816c3 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java @@ -366,6 +366,13 @@ public class RssSparkConfig { .doc(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT.description())) .createWithDefault(-1); + public static final ConfigEntry RSS_RESUBMIT_STAGE = + createBooleanBuilder( + new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_RESUBMIT_STAGE) + .internal() + .doc("Whether to enable the resubmit stage.")) + .createWithDefault(false); + // spark2 doesn't have this key defined public static final String SPARK_SHUFFLE_COMPRESS_KEY = "spark.shuffle.compress"; diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index a2c7edbbc3..d20fca5867 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -54,7 +54,6 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; import org.apache.uniffle.client.util.ClientUtils; -import org.apache.uniffle.client.util.RssClientConfig; import org.apache.uniffle.common.PartitionRange; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleAssignmentsInfo; @@ -183,7 +182,7 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { if (isDriver) { heartBeatScheduledExecutorService = ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat"); - if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false) + if (sparkConf.get(RssSparkConfig.RSS_RESUBMIT_STAGE) && RssSparkShuffleUtils.isStageResubmitSupported()) { LOG.info("stage resubmit is supported and enabled"); // start shuffle manager server diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index e1d535d335..d69196974a 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -61,7 +61,6 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; import org.apache.uniffle.client.util.ClientUtils; -import org.apache.uniffle.client.util.RssClientConfig; import org.apache.uniffle.common.PartitionRange; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleAssignmentsInfo; @@ -207,7 +206,7 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { if (isDriver) { heartBeatScheduledExecutorService = ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat"); - if (rssConf.getBoolean(RssClientConfig.RSS_RESUBMIT_STAGE, false) + if (sparkConf.get(RssSparkConfig.RSS_RESUBMIT_STAGE) && RssSparkShuffleUtils.isStageResubmitSupported()) { LOG.info("stage resubmit is supported and enabled"); // start shuffle manager server diff --git a/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java b/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java index 1391308f8b..4464d74970 100644 --- a/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java +++ b/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java @@ -512,6 +512,8 @@ public void excludeNodesNoDelayTest() throws Exception { assertEquals(4, scm.getNodesNum()); assertEquals(2, scm.getExcludeNodes().size()); } + File blacklistFile = new File(excludeNodesPath); + assertTrue(blacklistFile.delete()); } private void writeExcludeHosts(String path, Set values) throws Exception { diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java index 282243da9e..419efb4410 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RSSStageResubmitTest.java @@ -41,10 +41,12 @@ public class RSSStageResubmitTest extends SparkIntegrationTestBase { @BeforeAll public static void setupServers() throws Exception { - CoordinatorConf coordinatorConf = getCoordinatorConf(); + final CoordinatorConf coordinatorConf = getCoordinatorConf(); Map dynamicConf = Maps.newHashMap(); dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test"); dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name()); + dynamicConf.put( + RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_RESUBMIT_STAGE, "true"); addDynamicConf(coordinatorConf, dynamicConf); createCoordinatorServer(coordinatorConf); ShuffleServerConf shuffleServerConf = getShuffleServerConf(); @@ -79,8 +81,6 @@ protected SparkConf createSparkConf() { @Override public void updateSparkConfCustomer(SparkConf sparkConf) { - sparkConf.set( - RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConfig.RSS_RESUBMIT_STAGE, "true"); sparkConf.set("spark.task.maxFailures", String.valueOf(maxTaskFailures)); }