diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala index 5d285f89f22f5..5836944123226 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler.cluster import java.net.URL +import java.util.concurrent.atomic.AtomicReference import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.language.reflectiveCalls @@ -32,15 +33,35 @@ import org.apache.spark.ui.TestFilter class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with LocalSparkContext { + private var yarnSchedulerBackend: YarnSchedulerBackend = _ + + override def afterEach(): Unit = { + try { + if (yarnSchedulerBackend != null) { + yarnSchedulerBackend.stop() + } + } finally { + super.afterEach() + } + } + test("RequestExecutors reflects node blacklist and is serializable") { sc = new SparkContext("local", "YarnSchedulerBackendSuite") - val sched = mock[TaskSchedulerImpl] - when(sched.sc).thenReturn(sc) - val yarnSchedulerBackend = new YarnSchedulerBackend(sched, sc) { + // Subclassing the TaskSchedulerImpl here instead of using Mockito. For details see SPARK-26891. + val sched = new TaskSchedulerImpl(sc) { + val blacklistedNodes = new AtomicReference[Set[String]]() + + def setNodeBlacklist(nodeBlacklist: Set[String]): Unit = blacklistedNodes.set(nodeBlacklist) + + override def nodeBlacklist(): Set[String] = blacklistedNodes.get() + } + + val yarnSchedulerBackendExtended = new YarnSchedulerBackend(sched, sc) { def setHostToLocalTaskCount(hostToLocalTaskCount: Map[String, Int]): Unit = { this.hostToLocalTaskCount = hostToLocalTaskCount } } + yarnSchedulerBackend = yarnSchedulerBackendExtended val ser = new JavaSerializer(sc.conf).newInstance() for { blacklist <- IndexedSeq(Set[String](), Set("a", "b", "c")) @@ -50,9 +71,9 @@ class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with Loc Map("a" -> 1, "b" -> 2) ) } { - yarnSchedulerBackend.setHostToLocalTaskCount(hostToLocalCount) - when(sched.nodeBlacklist()).thenReturn(blacklist) - val req = yarnSchedulerBackend.prepareRequestExecutors(numRequested) + yarnSchedulerBackendExtended.setHostToLocalTaskCount(hostToLocalCount) + sched.setNodeBlacklist(blacklist) + val req = yarnSchedulerBackendExtended.prepareRequestExecutors(numRequested) assert(req.requestedTotal === numRequested) assert(req.nodeBlacklist === blacklist) assert(req.hostToLocalTaskCount.keySet.intersect(blacklist).isEmpty) @@ -75,9 +96,9 @@ class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with Loc // Before adding the "YARN" filter, should get the code from the filter in SparkConf. assert(TestUtils.httpResponseCode(url) === HttpServletResponse.SC_BAD_GATEWAY) - val backend = new YarnSchedulerBackend(sched, sc) { } + yarnSchedulerBackend = new YarnSchedulerBackend(sched, sc) { } - backend.addWebUIFilter(classOf[TestFilter2].getName(), + yarnSchedulerBackend.addWebUIFilter(classOf[TestFilter2].getName(), Map("responseCode" -> HttpServletResponse.SC_NOT_ACCEPTABLE.toString), "") sc.ui.get.getHandlers.foreach { h =>