From 9c399b47e8c3ceaa4525ae1a1c9dcca76b02b752 Mon Sep 17 00:00:00 2001 From: Minchu Yang Date: Fri, 5 Nov 2021 14:17:02 -0700 Subject: [PATCH] Add unit test --- .../apache/spark/MapOutputTrackerSuite.scala | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 8bebecfe14914..0ee2c77997973 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -854,4 +854,60 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { masterTracker.stop() rpcEnv.shutdown() } + + test("SPARK-37023: Avoid fetching merge status when shuffleMergeEnabled is false") { + val newConf = new SparkConf + newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + newConf.set(IS_TESTING, true) + newConf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(newConf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)) + + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(newConf)) + val slaveTracker = new MapOutputTrackerWorker(newConf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 4, 1) + slaveTracker.updateEpoch(masterTracker.getEpoch) + val bitmap = new RoaringBitmap() + bitmap.add(0) + bitmap.add(1) + bitmap.add(3) + + val blockMgrId = BlockManagerId("a", "hostA", 1000) + masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0)) + masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1)) + masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) + masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) + + masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, 0, + bitmap, 3000L)) + slaveTracker.updateEpoch(masterTracker.getEpoch) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + + val mapSizesByExecutorId = slaveTracker.getMapSizesByExecutorId(10, 0) + // mapSizesByExecutorId does not contain the merged block, since merge status is not fetched + assert(mapSizesByExecutorId.toSeq === + Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0), + (ShuffleBlockId(10, 1, 0), size1000, 1), + (ShuffleBlockId(10, 2, 0), size1000, 2), + (ShuffleBlockId(10, 3, 0), size1000, 3))))) + val pushBasedShuffleMapSizesByExecutorId = + slaveTracker.getPushBasedShuffleMapSizesByExecutorId(10, 0) + // pushBasedShuffleMapSizesByExecutorId will contain the merged block, since merge status + // is fetched + assert(pushBasedShuffleMapSizesByExecutorId.iter.toSeq === + Seq((blockMgrId, ArrayBuffer((ShuffleMergedBlockId(10, 0, 0), 3000, -1), + (ShuffleBlockId(10, 2, 0), size1000, 2))))) + + masterTracker.stop() + slaveTracker.stop() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } }