From 8941a4abcada873c26af924e129173dc33d66d71 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 22 Dec 2017 23:05:03 -0800 Subject: [PATCH] [SPARK-22789] Map-only continuous processing execution ## What changes were proposed in this pull request? Basic continuous execution, supporting map/flatMap/filter, with commits and advancement through RPC. ## How was this patch tested? new unit-ish tests (exercising execution end to end) Author: Jose Torres Closes #19984 from jose-torres/continuous-impl. --- project/MimaExcludes.scala | 5 + .../UnsupportedOperationChecker.scala | 25 +- .../apache/spark/sql/internal/SQLConf.scala | 21 ++ .../sources/v2/reader/ContinuousReader.java | 6 + .../sources/v2/reader/MicroBatchReader.java | 6 + .../apache/spark/sql/streaming/Trigger.java | 54 +++ .../spark/sql/execution/SparkStrategies.scala | 7 + .../datasources/v2/DataSourceV2ScanExec.scala | 20 +- .../datasources/v2/WriteToDataSourceV2.scala | 60 ++- .../streaming/BaseStreamingSource.java | 8 - .../execution/streaming/HDFSMetadataLog.scala | 14 + .../streaming/MicroBatchExecution.scala | 44 ++- .../sql/execution/streaming/OffsetSeq.scala | 2 +- .../streaming/ProgressReporter.scala | 10 +- .../streaming/RateSourceProvider.scala | 9 +- .../streaming/RateStreamOffset.scala | 5 +- .../spark/sql/execution/streaming/Sink.scala | 2 +- .../sql/execution/streaming/Source.scala | 2 +- .../execution/streaming/StreamExecution.scala | 20 +- .../execution/streaming/StreamProgress.scala | 19 +- .../streaming/StreamingRelation.scala | 47 +++ .../ContinuousDataSourceRDDIter.scala | 217 +++++++++++ .../continuous/ContinuousExecution.scala | 349 ++++++++++++++++++ .../ContinuousRateStreamSource.scala | 11 +- .../continuous/ContinuousTrigger.scala | 70 ++++ .../continuous/EpochCoordinator.scala | 191 ++++++++++ .../sources/RateStreamSourceV2.scala | 19 +- .../streaming/sources/memoryV2.scala | 13 + .../sql/streaming/DataStreamReader.scala | 38 +- .../sql/streaming/DataStreamWriter.scala | 19 +- .../sql/streaming/StreamingQueryManager.scala | 45 ++- .../org/apache/spark/sql/QueryTest.scala | 56 ++- .../streaming/RateSourceV2Suite.scala | 30 +- .../spark/sql/streaming/StreamSuite.scala | 17 +- .../spark/sql/streaming/StreamTest.scala | 55 ++- .../continuous/ContinuousSuite.scala | 316 ++++++++++++++++ 36 files changed, 1682 insertions(+), 150 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9902fedb65d59..81584af6813ea 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,11 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // SPARK-22789: Map-only continuous processing execution + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$9"), + // SPARK-22372: Make cluster submission use SparkApplication. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getSecretKeyFromUserCredentials"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.isYarnMode"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 04502d04d9509..b55043c270644 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, MonotonicallyIncreasingID} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, CurrentDate, CurrentTimestamp, MonotonicallyIncreasingID} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ @@ -339,6 +339,29 @@ object UnsupportedOperationChecker { } } + def checkForContinuous(plan: LogicalPlan, outputMode: OutputMode): Unit = { + checkForStreaming(plan, outputMode) + + plan.foreachUp { implicit subPlan => + subPlan match { + case (_: Project | _: Filter | _: MapElements | _: MapPartitions | + _: DeserializeToObject | _: SerializeFromObject) => + case node if node.nodeName == "StreamingRelationV2" => + case node => + throwError(s"Continuous processing does not support ${node.nodeName} operations.") + } + + subPlan.expressions.foreach { e => + if (e.collectLeaves().exists { + case (_: CurrentTimestamp | _: CurrentDate) => true + case _ => false + }) { + throwError(s"Continuous processing does not support current time operations.") + } + } + } + } + private def throwErrorIf( condition: Boolean, msg: String)(implicit operator: LogicalPlan): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bdc8d92e84079..84fe4bb711a4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1044,6 +1044,22 @@ object SQLConf { "When this conf is not set, the value from `spark.redaction.string.regex` is used.") .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = + buildConf("spark.sql.streaming.continuous.executorQueueSize") + .internal() + .doc("The size (measured in number of rows) of the queue used in continuous execution to" + + " buffer the results of a ContinuousDataReader.") + .intConf + .createWithDefault(1024) + + val CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS = + buildConf("spark.sql.streaming.continuous.executorPollIntervalMs") + .internal() + .doc("The interval at which continuous execution readers will poll to check whether" + + " the epoch has advanced on the driver.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(100) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1357,6 +1373,11 @@ class SQLConf extends Serializable with Logging { def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) + + def continuousStreamingExecutorPollIntervalMs: Long = + getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java index 1baf82c2df762..34141d6cd85fd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReader.java @@ -65,4 +65,10 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reade default boolean needsReconfiguration() { return false; } + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + void commit(Offset end); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java index 438e3f55b7bcf..bd15c07d87f6c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReader.java @@ -61,4 +61,10 @@ public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSourc * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader */ Offset deserializeOffset(String json); + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + void commit(Offset end); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java index d31790a285687..33ae9a9e87668 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java @@ -22,6 +22,7 @@ import scala.concurrent.duration.Duration; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger; import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; /** @@ -95,4 +96,57 @@ public static Trigger ProcessingTime(String interval) { public static Trigger Once() { return OneTimeTrigger$.MODULE$; } + + /** + * A trigger that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. + * + * @since 2.3.0 + */ + public static Trigger Continuous(long intervalMs) { + return ContinuousTrigger.apply(intervalMs); + } + + /** + * A trigger that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. + * + * {{{ + * import java.util.concurrent.TimeUnit + * df.writeStream.trigger(Trigger.Continuous(10, TimeUnit.SECONDS)) + * }}} + * + * @since 2.3.0 + */ + public static Trigger Continuous(long interval, TimeUnit timeUnit) { + return ContinuousTrigger.create(interval, timeUnit); + } + + /** + * (Scala-friendly) + * A trigger that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. + * + * {{{ + * import scala.concurrent.duration._ + * df.writeStream.trigger(Trigger.Continuous(10.seconds)) + * }}} + * @since 2.3.0 + */ + public static Trigger Continuous(Duration interval) { + return ContinuousTrigger.apply(interval); + } + + /** + * A trigger that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. + * + * {{{ + * df.writeStream.trigger(Trigger.Continuous("10 seconds")) + * }}} + * @since 2.3.0 + */ + public static Trigger Continuous(String interval) { + return ContinuousTrigger.apply(interval); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9e713cd7bbe2b..8c6c324d456c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -31,8 +31,10 @@ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.types.StructType /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -374,6 +376,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { StreamingRelationExec(s.sourceName, s.output) :: Nil case s: StreamingExecutionRelation => StreamingRelationExec(s.toString, s.output) :: Nil + case s: StreamingRelationV2 => + StreamingRelationExec(s.sourceName, s.output) :: Nil case _ => Nil } } @@ -404,6 +408,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case MemoryPlan(sink, output) => val encoder = RowEncoder(sink.schema) LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil + case MemoryPlanV2(sink, output) => + val encoder = RowEncoder(StructType.fromAttributes(output)) + LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil case logical.Distinct(child) => throw new IllegalStateException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 3f243dc44e043..e4fca1b10dfad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.types.StructType @@ -52,10 +54,20 @@ case class DataSourceV2ScanExec( }.asJava } - val inputRDD = new DataSourceRDD(sparkContext, readTasks) - .asInstanceOf[RDD[InternalRow]] + val inputRDD = reader match { + case _: ContinuousReader => + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + .askSync[Unit](SetReaderPartitions(readTasks.size())) + + new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + + case _ => + new DataSourceRDD(sparkContext, readTasks) + } + val numOutputRows = longMetric("numOutputRows") - inputRDD.map { r => + inputRDD.asInstanceOf[RDD[InternalRow]].map { r => numOutputRows += 1 r } @@ -73,7 +85,7 @@ class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) } } -class RowToUnsafeDataReader(rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) +class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) extends DataReader[UnsafeRow] { override def next: Boolean = rowReader.next diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index b72d15ed15aed..1862da8892cb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -58,10 +60,22 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) s"The input RDD has ${messages.length} partitions.") try { + val runTask = writer match { + case w: ContinuousWriter => + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) + + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.runContinuous(writeTask, context, iter) + case _ => + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter) + } + sparkContext.runJob( rdd, - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter), + runTask, rdd.partitions.indices, (index, message: WriterCommitMessage) => messages(index) = message ) @@ -70,6 +84,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) writer.commit(messages) logInfo(s"Data source writer $writer committed.") } catch { + case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => + // Interruption is how continuous queries are ended, so accept and ignore the exception. case cause: Throwable => logError(s"Data source writer $writer is aborting.") try { @@ -109,6 +125,44 @@ object DataWritingSparkTask extends Logging { logError(s"Writer for partition ${context.partitionId()} aborted.") }) } + + def runContinuous( + writeTask: DataWriterFactory[InternalRow], + context: TaskContext, + iter: Iterator[InternalRow]): WriterCommitMessage = { + val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), + SparkEnv.get) + val currentMsg: WriterCommitMessage = null + var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + do { + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + iter.foreach(dataWriter.write) + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + val msg = dataWriter.commit() + logInfo(s"Writer for partition ${context.partitionId()} committed.") + epochCoordinator.send( + CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + ) + currentEpoch += 1 + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer + logError(s"Writer for partition ${context.partitionId()} is aborting.") + dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } while (!context.isInterrupted()) + + currentMsg + } } class InternalRowDataWriterFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java index 3a02cbfe7afe3..c44b8af2552f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.streaming; -import org.apache.spark.sql.sources.v2.reader.Offset; - /** * The shared interface between V1 streaming sources and V2 streaming readers. * @@ -26,12 +24,6 @@ * directly, and will be removed in future versions. */ public interface BaseStreamingSource { - /** - * Informs the source that Spark has completed processing all data for offsets less than or - * equal to `end` and will only request offsets greater than `end` in the future. - */ - void commit(Offset end); - /** Stop this source and free any resources it has allocated. */ void stop(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 43cf0ef1da8ca..6e8154d58d4c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -266,6 +266,20 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } } + /** + * Removes all log entries later than thresholdBatchId (exclusive). + */ + def purgeAfter(thresholdBatchId: Long): Unit = { + val batchIds = fileManager.list(metadataPath, batchFilesFilter) + .map(f => pathToBatchId(f.getPath)) + + for (batchId <- batchIds if batchId > thresholdBatchId) { + val path = batchIdToPath(batchId) + fileManager.delete(path) + logTrace(s"Removed metadata log file: $path") + } + } + private def createFileManager(): FileManager = { val hadoopConf = sparkSession.sessionState.newHadoopConf() try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 4a3de8bae4bc9..20f9810faa5c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.sources.v2.MicroBatchReadSupport import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -41,6 +42,8 @@ class MicroBatchExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { + @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -53,6 +56,7 @@ class MicroBatchExecution( s"but the current thread was ${Thread.currentThread}") var nextSourceId = 0L val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() + val v2ToExecutionRelationMap = MutableMap[StreamingRelationV2, StreamingExecutionRelation]() val _logicalPlan = analyzedPlan.transform { case streamingRelation@StreamingRelation(dataSource, _, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { @@ -64,6 +68,17 @@ class MicroBatchExecution( // "df.logicalPlan" has already used attributes of the previous `output`. StreamingExecutionRelation(source, output)(sparkSession) }) + case s @ StreamingRelationV2(v2DataSource, _, _, output, v1DataSource) + if !v2DataSource.isInstanceOf[MicroBatchReadSupport] => + v2ToExecutionRelationMap.getOrElseUpdate(s, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + val source = v1DataSource.createSource(metadataPath) + nextSourceId += 1 + // We still need to use the previous `output` instead of `source.schema` as attributes in + // "df.logicalPlan" has already used attributes of the previous `output`. + StreamingExecutionRelation(source, output)(sparkSession) + }) } sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source } uniqueSources = sources.distinct @@ -170,12 +185,14 @@ class MicroBatchExecution( * Make a call to getBatch using the offsets from previous batch. * because certain sources (e.g., KafkaSource) assume on restart the last * batch will be executed before getOffset is called again. */ - availableOffsets.foreach { ao: (Source, Offset) => - val (source, end) = ao - if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { - val start = committedOffsets.get(source) - source.getBatch(start, end) - } + availableOffsets.foreach { + case (source: Source, end: Offset) => + if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { + val start = committedOffsets.get(source) + source.getBatch(start, end) + } + case nonV1Tuple => + throw new IllegalStateException(s"Unexpected V2 source in $nonV1Tuple") } currentBatchId = latestCommittedBatchId + 1 committedOffsets ++= availableOffsets @@ -219,11 +236,12 @@ class MicroBatchExecution( val hasNewData = { awaitProgressLock.lock() try { - val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { s => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - (s, s.getOffset) - } + val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { + case s: Source => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + (s, s.getOffset) + } }.toMap availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get) @@ -298,7 +316,7 @@ class MicroBatchExecution( val prevBatchOff = offsetLog.get(currentBatchId - 1) if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { - case (src, off) => src.commit(off) + case (src: Source, off) => src.commit(off) } } else { throw new IllegalStateException(s"batch $currentBatchId doesn't exist") @@ -331,7 +349,7 @@ class MicroBatchExecution( // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { - case (source, available) + case (source: Source, available) if committedOffsets.get(source).map(_ != available).getOrElse(true) => val current = committedOffsets.get(source) val batch = source.getBatch(current, available) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 4e0a468b962a2..a1b63a6de3823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -38,7 +38,7 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet * This method is typically used to associate a serialized offset with actual sources (which * cannot be serialized). */ - def toStreamProgress(sources: Seq[Source]): StreamProgress = { + def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = { assert(sources.size == offsets.size) new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index b1c3a8ab235ab..1c9043613cb69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.Clock trait ProgressReporter extends Logging { case class ExecutionStats( - inputRows: Map[Source, Long], + inputRows: Map[BaseStreamingSource, Long], stateOperators: Seq[StateOperatorProgress], eventTimeStats: Map[String, String]) @@ -53,11 +53,11 @@ trait ProgressReporter extends Logging { protected def triggerClock: Clock protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution - protected def newData: Map[Source, DataFrame] + protected def newData: Map[BaseStreamingSource, DataFrame] protected def availableOffsets: StreamProgress protected def committedOffsets: StreamProgress - protected def sources: Seq[Source] - protected def sink: Sink + protected def sources: Seq[BaseStreamingSource] + protected def sink: BaseStreamingSink protected def offsetSeqMetadata: OffsetSeqMetadata protected def currentBatchId: Long protected def sparkSession: SparkSession @@ -230,7 +230,7 @@ trait ProgressReporter extends Logging { } val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() - val numInputRows: Map[Source, Long] = + val numInputRows: Map[BaseStreamingSource, Long] = if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 41761324cf6ac..3f85fa913f28c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -52,7 +52,7 @@ import org.apache.spark.util.{ManualClock, SystemClock} * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister - with DataSourceV2 with MicroBatchReadSupport with ContinuousReadSupport{ + with DataSourceV2 with ContinuousReadSupport { override def sourceSchema( sqlContext: SQLContext, @@ -107,13 +107,6 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister ) } - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceV2Options): MicroBatchReader = { - new RateStreamV2Reader(options) - } - override def createContinuousReader( schema: Optional[StructType], checkpointLocation: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala index 726d8574af52b..65d6d18936167 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala @@ -22,8 +22,11 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.sources.v2 -case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, (Long, Long)]) +case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, ValueRunTimeMsPair]) extends v2.reader.Offset { implicit val defaultFormats: DefaultFormats = DefaultFormats override val json = Serialization.write(partitionToValueAndRunTimeMs) } + + +case class ValueRunTimeMsPair(value: Long, runTimeMs: Long) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala index d10cd3044ecdf..34bc085d920c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.DataFrame * exactly once semantics a sink must be idempotent in the face of multiple attempts to add the same * batch. */ -trait Sink { +trait Sink extends BaseStreamingSink { /** * Adds a batch of data to this sink. The data for a given `batchId` is deterministic and if diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 311942f6dbd84..dbbd59e06909c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructType * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark * will regularly query each [[Source]] to see if any more data is available. */ -trait Source { +trait Source extends BaseStreamingSource { /** Returns the schema of the data from this source */ def schema: StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 129995dcf3607..3e76bf7b7ca8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -44,6 +44,7 @@ trait State case object INITIALIZING extends State case object ACTIVE extends State case object TERMINATED extends State +case object RECONFIGURING extends State /** * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. @@ -59,7 +60,7 @@ abstract class StreamExecution( override val name: String, private val checkpointRoot: String, analyzedPlan: LogicalPlan, - val sink: Sink, + val sink: BaseStreamingSink, val trigger: Trigger, val triggerClock: Clock, val outputMode: OutputMode, @@ -147,30 +148,25 @@ abstract class StreamExecution( * Pretty identified string of printing in logs. Format is * If name is set "queryName [id = xyz, runId = abc]" else "[id = xyz, runId = abc]" */ - private val prettyIdString = + protected val prettyIdString = Option(name).map(_ + " ").getOrElse("") + s"[id = $id, runId = $runId]" - /** - * All stream sources present in the query plan. This will be set when generating logical plan. - */ - @volatile protected var sources: Seq[Source] = Seq.empty - /** * A list of unique sources in the query plan. This will be set when generating logical plan. */ - @volatile protected var uniqueSources: Seq[Source] = Seq.empty + @volatile protected var uniqueSources: Seq[BaseStreamingSource] = Seq.empty /** Defines the internal state of execution */ - private val state = new AtomicReference[State](INITIALIZING) + protected val state = new AtomicReference[State](INITIALIZING) @volatile var lastExecution: IncrementalExecution = _ /** Holds the most recent input data for each source. */ - protected var newData: Map[Source, DataFrame] = _ + protected var newData: Map[BaseStreamingSource, DataFrame] = _ @volatile - private var streamDeathCause: StreamingQueryException = null + protected var streamDeathCause: StreamingQueryException = null /* Get the call site in the caller thread; will pass this into the micro batch thread */ private val callSite = Utils.getCallSite() @@ -389,7 +385,7 @@ abstract class StreamExecution( } /** Stops all streaming sources safely. */ - private def stopSources(): Unit = { + protected def stopSources(): Unit = { uniqueSources.foreach { source => try { source.stop() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index a3f3662e6f4c9..8531070b1bc49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -23,25 +23,28 @@ import scala.collection.{immutable, GenTraversableOnce} * A helper class that looks like a Map[Source, Offset]. */ class StreamProgress( - val baseMap: immutable.Map[Source, Offset] = new immutable.HashMap[Source, Offset]) - extends scala.collection.immutable.Map[Source, Offset] { + val baseMap: immutable.Map[BaseStreamingSource, Offset] = + new immutable.HashMap[BaseStreamingSource, Offset]) + extends scala.collection.immutable.Map[BaseStreamingSource, Offset] { - def toOffsetSeq(source: Seq[Source], metadata: OffsetSeqMetadata): OffsetSeq = { + def toOffsetSeq(source: Seq[BaseStreamingSource], metadata: OffsetSeqMetadata): OffsetSeq = { OffsetSeq(source.map(get), Some(metadata)) } override def toString: String = baseMap.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") - override def +[B1 >: Offset](kv: (Source, B1)): Map[Source, B1] = baseMap + kv + override def +[B1 >: Offset](kv: (BaseStreamingSource, B1)): Map[BaseStreamingSource, B1] = { + baseMap + kv + } - override def get(key: Source): Option[Offset] = baseMap.get(key) + override def get(key: BaseStreamingSource): Option[Offset] = baseMap.get(key) - override def iterator: Iterator[(Source, Offset)] = baseMap.iterator + override def iterator: Iterator[(BaseStreamingSource, Offset)] = baseMap.iterator - override def -(key: Source): Map[Source, Offset] = baseMap - key + override def -(key: BaseStreamingSource): Map[BaseStreamingSource, Offset] = baseMap - key - def ++(updates: GenTraversableOnce[(Source, Offset)]): StreamProgress = { + def ++(updates: GenTraversableOnce[(BaseStreamingSource, Offset)]): StreamProgress = { new StreamProgress(baseMap ++ updates) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 6b82c78ea653d..0ca2e7854d94b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -75,6 +76,52 @@ case class StreamingExecutionRelation( ) } +// We have to pack in the V1 data source as a shim, for the case when a source implements +// continuous processing (which is always V2) but only has V1 microbatch support. We don't +// know at read time whether the query is conntinuous or not, so we need to be able to +// swap a V1 relation back in. +/** + * Used to link a [[DataSourceV2]] into a streaming + * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating + * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]], + * and should be converted before passing to [[StreamExecution]]. + */ +case class StreamingRelationV2( + dataSource: DataSourceV2, + sourceName: String, + extraOptions: Map[String, String], + output: Seq[Attribute], + v1DataSource: DataSource)(session: SparkSession) + extends LeafNode { + override def isStreaming: Boolean = true + override def toString: String = sourceName + + override def computeStats(): Statistics = Statistics( + sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) + ) +} + +/** + * Used to link a [[DataSourceV2]] into a continuous processing execution. + */ +case class ContinuousExecutionRelation( + source: ContinuousReadSupport, + extraOptions: Map[String, String], + output: Seq[Attribute])(session: SparkSession) + extends LeafNode { + + override def isStreaming: Boolean = true + override def toString: String = source.toString + + // There's no sensible value here. On the execution path, this relation will be + // swapped out with microbatches. But some dataframe operations (in particular explain) do lead + // to this node surviving analysis. So we satisfy the LeafNode contract with the session default + // value. + override def computeStats(): Statistics = Statistics( + sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) + ) +} + /** * A dummy physical plan for [[StreamingRelation]] to support * [[org.apache.spark.sql.Dataset.explain]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala new file mode 100644 index 0000000000000..89fb2ace20917 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} + +import scala.collection.JavaConverters._ + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.streaming.ProcessingTime +import org.apache.spark.util.{SystemClock, ThreadUtils} + +class ContinuousDataSourceRDD( + sc: SparkContext, + sqlContext: SQLContext, + @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) + extends RDD[UnsafeRow](sc, Nil) { + + private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize + private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs + + override protected def getPartitions: Array[Partition] = { + readTasks.asScala.zipWithIndex.map { + case (readTask, index) => new DataSourceRDDPartition(index, readTask) + }.toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + + val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) + + // This queue contains two types of messages: + // * (null, null) representing an epoch boundary. + // * (row, off) containing a data row and its corresponding PartitionOffset. + val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize) + + val epochPollFailed = new AtomicBoolean(false) + val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + s"epoch-poll--${runId}--${context.partitionId()}") + val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) + epochPollExecutor.scheduleWithFixedDelay( + epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) + + // Important sequencing - we must get start offset before the data reader thread begins + val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset + + val dataReaderFailed = new AtomicBoolean(false) + val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed) + dataReaderThread.setDaemon(true) + dataReaderThread.start() + + context.addTaskCompletionListener(_ => { + reader.close() + dataReaderThread.interrupt() + epochPollExecutor.shutdown() + }) + + val epochEndpoint = EpochCoordinatorRef.get(runId, SparkEnv.get) + new Iterator[UnsafeRow] { + private val POLL_TIMEOUT_MS = 1000 + + private var currentEntry: (UnsafeRow, PartitionOffset) = _ + private var currentOffset: PartitionOffset = startOffset + private var currentEpoch = + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + override def hasNext(): Boolean = { + while (currentEntry == null) { + if (context.isInterrupted() || context.isCompleted()) { + currentEntry = (null, null) + } + if (dataReaderFailed.get()) { + throw new SparkException("data read failed", dataReaderThread.failureReason) + } + if (epochPollFailed.get()) { + throw new SparkException("epoch poll failed", epochPollRunnable.failureReason) + } + currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } + + currentEntry match { + // epoch boundary marker + case (null, null) => + epochEndpoint.send(ReportPartitionOffset( + context.partitionId(), + currentEpoch, + currentOffset)) + currentEpoch += 1 + currentEntry = null + false + // real row + case (_, offset) => + currentOffset = offset + true + } + } + + override def next(): UnsafeRow = { + if (currentEntry == null) throw new NoSuchElementException("No current row was set") + val r = currentEntry._1 + currentEntry = null + r + } + } + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + } +} + +case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset + +class EpochPollRunnable( + queue: BlockingQueue[(UnsafeRow, PartitionOffset)], + context: TaskContext, + failedFlag: AtomicBoolean) + extends Thread with Logging { + private[continuous] var failureReason: Throwable = _ + + private val epochEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), SparkEnv.get) + private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + override def run(): Unit = { + try { + val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch) + for (i <- currentEpoch to newEpoch - 1) { + queue.put((null, null)) + logDebug(s"Sent marker to start epoch ${i + 1}") + } + currentEpoch = newEpoch + } catch { + case t: Throwable => + failureReason = t + failedFlag.set(true) + throw t + } + } +} + +class DataReaderThread( + reader: DataReader[UnsafeRow], + queue: BlockingQueue[(UnsafeRow, PartitionOffset)], + context: TaskContext, + failedFlag: AtomicBoolean) + extends Thread( + s"continuous-reader--${context.partitionId()}--" + + s"${context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)}") { + private[continuous] var failureReason: Throwable = _ + + override def run(): Unit = { + val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) + try { + while (!context.isInterrupted && !context.isCompleted()) { + if (!reader.next()) { + // Check again, since reader.next() might have blocked through an incoming interrupt. + if (!context.isInterrupted && !context.isCompleted()) { + throw new IllegalStateException( + "Continuous reader reported no elements! Reader should have blocked waiting.") + } else { + return + } + } + + queue.put((reader.get().copy(), baseReader.getOffset)) + } + } catch { + case _: InterruptedException if context.isInterrupted() => + // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. + + case t: Throwable => + failureReason = t + failedFlag.set(true) + // Don't rethrow the exception in this thread. It's not needed, and the default Spark + // exception handler will kill the executor. + } + } +} + +object ContinuousDataSourceRDD { + private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { + reader match { + case r: ContinuousDataReader[UnsafeRow] => r + case wrapped: RowToUnsafeDataReader => + wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] + case _ => + throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala new file mode 100644 index 0000000000000..1c35b06bd4b85 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} + +import org.apache.spark.SparkEnv +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, ContinuousWriteSupport, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.ContinuousWriter +import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{Clock, Utils} + +class ContinuousExecution( + sparkSession: SparkSession, + name: String, + checkpointRoot: String, + analyzedPlan: LogicalPlan, + sink: ContinuousWriteSupport, + trigger: Trigger, + triggerClock: Clock, + outputMode: OutputMode, + extraOptions: Map[String, String], + deleteCheckpointOnStop: Boolean) + extends StreamExecution( + sparkSession, name, checkpointRoot, analyzedPlan, sink, + trigger, triggerClock, outputMode, deleteCheckpointOnStop) { + + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty + override protected def sources: Seq[BaseStreamingSource] = continuousSources + + override lazy val logicalPlan: LogicalPlan = { + assert(queryExecutionThread eq Thread.currentThread, + "logicalPlan must be initialized in StreamExecutionThread " + + s"but the current thread was ${Thread.currentThread}") + val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() + analyzedPlan.transform { + case r @ StreamingRelationV2( + source: ContinuousReadSupport, _, extraReaderOptions, output, _) => + toExecutionRelationMap.getOrElseUpdate(r, { + ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) + }) + case StreamingRelationV2(_, sourceName, _, _, _) => + throw new AnalysisException( + s"Data source $sourceName does not support continuous processing.") + } + } + + private val triggerExecutor = trigger match { + case ContinuousTrigger(t) => ProcessingTimeExecutor(ProcessingTime(t), triggerClock) + case _ => throw new IllegalStateException(s"Unsupported type of trigger: $trigger") + } + + override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { + do { + try { + runContinuous(sparkSessionForStream) + } catch { + case _: InterruptedException if state.get().equals(RECONFIGURING) => + // swallow exception and run again + state.set(ACTIVE) + } + } while (state.get() == ACTIVE) + } + + /** + * Populate the start offsets to start the execution at the current offsets stored in the sink + * (i.e. avoid reprocessing data that we have already processed). This function must be called + * before any processing occurs and will populate the following fields: + * - currentBatchId + * - committedOffsets + * The basic structure of this method is as follows: + * + * Identify (from the commit log) the latest epoch that has committed + * IF last epoch exists THEN + * Get end offsets for the epoch + * Set those offsets as the current commit progress + * Set the next epoch ID as the last + 1 + * Return the end offsets of the last epoch as start for the next one + * DONE + * ELSE + * Start a new query log + * DONE + */ + private def getStartOffsets(sparkSessionToRunBatches: SparkSession): OffsetSeq = { + // Note that this will need a slight modification for exactly once. If ending offsets were + // reported but not committed for any epochs, we must replay exactly to those offsets. + // For at least once, we can just ignore those reports and risk duplicates. + commitLog.getLatest() match { + case Some((latestEpochId, _)) => + val nextOffsets = offsetLog.get(latestEpochId).getOrElse { + throw new IllegalStateException( + s"Batch $latestEpochId was committed without end epoch offsets!") + } + committedOffsets = nextOffsets.toStreamProgress(sources) + + // Forcibly align commit and offset logs by slicing off any spurious offset logs from + // a previous run. We can't allow commits to an epoch that a previous run reached but + // this run has not. + offsetLog.purgeAfter(latestEpochId) + + currentBatchId = latestEpochId + 1 + logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") + nextOffsets + case None => + // We are starting this stream for the first time. Offsets are all None. + logInfo(s"Starting new streaming query.") + currentBatchId = 0 + OffsetSeq.fill(continuousSources.map(_ => null): _*) + } + } + + /** + * Do a continuous run. + * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. + */ + private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { + // A list of attributes that will need to be updated. + val replacements = new ArrayBuffer[(Attribute, Attribute)] + // Translate from continuous relation to the underlying data source. + var nextSourceId = 0 + continuousSources = logicalPlan.collect { + case ContinuousExecutionRelation(dataSource, extraReaderOptions, output) => + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + nextSourceId += 1 + + dataSource.createContinuousReader( + java.util.Optional.empty[StructType](), + metadataPath, + new DataSourceV2Options(extraReaderOptions.asJava)) + } + uniqueSources = continuousSources.distinct + + val offsets = getStartOffsets(sparkSessionForQuery) + + var insertedSourceId = 0 + val withNewSources = logicalPlan transform { + case ContinuousExecutionRelation(_, _, output) => + val reader = continuousSources(insertedSourceId) + insertedSourceId += 1 + val newOutput = reader.readSchema().toAttributes + + assert(output.size == newOutput.size, + s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + + s"${Utils.truncatedString(newOutput, ",")}") + replacements ++= output.zip(newOutput) + + val loggedOffset = offsets.offsets(0) + val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) + reader.setOffset(java.util.Optional.ofNullable(realOffset.orNull)) + DataSourceV2Relation(newOutput, reader) + } + + // Rewire the plan to use the new attributes that were returned by the source. + val replacementMap = AttributeMap(replacements) + val triggerLogicalPlan = withNewSources transformAllExpressions { + case a: Attribute if replacementMap.contains(a) => + replacementMap(a).withMetadata(a.metadata) + case (_: CurrentTimestamp | _: CurrentDate) => + throw new IllegalStateException( + "CurrentTimestamp and CurrentDate not yet supported for continuous processing") + } + + val writer = sink.createContinuousWriter( + s"$runId", + triggerLogicalPlan.schema, + outputMode, + new DataSourceV2Options(extraOptions.asJava)) + val withSink = WriteToDataSourceV2(writer.get(), triggerLogicalPlan) + + val reader = withSink.collect { + case DataSourceV2Relation(_, r: ContinuousReader) => r + }.head + + reportTimeTaken("queryPlanning") { + lastExecution = new IncrementalExecution( + sparkSessionForQuery, + withSink, + outputMode, + checkpointFile("state"), + runId, + currentBatchId, + offsetSeqMetadata) + lastExecution.executedPlan // Force the lazy generation of execution plan + } + + sparkSession.sparkContext.setLocalProperty( + ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) + sparkSession.sparkContext.setLocalProperty( + ContinuousExecution.RUN_ID_KEY, runId.toString) + + // Use the parent Spark session for the endpoint since it's where this query ID is registered. + val epochEndpoint = + EpochCoordinatorRef.create( + writer.get(), reader, this, currentBatchId, sparkSession, SparkEnv.get) + val epochUpdateThread = new Thread(new Runnable { + override def run: Unit = { + try { + triggerExecutor.execute(() => { + startTrigger() + + if (reader.needsReconfiguration()) { + state.set(RECONFIGURING) + stopSources() + if (queryExecutionThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) + queryExecutionThread.interrupt() + // No need to join - this thread is about to end anyway. + } + false + } else if (isActive) { + currentBatchId = epochEndpoint.askSync[Long](IncrementAndGetEpoch) + logInfo(s"New epoch $currentBatchId is starting.") + true + } else { + false + } + }) + } catch { + case _: InterruptedException => + // Cleanly stop the query. + return + } + } + }, s"epoch update thread for $prettyIdString") + + try { + epochUpdateThread.setDaemon(true) + epochUpdateThread.start() + + reportTimeTaken("runContinuous") { + SQLExecution.withNewExecutionId( + sparkSessionForQuery, lastExecution)(lastExecution.toRdd) + } + } finally { + SparkEnv.get.rpcEnv.stop(epochEndpoint) + + epochUpdateThread.interrupt() + epochUpdateThread.join() + } + } + + /** + * Report ending partition offsets for the given reader at the given epoch. + */ + def addOffset( + epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { + assert(continuousSources.length == 1, "only one continuous source supported currently") + + if (partitionOffsets.contains(null)) { + // If any offset is null, that means the corresponding partition hasn't seen any data yet, so + // there's nothing meaningful to add to the offset log. + } + val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) + synchronized { + if (queryExecutionThread.isAlive) { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + } else { + return + } + } + } + + /** + * Mark the specified epoch as committed. All readers must have reported end offsets for the epoch + * before this is called. + */ + def commit(epoch: Long): Unit = { + assert(continuousSources.length == 1, "only one continuous source supported currently") + assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit") + synchronized { + if (queryExecutionThread.isAlive) { + commitLog.add(epoch) + val offset = offsetLog.get(epoch).get.offsets(0).get + committedOffsets ++= Seq(continuousSources(0) -> offset) + } else { + return + } + } + + if (minLogEntriesToMaintain < currentBatchId) { + offsetLog.purge(currentBatchId - minLogEntriesToMaintain) + commitLog.purge(currentBatchId - minLogEntriesToMaintain) + } + + awaitProgressLock.lock() + try { + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() + } + } + + /** + * Blocks the current thread until execution has committed at or after the specified epoch. + */ + private[sql] def awaitEpoch(epoch: Long): Unit = { + def notDone = { + val latestCommit = commitLog.getLatest() + latestCommit match { + case Some((latestEpoch, _)) => + latestEpoch < epoch + case None => true + } + } + + while (notDone) { + awaitProgressLock.lock() + try { + awaitProgressLockCondition.await(100, TimeUnit.MILLISECONDS) + if (streamDeathCause != null) { + throw streamDeathCause + } + } finally { + awaitProgressLock.unlock() + } + } + } +} + +object ContinuousExecution { + val START_EPOCH_KEY = "__continuous_start_epoch" + val RUN_ID_KEY = "__run_id" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 4c3a1ee201ac1..89a8562b4b59e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -25,7 +25,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset} +import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2, DataSourceV2Options} import org.apache.spark.sql.sources.v2.reader._ @@ -47,13 +47,14 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { assert(offsets.length == numPartitions) val tuples = offsets.map { - case ContinuousRateStreamPartitionOffset(i, currVal, nextRead) => (i, (currVal, nextRead)) + case ContinuousRateStreamPartitionOffset(i, currVal, nextRead) => + (i, ValueRunTimeMsPair(currVal, nextRead)) } RateStreamOffset(Map(tuples: _*)) } override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, (Long, Long)]](json)) + RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } override def readSchema(): StructType = RateSourceProvider.SCHEMA @@ -85,8 +86,8 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. RateStreamReadTask( - start._1, // starting row value - start._2, // starting time in ms + start.value, + start.runTimeMs, i, numPartitions, perPartitionRate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala new file mode 100644 index 0000000000000..90e1766c4d9f1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration.Duration + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.streaming.{ProcessingTime, Trigger} +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at + * the specified interval. + */ +@InterfaceStability.Evolving +case class ContinuousTrigger(intervalMs: Long) extends Trigger { + require(intervalMs >= 0, "the interval of trigger should not be negative") +} + +private[sql] object ContinuousTrigger { + def apply(interval: String): ContinuousTrigger = { + if (StringUtils.isBlank(interval)) { + throw new IllegalArgumentException( + "interval cannot be null or blank.") + } + val cal = if (interval.startsWith("interval")) { + CalendarInterval.fromString(interval) + } else { + CalendarInterval.fromString("interval " + interval) + } + if (cal == null) { + throw new IllegalArgumentException(s"Invalid interval: $interval") + } + if (cal.months > 0) { + throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") + } + new ContinuousTrigger(cal.microseconds / 1000) + } + + def apply(interval: Duration): ContinuousTrigger = { + ContinuousTrigger(interval.toMillis) + } + + def create(interval: String): ContinuousTrigger = { + apply(interval) + } + + def create(interval: Long, unit: TimeUnit): ContinuousTrigger = { + ContinuousTrigger(unit.toMillis(interval)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala new file mode 100644 index 0000000000000..7f1e8abd79b99 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper +import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.{ContinuousWriter, WriterCommitMessage} +import org.apache.spark.util.RpcUtils + +private[continuous] sealed trait EpochCoordinatorMessage extends Serializable + +// Driver epoch trigger message +/** + * Atomically increment the current epoch and get the new value. + */ +private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage + +// Init messages +/** + * Set the reader and writer partition counts. Tasks may not be started until the coordinator + * has acknowledged these messages. + */ +private[sql] case class SetReaderPartitions(numPartitions: Int) extends EpochCoordinatorMessage +case class SetWriterPartitions(numPartitions: Int) extends EpochCoordinatorMessage + +// Partition task messages +/** + * Get the current epoch. + */ +private[sql] case object GetCurrentEpoch extends EpochCoordinatorMessage +/** + * Commit a partition at the specified epoch with the given message. + */ +private[sql] case class CommitPartitionEpoch( + partitionId: Int, + epoch: Long, + message: WriterCommitMessage) extends EpochCoordinatorMessage +/** + * Report that a partition is ending the specified epoch at the specified offset. + */ +private[sql] case class ReportPartitionOffset( + partitionId: Int, + epoch: Long, + offset: PartitionOffset) extends EpochCoordinatorMessage + + +/** Helper object used to create reference to [[EpochCoordinator]]. */ +private[sql] object EpochCoordinatorRef extends Logging { + private def endpointName(runId: String) = s"EpochCoordinator-$runId" + + /** + * Create a reference to a new [[EpochCoordinator]]. + */ + def create( + writer: ContinuousWriter, + reader: ContinuousReader, + query: ContinuousExecution, + startEpoch: Long, + session: SparkSession, + env: SparkEnv): RpcEndpointRef = synchronized { + val coordinator = new EpochCoordinator( + writer, reader, query, startEpoch, session, env.rpcEnv) + val ref = env.rpcEnv.setupEndpoint(endpointName(query.runId.toString()), coordinator) + logInfo("Registered EpochCoordinator endpoint") + ref + } + + def get(runId: String, env: SparkEnv): RpcEndpointRef = synchronized { + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(runId), env.conf, env.rpcEnv) + logDebug("Retrieved existing EpochCoordinator endpoint") + rpcEndpointRef + } +} + +/** + * Handles three major epoch coordination tasks for continuous processing: + * + * * Maintains a local epoch counter (the "driver epoch"), incremented by IncrementAndGetEpoch + * and pollable from executors by GetCurrentEpoch. Note that this epoch is *not* immediately + * reflected anywhere in ContinuousExecution. + * * Collates ReportPartitionOffset messages, and forwards to ContinuousExecution when all + * readers have ended a given epoch. + * * Collates CommitPartitionEpoch messages, and forwards to ContinuousExecution when all readers + * have both committed and reported an end offset for a given epoch. + */ +private[continuous] class EpochCoordinator( + writer: ContinuousWriter, + reader: ContinuousReader, + query: ContinuousExecution, + startEpoch: Long, + session: SparkSession, + override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + + private var numReaderPartitions: Int = _ + private var numWriterPartitions: Int = _ + + private var currentDriverEpoch = startEpoch + + // (epoch, partition) -> message + private val partitionCommits = + mutable.Map[(Long, Int), WriterCommitMessage]() + // (epoch, partition) -> offset + private val partitionOffsets = + mutable.Map[(Long, Int), PartitionOffset]() + + private def resolveCommitsAtEpoch(epoch: Long) = { + val thisEpochCommits = + partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + val nextEpochOffsets = + partitionOffsets.collect { case ((e, _), o) if e == epoch => o } + + if (thisEpochCommits.size == numWriterPartitions && + nextEpochOffsets.size == numReaderPartitions) { + logDebug(s"Epoch $epoch has received commits from all partitions. Committing globally.") + // Sequencing is important here. We must commit to the writer before recording the commit + // in the query, or we will end up dropping the commit if we restart in the middle. + writer.commit(epoch, thisEpochCommits.toArray) + query.commit(epoch) + + // Cleanup state from before this epoch, now that we know all partitions are forever past it. + for (k <- partitionCommits.keys.filter { case (e, _) => e < epoch }) { + partitionCommits.remove(k) + } + for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { + partitionCommits.remove(k) + } + } + } + + override def receive: PartialFunction[Any, Unit] = { + case CommitPartitionEpoch(partitionId, epoch, message) => + logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") + if (!partitionCommits.isDefinedAt((epoch, partitionId))) { + partitionCommits.put((epoch, partitionId), message) + resolveCommitsAtEpoch(epoch) + } + + case ReportPartitionOffset(partitionId, epoch, offset) => + partitionOffsets.put((epoch, partitionId), offset) + val thisEpochOffsets = + partitionOffsets.collect { case ((e, _), o) if e == epoch => o } + if (thisEpochOffsets.size == numReaderPartitions) { + logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets") + query.addOffset(epoch, reader, thisEpochOffsets.toSeq) + resolveCommitsAtEpoch(epoch) + } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case GetCurrentEpoch => + val result = currentDriverEpoch + logDebug(s"Epoch $result") + context.reply(result) + + case IncrementAndGetEpoch => + currentDriverEpoch += 1 + context.reply(currentDriverEpoch) + + case SetReaderPartitions(numPartitions) => + numReaderPartitions = numPartitions + context.reply(()) + + case SetWriterPartitions(numPartitions) => + numWriterPartitions = numPartitions + context.reply(()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 45dc7d75cbc8d..1c66aed8690a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -27,7 +27,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.RateStreamOffset +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} @@ -71,7 +71,7 @@ class RateStreamV2Reader(options: DataSourceV2Options) val currentTime = clock.getTimeMillis() RateStreamOffset( this.start.partitionToValueAndRunTimeMs.map { - case startOffset @ (part, (currentVal, currentReadTime)) => + case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => // Calculate the number of rows we should advance in this partition (based on the // current time), and output a corresponding offset. val readInterval = currentTime - currentReadTime @@ -79,9 +79,9 @@ class RateStreamV2Reader(options: DataSourceV2Options) if (numNewRows <= 0) { startOffset } else { - (part, - (currentVal + (numNewRows * numPartitions), - currentReadTime + (numNewRows * msPerPartitionBetweenRows))) + (part, ValueRunTimeMsPair( + currentVal + (numNewRows * numPartitions), + currentReadTime + (numNewRows * msPerPartitionBetweenRows))) } } ) @@ -98,15 +98,15 @@ class RateStreamV2Reader(options: DataSourceV2Options) } override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, (Long, Long)]](json)) + RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } override def createReadTasks(): java.util.List[ReadTask[Row]] = { val startMap = start.partitionToValueAndRunTimeMs val endMap = end.partitionToValueAndRunTimeMs endMap.keys.toSeq.map { part => - val (endVal, _) = endMap(part) - val (startVal, startTimeMs) = startMap(part) + val ValueRunTimeMsPair(endVal, _) = endMap(part) + val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) val packedRows = mutable.ListBuffer[(Long, Long)]() var outVal = startVal + numPartitions @@ -158,7 +158,8 @@ object RateStreamSourceV2 { // by the increment that will later be applied. The first row output in each // partition will have a value equal to the partition index. (i, - ((i - numPartitions).toLong, + ValueRunTimeMsPair( + (i - numPartitions).toLong, creationTimeMs)) }.toMap) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 94c5dd63089b1..972248d5e4df8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources.v2.{ContinuousWriteSupport, DataSourceV2, DataSourceV2Options, MicroBatchWriteSupport} @@ -177,3 +179,14 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) override def abort(): Unit = {} } + + +/** + * Used to query the data that has been written into a [[MemorySink]]. + */ +case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { + private val sizePerRow = output.map(_.dataType.defaultSize).sum + + override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 41aa02c2b5e35..f17935e86f459 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -26,8 +26,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.StreamingRelation +import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Options, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Interface used to load a streaming `Dataset` from external storage systems (e.g. file systems, @@ -153,13 +155,33 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo "read files of Hive data source directly.") } - val dataSource = - DataSource( - sparkSession, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap) - Dataset.ofRows(sparkSession, StreamingRelation(dataSource)) + val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance() + val options = new DataSourceV2Options(extraOptions.asJava) + // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. + // We can't be sure at this point whether we'll actually want to use V2, since we don't know the + // writer or whether the query is continuous. + val v1DataSource = DataSource( + sparkSession, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap) + ds match { + case s: ContinuousReadSupport => + val tempReader = s.createContinuousReader( + java.util.Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + // Generate the V1 node to catch errors thrown within generation. + StreamingRelation(v1DataSource) + Dataset.ofRows( + sparkSession, + StreamingRelationV2( + s, source, extraOptions.toMap, + tempReader.readSchema().toAttributes, v1DataSource)(sparkSession)) + case _ => + // Code path for data source v1. + Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 0be69b98abc8a..db588ae282f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -26,7 +26,9 @@ import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger +import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -240,14 +242,23 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { if (extraOptions.get("queryName").isEmpty) { throw new AnalysisException("queryName must be specified for memory sink") } - val sink = new MemorySink(df.schema, outputMode) - val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink)) + val (sink, resultDf) = trigger match { + case _: ContinuousTrigger => + val s = new MemorySinkV2() + val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) + (s, r) + case _ => + val s = new MemorySink(df.schema, outputMode) + val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) + (s, r) + } val chkpointLoc = extraOptions.get("checkpointLocation") val recoverFromChkpoint = outputMode == OutputMode.Complete() val query = df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), chkpointLoc, df, + extraOptions.toMap, sink, outputMode, useTempCheckpointLocation = true, @@ -262,6 +273,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, + extraOptions.toMap, sink, outputMode, useTempCheckpointLocation = true, @@ -277,6 +289,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, + extraOptions.toMap, dataSource.createSink(outputMode), outputMode, useTempCheckpointLocation = source == "console", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 555d6e23f9385..e808ffaa96410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -29,8 +29,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.ContinuousWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -188,7 +190,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo userSpecifiedName: Option[String], userSpecifiedCheckpointLocation: Option[String], df: DataFrame, - sink: Sink, + extraOptions: Map[String, String], + sink: BaseStreamingSink, outputMode: OutputMode, useTempCheckpointLocation: Boolean, recoverFromCheckpointLocation: Boolean, @@ -237,16 +240,32 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo "is not supported in streaming DataFrames/Datasets and will be disabled.") } - new StreamingQueryWrapper(new MicroBatchExecution( - sparkSession, - userSpecifiedName.orNull, - checkpointLocation, - analyzedPlan, - sink, - trigger, - triggerClock, - outputMode, - deleteCheckpointOnStop)) + sink match { + case v1Sink: Sink => + new StreamingQueryWrapper(new MicroBatchExecution( + sparkSession, + userSpecifiedName.orNull, + checkpointLocation, + analyzedPlan, + v1Sink, + trigger, + triggerClock, + outputMode, + deleteCheckpointOnStop)) + case v2Sink: ContinuousWriteSupport => + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + new StreamingQueryWrapper(new ContinuousExecution( + sparkSession, + userSpecifiedName.orNull, + checkpointLocation, + analyzedPlan, + v2Sink, + trigger, + triggerClock, + outputMode, + extraOptions, + deleteCheckpointOnStop)) + } } /** @@ -269,7 +288,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo userSpecifiedName: Option[String], userSpecifiedCheckpointLocation: Option[String], df: DataFrame, - sink: Sink, + extraOptions: Map[String, String], + sink: BaseStreamingSink, outputMode: OutputMode, useTempCheckpointLocation: Boolean = false, recoverFromCheckpointLocation: Boolean = true, @@ -279,6 +299,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo userSpecifiedName, userSpecifiedCheckpointLocation, df, + extraOptions, sink, outputMode, useTempCheckpointLocation, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index fcaca3d75b74f..9fb8be423614b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -297,31 +297,47 @@ object QueryTest { }) } + private def genError( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): String = { + val getRowType: Option[Row] => String = row => + row.map(row => + if (row.schema == null) { + "struct<>" + } else { + s"${row.schema.catalogString}" + }).getOrElse("struct<>") + + s""" + |== Results == + |${ + sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + getRowType(expectedAnswer.headOption) +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + getRowType(sparkAnswer.headOption) +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n") + } + """.stripMargin + } + + def includesRows( + expectedRows: Seq[Row], + sparkAnswer: Seq[Row]): Option[String] = { + if (!prepareAnswer(expectedRows, true).toSet.subsetOf(prepareAnswer(sparkAnswer, true).toSet)) { + return Some(genError(expectedRows, sparkAnswer, true)) + } + None + } + def sameRows( expectedAnswer: Seq[Row], sparkAnswer: Seq[Row], isSorted: Boolean = false): Option[String] = { if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { - val getRowType: Option[Row] => String = row => - row.map(row => - if (row.schema == null) { - "struct<>" - } else { - s"${row.schema.catalogString}" - }).getOrElse("struct<>") - - val errorMessage = - s""" - |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - getRowType(expectedAnswer.headOption) +: - prepareAnswer(expectedAnswer, isSorted).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - getRowType(sparkAnswer.headOption) +: - prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")} - """.stripMargin - return Some(errorMessage) + return Some(genError(expectedAnswer, sparkAnswer, isSorted)) } None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index 6514c5f0fdfeb..dc833b2ccaa22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -29,16 +29,6 @@ import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Optio import org.apache.spark.sql.streaming.StreamTest class RateSourceV2Suite extends StreamTest { - test("microbatch in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceV2Options.empty()) - assert(reader.isInstanceOf[RateStreamV2Reader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - test("microbatch - numPartitions propagated") { val reader = new RateStreamV2Reader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) @@ -49,8 +39,8 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - set offset") { val reader = new RateStreamV2Reader(DataSourceV2Options.empty()) - val startOffset = RateStreamOffset(Map((0, (0, 1000)))) - val endOffset = RateStreamOffset(Map((0, (0, 2000)))) + val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) + val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) assert(reader.getStartOffset() == startOffset) assert(reader.getEndOffset() == endOffset) @@ -63,15 +53,15 @@ class RateSourceV2Suite extends StreamTest { reader.setOffsetRange(Optional.empty(), Optional.empty()) reader.getStartOffset() match { case r: RateStreamOffset => - assert(r.partitionToValueAndRunTimeMs(0)._2 == reader.creationTimeMs) + assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) case _ => throw new IllegalStateException("unexpected offset type") } reader.getEndOffset() match { case r: RateStreamOffset => // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted // longer than 100ms. It should never be early. - assert(r.partitionToValueAndRunTimeMs(0)._1 >= 9) - assert(r.partitionToValueAndRunTimeMs(0)._2 >= reader.creationTimeMs + 100) + assert(r.partitionToValueAndRunTimeMs(0).value >= 9) + assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) case _ => throw new IllegalStateException("unexpected offset type") } @@ -80,8 +70,8 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - predetermined batch size") { val reader = new RateStreamV2Reader( new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) - val startOffset = RateStreamOffset(Map((0, (0, 1000)))) - val endOffset = RateStreamOffset(Map((0, (20, 2000)))) + val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) + val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) val tasks = reader.createReadTasks() assert(tasks.size == 1) @@ -93,8 +83,8 @@ class RateSourceV2Suite extends StreamTest { new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { - case (part, (currentVal, currentReadTime)) => - (part, (currentVal + 33, currentReadTime + 1000)) + case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => + (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) }.toMap) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) @@ -135,7 +125,7 @@ class RateSourceV2Suite extends StreamTest { val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) - ._2 + .runTimeMs val r = t.createDataReader().asInstanceOf[RateStreamDataReader] for (rowIndex <- 0 to 9) { r.next() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 755490308b5b9..c65e5d3dd75c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -77,10 +77,23 @@ class StreamSuite extends StreamTest { } test("StreamingRelation.computeStats") { + withTempDir { dir => + val df = spark.readStream.format("csv").schema(StructType(Seq())).load(dir.getCanonicalPath) + val streamingRelation = df.logicalPlan collect { + case s: StreamingRelation => s + } + assert(streamingRelation.nonEmpty, "cannot find StreamingRelation") + assert( + streamingRelation.head.computeStats.sizeInBytes == + spark.sessionState.conf.defaultSizeInBytes) + } + } + + test("StreamingRelationV2.computeStats") { val streamingRelation = spark.readStream.format("rate").load().logicalPlan collect { - case s: StreamingRelation => s + case s: StreamingRelationV2 => s } - assert(streamingRelation.nonEmpty, "cannot find StreamingRelation") + assert(streamingRelation.nonEmpty, "cannot find StreamingExecutionRelation") assert( streamingRelation.head.computeStats.sizeInBytes == spark.sessionState.conf.defaultSizeInBytes) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 71a474ef63e84..fb9ebc81dd750 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -33,11 +33,14 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext @@ -168,6 +171,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } + case class CheckAnswerRowsContains(expectedAnswer: Seq[Row], lastOnly: Boolean = false) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" + } + case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { override def toString: String = s"$operatorName: ${checkFunction.toString()}" @@ -237,6 +246,25 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be AssertOnQuery(query => { func(query); true }) } + object AwaitEpoch { + def apply(epoch: Long): AssertOnQuery = + Execute { + case s: ContinuousExecution => s.awaitEpoch(epoch) + case _ => throw new IllegalStateException("microbatch cannot await epoch") + } + } + + object IncrementEpoch { + def apply(): AssertOnQuery = + Execute { + case s: ContinuousExecution => + val newEpoch = EpochCoordinatorRef.get(s.runId.toString, SparkEnv.get) + .askSync[Long](IncrementAndGetEpoch) + s.awaitEpoch(newEpoch - 1) + case _ => throw new IllegalStateException("microbatch cannot increment epoch") + } + } + /** * Executes the specified actions on the given streaming DataFrame and provides helpful * error messages in the case of failures or incorrect answers. @@ -246,7 +274,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be */ def testStream( _stream: Dataset[_], - outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized { + outputMode: OutputMode = OutputMode.Append, + useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -259,7 +288,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = new MemorySink(stream.schema, outputMode) + val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) val resetConfValues = mutable.Map[String, Option[String]]() @volatile @@ -308,7 +337,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be "" } - def testState = + def testState = { + val sinkDebugString = sink match { + case s: MemorySink => s.toDebugString + case s: MemorySinkV2 => s.toDebugString + } s""" |== Progress == |$testActions @@ -321,12 +354,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be |${if (streamThreadDeathCause != null) stackTraceToString(streamThreadDeathCause) else ""} | |== Sink == - |${sink.toDebugString} + |$sinkDebugString | | |== Plan == |${if (currentStream != null) currentStream.lastExecution else ""} """.stripMargin + } def verify(condition: => Boolean, message: String): Unit = { if (!condition) { @@ -383,7 +417,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - try if (lastOnly) sink.latestBatchData else sink.allData catch { + val (latestBatchData, allData) = sink match { + case s: MemorySink => (s.latestBatchData, s.allData) + case s: MemorySinkV2 => (s.latestBatchData, s.allData) + } + try if (lastOnly) latestBatchData else allData catch { case e: Exception => failTest("Exception while getting data from sink", e) } @@ -423,6 +461,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be None, Some(metadataRoot), stream, + Map(), sink, outputMode, trigger = trigger, @@ -594,6 +633,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be error => failTest(error) } + case CheckAnswerRowsContains(expectedAnswer, lastOnly) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } + case CheckAnswerRowsByFunc(checkFunction, lastOnly) => val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) sparkAnswer.foreach { row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala new file mode 100644 index 0000000000000..eda0d8ad48313 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import java.io.{File, InterruptedIOException, IOException, UncheckedIOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit} + +import scala.reflect.ClassTag +import scala.util.control.ControlThrowable + +import com.google.common.util.concurrent.UncheckedExecutionException +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.Range +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.execution.command.ExplainCommand +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.test.TestSparkSession +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +class ContinuousSuiteBase extends StreamTest { + // We need more than the default local[2] to be able to schedule all partitions simultaneously. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + protected def waitForRateSourceTriggers(query: StreamExecution, numTriggers: Int): Unit = { + query match { + case s: ContinuousExecution => + assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") + val reader = s.lastExecution.executedPlan.collectFirst { + case DataSourceV2ScanExec(_, r: ContinuousRateStreamReader) => r + }.get + + val deltaMs = numTriggers * 1000 + 300 + while (System.currentTimeMillis < reader.creationTime + deltaMs) { + Thread.sleep(reader.creationTime + deltaMs - System.currentTimeMillis) + } + } + } + + // A continuous trigger that will only fire the initial time for the duration of a test. + // This allows clean testing with manual epoch advancement. + protected val longContinuousTrigger = Trigger.Continuous("1 hour") +} + +class ContinuousSuite extends ContinuousSuiteBase { + import testImplicits._ + + test("basic rate source") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + StopStream, + StartStream(longContinuousTrigger), + AwaitEpoch(2), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), + StopStream) + } + + test("map") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + .map(r => r.getLong(0) * 2) + + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + Execute(waitForRateSourceTriggers(_, 4)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 40, 2).map(Row(_)))) + } + + test("flatMap") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + .flatMap(r => Seq(0, r.getLong(0), r.getLong(0) * 2)) + + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + Execute(waitForRateSourceTriggers(_, 4)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 20).flatMap(n => Seq(0, n, n * 2)).map(Row(_)))) + } + + test("filter") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + .where('value > 5) + + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + Execute(waitForRateSourceTriggers(_, 4)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + } + + test("deduplicate") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + .dropDuplicates() + + val except = intercept[AnalysisException] { + testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + } + + assert(except.message.contains( + "Continuous processing does not support Deduplicate operations.")) + } + + test("timestamp") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select(current_timestamp()) + + val except = intercept[AnalysisException] { + testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + } + + assert(except.message.contains( + "Continuous processing does not support current time operations.")) + } + + test("repeatedly restart") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + StopStream, + StartStream(longContinuousTrigger), + StopStream, + StartStream(longContinuousTrigger), + StopStream, + StartStream(longContinuousTrigger), + AwaitEpoch(2), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), + StopStream) + } + + test("query without test harness") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "2") + .option("rowsPerSecond", "2") + .load() + .select('value) + val query = df.writeStream + .format("memory") + .queryName("noharness") + .trigger(Trigger.Continuous(100)) + .start() + val continuousExecution = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.asInstanceOf[ContinuousExecution] + continuousExecution.awaitEpoch(0) + waitForRateSourceTriggers(continuousExecution, 2) + query.stop() + + val results = spark.read.table("noharness").collect() + assert(Set(0, 1, 2, 3).map(Row(_)).subsetOf(results.toSet)) + } +} + +class ContinuousStressSuite extends ContinuousSuiteBase { + import testImplicits._ + + test("only one epoch") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "500") + .load() + .select('value) + + testStream(df, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 201)), + IncrementEpoch(), + Execute { query => + val data = query.sink.asInstanceOf[MemorySinkV2].allData + val vals = data.map(_.getLong(0)).toSet + assert(scala.Range(0, 25000).forall { i => + vals.contains(i) + }) + }) + } + + test("automatic epoch advancement") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "500") + .load() + .select('value) + + testStream(df, useV2Sink = true)( + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 201)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) + } + + test("restarts") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "500") + .load() + .select('value) + + testStream(df, useV2Sink = true)( + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(10), + StopStream, + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(20), + StopStream, + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(21), + StopStream, + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(22), + StopStream, + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(25), + StopStream, + StartStream(Trigger.Continuous(2012)), + StopStream, + StartStream(Trigger.Continuous(2012)), + AwaitEpoch(50), + CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) + } +}