From 8e4f2efea08b7013a2702543dc3860f9d277e3ac Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 1 Apr 2016 17:44:44 +0000 Subject: [PATCH 1/7] [SPARK-1239] Don't fetch all map output statuses at each reducer during shuffles --- .../org/apache/spark/MapOutputTracker.scala | 235 +++++++++++++++--- .../scala/org/apache/spark/SparkEnv.scala | 6 +- .../apache/spark/MapOutputTrackerSuite.scala | 91 ++++++- .../spark/scheduler/DAGSchedulerSuite.scala | 7 +- .../BlockManagerReplicationSuite.scala | 4 +- .../spark/storage/BlockManagerSuite.scala | 4 +- .../streaming/ReceivedBlockHandlerSuite.scala | 4 +- 7 files changed, 299 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 3a5caa3510eb..7923aff2e8de 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,13 +18,15 @@ package org.apache.spark import java.io._ -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor} import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.reflect.ClassTag +import scala.util.control.NonFatal +import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus @@ -37,31 +39,18 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage +private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext) + /** RpcEndpoint class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterEndpoint( override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) extends RpcEndpoint with Logging { - val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) - val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) - val serializedSize = mapOutputStatuses.length - if (serializedSize > maxRpcMessageSize) { - - val msg = s"Map output statuses were $serializedSize bytes which " + - s"exceeds spark.rpc.message.maxSize ($maxRpcMessageSize bytes)." - - /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender. - * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */ - val exception = new SparkException(msg) - logError(msg, exception) - context.sendFailure(exception) - } else { - context.reply(mapOutputStatuses) - } + val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context)) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") @@ -270,12 +259,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging /** * MapOutputTracker for the driver. */ -private[spark] class MapOutputTrackerMaster(conf: SparkConf) +private[spark] class MapOutputTrackerMaster(conf: SparkConf, + broadcastManager: BroadcastManager, isLocal: Boolean) extends MapOutputTracker(conf) { /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ private var cacheEpoch = epoch + // The size at which we use Broadcast to send the map output statuses to the executors + private val minSizeForBroadcast = conf.getInt("spark.shuffle.mapOutput.minSizeForBroadcast", + 1024 * 512) + /** Whether to compute locality preferences for reduce tasks */ private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true) @@ -296,10 +290,89 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala + private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + + // Kept in sync with cachedSerializedStatuses explicitly + // This is required so that the Broadcast variable remains in scope until we remove + // the shuffleId explicitly or implicitly. + private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]() + + // This is to prevent multiple serializations of the same shuffle - which happens when + // there is a request storm when shuffle start. + private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]() + + // requests for map output statuses + private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] + + // Thread pool used for handling map output status requests. This is a separate thread pool + // to ensure we don't block the normal dispatcher threads. + private val threadpool: ThreadPoolExecutor = { + val numThreads = conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8) + val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher") + for (i <- 0 until numThreads) { + pool.execute(new MessageLoop) + } + pool + } + + def post(message: GetMapOutputMessage): Unit = { + mapOutputRequests.offer(message) + } + + /** Message loop used for dispatching messages. */ + private class MessageLoop extends Runnable { + override def run(): Unit = { + try { + while (true) { + try { + val data = mapOutputRequests.take() + if (data == PoisonPill) { + // Put PoisonPill back so that other MessageLoops can see it. + mapOutputRequests.offer(PoisonPill) + return + } + val context = data.context + val shuffleId = data.shuffleId + val hostPort = context.senderAddress.hostPort + logDebug("Handling request to send map output locations for shuffle " + shuffleId + + " to " + hostPort) + val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) + val serializedSize = mapOutputStatuses.length + if (serializedSize > maxRpcMessageSize) { + val msg = s"Map output statuses were $serializedSize bytes which " + + s"exceeds spark.rpc.message.maxSize ($maxRpcMessageSize bytes)." + + // For SPARK-1244 we'll opt for just logging an error and then sending it to + // the sender. A bigger refactoring (SPARK-1239) will ultimately remove this + // entire code path. + val exception = new SparkException(msg) + logError(msg, exception) + context.sendFailure(exception) + } else { + context.reply(mapOutputStatuses) + } + } catch { + case NonFatal(e) => logError(e.getMessage, e) + } + } + } catch { + case ie: InterruptedException => // exit + } + } + } + + /** A poison endpoint that indicates MessageLoop should exit its message loop. */ + private val PoisonPill = new GetMapOutputMessage(-99, null) + + // Exposed for testing + private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size + def registerShuffle(shuffleId: Int, numMaps: Int) { if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } + // add in advance + shuffleIdLocks.putIfAbsent(shuffleId, new Object()) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { @@ -337,6 +410,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) override def unregisterShuffle(shuffleId: Int) { mapStatuses.remove(shuffleId) cachedSerializedStatuses.remove(shuffleId) + cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v)) + shuffleIdLocks.remove(shuffleId) } /** Check if the given shuffle is being tracked */ @@ -428,40 +503,93 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + private def removeBroadcast(bcast: Broadcast[_]): Unit = { + if (null != bcast) { + broadcastManager.unbroadcast(bcast.id, + removeFromDriver = true, blocking = false) + } + } + + private def clearCachedBroadcast(): Unit = { + for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2) + cachedSerializedBroadcast.clear() + } + def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { var statuses: Array[MapStatus] = null var epochGotten: Long = -1 epochLock.synchronized { if (epoch > cacheEpoch) { cachedSerializedStatuses.clear() + clearCachedBroadcast() cacheEpoch = epoch } cachedSerializedStatuses.get(shuffleId) match { case Some(bytes) => return bytes case None => + logDebug("cached status not found for : " + shuffleId) statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) epochGotten = epoch } } - // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "statuses"; let's serialize and return that - val bytes = MapOutputTracker.serializeMapStatuses(statuses) - logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) - // Add them into the table only if the epoch hasn't changed while we were working - epochLock.synchronized { - if (epoch == epochGotten) { - cachedSerializedStatuses(shuffleId) = bytes + + var shuffleIdLock = shuffleIdLocks.get(shuffleId) + if (null == shuffleIdLock) { + val newLock = new Object() + // in general, this condition should be false - but good to be paranoid + val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock) + shuffleIdLock = if (null != prevLock) prevLock else newLock + } + val newbytes = shuffleIdLock.synchronized { + + // double check to make sure someone else didn't serialize and cache the same + // mapstatus while we were waiting on the synchronize + epochLock.synchronized { + if (epoch > cacheEpoch) { + cachedSerializedStatuses.clear() + clearCachedBroadcast() + cacheEpoch = epoch + } + cachedSerializedStatuses.get(shuffleId) match { + case Some(bytes) => + return bytes + case None => + logDebug("shuffle lock cached status not found for : " + shuffleId) + statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) + epochGotten = epoch + } + } + + // If we got here, we failed to find the serialized locations in the cache, so we pulled + // out a snapshot of the locations as "statuses"; let's serialize and return that + val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager, + isLocal, minSizeForBroadcast) + logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) + // Add them into the table only if the epoch hasn't changed while we were working + epochLock.synchronized { + if (epoch == epochGotten) { + cachedSerializedStatuses(shuffleId) = bytes + if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast + } else { + logInfo("epoch changed, not caching!") + removeBroadcast(bcast) + } } + bytes } - bytes + newbytes } override def stop() { + mapOutputRequests.offer(PoisonPill) + threadpool.shutdown() sendTracker(StopMapOutputTracker) mapStatuses.clear() trackerEndpoint = null cachedSerializedStatuses.clear() + clearCachedBroadcast() + shuffleIdLocks.clear() } } @@ -477,12 +605,16 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private[spark] object MapOutputTracker extends Logging { val ENDPOINT_NAME = "MapOutputTracker" + private val DIRECT = 0 + private val BROADCAST = 1 // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. - def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = { + def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: BroadcastManager, + isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[Byte]]) = { val out = new ByteArrayOutputStream + out.write(DIRECT) val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) Utils.tryWithSafeFinally { // Since statuses can be modified in parallel, sync on it @@ -492,16 +624,51 @@ private[spark] object MapOutputTracker extends Logging { } { objOut.close() } - out.toByteArray + val arr = out.toByteArray + if (minBroadcastSize >= 0 && arr.length >= minBroadcastSize) { + // Use broadcast instead. + // Important arr(0) is the tag == DIRECT, ignore that while deserializing ! + val bcast = broadcastManager.newBroadcast(arr, isLocal) + // toByteArray creates copy, so we can reuse out + out.reset() + out.write(BROADCAST) + val oos = new ObjectOutputStream(new GZIPOutputStream(out)) + oos.writeObject(bcast) + oos.close() + val outArr = out.toByteArray + logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length) + (outArr, bcast) + } else { + (arr, null) + } } // Opposite of serializeMapStatuses. def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = { - val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - Utils.tryWithSafeFinally { - objIn.readObject().asInstanceOf[Array[MapStatus]] - } { - objIn.close() + assert (bytes.length > 0) + + def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = { + val objIn = new ObjectInputStream(new GZIPInputStream( + new ByteArrayInputStream(arr, off, len))) + Utils.tryWithSafeFinally { + objIn.readObject() + } { + objIn.close() + } + } + + bytes(0) match { + case DIRECT => + deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]] + case BROADCAST => + // deserialize the Broadcast, pull .value array out of it, and then deserialize that + val bcast = deserializeObject(bytes, 1, bytes.length - 1). + asInstanceOf[Broadcast[Array[Byte]]] + logInfo("Broadcast mapstatuses size = " + bytes.length + + ", actual size = " + bcast.value.length) + // Important - ignore the DIRECT tag ! Start from offset 1 + deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]] + case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0)) } } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 700e2cb3f91b..c9175ae04a1e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -285,8 +285,10 @@ object SparkEnv extends Logging { } } + val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) + val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf) + new MapOutputTrackerMaster(conf, broadcastManager, isLocal) } else { new MapOutputTrackerWorker(conf) } @@ -326,8 +328,6 @@ object SparkEnv extends Logging { serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) - val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) - val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index ddf48765ec30..024813f42319 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.mockito.Matchers.{any, isA} import org.mockito.Mockito._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException @@ -29,6 +30,14 @@ import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf + // disabled by default. + conf.setIfMissing("spark.shuffle.mapOutput.minSizeForBroadcast", "-1") + + private def newTrackerMaster(sparkConf: SparkConf = conf) = { + val broadcastManager = new BroadcastManager(true, sparkConf, + new SecurityManager(sparkConf)) + new MapOutputTrackerMaster(sparkConf, broadcastManager, true) + } def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, securityManager: SecurityManager = new SecurityManager(conf)): RpcEnv = { @@ -37,7 +46,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { test("master start and stop") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.stop() @@ -46,7 +55,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { test("master register shuffle and fetch") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) @@ -62,13 +71,14 @@ class MapOutputTrackerSuite extends SparkFunSuite { Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) .toSet) + assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.stop() rpcEnv.shutdown() } test("master register and unregister shuffle") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) @@ -80,6 +90,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) + assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty) @@ -90,7 +101,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { test("master register shuffle and unregister map output and fetch") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) @@ -101,6 +112,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) + assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) @@ -118,7 +130,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val hostname = "localhost" val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) - val masterTracker = new MapOutputTrackerMaster(conf) + val masterTracker = newTrackerMaster() masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) @@ -127,14 +139,20 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveTracker.trackerEndpoint = slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) + masterTracker.registerShuffle(10, 1) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + assert(0 == masterTracker.getNumCachedSerializedBroadcast) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0) === @@ -147,6 +165,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // failure should be cached intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.stop() slaveTracker.stop() @@ -158,8 +177,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val newConf = new SparkConf newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "-1") - val masterTracker = new MapOutputTrackerMaster(conf) + val masterTracker = newTrackerMaster(newConf) val rpcEnv = createRpcEnv("spark") val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) @@ -172,6 +192,8 @@ class MapOutputTrackerSuite extends SparkFunSuite { val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) + // receiveAndReply async so need to wait + assert(0 == masterTracker.getNumCachedSerializedBroadcast) verify(rpcCallContext).reply(any()) verify(rpcCallContext, never()).sendFailure(any()) @@ -183,8 +205,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val newConf = new SparkConf newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "-1") - val masterTracker = new MapOutputTrackerMaster(conf) + val masterTracker = newTrackerMaster(newConf) val rpcEnv = createRpcEnv("test") val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) @@ -201,16 +224,20 @@ class MapOutputTrackerSuite extends SparkFunSuite { val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) - verify(rpcCallContext, never()).reply(any()) - verify(rpcCallContext).sendFailure(isA(classOf[SparkException])) -// masterTracker.stop() // this throws an exception + // Default size for broadcast in this testsuite is set to -1 so should not cause broadcast + // to be used. + verify(rpcCallContext, timeout(5000).never()).reply(any()) + verify(rpcCallContext, timeout(5000)).sendFailure(isA(classOf[SparkException])) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) + + // masterTracker.stop() // this throws an exception rpcEnv.shutdown() } test("getLocationsWithLargestOutputs with multiple outputs in same machine") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) // Setup 3 map tasks @@ -242,4 +269,46 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.stop() rpcEnv.shutdown() } + + test("remote fetch exceeds RPC size") { + // Same test as "remote fetch exceeds rpc frame size" - except we force use of broadcast here + val newConf = new SparkConf + newConf.set("spark.rpc.message.maxSize", "1") + newConf.set("spark.rpc.askTimeout", "1") // Fail fast + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KB << 1MB framesize + + // needs TorrentBroadcast so need a SparkContext + val sc = new SparkContext("local", "MapOutputTrackerSuite", newConf) + try { + val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val rpcEnv = sc.env.rpcEnv + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.stop(masterTracker.trackerEndpoint) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) + + // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. + // Note that the size is hand-selected here because map output statuses are compressed before + // being sent. + masterTracker.registerShuffle(20, 100) + (0 until 100).foreach { i => + masterTracker.registerMapOutput(20, i, new CompressedMapStatus( + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + } + val senderAddress = RpcAddress("localhost", 12345) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.senderAddress).thenReturn(senderAddress) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) + // should succeed since majority of data is broadcast and actual serialized + // message size is small + verify(rpcCallContext, timeout(5000)).reply(any()) + verify(rpcCallContext, timeout(5000).never()).sendFailure(any()) + assert(1 == masterTracker.getNumCachedSerializedBroadcast) + masterTracker.unregisterShuffle(20) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) + + } finally { + LocalSparkContext.stop(sc) + } + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2293c11dad73..4ae5d7e0ccbe 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode @@ -157,6 +158,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } var mapOutputTracker: MapOutputTrackerMaster = null + var broadcastManager: BroadcastManager = null + var securityMgr: SecurityManager = null var scheduler: DAGScheduler = null var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null @@ -208,7 +211,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou cancelledStages.clear() cacheLocations.clear() results.clear() - mapOutputTracker = new MapOutputTrackerMaster(conf) + securityMgr = new SecurityManager(conf) + broadcastManager = new BroadcastManager(true, conf, securityMgr) + mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) scheduler = new DAGScheduler( sc, taskScheduler, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 98e8450fa145..2dbffd41582b 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService @@ -43,7 +44,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo private var rpcEnv: RpcEnv = null private var master: BlockManagerMaster = null private val securityMgr = new SecurityManager(conf) - private val mapOutputTracker = new MapOutputTrackerMaster(conf) + private val bcastManager = new BroadcastManager(true, conf, securityMgr) + private val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) private val shuffleManager = new HashShuffleManager(conf) // List of block manager created during an unit test, so that all of the them can be stopped diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 6fc32cb30a3b..22d9d703fd11 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -33,6 +33,7 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod import org.apache.spark.memory.{MemoryMode, StaticMemoryManager} import org.apache.spark.network.{BlockDataManager, BlockTransferService} @@ -59,7 +60,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(new SparkConf(false)) - val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false)) + val bcastManager = new BroadcastManager(true, conf, securityMgr) + val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) val shuffleManager = new HashShuffleManager(new SparkConf(false)) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 4e77cd6347d1..60a023636145 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.network.netty.NettyBlockTransferService @@ -57,7 +58,8 @@ class ReceivedBlockHandlerSuite val hadoopConf = new Configuration() val streamId = 1 val securityMgr = new SecurityManager(conf) - val mapOutputTracker = new MapOutputTrackerMaster(conf) + val broadcastManager = new BroadcastManager(true, conf, securityMgr) + val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) val shuffleManager = new HashShuffleManager(conf) val serializer = new KryoSerializer(conf) var serializerManager = new SerializerManager(serializer, conf) From 3c1def02b80ebfed55904e504609aa02de6559ce Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 1 Apr 2016 18:49:49 +0000 Subject: [PATCH 2/7] Update unit test --- .../scala/org/apache/spark/MapOutputTracker.scala | 2 +- .../org/apache/spark/MapOutputTrackerSuite.scala | 14 +++----------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 7923aff2e8de..9a1052503857 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -572,7 +572,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, cachedSerializedStatuses(shuffleId) = bytes if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast } else { - logInfo("epoch changed, not caching!") + logInfo("Epoch changed, not caching!") removeBroadcast(bcast) } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 024813f42319..613f419a9e1e 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -139,24 +139,19 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveTracker.trackerEndpoint = slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - assert(0 == masterTracker.getNumCachedSerializedBroadcast) - masterTracker.registerShuffle(10, 1) - assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(0 == masterTracker.getNumCachedSerializedBroadcast) intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } - assert(0 == masterTracker.getNumCachedSerializedBroadcast) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) - assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0) === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) masterTracker.incrementEpoch() @@ -192,7 +187,6 @@ class MapOutputTrackerSuite extends SparkFunSuite { val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) - // receiveAndReply async so need to wait assert(0 == masterTracker.getNumCachedSerializedBroadcast) verify(rpcCallContext).reply(any()) verify(rpcCallContext, never()).sendFailure(any()) @@ -227,8 +221,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Default size for broadcast in this testsuite is set to -1 so should not cause broadcast // to be used. - verify(rpcCallContext, timeout(5000).never()).reply(any()) - verify(rpcCallContext, timeout(5000)).sendFailure(isA(classOf[SparkException])) + verify(rpcCallContext, timeout(30000)).sendFailure(isA(classOf[SparkException])) assert(0 == masterTracker.getNumCachedSerializedBroadcast) // masterTracker.stop() // this throws an exception @@ -300,8 +293,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) // should succeed since majority of data is broadcast and actual serialized // message size is small - verify(rpcCallContext, timeout(5000)).reply(any()) - verify(rpcCallContext, timeout(5000).never()).sendFailure(any()) + verify(rpcCallContext, timeout(30000)).reply(any()) assert(1 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.unregisterShuffle(20) assert(0 == masterTracker.getNumCachedSerializedBroadcast) From ab17d52e9c8e246d555ac03e4c60c03e3ed54820 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 1 Apr 2016 21:18:54 +0000 Subject: [PATCH 3/7] Fix tests --- .../test/scala/org/apache/spark/MapOutputTrackerSuite.scala | 5 +++-- .../scala/org/apache/spark/storage/BlockManagerSuite.scala | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 613f419a9e1e..dc4a28a212e6 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -187,9 +187,10 @@ class MapOutputTrackerSuite extends SparkFunSuite { val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) + // Default size for broadcast in this testsuite is set to -1 so should not cause broadcast + // to be used. + verify(rpcCallContext, timeout(30000)).reply(any()) assert(0 == masterTracker.getNumCachedSerializedBroadcast) - verify(rpcCallContext).reply(any()) - verify(rpcCallContext, never()).sendFailure(any()) // masterTracker.stop() // this throws an exception rpcEnv.shutdown() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 22d9d703fd11..7d4a498f9013 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -60,8 +60,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(new SparkConf(false)) - val bcastManager = new BroadcastManager(true, conf, securityMgr) - val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) + val bcastManager = new BroadcastManager(true, new SparkConf(false), securityMgr) + val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true) val shuffleManager = new HashShuffleManager(new SparkConf(false)) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test From fcef95be071dac058511616d9db2ecedea75348a Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Wed, 6 Apr 2016 21:03:06 +0000 Subject: [PATCH 4/7] create function for duplicated code --- .../org/apache/spark/MapOutputTracker.scala | 54 +++++++++---------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 9a1052503857..46d1c166dd79 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -517,34 +517,13 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { var statuses: Array[MapStatus] = null + var retBytes: Array[Byte] = null var epochGotten: Long = -1 - epochLock.synchronized { - if (epoch > cacheEpoch) { - cachedSerializedStatuses.clear() - clearCachedBroadcast() - cacheEpoch = epoch - } - cachedSerializedStatuses.get(shuffleId) match { - case Some(bytes) => - return bytes - case None => - logDebug("cached status not found for : " + shuffleId) - statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) - epochGotten = epoch - } - } - var shuffleIdLock = shuffleIdLocks.get(shuffleId) - if (null == shuffleIdLock) { - val newLock = new Object() - // in general, this condition should be false - but good to be paranoid - val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock) - shuffleIdLock = if (null != prevLock) prevLock else newLock - } - val newbytes = shuffleIdLock.synchronized { - - // double check to make sure someone else didn't serialize and cache the same - // mapstatus while we were waiting on the synchronize + // Check to see if we have a cached version, returns true if it does + // and has side effect of setting retBytes. If not returns false + // with side effect of setting statuses + def checkCachedStatuses(): Boolean = { epochLock.synchronized { if (epoch > cacheEpoch) { cachedSerializedStatuses.clear() @@ -553,13 +532,31 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, } cachedSerializedStatuses.get(shuffleId) match { case Some(bytes) => - return bytes + retBytes = bytes + true case None => - logDebug("shuffle lock cached status not found for : " + shuffleId) + logDebug("cached status not found for : " + shuffleId) statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) epochGotten = epoch + false } } + } + + if (checkCachedStatuses()) return retBytes + var shuffleIdLock = shuffleIdLocks.get(shuffleId) + if (null == shuffleIdLock) { + val newLock = new Object() + // in general, this condition should be false - but good to be paranoid + val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock) + shuffleIdLock = if (null != prevLock) prevLock else newLock + } + // synchronize so we only serialize/broadcast it once since multiple threads call + // in parallel + shuffleIdLock.synchronized { + // double check to make sure someone else didn't serialize and cache the same + // mapstatus while we were waiting on the synchronize + if (checkCachedStatuses()) return retBytes // If we got here, we failed to find the serialized locations in the cache, so we pulled // out a snapshot of the locations as "statuses"; let's serialize and return that @@ -578,7 +575,6 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, } bytes } - newbytes } override def stop() { From 0382155ef73437e490abdf83af2b232adbab0eb6 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Thu, 5 May 2016 19:12:22 +0000 Subject: [PATCH 5/7] Replace maxRpcMessageSize run time check with up front check to make sure min broadcast size <= to the rpc size --- .../org/apache/spark/MapOutputTracker.scala | 27 ++++++-------- .../apache/spark/MapOutputTrackerSuite.scala | 37 +++---------------- .../spark/storage/BlockManagerSuite.scala | 1 - 3 files changed, 17 insertions(+), 48 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 46d1c166dd79..c7852612604c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -315,6 +315,16 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, pool } + // Make sure that that we aren't going to exceed the max RPC message size by making sure + // we use broadcast to send large map output statuses. + if (minSizeForBroadcast > maxRpcMessageSize) { + val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " + + s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " + + "message that is to large." + logError(msg) + throw new IllegalArgumentException(msg) + } + def post(message: GetMapOutputMessage): Unit = { mapOutputRequests.offer(message) } @@ -337,20 +347,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, logDebug("Handling request to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) - val serializedSize = mapOutputStatuses.length - if (serializedSize > maxRpcMessageSize) { - val msg = s"Map output statuses were $serializedSize bytes which " + - s"exceeds spark.rpc.message.maxSize ($maxRpcMessageSize bytes)." - - // For SPARK-1244 we'll opt for just logging an error and then sending it to - // the sender. A bigger refactoring (SPARK-1239) will ultimately remove this - // entire code path. - val exception = new SparkException(msg) - logError(msg, exception) - context.sendFailure(exception) - } else { - context.reply(mapOutputStatuses) - } + context.reply(mapOutputStatuses) } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -621,7 +618,7 @@ private[spark] object MapOutputTracker extends Logging { objOut.close() } val arr = out.toByteArray - if (minBroadcastSize >= 0 && arr.length >= minBroadcastSize) { + if (arr.length >= minBroadcastSize) { // Use broadcast instead. // Important arr(0) is the tag == DIRECT, ignore that while deserializing ! val bcast = broadcastManager.newBroadcast(arr, isLocal) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index dc4a28a212e6..c6aebc19fd12 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -30,8 +30,6 @@ import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf - // disabled by default. - conf.setIfMissing("spark.shuffle.mapOutput.minSizeForBroadcast", "-1") private def newTrackerMaster(sparkConf: SparkConf = conf) = { val broadcastManager = new BroadcastManager(true, sparkConf, @@ -172,7 +170,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val newConf = new SparkConf newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast - newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "-1") + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "1048576") val masterTracker = newTrackerMaster(newConf) val rpcEnv = createRpcEnv("spark") @@ -196,37 +194,13 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } - test("remote fetch exceeds max RPC message size") { + test("min broadcast size exceeds max RPC message size") { val newConf = new SparkConf newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast - newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "-1") + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", Int.MaxValue.toString) - val masterTracker = newTrackerMaster(newConf) - val rpcEnv = createRpcEnv("test") - val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) - rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - - // Message size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. - // Note that the size is hand-selected here because map output statuses are compressed before - // being sent. - masterTracker.registerShuffle(20, 100) - (0 until 100).foreach { i => - masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) - } - val senderAddress = RpcAddress("localhost", 12345) - val rpcCallContext = mock(classOf[RpcCallContext]) - when(rpcCallContext.senderAddress).thenReturn(senderAddress) - masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) - - // Default size for broadcast in this testsuite is set to -1 so should not cause broadcast - // to be used. - verify(rpcCallContext, timeout(30000)).sendFailure(isA(classOf[SparkException])) - assert(0 == masterTracker.getNumCachedSerializedBroadcast) - - // masterTracker.stop() // this throws an exception - rpcEnv.shutdown() + intercept[IllegalArgumentException] { newTrackerMaster(newConf) } } test("getLocationsWithLargestOutputs with multiple outputs in same machine") { @@ -264,8 +238,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } - test("remote fetch exceeds RPC size") { - // Same test as "remote fetch exceeds rpc frame size" - except we force use of broadcast here + test("remote fetch using broadcast") { val newConf = new SparkConf newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 86a7bcaae7af..a2580304c4ed 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -62,7 +62,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val securityMgr = new SecurityManager(new SparkConf(false)) val bcastManager = new BroadcastManager(true, new SparkConf(false), securityMgr) val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true) - val shuffleManager = new HashShuffleManager(new SparkConf(false)) val shuffleManager = new SortShuffleManager(new SparkConf(false)) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test From 2a520569dbcd5e0ff7a974f2e981bf99ee2c29b7 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 6 May 2016 19:09:10 +0000 Subject: [PATCH 6/7] Change to use getSizeAsBytes instead of getInt --- core/src/main/scala/org/apache/spark/MapOutputTracker.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index c7852612604c..a43294899f19 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -267,8 +267,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, private var cacheEpoch = epoch // The size at which we use Broadcast to send the map output statuses to the executors - private val minSizeForBroadcast = conf.getInt("spark.shuffle.mapOutput.minSizeForBroadcast", - 1024 * 512) + private val minSizeForBroadcast = conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", + "512k").toInt /** Whether to compute locality preferences for reduce tasks */ private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true) From 396632a030e2fc5618e5946e735b9fecacd2bda1 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 6 May 2016 19:16:51 +0000 Subject: [PATCH 7/7] fix scalastyle --- core/src/main/scala/org/apache/spark/MapOutputTracker.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index a43294899f19..6bd950205fad 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -267,8 +267,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, private var cacheEpoch = epoch // The size at which we use Broadcast to send the map output statuses to the executors - private val minSizeForBroadcast = conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", - "512k").toInt + private val minSizeForBroadcast = + conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt /** Whether to compute locality preferences for reduce tasks */ private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true)