Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44635][CORE] Handle shuffle fetch failures in decommissions #42296

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 25 additions & 1 deletion core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.internal.config._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus}
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.shuffle.{MetadataFetchFailedException, MetadataUpdateFailedException}
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId}
import org.apache.spark.util._
import org.apache.spark.util.collection.OpenHashMap
Expand Down Expand Up @@ -1288,6 +1288,30 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
mapSizesByExecutorId.iter
}

def getMapOutputLocationWithRefresh(
Copy link
Contributor

@ukby1234 ukby1234 Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe return Option[BlockManagerId]?:

  def getMapOutputLocationWithRefresh(
                                       shuffleId: Int,
                                       mapId: Long,
                                       prevLocation: BlockManagerId): Option[BlockManagerId] = {
    // Try to get the cached location first in case other concurrent tasks
    // fetched the fresh location already
    getMapOutputLocation(shuffleId, mapId) match {
      case Some(location) =>
        if (location == prevLocation) {
          unregisterShuffle(shuffleId)
          getMapOutputLocation(shuffleId, mapId)
        } else {
          Some(location)
        }
      case _ =>
        None
    }
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still want to throw a MetadataFetchFailedException when failing to get a refreshed location here. So I would prefer returning a BlockManagerId and make it specific.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do the following with Option:

                val currentAddressOpt = mapOutputTrackerWorker
                  .getMapOutputLocationWithRefresh(shuffleId, mapId, address)
                currentAddressOpt match {
                  case Some(currentAddress) =>
                    if (currentAddress != address) {
                      logInfo(s"Map status location for block $blockId changed from $address " +
                        s"to $currentAddress")
                      remainingBlocks -= blockId
                      deferredBlocks.getOrElseUpdate(currentAddress, new ArrayBuffer[String]())
                        .append(blockId)
                      enqueueDeferredFetchRequestIfNecessary()
                    } else {
                      results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e))
                    }
                  case None =>
                    results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e))
                }

It is also consistent with other function signatures with getMapOutputLocation.

shuffleId: Int,
mapId: Long,
prevLocation: BlockManagerId): BlockManagerId = {
// Try to get the cached location first in case other concurrent tasks
// fetched the fresh location already
var currentLocationOpt = getMapOutputLocation(shuffleId, mapId)
if (currentLocationOpt.isDefined && currentLocationOpt.get == prevLocation) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (currentLocationOpt.isDefined && currentLocationOpt.get == prevLocation) {
if (currentLocationOpt.exists(_ == prevLocation)) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will change to currentLocationOpt.contains(prevLocation).

// Address in the cache unchanged. Try to clean cache and get a fresh location
unregisterShuffle(shuffleId)
currentLocationOpt = getMapOutputLocation(shuffleId, mapId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: we end up removing both map and merge status here - for this call second call, pass canFetchMergeResult = true in getMapOutputLocation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Will do.

}
if (currentLocationOpt.isEmpty) {
throw new MetadataUpdateFailedException(shuffleId, mapId,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you reuse MetadataFetchFailedException? We can use the message field to distinguish the error case.

message = s"Failed to get map output location for shuffleId $shuffleId, mapId $mapId")
}
currentLocationOpt.get
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: currentLocationOpt.getOrElse( throw ... )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When shuffle fallback storage is enabled, this currentLocationOptcan be the FALLBACK_BLOCK_MANAGER_ID, and DeferFetchRequestResult below doesn't handle this special case.
so either 1) check the FetchRequest for fallback storage special ID 2)rewrite the RPC address to localhost so we get the blocks inside the fallback storage.

Copy link
Contributor

@mridulm mridulm Aug 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us filter it out here, and add support for fetching from fallback in a separate pr.

+CC @dongjoon-hyun as well.

}

private def getMapOutputLocation(shuffleId: Int, mapId: Long): Option[BlockManagerId] = {
val (mapOutputStatuses, _) = getStatuses(shuffleId, conf, false)
mapOutputStatuses.filter(_ != null).find(_.mapId == mapId).map(_.location)
}

override def getPushBasedShuffleMapSizesByExecutorId(
shuffleId: Int,
startMapIndex: Int,
Expand Down
21 changes: 21 additions & 0 deletions core/src/main/scala/org/apache/spark/TestUtils.scala
Expand Up @@ -491,6 +491,27 @@ private[spark] object TestUtils {
EnumSet.of(OWNER_READ, OWNER_EXECUTE, OWNER_WRITE))
file.getPath
}

