From 9f0a9c4a37f55d0e9591bb59957afcdf745a678d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 25 Jun 2018 14:46:12 -0700 Subject: [PATCH 1/2] [SPARK-24631][tests] Avoid cross-job pollution in TestUtils / SpillListener. There is a narrow race in this code that is caused when the code being run in assertSpilled / assertNotSpilled runs more than a single job. SpillListener assumed that only a single job was run, and so would only block waiting for that single job to finish when `numSpilledStages` was called. But some tests (like SQL tests that call `checkAnswer`) run more than one job, and so that wait was basically a no-op. This could cause the next test to install a listener to receive events from the previous job. Which could cause test failures in certain cases. The change fixes that race, and also uninstalls listeners after the test runs, so they don't accumulate when the SparkContext is shared among multiple tests. --- .../scala/org/apache/spark/TestUtils.scala | 49 +++++++++++-------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index b5c4c705dcbc7..5071b6750d467 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import java.security.SecureRandom import java.security.cert.X509Certificate import java.util.{Arrays, Properties} -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} +import java.util.concurrent.{TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -173,10 +173,11 @@ private[spark] object TestUtils { * Run some code involving jobs submitted to the given context and assert that the jobs spilled. */ def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { - val spillListener = new SpillListener - sc.addSparkListener(spillListener) - body - assert(spillListener.numSpilledStages > 0, s"expected $identifier to spill, but did not") + withListener(sc, new SpillListener) { listener => + val ret = body + assert(listener.numSpilledStages > 0, s"expected $identifier to spill, but did not") + ret + } } /** @@ -184,10 +185,11 @@ private[spark] object TestUtils { * did not spill. */ def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { - val spillListener = new SpillListener - sc.addSparkListener(spillListener) - body - assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did") + withListener(sc, new SpillListener) { listener => + val ret = body + assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did") + ret + } } /** @@ -233,6 +235,21 @@ private[spark] object TestUtils { } } + /** + * Runs some code with the given listener installed in the SparkContext. After the code runs, + * this method will wait until all events posted to the listener bus are processed, and then + * remove the listener from the bus. + */ + def withListener[L <: SparkListener, T](sc: SparkContext, listener: L) (body: L => T): T = { + sc.addSparkListener(listener) + try { + body(listener) + } finally { + sc.listenerBus.waitUntilEmpty(TimeUnit.SECONDS.toMillis(10)) + sc.listenerBus.removeListener(listener) + } + } + /** * Wait until at least `numExecutors` executors are up, or throw `TimeoutException` if the waiting * time elapsed before `numExecutors` executors up. Exposed for testing. @@ -289,21 +306,17 @@ private[spark] object TestUtils { private class SpillListener extends SparkListener { private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] private val spilledStageIds = new mutable.HashSet[Int] - private val stagesDone = new CountDownLatch(1) - def numSpilledStages: Int = { - // Long timeout, just in case somehow the job end isn't notified. - // Fails if a timeout occurs - assert(stagesDone.await(10, TimeUnit.SECONDS)) + def numSpilledStages: Int = synchronized { spilledStageIds.size } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { stageIdToTaskMetrics.getOrElseUpdate( taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics } - override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = { + override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = synchronized { val stageId = stageComplete.stageInfo.stageId val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten val spilled = metrics.map(_.memoryBytesSpilled).sum > 0 @@ -311,8 +324,4 @@ private class SpillListener extends SparkListener { spilledStageIds += stageId } } - - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - stagesDone.countDown() - } } From 18d5ebfd201deaebf774835ec5eb08d2b6d08454 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 26 Jun 2018 09:16:00 -0700 Subject: [PATCH 2/2] Clean up return types. --- core/src/main/scala/org/apache/spark/TestUtils.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 5071b6750d467..6cc8fe1173d2e 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -172,11 +172,10 @@ private[spark] object TestUtils { /** * Run some code involving jobs submitted to the given context and assert that the jobs spilled. */ - def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { + def assertSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = { withListener(sc, new SpillListener) { listener => - val ret = body + body assert(listener.numSpilledStages > 0, s"expected $identifier to spill, but did not") - ret } } @@ -184,11 +183,10 @@ private[spark] object TestUtils { * Run some code involving jobs submitted to the given context and assert that the jobs * did not spill. */ - def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { + def assertNotSpilled(sc: SparkContext, identifier: String)(body: => Unit): Unit = { withListener(sc, new SpillListener) { listener => - val ret = body + body assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did") - ret } } @@ -240,7 +238,7 @@ private[spark] object TestUtils { * this method will wait until all events posted to the listener bus are processed, and then * remove the listener from the bus. */ - def withListener[L <: SparkListener, T](sc: SparkContext, listener: L) (body: L => T): T = { + def withListener[L <: SparkListener](sc: SparkContext, listener: L) (body: L => Unit): Unit = { sc.addSparkListener(listener) try { body(listener)