Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Minchu Yang committed Nov 5, 2021
1 parent b1cca0c commit 9c399b4
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -854,4 +854,60 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
masterTracker.stop()
rpcEnv.shutdown()
}

test("SPARK-37023: Avoid fetching merge status when shuffleMergeEnabled is false") {
val newConf = new SparkConf
newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
newConf.set(IS_TESTING, true)
newConf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
val hostname = "localhost"
val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(newConf))

val masterTracker = newTrackerMaster()
masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf))

val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(newConf))
val slaveTracker = new MapOutputTrackerWorker(newConf)
slaveTracker.trackerEndpoint =
slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)

masterTracker.registerShuffle(10, 4, 1)
slaveTracker.updateEpoch(masterTracker.getEpoch)
val bitmap = new RoaringBitmap()
bitmap.add(0)
bitmap.add(1)
bitmap.add(3)

val blockMgrId = BlockManagerId("a", "hostA", 1000)
masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0))
masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1))
masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2))
masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3))

masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, 0,
bitmap, 3000L))
slaveTracker.updateEpoch(masterTracker.getEpoch)
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))

val mapSizesByExecutorId = slaveTracker.getMapSizesByExecutorId(10, 0)
// mapSizesByExecutorId does not contain the merged block, since merge status is not fetched
assert(mapSizesByExecutorId.toSeq ===
Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0),
(ShuffleBlockId(10, 1, 0), size1000, 1),
(ShuffleBlockId(10, 2, 0), size1000, 2),
(ShuffleBlockId(10, 3, 0), size1000, 3)))))
val pushBasedShuffleMapSizesByExecutorId =
slaveTracker.getPushBasedShuffleMapSizesByExecutorId(10, 0)
// pushBasedShuffleMapSizesByExecutorId will contain the merged block, since merge status
// is fetched
assert(pushBasedShuffleMapSizesByExecutorId.iter.toSeq ===
Seq((blockMgrId, ArrayBuffer((ShuffleMergedBlockId(10, 0, 0), 3000, -1),
(ShuffleBlockId(10, 2, 0), size1000, 2)))))

masterTracker.stop()
slaveTracker.stop()
rpcEnv.shutdown()
slaveRpcEnv.shutdown()
}
}

0 comments on commit 9c399b4

Please sign in to comment.