def withConf[T](confPairs: (String, String)*)(f: => T): T = {
bozhang2820 marked this conversation as resolved.
Show resolved Hide resolved
val conf = SparkEnv.get.conf
val (keys, values) = confPairs.unzip
val currentValues = keys.map { key =>
if (conf.contains(key)) {
Some(conf.get(key))
} else {
None
}
}
(keys, values).zipped.foreach { (key, value) =>
conf.set(key, value)
}
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => conf.set(key, value)
case (key, None) => conf.remove(key)
}
}
}
Comment on lines +497 to +515
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
val conf = SparkEnv.get.conf
val (keys, values) = confPairs.unzip
val currentValues = keys.map { key =>
if (conf.contains(key)) {
Some(conf.get(key))
} else {
None
}
}
(keys, values).zipped.foreach { (key, value) =>
conf.set(key, value)
}
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => conf.set(key, value)
case (key, None) => conf.remove(key)
}
}
}
def withConf[T](confPairs: (String, String)*)(f: => T): T = {
val conf = SparkEnv.get.conf
val inputConfMap = confPairs.toMap
val modifiedValues = conf.getAll.filter(kv => inputConfMap.contains(kv._1)).toMap
inputConfMap.foreach { kv =>
conf.set(kv._1, kv._2)
}
try f finally {
inputConfMap.keys.foreach { key =>
if (modifiedValues.contains(key)) {
conf.set(key, modifiedValues(key))
} else {
conf.remove(key)
}
}
}
}

}


Expand Down
Expand Up @@ -528,6 +528,15 @@ package object config {
.bytesConf(ByteUnit.BYTE)
.createOptional

private[spark] val STORAGE_DECOMMISSION_SHUFFLE_REFRESH =
ConfigBuilder("spark.storage.decommission.shuffleBlocks.refreshLocationsEnabled")
.doc("If true, executors will try to refresh the cached locations for the shuffle blocks" +
"when fetch failures happens (and decommission shuffle block migration is enabled), " +
"and retry fetching when the location changes.")
.version("3.5.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to 4.0.0

.booleanConf
.createWithDefault(false)

private[spark] val STORAGE_REPLICATION_TOPOLOGY_FILE =
ConfigBuilder("spark.storage.replication.topologyFile")
.version("2.1.0")
Expand Down
Expand Up @@ -70,3 +70,12 @@ private[spark] class MetadataFetchFailedException(
reduceId: Int,
message: String)
extends FetchFailedException(null, shuffleId, -1L, -1, reduceId, message)

/**
* Failed to update shuffle metadata (in cases like decommission).
*/
private[spark] class MetadataUpdateFailedException(
shuffleId: Int,
mapId: Long,
message: String)
extends FetchFailedException(null, shuffleId, mapId, -1, -1, message)
8 changes: 8 additions & 0 deletions core/src/main/scala/org/apache/spark/storage/BlockId.scala
Expand Up @@ -217,6 +217,14 @@ object BlockId {
val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r
val TEST = "test_(.*)".r

def getShuffleIdAndMapId(blockId: BlockId): (Int, Long) = blockId match {
case ShuffleBlockId(shuffleId, mapId, _) => (shuffleId, mapId)
case ShuffleBlockBatchId(shuffleId, mapId, _, _) => (shuffleId, mapId)
case ShuffleDataBlockId(shuffleId, mapId, _) => (shuffleId, mapId)
case ShuffleIndexBlockId(shuffleId, mapId, _) => (shuffleId, mapId)
case _ => throw new SparkException(s"Unexpected shuffle BlockId $blockId")
}

def apply(name: String): BlockId = name match {
case RDD(rddId, splitIndex) =>
RDDBlockId(rddId.toInt, splitIndex.toInt)
Expand Down
Expand Up @@ -33,10 +33,10 @@ import io.netty.util.internal.OutOfDirectMemoryError
import org.apache.commons.io.IOUtils
import org.roaringbitmap.RoaringBitmap

import org.apache.spark.{MapOutputTracker, SparkException, TaskContext}
import org.apache.spark.{MapOutputTracker, MapOutputTrackerWorker, SparkEnv, SparkException, TaskContext}
import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.internal.Logging
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle._
import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper}
Expand Down Expand Up @@ -111,6 +111,14 @@ final class ShuffleBlockFetcherIterator(
// nodes, rather than blocking on reading output from one node.
private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L)

private val isShuffleMigrationEnabled =
SparkEnv.get.conf.get(config.DECOMMISSION_ENABLED) &&
SparkEnv.get.conf.get(config.STORAGE_DECOMMISSION_ENABLED) &&
SparkEnv.get.conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED)

private val shouldPerformShuffleLocationRefresh =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about make this one of the constructor argument? One of the benefit is that you don't need to write tests with TestUtils.withConf

isShuffleMigrationEnabled && SparkEnv.get.conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_REFRESH)

/**
* Total number of blocks to fetch.
*/
Expand Down Expand Up @@ -264,18 +272,22 @@ final class ShuffleBlockFetcherIterator(
case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex))
}.toMap
val remainingBlocks = new HashSet[String]() ++= infoMap.keys
val deferredBlocks = new ArrayBuffer[String]()
val deferredBlocks = new HashMap[BlockManagerId, Queue[String]]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
val deferredBlocks = new HashMap[BlockManagerId, Queue[String]]()
val deferredBlocks = new HashMap[BlockManagerId, ArrayBuffer[String]]()

