Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKPORT][SPARK-23040][CORE] Returns interruptible iterator for shuffle reader #20954

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -95,7 +95,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
}

// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
val resultIter = dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
val sorter =
Expand All @@ -104,9 +104,21 @@ private[spark] class BlockStoreShuffleReader[K, C](
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
// Use completion callback to stop sorter if task was finished/cancelled.
context.addTaskCompletionListener(_ => {
sorter.stop()
})
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}

resultIter match {
case _: InterruptibleIterator[Product2[K, C]] => resultIter
case _ =>
// Use another interruptible iterator here to support task cancellation as aggregator
// or(and) sorter may have consumed previous interruptible iterator.
new InterruptibleIterator[Product2[K, C]](context, resultIter)
}
}
}
65 changes: 64 additions & 1 deletion core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark

import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicInteger

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
Expand All @@ -26,7 +27,7 @@ import scala.concurrent.duration._
import org.scalatest.BeforeAndAfter
import org.scalatest.Matchers

import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart}
import org.apache.spark.util.ThreadUtils

/**
Expand All @@ -40,6 +41,10 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
override def afterEach() {
try {
resetSparkContext()
JobCancellationSuite.taskStartedSemaphore.drainPermits()
JobCancellationSuite.taskCancelledSemaphore.drainPermits()
JobCancellationSuite.twoJobsSharingStageSemaphore.drainPermits()
JobCancellationSuite.executionOfInterruptibleCounter.set(0)
} finally {
super.afterEach()
}
Expand Down Expand Up @@ -320,6 +325,62 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
f2.get()
}

test("interruptible iterator of shuffle reader") {
// In this test case, we create a Spark job of two stages. The second stage is cancelled during
// execution and a counter is used to make sure that the corresponding tasks are indeed
// cancelled.
import JobCancellationSuite._
sc = new SparkContext("local[2]", "test interruptible iterator")

val taskCompletedSem = new Semaphore(0)

sc.addSparkListener(new SparkListener {
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
// release taskCancelledSemaphore when cancelTasks event has been posted
if (stageCompleted.stageInfo.stageId == 1) {
taskCancelledSemaphore.release(1000)
}
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
if (taskEnd.stageId == 1) { // make sure tasks are completed
taskCompletedSem.release()
}
}
})

val f = sc.parallelize(1 to 1000).map { i => (i, i) }
.repartitionAndSortWithinPartitions(new HashPartitioner(1))
.mapPartitions { iter =>
taskStartedSemaphore.release()
iter
}.foreachAsync { x =>
if (x._1 >= 10) {
// This block of code is partially executed. It will be blocked when x._1 >= 10 and the
// next iteration will be cancelled if the source iterator is interruptible. Then in this
// case, the maximum num of increment would be 10(|1...10|)
taskCancelledSemaphore.acquire()
}
executionOfInterruptibleCounter.getAndIncrement()
}

taskStartedSemaphore.acquire()
// Job is cancelled when:
// 1. task in reduce stage has been started, guaranteed by previous line.
// 2. task in reduce stage is blocked after processing at most 10 records as
// taskCancelledSemaphore is not released until cancelTasks event is posted
// After job being cancelled, task in reduce stage will be cancelled and no more iteration are
// executed.
f.cancel()

val e = intercept[SparkException](f.get()).getCause
assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))

// Make sure tasks are indeed completed.
taskCompletedSem.acquire()
assert(executionOfInterruptibleCounter.get() <= 10)
}

def testCount() {
// Cancel before launching any tasks
{
Expand Down Expand Up @@ -381,7 +442,9 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft


object JobCancellationSuite {
// To avoid any headaches, reset these global variables in the companion class's afterEach block
val taskStartedSemaphore = new Semaphore(0)
val taskCancelledSemaphore = new Semaphore(0)
val twoJobsSharingStageSemaphore = new Semaphore(0)
val executionOfInterruptibleCounter = new AtomicInteger(0)
}