Skip to content
Permalink
Browse files

[SPARK-23093][SS] Don't change run id when reconfiguring a continuous…

… processing query.

## What changes were proposed in this pull request?

Keep the run ID static, using a different ID for the epoch coordinator to avoid cross-execution message contamination.

## How was this patch tested?

new and existing unit tests

Author: Jose Torres <jose@databricks.com>

Closes #20282 from jose-torres/fix-runid.

(cherry picked from commit e946c63)
Signed-off-by: Shixiong Zhu <zsxwing@gmail.com>
  • Loading branch information...
jose-torres authored and zsxwing committed Jan 17, 2018
1 parent dbd2a55 commit 79ccd0cadf09c41c0f4b5853a54798be17a20584
@@ -58,7 +58,8 @@ case class DataSourceV2ScanExec(

case _: ContinuousReader =>
EpochCoordinatorRef.get(
sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env)
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env)
.askSync[Unit](SetReaderPartitions(readTasks.size()))
new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks)
.asInstanceOf[RDD[InternalRow]]
@@ -64,7 +64,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan)
val runTask = writer match {
case w: ContinuousWriter =>
EpochCoordinatorRef.get(
sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env)
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env)
.askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))

(context: TaskContext, iter: Iterator[InternalRow]) =>
@@ -135,7 +136,7 @@ object DataWritingSparkTask extends Logging {
iter: Iterator[InternalRow]): WriterCommitMessage = {
val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber())
val epochCoordinator = EpochCoordinatorRef.get(
context.getLocalProperty(ContinuousExecution.RUN_ID_KEY),
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
SparkEnv.get)
val currentMsg: WriterCommitMessage = null
var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
@@ -142,8 +142,7 @@ abstract class StreamExecution(

override val id: UUID = UUID.fromString(streamMetadata.id)

override def runId: UUID = currentRunId
protected var currentRunId = UUID.randomUUID
override val runId: UUID = UUID.randomUUID

/**
* Pretty identified string of printing in logs. Format is
@@ -59,7 +59,7 @@ class ContinuousDataSourceRDD(

val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader()

val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)
val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)

// This queue contains two types of messages:
// * (null, null) representing an epoch boundary.
@@ -68,7 +68,7 @@ class ContinuousDataSourceRDD(

val epochPollFailed = new AtomicBoolean(false)
val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
s"epoch-poll--${runId}--${context.partitionId()}")
s"epoch-poll--$coordinatorId--${context.partitionId()}")
val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed)
epochPollExecutor.scheduleWithFixedDelay(
epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
@@ -86,7 +86,7 @@ class ContinuousDataSourceRDD(
epochPollExecutor.shutdown()
})

val epochEndpoint = EpochCoordinatorRef.get(runId, SparkEnv.get)
val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get)
new Iterator[UnsafeRow] {
private val POLL_TIMEOUT_MS = 1000

@@ -150,7 +150,7 @@ class EpochPollRunnable(
private[continuous] var failureReason: Throwable = _

private val epochEndpoint = EpochCoordinatorRef.get(
context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), SparkEnv.get)
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong

override def run(): Unit = {
@@ -177,7 +177,7 @@ class DataReaderThread(
failedFlag: AtomicBoolean)
extends Thread(
s"continuous-reader--${context.partitionId()}--" +
s"${context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)}") {
s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") {
private[continuous] var failureReason: Throwable = _

override def run(): Unit = {
@@ -57,6 +57,9 @@ class ContinuousExecution(
@volatile protected var continuousSources: Seq[ContinuousReader] = _
override protected def sources: Seq[BaseStreamingSource] = continuousSources

// For use only in test harnesses.
private[sql] var currentEpochCoordinatorId: String = _

override lazy val logicalPlan: LogicalPlan = {
assert(queryExecutionThread eq Thread.currentThread,
"logicalPlan must be initialized in StreamExecutionThread " +
@@ -149,7 +152,6 @@ class ContinuousExecution(
* @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with.
*/
private def runContinuous(sparkSessionForQuery: SparkSession): Unit = {
currentRunId = UUID.randomUUID
// A list of attributes that will need to be updated.
val replacements = new ArrayBuffer[(Attribute, Attribute)]
// Translate from continuous relation to the underlying data source.
@@ -219,15 +221,19 @@ class ContinuousExecution(
lastExecution.executedPlan // Force the lazy generation of execution plan
}

sparkSession.sparkContext.setLocalProperty(
sparkSessionForQuery.sparkContext.setLocalProperty(
ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString)
sparkSession.sparkContext.setLocalProperty(
ContinuousExecution.RUN_ID_KEY, runId.toString)
// Add another random ID on top of the run ID, to distinguish epoch coordinators across
// reconfigurations.
val epochCoordinatorId = s"$runId--${UUID.randomUUID}"
currentEpochCoordinatorId = epochCoordinatorId
sparkSessionForQuery.sparkContext.setLocalProperty(
ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId)

// Use the parent Spark session for the endpoint since it's where this query ID is registered.
val epochEndpoint =
EpochCoordinatorRef.create(
writer.get(), reader, this, currentBatchId, sparkSession, SparkEnv.get)
writer.get(), reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get)
val epochUpdateThread = new Thread(new Runnable {
override def run: Unit = {
try {
@@ -359,5 +365,5 @@ class ContinuousExecution(

object ContinuousExecution {
val START_EPOCH_KEY = "__continuous_start_epoch"
val RUN_ID_KEY = "__run_id"
val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id"
}
@@ -79,7 +79,7 @@ private[sql] case class ReportPartitionOffset(

/** Helper object used to create reference to [[EpochCoordinator]]. */
private[sql] object EpochCoordinatorRef extends Logging {
private def endpointName(runId: String) = s"EpochCoordinator-$runId"
private def endpointName(id: String) = s"EpochCoordinator-$id"

/**
* Create a reference to a new [[EpochCoordinator]].
@@ -88,18 +88,19 @@ private[sql] object EpochCoordinatorRef extends Logging {
writer: ContinuousWriter,
reader: ContinuousReader,
query: ContinuousExecution,
epochCoordinatorId: String,
startEpoch: Long,
session: SparkSession,
env: SparkEnv): RpcEndpointRef = synchronized {
val coordinator = new EpochCoordinator(
writer, reader, query, startEpoch, session, env.rpcEnv)
val ref = env.rpcEnv.setupEndpoint(endpointName(query.runId.toString()), coordinator)
val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator)
logInfo("Registered EpochCoordinator endpoint")
ref
}

def get(runId: String, env: SparkEnv): RpcEndpointRef = synchronized {
val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(runId), env.conf, env.rpcEnv)
def get(id: String, env: SparkEnv): RpcEndpointRef = synchronized {
val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(id), env.conf, env.rpcEnv)
logDebug("Retrieved existing EpochCoordinator endpoint")
rpcEndpointRef
}
@@ -263,7 +263,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
def apply(): AssertOnQuery =
Execute {
case s: ContinuousExecution =>
val newEpoch = EpochCoordinatorRef.get(s.runId.toString, SparkEnv.get)
val newEpoch = EpochCoordinatorRef.get(s.currentEpochCoordinatorId, SparkEnv.get)
.askSync[Long](IncrementAndGetEpoch)
s.awaitEpoch(newEpoch - 1)
case _ => throw new IllegalStateException("microbatch cannot increment epoch")
@@ -174,6 +174,31 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
}
}

test("continuous processing listeners should receive QueryTerminatedEvent") {
val df = spark.readStream.format("rate").load()
val listeners = (1 to 5).map(_ => new EventCollector)
try {
listeners.foreach(listener => spark.streams.addListener(listener))
testStream(df, OutputMode.Append, useV2Sink = true)(
StartStream(Trigger.Continuous(1000)),
StopStream,
AssertOnQuery { query =>
eventually(Timeout(streamingTimeout)) {
listeners.foreach(listener => assert(listener.terminationEvent !== null))
listeners.foreach(listener => assert(listener.terminationEvent.id === query.id))
listeners.foreach(listener => assert(listener.terminationEvent.runId === query.runId))
listeners.foreach(listener => assert(listener.terminationEvent.exception === None))
}
listeners.foreach(listener => listener.checkAsyncErrors())
listeners.foreach(listener => listener.reset())
true
}
)
} finally {
listeners.foreach(spark.streams.removeListener)
}
}

test("adding and removing listener") {
def isListenerActive(listener: EventCollector): Boolean = {
listener.reset()

0 comments on commit 79ccd0c

Please sign in to comment.
You can’t perform that action at this time.