diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 6b3c831e60472..ea63ff5dc1580 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -125,8 +125,20 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpointRef.askWithRetry[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) + val r = RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount) + yarnSchedulerEndpoint.amEndpoint match { + case Some(am) => + try { + am.askWithRetry[Boolean](r) + } catch { + case NonFatal(e) => + logError(s"Sending $r to AM was unsuccessful", e) + return false + } + case None => + logWarning("Attempted to request executors before the AM has registered!") + return false + } } /** @@ -209,7 +221,7 @@ private[spark] abstract class YarnSchedulerBackend( */ private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private var amEndpoint: Option[RpcEndpointRef] = None + var amEndpoint: Option[RpcEndpointRef] = None private val askAmThreadPool = ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool")