diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 78ac00909ea1a..5b72755da2216 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -22,6 +22,7 @@ import java.net.Socket import java.util.Locale import scala.collection.JavaConverters._ +import scala.collection.concurrent import scala.collection.mutable import scala.util.Properties @@ -339,19 +340,26 @@ object SparkEnv extends Logging { None } - val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( - BlockManagerMaster.DRIVER_ENDPOINT_NAME, - new BlockManagerMasterEndpoint( - rpcEnv, - isLocal, - conf, - listenerBus, - if (conf.get(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED)) { - externalShuffleClient - } else { - None - })), - conf, isDriver) + // Mapping from block manager id to the block manager's information. + val blockManagerInfo = new concurrent.TrieMap[BlockManagerId, BlockManagerInfo]() + val blockManagerMaster = new BlockManagerMaster( + registerOrLookupEndpoint( + BlockManagerMaster.DRIVER_ENDPOINT_NAME, + new BlockManagerMasterEndpoint( + rpcEnv, + isLocal, + conf, + listenerBus, + if (conf.get(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED)) { + externalShuffleClient + } else { + None + }, blockManagerInfo)), + registerOrLookupEndpoint( + BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME, + new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)), + conf, + isDriver) val blockTransferService = new NettyBlockTransferService(conf, securityManager, bindAddress, advertiseAddress, diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c3e1cd8b23f14..fe3a48440991a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -267,7 +267,7 @@ private[spark] class DAGScheduler( executorUpdates: mutable.Map[(Int, Int), ExecutorMetrics]): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates, executorUpdates)) - blockManagerMaster.driverEndpoint.askSync[Boolean]( + blockManagerMaster.driverHeartbeatEndPoint.askSync[Boolean]( BlockManagerHeartbeat(blockManagerId), new RpcTimeout(10.minutes, "BlockManagerHeartbeat")) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 525304fe3c9d3..9678c917882cd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -30,6 +30,7 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils} private[spark] class BlockManagerMaster( var driverEndpoint: RpcEndpointRef, + var driverHeartbeatEndPoint: RpcEndpointRef, conf: SparkConf, isDriver: Boolean) extends Logging { @@ -230,6 +231,11 @@ class BlockManagerMaster( if (driverEndpoint != null && isDriver) { tell(StopBlockManagerMaster) driverEndpoint = null + if (driverHeartbeatEndPoint.askSync[Boolean](StopBlockManagerMaster)) { + driverHeartbeatEndPoint = null + } else { + logWarning("Failed to stop BlockManagerMasterHeartbeatEndpoint") + } logInfo("BlockManagerMaster stopped") } } @@ -245,4 +251,5 @@ class BlockManagerMaster( private[spark] object BlockManagerMaster { val DRIVER_ENDPOINT_NAME = "BlockManagerMaster" + val DRIVER_HEARTBEAT_ENDPOINT_NAME = "BlockManagerMasterHeartbeat" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 02d0e1a834909..7e2027701c33a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -36,7 +36,7 @@ import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** - * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses + * BlockManagerMasterEndpoint is an [[IsolatedRpcEndpoint]] on the master node to track statuses * of all slaves' block managers. */ private[spark] @@ -45,12 +45,10 @@ class BlockManagerMasterEndpoint( val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus, - externalBlockStoreClient: Option[ExternalBlockStoreClient]) + externalBlockStoreClient: Option[ExternalBlockStoreClient], + blockManagerInfo: mutable.Map[BlockManagerId, BlockManagerInfo]) extends IsolatedRpcEndpoint with Logging { - // Mapping from block manager id to the block manager's information. - private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] - // Mapping from external shuffle service block manager id to the block statuses. private val blockStatusByShuffleService = new mutable.HashMap[BlockManagerId, JHashMap[BlockId, BlockStatus]] @@ -144,9 +142,6 @@ class BlockManagerMasterEndpoint( case StopBlockManagerMaster => context.reply(true) stop() - - case BlockManagerHeartbeat(blockManagerId) => - context.reply(heartbeatReceived(blockManagerId)) } private def removeRdd(rddId: Int): Future[Seq[Int]] = { @@ -290,19 +285,6 @@ class BlockManagerMasterEndpoint( blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) } - /** - * Return true if the driver knows about the given block manager. Otherwise, return false, - * indicating that the block manager should re-register. - */ - private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { - if (!blockManagerInfo.contains(blockManagerId)) { - blockManagerId.isDriver && !isLocal - } else { - blockManagerInfo(blockManagerId).updateLastSeenMs() - true - } - } - // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. private def removeBlockFromWorkers(blockId: BlockId): Unit = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterHeartbeatEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterHeartbeatEndpoint.scala new file mode 100644 index 0000000000000..b06002123d803 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterHeartbeatEndpoint.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.storage.BlockManagerMessages.{BlockManagerHeartbeat, StopBlockManagerMaster} + +/** + * Separate heartbeat out of BlockManagerMasterEndpoint due to performance consideration. + */ +private[spark] class BlockManagerMasterHeartbeatEndpoint( + override val rpcEnv: RpcEnv, + isLocal: Boolean, + blockManagerInfo: mutable.Map[BlockManagerId, BlockManagerInfo]) + extends ThreadSafeRpcEndpoint with Logging { + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case BlockManagerHeartbeat(blockManagerId) => + context.reply(heartbeatReceived(blockManagerId)) + + case StopBlockManagerMaster => + stop() + context.reply(true) + + case _ => // do nothing for unexpected events + } + + /** + * Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { + if (!blockManagerInfo.contains(blockManagerId)) { + blockManagerId.isDriver && !isLocal + } else { + blockManagerInfo(blockManagerId).updateLastSeenMs() + true + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index c27d50ab66e66..1e3b59f7e97d8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -245,7 +245,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi */ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] // stub out BlockManagerMaster.getLocations to use our cacheLocations - val blockManagerMaster = new BlockManagerMaster(null, conf, true) { + val blockManagerMaster = new BlockManagerMaster(null, null, conf, true) { override def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map { _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)). diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index d8f42ea9557d9..59ace850d0bd2 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.util.Locale +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.implicitConversions @@ -97,9 +98,12 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite conf.set(STORAGE_CACHED_PEERS_TTL, 10) sc = new SparkContext("local", "test", conf) + val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]() master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(conf), None)), conf, true) + new LiveListenerBus(conf), None, blockManagerInfo)), + rpcEnv.setupEndpoint("blockmanagerHeartbeat", + new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true) allStores.clear() } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 484d246959ec2..8595f73fe5dd5 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -21,6 +21,7 @@ import java.io.File import java.nio.ByteBuffer import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.Future import scala.concurrent.duration._ @@ -142,10 +143,13 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // need to create a SparkContext is to initialize LiveListenerBus. sc = mock(classOf[SparkContext]) when(sc.conf).thenReturn(conf) - master = spy(new BlockManagerMaster( - rpcEnv.setupEndpoint("blockmanager", - new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(conf), None)), conf, true)) + + val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]() + master = spy(new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", + new BlockManagerMasterEndpoint(rpcEnv, true, conf, + new LiveListenerBus(conf), None, blockManagerInfo)), + rpcEnv.setupEndpoint("blockmanagerHeartbeat", + new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true)) val initialize = PrivateMethod[Unit](Symbol("initialize")) SizeEstimator invokePrivate initialize() @@ -468,7 +472,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - val reregister = !master.driverEndpoint.askSync[Boolean]( + val reregister = !master.driverHeartbeatEndPoint.askSync[Boolean]( BlockManagerHeartbeat(store.blockManagerId)) assert(reregister) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 286095e4ee0d7..0976494b6d094 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming import java.io.File import java.nio.ByteBuffer +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.reflect.ClassTag @@ -87,9 +88,12 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) + val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]() blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(conf), None)), conf, true) + new LiveListenerBus(conf), None, blockManagerInfo)), + rpcEnv.setupEndpoint("blockmanagerHeartbeat", + new BlockManagerMasterHeartbeatEndpoint(rpcEnv, true, blockManagerInfo)), conf, true) storageLevel = StorageLevel.MEMORY_ONLY_SER blockManager = createBlockManager(blockManagerSize, conf)