Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24022][TEST] Make SparkContextSuite not flaky #21105

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 26 additions & 35 deletions core/src/test/scala/org/apache/spark/SparkContextSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark
import java.io.File
import java.net.{MalformedURLException, URI}
import java.nio.charset.StandardCharsets
import java.util.concurrent.{Semaphore, TimeUnit}
import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit}

import scala.concurrent.duration._

Expand Down Expand Up @@ -498,45 +498,36 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu

test("Cancelling stages/jobs with custom reasons.") {
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true")
val REASON = "You shall not pass"
val slices = 10

val listener = new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
if (SparkContextSuite.cancelStage) {
eventually(timeout(10.seconds)) {
assert(SparkContextSuite.isTaskStarted)
for (cancelWhat <- Seq("stage", "job")) {
// This countdown latch used to make sure stage or job canceled in listener
val latch = new CountDownLatch(1)

val listener = cancelWhat match {
case "stage" =>
new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
sc.cancelStage(taskStart.stageId, REASON)
latch.countDown()
}
}
sc.cancelStage(taskStart.stageId, REASON)
SparkContextSuite.cancelStage = false
SparkContextSuite.semaphore.release(slices)
}
}

override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
if (SparkContextSuite.cancelJob) {
eventually(timeout(10.seconds)) {
assert(SparkContextSuite.isTaskStarted)
case "job" =>
new SparkListener {
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
sc.cancelJob(jobStart.jobId, REASON)
latch.countDown()
}
}
sc.cancelJob(jobStart.jobId, REASON)
SparkContextSuite.cancelJob = false
SparkContextSuite.semaphore.release(slices)
}
}
}
sc.addSparkListener(listener)

for (cancelWhat <- Seq("stage", "job")) {
SparkContextSuite.semaphore.drainPermits()
SparkContextSuite.isTaskStarted = false
SparkContextSuite.cancelStage = (cancelWhat == "stage")
SparkContextSuite.cancelJob = (cancelWhat == "job")
sc.addSparkListener(listener)

val ex = intercept[SparkException] {
sc.range(0, 10000L, numSlices = slices).mapPartitions { x =>
SparkContextSuite.isTaskStarted = true
// Block waiting for the listener to cancel the stage or job.
SparkContextSuite.semaphore.acquire()
sc.range(0, 10000L, numSlices = 10).mapPartitions { x =>
x.synchronized {
x.wait()
}
x
}.count()
}
Expand All @@ -550,9 +541,11 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.")
}

latch.await(20, TimeUnit.SECONDS)
eventually(timeout(20.seconds)) {
assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0)
}
sc.removeSparkListener(listener)
}
}

Expand Down Expand Up @@ -637,8 +630,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
}

object SparkContextSuite {
@volatile var cancelJob = false
@volatile var cancelStage = false
@volatile var isTaskStarted = false
@volatile var taskKilled = false
@volatile var taskSucceeded = false
Expand Down