val blockIds = req.blocks.map(_.blockId.toString)
val address = req.address
val requestStartTime = clock.nanoTime()

@inline def enqueueDeferredFetchRequestIfNecessary(): Unit = {
if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) {
val blocks = deferredBlocks.map { blockId =>
val (size, mapIndex) = infoMap(blockId)
FetchBlockInfo(BlockId(blockId), size, mapIndex)
val newAddressToBlocks = new HashMap[BlockManagerId, Queue[FetchBlockInfo]]()
deferredBlocks.foreach { case (blockManagerId, blockIds) =>
val blocks = blockIds.map { blockId =>
val (size, mapIndex) = infoMap(blockId)
FetchBlockInfo(BlockId(blockId), size, mapIndex)
}
newAddressToBlocks.put(blockManagerId, blocks)
}
mridulm marked this conversation as resolved.
Show resolved Hide resolved
results.put(DeferFetchRequestResult(FetchRequest(address, blocks)))
results.put(DeferFetchRequestResult(address, newAddressToBlocks))
deferredBlocks.clear()
}
}
Expand Down Expand Up @@ -344,7 +356,8 @@ final class ShuffleBlockFetcherIterator(
s"due to Netty OOM, will retry")
}
remainingBlocks -= blockId
deferredBlocks += blockId
deferredBlocks.getOrElseUpdate(address, new Queue[String]())
.enqueue(blockId)
enqueueDeferredFetchRequestIfNecessary()
}

Expand All @@ -355,8 +368,23 @@ final class ShuffleBlockFetcherIterator(
updateMergedReqsDuration(wasReqForMergedChunks = true)
results.put(FallbackOnPushMergedFailureResult(
block, address, infoMap(blockId)._1, remainingBlocks.isEmpty))
} else {
} else if (!shouldPerformShuffleLocationRefresh) {
results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e))
} else {
val (shuffleId, mapId) = BlockId.getShuffleIdAndMapId(block)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move the getShuffleIdAndMapId into the Try ?
We will effectively block shuffle indefinitely in case getShuffleIdAndMapId throws an exception (it should not currently - but code could evolve).

Something like:

                Try {
                  val (shuffleId, mapId) = BlockId.getShuffleIdAndMapId(block)
                  mapOutputTrackerWorker
                    .getMapOutputLocationWithRefresh(shuffleId, mapId, address)
                } match {

val mapOutputTrackerWorker = mapOutputTracker.asInstanceOf[MapOutputTrackerWorker]
val currentAddress = mapOutputTrackerWorker
.getMapOutputLocationWithRefresh(shuffleId, mapId, address)
if (currentAddress != address) {
logInfo(s"Map status location for block $blockId changed from $address " +
s"to $currentAddress")
remainingBlocks -= blockId
deferredBlocks.getOrElseUpdate(currentAddress, new Queue[String]())
.enqueue(blockId)
enqueueDeferredFetchRequestIfNecessary()
} else {
results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e))
}
}
}
}
Expand Down Expand Up @@ -970,15 +998,16 @@ final class ShuffleBlockFetcherIterator(
}
throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg))

