Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44635][CORE] Handle shuffle fetch failures in decommissions #42296

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Expand Up @@ -1288,6 +1288,32 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
mapSizesByExecutorId.iter
}

def getMapOutputLocationWithRefresh(
Copy link
Contributor

@ukby1234 ukby1234 Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe return Option[BlockManagerId]?:

  def getMapOutputLocationWithRefresh(
                                       shuffleId: Int,
                                       mapId: Long,
                                       prevLocation: BlockManagerId): Option[BlockManagerId] = {
    // Try to get the cached location first in case other concurrent tasks
    // fetched the fresh location already
    getMapOutputLocation(shuffleId, mapId) match {
      case Some(location) =>
        if (location == prevLocation) {
          unregisterShuffle(shuffleId)
          getMapOutputLocation(shuffleId, mapId)
        } else {
          Some(location)
        }
      case _ =>
        None
    }
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still want to throw a MetadataFetchFailedException when failing to get a refreshed location here. So I would prefer returning a BlockManagerId and make it specific.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do the following with Option:

                val currentAddressOpt = mapOutputTrackerWorker
                  .getMapOutputLocationWithRefresh(shuffleId, mapId, address)
                currentAddressOpt match {
                  case Some(currentAddress) =>
                    if (currentAddress != address) {
                      logInfo(s"Map status location for block $blockId changed from $address " +
                        s"to $currentAddress")
                      remainingBlocks -= blockId
                      deferredBlocks.getOrElseUpdate(currentAddress, new ArrayBuffer[String]())
                        .append(blockId)
                      enqueueDeferredFetchRequestIfNecessary()
                    } else {
                      results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e))
                    }
                  case None =>
                    results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e))
                }

It is also consistent with other function signatures with getMapOutputLocation.

shuffleId: Int,
mapId: Long,
prevLocation: BlockManagerId): BlockManagerId = {
// Try to get the cached location first in case other concurrent tasks
// fetched the fresh location already
var currentLocationOpt = getMapOutputLocation(shuffleId, mapId)
if (currentLocationOpt.contains(prevLocation)) {
// Address in the cache unchanged. Try to clean cache and get a fresh location
unregisterShuffle(shuffleId)
currentLocationOpt = getMapOutputLocation(shuffleId, mapId, canFetchMergeResult = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
currentLocationOpt = getMapOutputLocation(shuffleId, mapId, canFetchMergeResult = true)
currentLocationOpt = getMapOutputLocation(shuffleId, mapId, fetchMergeResult)

}
currentLocationOpt.getOrElse(
throw new MetadataFetchFailedException(shuffleId, -1,
message = s"Failed to get map output location for shuffleId $shuffleId, mapId $mapId")
)
}

private def getMapOutputLocation(
shuffleId: Int,
mapId: Long,
canFetchMergeResult: Boolean = false): Option[BlockManagerId] = {
val (mapOutputStatuses, _) = getStatuses(shuffleId, conf, canFetchMergeResult)
mapOutputStatuses.filter(_ != null).find(_.mapId == mapId).map(_.location)
}

override def getPushBasedShuffleMapSizesByExecutorId(
shuffleId: Int,
startMapIndex: Int,
Expand Down
22 changes: 22 additions & 0 deletions core/src/main/scala/org/apache/spark/TestUtils.scala
Expand Up @@ -491,6 +491,28 @@ private[spark] object TestUtils {
EnumSet.of(OWNER_READ, OWNER_EXECUTE, OWNER_WRITE))
file.getPath
}

/** Sets all configs specified in `confPairs`, calls `f`, and then restores them. */
def withConf[T](confPairs: (String, String)*)(f: => T): T = {
bozhang2820 marked this conversation as resolved.
Show resolved Hide resolved
val conf = SparkEnv.get.conf
val (keys, values) = confPairs.unzip
val currentValues = keys.map { key =>
if (conf.contains(key)) {
Some(conf.get(key))
} else {
None
}
}
(keys, values).zipped.foreach { (key, value) =>
conf.set(key, value)
}
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => conf.set(key, value)
case (key, None) => conf.remove(key)
}
}
}
Comment on lines +497 to +515
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
val conf = SparkEnv.get.conf
val (keys, values) = confPairs.unzip
val currentValues = keys.map { key =>
if (conf.contains(key)) {
Some(conf.get(key))
} else {
None
}
}
(keys, values).zipped.foreach { (key, value) =>
conf.set(key, value)
}
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => conf.set(key, value)
case (key, None) => conf.remove(key)
}
}
}
def withConf[T](confPairs: (String, String)*)(f: => T): T = {
val conf = SparkEnv.get.conf
val inputConfMap = confPairs.toMap
val modifiedValues = conf.getAll.filter(kv => inputConfMap.contains(kv._1)).toMap
inputConfMap.foreach { kv =>
conf.set(kv._1, kv._2)
}
try f finally {
inputConfMap.keys.foreach { key =>
if (modifiedValues.contains(key)) {
conf.set(key, modifiedValues(key))
} else {
conf.remove(key)
}
}
}
}

}


