diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 7b1e2af1b824f..75c87a0553a7a 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -112,9 +112,9 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { logDebug("Got cleaning task " + task) referenceBuffer -= reference.get task match { - case CleanRDD(rddId) => doCleanupRDD(rddId) - case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId) - case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId) + case CleanRDD(rddId) => doCleanupRDD(rddId, blocking = false) + case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId, blocking = false) + case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId, blocking = false) } } } catch { @@ -124,10 +124,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform RDD cleanup. */ - private def doCleanupRDD(rddId: Int) { + private def doCleanupRDD(rddId: Int, blocking: Boolean) { try { logDebug("Cleaning RDD " + rddId) - sc.unpersistRDD(rddId, blocking = false) + sc.unpersistRDD(rddId, blocking) listeners.foreach(_.rddCleaned(rddId)) logInfo("Cleaned RDD " + rddId) } catch { @@ -135,12 +135,12 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - /** Perform shuffle cleanup. */ - private def doCleanupShuffle(shuffleId: Int) { + /** Perform shuffle cleanup, asynchronously. */ + private def doCleanupShuffle(shuffleId: Int, blocking: Boolean) { try { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) - blockManagerMaster.removeShuffle(shuffleId) + blockManagerMaster.removeShuffle(shuffleId, blocking) listeners.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) } catch { @@ -149,10 +149,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform broadcast cleanup. */ - private def doCleanupBroadcast(broadcastId: Long) { + private def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) { try { logDebug("Cleaning broadcast " + broadcastId) - broadcastManager.unbroadcast(broadcastId, removeFromDriver = true) + broadcastManager.unbroadcast(broadcastId, true, blocking) listeners.foreach(_.broadcastCleaned(broadcastId)) logInfo("Cleaned broadcast " + broadcastId) } catch { @@ -164,18 +164,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] - // Used for testing + // Used for testing, explicitly blocks until cleanup is completed def cleanupRDD(rdd: RDD[_]) { - doCleanupRDD(rdd.id) + doCleanupRDD(rdd.id, blocking = true) } def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { - doCleanupShuffle(shuffleDependency.shuffleId) + doCleanupShuffle(shuffleDependency.shuffleId, blocking = true) } def cleanupBroadcast[T](broadcast: Broadcast[T]) { - doCleanupBroadcast(broadcast.id) + doCleanupBroadcast(broadcast.id, blocking = true) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index e8a97d1754901..f28b6565a830c 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -61,22 +61,31 @@ abstract class Broadcast[T](val id: Long) extends Serializable { def value: T + /** + * Asynchronously delete cached copies of this broadcast on the executors. + * If the broadcast is used after this is called, it will need to be re-sent to each executor. + */ + def unpersist() { + unpersist(blocking = false) + } + /** * Delete cached copies of this broadcast on the executors. If the broadcast is used after * this is called, it will need to be re-sent to each executor. + * @param blocking Whether to block until unpersisting has completed */ - def unpersist() + def unpersist(blocking: Boolean) /** * Remove all persisted state associated with this broadcast on both the executors and * the driver. */ - private[spark] def destroy() { + private[spark] def destroy(blocking: Boolean) { _isValid = false - onDestroy() + onDestroy(blocking) } - protected def onDestroy() + protected def onDestroy(blocking: Boolean) /** * If this broadcast is no longer valid, throw an exception. diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 9ff1675e76a5e..a7867bcaabfc2 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -29,6 +29,6 @@ import org.apache.spark.SparkConf trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] - def unbroadcast(id: Long, removeFromDriver: Boolean) + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) def stop() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index c3ea16ff9eb5e..cf62aca4d45e8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -60,7 +60,7 @@ private[spark] class BroadcastManager( broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } - def unbroadcast(id: Long, removeFromDriver: Boolean) { - broadcastFactory.unbroadcast(id, removeFromDriver) + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + broadcastFactory.unbroadcast(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index f4e2e222f4984..2d5e0352f4265 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -50,12 +50,12 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea /** * Remove all persisted state associated with this HTTP broadcast on the executors. */ - def unpersist() { - HttpBroadcast.unpersist(id, removeFromDriver = false) + def unpersist(blocking: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver = false, blocking) } - protected def onDestroy() { - HttpBroadcast.unpersist(id, removeFromDriver = true) + protected def onDestroy(blocking: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver = true, blocking) } // Used by the JVM when serializing this object @@ -194,8 +194,8 @@ private[spark] object HttpBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver * and delete the associated broadcast file. */ - def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) if (removeFromDriver) { val file = getFile(id) files.remove(file.toString) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index 4affa922156c9..2958e4f4c658a 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -34,9 +34,10 @@ class HttpBroadcastFactory extends BroadcastFactory { /** * Remove all persisted state associated with the HTTP broadcast with the given ID. - * @param removeFromDriver Whether to remove state from the driver. + * @param removeFromDriver Whether to remove state from the driver + * @param blocking Whether to block until unbroadcasted */ - def unbroadcast(id: Long, removeFromDriver: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver) + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 73eeedb8d1f63..7f37e306f0d07 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -53,12 +53,12 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo /** * Remove all persisted state associated with this Torrent broadcast on the executors. */ - def unpersist() { - TorrentBroadcast.unpersist(id, removeFromDriver = false) + def unpersist(blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking) } - protected def onDestroy() { - TorrentBroadcast.unpersist(id, removeFromDriver = true) + protected def onDestroy(blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) } private def sendBroadcast() { @@ -242,8 +242,8 @@ private[spark] object TorrentBroadcast extends Logging { * Remove all persisted blocks associated with this torrent broadcast on the executors. * If removeFromDriver is true, also remove these persisted blocks on the driver. */ - def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index eabe792b550bb..feb0e945fac19 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -36,8 +36,9 @@ class TorrentBroadcastFactory extends BroadcastFactory { /** * Remove all persisted state associated with the torrent broadcast with the given ID. * @param removeFromDriver Whether to remove state from the driver. + * @param blocking Whether to block until unbroadcasted */ - def unbroadcast(id: Long, removeFromDriver: Boolean) { - TorrentBroadcast.unpersist(id, removeFromDriver) + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 616d24ccd8b6e..4c8e718539ec7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -829,12 +829,13 @@ private[spark] class BlockManager( /** * Remove all blocks belonging to the given broadcast. */ - def removeBroadcast(broadcastId: Long, tellMaster: Boolean) { + def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { logInfo("Removing broadcast " + broadcastId) val blocksToRemove = blockInfo.keys.collect { case bid @ BroadcastBlockId(`broadcastId`, _) => bid } blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) } + blocksToRemove.size } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 73074e2188e65..29300de7d6638 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -117,14 +117,28 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } } - /** Remove all blocks belonging to the given shuffle asynchronously. */ - def removeShuffle(shuffleId: Int) { - askDriverWithReply(RemoveShuffle(shuffleId)) + /** Remove all blocks belonging to the given shuffle. */ + def removeShuffle(shuffleId: Int, blocking: Boolean) { + val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + future.onFailure { + case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e) + } + if (blocking) { + Await.result(future, timeout) + } } - /** Remove all blocks belonging to the given broadcast asynchronously. */ - def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) { - askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster)) + /** Remove all blocks belonging to the given broadcast. */ + def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { + val future = askDriverWithReply[Future[Seq[Int]]](RemoveBroadcast(broadcastId, removeFromMaster)) + future.onFailure { + case e: Throwable => + logError("Failed to remove broadcast " + broadcastId + + " with removeFromMaster = " + removeFromMaster, e) + } + if (blocking) { + Await.result(future, timeout) + } } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 3b63bf3f3774d..f238820942e34 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -100,12 +100,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus sender ! removeRdd(rddId) case RemoveShuffle(shuffleId) => - removeShuffle(shuffleId) - sender ! true + sender ! removeShuffle(shuffleId) case RemoveBroadcast(broadcastId, removeFromDriver) => - removeBroadcast(broadcastId, removeFromDriver) - sender ! true + sender ! removeBroadcast(broadcastId, removeFromDriver) case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) @@ -150,15 +148,22 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // The dispatcher is used as an implicit argument into the Future sequence construction. import context.dispatcher val removeMsg = RemoveRdd(rddId) - Future.sequence(blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] - }.toSeq) + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + }.toSeq + ) } - private def removeShuffle(shuffleId: Int) { + private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { // Nothing to do in the BlockManagerMasterActor data structures + import context.dispatcher val removeMsg = RemoveShuffle(shuffleId) - blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg } + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] + }.toSeq + ) } /** @@ -166,12 +171,18 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed * from the executors, but not from the driver. */ - private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { + private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { // TODO: Consolidate usages of + import context.dispatcher val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) - blockManagerInfo.values - .filter { info => removeFromDriver || info.blockManagerId.executorId != "" } - .foreach { bm => bm.slaveActor ! removeMsg } + val requiredBlockManagers = blockManagerInfo.values.filter { info => + removeFromDriver || info.blockManagerId.executorId != "" + } + Future.sequence( + requiredBlockManagers.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + }.toSeq + ) } private def removeBlockManager(blockManagerId: BlockManagerId) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 2396ca49a7d3f..5c91ad36371bc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import scala.concurrent.Future -import akka.actor.Actor +import akka.actor.{ActorRef, Actor} import org.apache.spark.{Logging, MapOutputTracker} import org.apache.spark.storage.BlockManagerMessages._ @@ -39,35 +39,44 @@ class BlockManagerSlaveActor( // Operations that involve removing blocks may be slow and should be done asynchronously override def receive = { case RemoveBlock(blockId) => - val removeBlock = Future { blockManager.removeBlock(blockId) } - removeBlock.onFailure { case t: Throwable => - logError("Error in removing block " + blockId, t) + doAsync("removing block", sender) { + blockManager.removeBlock(blockId) + true } case RemoveRdd(rddId) => - val removeRdd = Future { sender ! blockManager.removeRdd(rddId) } - removeRdd.onFailure { case t: Throwable => - logError("Error in removing RDD " + rddId, t) + doAsync("removing RDD", sender) { + blockManager.removeRdd(rddId) } case RemoveShuffle(shuffleId) => - val removeShuffle = Future { + doAsync("removing shuffle", sender) { blockManager.shuffleBlockManager.removeShuffle(shuffleId) - if (mapOutputTracker != null) { - mapOutputTracker.unregisterShuffle(shuffleId) - } - } - removeShuffle.onFailure { case t: Throwable => - logError("Error in removing shuffle " + shuffleId, t) } case RemoveBroadcast(broadcastId, tellMaster) => - val removeBroadcast = Future { blockManager.removeBroadcast(broadcastId, tellMaster) } - removeBroadcast.onFailure { case t: Throwable => - logError("Error in removing broadcast " + broadcastId, t) + doAsync("removing RDD", sender) { + blockManager.removeBroadcast(broadcastId, tellMaster) } case GetBlockStatus(blockId, _) => sender ! blockManager.getStatus(blockId) } + + private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { + val future = Future { + logDebug(actionMessage) + val response = body + response + } + future.onSuccess { case response => + logDebug("Successful in " + actionMessage + ", response is " + response) + responseActor ! response + logDebug("Sent response: " + response + " to " + responseActor) + } + future.onFailure { case t: Throwable => + logError("Error in " + actionMessage, t) + responseActor ! null.asInstanceOf[T] + } + } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 06233153c56d4..1f9732565709d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -176,7 +176,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { } /** Remove all the blocks / files related to a particular shuffle. */ - private def removeShuffleBlocks(shuffleId: ShuffleId) { + private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => if (consolidateShuffleFiles) { @@ -190,8 +190,10 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { } } logInfo("Deleted all files for shuffle " + shuffleId) + true case None => logInfo("Could not find files for shuffle " + shuffleId + " for deleting") + false } } diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index f1bfb6666ddda..e2f6ba80e0dbb 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -78,30 +78,48 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } - test("Unpersisting HttpBroadcast on executors only") { - testUnpersistHttpBroadcast(2, removeFromDriver = false) + test("Unpersisting HttpBroadcast on executors only in local mode") { + testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false) } - test("Unpersisting HttpBroadcast on executors and driver") { - testUnpersistHttpBroadcast(2, removeFromDriver = true) + test("Unpersisting HttpBroadcast on executors and driver in local mode") { + testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true) } - test("Unpersisting TorrentBroadcast on executors only") { - testUnpersistTorrentBroadcast(2, removeFromDriver = false) + test("Unpersisting HttpBroadcast on executors only in distributed mode") { + testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false) } - test("Unpersisting TorrentBroadcast on executors and driver") { - testUnpersistTorrentBroadcast(2, removeFromDriver = true) + test("Unpersisting HttpBroadcast on executors and driver in distributed mode") { + testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true) } + test("Unpersisting TorrentBroadcast on executors only in local mode") { + testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver in local mode") { + testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = true) + } + + test("Unpersisting TorrentBroadcast on executors only in distributed mode") { + testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") { + testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true) + } /** - * Verify the persistence of state associated with an HttpBroadcast in a local-cluster. + * Verify the persistence of state associated with an HttpBroadcast in either local mode or + * local-cluster mode (when distributed = true). * * This test creates a broadcast variable, uses it on all executors, and then unpersists it. * In between each step, this test verifies that the broadcast blocks and the broadcast file * are present only on the expected nodes. */ - private def testUnpersistHttpBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) { + val numSlaves = if (distributed) 2 else 0 + def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id)) // Verify that the broadcast file is created, and blocks are persisted only on the driver @@ -115,7 +133,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(status.memSize > 0, "Block should be in memory store on the driver") assert(status.diskSize === 0, "Block should not be in disk store on the driver") } - assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") + if (distributed) { + // this file is only generated in distributed mode + assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") + } } // Verify that blocks are persisted in both the executors and the driver @@ -138,12 +159,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { val expectedNumBlocks = if (removeFromDriver) 0 else 1 val possiblyNot = if (removeFromDriver) "" else " not" assert(statuses.size === expectedNumBlocks, - "Block should%s be unpersisted on the driver".format(possiblyNot)) - assert(removeFromDriver === !HttpBroadcast.getFile(blockIds.head.broadcastId).exists, - "Broadcast file should%s be deleted".format(possiblyNot)) + "Block should%s be unpersisted on the driver".format(possiblyNot)) + if (distributed && removeFromDriver) { + // this file is only generated in distributed mode + assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists, + "Broadcast file should%s be deleted".format(possiblyNot)) + } } - testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -154,13 +178,20 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { * In between each step, this test verifies that the broadcast blocks are present only on the * expected nodes. */ - private def testUnpersistTorrentBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) { + val numSlaves = if (distributed) 2 else 0 + def getBlockIds(id: Long) = { val broadcastBlockId = BroadcastBlockId(id) val metaBlockId = BroadcastBlockId(id, "meta") // Assume broadcast value is small enough to fit into 1 piece val pieceBlockId = BroadcastBlockId(id, "piece0") - Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) + if (distributed) { + // the metadata and piece blocks are generated only in distributed mode + Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) + } else { + Seq[BroadcastBlockId](broadcastBlockId) + } } // Verify that blocks are persisted only on the driver @@ -187,7 +218,8 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { statuses.head match { case (bm, _) => assert(bm.executorId === "") } } else { // Other blocks are on both the executors and the driver - assert(statuses.size === numSlaves + 1) + assert(statuses.size === numSlaves + 1, + blockId + " has " + statuses.size + " statuses: " + statuses.mkString(",")) statuses.foreach { case (_, status) => assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) assert(status.memSize > 0, "Block should be in memory store") @@ -209,7 +241,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } } - testUnpersistBroadcast(numSlaves, torrentConf, getBlockIds, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -223,7 +255,8 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable. */ private def testUnpersistBroadcast( - numSlaves: Int, + distributed: Boolean, + numSlaves: Int, // used only when distributed = true broadcastConf: SparkConf, getBlockIds: Long => Seq[BroadcastBlockId], afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, @@ -231,7 +264,11 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, removeFromDriver: Boolean) { - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + sc = if (distributed) { + new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + } else { + new SparkContext("local", "test", broadcastConf) + } val blockManagerMaster = sc.env.blockManager.master val list = List[Int](1, 2, 3, 4) @@ -241,15 +278,17 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { afterCreation(blocks, blockManagerMaster) // Use broadcast variable on all executors - val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + val partitions = 10 + assert(partitions > numSlaves) + val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) afterUsingBroadcast(blocks, blockManagerMaster) // Unpersist broadcast if (removeFromDriver) { - broadcast.destroy() + broadcast.destroy(blocking = true) } else { - broadcast.unpersist() + broadcast.unpersist(blocking = true) } afterUnpersist(blocks, blockManagerMaster) @@ -260,8 +299,8 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Instead, crash the driver by directly accessing the broadcast value. intercept[SparkException] { broadcast.value } } else { - val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) } }