diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java index eb4d9d9abc8e3..861a8e623a6e5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java @@ -56,7 +56,8 @@ public void initializeExecutor(String appId, String execId, Map if (blockManager == null) { throw new IllegalStateException("No blockManager available from the SparkEnv."); } - blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); + blockResolver = + new IndexShuffleBlockResolver(sparkConf, blockManager, Map.of() /* Shouldn't be accessed */); } @Override diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index dde9b541b62fc..20b8d0809f329 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -21,6 +21,7 @@ import java.io._ import java.nio.ByteBuffer import java.nio.channels.Channels import java.nio.file.Files +import java.util.{Map => JMap} import scala.collection.mutable.ArrayBuffer @@ -40,6 +41,7 @@ import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage._ import org.apache.spark.util.Utils +import org.apache.spark.util.collection.OpenHashSet /** * Create and maintain the shuffle blocks' mapping between logic block and physical file location. @@ -55,7 +57,8 @@ import org.apache.spark.util.Utils private[spark] class IndexShuffleBlockResolver( conf: SparkConf, // var for testing - var _blockManager: BlockManager = null) + var _blockManager: BlockManager = null, + val taskIdMapsForShuffle: JMap[Int, OpenHashSet[Long]] = JMap.of()) extends ShuffleBlockResolver with Logging with MigratableResolver { @@ -285,6 +288,21 @@ private[spark] class IndexShuffleBlockResolver( throw SparkCoreErrors.failedRenameTempFileError(fileTmp, file) } } + blockId match { + case ShuffleIndexBlockId(shuffleId, mapId, _) => + val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent( + shuffleId, _ => new OpenHashSet[Long](8) + ) + mapTaskIds.add(mapId) + + case ShuffleDataBlockId(shuffleId, mapId, _) => + val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent( + shuffleId, _ => new OpenHashSet[Long](8) + ) + mapTaskIds.add(mapId) + + case _ => // Unreachable + } blockManager.reportBlockStatus(blockId, BlockStatus(StorageLevel.DISK_ONLY, 0, diskSize)) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 4d811b051de08..efffda43695cc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -87,7 +87,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) - override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + override val shuffleBlockResolver = + new IndexShuffleBlockResolver(conf, taskIdMapsForShuffle = taskIdMapsForShuffle) /** * Obtains a [[ShuffleHandle]] to pass to tasks. diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 063d391bb4bfd..fac84b2e9187d 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -314,7 +314,8 @@ public void writeWithoutSpilling() throws Exception { @Test public void writeChecksumFileWithoutSpill() throws Exception { - IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager); + IndexShuffleBlockResolver blockResolver = + new IndexShuffleBlockResolver(conf, blockManager, Map.of()); ShuffleChecksumBlockId checksumBlockId = new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); @@ -344,7 +345,8 @@ public void writeChecksumFileWithoutSpill() throws Exception { @Test public void writeChecksumFileWithSpill() throws Exception { - IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager); + IndexShuffleBlockResolver blockResolver = + new IndexShuffleBlockResolver(conf, blockManager, Map.of()); ShuffleChecksumBlockId checksumBlockId = new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala index ba665600a1cb7..febe1ac4bb4cf 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.storage +import java.io.File import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, Semaphore, TimeUnit} import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.jdk.CollectionConverters._ +import org.apache.commons.io.FileUtils import org.scalatest.concurrent.Eventually import org.apache.spark._ @@ -353,4 +355,78 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS import scala.language.reflectiveCalls assert(listener.removeReasonValidated) } + + test("SPARK-46957: Migrated shuffle files should be able to cleanup from executor") { + + val sparkTempDir = System.getProperty("java.io.tmpdir") + + def shuffleFiles: Seq[File] = { + FileUtils + .listFiles(new File(sparkTempDir), Array("data", "index"), true) + .asScala + .toSeq + } + + val existingShuffleFiles = shuffleFiles + + val conf = new SparkConf() + .setAppName("SPARK-46957") + .setMaster("local-cluster[2,1,1024]") + .set(config.DECOMMISSION_ENABLED, true) + .set(config.STORAGE_DECOMMISSION_ENABLED, true) + .set(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED, true) + sc = new SparkContext(conf) + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + val shuffleBlockUpdates = new ArrayBuffer[BlockId]() + var isDecommissionedExecutorRemoved = false + val execToDecommission = sc.getExecutorIds().head + sc.addSparkListener(new SparkListener { + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { + if (blockUpdated.blockUpdatedInfo.blockId.isShuffle) { + shuffleBlockUpdates += blockUpdated.blockUpdatedInfo.blockId + } + } + + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + assert(execToDecommission === executorRemoved.executorId) + isDecommissionedExecutorRemoved = true + } + }) + + // Run a job to create shuffle data + val result = sc.parallelize(1 to 1000, 10) + .map { i => (i % 2, i) } + .reduceByKey(_ + _).collect() + + assert(result.head === (0, 250500)) + assert(result.tail.head === (1, 250000)) + sc.schedulerBackend + .asInstanceOf[StandaloneSchedulerBackend] + .decommissionExecutor( + execToDecommission, + ExecutorDecommissionInfo("test", None), + adjustTargetNumExecutors = true + ) + + eventually(timeout(1.minute), interval(10.milliseconds)) { + assert(isDecommissionedExecutorRemoved) + // Ensure there are shuffle data have been migrated + assert(shuffleBlockUpdates.size >= 2) + } + + val shuffleId = shuffleBlockUpdates + .find(_.isInstanceOf[ShuffleIndexBlockId]) + .map(_.asInstanceOf[ShuffleIndexBlockId].shuffleId) + .get + + val newShuffleFiles = shuffleFiles.diff(existingShuffleFiles) + assert(newShuffleFiles.size >= shuffleBlockUpdates.size) + + // Remove the shuffle data + sc.shuffleDriverComponents.removeShuffle(shuffleId, true) + + eventually(timeout(1.minute), interval(10.milliseconds)) { + assert(newShuffleFiles.intersect(shuffleFiles).isEmpty) + } + } }