Expand Down
Expand Up @@ -528,6 +528,15 @@ package object config {
.bytesConf(ByteUnit.BYTE)
.createOptional

private[spark] val STORAGE_DECOMMISSION_SHUFFLE_REFRESH =
ConfigBuilder("spark.storage.decommission.shuffleBlocks.refreshLocationsEnabled")
.doc("If true, executors will try to refresh the cached locations for the shuffle blocks" +
"when fetch failures happens (and decommission shuffle block migration is enabled), " +
"and retry fetching when the location changes.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

private[spark] val STORAGE_REPLICATION_TOPOLOGY_FILE =
ConfigBuilder("spark.storage.replication.topologyFile")
.version("2.1.0")
Expand Down
8 changes: 8 additions & 0 deletions core/src/main/scala/org/apache/spark/storage/BlockId.scala
Expand Up @@ -217,6 +217,14 @@ object BlockId {
val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r
val TEST = "test_(.*)".r

def getShuffleIdAndMapId(blockId: BlockId): (Int, Long) = blockId match {
case ShuffleBlockId(shuffleId, mapId, _) => (shuffleId, mapId)
case ShuffleBlockBatchId(shuffleId, mapId, _, _) => (shuffleId, mapId)
case ShuffleDataBlockId(shuffleId, mapId, _) => (shuffleId, mapId)
case ShuffleIndexBlockId(shuffleId, mapId, _) => (shuffleId, mapId)
case _ => throw new SparkException(s"Unexpected shuffle BlockId $blockId")
}

def apply(name: String): BlockId = name match {
case RDD(rddId, splitIndex) =>
RDDBlockId(rddId.toInt, splitIndex.toInt)
Expand Down
Expand Up @@ -27,16 +27,16 @@ import javax.annotation.concurrent.GuardedBy
import scala.collection
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
import scala.util.{Failure, Success}
import scala.util.{Failure, Success, Try}

import io.netty.util.internal.OutOfDirectMemoryError
import org.apache.commons.io.IOUtils
import org.roaringbitmap.RoaringBitmap

import org.apache.spark.{MapOutputTracker, SparkException, TaskContext}
import org.apache.spark.{MapOutputTracker, MapOutputTrackerWorker, SparkEnv, SparkException, TaskContext}
import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.internal.Logging
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle._
import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper}
Expand Down Expand Up @@ -111,6 +111,14 @@ final class ShuffleBlockFetcherIterator(
// nodes, rather than blocking on reading output from one node.
private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L)

private val isShuffleMigrationEnabled =
SparkEnv.get.conf.get(config.DECOMMISSION_ENABLED) &&
SparkEnv.get.conf.get(config.STORAGE_DECOMMISSION_ENABLED) &&
SparkEnv.get.conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED)

private val shouldPerformShuffleLocationRefresh =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about make this one of the constructor argument? One of the benefit is that you don't need to write tests with TestUtils.withConf

isShuffleMigrationEnabled && SparkEnv.get.conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_REFRESH)

/**
* Total number of blocks to fetch.
*/
Expand Down Expand Up @@ -152,7 +160,7 @@ final class ShuffleBlockFetcherIterator(
private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]()

/** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L
private var bytesInFlight = 0L

/** Current number of requests in flight */
private[this] var reqsInFlight = 0
Expand Down Expand Up @@ -264,18 +272,22 @@ final class ShuffleBlockFetcherIterator(
case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex))
}.toMap
val remainingBlocks = new HashSet[String]() ++= infoMap.keys
val deferredBlocks = new ArrayBuffer[String]()
val deferredBlocks = new HashMap[BlockManagerId, ArrayBuffer[String]]()
val blockIds = req.blocks.map(_.blockId.toString)
val address = req.address
val requestStartTime = clock.nanoTime()

@inline def enqueueDeferredFetchRequestIfNecessary(): Unit = {
if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) {
val blocks = deferredBlocks.map { blockId =>
val (size, mapIndex) = infoMap(blockId)
FetchBlockInfo(BlockId(blockId), size, mapIndex)
val newAddressToBlocks = new HashMap[BlockManagerId, ArrayBuffer[FetchBlockInfo]]()
deferredBlocks.foreach { case (blockManagerId, blockIds) =>
val blocks = blockIds.map { blockId =>
val (size, mapIndex) = infoMap(blockId)
FetchBlockInfo(BlockId(blockId), size, mapIndex)
}
newAddressToBlocks.put(blockManagerId, blocks)
}
mridulm marked this conversation as resolved.
Show resolved Hide resolved
results.put(DeferFetchRequestResult(FetchRequest(address, blocks)))
results.put(DeferFetchRequestResult(address, newAddressToBlocks))
deferredBlocks.clear()
}
}
Expand Down Expand Up @@ -344,7 +356,7 @@ final class ShuffleBlockFetcherIterator(
s"due to Netty OOM, will retry")
}
remainingBlocks -= blockId
deferredBlocks += blockId
deferredBlocks.getOrElseUpdate(address, new ArrayBuffer[String]()) += blockId
enqueueDeferredFetchRequestIfNecessary()
}

