diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index be3ba751af69..1f997592dbfb 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -511,6 +511,7 @@ private[spark] object LogKeys { case object NUM_ITERATIONS extends LogKey case object NUM_KAFKA_PULLS extends LogKey case object NUM_KAFKA_RECORDS_PULLED extends LogKey + case object NUM_LAGGING_STORES extends LogKey case object NUM_LEADING_SINGULAR_VALUES extends LogKey case object NUM_LEFT_PARTITION_VALUES extends LogKey case object NUM_LOADED_ENTRIES extends LogKey @@ -751,6 +752,9 @@ private[spark] object LogKeys { case object SLEEP_TIME extends LogKey case object SLIDE_DURATION extends LogKey case object SMALLEST_CLUSTER_INDEX extends LogKey + case object SNAPSHOT_EVENT extends LogKey + case object SNAPSHOT_EVENT_TIME_DELTA extends LogKey + case object SNAPSHOT_EVENT_VERSION_DELTA extends LogKey case object SNAPSHOT_VERSION extends LogKey case object SOCKET_ADDRESS extends LogKey case object SOURCE extends LogKey diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 93acb39944fa..c8571a58f9d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2332,6 +2332,70 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG = + buildConf("spark.sql.streaming.stateStore.multiplierForMinVersionDiffToLog") + .internal() + .doc( + "Determines the version threshold for logging warnings when a state store falls behind. " + + "The coordinator logs a warning when the store's uploaded snapshot version trails the " + + "query's latest version by the configured number of deltas needed to create a snapshot, " + + "times this multiplier." + ) + .version("4.1.0") + .longConf + .checkValue(k => k >= 1L, "Must be greater than or equal to 1") + .createWithDefault(5L) + + val STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG = + buildConf("spark.sql.streaming.stateStore.multiplierForMinTimeDiffToLog") + .internal() + .doc( + "Determines the time threshold for logging warnings when a state store falls behind. " + + "The coordinator logs a warning when the store's uploaded snapshot timestamp trails the " + + "current time by the configured maintenance interval, times this multiplier." + ) + .version("4.1.0") + .longConf + .checkValue(k => k >= 1L, "Must be greater than or equal to 1") + .createWithDefault(10L) + + val STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG = + buildConf("spark.sql.streaming.stateStore.coordinatorReportSnapshotUploadLag") + .internal() + .doc( + "When enabled, the state store coordinator will report state stores whose snapshot " + + "have not been uploaded for some time. See the conf snapshotLagReportInterval for " + + "the minimum time between reports, and the conf multiplierForMinVersionDiffToLog " + + "and multiplierForMinTimeDiffToLog for the logging thresholds." + ) + .version("4.1.0") + .booleanConf + .createWithDefault(true) + + val STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL = + buildConf("spark.sql.streaming.stateStore.snapshotLagReportInterval") + .internal() + .doc( + "The minimum amount of time between the state store coordinator's reports on " + + "state store instances trailing behind in snapshot uploads." + ) + .version("4.1.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(TimeUnit.MINUTES.toMillis(5)) + + val STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT = + buildConf("spark.sql.streaming.stateStore.maxLaggingStoresToReport") + .internal() + .doc( + "Maximum number of state stores the coordinator will report as trailing in " + + "snapshot uploads. Stores are selected based on the most lagging behind in " + + "snapshot version." + ) + .version("4.1.0") + .intConf + .checkValue(k => k >= 0, "Must be greater than or equal to 0") + .createWithDefault(5) + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") .internal() @@ -5931,6 +5995,21 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def stateStoreSkipNullsForStreamStreamJoins: Boolean = getConf(STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS) + def stateStoreCoordinatorMultiplierForMinVersionDiffToLog: Long = + getConf(STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG) + + def stateStoreCoordinatorMultiplierForMinTimeDiffToLog: Long = + getConf(STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG) + + def stateStoreCoordinatorReportSnapshotUploadLag: Boolean = + getConf(STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG) + + def stateStoreCoordinatorSnapshotLagReportInterval: Long = + getConf(STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) + + def stateStoreCoordinatorMaxLaggingStoresToReport: Int = + getConf(STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT) + def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala index 6ce6f06de113..6d4a3ecd3603 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala @@ -53,7 +53,7 @@ class StreamingQueryManager private[sql] ( with Logging { private[sql] val stateStoreCoordinator = - StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) + StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env, sqlConf) private val listenerBus = new StreamingQueryListenerBus(Some(sparkSession.sparkContext.listenerBus)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 246057a5a9d0..8c1e5e901513 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -71,7 +71,8 @@ class IncrementalExecution( MutableMap[Long, Array[Array[String]]] = MutableMap[Long, Array[Array[String]]](), val stateSchemaMetadatas: MutableMap[Long, StateSchemaBroadcast] = MutableMap[Long, StateSchemaBroadcast](), - mode: CommandExecutionMode.Value = CommandExecutionMode.ALL) + mode: CommandExecutionMode.Value = CommandExecutionMode.ALL, + val isTerminatingTrigger: Boolean = false) extends QueryExecution(sparkSession, logicalPlan, mode = mode) with Logging { // Modified planner with stateful operations. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index c977a499edc0..1dd70ad985cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -858,7 +858,8 @@ class MicroBatchExecution( watermarkPropagator, execCtx.previousContext.isEmpty, currentStateStoreCkptId, - stateSchemaMetadatas) + stateSchemaMetadatas, + isTerminatingTrigger = trigger.isInstanceOf[AvailableNowTrigger.type]) execCtx.executionPlan.executedPlan // Force the lazy generation of execution plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d814a86c84c7..dc04ba3331e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream} import org.apache.spark.sql.execution.{QueryExecution, StreamSourceAwareSparkPlan} import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress} +import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent} import org.apache.spark.util.{Clock, Utils} @@ -61,6 +62,12 @@ class ProgressReporter( val noDataProgressEventInterval: Long = sparkSession.sessionState.conf.streamingNoDataProgressEventInterval + val coordinatorReportSnapshotUploadLag: Boolean = + sparkSession.sessionState.conf.stateStoreCoordinatorReportSnapshotUploadLag + + val stateStoreCoordinator: StateStoreCoordinatorRef = + sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + private val timestampFormat = DateTimeFormatter .ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 @@ -283,6 +290,17 @@ abstract class ProgressContext( progressReporter.lastNoExecutionProgressEventTime = triggerClock.getTimeMillis() progressReporter.updateProgress(newProgress) + // Ask the state store coordinator to log all lagging state stores + if (progressReporter.coordinatorReportSnapshotUploadLag) { + val latestVersion = lastEpochId + 1 + progressReporter.stateStoreCoordinator + .logLaggingStateStores( + lastExecution.runId, + latestVersion, + lastExecution.isTerminatingTrigger + ) + } + // Update the value since this trigger executes a batch successfully. this.execStatsOnLatestExecutedBatch = Some(execStats) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 648fe0f5b1fd..98d49596d11b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ import java.util -import java.util.Locale +import java.util.{Locale, UUID} import java.util.concurrent.atomic.{AtomicLong, LongAdder} import scala.collection.mutable @@ -551,6 +551,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with val snapshotCurrentVersionMap = readSnapshotFile(version) if (snapshotCurrentVersionMap.isDefined) { synchronized { putStateIntoStateCacheMap(version, snapshotCurrentVersionMap.get) } + + // Report the loaded snapshot's version to the coordinator + reportSnapshotUploadToCoordinator(version) + return snapshotCurrentVersionMap.get } @@ -580,6 +584,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with } synchronized { putStateIntoStateCacheMap(version, resultMap) } + + // Report the last available snapshot's version to the coordinator + reportSnapshotUploadToCoordinator(lastAvailableVersion) + resultMap } @@ -699,6 +707,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with log"for ${MDC(LogKeys.OP_TYPE, opType)}") // Compare and update with the version that was just uploaded. lastUploadedSnapshotVersion.updateAndGet(v => Math.max(version, v)) + // Report the snapshot upload event to the coordinator + reportSnapshotUploadToCoordinator(version) } /** @@ -1043,6 +1053,18 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), keySchema, valueSchema) } + + /** Reports to the coordinator the store's latest snapshot version */ + private def reportSnapshotUploadToCoordinator(version: Long): Unit = { + if (storeConf.reportSnapshotUploadLag) { + // Attach the query run ID and current timestamp to the RPC message + val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf)) + val currentTimestamp = System.currentTimeMillis() + StateStoreProvider.coordinatorRef.foreach( + _.snapshotUploaded(StateStoreProviderId(stateStoreId, runId), version, currentTimestamp) + ) + } + } } /** [[StateStoreChangeDataReader]] implementation for [[HDFSBackedStateStoreProvider]] */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 15df2fae8260..07553f51c60e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -64,6 +64,7 @@ case object StoreTaskCompletionListener extends RocksDBOpType("store_task_comple * @param stateStoreId StateStoreId for the state store * @param localRootDir Root directory in local disk that is used to working and checkpointing dirs * @param hadoopConf Hadoop configuration for talking to the remote file system + * @param eventForwarder The RocksDBEventForwarder object for reporting events to the coordinator */ class RocksDB( dfsRootDir: String, @@ -73,7 +74,8 @@ class RocksDB( loggingId: String = "", useColumnFamilies: Boolean = false, enableStateStoreCheckpointIds: Boolean = false, - partitionId: Int = 0) extends Logging { + partitionId: Int = 0, + eventForwarder: Option[RocksDBEventForwarder] = None) extends Logging { import RocksDB._ @@ -403,6 +405,9 @@ class RocksDB( // Initialize maxVersion upon successful load from DFS fileManager.setMaxSeenVersion(version) + // Report this snapshot version to the coordinator + reportSnapshotUploadToCoordinator(latestSnapshotVersion) + openLocalRocksDB(metadata) if (loadedVersion != version) { @@ -480,6 +485,9 @@ class RocksDB( // Initialize maxVersion upon successful load from DFS fileManager.setMaxSeenVersion(version) + // Report this snapshot version to the coordinator + reportSnapshotUploadToCoordinator(latestSnapshotVersion) + openLocalRocksDB(metadata) if (loadedVersion != version) { @@ -617,6 +625,8 @@ class RocksDB( loadedVersion = -1 // invalidate loaded data throw t } + // Report this snapshot version to the coordinator + reportSnapshotUploadToCoordinator(snapshotVersion) this } @@ -1495,6 +1505,8 @@ class RocksDB( log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}") // Compare and update with the version that was just uploaded. lastUploadedSnapshotVersion.updateAndGet(v => Math.max(snapshot.version, v)) + // Report snapshot upload event to the coordinator. + reportSnapshotUploadToCoordinator(snapshot.version) } finally { snapshot.close() } @@ -1502,6 +1514,16 @@ class RocksDB( fileManagerMetrics } + /** Reports to the coordinator with the event listener that a snapshot finished uploading */ + private def reportSnapshotUploadToCoordinator(version: Long): Unit = { + if (conf.reportSnapshotUploadLag) { + // Note that we still report snapshot versions even when changelog checkpointing is disabled. + // The coordinator needs a way to determine whether upload messages are disabled or not, + // which would be different between RocksDB and HDFS stores due to changelog checkpointing. + eventForwarder.foreach(_.reportSnapshotUploaded(version)) + } + } + /** Create a native RocksDB logger that forwards native logs to log4j with correct log levels. */ private def createLogger(): Logger = { val dbLogger = new Logger(rocksDbOptions.infoLogLevel()) { @@ -1768,7 +1790,8 @@ case class RocksDBConf( highPriorityPoolRatio: Double, compressionCodec: String, allowFAllocate: Boolean, - compression: String) + compression: String, + reportSnapshotUploadLag: Boolean) object RocksDBConf { /** Common prefix of all confs in SQLConf that affects RocksDB */ @@ -1951,7 +1974,8 @@ object RocksDBConf { getRatioConf(HIGH_PRIORITY_POOL_RATIO_CONF), storeConf.compressionCodec, getBooleanConf(ALLOW_FALLOCATE_CONF), - getStringConf(COMPRESSION_CONF)) + getStringConf(COMPRESSION_CONF), + storeConf.reportSnapshotUploadLag) } def apply(): RocksDBConf = apply(new StateStoreConf()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 601caaa34290..6a36b8c01519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StreamExecution} +import org.apache.spark.sql.execution.streaming.CheckpointFileManager import org.apache.spark.sql.execution.streaming.state.StateStoreEncoding.Avro import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform @@ -67,7 +67,7 @@ private[sql] class RocksDBStateStoreProvider verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) val cfId = rocksDB.createColFamilyIfAbsent(colFamilyName, isInternal) val dataEncoderCacheKey = StateRowEncoderCacheKey( - queryRunId = getRunId(hadoopConf), + queryRunId = StateStoreProvider.getRunId(hadoopConf), operatorId = stateStoreId.operatorId, partitionId = stateStoreId.partitionId, stateStoreName = stateStoreId.storeName, @@ -390,6 +390,8 @@ private[sql] class RocksDBStateStoreProvider this.useColumnFamilies = useColumnFamilies this.stateStoreEncoding = storeConf.stateStoreEncodingFormat this.stateSchemaProvider = stateSchemaProvider + this.rocksDBEventForwarder = + Some(RocksDBEventForwarder(StateStoreProvider.getRunId(hadoopConf), stateStoreId)) if (useMultipleValuesPerKey) { require(useColumnFamilies, "Multiple values per key support requires column families to be" + @@ -399,7 +401,7 @@ private[sql] class RocksDBStateStoreProvider rocksDB // lazy initialization val dataEncoderCacheKey = StateRowEncoderCacheKey( - queryRunId = getRunId(hadoopConf), + queryRunId = StateStoreProvider.getRunId(hadoopConf), operatorId = stateStoreId.operatorId, partitionId = stateStoreId.partitionId, stateStoreName = stateStoreId.storeName, @@ -523,6 +525,7 @@ private[sql] class RocksDBStateStoreProvider @volatile private var useColumnFamilies: Boolean = _ @volatile private var stateStoreEncoding: String = _ @volatile private var stateSchemaProvider: Option[StateSchemaProvider] = _ + @volatile private var rocksDBEventForwarder: Option[RocksDBEventForwarder] = _ protected def createRocksDB( dfsRootDir: String, @@ -532,7 +535,8 @@ private[sql] class RocksDBStateStoreProvider loggingId: String, useColumnFamilies: Boolean, enableStateStoreCheckpointIds: Boolean, - partitionId: Int = 0): RocksDB = { + partitionId: Int = 0, + eventForwarder: Option[RocksDBEventForwarder] = None): RocksDB = { new RocksDB( dfsRootDir, conf, @@ -541,7 +545,8 @@ private[sql] class RocksDBStateStoreProvider loggingId, useColumnFamilies, enableStateStoreCheckpointIds, - partitionId) + partitionId, + eventForwarder) } private[sql] lazy val rocksDB = { @@ -551,7 +556,8 @@ private[sql] class RocksDBStateStoreProvider val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr) createRocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr, - useColumnFamilies, storeConf.enableStateStoreCheckpointIds, stateStoreId.partitionId) + useColumnFamilies, storeConf.enableStateStoreCheckpointIds, stateStoreId.partitionId, + rocksDBEventForwarder) } private val keyValueEncoderMap = new java.util.concurrent.ConcurrentHashMap[String, @@ -822,16 +828,6 @@ object RocksDBStateStoreProvider { ) } - private def getRunId(hadoopConf: Configuration): String = { - val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY) - if (runId != null) { - runId - } else { - assert(Utils.isTesting, "Failed to find query id/batch Id in task context") - UUID.randomUUID().toString - } - } - // Native operation latencies report as latency in microseconds // as SQLMetrics support millis. Convert the value to millis val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric( @@ -991,3 +987,33 @@ class RocksDBStateStoreChangeDataReader( } } } + +/** + * Class used to relay events reported from a RocksDB instance to the state store coordinator. + * + * We pass this into the RocksDB instance to report specific events like snapshot uploads. + * This should only be used to report back to the coordinator for metrics and monitoring purposes. + */ +private[state] case class RocksDBEventForwarder(queryRunId: String, stateStoreId: StateStoreId) { + // Build the state store provider ID from the query run ID and the state store ID + private val providerId = StateStoreProviderId(stateStoreId, UUID.fromString(queryRunId)) + + /** + * Callback function from RocksDB to report events to the coordinator. + * Information from the store provider such as the state store ID and query run ID are + * attached here to report back to the coordinator. + * + * @param version The snapshot version that was just uploaded from RocksDB + */ + def reportSnapshotUploaded(version: Long): Unit = { + // Report the state store provider ID and the version to the coordinator + val currentTimestamp = System.currentTimeMillis() + StateStoreProvider.coordinatorRef.foreach( + _.snapshotUploaded( + providerId, + version, + currentTimestamp + ) + ) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ccb925287e77..63936305c7cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -593,7 +593,15 @@ trait StateStoreProvider { def supportedInstanceMetrics: Seq[StateStoreInstanceMetric] = Seq.empty } -object StateStoreProvider { +object StateStoreProvider extends Logging { + + /** + * The state store coordinator reference used to report events such as snapshot uploads from + * the state store providers. + * For all other messages, refer to the coordinator reference in the [[StateStore]] object. + */ + @GuardedBy("this") + private var stateStoreCoordinatorRef: StateStoreCoordinatorRef = _ /** * Return a instance of the given provider class name. The instance will not be initialized. @@ -652,6 +660,47 @@ object StateStoreProvider { } } } + + /** + * Get the runId from the provided hadoopConf. If it is not found, generate a random UUID. + * + * @param hadoopConf Hadoop configuration used by the StateStore to save state data + */ + private[state] def getRunId(hadoopConf: Configuration): String = { + val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY) + if (runId != null) { + runId + } else { + assert(Utils.isTesting, "Failed to find query id/batch Id in task context") + UUID.randomUUID().toString + } + } + + /** + * Create the state store coordinator reference which will be reused across state store providers + * in the executor. + * This coordinator reference should only be used to report events from store providers regarding + * snapshot uploads to avoid lock contention with other coordinator RPC messages. + */ + private[state] def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized { + val env = SparkEnv.get + if (env != null) { + val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER + // If running locally, then the coordinator reference in stateStoreCoordinatorRef may have + // become inactive as SparkContext + SparkEnv may have been restarted. Hence, when running in + // driver, always recreate the reference. + if (isDriver || stateStoreCoordinatorRef == null) { + logDebug("Getting StateStoreCoordinatorRef") + stateStoreCoordinatorRef = StateStoreCoordinatorRef.forExecutor(env) + } + logInfo(log"Retrieved reference to StateStoreCoordinator: " + + log"${MDC(LogKeys.STATE_STORE_COORDINATOR, stateStoreCoordinatorRef)}") + Some(stateStoreCoordinatorRef) + } else { + stateStoreCoordinatorRef = null + None + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 807534ee4569..e0450cfc4f69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -97,6 +97,12 @@ class StateStoreConf( val enableStateStoreCheckpointIds = StatefulOperatorStateInfo.enableStateStoreCheckpointIds(sqlConf) + /** + * Whether the coordinator is reporting state stores trailing behind in snapshot uploads. + */ + val reportSnapshotUploadLag: Boolean = + sqlConf.stateStoreCoordinatorReportSnapshotUploadLag + /** * Additional configurations related to state store. This will capture all configs in * SQLConf that start with `spark.sql.streaming.stateStore.` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 84b77efea3ca..903f27fb2a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -22,9 +22,10 @@ import java.util.UUID import scala.collection.mutable import org.apache.spark.SparkEnv -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.RpcUtils /** Trait representing all messages to [[StateStoreCoordinator]] */ @@ -55,6 +56,45 @@ private case class GetLocation(storeId: StateStoreProviderId) private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage +/** + * This message is used to report a state store has just finished uploading a snapshot, + * along with the timestamp in milliseconds and the snapshot version. + */ +private case class ReportSnapshotUploaded( + providerId: StateStoreProviderId, + version: Long, + timestamp: Long) + extends StateStoreCoordinatorMessage + +/** + * This message is used for the coordinator to look for all state stores that are lagging behind + * in snapshot uploads. The coordinator will then log a warning message for each lagging instance. + */ +private case class LogLaggingStateStores( + queryRunId: UUID, + latestVersion: Long, + isTerminatingTrigger: Boolean) + extends StateStoreCoordinatorMessage + +/** + * Message used for testing. + * This message is used to retrieve the latest snapshot version reported for upload from a + * specific state store. + */ +private case class GetLatestSnapshotVersionForTesting(providerId: StateStoreProviderId) + extends StateStoreCoordinatorMessage + +/** + * Message used for testing. + * This message is used to retrieve all active state store instances falling behind in + * snapshot uploads, using version and time criteria. + */ +private case class GetLaggingStoresForTesting( + queryRunId: UUID, + latestVersion: Long, + isTerminatingTrigger: Boolean) + extends StateStoreCoordinatorMessage + private object StopCoordinator extends StateStoreCoordinatorMessage @@ -66,9 +106,9 @@ object StateStoreCoordinatorRef extends Logging { /** * Create a reference to a [[StateStoreCoordinator]] */ - def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + def forDriver(env: SparkEnv, sqlConf: SQLConf): StateStoreCoordinatorRef = synchronized { try { - val coordinator = new StateStoreCoordinator(env.rpcEnv) + val coordinator = new StateStoreCoordinator(env.rpcEnv, sqlConf) val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator) logInfo("Registered StateStoreCoordinator endpoint") new StateStoreCoordinatorRef(coordinatorRef) @@ -119,6 +159,46 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId)) } + /** Inform that an executor has uploaded a snapshot */ + private[sql] def snapshotUploaded( + providerId: StateStoreProviderId, + version: Long, + timestamp: Long): Boolean = { + rpcEndpointRef.askSync[Boolean](ReportSnapshotUploaded(providerId, version, timestamp)) + } + + /** Ask the coordinator to log all state store instances that are lagging behind in uploads */ + private[sql] def logLaggingStateStores( + queryRunId: UUID, + latestVersion: Long, + isTerminatingTrigger: Boolean): Boolean = { + rpcEndpointRef.askSync[Boolean]( + LogLaggingStateStores(queryRunId, latestVersion, isTerminatingTrigger)) + } + + /** + * Endpoint used for testing. + * Get the latest snapshot version uploaded for a state store. + */ + private[state] def getLatestSnapshotVersionForTesting( + providerId: StateStoreProviderId): Option[Long] = { + rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersionForTesting(providerId)) + } + + /** + * Endpoint used for testing. + * Get the state store instances that are falling behind in snapshot uploads for a particular + * query run. + */ + private[state] def getLaggingStoresForTesting( + queryRunId: UUID, + latestVersion: Long, + isTerminatingTrigger: Boolean = false): Seq[StateStoreProviderId] = { + rpcEndpointRef.askSync[Seq[StateStoreProviderId]]( + GetLaggingStoresForTesting(queryRunId, latestVersion, isTerminatingTrigger) + ) + } + private[state] def stop(): Unit = { rpcEndpointRef.askSync[Boolean](StopCoordinator) } @@ -129,10 +209,30 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, * and get their locations for job scheduling. */ -private class StateStoreCoordinator(override val rpcEnv: RpcEnv) - extends ThreadSafeRpcEndpoint with Logging { +private class StateStoreCoordinator( + override val rpcEnv: RpcEnv, + val sqlConf: SQLConf) + extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] + // Stores the latest snapshot upload event for a specific state store + private val stateStoreLatestUploadedSnapshot = + new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent] + + // Default snapshot upload event to use when a provider has never uploaded a snapshot + private val defaultSnapshotUploadEvent = SnapshotUploadEvent(0, 0) + + // Stores the last timestamp in milliseconds for each queryRunId indicating when the + // coordinator did a report on instances lagging behind on snapshot uploads. + // The initial timestamp is defaulted to 0 milliseconds. + private val lastFullSnapshotLagReportTimeMs = new mutable.HashMap[UUID, Long] + + private def shouldCoordinatorReportSnapshotLag: Boolean = + sqlConf.stateStoreCoordinatorReportSnapshotUploadLag + + private def coordinatorLagReportInterval: Long = + sqlConf.stateStoreCoordinatorSnapshotLagReportInterval + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => logDebug(s"Reported state store $id is active at $executorId") @@ -164,13 +264,160 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) val storeIdsToRemove = instances.keys.filter(_.queryRunId == runId).toSeq instances --= storeIdsToRemove + // Also remove these instances from snapshot upload event tracking + stateStoreLatestUploadedSnapshot --= storeIdsToRemove + // Remove the corresponding run id entries for report time and starting time + lastFullSnapshotLagReportTimeMs -= runId logDebug(s"Deactivating instances related to checkpoint location $runId: " + storeIdsToRemove.mkString(", ")) context.reply(true) + case ReportSnapshotUploaded(providerId, version, timestamp) => + // Ignore this upload event if the registered latest version for the store is more recent, + // since it's possible that an older version gets uploaded after a new executor uploads for + // the same state store but with a newer snapshot. + logDebug(s"Snapshot version $version was uploaded for state store $providerId") + if (!stateStoreLatestUploadedSnapshot.get(providerId).exists(_.version >= version)) { + stateStoreLatestUploadedSnapshot.put(providerId, SnapshotUploadEvent(version, timestamp)) + } + context.reply(true) + + case LogLaggingStateStores(queryRunId, latestVersion, isTerminatingTrigger) => + val currentTimestamp = System.currentTimeMillis() + // Only log lagging instances if snapshot lag reporting and uploading is enabled, + // otherwise all instances will be considered lagging. + if (shouldCoordinatorReportSnapshotLag) { + val laggingStores = + findLaggingStores(queryRunId, latestVersion, currentTimestamp, isTerminatingTrigger) + if (laggingStores.nonEmpty) { + logWarning( + log"StateStoreCoordinator Snapshot Lag Report for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Number of state stores falling behind: " + + log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" + ) + // Report all stores that are behind in snapshot uploads. + // Only report the list of providers lagging behind if the last reported time + // is not recent for this query run. The lag report interval denotes the minimum + // time between these full reports. + val timeSinceLastReport = + currentTimestamp - lastFullSnapshotLagReportTimeMs.getOrElse(queryRunId, 0L) + if (timeSinceLastReport > coordinatorLagReportInterval) { + // Mark timestamp of the report and log the lagging instances + lastFullSnapshotLagReportTimeMs.put(queryRunId, currentTimestamp) + // Only report the stores that are lagging the most behind in snapshot uploads. + laggingStores + .sortBy(stateStoreLatestUploadedSnapshot.getOrElse(_, defaultSnapshotUploadEvent)) + .take(sqlConf.stateStoreCoordinatorMaxLaggingStoresToReport) + .foreach { providerId => + val baseLogMessage = + log"StateStoreCoordinator Snapshot Lag Detected for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, providerId.storeId)} " + + log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}" + + val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match { + case Some(snapshotEvent) => + val versionDelta = latestVersion - snapshotEvent.version + val timeDelta = currentTimestamp - snapshotEvent.timestamp + + baseLogMessage + log", " + + log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, snapshotEvent)}, " + + log"version delta: " + + log"${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + + log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" + case None => + baseLogMessage + log", latest snapshot: no upload for query run)" + } + logWarning(logMessage) + } + } + } + } + context.reply(true) + + case GetLatestSnapshotVersionForTesting(providerId) => + val version = stateStoreLatestUploadedSnapshot.get(providerId).map(_.version) + logDebug(s"Got latest snapshot version of the state store $providerId: $version") + context.reply(version) + + case GetLaggingStoresForTesting(queryRunId, latestVersion, isTerminatingTrigger) => + val currentTimestamp = System.currentTimeMillis() + // Only report if snapshot lag reporting is enabled + if (shouldCoordinatorReportSnapshotLag) { + val laggingStores = + findLaggingStores(queryRunId, latestVersion, currentTimestamp, isTerminatingTrigger) + logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}") + context.reply(laggingStores) + } else { + context.reply(Seq.empty) + } + case StopCoordinator => stop() // Stop before replying to ensure that endpoint name has been deregistered logInfo("StateStoreCoordinator stopped") context.reply(true) } + + private def findLaggingStores( + queryRunId: UUID, + referenceVersion: Long, + referenceTimestamp: Long, + isTerminatingTrigger: Boolean): Seq[StateStoreProviderId] = { + // Determine alert thresholds from configurations for both time and version differences. + val snapshotVersionDeltaMultiplier = + sqlConf.stateStoreCoordinatorMultiplierForMinVersionDiffToLog + val maintenanceIntervalMultiplier = sqlConf.stateStoreCoordinatorMultiplierForMinTimeDiffToLog + val minDeltasForSnapshot = sqlConf.stateStoreMinDeltasForSnapshot + val maintenanceInterval = sqlConf.streamingMaintenanceInterval + + // Use the configured multipliers multiplierForMinVersionDiffToLog and + // multiplierForMinTimeDiffToLog to determine the proper alert thresholds. + val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * minDeltasForSnapshot + val minTimeDeltaForLogging = maintenanceIntervalMultiplier * maintenanceInterval + + // Look for active state store providers that are lagging behind in snapshot uploads. + // The coordinator should only consider providers that are part of this specific query run. + instances.view.keys + .filter(_.queryRunId == queryRunId) + .filter { storeProviderId => + // Stores that didn't upload a snapshot will be treated as a store with a snapshot of + // version 0 and timestamp 0ms. + val latestSnapshot = stateStoreLatestUploadedSnapshot.getOrElse( + storeProviderId, + defaultSnapshotUploadEvent + ) + // Mark a state store as lagging if it's behind in both version and time. + // A state store is considered lagging if it's behind in both version and time according + // to the configured thresholds. + val isBehindOnVersions = + referenceVersion - latestSnapshot.version > minVersionDeltaForLogging + val isBehindOnTime = + referenceTimestamp - latestSnapshot.timestamp > minTimeDeltaForLogging + // If the query is using a trigger that self-terminates like OneTimeTrigger + // and AvailableNowTrigger, we ignore the time threshold check as the upload frequency + // is not fully dependent on the maintenance interval. + isBehindOnVersions && (isTerminatingTrigger || isBehindOnTime) + }.toSeq + } +} + +case class SnapshotUploadEvent( + version: Long, + timestamp: Long +) extends Ordered[SnapshotUploadEvent] { + + override def compare(otherEvent: SnapshotUploadEvent): Int = { + // Compare by version first, then by timestamp as tiebreaker + val versionCompare = this.version.compare(otherEvent.version) + if (versionCompare == 0) { + this.timestamp.compare(otherEvent.timestamp) + } else { + versionCompare + } + } + + override def toString(): String = { + s"SnapshotUploadEvent(version=$version, timestamp=$timestamp)" + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index bffd2c5d9f70..692b5c0ebc3a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -63,6 +63,9 @@ public void setUp() { spark = new TestSparkSession(); jsc = new JavaSparkContext(spark.sparkContext()); spark.loadTestData(); + + // Initialize state store coordinator endpoint + spark.streams().stateStoreCoordinator(); } @AfterEach diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala index 4711a45804fb..ebbdd1ad63ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala @@ -251,7 +251,8 @@ class FailureInjectionRocksDBStateStoreProvider extends RocksDBStateStoreProvide loggingId: String, useColumnFamilies: Boolean, enableStateStoreCheckpointIds: Boolean, - partitionId: Int): RocksDB = { + partitionId: Int, + eventForwarder: Option[RocksDBEventForwarder] = None): RocksDB = { FailureInjectionRocksDBStateStoreProvider.createRocksDBWithFaultInjection( dfsRootDir, conf, @@ -260,7 +261,8 @@ class FailureInjectionRocksDBStateStoreProvider extends RocksDBStateStoreProvide loggingId, useColumnFamilies, enableStateStoreCheckpointIds, - partitionId) + partitionId, + eventForwarder) } } @@ -277,7 +279,8 @@ object FailureInjectionRocksDBStateStoreProvider { loggingId: String, useColumnFamilies: Boolean, enableStateStoreCheckpointIds: Boolean, - partitionId: Int): RocksDB = { + partitionId: Int, + eventForwarder: Option[RocksDBEventForwarder]): RocksDB = { new RocksDB( dfsRootDir, conf = conf, @@ -286,7 +289,8 @@ object FailureInjectionRocksDBStateStoreProvider { loggingId = loggingId, useColumnFamilies = useColumnFamilies, enableStateStoreCheckpointIds = enableStateStoreCheckpointIds, - partitionId = partitionId + partitionId = partitionId, + eventForwarder = eventForwarder ) { override def createFileManager( dfsRootDir: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala index 31fc51c4d56f..5c24ec209036 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala @@ -523,7 +523,8 @@ class RocksDBCheckpointFailureInjectionSuite extends StreamTest loggingId = s"[Thread-${Thread.currentThread.getId}]", useColumnFamilies = true, enableStateStoreCheckpointIds = enableStateStoreCheckpointIds, - partitionId = 0) + partitionId = 0, + eventForwarder = None) db.load(version, checkpointId) func(db) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 5aea0077e2aa..b13508682188 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -52,6 +52,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid before { StateStore.stop() require(!StateStore.isMaintenanceRunning) + spark.streams.stateStoreCoordinator // initialize the lazy coordinator } after { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 2ebc533f7137..09118edc4357 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -26,8 +26,10 @@ import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} -import org.apache.spark.sql.functions.count -import org.apache.spark.sql.internal.SQLConf.SHUFFLE_PARTITIONS +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} +import org.apache.spark.sql.functions.{count, expr} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.{StreamingQuery, StreamTest, Trigger} import org.apache.spark.util.Utils class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { @@ -102,7 +104,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("multiple references have same underlying coordinator") { withCoordinatorRef(sc) { coordRef1 => - val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) + val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env, new SQLConf) val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) @@ -125,7 +127,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { import spark.implicits._ coordRef = spark.streams.stateStoreCoordinator implicit val sqlContext = spark.sqlContext - spark.conf.set(SHUFFLE_PARTITIONS.key, "1") + spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") // Start a query and run a batch to load state stores val inputData = MemoryStream[Int] @@ -155,16 +157,622 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { StateStore.stop() } } + + private val allJoinStateStoreNames: Seq[String] = + SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) + + /** Lists the state store providers used for a test, and the set of lagging partition IDs */ + private val regularStateStoreProviders = Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName, Set.empty[Int]), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName, Set.empty[Int]) + ) + + /** Lists the state store providers used for a test, and the set of lagging partition IDs */ + private val faultyStateStoreProviders = Seq( + ( + "RocksDBSkipMaintenanceOnCertainPartitionsProvider", + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + Set(0, 1) + ), + ( + "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName, + Set(0, 1) + ) + ) + + private val allStateStoreProviders = + regularStateStoreProviders ++ faultyStateStoreProviders + + /** + * Verifies snapshot upload RPC messages from state stores are registered and verifies + * the coordinator detected the correct lagging partitions. + */ + private def verifySnapshotUploadEvents( + coordRef: StateStoreCoordinatorRef, + query: StreamingQuery, + badPartitions: Set[Int], + storeNames: Seq[String] = Seq(StateStoreId.DEFAULT_STORE_NAME)): Unit = { + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all stores have uploaded a snapshot and it's logged by the coordinator + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + // Verify for every store name listed + storeNames.foreach { storeName => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) + val providerId = StateStoreProviderId(storeId, query.runId) + val latestSnapshotVersion = coordRef.getLatestSnapshotVersionForTesting(providerId) + if (badPartitions.contains(partitionId)) { + assert(latestSnapshotVersion.getOrElse(0) == 0) + } else { + assert(latestSnapshotVersion.get >= 0) + } + } + } + // Verify that only the bad partitions are all marked as lagging. + // Join queries should have all their state stores marked as lagging, + // which would be 4 stores per partition instead of 1. + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == badPartitions.size * storeNames.size) + assert(laggingStores.map(_.storeId.partitionId).toSet == badPartitions) + } + + /** Sets up a stateful dropDuplicate query for testing */ + private def setUpStatefulQuery( + inputData: MemoryStream[Int], queryName: String): StreamingQuery = { + // Set up a stateful drop duplicate query + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName(queryName) + .option("checkpointLocation", checkpointLocation.toString) + .start() + query + } + + allStateStoreProviders.foreach { case (providerName, providerClassName, badPartitions) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + val inputData = MemoryStream[Int] + val query = setUpStatefulQuery(inputData, "query") + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + // Verify only the partitions in badPartitions are marked as lagging + verifySnapshotUploadEvents(coordRef, query, badPartitions) + query.stop() + } + } + } + + allStateStoreProviders.foreach { case (providerName, providerClassName, badPartitions) => + test( + s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + + s"reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", + SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + // Start a join query and run some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = joined.writeStream + .format("memory") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 7).foreach { _ => + input1.addData(1, 5) + input2.addData(1, 5, 10) + query.processAllAvailable() + Thread.sleep(500) + } + // Verify only the partitions in badPartitions are marked as lagging + verifySnapshotUploadEvents(coordRef, query, badPartitions, allJoinStateStoreNames) + query.stop() + } + } + } + + test("SPARK-51358: Verify coordinator properly handles simultaneous query runs") { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + // Start and run two queries together with some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val query1 = setUpStatefulQuery(input1, "query1") + val query2 = setUpStatefulQuery(input2, "query2") + + // Go through several rounds of input to force snapshot uploads for both queries + (0 until 2).foreach { _ => + input1.addData(1, 2, 3) + input2.addData(1, 2, 3) + query1.processAllAvailable() + query2.processAllAvailable() + // Process twice the amount of data for the first query + input1.addData(1, 2, 3) + query1.processAllAvailable() + Thread.sleep(1000) + } + // Verify that the coordinator logged the correct lagging stores for the first query + val streamingQuery1 = query1.asInstanceOf[StreamingQueryWrapper].streamingQuery + val latestVersion1 = streamingQuery1.lastProgress.batchId + 1 + val laggingStores1 = coordRef.getLaggingStoresForTesting(query1.runId, latestVersion1) + + assert(laggingStores1.size == 2) + assert(laggingStores1.forall(_.storeId.partitionId <= 1)) + assert(laggingStores1.forall(_.queryRunId == query1.runId)) + + // Verify that the second query run hasn't reported anything yet due to lack of data + val streamingQuery2 = query2.asInstanceOf[StreamingQueryWrapper].streamingQuery + var latestVersion2 = streamingQuery2.lastProgress.batchId + 1 + var laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, latestVersion2) + assert(laggingStores2.isEmpty) + + // Process some more data for the second query to force lag reports + input2.addData(1, 2, 3) + query2.processAllAvailable() + Thread.sleep(500) + + // Verify that the coordinator logged the correct lagging stores for the second query + latestVersion2 = streamingQuery2.lastProgress.batchId + 1 + laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, latestVersion2) + + assert(laggingStores2.size == 2) + assert(laggingStores2.forall(_.storeId.partitionId <= 1)) + assert(laggingStores2.forall(_.queryRunId == query2.runId)) + } + } + + test( + "SPARK-51358: Snapshot uploads in RocksDB are not reported if changelog " + + "checkpointing is disabled" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "false", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val query = setUpStatefulQuery(inputData, "query") + + // Go through two batches to force two snapshot uploads. + // This would be enough to pass the version check for lagging stores. + inputData.addData(1, 2, 3) + query.processAllAvailable() + inputData.addData(1, 2, 3) + query.processAllAvailable() + + // Sleep for the duration of a maintenance interval - which should be enough + // to pass the time check for lagging stores. + Thread.sleep(100) + + val latestVersion = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 + // Verify that no instances are marked as lagging, even when upload messages are sent. + // Since snapshot uploads are tied to commit, the lack of version difference should prevent + // the stores from being marked as lagging. + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + query.stop() + } + } + + test("SPARK-51358: Snapshot lag reports properly detects when all state stores are lagging") { + withCoordinatorAndSQLConf( + sc, + // Only use two partitions with the faulty store provider (both stores will skip uploads) + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val query = setUpStatefulQuery(inputData, "query") + + // Go through several rounds of input to force snapshot uploads + (0 until 3).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + val latestVersion = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 + // Verify that all instances are marked as lagging, since no upload messages are being sent + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).size == 2) + query.stop() + } + } +} + +class StateStoreCoordinatorStreamingSuite extends StreamTest { + import testImplicits._ + + Seq( + ("RocksDB", classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName), + ("HDFS", classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName) + ).foreach { case (providerName, providerClassName) => + test( + s"SPARK-51358: Restarting queries do not mark state stores as lagging for $providerName" + ) { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "2", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) + // Keep track of state checkpoint directory for the second run + var stateCheckpoint = "" + + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + stateCheckpoint = query.lastExecution.checkpointLocation + val latestVersion = query.lastProgress.batchId + 1 + + // Verify the coordinator logged snapshot uploads + (0 until numPartitions).map { + partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + if (partitionId <= 1) { + // Verify state stores in partition 0 and 1 are lagging and didn't upload + assert( + coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0 + ) + } else { + // Verify other stores have uploaded a snapshot and it's properly logged + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + } + // Verify that the normal state store (partitionId=2) is not lagging behind, + // and the faulty stores are reported as lagging. + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + // Stopping the streaming query should deactivate and clear snapshot uploaded events + StopStream, + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + + // Verify we evicted the previous latest uploaded snapshots from the coordinator + (0 until numPartitions).map { partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) + } + // Verify that we are not reporting any lagging stores after eviction, + // since none of these state stores are active anymore. + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + } + ) + // Restart the query, but do not add too much data so that we don't associate + // the current StateStoreProviderId (store id + query run id) with any new uploads. + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Perform one round of data, which is enough to activate instances and force a + // lagging instance report, but not enough to trigger a snapshot upload yet. + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Verify that the state stores have restored their snapshot version from the + // checkpoint and reported their current version + (0 until numPartitions).map { + partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + val latestSnapshotVersion = + coordRef.getLatestSnapshotVersionForTesting(providerId) + if (partitionId <= 1) { + // Verify state stores in partition 0/1 are still lagging and didn't upload + assert(latestSnapshotVersion.getOrElse(0) == 0) + } else { + // Verify other stores have uploaded a snapshot and it's properly logged + assert(latestSnapshotVersion.get > 0) + } + } + // Sleep a bit to allow the coordinator to pass the time threshold and report lag + Thread.sleep(5 * 100) + // Verify that we're reporting the faulty state stores (partitionId 0 and 1) + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + StopStream + ) + } + } + } + } + + test("SPARK-51358: Restarting queries with updated SQLConf get propagated to the coordinator") { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Process multiple batches so that the coordinator can start reporting lagging instances + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Sleep a bit to allow the coordinator to pass the time threshold and report lag + Thread.sleep(5 * 100) + // Verify that only the faulty stores are reported as lagging + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + // Stopping the streaming query should deactivate and clear snapshot uploaded events + StopStream + ) + // Bump up version multiplier, which would stop the coordinator from reporting + // lagging stores for the next few versions + spark.conf + .set(SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key, "10") + // Restart the query, and verify the conf change reflects in the coordinator + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Process the same amount of data as the first run + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Sleep the same amount to mimic conditions from first run + Thread.sleep(5 * 100) + // Verify that we are not reporting any lagging stores despite restarting + // because of the higher version multiplier + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + }, + StopStream + ) + } + } + } + + Seq( + ("RocksDB", classOf[RocksDBStateStoreProvider].getName), + ("HDFS", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { case (providerName, providerClassName) => + test( + s"SPARK-51358: Infrequent maintenance with $providerName using Trigger.AvailableNow " + + s"should be reported" + ) { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "50", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + + // Populate state stores with an initial snapshot, so that timestamp isn't marked + // as the default 0ms. + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable() + ) + // Increase maintenance interval to a much larger value to stop snapshot uploads + spark.conf.set(SQLConf.STREAMING_MAINTENANCE_INTERVAL.key, "60000") + // Execute a few batches in a short span + testStream(query)( + AddData(inputData, 1, 2, 3), + StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), + Execute { query => + query.awaitTermination() + // Verify the query ran with the AvailableNow trigger + assert(query.lastExecution.isTerminatingTrigger) + }, + AddData(inputData, 1, 2, 3), + StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), + Execute { query => + query.awaitTermination() + }, + // Start without available now, otherwise the stream closes too quickly for the + // testing RPC call to report lagging state stores + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Process data to activate state stores, but not enough to trigger snapshot uploads + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Verify that all faulty stores are reported as lagging despite the short burst. + // This test scenario mimics cases where snapshots have not been uploaded for + // a while due to the short running duration of AvailableNow. + val laggingStores = coordRef.getLaggingStoresForTesting( + query.runId, + latestVersion, + isTerminatingTrigger = true + ) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + StopStream + ) + } + } + } + } } object StateStoreCoordinatorSuite { def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = { var coordinatorRef: StateStoreCoordinatorRef = null try { - coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env) + coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env, new SQLConf) body(coordinatorRef) } finally { if (coordinatorRef != null) coordinatorRef.stop() } } + + def withCoordinatorAndSQLConf(sc: SparkContext, pairs: (String, String)*)( + body: (StateStoreCoordinatorRef, SparkSession) => Unit): Unit = { + var spark: SparkSession = null + var coordinatorRef: StateStoreCoordinatorRef = null + try { + spark = SparkSession.builder().sparkContext(sc).getOrCreate() + SparkSession.setActiveSession(spark) + coordinatorRef = spark.streams.stateStoreCoordinator + // Set up SQLConf entries + pairs.foreach { case (key, value) => spark.conf.set(key, value) } + body(coordinatorRef, spark) + } finally { + SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop())) + // Unset all custom SQLConf entries + if (spark != null) pairs.foreach { case (key, _) => spark.conf.unset(key) } + if (coordinatorRef != null) coordinatorRef.stop() + StateStore.stop() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 8af42d6dec26..093e8b991cc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -427,6 +427,7 @@ abstract class StateVariableSuiteBase extends SharedSparkSession before { StateStore.stop() require(!StateStore.isMaintenanceRunning) + spark.streams.stateStoreCoordinator // initialize the lazy coordinator } after {