case DeferFetchRequestResult(request) =>
val address = request.address
numBlocksInFlightPerAddress(address) -= request.blocks.size
bytesInFlight -= request.size
case DeferFetchRequestResult(failedAddress, newAddressToBlocks) =>
numBlocksInFlightPerAddress(failedAddress) -= newAddressToBlocks.values.map(_.size).sum
bytesInFlight -= newAddressToBlocks.values.map(_.map(_.size).sum).sum
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
val defReqQueue =
deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
defReqQueue.enqueue(request)
newAddressToBlocks.foreach { case (newAddress, blocks) =>
val defReqQueue =
deferredFetchRequests.getOrElseUpdate(newAddress, new Queue[FetchRequest]())
defReqQueue.enqueue(FetchRequest(newAddress, blocks))
}
result = null

case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) =>
Expand Down Expand Up @@ -1580,7 +1609,9 @@ object ShuffleBlockFetcherIterator {
* Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM
bozhang2820 marked this conversation as resolved.
Show resolved Hide resolved
*/
private[storage]
case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
case class DeferFetchRequestResult(
failedAddress: BlockManagerId,
newAddressToBlocks: HashMap[BlockManagerId, Queue[FetchBlockInfo]]) extends FetchResult

/**
* Result of an un-successful fetch of either of these:
Expand Down
Expand Up @@ -39,7 +39,7 @@ import org.mockito.stubbing.Answer
import org.roaringbitmap.RoaringBitmap
import org.scalatest.PrivateMethodTester

import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext}
import org.apache.spark.{MapOutputTrackerWorker, SharedSparkContext, SparkEnv, SparkFunSuite, TaskContext, TestUtils}
import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
import org.apache.spark.network._
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
Expand All @@ -51,16 +51,19 @@ import org.apache.spark.storage.ShuffleBlockFetcherIterator._
import org.apache.spark.util.Utils


class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite
with PrivateMethodTester
with SharedSparkContext {

private var transfer: BlockTransferService = _
private var mapOutputTracker: MapOutputTracker = _
private var mapOutputTracker: MapOutputTrackerWorker = _

override def beforeEach(): Unit = {
transfer = mock(classOf[BlockTransferService])
mapOutputTracker = mock(classOf[MapOutputTracker])
mapOutputTracker = mock(classOf[MapOutputTrackerWorker])
when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any()))
.thenReturn(Seq.empty.iterator)
when(mapOutputTracker.getMapOutputLocationWithRefresh(any(), any(), any())).thenReturn(null)
}

private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*)
Expand Down Expand Up @@ -664,6 +667,66 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release()
}

test("handle map output location change") {
TestUtils.withConf(
"spark.decommission.enabled" -> "true",
"spark.storage.decommission.enabled" -> "true",
"spark.storage.decommission.shuffleBlocks.enabled" -> "true"
) {
val remoteBmId = BlockManagerId("test-remote-1", "test-remote-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
)

answerFetchBlocks { invocation =>
val host = invocation.getArgument[String](0)
val listener = invocation.getArgument[BlockFetchingListener](4)
host match {
case "test-remote-1" =>
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
// TODO: update exception type here
listener.onBlockFetchFailure(
ShuffleBlockId(0, 1, 0).toString, new RuntimeException())
listener.onBlockFetchFailure(
ShuffleBlockId(0, 2, 0).toString, new RuntimeException())
case "test-remote-2" =>
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
}
}

when(mapOutputTracker.getMapOutputLocationWithRefresh(any(), any(), any()))
.thenAnswer { invocation =>
val mapId = invocation.getArgument[Long](1)
mapId match {
case 0 => BlockManagerId("test-remote-1", "test-remote-1", 2)
case 1 => BlockManagerId("test-remote-2", "test-remote-2", 2)
case 2 => BlockManagerId("test-remote-2", "test-remote-2", 2)
}
}

Seq(true, false).foreach { isEnabled =>
SparkEnv.get.conf.set(
"spark.storage.decommission.shuffleBlocks.refreshLocationsEnabled", isEnabled.toString)
val iterator = createShuffleBlockIteratorWithDefaults(
Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0))
)
if (isEnabled) {
assert(iterator.toList.map(_._1) == blocks.keys.toList)
} else {
intercept[FetchFailedException] {
iterator.toList
}
}
}
}
}

test("fail all blocks if any of the remote request fails") {
// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
Expand Down