Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24063][SS] Add maximum epoch queue threshold for ContinuousExecution #23156

Closed
wants to merge 10 commits into from
Expand Up @@ -1413,6 +1413,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE =
buildConf("spark.sql.streaming.continuous.epochBacklogQueueSize")
.internal()
gaborgsomogyi marked this conversation as resolved.
Show resolved Hide resolved
.doc("The max number of entries to be stored in queue to wait for late epochs. " +
"If this parameter is exceeded by the size of the queue, stream will stop with an error.")
.intConf
.createWithDefault(10000)
gaborgsomogyi marked this conversation as resolved.
Show resolved Hide resolved

val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
buildConf("spark.sql.streaming.continuous.executorQueueSize")
.internal()
Expand Down Expand Up @@ -2016,6 +2024,9 @@ class SQLConf extends Serializable with Logging {

def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)

def continuousStreamingEpochBacklogQueueSize: Int =
getConf(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE)

def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE)

def continuousStreamingExecutorPollIntervalMs: Long =
Expand Down
Expand Up @@ -60,6 +60,10 @@ class ContinuousExecution(
// For use only in test harnesses.
private[sql] var currentEpochCoordinatorId: String = _

// Throwable that caused the execution to fail
private var failure: Option[Throwable] = None
gaborgsomogyi marked this conversation as resolved.
Show resolved Hide resolved
protected val failureLock = new AnyRef

override val logicalPlan: LogicalPlan = {
val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]()
analyzedPlan.transform {
Expand Down Expand Up @@ -277,6 +281,10 @@ class ContinuousExecution(
lastExecution.toRdd
}
}

failureLock.synchronized {
gaborgsomogyi marked this conversation as resolved.
Show resolved Hide resolved
failure.foreach(throw _)
}
} catch {
case t: Throwable if StreamExecution.isInterruptionException(t, sparkSession.sparkContext) &&
state.get() == RECONFIGURING =>
Expand Down Expand Up @@ -390,6 +398,41 @@ class ContinuousExecution(
}
}

/**
* Stores error and stops the query execution thread to terminate the query in new thread.
*/
def stopInNewThread(error: Throwable): Unit = {
failureLock.synchronized {
failure match {
case None =>
logError(s"Query $prettyIdString received exception $error")
failure = Some(error)
stopInNewThread()
case _ =>
// Stop already initiated
}
}
}

/**
* Stops the query execution thread to terminate the query in new thread.
*/
private def stopInNewThread(): Unit = {
new Thread("stop-continuous-execution") {
gaborgsomogyi marked this conversation as resolved.
Show resolved Hide resolved
setDaemon(true)

override def run(): Unit = {
try {
ContinuousExecution.this.stop()
} catch {
case e: Throwable =>
logError(e.getMessage, e)
throw e
}
}
}.start()
}

/**
* Stops the query execution thread to terminate the query.
*/
Expand Down
Expand Up @@ -123,6 +123,9 @@ private[continuous] class EpochCoordinator(
override val rpcEnv: RpcEnv)
extends ThreadSafeRpcEndpoint with Logging {

private val epochBacklogQueueSize =
session.sqlContext.conf.continuousStreamingEpochBacklogQueueSize

private var queryWritesStopped: Boolean = false

private var numReaderPartitions: Int = _
Expand Down Expand Up @@ -212,6 +215,7 @@ private[continuous] class EpochCoordinator(
if (!partitionCommits.isDefinedAt((epoch, partitionId))) {
partitionCommits.put((epoch, partitionId), message)
resolveCommitsAtEpoch(epoch)
checkProcessingQueueBoundaries()
}

case ReportPartitionOffset(partitionId, epoch, offset) =>
Expand All @@ -223,6 +227,22 @@ private[continuous] class EpochCoordinator(
query.addOffset(epoch, readSupport, thisEpochOffsets.toSeq)
resolveCommitsAtEpoch(epoch)
}
checkProcessingQueueBoundaries()
}

private def checkProcessingQueueBoundaries() = {
if (partitionOffsets.size > epochBacklogQueueSize) {
query.stopInNewThread(new IllegalStateException("Size of the partition offset queue has " +
"exceeded it's maximum"))
gaborgsomogyi marked this conversation as resolved.
Show resolved Hide resolved
}
if (partitionCommits.size > epochBacklogQueueSize) {
query.stopInNewThread(new IllegalStateException("Size of the partition commit queue has " +
"exceeded it's maximum"))
}
if (epochsWaitingToBeCommitted.size > epochBacklogQueueSize) {
query.stopInNewThread(new IllegalStateException("Size of the epoch queue has " +
"exceeded it's maximum"))
}
}

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
Expand Down
Expand Up @@ -343,3 +343,31 @@ class ContinuousMetaSuite extends ContinuousSuiteBase {
}
}
}

