Skip to content

Commit

Permalink
[SPARK-37023][CORE] Avoid fetching merge status when shuffleMergeEnab…
Browse files Browse the repository at this point in the history
…led is false for a shuffleDependency during retry

### What changes were proposed in this pull request?

At high level, created a helper method `getMapSizesByExecutorIdImpl` on which `getMapSizesByExecutorId` and `getPushBasedShuffleMapSizesByExecutorId` can rely. It takes a parameter `useMergeResult`, which helps to check if fetching merge result is needed or not, and pass it as `canFetchMergeResult` into `getStatuses`.

### Why are the changes needed?

During some stage retry cases, the `shuffleDependency.shuffleMergeEnabled` can be set to false, but there will be `mergeStatus` since the Driver has already collected the merged status for its shuffle dependency. If this is the case, the current implementation would set the enableBatchFetch to false, since there are mergeStatus, to cause the assertion in `MapOutoutputTracker.getMapSizesByExecutorId` failed:
```
assert(mapSizesByExecutorId.enableBatchFetch == true)
```

The proposed fix helps resolve the issue.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Passed the existing UTs.

Closes #34461 from rmcyang/SPARK-37023.

Authored-by: Minchu Yang <minyang@minyang-mn3.linkedin.biz>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
  • Loading branch information
Minchu Yang authored and Mridul Muralidharan committed Nov 11, 2021
1 parent a116045 commit f1532a2
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 12 deletions.
63 changes: 51 additions & 12 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -538,12 +538,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
val mapSizesByExecutorId = getPushBasedShuffleMapSizesByExecutorId(
shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
assert(mapSizesByExecutorId.enableBatchFetch == true)
mapSizesByExecutorId.iter
}
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

/**
* Called from executors to get the server URIs and output sizes for each shuffle block that
Expand Down Expand Up @@ -1096,7 +1091,20 @@ private[spark] class MapOutputTrackerMaster(
}

// This method is only called in local-mode.
def getPushBasedShuffleMapSizesByExecutorId(
override def getMapSizesByExecutorId(
shuffleId: Int,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
val mapSizesByExecutorId = getPushBasedShuffleMapSizesByExecutorId(
shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
assert(mapSizesByExecutorId.enableBatchFetch == true)
mapSizesByExecutorId.iter
}

// This method is only called in local-mode.
override def getPushBasedShuffleMapSizesByExecutorId(
shuffleId: Int,
startMapIndex: Int,
endMapIndex: Int,
Expand Down Expand Up @@ -1174,14 +1182,44 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
*/
private val fetchingLock = new KeyLock[Int]

override def getMapSizesByExecutorId(
shuffleId: Int,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
val mapSizesByExecutorId = getMapSizesByExecutorIdImpl(
shuffleId, startMapIndex, endMapIndex, startPartition, endPartition, useMergeResult = false)
assert(mapSizesByExecutorId.enableBatchFetch == true)
mapSizesByExecutorId.iter
}

override def getPushBasedShuffleMapSizesByExecutorId(
shuffleId: Int,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int): MapSizesByExecutorId = {
getMapSizesByExecutorIdImpl(
shuffleId, startMapIndex, endMapIndex, startPartition, endPartition, useMergeResult = true)
}

private def getMapSizesByExecutorIdImpl(
shuffleId: Int,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
useMergeResult: Boolean): MapSizesByExecutorId = {
logDebug(s"Fetching outputs for shuffle $shuffleId")
val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf)
val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf,
// EnableBatchFetch can be set to false during stage retry when the
// shuffleDependency.shuffleMergeEnabled is set to false, and Driver
// has already collected the mergedStatus for its shuffle dependency.
// In this case, boolean check helps to insure that the unnecessary
// mergeStatus won't be fetched, thus mergedOutputStatuses won't be
// passed to convertMapStatuses. See details in [SPARK-37023].
if (useMergeResult) fetchMergeResult else false)
try {
val actualEndMapIndex =
if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex
Expand All @@ -1205,7 +1243,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
// Fetch the map statuses and merge statuses again since they might have already been
// cleared by another task running in the same executor.
val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf)
val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf, fetchMergeResult)
try {
val mergeStatus = mergeResultStatuses(partitionId)
// If the original MergeStatus is no longer available, we cannot identify the list of
Expand All @@ -1230,7 +1268,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
// Fetch the map statuses and merge statuses again since they might have already been
// cleared by another task running in the same executor.
val (mapOutputStatuses, _) = getStatuses(shuffleId, conf)
val (mapOutputStatuses, _) = getStatuses(shuffleId, conf, fetchMergeResult)
try {
MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, mapOutputStatuses,
chunkTracker)
Expand All @@ -1252,8 +1290,9 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
*/
private def getStatuses(
shuffleId: Int,
conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) = {
if (fetchMergeResult) {
conf: SparkConf,
canFetchMergeResult: Boolean): (Array[MapStatus], Array[MergeStatus]) = {
if (canFetchMergeResult) {
val mapOutputStatuses = mapStatuses.get(shuffleId).orNull
val mergeOutputStatuses = mergeStatuses.get(shuffleId).orNull

Expand Down
56 changes: 56 additions & 0 deletions core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -854,4 +854,60 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
masterTracker.stop()
rpcEnv.shutdown()
}

test("SPARK-37023: Avoid fetching merge status when shuffleMergeEnabled is false") {
val newConf = new SparkConf
newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
newConf.set(IS_TESTING, true)
newConf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
val hostname = "localhost"
val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(newConf))

val masterTracker = newTrackerMaster()
masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf))

val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(newConf))
val slaveTracker = new MapOutputTrackerWorker(newConf)
slaveTracker.trackerEndpoint =
slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)

masterTracker.registerShuffle(10, 4, 1)
slaveTracker.updateEpoch(masterTracker.getEpoch)
val bitmap = new RoaringBitmap()
bitmap.add(0)
bitmap.add(1)
bitmap.add(3)

val blockMgrId = BlockManagerId("a", "hostA", 1000)
masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0))
masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1))
masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2))
masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3))

masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, 0,
bitmap, 3000L))
slaveTracker.updateEpoch(masterTracker.getEpoch)
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))

val mapSizesByExecutorId = slaveTracker.getMapSizesByExecutorId(10, 0)
// mapSizesByExecutorId does not contain the merged block, since merge status is not fetched
assert(mapSizesByExecutorId.toSeq ===
Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0),
(ShuffleBlockId(10, 1, 0), size1000, 1),
(ShuffleBlockId(10, 2, 0), size1000, 2),
(ShuffleBlockId(10, 3, 0), size1000, 3)))))
val pushBasedShuffleMapSizesByExecutorId =
slaveTracker.getPushBasedShuffleMapSizesByExecutorId(10, 0)
// pushBasedShuffleMapSizesByExecutorId will contain the merged block, since merge status
// is fetched
assert(pushBasedShuffleMapSizesByExecutorId.iter.toSeq ===
Seq((blockMgrId, ArrayBuffer((ShuffleMergedBlockId(10, 0, 0), 3000, -1),
(ShuffleBlockId(10, 2, 0), size1000, 2)))))

masterTracker.stop()
slaveTracker.stop()
rpcEnv.shutdown()
slaveRpcEnv.shutdown()
}
}

0 comments on commit f1532a2

Please sign in to comment.