diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 53fcc2748b4e0..7f056b8feae27 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -35,11 +35,15 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea def value = value_ def unpersist(removeSource: Boolean) { - SparkEnv.get.blockManager.master.removeBlock(blockId) - SparkEnv.get.blockManager.removeBlock(blockId) + HttpBroadcast.synchronized { + SparkEnv.get.blockManager.master.removeBlock(blockId) + SparkEnv.get.blockManager.removeBlock(blockId) + } if (removeSource) { - HttpBroadcast.cleanupById(id) + HttpBroadcast.synchronized { + HttpBroadcast.cleanupById(id) + } } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 11e74675491c6..e6a8ae199e723 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -33,19 +33,55 @@ extends Broadcast[T](id) with Logging with Serializable { def value = value_ def unpersist(removeSource: Boolean) { - SparkEnv.get.blockManager.master.removeBlock(broadcastId) - SparkEnv.get.blockManager.removeBlock(broadcastId) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.master.removeBlock(broadcastId) + SparkEnv.get.blockManager.removeBlock(broadcastId) + } + + if (!removeSource) { + //We can't tell BlockManager master to remove blocks from all nodes except driver, + //so we need to save them here in order to store them on disk later. + //This may be inefficient if blocks were already dropped to disk, + //but since unpersist is supposed to be called right after working with + //a broadcast this should not happen (and getting them from memory is cheap). + arrayOfBlocks = new Array[TorrentBlock](totalBlocks) + + for (pid <- 0 until totalBlocks) { + val pieceId = pieceBlockId(pid) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(pieceId) match { + case Some(x) => + arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + } + } + } + } + + for (pid <- 0 until totalBlocks) { + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.master.removeBlock(pieceBlockId(pid)) + } + } if (removeSource) { - for (pid <- pieceIds) { - SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid)) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.removeBlock(metaId) } - SparkEnv.get.blockManager.removeBlock(metaId) } else { - for (pid <- pieceIds) { - SparkEnv.get.blockManager.dropFromMemory(pieceBlockId(pid)) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.dropFromMemory(metaId) } - SparkEnv.get.blockManager.dropFromMemory(metaId) + + for (i <- 0 until totalBlocks) { + val pieceId = pieceBlockId(i) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + pieceId, arrayOfBlocks(i), StorageLevel.DISK_ONLY, true) + } + } + arrayOfBlocks = null } } @@ -128,11 +164,6 @@ extends Broadcast[T](id) with Logging with Serializable { } private def resetWorkerVariables() { - if (arrayOfBlocks != null) { - for (pid <- pieceIds) { - SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid)) - } - } arrayOfBlocks = null totalBytes = -1 totalBlocks = -1 diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 4e47a06c1fed2..5dff0e95b31ba 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -187,9 +187,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) */ def dropFromMemory(blockId: BlockId) { val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping - // blocks and removing entries. However the check is still here for - // future safety. + // This should never be null if called from ensureFreeSpace as only one + // thread should be dropping blocks and removing entries. + // However the check is required in other cases. if (entry != null) { val data = if (entry.deserialized) { Left(entry.value.asInstanceOf[ArrayBuffer[Any]])