diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 588f7d28155b9..af26abc09892f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -538,12 +538,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging startMapIndex: Int, endMapIndex: Int, startPartition: Int, - endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - val mapSizesByExecutorId = getPushBasedShuffleMapSizesByExecutorId( - shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) - assert(mapSizesByExecutorId.enableBatchFetch == true) - mapSizesByExecutorId.iter - } + endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] /** * Called from executors to get the server URIs and output sizes for each shuffle block that @@ -1096,7 +1091,20 @@ private[spark] class MapOutputTrackerMaster( } // This method is only called in local-mode. - def getPushBasedShuffleMapSizesByExecutorId( + override def getMapSizesByExecutorId( + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + val mapSizesByExecutorId = getPushBasedShuffleMapSizesByExecutorId( + shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + assert(mapSizesByExecutorId.enableBatchFetch == true) + mapSizesByExecutorId.iter + } + + // This method is only called in local-mode. + override def getPushBasedShuffleMapSizesByExecutorId( shuffleId: Int, startMapIndex: Int, endMapIndex: Int, @@ -1174,14 +1182,44 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr */ private val fetchingLock = new KeyLock[Int] + override def getMapSizesByExecutorId( + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + val mapSizesByExecutorId = getMapSizesByExecutorIdImpl( + shuffleId, startMapIndex, endMapIndex, startPartition, endPartition, useMergeResult = false) + assert(mapSizesByExecutorId.enableBatchFetch == true) + mapSizesByExecutorId.iter + } + override def getPushBasedShuffleMapSizesByExecutorId( shuffleId: Int, startMapIndex: Int, endMapIndex: Int, startPartition: Int, endPartition: Int): MapSizesByExecutorId = { + getMapSizesByExecutorIdImpl( + shuffleId, startMapIndex, endMapIndex, startPartition, endPartition, useMergeResult = true) + } + + private def getMapSizesByExecutorIdImpl( + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + useMergeResult: Boolean): MapSizesByExecutorId = { logDebug(s"Fetching outputs for shuffle $shuffleId") - val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf) + val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf, + // EnableBatchFetch can be set to false during stage retry when the + // shuffleDependency.shuffleMergeEnabled is set to false, and Driver + // has already collected the mergedStatus for its shuffle dependency. + // In this case, boolean check helps to insure that the unnecessary + // mergeStatus won't be fetched, thus mergedOutputStatuses won't be + // passed to convertMapStatuses. See details in [SPARK-37023]. + if (useMergeResult) fetchMergeResult else false) try { val actualEndMapIndex = if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex @@ -1205,7 +1243,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") // Fetch the map statuses and merge statuses again since they might have already been // cleared by another task running in the same executor. - val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf) + val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf, fetchMergeResult) try { val mergeStatus = mergeResultStatuses(partitionId) // If the original MergeStatus is no longer available, we cannot identify the list of @@ -1230,7 +1268,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") // Fetch the map statuses and merge statuses again since they might have already been // cleared by another task running in the same executor. - val (mapOutputStatuses, _) = getStatuses(shuffleId, conf) + val (mapOutputStatuses, _) = getStatuses(shuffleId, conf, fetchMergeResult) try { MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, mapOutputStatuses, chunkTracker) @@ -1252,8 +1290,9 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr */ private def getStatuses( shuffleId: Int, - conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) = { - if (fetchMergeResult) { + conf: SparkConf, + canFetchMergeResult: Boolean): (Array[MapStatus], Array[MergeStatus]) = { + if (canFetchMergeResult) { val mapOutputStatuses = mapStatuses.get(shuffleId).orNull val mergeOutputStatuses = mergeStatuses.get(shuffleId).orNull diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 8bebecfe14914..0ee2c77997973 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -854,4 +854,60 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { masterTracker.stop() rpcEnv.shutdown() } + + test("SPARK-37023: Avoid fetching merge status when shuffleMergeEnabled is false") { + val newConf = new SparkConf + newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + newConf.set(IS_TESTING, true) + newConf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(newConf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)) + + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(newConf)) + val slaveTracker = new MapOutputTrackerWorker(newConf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 4, 1) + slaveTracker.updateEpoch(masterTracker.getEpoch) + val bitmap = new RoaringBitmap() + bitmap.add(0) + bitmap.add(1) + bitmap.add(3) + + val blockMgrId = BlockManagerId("a", "hostA", 1000) + masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0)) + masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1)) + masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) + masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) + + masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, 0, + bitmap, 3000L)) + slaveTracker.updateEpoch(masterTracker.getEpoch) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + + val mapSizesByExecutorId = slaveTracker.getMapSizesByExecutorId(10, 0) + // mapSizesByExecutorId does not contain the merged block, since merge status is not fetched + assert(mapSizesByExecutorId.toSeq === + Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0), + (ShuffleBlockId(10, 1, 0), size1000, 1), + (ShuffleBlockId(10, 2, 0), size1000, 2), + (ShuffleBlockId(10, 3, 0), size1000, 3))))) + val pushBasedShuffleMapSizesByExecutorId = + slaveTracker.getPushBasedShuffleMapSizesByExecutorId(10, 0) + // pushBasedShuffleMapSizesByExecutorId will contain the merged block, since merge status + // is fetched + assert(pushBasedShuffleMapSizesByExecutorId.iter.toSeq === + Seq((blockMgrId, ArrayBuffer((ShuffleMergedBlockId(10, 0, 0), 3000, -1), + (ShuffleBlockId(10, 2, 0), size1000, 2))))) + + masterTracker.stop() + slaveTracker.stop() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } }