Skip to content

Commit

Permalink
Fixed failing BroadcastSuite unit tests by introducing blocking for r…
Browse files Browse the repository at this point in the history
…emoveShuffle and removeBroadcast in BlockManager*
  • Loading branch information
tdas committed Apr 4, 2014
1 parent a430f06 commit 104a89a
Show file tree
Hide file tree
Showing 14 changed files with 190 additions and 103 deletions.
28 changes: 14 additions & 14 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
logDebug("Got cleaning task " + task)
referenceBuffer -= reference.get
task match {
case CleanRDD(rddId) => doCleanupRDD(rddId)
case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId)
case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId)
case CleanRDD(rddId) => doCleanupRDD(rddId, blocking = false)
case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId, blocking = false)
case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId, blocking = false)
}
}
} catch {
Expand All @@ -124,23 +124,23 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}

/** Perform RDD cleanup. */
private def doCleanupRDD(rddId: Int) {
private def doCleanupRDD(rddId: Int, blocking: Boolean) {
try {
logDebug("Cleaning RDD " + rddId)
sc.unpersistRDD(rddId, blocking = false)
sc.unpersistRDD(rddId, blocking)
listeners.foreach(_.rddCleaned(rddId))
logInfo("Cleaned RDD " + rddId)
} catch {
case t: Throwable => logError("Error cleaning RDD " + rddId, t)
}
}

/** Perform shuffle cleanup. */
private def doCleanupShuffle(shuffleId: Int) {
/** Perform shuffle cleanup, asynchronously. */
private def doCleanupShuffle(shuffleId: Int, blocking: Boolean) {
try {
logDebug("Cleaning shuffle " + shuffleId)
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
blockManagerMaster.removeShuffle(shuffleId)
blockManagerMaster.removeShuffle(shuffleId, blocking)
listeners.foreach(_.shuffleCleaned(shuffleId))
logInfo("Cleaned shuffle " + shuffleId)
} catch {
Expand All @@ -149,10 +149,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}

/** Perform broadcast cleanup. */
private def doCleanupBroadcast(broadcastId: Long) {
private def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) {
try {
logDebug("Cleaning broadcast " + broadcastId)
broadcastManager.unbroadcast(broadcastId, removeFromDriver = true)
broadcastManager.unbroadcast(broadcastId, true, blocking)
listeners.foreach(_.broadcastCleaned(broadcastId))
logInfo("Cleaned broadcast " + broadcastId)
} catch {
Expand All @@ -164,18 +164,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
private def broadcastManager = sc.env.broadcastManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]

// Used for testing
// Used for testing, explicitly blocks until cleanup is completed

def cleanupRDD(rdd: RDD[_]) {
doCleanupRDD(rdd.id)
doCleanupRDD(rdd.id, blocking = true)
}

def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
doCleanupShuffle(shuffleDependency.shuffleId)
doCleanupShuffle(shuffleDependency.shuffleId, blocking = true)
}

def cleanupBroadcast[T](broadcast: Broadcast[T]) {
doCleanupBroadcast(broadcast.id)
doCleanupBroadcast(broadcast.id, blocking = true)
}
}

Expand Down
17 changes: 13 additions & 4 deletions core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,31 @@ abstract class Broadcast[T](val id: Long) extends Serializable {

def value: T

/**
* Asynchronously delete cached copies of this broadcast on the executors.
* If the broadcast is used after this is called, it will need to be re-sent to each executor.
*/
def unpersist() {
unpersist(blocking = false)
}

/**
* Delete cached copies of this broadcast on the executors. If the broadcast is used after
* this is called, it will need to be re-sent to each executor.
* @param blocking Whether to block until unpersisting has completed
*/
def unpersist()
def unpersist(blocking: Boolean)

/**
* Remove all persisted state associated with this broadcast on both the executors and
* the driver.
*/
private[spark] def destroy() {
private[spark] def destroy(blocking: Boolean) {
_isValid = false
onDestroy()
onDestroy(blocking)
}

protected def onDestroy()
protected def onDestroy(blocking: Boolean)

/**
* If this broadcast is no longer valid, throw an exception.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ import org.apache.spark.SparkConf
trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager)
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def unbroadcast(id: Long, removeFromDriver: Boolean)
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean)
def stop()
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private[spark] class BroadcastManager(
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}

def unbroadcast(id: Long, removeFromDriver: Boolean) {
broadcastFactory.unbroadcast(id, removeFromDriver)
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
/**
* Remove all persisted state associated with this HTTP broadcast on the executors.
*/
def unpersist() {
HttpBroadcast.unpersist(id, removeFromDriver = false)
def unpersist(blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver = false, blocking)
}

protected def onDestroy() {
HttpBroadcast.unpersist(id, removeFromDriver = true)
protected def onDestroy(blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver = true, blocking)
}

// Used by the JVM when serializing this object
Expand Down Expand Up @@ -194,8 +194,8 @@ private[spark] object HttpBroadcast extends Logging {
* If removeFromDriver is true, also remove these persisted blocks on the driver
* and delete the associated broadcast file.
*/
def unpersist(id: Long, removeFromDriver: Boolean) = synchronized {
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver)
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
if (removeFromDriver) {
val file = getFile(id)
files.remove(file.toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ class HttpBroadcastFactory extends BroadcastFactory {

/**
* Remove all persisted state associated with the HTTP broadcast with the given ID.
* @param removeFromDriver Whether to remove state from the driver.
* @param removeFromDriver Whether to remove state from the driver
* @param blocking Whether to block until unbroadcasted
*/
def unbroadcast(id: Long, removeFromDriver: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver)
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver, blocking)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
/**
* Remove all persisted state associated with this Torrent broadcast on the executors.
*/
def unpersist() {
TorrentBroadcast.unpersist(id, removeFromDriver = false)
def unpersist(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
}

protected def onDestroy() {
TorrentBroadcast.unpersist(id, removeFromDriver = true)
protected def onDestroy(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
}

private def sendBroadcast() {
Expand Down Expand Up @@ -242,8 +242,8 @@ private[spark] object TorrentBroadcast extends Logging {
* Remove all persisted blocks associated with this torrent broadcast on the executors.
* If removeFromDriver is true, also remove these persisted blocks on the driver.
*/
def unpersist(id: Long, removeFromDriver: Boolean) = synchronized {
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver)
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ class TorrentBroadcastFactory extends BroadcastFactory {
/**
* Remove all persisted state associated with the torrent broadcast with the given ID.
* @param removeFromDriver Whether to remove state from the driver.
* @param blocking Whether to block until unbroadcasted
*/
def unbroadcast(id: Long, removeFromDriver: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver)
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver, blocking)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -829,12 +829,13 @@ private[spark] class BlockManager(
/**
* Remove all blocks belonging to the given broadcast.
*/
def removeBroadcast(broadcastId: Long, tellMaster: Boolean) {
def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = {
logInfo("Removing broadcast " + broadcastId)
val blocksToRemove = blockInfo.keys.collect {
case bid @ BroadcastBlockId(`broadcastId`, _) => bid
}
blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) }
blocksToRemove.size
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,28 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
}
}

/** Remove all blocks belonging to the given shuffle asynchronously. */
def removeShuffle(shuffleId: Int) {
askDriverWithReply(RemoveShuffle(shuffleId))
/** Remove all blocks belonging to the given shuffle. */
def removeShuffle(shuffleId: Int, blocking: Boolean) {
val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
future.onFailure {
case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e)
}
if (blocking) {
Await.result(future, timeout)
}
}

/** Remove all blocks belonging to the given broadcast asynchronously. */
def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) {
askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster))
/** Remove all blocks belonging to the given broadcast. */
def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
val future = askDriverWithReply[Future[Seq[Int]]](RemoveBroadcast(broadcastId, removeFromMaster))
future.onFailure {
case e: Throwable =>
logError("Failed to remove broadcast " + broadcastId +
" with removeFromMaster = " + removeFromMaster, e)
}
if (blocking) {
Await.result(future, timeout)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
sender ! removeRdd(rddId)

case RemoveShuffle(shuffleId) =>
removeShuffle(shuffleId)
sender ! true
sender ! removeShuffle(shuffleId)

case RemoveBroadcast(broadcastId, removeFromDriver) =>
removeBroadcast(broadcastId, removeFromDriver)
sender ! true
sender ! removeBroadcast(broadcastId, removeFromDriver)

case RemoveBlock(blockId) =>
removeBlockFromWorkers(blockId)
Expand Down Expand Up @@ -150,28 +148,41 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
// The dispatcher is used as an implicit argument into the Future sequence construction.
import context.dispatcher
val removeMsg = RemoveRdd(rddId)
Future.sequence(blockManagerInfo.values.map { bm =>
bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
}.toSeq)
Future.sequence(
blockManagerInfo.values.map { bm =>
bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
}.toSeq
)
}

private def removeShuffle(shuffleId: Int) {
private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
// Nothing to do in the BlockManagerMasterActor data structures
import context.dispatcher
val removeMsg = RemoveShuffle(shuffleId)
blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg }
Future.sequence(
blockManagerInfo.values.map { bm =>
bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean]
}.toSeq
)
}

/**
* Delegate RemoveBroadcast messages to each BlockManager because the master may not notified
* of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed
* from the executors, but not from the driver.
*/
private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) {
private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
// TODO: Consolidate usages of <driver>
import context.dispatcher
val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
blockManagerInfo.values
.filter { info => removeFromDriver || info.blockManagerId.executorId != "<driver>" }
.foreach { bm => bm.slaveActor ! removeMsg }
val requiredBlockManagers = blockManagerInfo.values.filter { info =>
removeFromDriver || info.blockManagerId.executorId != "<driver>"
}
Future.sequence(
requiredBlockManagers.map { bm =>
bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
}.toSeq
)
}

private def removeBlockManager(blockManagerId: BlockManagerId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.storage

import scala.concurrent.Future

import akka.actor.Actor
import akka.actor.{ActorRef, Actor}

import org.apache.spark.{Logging, MapOutputTracker}
import org.apache.spark.storage.BlockManagerMessages._
Expand All @@ -39,35 +39,44 @@ class BlockManagerSlaveActor(
// Operations that involve removing blocks may be slow and should be done asynchronously
override def receive = {
case RemoveBlock(blockId) =>
val removeBlock = Future { blockManager.removeBlock(blockId) }
removeBlock.onFailure { case t: Throwable =>
logError("Error in removing block " + blockId, t)
doAsync("removing block", sender) {
blockManager.removeBlock(blockId)
true
}

case RemoveRdd(rddId) =>
val removeRdd = Future { sender ! blockManager.removeRdd(rddId) }
removeRdd.onFailure { case t: Throwable =>
logError("Error in removing RDD " + rddId, t)
doAsync("removing RDD", sender) {
blockManager.removeRdd(rddId)
}

case RemoveShuffle(shuffleId) =>
val removeShuffle = Future {
doAsync("removing shuffle", sender) {
blockManager.shuffleBlockManager.removeShuffle(shuffleId)
if (mapOutputTracker != null) {
mapOutputTracker.unregisterShuffle(shuffleId)
}
}
removeShuffle.onFailure { case t: Throwable =>
logError("Error in removing shuffle " + shuffleId, t)
}

case RemoveBroadcast(broadcastId, tellMaster) =>
val removeBroadcast = Future { blockManager.removeBroadcast(broadcastId, tellMaster) }
removeBroadcast.onFailure { case t: Throwable =>
logError("Error in removing broadcast " + broadcastId, t)
doAsync("removing RDD", sender) {
blockManager.removeBroadcast(broadcastId, tellMaster)
}

case GetBlockStatus(blockId, _) =>
sender ! blockManager.getStatus(blockId)
}

private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) {
val future = Future {
logDebug(actionMessage)
val response = body
response
}
future.onSuccess { case response =>
logDebug("Successful in " + actionMessage + ", response is " + response)
responseActor ! response
logDebug("Sent response: " + response + " to " + responseActor)
}
future.onFailure { case t: Throwable =>
logError("Error in " + actionMessage, t)
responseActor ! null.asInstanceOf[T]
}
}
}
Loading

0 comments on commit 104a89a

Please sign in to comment.