Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.storage

import java.io.IOException
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable
Expand All @@ -30,15 +31,22 @@ import org.apache.spark.internal.{config, Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.shuffle.ShuffleBlockInfo
import org.apache.spark.storage.BlockManagerMessages.ReplicateBlock
import org.apache.spark.util.{ThreadUtils, Utils}
import org.apache.spark.util.{EventLoop, ThreadUtils, Utils}

private[storage] sealed trait DecommissionEvent

private object PeriodicShuffleRefresh extends DecommissionEvent

private case class ShuffleMigrationPeerAbortion(peer: BlockManagerId) extends DecommissionEvent

/**
* Class to handle block manager decommissioning retries.
* It creates a Thread to retry migrating all RDD cache and Shuffle blocks
*/
private[storage] class BlockManagerDecommissioner(
conf: SparkConf,
bm: BlockManager) extends Logging {
bm: BlockManager)
extends EventLoop[DecommissionEvent]("decommission-event-loop") with Logging {

private val fallbackStorage = FallbackStorage.getFallbackStorage(conf)
private val maxReplicationFailuresForDecommission =
Expand Down Expand Up @@ -68,7 +76,7 @@ private[storage] class BlockManagerDecommissioner(
* The producer/consumer model is chosen for shuffle block migration to maximize
* the chance of migrating all shuffle blocks before the executor is forced to exit.
*/
private class ShuffleMigrationRunnable(peer: BlockManagerId) extends Runnable {
private[storage] class ShuffleMigrationRunnable(peer: BlockManagerId) extends Runnable {
@volatile var keepRunning = true

private def allowRetry(shuffleBlock: ShuffleBlockInfo, failureNum: Int): Boolean = {
Expand Down Expand Up @@ -196,6 +204,7 @@ private[storage] class BlockManagerDecommissioner(
logError("Error occurred during shuffle blocks migration.", e)
}
}
post(ShuffleMigrationPeerAbortion(peer))
}
}

Expand All @@ -218,8 +227,15 @@ private[storage] class BlockManagerDecommissioner(
@volatile private var stoppedShuffle =
!conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED)

private val migrationPeers =
mutable.HashMap[BlockManagerId, ShuffleMigrationRunnable]()
// All the known peers to the current block manager, including standby, active and dead peers.
private[storage] val knownShuffleMigrationPeers = mutable.HashSet[BlockManagerId]()
// Standby peers that are used for quick migration peer refresh if any migration peer
// is aborted. The standby peer could be stale (exited or decommissioned) if it has
// waited for a while. And we use stack to avoid the bad case as much as possible.
private[storage] val standbyShuffleMigrationPeers = mutable.ArrayDeque[BlockManagerId]()
// All the active migration peers that are currently migrating shuffle blocks.
private[storage] val activeShuffleMigrationPeers =
new mutable.HashMap[BlockManagerId, ShuffleMigrationRunnable]()

private val rddBlockMigrationExecutor =
if (conf.get(config.STORAGE_DECOMMISSION_RDD_BLOCKS_ENABLED)) {
Expand Down Expand Up @@ -264,38 +280,34 @@ private[storage] class BlockManagerDecommissioner(

private val shuffleBlockMigrationRefreshExecutor =
if (conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED)) {
Some(ThreadUtils.newDaemonSingleThreadExecutor("block-manager-decommission-shuffle"))
Some(ThreadUtils.newDaemonSingleThreadScheduledExecutor("block-manager-decommission-shuffle"))
} else None

private val shuffleBlockMigrationRefreshRunnable = new Runnable {
val sleepInterval = conf.get(config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL)

override def run(): Unit = {
logInfo("Attempting to migrate all shuffle blocks")
while (!stopped && !stoppedShuffle) {
try {
val startTime = System.nanoTime()
shuffleBlocksLeft = refreshMigratableShuffleBlocks()
lastShuffleMigrationTime = startTime
logInfo(log"Finished current round refreshing migratable shuffle blocks, " +
log"waiting for ${MDC(SLEEP_TIME, sleepInterval)}ms before the " +
log"next round refreshing.")
Thread.sleep(sleepInterval)
} catch {
case _: InterruptedException if stopped =>
logInfo("Stop refreshing migratable shuffle blocks.")
case NonFatal(e) =>
logError("Error occurred during shuffle blocks migration.", e)
stoppedShuffle = true
}
}
private def shuffleRefresh(): Unit = {
if (stopped || stoppedShuffle) return
try {
logInfo("Start shuffle refresh")
val startTime = System.nanoTime()
refreshShuffleMigrationPeers()
shuffleBlocksLeft = refreshMigratableShuffleBlocks()
lastShuffleMigrationTime = startTime
logInfo(log"Finished current round shuffle refresh")
} catch {
case _: InterruptedException if stopped =>
logInfo("Stop refreshing migratable shuffle blocks.")
case NonFatal(e) =>
logError("Error occurred during shuffle blocks migration.", e)
stoppedShuffle = true
}
}

// The maximum number of distinguished (non-decommissioned) peers that can
// be used to migrate shuffle blocks concurrently.
private val maxShuffleMigrationPeers = conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_MAX_THREADS)

private val shuffleMigrationPool =
if (conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED)) {
Some(ThreadUtils.newDaemonCachedThreadPool("migrate-shuffles",
conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_MAX_THREADS)))
Some(ThreadUtils.newDaemonCachedThreadPool("migrate-shuffles", maxShuffleMigrationPeers))
} else None

/**
Expand All @@ -318,28 +330,57 @@ private[storage] class BlockManagerDecommissioner(
logInfo(log"${MDC(COUNT, newShufflesToMigrate.size)} of " +
log"${MDC(TOTAL, localShuffles.size)} local shuffles are added. " +
log"In total, ${MDC(NUM_REMAINED, remainedShuffles)} shuffles are remained.")
// If we found any new shuffles to migrate or otherwise have not migrated everything.
newShufflesToMigrate.nonEmpty || migratingShuffles.size > numMigratedShuffles.get()
}

private def refreshShuffleMigrationPeers(): Unit = {
if (stopped || stoppedShuffle) return

logInfo("Start refreshing shuffle migration peers (active/standby/max=" +
s"${activeShuffleMigrationPeers.size}/" +
s"${standbyShuffleMigrationPeers.size}/" +
s"$maxShuffleMigrationPeers)")

@inline def refresh(): Unit = {
var numPeersToAdd = maxShuffleMigrationPeers - activeShuffleMigrationPeers.size
// Refill migration peers from the standby peers
while (standbyShuffleMigrationPeers.nonEmpty && numPeersToAdd > 0) {
val standByPeer = standbyShuffleMigrationPeers.removeHead()
val migrationThread = new ShuffleMigrationRunnable(standByPeer)
shuffleMigrationPool.foreach(_.submit(migrationThread))
activeShuffleMigrationPeers.put(standByPeer, migrationThread)
numPeersToAdd -= 1
}
}

refresh()

// Update the threads doing migrations
val livePeerSet = bm.getPeers(false).toSet
val currentPeerSet = migrationPeers.keys.toSet
val deadPeers = currentPeerSet.diff(livePeerSet)
// Randomize the orders of the peers to avoid hotspot nodes.
val newPeers = Utils.randomize(livePeerSet.diff(currentPeerSet))
migrationPeers ++= newPeers.map { peer =>
logDebug(s"Starting thread to migrate shuffle blocks to ${peer}")
val runnable = new ShuffleMigrationRunnable(peer)
shuffleMigrationPool.foreach(_.submit(runnable))
(peer, runnable)
if (activeShuffleMigrationPeers.size < maxShuffleMigrationPeers) {
// Refresh peers from the remote if possible
val livePeerSet = bm.getPeers(false).toSet
val deadPeers = knownShuffleMigrationPeers.diff(livePeerSet)
deadPeers.foreach { dp =>
activeShuffleMigrationPeers.remove(dp).foreach(_.keepRunning = false)
standbyShuffleMigrationPeers.removeFirst(_ == dp)
}
// Randomize the orders of the peers to avoid hotspot nodes.
val newPeers = Utils.randomize(livePeerSet.diff(knownShuffleMigrationPeers))
knownShuffleMigrationPeers ++= newPeers
newPeers.foreach(standbyShuffleMigrationPeers.prepend)
refresh()
}
// A peer may have entered a decommissioning state, don't transfer any new blocks
deadPeers.foreach(migrationPeers.get(_).foreach(_.keepRunning = false))
// If we don't have anyone to migrate to give up
if (!migrationPeers.values.exists(_.keepRunning)) {

// Give up if we don't have any peer to migrate to
if (activeShuffleMigrationPeers.isEmpty) {
logWarning("No available peers to receive Shuffle blocks, stop migration.")
stoppedShuffle = true
} else {
logInfo("Finish refreshing shuffle migration peers (active/standby/max=" +
s"${activeShuffleMigrationPeers.size}/" +
s"${standbyShuffleMigrationPeers.size}/" +
s"$maxShuffleMigrationPeers)")
}
// If we found any new shuffles to migrate or otherwise have not migrated everything.
newShufflesToMigrate.nonEmpty || migratingShuffles.size > numMigratedShuffles.get()
}

/**
Expand All @@ -349,7 +390,9 @@ private[storage] class BlockManagerDecommissioner(
shuffleMigrationPool.foreach { threadPool =>
logInfo("Stopping migrating shuffle blocks.")
// Stop as gracefully as possible.
migrationPeers.values.foreach(_.keepRunning = false)
activeShuffleMigrationPeers.values.foreach(_.keepRunning = false)
standbyShuffleMigrationPeers.clear()
knownShuffleMigrationPeers.clear()
threadPool.shutdownNow()
}
}
Expand Down Expand Up @@ -403,13 +446,21 @@ private[storage] class BlockManagerDecommissioner(
replicatedSuccessfully
}

def start(): Unit = {
override def start(): Unit = {
logInfo("Starting block migration")
super.start()
rddBlockMigrationExecutor.foreach(_.submit(rddBlockMigrationRunnable))
shuffleBlockMigrationRefreshExecutor.foreach(_.submit(shuffleBlockMigrationRefreshRunnable))
shuffleBlockMigrationRefreshExecutor.foreach(
_.scheduleAtFixedRate(
() => post(PeriodicShuffleRefresh),
0,
conf.get(config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL),
TimeUnit.MILLISECONDS
)
)
}

def stop(): Unit = {
override def stop(): Unit = {
if (stopped) {
return
} else {
Expand All @@ -433,9 +484,23 @@ private[storage] class BlockManagerDecommissioner(
case NonFatal(e) =>
logError(s"Error during shutdown shuffle block migration thread", e)
}
super.stop()
logInfo("Stopped block migration")
}

override def onReceive(event: DecommissionEvent): Unit = {
event match {
case PeriodicShuffleRefresh => shuffleRefresh()
case ShuffleMigrationPeerAbortion(peer) =>
activeShuffleMigrationPeers.remove(peer)
refreshShuffleMigrationPeers()
}
}

override def onError(e: Throwable): Unit = {
logError("Error in decommission event loop", e)
}

/*
* Returns the last migration time and a boolean for if all blocks have been migrated.
* The last migration time is calculated to be the minimum of the last migration of any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,4 +423,55 @@ class BlockManagerDecommissionUnitSuite extends SparkFunSuite with Matchers {
bmDecomManager.stop()
}
}

test("SPARK-48637: on-demand migration peer refresh") {
// Set refresh interval to the Long.MaxValue to avoid the periodic peer refresh
sparkConf.set(config.STORAGE_DECOMMISSION_REPLICATION_REATTEMPT_INTERVAL, Long.MaxValue)
// Set the max concurrent shuffle migration threads to 1
sparkConf.set(config.STORAGE_DECOMMISSION_SHUFFLE_MAX_THREADS, 1)

val bm = mock(classOf[BlockManager])
val peer1 = mock(classOf[BlockManagerId])
val peer2 = mock(classOf[BlockManagerId])
when(bm.getPeers(mc.any())).thenReturn(Seq(peer1, peer2))

val blockTransferService = mock(classOf[BlockTransferService])
when(bm.blockTransferService).thenReturn(blockTransferService)

val migratableShuffleBlockResolver = mock(classOf[MigratableResolver])
registerShuffleBlocks(migratableShuffleBlockResolver, Set())
when(bm.migratableResolver).thenReturn(migratableShuffleBlockResolver)

val decommissioner = new BlockManagerDecommissioner(sparkConf, bm)
assert(decommissioner.knownShuffleMigrationPeers.isEmpty)
assert(decommissioner.standbyShuffleMigrationPeers.isEmpty)
assert(decommissioner.activeShuffleMigrationPeers.isEmpty)
decommissioner.start()

eventually(timeout(10.seconds), interval(10.milliseconds)) {
assert(decommissioner.knownShuffleMigrationPeers.size === 2)
assert(decommissioner.standbyShuffleMigrationPeers.size === 1)
assert(decommissioner.activeShuffleMigrationPeers.size === 1)
}

val preStandByPeer = decommissioner.standbyShuffleMigrationPeers.head
val preActivePeer = decommissioner.activeShuffleMigrationPeers.head

// Terminate the active migration peer
preActivePeer._2.keepRunning = false
decommissioner.shufflesToMigrate.add((ShuffleBlockInfo(1, 1L), 0))
when(migratableShuffleBlockResolver.getMigrationBlocks(mc.eq(ShuffleBlockInfo(1, 1L))))
.thenReturn(List.empty)

eventually(timeout(10.seconds), interval(10.milliseconds)) {
assert(decommissioner.knownShuffleMigrationPeers.size === 2)
assert(decommissioner.standbyShuffleMigrationPeers.size === 0)
assert(decommissioner.activeShuffleMigrationPeers.size === 1)
}

// The previous active peer should be removed and the previous standby peer
// should be transferred to the active peer now
assert(!decommissioner.activeShuffleMigrationPeers.contains(preActivePeer._1))
assert(decommissioner.activeShuffleMigrationPeers.head._1 === preStandByPeer)
}
}