Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import scala.reflect.ClassTag
import scala.reflect._

import org.apache.spark._
import org.apache.spark.util.Utils
import org.apache.spark.storage.{BlockId, RDDBlockId, StorageLevel}
import org.apache.spark.util.{CompletionIterator, Utils}

private[spark]
class CartesianPartition(
Expand Down Expand Up @@ -72,8 +73,60 @@ class CartesianRDD[T: ClassTag, U: ClassTag](

override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = {
val currSplit = split.asInstanceOf[CartesianPartition]
for (x <- rdd1.iterator(currSplit.s1, context);
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
val (iter2, readCachedBlock) =
getOrCacheBlock(rdd2, currSplit.s2, context, StorageLevel.MEMORY_AND_DISK)
val resultIter = for (x <- rdd1.iterator(currSplit.s1, context); y <- iter2) yield (x, y)

CompletionIterator[(T, U), Iterator[(T, U)]](resultIter,
removeBlock(RDDBlockId(rdd2.id, currSplit.s2.index), readCachedBlock))
}

/**
* Try to get the block from the local, if not local, then get from the remote and cache it in
* local.
*
* Because the Block may be used by another task in the same executor, so when the task is
* complete, we try to remove the block in a non-blocking manner, otherwise it will be marked
* as removable.
*/
private def getOrCacheBlock(
rdd: RDD[U],
partition: Partition,
context: TaskContext,
level: StorageLevel): (Iterator[U], Boolean) = {
val blockId = RDDBlockId(rdd.id, partition.index)
var readCachedBlock = true
// This method is called on executors, so we need call SparkEnv.get instead of sc.env.
val iterator = SparkEnv.get.blockManager.getOrElseUpdate(blockId, level, classTag[U], () => {
readCachedBlock = false
rdd.computeOrReadCheckpoint(partition, context)
}, true) match {
case Left(blockResult) =>
if (readCachedBlock) {
val existingMetrics = context.taskMetrics().inputMetrics
existingMetrics.incBytesRead(blockResult.bytes)
new InterruptibleIterator[U](context, blockResult.data.asInstanceOf[Iterator[U]]) {
override def next(): U = {
existingMetrics.incRecordsRead(1)
delegate.next()
}
}
} else {
new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[U]])
}
case Right(iter) =>
new InterruptibleIterator(context, iter.asInstanceOf[Iterator[U]])
}

(iterator, readCachedBlock)
}

private def removeBlock(blockId: BlockId,
readCachedBlock: Boolean): Unit = {
val blockManager = SparkEnv.get.blockManager
if (!readCachedBlock || blockManager.isRemovable(blockId)) {
blockManager.removeOrMarkAsRemovable(blockId, true)
}
}

override def getDependencies: Seq[Dependency[_]] = List(
Expand Down
86 changes: 79 additions & 7 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.io._
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.util.concurrent.ConcurrentHashMap

import scala.collection.mutable
import scala.collection.mutable.HashMap
Expand Down Expand Up @@ -202,6 +203,9 @@ private[spark] class BlockManager(

private var blockReplicationPolicy: BlockReplicationPolicy = _

// Record the removable block.
private lazy val removableBlocks = ConcurrentHashMap.newKeySet[BlockId]()

/**
* Initializes the BlockManager with the given appId. This is not performed in the constructor as
* the appId may not be known at BlockManager instantiation time (in particular for the driver,
Expand Down Expand Up @@ -679,23 +683,54 @@ private[spark] class BlockManager(
}

/**
* Get a block from the block manager (either local or remote).
* Get a block from the block manager (either local or remote).
*
* This acquires a read lock on the block if the block was stored locally and does not acquire
* any locks if the block was fetched from a remote block manager. The read lock will
* automatically be freed once the result's `data` iterator is fully consumed.
*/
def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
getOrCacheRemote(blockId, false, Some(StorageLevel.NONE))
}

/**
* Get a block from the block manager (either local or remote). And also can cache the block
* fetched from remote in local.
*
* This acquires a read lock on the block if the block was stored locally and does not acquire
* any locks if the block was fetched from a remote block manager. The read lock will
* automatically be freed once the result's `data` iterator is fully consumed.
* @param blockId the block under fetching.
* @param cacheRemote whether cache the block fetched remotely.
* @param storageLevel if the cacheRemote enabled, this should be set.
*/
def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
def getOrCacheRemote[T: ClassTag](
blockId: BlockId,
cacheRemote: Boolean = false,
storageLevel: Option[StorageLevel] = None): Option[BlockResult] = {
val local = getLocalValues(blockId)
if (local.isDefined) {
logInfo(s"Found block $blockId locally")
return local
}
val remote = getRemoteValues[T](blockId)
if (remote.isDefined) {
logInfo(s"Found block $blockId remotely")
return remote
remote match {
case Some(blockResult) =>
logInfo(s"Found block $blockId remotely")
if (cacheRemote) {
assert(storageLevel.isDefined && storageLevel.get.isValid,
"The storage level is invalid.")
val putResult = putIterator(blockId, blockResult.data, storageLevel.get) match {
case true => "success"
case false => "fail"
}

logInfo(s"Cache bock $blockId fetched from remotely $putResult")
}
return remote
case None =>
}

None
}

Expand Down Expand Up @@ -740,10 +775,11 @@ private[spark] class BlockManager(
blockId: BlockId,
level: StorageLevel,
classTag: ClassTag[T],
makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
makeIterator: () => Iterator[T],
cacheRemote: Boolean = false): Either[BlockResult, Iterator[T]] = {
// Attempt to read the block from local or remote storage. If it's present, then we don't need
// to go through the local-get-or-put path.
get[T](blockId)(classTag) match {
getOrCacheRemote[T](blockId, cacheRemote, Some(level))(classTag) match {
case Some(block) =>
return Left(block)
case _ =>
Expand Down Expand Up @@ -1457,6 +1493,41 @@ private[spark] class BlockManager(
}
}

/**
* Whether the block is removable.
*/
def isRemovable(blockId: BlockId): Boolean = {
removableBlocks.contains(blockId)
}

/**
* Try to remove the block without blocking. Mark it as removable if it is in use.
*/
def removeOrMarkAsRemovable(blockId: BlockId, tellMaster: Boolean = true): Unit = {
// Try to lock for writing without blocking.
blockInfoManager.lockForWriting(blockId, false) match {
case None =>
// Because lock in unblocking manner, so the block may not exist or be used by other tasks.
blockInfoManager.synchronized {
blockInfoManager.get(blockId) match {
case None =>
// The block has already been removed; do nothing.
logWarning(s"Asked to remove block $blockId, which does not exist")
removableBlocks.remove(blockId)
case Some(_) =>
// The block is in use, mark it as removable.
logDebug(s"Marking block $blockId as removable")
removableBlocks.add(blockId)
}
}
case Some(info) =>
logDebug(s"Removing block $blockId")
removableBlocks.remove(blockId)
removeBlockInternal(blockId, tellMaster = tellMaster && info.tellMaster)
addUpdatedBlockStatusToTaskMetrics(blockId, BlockStatus.empty)
}
}

private def addUpdatedBlockStatusToTaskMetrics(blockId: BlockId, status: BlockStatus): Unit = {
Option(TaskContext.get()).foreach { c =>
c.taskMetrics().incUpdatedBlockStatuses(blockId -> status)
Expand All @@ -1474,6 +1545,7 @@ private[spark] class BlockManager(
// Closing should be idempotent, but maybe not for the NioBlockTransferService.
shuffleClient.close()
}
removableBlocks.clear()
diskBlockManager.stop()
rpcEnv.stop(slaveEndpoint)
blockInfoManager.clear()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,19 @@
package org.apache.spark.storage

import java.nio.ByteBuffer
import java.util.Properties

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.concurrent.Future
import scala.language.implicitConversions
import scala.language.postfixOps
import scala.reflect.ClassTag

import org.mockito.{Matchers => mc}
import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest._
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._

import org.apache.spark._
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.executor.DataReadMethod
Expand Down Expand Up @@ -101,6 +100,16 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
blockManager
}

private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
try {
TaskContext.setTaskContext(
new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null))
block
} finally {
TaskContext.unset()
}
}

override def beforeEach(): Unit = {
super.beforeEach()
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
Expand Down Expand Up @@ -1255,6 +1264,77 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(master.getLocations("item").isEmpty)
}

test("cache block fetch remotely") {
store = makeBlockManager(8000, "executor1")
store2 = makeBlockManager(8000, "executor2")
val arr = new Array[Byte](4000)
store.putSingle("block1", arr, StorageLevel.MEMORY_AND_DISK, true)
assert(store.getSingleAndReleaseLock("block1").isDefined, "block1 was not in store")

// default not cache the remotely block
store2.getOrCacheRemote("block1")
assert(!store2.getLocalAndReleaseLock("block1").isDefined,
"block1 should not be cached by store2")

// cache remotely block
store2.getOrCacheRemote("block1", true, Some(StorageLevel.MEMORY_AND_DISK))
assert(store2.getLocalAndReleaseLock("block1").isDefined,
"block1 should be cached by store2")
assert(master.getLocations("block1").size == 2,
"master did not report 2 locations for block1")
}

test("remote block with blocking") {
store = makeBlockManager(8000, "executor1")
val arr = new Array[Byte](4000)
store.putSingle("block", arr, StorageLevel.MEMORY_AND_DISK, true)
withTaskId(0) {
store.get("block")
}
val future = Future {
withTaskId(1) {
store.removeBlock("block")
master.getLocations("block").isEmpty
}
}
Thread.sleep(300)
assert(store.getStatus("block").isDefined, "block was not in store")
withTaskId(0) {
store.releaseLock("block")
}
assert(ThreadUtils.awaitResult(future, 1.seconds))
}

test("remote block without blocking") {
store = makeBlockManager(8000, "executor1")
val arr = new Array[Byte](4000)
store.putSingle("block", arr, StorageLevel.MEMORY_AND_DISK, true)
withTaskId(0) {
// lock the block with read lock
store.get("block")
}
val future = Future {
withTaskId(1) {
store.removeOrMarkAsRemovable("block")
store.isRemovable("block")
}
}
Thread.sleep(300)
assert(store.getStatus("block").isDefined, "block should not be removed")
assert(ThreadUtils.awaitResult(future, 1.seconds), "block should be marked as removable")
withTaskId(0) {
store.releaseLock("block")
}
val future1 = Future {
withTaskId(1) {
store.removeOrMarkAsRemovable("block")
!store.isRemovable("block")
}
}
assert(ThreadUtils.awaitResult(future1, 1.seconds), "block should not be marked as removable")
assert(master.getLocations("block").isEmpty, "block should be removed")
}

class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService {
var numCalls = 0

Expand Down