class ContinuousEpochBacklogSuite extends ContinuousSuiteBase {
import testImplicits._

// We need to specify spark.sql.streaming.continuous.epochBacklogQueueSize.
gaborgsomogyi marked this conversation as resolved.
Show resolved Hide resolved
override protected def createSparkSession = new TestSparkSession(
new SparkContext(
"local[1]",
"continuous-stream-test-sql-context",
sparkConf.set("spark.sql.testkey", "true")
.set("spark.sql.streaming.continuous.epochBacklogQueueSize", "10")))

test("epoch backlog overflow") {
val df = spark.readStream
.format("rate")
.option("numPartitions", "2")
.option("rowsPerSecond", "500")
.load()
.select('value)

testStream(df, useV2Sink = true)(
StartStream(Trigger.Continuous(1)),
ExpectFailure[IllegalStateException] { e =>
e.getMessage.contains("queue has exceeded it's maximum")
}
)
}
}
Expand Up @@ -17,9 +17,9 @@

package org.apache.spark.sql.streaming.continuous

import org.mockito.{ArgumentCaptor, InOrder}
import org.mockito.ArgumentMatchers.{any, eq => eqTo}
import org.mockito.InOrder
import org.mockito.Mockito.{inOrder, never, verify}
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfterEach
import org.scalatest.mockito.MockitoSugar

Expand All @@ -43,14 +43,20 @@ class EpochCoordinatorSuite
private var writeSupport: StreamingWriteSupport = _
private var query: ContinuousExecution = _
private var orderVerifier: InOrder = _
private val epochBacklogQueueSize = 10

override def beforeEach(): Unit = {
val reader = mock[ContinuousReadSupport]
writeSupport = mock[StreamingWriteSupport]
query = mock[ContinuousExecution]
orderVerifier = inOrder(writeSupport, query)

spark = new TestSparkSession()
spark = new TestSparkSession(
new SparkContext(
"local[2]", "test-sql-context",
new SparkConf().set("spark.sql.testkey", "true")
.set("spark.sql.streaming.continuous.epochBacklogQueueSize",
gaborgsomogyi marked this conversation as resolved.
Show resolved Hide resolved
epochBacklogQueueSize.toString)))

epochCoordinator
= EpochCoordinatorRef.create(writeSupport, reader, query, "test", 1, spark, SparkEnv.get)
Expand Down Expand Up @@ -186,6 +192,74 @@ class EpochCoordinatorSuite
verifyCommitsInOrderOf(List(1, 2, 3, 4, 5))
}

test("several epochs, max epoch backlog reached by partitionOffsets") {

gaborgsomogyi marked this conversation as resolved.
Show resolved Hide resolved
setWriterPartitions(1)
setReaderPartitions(1)

reportPartitionOffset(0, 1)

// Commit messages not arriving

for (i <- 2 to epochBacklogQueueSize + 1) {
reportPartitionOffset(0, i)
}

makeSynchronousCall()

for (i <- 1 to epochBacklogQueueSize + 1) {
verifyNoCommitFor(i)
}
verifyStoppedWithException("Size of the partition offset queue has exceeded it's maximum")
}

test("several epochs, max epoch backlog reached by partitionCommits") {

setWriterPartitions(1)
setReaderPartitions(1)

commitPartitionEpoch(0, 1)

// Offset messages not arriving

for (i <- 2 to epochBacklogQueueSize + 1) {
commitPartitionEpoch(0, i)
}

makeSynchronousCall()

for (i <- 1 to epochBacklogQueueSize + 1) {
verifyNoCommitFor(i)
}
verifyStoppedWithException("Size of the partition commit queue has exceeded it's maximum")
}

test("several epochs, max epoch backlog reached by epochsWaitingToBeCommitted") {

setWriterPartitions(2)
setReaderPartitions(2)

commitPartitionEpoch(0, 1)
reportPartitionOffset(0, 1)

// For partition 2 epoch 1 messages never arriving

// +2 because the first epoch not yet arrived
for (i <- 2 to epochBacklogQueueSize + 2) {
commitPartitionEpoch(0, i)
reportPartitionOffset(0, i)
commitPartitionEpoch(1, i)
reportPartitionOffset(1, i)
}

makeSynchronousCall()

for (i <- 1 to epochBacklogQueueSize + 2) {
verifyNoCommitFor(i)
}
verifyStoppedWithException("Size of the epoch queue has exceeded it's maximum")
}

private def setWriterPartitions(numPartitions: Int): Unit = {
epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions))
}
Expand Down Expand Up @@ -221,4 +295,13 @@ class EpochCoordinatorSuite
private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = {
epochs.foreach(verifyCommit)
}

private def verifyStoppedWithException(msg: String): Unit = {
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]);
verify(query, atLeastOnce()).stopInNewThread(exceptionCaptor.capture())

import scala.collection.JavaConverters._
val throwable = exceptionCaptor.getAllValues.asScala.find(_.getMessage === msg)
assert(throwable != null, "Stream stopped with an exception but expected message is missing")
}
}