Expand All @@ -355,8 +367,28 @@ final class ShuffleBlockFetcherIterator(
updateMergedReqsDuration(wasReqForMergedChunks = true)
results.put(FallbackOnPushMergedFailureResult(
block, address, infoMap(blockId)._1, remainingBlocks.isEmpty))
} else {
} else if (!shouldPerformShuffleLocationRefresh) {
results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e))
} else {
val (shuffleId, mapId) = BlockId.getShuffleIdAndMapId(block)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move the getShuffleIdAndMapId into the Try ?
We will effectively block shuffle indefinitely in case getShuffleIdAndMapId throws an exception (it should not currently - but code could evolve).

Something like:

                Try {
                  val (shuffleId, mapId) = BlockId.getShuffleIdAndMapId(block)
                  mapOutputTrackerWorker
                    .getMapOutputLocationWithRefresh(shuffleId, mapId, address)
                } match {

val mapOutputTrackerWorker = mapOutputTracker.asInstanceOf[MapOutputTrackerWorker]
Try(mapOutputTrackerWorker
.getMapOutputLocationWithRefresh(shuffleId, mapId, address)) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refreshing map output locations in a Netty callback thread will cause potential deadlock. Here is the reason:

  1. Some map output locations are stored via broadcast variables
  2. This code has a synchronization block
  3. The netty response to fetch broadcast variables might be blocked by other handlers like the shuffle success handler
  4. In the above case, because the shuffle success handler also requires the same lock from 2), this is a deadlock

The above situation happened during my test of this code running this patch.

case Success(newAddress) =>
if (newAddress != address) {
logInfo(s"Map status location for block $blockId changed from $address " +
s"to $newAddress")
remainingBlocks -= blockId
deferredBlocks.getOrElseUpdate(newAddress,
new ArrayBuffer[String]()) += blockId
enqueueDeferredFetchRequestIfNecessary()
} else {
results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e))
}
case Failure(ex) =>
ex.addSuppressed(e)
results.put(FailureFetchResult(block, infoMap(blockId)._2, address, ex))
}
}
}
}
Expand Down Expand Up @@ -970,15 +1002,16 @@ final class ShuffleBlockFetcherIterator(
}
throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg))

case DeferFetchRequestResult(request) =>
val address = request.address
numBlocksInFlightPerAddress(address) -= request.blocks.size
bytesInFlight -= request.size
case DeferFetchRequestResult(failedAddress, newAddressToBlocks) =>
numBlocksInFlightPerAddress(failedAddress) -= newAddressToBlocks.values.map(_.size).sum
bytesInFlight -= newAddressToBlocks.values.map(_.map(_.size).sum).sum
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
val defReqQueue =
deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
defReqQueue.enqueue(request)
newAddressToBlocks.foreach { case (newAddress, blocks) =>
val defReqQueue =
deferredFetchRequests.getOrElseUpdate(newAddress, new Queue[FetchRequest]())
defReqQueue.enqueue(FetchRequest(newAddress, blocks))
}
result = null

case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) =>
Expand Down Expand Up @@ -1167,8 +1200,10 @@ final class ShuffleBlockFetcherIterator(
// immediately, defer the request until the next time it can be processed.

// Process any outstanding deferred fetch requests if possible.
// Skip when the address is the local BM Id (this may happen when shuffle migration is enabled)
if (deferredFetchRequests.nonEmpty) {
for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
for ((remoteAddress, defReqQueue) <- deferredFetchRequests
if remoteAddress != blockManager.blockManagerId) {
bozhang2820 marked this conversation as resolved.
Show resolved Hide resolved
while (isRemoteBlockFetchable(defReqQueue) &&
!isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
val request = defReqQueue.dequeue()
Expand Down Expand Up @@ -1196,6 +1231,16 @@ final class ShuffleBlockFetcherIterator(
}
}

if (deferredFetchRequests.contains(blockManager.blockManagerId)) {
// This might happen when shuffle migration is enabled
// Change the remote fetches to local fetches here
val defReqQueue = deferredFetchRequests(blockManager.blockManagerId)
val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
localBlocks ++= defReqQueue.flatMap(req => req.blocks.map(i => (i.blockId, i.mapIndex)))
fetchLocalBlocks(localBlocks)
deferredFetchRequests -= blockManager.blockManagerId
}

def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
if (request.forMergedMetas) {
pushBasedFetchHelper.sendFetchMergedStatusRequest(request)
Expand Down Expand Up @@ -1580,7 +1625,9 @@ object ShuffleBlockFetcherIterator {
* Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM
bozhang2820 marked this conversation as resolved.
Show resolved Hide resolved
*/
private[storage]
case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
case class DeferFetchRequestResult(
failedAddress: BlockManagerId,
newAddressToBlocks: HashMap[BlockManagerId, ArrayBuffer[FetchBlockInfo]]) extends FetchResult

/**
* Result of an un-successful fetch of either of these:
Expand Down