diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala index 19807453ee28c..352a762ede895 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerDecommissioner.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.IOException +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable @@ -30,7 +31,13 @@ import org.apache.spark.internal.{config, Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.shuffle.ShuffleBlockInfo import org.apache.spark.storage.BlockManagerMessages.ReplicateBlock -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{EventLoop, ThreadUtils, Utils} + +private[storage] sealed trait DecommissionEvent + +private object PeriodicShuffleRefresh extends DecommissionEvent + +private case class ShuffleMigrationPeerAbortion(peer: BlockManagerId) extends DecommissionEvent /** * Class to handle block manager decommissioning retries. @@ -38,7 +45,8 @@ import org.apache.spark.util.{ThreadUtils, Utils} */ private[storage] class BlockManagerDecommissioner( conf: SparkConf, - bm: BlockManager) extends Logging { + bm: BlockManager) + extends EventLoop[DecommissionEvent]("decommission-event-loop") with Logging { private val fallbackStorage = FallbackStorage.getFallbackStorage(conf) private val maxReplicationFailuresForDecommission = @@ -68,7 +76,7 @@ private[storage] class BlockManagerDecommissioner( * The producer/consumer model is chosen for shuffle block migration to maximize * the chance of migrating all shuffle blocks before the executor is forced to exit. */ - private class ShuffleMigrationRunnable(peer: BlockManagerId) extends Runnable { + private[storage] class ShuffleMigrationRunnable(peer: BlockManagerId) extends Runnable { @volatile var keepRunning = true private def allowRetry(shuffleBlock: ShuffleBlockInfo, failureNum: Int): Boolean = { @@ -196,6 +204,7 @@ private[storage] class BlockManagerDecommissioner( logError("Error occurred during shuffle blocks migration.", e) } } + post(ShuffleMigrationPeerAbortion(peer)) } } @@ -218,8 +227,15 @@ private[storage] class BlockManagerDecommissioner( @volatile private var stoppedShuffle = !conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED) - private val migrationPeers = - mutable.HashMap[BlockManagerId, ShuffleMigrationRunnable]() + // All the known peers to the current block manager, including standby, active and dead peers. + private[storage] val knownShuffleMigrationPeers = mutable.HashSet[BlockManagerId]() + // Standby peers that are used for quick migration peer refresh if any migration peer + // is aborted. The standby peer could be stale (exited or decommissioned) if it has + // waited for a while. And we use stack to avoid the bad case as much as possible. + private[storage] val standbyShuffleMigrationPeers = mutable.ArrayDeque[BlockManagerId]() + // All the active migration peers that are currently migrating shuffle blocks. + private[storage] val activeShuffleMigrationPeers = + new mutable.HashMap[BlockManagerId, ShuffleMigrationRunnable]() private val rddBlockMigrationExecutor = if (conf.get(config.STORAGE_DECOMMISSION_RDD_BLOCKS_ENABLED)) { @@ -264,38 +280,34 @@ private[storage] class BlockManagerDecommissioner( private val shuffleBlockMigrationRefreshExecutor = if (conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED)) { - Some(ThreadUtils.newDaemonSingleThreadExecutor("block-manager-decommission-shuffle")) + Some(ThreadUtils.newDaemonSingleThreadScheduledExecutor("block-manager-decommission-shuffle")) } else None - private val shuffleBlockMigrationRefreshRunnable = new Runnable { - val sleepInterval = conf.get(config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL) - - override def run(): Unit = { - logInfo("Attempting to migrate all shuffle blocks") - while (!stopped && !stoppedShuffle) { - try { - val startTime = System.nanoTime() - shuffleBlocksLeft = refreshMigratableShuffleBlocks() - lastShuffleMigrationTime = startTime - logInfo(log"Finished current round refreshing migratable shuffle blocks, " + - log"waiting for ${MDC(SLEEP_TIME, sleepInterval)}ms before the " + - log"next round refreshing.") - Thread.sleep(sleepInterval) - } catch { - case _: InterruptedException if stopped => - logInfo("Stop refreshing migratable shuffle blocks.") - case NonFatal(e) => - logError("Error occurred during shuffle blocks migration.", e) - stoppedShuffle = true - } - } + private def shuffleRefresh(): Unit = { + if (stopped || stoppedShuffle) return + try { + logInfo("Start shuffle refresh") + val startTime = System.nanoTime() + refreshShuffleMigrationPeers() + shuffleBlocksLeft = refreshMigratableShuffleBlocks() + lastShuffleMigrationTime = startTime + logInfo(log"Finished current round shuffle refresh") + } catch { + case _: InterruptedException if stopped => + logInfo("Stop refreshing migratable shuffle blocks.") + case NonFatal(e) => + logError("Error occurred during shuffle blocks migration.", e) + stoppedShuffle = true } } + // The maximum number of distinguished (non-decommissioned) peers that can + // be used to migrate shuffle blocks concurrently. + private val maxShuffleMigrationPeers = conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_MAX_THREADS) + private val shuffleMigrationPool = if (conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED)) { - Some(ThreadUtils.newDaemonCachedThreadPool("migrate-shuffles", - conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_MAX_THREADS))) + Some(ThreadUtils.newDaemonCachedThreadPool("migrate-shuffles", maxShuffleMigrationPeers)) } else None /** @@ -318,28 +330,57 @@ private[storage] class BlockManagerDecommissioner( logInfo(log"${MDC(COUNT, newShufflesToMigrate.size)} of " + log"${MDC(TOTAL, localShuffles.size)} local shuffles are added. " + log"In total, ${MDC(NUM_REMAINED, remainedShuffles)} shuffles are remained.") + // If we found any new shuffles to migrate or otherwise have not migrated everything. + newShufflesToMigrate.nonEmpty || migratingShuffles.size > numMigratedShuffles.get() + } + + private def refreshShuffleMigrationPeers(): Unit = { + if (stopped || stoppedShuffle) return + + logInfo("Start refreshing shuffle migration peers (active/standby/max=" + + s"${activeShuffleMigrationPeers.size}/" + + s"${standbyShuffleMigrationPeers.size}/" + + s"$maxShuffleMigrationPeers)") + + @inline def refresh(): Unit = { + var numPeersToAdd = maxShuffleMigrationPeers - activeShuffleMigrationPeers.size + // Refill migration peers from the standby peers + while (standbyShuffleMigrationPeers.nonEmpty && numPeersToAdd > 0) { + val standByPeer = standbyShuffleMigrationPeers.removeHead() + val migrationThread = new ShuffleMigrationRunnable(standByPeer) + shuffleMigrationPool.foreach(_.submit(migrationThread)) + activeShuffleMigrationPeers.put(standByPeer, migrationThread) + numPeersToAdd -= 1 + } + } + + refresh() - // Update the threads doing migrations - val livePeerSet = bm.getPeers(false).toSet - val currentPeerSet = migrationPeers.keys.toSet - val deadPeers = currentPeerSet.diff(livePeerSet) - // Randomize the orders of the peers to avoid hotspot nodes. - val newPeers = Utils.randomize(livePeerSet.diff(currentPeerSet)) - migrationPeers ++= newPeers.map { peer => - logDebug(s"Starting thread to migrate shuffle blocks to ${peer}") - val runnable = new ShuffleMigrationRunnable(peer) - shuffleMigrationPool.foreach(_.submit(runnable)) - (peer, runnable) + if (activeShuffleMigrationPeers.size < maxShuffleMigrationPeers) { + // Refresh peers from the remote if possible + val livePeerSet = bm.getPeers(false).toSet + val deadPeers = knownShuffleMigrationPeers.diff(livePeerSet) + deadPeers.foreach { dp => + activeShuffleMigrationPeers.remove(dp).foreach(_.keepRunning = false) + standbyShuffleMigrationPeers.removeFirst(_ == dp) + } + // Randomize the orders of the peers to avoid hotspot nodes. + val newPeers = Utils.randomize(livePeerSet.diff(knownShuffleMigrationPeers)) + knownShuffleMigrationPeers ++= newPeers + newPeers.foreach(standbyShuffleMigrationPeers.prepend) + refresh() } - // A peer may have entered a decommissioning state, don't transfer any new blocks - deadPeers.foreach(migrationPeers.get(_).foreach(_.keepRunning = false)) - // If we don't have anyone to migrate to give up - if (!migrationPeers.values.exists(_.keepRunning)) { + + // Give up if we don't have any peer to migrate to + if (activeShuffleMigrationPeers.isEmpty) { logWarning("No available peers to receive Shuffle blocks, stop migration.") stoppedShuffle = true + } else { + logInfo("Finish refreshing shuffle migration peers (active/standby/max=" + + s"${activeShuffleMigrationPeers.size}/" + + s"${standbyShuffleMigrationPeers.size}/" + + s"$maxShuffleMigrationPeers)") } - // If we found any new shuffles to migrate or otherwise have not migrated everything. - newShufflesToMigrate.nonEmpty || migratingShuffles.size > numMigratedShuffles.get() } /** @@ -349,7 +390,9 @@ private[storage] class BlockManagerDecommissioner( shuffleMigrationPool.foreach { threadPool => logInfo("Stopping migrating shuffle blocks.") // Stop as gracefully as possible. - migrationPeers.values.foreach(_.keepRunning = false) + activeShuffleMigrationPeers.values.foreach(_.keepRunning = false) + standbyShuffleMigrationPeers.clear() + knownShuffleMigrationPeers.clear() threadPool.shutdownNow() } } @@ -403,13 +446,21 @@ private[storage] class BlockManagerDecommissioner( replicatedSuccessfully } - def start(): Unit = { + override def start(): Unit = { logInfo("Starting block migration") + super.start() rddBlockMigrationExecutor.foreach(_.submit(rddBlockMigrationRunnable)) - shuffleBlockMigrationRefreshExecutor.foreach(_.submit(shuffleBlockMigrationRefreshRunnable)) + shuffleBlockMigrationRefreshExecutor.foreach( + _.scheduleAtFixedRate( + () => post(PeriodicShuffleRefresh), + 0, + conf.get(config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL), + TimeUnit.MILLISECONDS + ) + ) } - def stop(): Unit = { + override def stop(): Unit = { if (stopped) { return } else { @@ -433,9 +484,23 @@ private[storage] class BlockManagerDecommissioner( case NonFatal(e) => logError(s"Error during shutdown shuffle block migration thread", e) } + super.stop() logInfo("Stopped block migration") } + override def onReceive(event: DecommissionEvent): Unit = { + event match { + case PeriodicShuffleRefresh => shuffleRefresh() + case ShuffleMigrationPeerAbortion(peer) => + activeShuffleMigrationPeers.remove(peer) + refreshShuffleMigrationPeers() + } + } + + override def onError(e: Throwable): Unit = { + logError("Error in decommission event loop", e) + } + /* * Returns the last migration time and a boolean for if all blocks have been migrated. * The last migration time is calculated to be the minimum of the last migration of any diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionUnitSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionUnitSuite.scala index b7ad6722faa8c..d1f46a2f63dd3 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionUnitSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionUnitSuite.scala @@ -423,4 +423,55 @@ class BlockManagerDecommissionUnitSuite extends SparkFunSuite with Matchers { bmDecomManager.stop() } } + + test("SPARK-48637: on-demand migration peer refresh") { + // Set refresh interval to the Long.MaxValue to avoid the periodic peer refresh + sparkConf.set(config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL, Long.MaxValue) + // Set the max concurrent shuffle migration threads to 1 + sparkConf.set(config.STORAGE_DECOMMISSION_SHUFFLE_MAX_THREADS, 1) + + val bm = mock(classOf[BlockManager]) + val peer1 = mock(classOf[BlockManagerId]) + val peer2 = mock(classOf[BlockManagerId]) + when(bm.getPeers(mc.any())).thenReturn(Seq(peer1, peer2)) + + val blockTransferService = mock(classOf[BlockTransferService]) + when(bm.blockTransferService).thenReturn(blockTransferService) + + val migratableShuffleBlockResolver = mock(classOf[MigratableResolver]) + registerShuffleBlocks(migratableShuffleBlockResolver, Set()) + when(bm.migratableResolver).thenReturn(migratableShuffleBlockResolver) + + val decommissioner = new BlockManagerDecommissioner(sparkConf, bm) + assert(decommissioner.knownShuffleMigrationPeers.isEmpty) + assert(decommissioner.standbyShuffleMigrationPeers.isEmpty) + assert(decommissioner.activeShuffleMigrationPeers.isEmpty) + decommissioner.start() + + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(decommissioner.knownShuffleMigrationPeers.size === 2) + assert(decommissioner.standbyShuffleMigrationPeers.size === 1) + assert(decommissioner.activeShuffleMigrationPeers.size === 1) + } + + val preStandByPeer = decommissioner.standbyShuffleMigrationPeers.head + val preActivePeer = decommissioner.activeShuffleMigrationPeers.head + + // Terminate the active migration peer + preActivePeer._2.keepRunning = false + decommissioner.shufflesToMigrate.add((ShuffleBlockInfo(1, 1L), 0)) + when(migratableShuffleBlockResolver.getMigrationBlocks(mc.eq(ShuffleBlockInfo(1, 1L)))) + .thenReturn(List.empty) + + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(decommissioner.knownShuffleMigrationPeers.size === 2) + assert(decommissioner.standbyShuffleMigrationPeers.size === 0) + assert(decommissioner.activeShuffleMigrationPeers.size === 1) + } + + // The previous active peer should be removed and the previous standby peer + // should be transferred to the active peer now + assert(!decommissioner.activeShuffleMigrationPeers.contains(preActivePeer._1)) + assert(decommissioner.activeShuffleMigrationPeers.head._1 === preStandByPeer) + } }