From 958b491ffdfaec86d73002bc3a7ffc00d4e6f95b Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Fri, 28 Feb 2025 16:48:11 -0800 Subject: [PATCH 01/36] SPARK-51358 Introduce snapshot upload lag detection through StateStoreCoordinator --- .../apache/spark/sql/internal/SQLConf.scala | 13 + .../sql/classic/StreamingQueryManager.scala | 2 +- .../execution/streaming/state/RocksDB.scala | 13 + .../state/RocksDBStateStoreProvider.scala | 28 +- .../streaming/state/StateStore.scala | 6 + .../state/StateStoreCoordinator.scala | 120 +++++++- .../state/StateStoreCoordinatorSuite.scala | 288 +++++++++++++++++- 7 files changed, 459 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7e161fb9b7ab..e73bf4d5f2d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2249,6 +2249,19 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG = + buildConf("spark.sql.streaming.stateStore.minSnapshotVersionDeltaToLog") + .internal() + .doc( + "Minimum number of versions between the most recent uploaded snapshot version of a " + + "single state store instance and the most recent version across all state store " + + "instances to log a warning message." + ) + .version("4.0.0") + .intConf + .checkValue(k => k >= 0, "Must be greater than or equal to 0") + .createWithDefault(30) + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala index 6ce6f06de113..6d4a3ecd3603 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala @@ -53,7 +53,7 @@ class StreamingQueryManager private[sql] ( with Logging { private[sql] val stateStoreCoordinator = - StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) + StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env, sqlConf) private val listenerBus = new StreamingQueryListenerBus(Some(sparkSession.sparkContext.listenerBus)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 820322d1e0ee..09b354643fa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -134,6 +134,9 @@ class RocksDB( rocksDbOptions.setStatistics(new Statistics()) private val nativeStats = rocksDbOptions.statistics() + // Stores a StateStoreProvider reference for event callback such as snapshot upload reports + private var providerListener: Option[RocksDBEventListener] = None + private val workingDir = createTempDir("workingDir") private val fileManager = new RocksDBFileManager(dfsRootDir, createTempDir("fileManager"), hadoopConf, conf.compressionCodec, loggingId = loggingId) @@ -197,6 +200,11 @@ class RocksDB( @GuardedBy("acquireLock") private val shouldForceSnapshot: AtomicBoolean = new AtomicBoolean(false) + /** Attaches a RocksDBStateStoreProvider reference to the RocksDB instance for event callback. */ + def setListener(listener: RocksDBEventListener): Unit = { + providerListener = Some(listener) + } + private def getColumnFamilyInfo(cfName: String): ColumnFamilyInfo = { colFamilyNameToInfoMap.get(cfName) } @@ -1467,6 +1475,11 @@ class RocksDB( log"time taken: ${MDC(LogKeys.TIME_UNITS, uploadTime)} ms. " + log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}") lastUploadedSnapshotVersion.set(snapshot.version) + // Report to coordinator that the snapshot has been uploaded when + // changelog checkpointing is enabled, since that is when stores can lag behind. + if(enableChangelogCheckpointing) { + providerListener.foreach(_.reportSnapshotUploaded(snapshot.version)) + } } finally { snapshot.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 47721cea4359..ff86ed6c3782 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -38,9 +38,14 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform import org.apache.spark.util.{NonFateSharingCache, Utils} +/** Trait representing the different events reported from RocksDB instance */ +trait RocksDBEventListener { + def reportSnapshotUploaded(version: Long): Unit +} + private[sql] class RocksDBStateStoreProvider extends StateStoreProvider with Logging with Closeable - with SupportsFineGrainedReplay { + with SupportsFineGrainedReplay with RocksDBEventListener { import RocksDBStateStoreProvider._ class RocksDBStateStore(lastVersion: Long) extends StateStore { @@ -392,6 +397,10 @@ private[sql] class RocksDBStateStoreProvider rocksDB // lazy initialization + // Give the RocksDB instance a reference to this provider so it can call back to report + // specific events like snapshot uploads + rocksDB.setListener(this) + val dataEncoderCacheKey = StateRowEncoderCacheKey( queryRunId = getRunId(hadoopConf), operatorId = stateStoreId.operatorId, @@ -644,6 +653,23 @@ private[sql] class RocksDBStateStoreProvider throw StateStoreErrors.cannotCreateColumnFamilyWithReservedChars(colFamilyName) } } + + /** Callback function from RocksDB to report events to the coordinator. + * Additional information such as state store ID and query run ID are populated here + * to report back to the coordinator. + * + * @param version The snapshot version that was just uploaded from RocksDB + */ + def reportSnapshotUploaded(version: Long): Unit = { + // Collect the state store ID and query run ID to report back to the coordinator + StateStore.reportSnapshotUploaded( + StateStoreProviderId( + stateStoreId, + UUID.fromString(getRunId(hadoopConf)) + ), + version + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 09acc24aff98..9c032bda7002 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -1126,6 +1126,12 @@ object StateStore extends Logging { } } + def reportSnapshotUploaded(storeProviderId: StateStoreProviderId, snapshotVersion: Long): Unit = { + // Send current timestamp of uploaded snapshot as well + val currentTime = System.currentTimeMillis() + coordinatorRef.foreach(_.snapshotUploaded(storeProviderId, snapshotVersion, currentTime)) + } + private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 84b77efea3ca..7484fd73f251 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -25,6 +25,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.RpcUtils /** Trait representing all messages to [[StateStoreCoordinator]] */ @@ -55,6 +56,15 @@ private case class GetLocation(storeId: StateStoreProviderId) private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage +private case class SnapshotUploaded(storeId: StateStoreProviderId, version: Long, timestamp: Long) + extends StateStoreCoordinatorMessage + +private case class GetLatestSnapshotVersion(storeId: StateStoreProviderId) + extends StateStoreCoordinatorMessage + +private case class GetLaggingStores() + extends StateStoreCoordinatorMessage + private object StopCoordinator extends StateStoreCoordinatorMessage @@ -66,9 +76,9 @@ object StateStoreCoordinatorRef extends Logging { /** * Create a reference to a [[StateStoreCoordinator]] */ - def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + def forDriver(env: SparkEnv, conf: SQLConf): StateStoreCoordinatorRef = synchronized { try { - val coordinator = new StateStoreCoordinator(env.rpcEnv) + val coordinator = new StateStoreCoordinator(env.rpcEnv, conf) val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator) logInfo("Registered StateStoreCoordinator endpoint") new StateStoreCoordinatorRef(coordinatorRef) @@ -119,6 +129,25 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId)) } + /** Inform that an executor has uploaded a snapshot */ + private[sql] def snapshotUploaded( + storeProviderId: StateStoreProviderId, + version: Long, + timestamp: Long): Unit = { + rpcEndpointRef.askSync[Boolean](SnapshotUploaded(storeProviderId, version, timestamp)) + } + + /** Get the latest snapshot version uploaded for a state store */ + private[sql] def getLatestSnapshotVersion( + stateStoreProviderId: StateStoreProviderId): Option[Long] = { + rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersion(stateStoreProviderId)) + } + + /** Get the state store instances that are falling behind in snapshot uploads */ + private[sql] def getLaggingStores(): Seq[StateStoreProviderId] = { + rpcEndpointRef.askSync[Seq[StateStoreProviderId]](GetLaggingStores) + } + private[state] def stop(): Unit = { rpcEndpointRef.askSync[Boolean](StopCoordinator) } @@ -129,10 +158,17 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, * and get their locations for job scheduling. */ -private class StateStoreCoordinator(override val rpcEnv: RpcEnv) - extends ThreadSafeRpcEndpoint with Logging { +private class StateStoreCoordinator( + override val rpcEnv: RpcEnv, + val sqlConf: SQLConf) + extends ThreadSafeRpcEndpoint + with Logging { private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] + // Stores the latest snapshot version of a specific state store provider instance + private val stateStoreSnapshotVersions = + new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent] + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => logDebug(s"Reported state store $id is active at $executorId") @@ -168,9 +204,85 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) storeIdsToRemove.mkString(", ")) context.reply(true) + case SnapshotUploaded(providerId, version, timestamp) => + stateStoreSnapshotVersions.put(providerId, SnapshotUploadEvent(version, timestamp)) + logDebug(s"Snapshot uploaded at ${providerId} with version ${version}") + // Report all stores that are behind in snapshot uploads + val (laggingStores, latestSnapshot) = findLaggingStores() + if (laggingStores.nonEmpty) { + logWarning(s"Number of state stores falling behind: ${laggingStores.size}") + laggingStores.foreach { storeProviderId => + val snapshotEvent = + stateStoreSnapshotVersions.getOrElse(storeProviderId, SnapshotUploadEvent(-1, 0)) + logWarning( + s"State store falling behind $storeProviderId " + + s"(current: $snapshotEvent, latest: $latestSnapshot)" + ) + } + } + context.reply(true) + + case GetLatestSnapshotVersion(providerId) => + val version = stateStoreSnapshotVersions.get(providerId).map(_.version) + logDebug(s"Got latest snapshot version of the state store $providerId: $version") + context.reply(version) + + case GetLaggingStores => + val (laggingStores, _) = findLaggingStores() + logDebug(s"Got lagging state stores ${laggingStores + .map( + id => + s"StateStoreId(operatorId=${id.storeId.operatorId}, " + + s"partitionId=${id.storeId.partitionId}, " + + s"storeName=${id.storeId.storeName})" + ) + .mkString(", ")}") + context.reply(laggingStores) + case StopCoordinator => stop() // Stop before replying to ensure that endpoint name has been deregistered logInfo("StateStoreCoordinator stopped") context.reply(true) } + + case class SnapshotUploadEvent( + version: Long, + timestamp: Long + ) extends Ordered[SnapshotUploadEvent] { + def isLagging(latest: SnapshotUploadEvent): Boolean = { + val versionDelta = latest.version - version + val timeDelta = latest.timestamp - timestamp + val minVersionDeltaForLogging = + sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) + // Use 10 times the maintenance interval as the minimum time delta for logging + val minTimeDeltaForLogging = 10 * sqlConf.getConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL) + + versionDelta >= minVersionDeltaForLogging || + (version >= 0 && timeDelta > minTimeDeltaForLogging) + } + + override def compare(that: SnapshotUploadEvent): Int = { + this.version.compare(that.version) + } + + override def toString(): String = { + s"SnapshotUploadEvent(version=$version, timestamp=$timestamp)" + } + } + + private def findLaggingStores(): (Seq[StateStoreProviderId], SnapshotUploadEvent) = { + // Find the most updated instance to use as reference point + val latestSnapshot = instances + .map( + instance => stateStoreSnapshotVersions.getOrElse(instance._1, SnapshotUploadEvent(-1, 0)) + ) + .max + // Look for instances that are lagging behind in snapshot uploads + val laggingStores = instances.keys.filter { storeProviderId => + stateStoreSnapshotVersions + .getOrElse(storeProviderId, SnapshotUploadEvent(-1, 0)) + .isLagging(latestSnapshot) + }.toSeq + (laggingStores, latestSnapshot) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 2ebc533f7137..fe9ce7452e3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -26,8 +26,9 @@ import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} -import org.apache.spark.sql.functions.count -import org.apache.spark.sql.internal.SQLConf.SHUFFLE_PARTITIONS +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} +import org.apache.spark.sql.functions.{count, expr} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { @@ -102,7 +103,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("multiple references have same underlying coordinator") { withCoordinatorRef(sc) { coordRef1 => - val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) + val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env, new SQLConf) val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) @@ -125,7 +126,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { import spark.implicits._ coordRef = spark.streams.stateStoreCoordinator implicit val sqlContext = spark.sqlContext - spark.conf.set(SHUFFLE_PARTITIONS.key, "1") + spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") // Start a query and run a batch to load state stores val inputData = MemoryStream[Int] @@ -155,16 +156,293 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { StateStore.stop() } } + + test("snapshot uploads in RocksDB are not reported if changelog checkpointing is disabled") { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "false" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + inputData.addData(1, 2, 3) + query.processAllAvailable() + val stateCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + + // Verify stores do not report snapshot upload events to the coordinator. + // As a result, all stores will return nothing as the latest version + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => + val providerId = + StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, partitionId), query.runId) + assert(coordRef.getLatestSnapshotVersion(providerId).isEmpty) + } + query.stop() + } + } + + test("snapshot uploads in RocksDB are properly reported to the coordinator") { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "2" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + inputData.addData(1, 2, 3) + query.processAllAvailable() + val stateCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + + // Verify all stores have uploaded a snapshot and it's logged by the coordinator + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => + val providerId = + StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, partitionId), query.runId) + assert(coordRef.getLatestSnapshotVersion(providerId).get >= 0) + } + // Verify that we should not have any state stores lagging behind + assert(coordRef.getLaggingStores().isEmpty) + query.stop() + } + } + + test( + "snapshot uploads in RocksDBSkipMaintenanceOnCertainPartitionsProvider are properly " + + "reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "2" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + inputData.addData(1, 2, 3) + query.processAllAvailable() + inputData.addData(1, 2, 3) + query.processAllAvailable() + val stateCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => + val providerId = + StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, partitionId), query.runId) + if(partitionId <= 1) { + // Verify state stores in partition 0 and 1 are lagging and did not upload anything + assert(coordRef.getLatestSnapshotVersion(providerId).isEmpty) + } else { + // Verify other stores have uploaded a snapshot and it's logged by the coordinator + assert(coordRef.getLatestSnapshotVersion(providerId).get >= 0) + } + } + // We should have two state stores (id 0 and 1) that are lagging behind at this point + val laggingStores = coordRef.getLaggingStores() + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + query.stop() + } + } + + private val allJoinStateStoreNames: Seq[String] = + SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) + + test( + "snapshot uploads for join queries with RocksDBStateStoreProvider are properly " + + "reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "2" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a join query and run some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = joined.writeStream + .format("memory") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + (0 until 5).foreach { _ => + input1.addData(1, 5) + query.processAllAvailable() + input2.addData(1, 5, 10) + query.processAllAvailable() + } + val stateCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + + // Verify all state stores for join queries are reporting snapshot uploads + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => + allJoinStateStoreNames.foreach { storeName => + val providerId = + StateStoreProviderId( + StateStoreId(stateCheckpointDir, 0, partitionId, storeName), + query.runId + ) + assert(coordRef.getLatestSnapshotVersion(providerId).get >= 0) + } + } + // Verify that we should not have any state stores lagging behind + assert(coordRef.getLaggingStores().isEmpty) + query.stop() + } + } + + test( + "snapshot uploads for join queries with RocksDBSkipMaintenanceOnCertainPartitionsProvider " + + "are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "2" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a join query and run some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = joined.writeStream + .format("memory") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + (0 until 5).foreach { _ => + input1.addData(1, 5) + query.processAllAvailable() + input2.addData(1, 5, 10) + query.processAllAvailable() + } + val stateCheckpointDir = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + + // Verify all state stores for join queries are reporting snapshot uploads + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => + allJoinStateStoreNames.foreach { storeName => + val providerId = + StateStoreProviderId( + StateStoreId(stateCheckpointDir, 0, partitionId, storeName), + query.runId + ) + if (partitionId <= 1) { + // Verify state stores in partition 0 and 1 are lagging and did not upload anything + assert(coordRef.getLatestSnapshotVersion(providerId).isEmpty) + } else { + // Verify other stores have uploaded a snapshot and it's logged by the coordinator + assert(coordRef.getLatestSnapshotVersion(providerId).get >= 0) + } + } + } + // Verify that only stores from partition id 0 and 1 are lagging behind. + // Each partition has 4 stores for join queries, so there are 2 * 4 = 8 lagging stores. + val laggingStores = coordRef.getLaggingStores() + assert(laggingStores.size == 2 * 4) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + } + } } object StateStoreCoordinatorSuite { def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = { var coordinatorRef: StateStoreCoordinatorRef = null try { - coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env) + coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env, new SQLConf) body(coordinatorRef) } finally { if (coordinatorRef != null) coordinatorRef.stop() } } + + def withCoordinatorAndSQLConf(sc: SparkContext, pairs: (String, String)*)( + body: (StateStoreCoordinatorRef, SparkSession) => Unit): Unit = { + var coordinatorRef: StateStoreCoordinatorRef = null + try { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + SparkSession.setActiveSession(spark) + coordinatorRef = spark.streams.stateStoreCoordinator + // Set up SQLConf entries + pairs.foreach { case (key, value) => spark.conf.set(key, value) } + body(coordinatorRef, spark) + } finally { + SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop())) + if (coordinatorRef != null) coordinatorRef.stop() + StateStore.stop() + } + } } From 7ffadd8e89bfa09c4ca371b4957bf182b19afaf6 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Fri, 28 Feb 2025 20:01:50 -0800 Subject: [PATCH 02/36] SPARK-51358 Make test less flaky --- .../state/StateStoreCoordinatorSuite.scala | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index fe9ce7452e3e..7c4defdfead8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -157,7 +157,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } - test("snapshot uploads in RocksDB are not reported if changelog checkpointing is disabled") { + test( + "SPARK-51358: Snapshot uploads in RocksDB are not reported if changelog " + + "checkpointing is disabled" + ) { withCoordinatorAndSQLConf( sc, SQLConf.SHUFFLE_PARTITIONS.key -> "5", @@ -198,7 +201,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } - test("snapshot uploads in RocksDB are properly reported to the coordinator") { + test("SPARK-51358: Snapshot uploads in RocksDB are properly reported to the coordinator") { withCoordinatorAndSQLConf( sc, SQLConf.SHUFFLE_PARTITIONS.key -> "5", @@ -242,8 +245,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } test( - "snapshot uploads in RocksDBSkipMaintenanceOnCertainPartitionsProvider are properly " + - "reported to the coordinator" + "SPARK-51358: Snapshot uploads in RocksDBSkipMaintenanceOnCertainPartitionsProvider " + + "are properly reported to the coordinator" ) { withCoordinatorAndSQLConf( sc, @@ -299,17 +302,17 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) test( - "snapshot uploads for join queries with RocksDBStateStoreProvider are properly " + - "reported to the coordinator" + "SPARK-51358: Snapshot uploads for join queries with RocksDBStateStoreProvider " + + "are properly reported to the coordinator" ) { withCoordinatorAndSQLConf( sc, - SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "50", SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "2" + SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "4" ) { case (coordRef, spark) => import spark.implicits._ @@ -354,18 +357,18 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } test( - "snapshot uploads for join queries with RocksDBSkipMaintenanceOnCertainPartitionsProvider " + - "are properly reported to the coordinator" + "SPARK-51358: Snapshot uploads for join queries with " + + "RocksDBSkipMaintenanceOnCertainPartitionsProvider are properly reported to the coordinator" ) { withCoordinatorAndSQLConf( sc, - SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "50", SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "2" + SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "4" ) { case (coordRef, spark) => import spark.implicits._ From 3c6a5f93a0d8387d777a6d5a7cb8ef8b9c638037 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Tue, 4 Mar 2025 07:58:49 -0800 Subject: [PATCH 03/36] SPARK-51358 Update logging and event listener init --- .../execution/streaming/state/RocksDB.scala | 10 ++--- .../state/RocksDBStateStoreProvider.scala | 8 +++- .../state/StateStoreCoordinator.scala | 39 +++++++++++++++---- .../state/StateStoreCoordinatorSuite.scala | 6 ++- 4 files changed, 47 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 09b354643fa4..2d210f87d433 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -65,6 +65,7 @@ case object StoreTaskCompletionListener extends RocksDBOpType("store_task_comple * @param localRootDir Root directory in local disk that is used to working and checkpointing dirs * @param hadoopConf Hadoop configuration for talking to the remote file system * @param loggingId Id that will be prepended in logs for isolating concurrent RocksDBs + * @param providerListener A reference to the state store provider for event callback reporting */ class RocksDB( dfsRootDir: String, @@ -73,7 +74,9 @@ class RocksDB( hadoopConf: Configuration = new Configuration, loggingId: String = "", useColumnFamilies: Boolean = false, - enableStateStoreCheckpointIds: Boolean = false) extends Logging { + enableStateStoreCheckpointIds: Boolean = false, + providerListener: Option[RocksDBEventListener] = None) + extends Logging { import RocksDB._ @@ -134,9 +137,6 @@ class RocksDB( rocksDbOptions.setStatistics(new Statistics()) private val nativeStats = rocksDbOptions.statistics() - // Stores a StateStoreProvider reference for event callback such as snapshot upload reports - private var providerListener: Option[RocksDBEventListener] = None - private val workingDir = createTempDir("workingDir") private val fileManager = new RocksDBFileManager(dfsRootDir, createTempDir("fileManager"), hadoopConf, conf.compressionCodec, loggingId = loggingId) @@ -1477,7 +1477,7 @@ class RocksDB( lastUploadedSnapshotVersion.set(snapshot.version) // Report to coordinator that the snapshot has been uploaded when // changelog checkpointing is enabled, since that is when stores can lag behind. - if(enableChangelogCheckpointing) { + if (enableChangelogCheckpointing) { providerListener.foreach(_.reportSnapshotUploaded(snapshot.version)) } } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index ff86ed6c3782..409a050b1d10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -38,7 +38,11 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform import org.apache.spark.util.{NonFateSharingCache, Utils} -/** Trait representing the different events reported from RocksDB instance */ +/** + * Trait representing the different events reported from RocksDB instance. + * Gives the RocksDB instance a reference to this provider so it can call back to report + * specific events like snapshot uploads. + */ trait RocksDBEventListener { def reportSnapshotUploaded(version: Long): Unit } @@ -534,7 +538,7 @@ private[sql] class RocksDBStateStoreProvider val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr) new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr, - useColumnFamilies, storeConf.enableStateStoreCheckpointIds) + useColumnFamilies, storeConf.enableStateStoreCheckpointIds, Some(this)) } private val keyValueEncoderMap = new java.util.concurrent.ConcurrentHashMap[String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 7484fd73f251..5cc22e1969a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -59,9 +59,19 @@ private case class DeactivateInstances(runId: UUID) private case class SnapshotUploaded(storeId: StateStoreProviderId, version: Long, timestamp: Long) extends StateStoreCoordinatorMessage +/** + * Message used for testing. + * This message is used to retrieve the latest snapshot version reported for upload from a + * specific state store instance. + */ private case class GetLatestSnapshotVersion(storeId: StateStoreProviderId) extends StateStoreCoordinatorMessage +/** + * Message used for testing. + * This message is used to retrieve the all active state store instance falling behind in + * snapshot uploads, whether it is through version or time criteria. + */ private case class GetLaggingStores() extends StateStoreCoordinatorMessage @@ -134,16 +144,23 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { storeProviderId: StateStoreProviderId, version: Long, timestamp: Long): Unit = { + logWarning(s"ZEYU: snapshotUploaded rpc endoint ref") rpcEndpointRef.askSync[Boolean](SnapshotUploaded(storeProviderId, version, timestamp)) } - /** Get the latest snapshot version uploaded for a state store */ + /** + * Endpoint used for testing. + * Get the latest snapshot version uploaded for a state store. + */ private[sql] def getLatestSnapshotVersion( stateStoreProviderId: StateStoreProviderId): Option[Long] = { rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersion(stateStoreProviderId)) } - /** Get the state store instances that are falling behind in snapshot uploads */ + /** + * Endpoint used for testing. + * Get the state store instances that are falling behind in snapshot uploads. + */ private[sql] def getLaggingStores(): Seq[StateStoreProviderId] = { rpcEndpointRef.askSync[Seq[StateStoreProviderId]](GetLaggingStores) } @@ -161,13 +178,14 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { private class StateStoreCoordinator( override val rpcEnv: RpcEnv, val sqlConf: SQLConf) - extends ThreadSafeRpcEndpoint + extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] // Stores the latest snapshot version of a specific state store provider instance private val stateStoreSnapshotVersions = new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent] + private var lastSnapshotUploadEvent: Option[SnapshotUploadEvent] = None override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => @@ -210,16 +228,19 @@ private class StateStoreCoordinator( // Report all stores that are behind in snapshot uploads val (laggingStores, latestSnapshot) = findLaggingStores() if (laggingStores.nonEmpty) { - logWarning(s"Number of state stores falling behind: ${laggingStores.size}") + logWarning(s"StateStoreCoordinator Snapshot Lag - " + + s"Number of state stores falling behind: ${laggingStores.size}" + + s"(Last upload: ${lastSnapshotUploadEvent.getOrElse(SnapshotUploadEvent(-1, 0))})") laggingStores.foreach { storeProviderId => val snapshotEvent = stateStoreSnapshotVersions.getOrElse(storeProviderId, SnapshotUploadEvent(-1, 0)) logWarning( - s"State store falling behind $storeProviderId " + + s"StateStoreCoordinator Snapshot Lag - State store falling behind $storeProviderId " + s"(current: $snapshotEvent, latest: $latestSnapshot)" ) } } + lastSnapshotUploadEvent = Some(latestSnapshot) context.reply(true) case GetLatestSnapshotVersion(providerId) => @@ -258,7 +279,7 @@ private class StateStoreCoordinator( val minTimeDeltaForLogging = 10 * sqlConf.getConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL) versionDelta >= minVersionDeltaForLogging || - (version >= 0 && timeDelta > minTimeDeltaForLogging) + (version >= 0 && timeDelta > minTimeDeltaForLogging) } override def compare(that: SnapshotUploadEvent): Int = { @@ -271,12 +292,14 @@ private class StateStoreCoordinator( } private def findLaggingStores(): (Seq[StateStoreProviderId], SnapshotUploadEvent) = { + if (instances.isEmpty) { + return (Seq.empty, SnapshotUploadEvent(-1, 0)) + } // Find the most updated instance to use as reference point val latestSnapshot = instances .map( instance => stateStoreSnapshotVersions.getOrElse(instance._1, SnapshotUploadEvent(-1, 0)) - ) - .max + ).max // Look for instances that are lagging behind in snapshot uploads val laggingStores = instances.keys.filter { storeProviderId => stateStoreSnapshotVersions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 7c4defdfead8..bccee854130a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -225,6 +225,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .queryName("query") .option("checkpointLocation", checkpointLocation.toString) .start() + // Add and commit data multiple times to force new snapshot versions inputData.addData(1, 2, 3) query.processAllAvailable() inputData.addData(1, 2, 3) @@ -272,6 +273,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .queryName("query") .option("checkpointLocation", checkpointLocation.toString) .start() + // Add and commit data multiple times to force new snapshot versions inputData.addData(1, 2, 3) query.processAllAvailable() inputData.addData(1, 2, 3) @@ -282,7 +284,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => val providerId = StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, partitionId), query.runId) - if(partitionId <= 1) { + if (partitionId <= 1) { // Verify state stores in partition 0 and 1 are lagging and did not upload anything assert(coordRef.getLatestSnapshotVersion(providerId).isEmpty) } else { @@ -330,6 +332,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .queryName("query") .option("checkpointLocation", checkpointLocation.toString) .start() + // Add and commit data multiple times to force new snapshot versions (0 until 5).foreach { _ => input1.addData(1, 5) query.processAllAvailable() @@ -386,6 +389,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .queryName("query") .option("checkpointLocation", checkpointLocation.toString) .start() + // Add and commit data multiple times to force new snapshot versions (0 until 5).foreach { _ => input1.addData(1, 5) query.processAllAvailable() From 6056856e7c490b3670c17f403c99aa3bf0ab9b1f Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Tue, 4 Mar 2025 08:30:33 -0800 Subject: [PATCH 04/36] SPARK-51358 Remove log --- .../sql/execution/streaming/state/StateStoreCoordinator.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 5cc22e1969a6..0e8fa14a1c25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -144,7 +144,6 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { storeProviderId: StateStoreProviderId, version: Long, timestamp: Long): Unit = { - logWarning(s"ZEYU: snapshotUploaded rpc endoint ref") rpcEndpointRef.askSync[Boolean](SnapshotUploaded(storeProviderId, version, timestamp)) } From 41eaba40c3c6a6f545f3e3addded34f9915ec337 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Tue, 4 Mar 2025 08:49:27 -0800 Subject: [PATCH 05/36] SPARK-51358 Remove setListener --- .../apache/spark/sql/execution/streaming/state/RocksDB.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 2d210f87d433..0e70c02da606 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -200,11 +200,6 @@ class RocksDB( @GuardedBy("acquireLock") private val shouldForceSnapshot: AtomicBoolean = new AtomicBoolean(false) - /** Attaches a RocksDBStateStoreProvider reference to the RocksDB instance for event callback. */ - def setListener(listener: RocksDBEventListener): Unit = { - providerListener = Some(listener) - } - private def getColumnFamilyInfo(cfName: String): ColumnFamilyInfo = { colFamilyNameToInfoMap.get(cfName) } From 4117326f06b23bd34e52e104da86677bfb78473e Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Tue, 4 Mar 2025 09:07:11 -0800 Subject: [PATCH 06/36] SPARK-51358 Remove setListener call --- .../execution/streaming/state/RocksDBStateStoreProvider.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 409a050b1d10..0c740db33912 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -401,10 +401,6 @@ private[sql] class RocksDBStateStoreProvider rocksDB // lazy initialization - // Give the RocksDB instance a reference to this provider so it can call back to report - // specific events like snapshot uploads - rocksDB.setListener(this) - val dataEncoderCacheKey = StateRowEncoderCacheKey( queryRunId = getRunId(hadoopConf), operatorId = stateStoreId.operatorId, From d039f7356886eb8ee7c30debefd1ed7800c251d3 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Thu, 6 Mar 2025 15:16:17 -0800 Subject: [PATCH 07/36] SPARK-51358 Add additional detail to docstring --- .../streaming/state/RocksDBStateStoreProvider.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 0c740db33912..2d386fcf69c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -39,9 +39,11 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.util.{NonFateSharingCache, Utils} /** - * Trait representing the different events reported from RocksDB instance. - * Gives the RocksDB instance a reference to this provider so it can call back to report - * specific events like snapshot uploads. + * Trait representing events reported from a RocksDB instance. + * + * The internal RocksDB instance can use a provider with a `RocksDBEventListener` reference to + * report specific events like snapshot uploads. This should only be used to report back to the + * coordinator for metrics and monitoring purposes. */ trait RocksDBEventListener { def reportSnapshotUploaded(version: Long): Unit From cf6da39bccc4ea6c707fd098b7f6da026447ba4d Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Fri, 7 Mar 2025 15:05:44 -0800 Subject: [PATCH 08/36] SPARK-51358 Fix StateStoreSuite and use log interpolator --- .../org/apache/spark/internal/LogKey.scala | 4 +++ .../state/StateStoreCoordinator.scala | 32 ++++++++++++------- .../state/RocksDBStateStoreSuite.scala | 1 + .../state/StateStoreCoordinatorSuite.scala | 2 -- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index 318f32c52b90..5913d3edadc5 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -510,6 +510,7 @@ private[spark] object LogKeys { case object NUM_ITERATIONS extends LogKey case object NUM_KAFKA_PULLS extends LogKey case object NUM_KAFKA_RECORDS_PULLED extends LogKey + case object NUM_LAGGING_STORES extends LogKey case object NUM_LEADING_SINGULAR_VALUES extends LogKey case object NUM_LEFT_PARTITION_VALUES extends LogKey case object NUM_LOADED_ENTRIES extends LogKey @@ -749,6 +750,9 @@ private[spark] object LogKeys { case object SLEEP_TIME extends LogKey case object SLIDE_DURATION extends LogKey case object SMALLEST_CLUSTER_INDEX extends LogKey + case object SNAPSHOT_EVENT extends LogKey + case object SNAPSHOT_EVENT_TIME_DELTA extends LogKey + case object SNAPSHOT_EVENT_VERSION_DELTA extends LogKey case object SNAPSHOT_VERSION extends LogKey case object SOCKET_ADDRESS extends LogKey case object SOURCE extends LogKey diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 0e8fa14a1c25..8769d247664f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -22,7 +22,7 @@ import java.util.UUID import scala.collection.mutable import org.apache.spark.SparkEnv -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.internal.SQLConf @@ -184,7 +184,6 @@ private class StateStoreCoordinator( // Stores the latest snapshot version of a specific state store provider instance private val stateStoreSnapshotVersions = new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent] - private var lastSnapshotUploadEvent: Option[SnapshotUploadEvent] = None override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => @@ -227,19 +226,28 @@ private class StateStoreCoordinator( // Report all stores that are behind in snapshot uploads val (laggingStores, latestSnapshot) = findLaggingStores() if (laggingStores.nonEmpty) { - logWarning(s"StateStoreCoordinator Snapshot Lag - " + - s"Number of state stores falling behind: ${laggingStores.size}" + - s"(Last upload: ${lastSnapshotUploadEvent.getOrElse(SnapshotUploadEvent(-1, 0))})") + logWarning( + log"StateStoreCoordinator Snapshot Lag - Number of state stores falling behind: " + + log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)} " + + log"(Latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, latestSnapshot)})" + ) laggingStores.foreach { storeProviderId => - val snapshotEvent = - stateStoreSnapshotVersions.getOrElse(storeProviderId, SnapshotUploadEvent(-1, 0)) - logWarning( - s"StateStoreCoordinator Snapshot Lag - State store falling behind $storeProviderId " + - s"(current: $snapshotEvent, latest: $latestSnapshot)" - ) + val logMessage = stateStoreSnapshotVersions.get(storeProviderId) match { + case Some(snapshotEvent) => + val versionDelta = latestSnapshot.version - snapshotEvent.version + val timeDelta = latestSnapshot.timestamp - snapshotEvent.timestamp + + log"StateStoreCoordinator Snapshot Lag - State store falling behind " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + + log"(version delta: ${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + + log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" + case None => + log"StateStoreCoordinator Snapshot Lag - State store falling behind " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} (never uploaded)" + } + logWarning(logMessage) } } - lastSnapshotUploadEvent = Some(latestSnapshot) context.reply(true) case GetLatestSnapshotVersion(providerId) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 5aea0077e2aa..b13508682188 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -52,6 +52,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid before { StateStore.stop() require(!StateStore.isMaintenanceRunning) + spark.streams.stateStoreCoordinator // initialize the lazy coordinator } after { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index bccee854130a..fc20604fa262 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -335,7 +335,6 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { // Add and commit data multiple times to force new snapshot versions (0 until 5).foreach { _ => input1.addData(1, 5) - query.processAllAvailable() input2.addData(1, 5, 10) query.processAllAvailable() } @@ -392,7 +391,6 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { // Add and commit data multiple times to force new snapshot versions (0 until 5).foreach { _ => input1.addData(1, 5) - query.processAllAvailable() input2.addData(1, 5, 10) query.processAllAvailable() } From 6ea790f967264830bd82f387499341378858cf92 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Wed, 12 Mar 2025 22:00:11 -0700 Subject: [PATCH 09/36] SPARK-51358 Update coordinator logic, add additional configurations, and clean up nits --- .../apache/spark/sql/internal/SQLConf.scala | 43 +++++- .../execution/streaming/state/RocksDB.scala | 2 +- .../state/RocksDBStateStoreProvider.scala | 8 +- .../streaming/state/StateStore.scala | 4 +- .../streaming/state/StateStoreConf.scala | 7 + .../state/StateStoreCoordinator.scala | 135 +++++++++++------- .../state/StateStoreCoordinatorSuite.scala | 37 ++--- 7 files changed, 166 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f81c9cdc09c8..841d48bc6f04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2246,9 +2246,47 @@ object SQLConf { ) .version("4.0.0") .intConf - .checkValue(k => k >= 0, "Must be greater than or equal to 0") + .checkValue(k => k >= 1, "Must be greater than or equal to 1") .createWithDefault(30) + val STATE_STORE_COORDINATOR_MAINTENANCE_MULTIPLIER_FOR_MIN_TIME_DELTA_TO_LOG = + buildConf("spark.sql.streaming.stateStore.maintenanceMultiplierForMinTimeDeltaToLog") + .internal() + .doc( + "The multiplier used to determine the minimum time threshold between the single " + + "state store instance and the most recent version across all state store instances " + + "to log a warning message. The threshold is calculated as the maintenance interval, " + + "multiplied by this multiplier." + ) + .version("4.0.0") + .intConf + .checkValue(k => k >= 1, "Must be greater than or equal to 1") + .createWithDefault(20) + + val STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED = + buildConf("spark.sql.streaming.stateStore.coordinatorReportUpload.enabled") + .internal() + .doc( + "When true, the state store instances will send messages to the state store " + + "coordinator to report upload events whenever it finishes uploading a snapshot." + ) + .version("4.0.0") + .booleanConf + .createWithDefault(false) + + val STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL = + buildConf("spark.sql.streaming.stateStore.snapshotLagReportInterval") + .internal() + .doc( + "The minimum amount of time between the state store coordinator's full report on " + + "state store instances falling behind in snapshot uploads. The reports are not " + + "guaranteed to be separated by this interval, because the coordinator only checks " + + "for lagging instances when it receives a new snapshot upload message." + ) + .version("4.0.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(TimeUnit.MINUTES.toMillis(5)) + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") .internal() @@ -5789,6 +5827,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def stateStoreSkipNullsForStreamStreamJoins: Boolean = getConf(STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS) + def stateStoreCoordinatorReportUploadEnabled: Boolean = + getConf(STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED) + def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index db43507753c9..77b77541f6d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -65,7 +65,7 @@ case object StoreTaskCompletionListener extends RocksDBOpType("store_task_comple * @param localRootDir Root directory in local disk that is used to working and checkpointing dirs * @param hadoopConf Hadoop configuration for talking to the remote file system * @param loggingId Id that will be prepended in logs for isolating concurrent RocksDBs - * @param providerListener A reference to the state store provider for event callback reporting + * @param providerListener The parent RocksDBStateStoreProvider object used for event reports */ class RocksDB( dfsRootDir: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 095bf6f01e0a..1f300cfece79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -41,9 +41,8 @@ import org.apache.spark.util.{NonFateSharingCache, Utils} /** * Trait representing events reported from a RocksDB instance. * - * The internal RocksDB instance can use a provider with a `RocksDBEventListener` reference to - * report specific events like snapshot uploads. This should only be used to report back to the - * coordinator for metrics and monitoring purposes. + * We pass this into the internal RocksDB instance to report specific events like snapshot uploads. + * This should only be used to report back to the coordinator for metrics and monitoring purposes. */ trait RocksDBEventListener { def reportSnapshotUploaded(version: Long): Unit @@ -655,6 +654,9 @@ private[sql] class RocksDBStateStoreProvider * @param version The snapshot version that was just uploaded from RocksDB */ def reportSnapshotUploaded(version: Long): Unit = { + if (!storeConf.stateStoreCoordinatorReportUploadEnabled) { + return + } // Collect the state store ID and query run ID to report back to the coordinator StateStore.reportSnapshotUploaded( StateStoreProviderId( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 4d2a2c681d67..a40e68eb6127 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -1033,7 +1033,9 @@ object StateStore extends Logging { } } - def reportSnapshotUploaded(storeProviderId: StateStoreProviderId, snapshotVersion: Long): Unit = { + private[state] def reportSnapshotUploaded( + storeProviderId: StateStoreProviderId, + snapshotVersion: Long): Unit = { // Send current timestamp of uploaded snapshot as well val currentTime = System.currentTimeMillis() coordinatorRef.foreach(_.snapshotUploaded(storeProviderId, snapshotVersion, currentTime)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 9d26bf8fdf2e..50b11d86fb32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -92,6 +92,13 @@ class StateStoreConf( val enableStateStoreCheckpointIds = StatefulOperatorStateInfo.enableStateStoreCheckpointIds(sqlConf) + /** + * Whether to report snapshot uploaded messages from the internal RocksDB instance + * to the state store coordinator. + */ + val stateStoreCoordinatorReportUploadEnabled: Boolean = + sqlConf.stateStoreCoordinatorReportUploadEnabled + /** * Additional configurations related to state store. This will capture all configs in * SQLConf that start with `spark.sql.streaming.stateStore.` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 8769d247664f..42be484168c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -56,7 +56,14 @@ private case class GetLocation(storeId: StateStoreProviderId) private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage -private case class SnapshotUploaded(storeId: StateStoreProviderId, version: Long, timestamp: Long) +/** + * This message is used to report a state store instance has just finished uploading a snapshot, + * along with the timestamp in milliseconds and the snapshot version. + */ +private case class ReportSnapshotUploaded( + storeId: StateStoreProviderId, + version: Long, + timestamp: Long) extends StateStoreCoordinatorMessage /** @@ -64,7 +71,7 @@ private case class SnapshotUploaded(storeId: StateStoreProviderId, version: Long * This message is used to retrieve the latest snapshot version reported for upload from a * specific state store instance. */ -private case class GetLatestSnapshotVersion(storeId: StateStoreProviderId) +private case class GetLatestSnapshotVersionForTesting(storeId: StateStoreProviderId) extends StateStoreCoordinatorMessage /** @@ -72,7 +79,7 @@ private case class GetLatestSnapshotVersion(storeId: StateStoreProviderId) * This message is used to retrieve the all active state store instance falling behind in * snapshot uploads, whether it is through version or time criteria. */ -private case class GetLaggingStores() +private object GetLaggingStoresForTesting extends StateStoreCoordinatorMessage private object StopCoordinator @@ -144,24 +151,24 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { storeProviderId: StateStoreProviderId, version: Long, timestamp: Long): Unit = { - rpcEndpointRef.askSync[Boolean](SnapshotUploaded(storeProviderId, version, timestamp)) + rpcEndpointRef.askSync[Boolean](ReportSnapshotUploaded(storeProviderId, version, timestamp)) } /** * Endpoint used for testing. * Get the latest snapshot version uploaded for a state store. */ - private[sql] def getLatestSnapshotVersion( + private[state] def getLatestSnapshotVersionForTesting( stateStoreProviderId: StateStoreProviderId): Option[Long] = { - rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersion(stateStoreProviderId)) + rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersionForTesting(stateStoreProviderId)) } /** * Endpoint used for testing. * Get the state store instances that are falling behind in snapshot uploads. */ - private[sql] def getLaggingStores(): Seq[StateStoreProviderId] = { - rpcEndpointRef.askSync[Seq[StateStoreProviderId]](GetLaggingStores) + private[state] def getLaggingStoresForTesting(): Seq[StateStoreProviderId] = { + rpcEndpointRef.askSync[Seq[StateStoreProviderId]](GetLaggingStoresForTesting) } private[state] def stop(): Unit = { @@ -181,10 +188,17 @@ private class StateStoreCoordinator( with Logging { private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] - // Stores the latest snapshot version of a specific state store provider instance - private val stateStoreSnapshotVersions = + // Stores the latest snapshot upload event for a specific state store provider instance + private val stateStoreLatestUploadedSnapshot = new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent] + private val defaultSnapshotUploadEvent = SnapshotUploadEvent(-1, 0) + + // Stores the last timestamp in milliseconds where the coordinator did a full report on + // instances lagging behind on snapshot uploads. The initial timestamp is defaulted to + // 0 milliseconds. + private var lastFullSnapshotLagReport = 0L + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => logDebug(s"Reported state store $id is active at $executorId") @@ -220,51 +234,64 @@ private class StateStoreCoordinator( storeIdsToRemove.mkString(", ")) context.reply(true) - case SnapshotUploaded(providerId, version, timestamp) => - stateStoreSnapshotVersions.put(providerId, SnapshotUploadEvent(version, timestamp)) - logDebug(s"Snapshot uploaded at ${providerId} with version ${version}") - // Report all stores that are behind in snapshot uploads - val (laggingStores, latestSnapshot) = findLaggingStores() - if (laggingStores.nonEmpty) { + case ReportSnapshotUploaded(providerId, version, timestamp) => + // Ignore this upload event if the version isn't more recent for this provider, + // since it's possible that an old version gets uploaded after a new executor uploads + // for the same provider but with a newer snapshot. + if (stateStoreLatestUploadedSnapshot + .getOrElse(providerId, defaultSnapshotUploadEvent) + .version >= version) { + context.reply(true) + } + stateStoreLatestUploadedSnapshot.put(providerId, SnapshotUploadEvent(version, timestamp)) + logDebug(s"Snapshot version $version was uploaded for provider $providerId") + // Report all stores that are behind in snapshot uploads. + // Only report the full list of providers lagging behind if the last reported time + // is not recent. The lag report interval denotes the minimum time between these + // full reports. + val (laggingStores, latestSnapshotPerQuery) = findLaggingStores() + val coordinatorLagReportInterval = + SQLConf.get.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) + if (laggingStores.nonEmpty && + System.currentTimeMillis() - lastFullSnapshotLagReport > coordinatorLagReportInterval) { + // Mark timestamp of the full report and log the lagging instances + lastFullSnapshotLagReport = System.currentTimeMillis() logWarning( - log"StateStoreCoordinator Snapshot Lag - Number of state stores falling behind: " + - log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)} " + - log"(Latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, latestSnapshot)})" + log"StateStoreCoordinator Snapshot Lag Detected - " + + log"Number of state stores falling behind: " + + log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" ) laggingStores.foreach { storeProviderId => - val logMessage = stateStoreSnapshotVersions.get(storeProviderId) match { + val latestSnapshot = latestSnapshotPerQuery(storeProviderId.queryRunId) + val logMessage = stateStoreLatestUploadedSnapshot.get(storeProviderId) match { case Some(snapshotEvent) => val versionDelta = latestSnapshot.version - snapshotEvent.version val timeDelta = latestSnapshot.timestamp - snapshotEvent.timestamp - log"StateStoreCoordinator Snapshot Lag - State store falling behind " + + log"StateStoreCoordinator Snapshot Lag Detected - State store falling behind " + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + - log"(version delta: ${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + + log"(Latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, latestSnapshot)}, " + + log"version delta: ${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" case None => - log"StateStoreCoordinator Snapshot Lag - State store falling behind " + - log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} (never uploaded)" + log"StateStoreCoordinator Snapshot Lag Detected - State store falling behind " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + + log"(Latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, latestSnapshot)}, " + + log"never uploaded)" } logWarning(logMessage) } } context.reply(true) - case GetLatestSnapshotVersion(providerId) => - val version = stateStoreSnapshotVersions.get(providerId).map(_.version) + case GetLatestSnapshotVersionForTesting(providerId) => + val version = stateStoreLatestUploadedSnapshot.get(providerId).map(_.version) logDebug(s"Got latest snapshot version of the state store $providerId: $version") context.reply(version) - case GetLaggingStores => + case GetLaggingStoresForTesting => val (laggingStores, _) = findLaggingStores() - logDebug(s"Got lagging state stores ${laggingStores - .map( - id => - s"StateStoreId(operatorId=${id.storeId.operatorId}, " + - s"partitionId=${id.storeId.partitionId}, " + - s"storeName=${id.storeId.storeName})" - ) - .mkString(", ")}") + logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}") context.reply(laggingStores) case StopCoordinator => @@ -281,9 +308,17 @@ private class StateStoreCoordinator( val versionDelta = latest.version - version val timeDelta = latest.timestamp - timestamp val minVersionDeltaForLogging = - sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) - // Use 10 times the maintenance interval as the minimum time delta for logging - val minTimeDeltaForLogging = 10 * sqlConf.getConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL) + SQLConf.get.getConf(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) + // Use a multiple of the maintenance interval as the minimum time delta for logging + val maintenanceMultiplierForThreshold = + SQLConf.get.getConf( + SQLConf.STATE_STORE_COORDINATOR_MAINTENANCE_MULTIPLIER_FOR_MIN_TIME_DELTA_TO_LOG + ) + val minTimeDeltaForLogging = + maintenanceMultiplierForThreshold * SQLConf.get.getConf( + SQLConf.STREAMING_MAINTENANCE_INTERVAL + ) + versionDelta >= minVersionDeltaForLogging || (version >= 0 && timeDelta > minTimeDeltaForLogging) @@ -298,21 +333,25 @@ private class StateStoreCoordinator( } } - private def findLaggingStores(): (Seq[StateStoreProviderId], SnapshotUploadEvent) = { + private def findLaggingStores(): (Seq[StateStoreProviderId], Map[UUID, SnapshotUploadEvent]) = { if (instances.isEmpty) { - return (Seq.empty, SnapshotUploadEvent(-1, 0)) + return (Seq.empty, Map.empty) } - // Find the most updated instance to use as reference point - val latestSnapshot = instances - .map( - instance => stateStoreSnapshotVersions.getOrElse(instance._1, SnapshotUploadEvent(-1, 0)) - ).max + // Group instances by queryRunId and find the latest snapshot upload for each query + val latestSnapshotsByQuery = instances.groupBy(_._1.queryRunId).view.mapValues { + queryInstances => + queryInstances.map { + case (storeProviderId, _) => + stateStoreLatestUploadedSnapshot.getOrElse(storeProviderId, defaultSnapshotUploadEvent) + }.max + }.toMap // Look for instances that are lagging behind in snapshot uploads val laggingStores = instances.keys.filter { storeProviderId => - stateStoreSnapshotVersions - .getOrElse(storeProviderId, SnapshotUploadEvent(-1, 0)) + val latestSnapshot = latestSnapshotsByQuery(storeProviderId.queryRunId) + stateStoreLatestUploadedSnapshot + .getOrElse(storeProviderId, defaultSnapshotUploadEvent) .isLagging(latestSnapshot) }.toSeq - (laggingStores, latestSnapshot) + (laggingStores, latestSnapshotsByQuery) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index fc20604fa262..2efd536723cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -167,7 +167,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "false" + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "false", + SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" ) { case (coordRef, spark) => import spark.implicits._ @@ -195,7 +196,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => val providerId = StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, partitionId), query.runId) - assert(coordRef.getLatestSnapshotVersion(providerId).isEmpty) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) } query.stop() } @@ -209,7 +210,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "2" + SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" ) { case (coordRef, spark) => import spark.implicits._ @@ -237,10 +239,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => val providerId = StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, partitionId), query.runId) - assert(coordRef.getLatestSnapshotVersion(providerId).get >= 0) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) } // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStores().isEmpty) + assert(coordRef.getLaggingStoresForTesting().isEmpty) query.stop() } } @@ -257,7 +259,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "2" + SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" ) { case (coordRef, spark) => import spark.implicits._ @@ -286,14 +289,14 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, partitionId), query.runId) if (partitionId <= 1) { // Verify state stores in partition 0 and 1 are lagging and did not upload anything - assert(coordRef.getLatestSnapshotVersion(providerId).isEmpty) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) } else { // Verify other stores have uploaded a snapshot and it's logged by the coordinator - assert(coordRef.getLatestSnapshotVersion(providerId).get >= 0) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) } } // We should have two state stores (id 0 and 1) that are lagging behind at this point - val laggingStores = coordRef.getLaggingStores() + val laggingStores = coordRef.getLaggingStoresForTesting() assert(laggingStores.size == 2) assert(laggingStores.forall(_.storeId.partitionId <= 1)) query.stop() @@ -314,7 +317,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "4" + SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "4", + SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" ) { case (coordRef, spark) => import spark.implicits._ @@ -349,11 +353,11 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { StateStoreId(stateCheckpointDir, 0, partitionId, storeName), query.runId ) - assert(coordRef.getLatestSnapshotVersion(providerId).get >= 0) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) } } // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStores().isEmpty) + assert(coordRef.getLaggingStoresForTesting().isEmpty) query.stop() } } @@ -370,7 +374,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "4" + SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "4", + SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" ) { case (coordRef, spark) => import spark.implicits._ @@ -407,16 +412,16 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) if (partitionId <= 1) { // Verify state stores in partition 0 and 1 are lagging and did not upload anything - assert(coordRef.getLatestSnapshotVersion(providerId).isEmpty) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) } else { // Verify other stores have uploaded a snapshot and it's logged by the coordinator - assert(coordRef.getLatestSnapshotVersion(providerId).get >= 0) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) } } } // Verify that only stores from partition id 0 and 1 are lagging behind. // Each partition has 4 stores for join queries, so there are 2 * 4 = 8 lagging stores. - val laggingStores = coordRef.getLaggingStores() + val laggingStores = coordRef.getLaggingStoresForTesting() assert(laggingStores.size == 2 * 4) assert(laggingStores.forall(_.storeId.partitionId <= 1)) } From f2b84d4b32921b5a71c1548f806d9430e17ec3e4 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Thu, 13 Mar 2025 10:15:17 -0700 Subject: [PATCH 10/36] SPARK-51358 Clean up comments and styling --- .../apache/spark/sql/internal/SQLConf.scala | 14 +-- .../execution/streaming/state/RocksDB.scala | 2 +- .../state/RocksDBStateStoreProvider.scala | 7 +- .../streaming/state/StateStore.scala | 2 +- .../state/StateStoreCoordinator.scala | 108 ++++++++++-------- 5 files changed, 71 insertions(+), 62 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 841d48bc6f04..4e3c5b11c1db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2256,7 +2256,7 @@ object SQLConf { "The multiplier used to determine the minimum time threshold between the single " + "state store instance and the most recent version across all state store instances " + "to log a warning message. The threshold is calculated as the maintenance interval, " + - "multiplied by this multiplier." + "multiplied by this value." ) .version("4.0.0") .intConf @@ -2267,8 +2267,8 @@ object SQLConf { buildConf("spark.sql.streaming.stateStore.coordinatorReportUpload.enabled") .internal() .doc( - "When true, the state store instances will send messages to the state store " + - "coordinator to report upload events whenever it finishes uploading a snapshot." + "If enabled, state store instances will send a message to the state store " + + "coordinator whenever they complete a snapshot upload." ) .version("4.0.0") .booleanConf @@ -2278,10 +2278,10 @@ object SQLConf { buildConf("spark.sql.streaming.stateStore.snapshotLagReportInterval") .internal() .doc( - "The minimum amount of time between the state store coordinator's full report on " + - "state store instances falling behind in snapshot uploads. The reports are not " + - "guaranteed to be separated by this interval, because the coordinator only checks " + - "for lagging instances when it receives a new snapshot upload message." + "The minimum amount of time between the state store coordinator's report on " + + "state store instances lagging in snapshot uploads. The reports may be delayed " + + "as the coordinator only checks for lagging instances upon receiving a new " + + "snapshot upload message." ) .version("4.0.0") .timeConf(TimeUnit.MILLISECONDS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 77b77541f6d2..16032a6eafa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1468,7 +1468,7 @@ class RocksDB( log"with uniqueId: ${MDC(LogKeys.UUID, snapshot.uniqueId)} " + log"time taken: ${MDC(LogKeys.TIME_UNITS, uploadTime)} ms. " + log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}") - // Report to coordinator that the snapshot has been uploaded when + // Only report to the coordinator that the snapshot has been uploaded when // changelog checkpointing is enabled, since that is when stores can lag behind. if (enableChangelogCheckpointing) { providerListener.foreach(_.reportSnapshotUploaded(snapshot.version)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 1f300cfece79..86b04a45f55e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -647,9 +647,10 @@ private[sql] class RocksDBStateStoreProvider } } - /** Callback function from RocksDB to report events to the coordinator. - * Additional information such as state store ID and query run ID are populated here - * to report back to the coordinator. + /** + * Callback function from RocksDB to report events to the coordinator. + * Additional information such as the state store ID and the query run ID are + * attached here to report back to the coordinator. * * @param version The snapshot version that was just uploaded from RocksDB */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index a40e68eb6127..f94c1dc187b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -1036,7 +1036,7 @@ object StateStore extends Logging { private[state] def reportSnapshotUploaded( storeProviderId: StateStoreProviderId, snapshotVersion: Long): Unit = { - // Send current timestamp of uploaded snapshot as well + // Attach the current timestamp of uploaded snapshot and send the message to the coordinator val currentTime = System.currentTimeMillis() coordinatorRef.foreach(_.snapshotUploaded(storeProviderId, snapshotVersion, currentTime)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 42be484168c0..863e97fa3045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -235,51 +235,49 @@ private class StateStoreCoordinator( context.reply(true) case ReportSnapshotUploaded(providerId, version, timestamp) => - // Ignore this upload event if the version isn't more recent for this provider, - // since it's possible that an old version gets uploaded after a new executor uploads - // for the same provider but with a newer snapshot. - if (stateStoreLatestUploadedSnapshot - .getOrElse(providerId, defaultSnapshotUploadEvent) - .version >= version) { - context.reply(true) - } - stateStoreLatestUploadedSnapshot.put(providerId, SnapshotUploadEvent(version, timestamp)) - logDebug(s"Snapshot version $version was uploaded for provider $providerId") - // Report all stores that are behind in snapshot uploads. - // Only report the full list of providers lagging behind if the last reported time - // is not recent. The lag report interval denotes the minimum time between these - // full reports. - val (laggingStores, latestSnapshotPerQuery) = findLaggingStores() - val coordinatorLagReportInterval = - SQLConf.get.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) - if (laggingStores.nonEmpty && + // Ignore this upload event if the registered latest version for the provider is more recent, + // since it's possible that an older version gets uploaded after a new executor uploads for + // the same provider but with a newer snapshot. + if (!stateStoreLatestUploadedSnapshot.get(providerId).exists(_.version >= version)) { + stateStoreLatestUploadedSnapshot.put(providerId, SnapshotUploadEvent(version, timestamp)) + logDebug(s"Snapshot version $version was uploaded for provider $providerId") + + // Report all stores that are behind in snapshot uploads. + // Only report the full list of providers lagging behind if the last reported time + // is not recent. The lag report interval denotes the minimum time between these + // full reports. + val (laggingStores, latestSnapshotPerQuery) = findLaggingStores() + val coordinatorLagReportInterval = + SQLConf.get.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) + if (laggingStores.nonEmpty && System.currentTimeMillis() - lastFullSnapshotLagReport > coordinatorLagReportInterval) { - // Mark timestamp of the full report and log the lagging instances - lastFullSnapshotLagReport = System.currentTimeMillis() - logWarning( - log"StateStoreCoordinator Snapshot Lag Detected - " + - log"Number of state stores falling behind: " + - log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" - ) - laggingStores.foreach { storeProviderId => - val latestSnapshot = latestSnapshotPerQuery(storeProviderId.queryRunId) - val logMessage = stateStoreLatestUploadedSnapshot.get(storeProviderId) match { - case Some(snapshotEvent) => - val versionDelta = latestSnapshot.version - snapshotEvent.version - val timeDelta = latestSnapshot.timestamp - snapshotEvent.timestamp - - log"StateStoreCoordinator Snapshot Lag Detected - State store falling behind " + - log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + - log"(Latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, latestSnapshot)}, " + - log"version delta: ${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + - log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" - case None => - log"StateStoreCoordinator Snapshot Lag Detected - State store falling behind " + - log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + - log"(Latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, latestSnapshot)}, " + - log"never uploaded)" + // Mark timestamp of the full report and log the lagging instances + lastFullSnapshotLagReport = System.currentTimeMillis() + logWarning( + log"StateStoreCoordinator Snapshot Lag Detected - " + + log"Number of state stores falling behind: " + + log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" + ) + laggingStores.foreach { storeProviderId => + val latestSnapshot = latestSnapshotPerQuery(storeProviderId.queryRunId) + val logMessage = stateStoreLatestUploadedSnapshot.get(storeProviderId) match { + case Some(snapshotEvent) => + val versionDelta = latestSnapshot.version - snapshotEvent.version + val timeDelta = latestSnapshot.timestamp - snapshotEvent.timestamp + + log"StateStoreCoordinator Snapshot Lag Detected - State store falling behind " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + + log"(Latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, latestSnapshot)}, " + + log"version delta: ${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + + log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" + case None => + log"StateStoreCoordinator Snapshot Lag Detected - State store falling behind " + + log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + + log"(Latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, latestSnapshot)}, " + + log"never uploaded)" + } + logWarning(logMessage) } - logWarning(logMessage) } } context.reply(true) @@ -304,12 +302,13 @@ private class StateStoreCoordinator( version: Long, timestamp: Long ) extends Ordered[SnapshotUploadEvent] { + def isLagging(latest: SnapshotUploadEvent): Boolean = { val versionDelta = latest.version - version val timeDelta = latest.timestamp - timestamp - val minVersionDeltaForLogging = - SQLConf.get.getConf(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) - // Use a multiple of the maintenance interval as the minimum time delta for logging + + // Determine alert thresholds from configurations for both time and version differences. + // Use a multiple of the maintenance interval as the minimum time delta for logging. val maintenanceMultiplierForThreshold = SQLConf.get.getConf( SQLConf.STATE_STORE_COORDINATOR_MAINTENANCE_MULTIPLIER_FOR_MIN_TIME_DELTA_TO_LOG @@ -318,7 +317,8 @@ private class StateStoreCoordinator( maintenanceMultiplierForThreshold * SQLConf.get.getConf( SQLConf.STREAMING_MAINTENANCE_INTERVAL ) - + val minVersionDeltaForLogging = + SQLConf.get.getConf(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) versionDelta >= minVersionDeltaForLogging || (version >= 0 && timeDelta > minTimeDeltaForLogging) @@ -334,24 +334,32 @@ private class StateStoreCoordinator( } private def findLaggingStores(): (Seq[StateStoreProviderId], Map[UUID, SnapshotUploadEvent]) = { + // Skip this check if there are no active instances if (instances.isEmpty) { return (Seq.empty, Map.empty) } + // Group instances by queryRunId and find the latest snapshot upload for each query - val latestSnapshotsByQuery = instances.groupBy(_._1.queryRunId).view.mapValues { - queryInstances => + val latestSnapshotsByQuery = instances + .groupBy(_._1.queryRunId) + .view + .mapValues { queryInstances => + // Determine the latest snapshot upload across all instances for this query queryInstances.map { case (storeProviderId, _) => stateStoreLatestUploadedSnapshot.getOrElse(storeProviderId, defaultSnapshotUploadEvent) }.max - }.toMap + }.toMap + // Look for instances that are lagging behind in snapshot uploads val laggingStores = instances.keys.filter { storeProviderId => + // Compare this instance with the respective query's latest snapshot val latestSnapshot = latestSnapshotsByQuery(storeProviderId.queryRunId) stateStoreLatestUploadedSnapshot .getOrElse(storeProviderId, defaultSnapshotUploadEvent) .isLagging(latestSnapshot) }.toSeq + (laggingStores, latestSnapshotsByQuery) } } From 77aa7dbe53bc555bc79a7ab7fe81106ea98a2851 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Thu, 13 Mar 2025 10:35:42 -0700 Subject: [PATCH 11/36] SPARK-51358 Temporarily add faulty provider for tests --- .../streaming/state/StateStoreCoordinatorSuite.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 2efd536723cb..6734b8901b47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -31,6 +31,18 @@ import org.apache.spark.sql.functions.{count, expr} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils +// SkipMaintenanceOnCertainPartitionsProvider is a test-only provider that skips running +// maintenance for partitions 0 and 1 (these are arbitrary choices). This is used to test +// snapshot upload lag can be observed through StreamingQueryProgress metrics. +class SkipMaintenanceOnCertainPartitionsProvider extends RocksDBStateStoreProvider { + override def doMaintenance(): Unit = { + if (stateStoreId.partitionId == 0 || stateStoreId.partitionId == 1) { + return + } + super.doMaintenance() + } +} + class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { import StateStoreCoordinatorSuite._ From 3148211d60104e5bd8d59d38840679da4d15be8c Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Thu, 13 Mar 2025 16:49:28 -0700 Subject: [PATCH 12/36] SPARK-51358 Switch to SparkConf --- .../apache/spark/sql/internal/SQLConf.scala | 17 ++++---- .../sql/classic/StreamingQueryManager.scala | 2 +- .../state/StateStoreCoordinator.scala | 42 +++++++++---------- .../state/StateStoreCoordinatorSuite.scala | 32 ++++++++++---- 4 files changed, 54 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4e3c5b11c1db..b082f9e995af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2249,19 +2249,18 @@ object SQLConf { .checkValue(k => k >= 1, "Must be greater than or equal to 1") .createWithDefault(30) - val STATE_STORE_COORDINATOR_MAINTENANCE_MULTIPLIER_FOR_MIN_TIME_DELTA_TO_LOG = - buildConf("spark.sql.streaming.stateStore.maintenanceMultiplierForMinTimeDeltaToLog") + val STATE_STORE_COORDINATOR_MIN_SNAPSHOT_TIME_DELTA_TO_LOG = + buildConf("spark.sql.streaming.stateStore.minSnapshotTimeDeltaToLog") .internal() .doc( - "The multiplier used to determine the minimum time threshold between the single " + - "state store instance and the most recent version across all state store instances " + - "to log a warning message. The threshold is calculated as the maintenance interval, " + - "multiplied by this value." + "Minimum time between the timestamps of the most recent uploaded snapshot of a single " + + "state store instance and the most recent snapsnot upload's timestamp across all " + + "state store instances to log a warning message. It is recommended for this value to be " + + "longer than the maintenance interval." ) .version("4.0.0") - .intConf - .checkValue(k => k >= 1, "Must be greater than or equal to 1") - .createWithDefault(20) + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(TimeUnit.MINUTES.toMillis(30)) val STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED = buildConf("spark.sql.streaming.stateStore.coordinatorReportUpload.enabled") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala index 6d4a3ecd3603..6ce6f06de113 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala @@ -53,7 +53,7 @@ class StreamingQueryManager private[sql] ( with Logging { private[sql] val stateStoreCoordinator = - StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env, sqlConf) + StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) private val listenerBus = new StreamingQueryListenerBus(Some(sparkSession.sparkContext.listenerBus)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 863e97fa3045..3c0813d1136e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -21,7 +21,7 @@ import java.util.UUID import scala.collection.mutable -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.ExecutorCacheTaskLocation @@ -93,9 +93,9 @@ object StateStoreCoordinatorRef extends Logging { /** * Create a reference to a [[StateStoreCoordinator]] */ - def forDriver(env: SparkEnv, conf: SQLConf): StateStoreCoordinatorRef = synchronized { + def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { try { - val coordinator = new StateStoreCoordinator(env.rpcEnv, conf) + val coordinator = new StateStoreCoordinator(env.rpcEnv, env.conf) val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator) logInfo("Registered StateStoreCoordinator endpoint") new StateStoreCoordinatorRef(coordinatorRef) @@ -183,7 +183,7 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { */ private class StateStoreCoordinator( override val rpcEnv: RpcEnv, - val sqlConf: SQLConf) + conf: SparkConf) extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] @@ -192,6 +192,13 @@ private class StateStoreCoordinator( private val stateStoreLatestUploadedSnapshot = new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent] + // Determine alert thresholds from configurations for both time and version differences. + private val minTimeDeltaForLogging = + conf.get(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_TIME_DELTA_TO_LOG) + private val minVersionDeltaForLogging = + conf.get(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) + + // Default snapshot upload event to use when a provider has never uploaded a snapshot private val defaultSnapshotUploadEvent = SnapshotUploadEvent(-1, 0) // Stores the last timestamp in milliseconds where the coordinator did a full report on @@ -255,8 +262,8 @@ private class StateStoreCoordinator( lastFullSnapshotLagReport = System.currentTimeMillis() logWarning( log"StateStoreCoordinator Snapshot Lag Detected - " + - log"Number of state stores falling behind: " + - log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" + log"Number of state stores falling behind: " + + log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" ) laggingStores.foreach { storeProviderId => val latestSnapshot = latestSnapshotPerQuery(storeProviderId.queryRunId) @@ -307,25 +314,18 @@ private class StateStoreCoordinator( val versionDelta = latest.version - version val timeDelta = latest.timestamp - timestamp - // Determine alert thresholds from configurations for both time and version differences. - // Use a multiple of the maintenance interval as the minimum time delta for logging. - val maintenanceMultiplierForThreshold = - SQLConf.get.getConf( - SQLConf.STATE_STORE_COORDINATOR_MAINTENANCE_MULTIPLIER_FOR_MIN_TIME_DELTA_TO_LOG - ) - val minTimeDeltaForLogging = - maintenanceMultiplierForThreshold * SQLConf.get.getConf( - SQLConf.STREAMING_MAINTENANCE_INTERVAL - ) - val minVersionDeltaForLogging = - SQLConf.get.getConf(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) - versionDelta >= minVersionDeltaForLogging || (version >= 0 && timeDelta > minTimeDeltaForLogging) } - override def compare(that: SnapshotUploadEvent): Int = { - this.version.compare(that.version) + override def compare(otherEvent: SnapshotUploadEvent): Int = { + // Compare by version first, then by timestamp as tiebreaker + val versionCompare = this.version.compare(otherEvent.version) + if (versionCompare == 0) { + this.timestamp.compare(otherEvent.timestamp) + } else { + versionCompare + } } override def toString(): String = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 6734b8901b47..f4cc94cc3f2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -115,7 +115,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("multiple references have same underlying coordinator") { withCoordinatorRef(sc) { coordRef1 => - val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env, new SQLConf) + val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) @@ -222,12 +222,13 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "2", SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" ) { case (coordRef, spark) => import spark.implicits._ implicit val sqlContext = spark.sqlContext + // Set directly to SparkConf + sc.conf.set(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key, "2") // Start a query and run some data to force snapshot uploads val inputData = MemoryStream[Int] @@ -256,6 +257,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { // Verify that we should not have any state stores lagging behind assert(coordRef.getLaggingStoresForTesting().isEmpty) query.stop() + + // Manually unset the custom SparkConf configs + sc.conf.remove(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) } } @@ -271,12 +275,13 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "2", SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" ) { case (coordRef, spark) => import spark.implicits._ implicit val sqlContext = spark.sqlContext + // Set directly to SparkConf + sc.conf.set(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key, "2") // Start a query and run some data to force snapshot uploads val inputData = MemoryStream[Int] @@ -312,6 +317,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(laggingStores.size == 2) assert(laggingStores.forall(_.storeId.partitionId <= 1)) query.stop() + + // Manually unset the custom SparkConf configs + sc.conf.remove(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) } } @@ -325,16 +333,17 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { withCoordinatorAndSQLConf( sc, SQLConf.SHUFFLE_PARTITIONS.key -> "3", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "50", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "4", SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" ) { case (coordRef, spark) => import spark.implicits._ implicit val sqlContext = spark.sqlContext + // Set directly to SparkConf + sc.conf.set(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key, "3") // Start a join query and run some data to force snapshot uploads val input1 = MemoryStream[Int] @@ -371,6 +380,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { // Verify that we should not have any state stores lagging behind assert(coordRef.getLaggingStoresForTesting().isEmpty) query.stop() + + // Manually unset the custom SparkConf configs + sc.conf.remove(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) } } @@ -381,17 +393,18 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { withCoordinatorAndSQLConf( sc, SQLConf.SHUFFLE_PARTITIONS.key -> "3", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "50", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key -> "4", SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" ) { case (coordRef, spark) => import spark.implicits._ implicit val sqlContext = spark.sqlContext + // Set directly to SparkConf + sc.conf.set(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key, "3") // Start a join query and run some data to force snapshot uploads val input1 = MemoryStream[Int] @@ -436,6 +449,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { val laggingStores = coordRef.getLaggingStoresForTesting() assert(laggingStores.size == 2 * 4) assert(laggingStores.forall(_.storeId.partitionId <= 1)) + + // Manually unset the custom SparkConf configs + sc.conf.remove(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) } } } @@ -444,7 +460,7 @@ object StateStoreCoordinatorSuite { def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = { var coordinatorRef: StateStoreCoordinatorRef = null try { - coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env, new SQLConf) + coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env) body(coordinatorRef) } finally { if (coordinatorRef != null) coordinatorRef.stop() From 6ba4dcf5e50e8d38138d36e02594056b1617c7d4 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Mon, 17 Mar 2025 19:36:08 -0700 Subject: [PATCH 13/36] SPARK-51358 Use multipliers for alert thresholds --- .../apache/spark/sql/internal/SQLConf.scala | 40 +++++----- .../sql/classic/StreamingQueryManager.scala | 2 +- .../state/StateStoreCoordinator.scala | 35 +++++---- .../state/StateStoreCoordinatorSuite.scala | 73 +++++++++---------- 4 files changed, 79 insertions(+), 71 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b082f9e995af..925286f66fe8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2236,31 +2236,35 @@ object SQLConf { .booleanConf .createWithDefault(true) - val STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG = - buildConf("spark.sql.streaming.stateStore.minSnapshotVersionDeltaToLog") + val STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG = + buildConf("spark.sql.streaming.stateStore.minSnapshotDeltaMultiplierForMinVersionDeltaToLog") .internal() .doc( - "Minimum number of versions between the most recent uploaded snapshot version of a " + - "single state store instance and the most recent version across all state store " + - "instances to log a warning message." + "This multiplier determines the minimum version threshold for logging warnings when a " + + "state store instance falls behind. The coordinator logs a warning if a state store's " + + "last uploaded snapshot's version lags behind the most recent snapshot version by this " + + "threshold. The threshold is calculated as the configured minimum number of deltas " + + "needed to create a snapshot, multiplied by this multiplier." ) - .version("4.0.0") + .version("4.1.0") .intConf .checkValue(k => k >= 1, "Must be greater than or equal to 1") - .createWithDefault(30) + .createWithDefault(5) - val STATE_STORE_COORDINATOR_MIN_SNAPSHOT_TIME_DELTA_TO_LOG = - buildConf("spark.sql.streaming.stateStore.minSnapshotTimeDeltaToLog") + val STATE_STORE_COORDINATOR_MAINTENANCE_MULTIPLIER_FOR_MIN_TIME_DELTA_TO_LOG = + buildConf("spark.sql.streaming.stateStore.maintenanceMultiplierForMinTimeDeltaToLog") .internal() .doc( - "Minimum time between the timestamps of the most recent uploaded snapshot of a single " + - "state store instance and the most recent snapsnot upload's timestamp across all " + - "state store instances to log a warning message. It is recommended for this value to be " + - "longer than the maintenance interval." + "This multiplier determines the minimum time threshold for logging warnings when a " + + "state store instance falls behind. The coordinator logs a warning if a state store's " + + "last snapshot upload time lags behind the most recent snapshot upload by this " + + "threshold. The threshold is calculated as the maintenance interval multiplied by " + + "this multiplier." ) - .version("4.0.0") - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(TimeUnit.MINUTES.toMillis(30)) + .version("4.1.0") + .intConf + .checkValue(k => k >= 1, "Must be greater than or equal to 1") + .createWithDefault(10) val STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED = buildConf("spark.sql.streaming.stateStore.coordinatorReportUpload.enabled") @@ -2269,7 +2273,7 @@ object SQLConf { "If enabled, state store instances will send a message to the state store " + "coordinator whenever they complete a snapshot upload." ) - .version("4.0.0") + .version("4.1.0") .booleanConf .createWithDefault(false) @@ -2282,7 +2286,7 @@ object SQLConf { "as the coordinator only checks for lagging instances upon receiving a new " + "snapshot upload message." ) - .version("4.0.0") + .version("4.1.0") .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(TimeUnit.MINUTES.toMillis(5)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala index 6ce6f06de113..6d4a3ecd3603 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala @@ -53,7 +53,7 @@ class StreamingQueryManager private[sql] ( with Logging { private[sql] val stateStoreCoordinator = - StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) + StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env, sqlConf) private val listenerBus = new StreamingQueryListenerBus(Some(sparkSession.sparkContext.listenerBus)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 3c0813d1136e..0f34f6bd47c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -21,7 +21,7 @@ import java.util.UUID import scala.collection.mutable -import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.SparkEnv import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.ExecutorCacheTaskLocation @@ -93,9 +93,9 @@ object StateStoreCoordinatorRef extends Logging { /** * Create a reference to a [[StateStoreCoordinator]] */ - def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + def forDriver(env: SparkEnv, sqlConf: SQLConf): StateStoreCoordinatorRef = synchronized { try { - val coordinator = new StateStoreCoordinator(env.rpcEnv, env.conf) + val coordinator = new StateStoreCoordinator(env.rpcEnv, sqlConf) val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator) logInfo("Registered StateStoreCoordinator endpoint") new StateStoreCoordinatorRef(coordinatorRef) @@ -183,7 +183,7 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { */ private class StateStoreCoordinator( override val rpcEnv: RpcEnv, - conf: SparkConf) + val sqlConf: SQLConf) extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] @@ -192,12 +192,6 @@ private class StateStoreCoordinator( private val stateStoreLatestUploadedSnapshot = new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent] - // Determine alert thresholds from configurations for both time and version differences. - private val minTimeDeltaForLogging = - conf.get(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_TIME_DELTA_TO_LOG) - private val minVersionDeltaForLogging = - conf.get(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) - // Default snapshot upload event to use when a provider has never uploaded a snapshot private val defaultSnapshotUploadEvent = SnapshotUploadEvent(-1, 0) @@ -255,7 +249,7 @@ private class StateStoreCoordinator( // full reports. val (laggingStores, latestSnapshotPerQuery) = findLaggingStores() val coordinatorLagReportInterval = - SQLConf.get.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) + sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) if (laggingStores.nonEmpty && System.currentTimeMillis() - lastFullSnapshotLagReport > coordinatorLagReportInterval) { // Mark timestamp of the full report and log the lagging instances @@ -314,8 +308,23 @@ private class StateStoreCoordinator( val versionDelta = latest.version - version val timeDelta = latest.timestamp - timestamp - versionDelta >= minVersionDeltaForLogging || - (version >= 0 && timeDelta > minTimeDeltaForLogging) + // Determine alert thresholds from configurations for both time and version differences. + val snapshotVersionDeltaMultiplier = sqlConf.getConf( + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG) + val maintenanceIntervalMultiplier = sqlConf.getConf( + SQLConf.STATE_STORE_COORDINATOR_MAINTENANCE_MULTIPLIER_FOR_MIN_TIME_DELTA_TO_LOG) + val minDeltasForSnapshot = sqlConf.getConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + val maintenanceInterval = sqlConf.getConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL) + + // Use the configured multipliers to determine the proper alert thresholds + val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * minDeltasForSnapshot + val minTimeDeltaForLogging = maintenanceIntervalMultiplier * maintenanceInterval + + // Mark a state store as lagging if it is behind in both version and time. + // In the case that a snapshot was never uploaded, we treat version -1 as the preceding + // version of 0, and only rely on the version delta condition. + // Time requirement will be automatically satisfied as the initial timestamp is 0. + versionDelta >= minVersionDeltaForLogging && timeDelta > minTimeDeltaForLogging } override def compare(otherEvent: SnapshotUploadEvent): Int = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index f4cc94cc3f2d..06d30c487ebb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -115,7 +115,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("multiple references have same underlying coordinator") { withCoordinatorRef(sc) { coordRef1 => - val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) + val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env, new SQLConf) val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID) @@ -222,13 +222,13 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" + SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG.key -> + "2" ) { case (coordRef, spark) => import spark.implicits._ implicit val sqlContext = spark.sqlContext - // Set directly to SparkConf - sc.conf.set(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key, "2") // Start a query and run some data to force snapshot uploads val inputData = MemoryStream[Int] @@ -240,11 +240,12 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .queryName("query") .option("checkpointLocation", checkpointLocation.toString) .start() - // Add and commit data multiple times to force new snapshot versions - inputData.addData(1, 2, 3) - query.processAllAvailable() - inputData.addData(1, 2, 3) - query.processAllAvailable() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 2).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(1000) + } val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation @@ -257,9 +258,6 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { // Verify that we should not have any state stores lagging behind assert(coordRef.getLaggingStoresForTesting().isEmpty) query.stop() - - // Manually unset the custom SparkConf configs - sc.conf.remove(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) } } @@ -275,13 +273,13 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" + SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG.key -> + "2" ) { case (coordRef, spark) => import spark.implicits._ implicit val sqlContext = spark.sqlContext - // Set directly to SparkConf - sc.conf.set(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key, "2") // Start a query and run some data to force snapshot uploads val inputData = MemoryStream[Int] @@ -293,11 +291,12 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .queryName("query") .option("checkpointLocation", checkpointLocation.toString) .start() - // Add and commit data multiple times to force new snapshot versions - inputData.addData(1, 2, 3) - query.processAllAvailable() - inputData.addData(1, 2, 3) - query.processAllAvailable() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 2).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(1000) + } val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation @@ -317,9 +316,6 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(laggingStores.size == 2) assert(laggingStores.forall(_.storeId.partitionId <= 1)) query.stop() - - // Manually unset the custom SparkConf configs - sc.conf.remove(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) } } @@ -337,13 +333,13 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" + SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG.key -> + "5" ) { case (coordRef, spark) => import spark.implicits._ implicit val sqlContext = spark.sqlContext - // Set directly to SparkConf - sc.conf.set(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key, "3") // Start a join query and run some data to force snapshot uploads val input1 = MemoryStream[Int] @@ -357,11 +353,12 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .queryName("query") .option("checkpointLocation", checkpointLocation.toString) .start() - // Add and commit data multiple times to force new snapshot versions + // Add, commit, and wait multiple times to force snapshot versions and time difference (0 until 5).foreach { _ => input1.addData(1, 5) input2.addData(1, 5, 10) query.processAllAvailable() + Thread.sleep(500) } val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation @@ -380,9 +377,6 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { // Verify that we should not have any state stores lagging behind assert(coordRef.getLaggingStoresForTesting().isEmpty) query.stop() - - // Manually unset the custom SparkConf configs - sc.conf.remove(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) } } @@ -398,13 +392,13 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" + SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG.key -> + "5" ) { case (coordRef, spark) => import spark.implicits._ implicit val sqlContext = spark.sqlContext - // Set directly to SparkConf - sc.conf.set(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG.key, "3") // Start a join query and run some data to force snapshot uploads val input1 = MemoryStream[Int] @@ -418,11 +412,12 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .queryName("query") .option("checkpointLocation", checkpointLocation.toString) .start() - // Add and commit data multiple times to force new snapshot versions + // Add, commit, and wait multiple times to force snapshot versions and time difference (0 until 5).foreach { _ => input1.addData(1, 5) input2.addData(1, 5, 10) query.processAllAvailable() + Thread.sleep(500) } val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation @@ -449,9 +444,6 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { val laggingStores = coordRef.getLaggingStoresForTesting() assert(laggingStores.size == 2 * 4) assert(laggingStores.forall(_.storeId.partitionId <= 1)) - - // Manually unset the custom SparkConf configs - sc.conf.remove(SQLConf.STATE_STORE_COORDINATOR_MIN_SNAPSHOT_VERSION_DELTA_TO_LOG) } } } @@ -460,7 +452,7 @@ object StateStoreCoordinatorSuite { def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = { var coordinatorRef: StateStoreCoordinatorRef = null try { - coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env) + coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env, new SQLConf) body(coordinatorRef) } finally { if (coordinatorRef != null) coordinatorRef.stop() @@ -469,9 +461,10 @@ object StateStoreCoordinatorSuite { def withCoordinatorAndSQLConf(sc: SparkContext, pairs: (String, String)*)( body: (StateStoreCoordinatorRef, SparkSession) => Unit): Unit = { + var spark: SparkSession = null var coordinatorRef: StateStoreCoordinatorRef = null try { - val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + spark = SparkSession.builder().sparkContext(sc).getOrCreate() SparkSession.setActiveSession(spark) coordinatorRef = spark.streams.stateStoreCoordinator // Set up SQLConf entries @@ -479,6 +472,8 @@ object StateStoreCoordinatorSuite { body(coordinatorRef, spark) } finally { SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop())) + // Unset all custom SQLConf entries + if (spark != null) pairs.foreach { case (key, _) => spark.conf.unset(key) } if (coordinatorRef != null) coordinatorRef.stop() StateStore.stop() } From 11b5343db663a6e4d7cd12b5e36e802005f98e28 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Tue, 18 Mar 2025 14:55:51 -0700 Subject: [PATCH 14/36] SPARK-51358 Compare diff to batches instead of relative instances --- .../apache/spark/sql/internal/SQLConf.scala | 15 +- .../streaming/ProgressReporter.scala | 4 + .../state/RocksDBStateStoreProvider.scala | 21 ++- .../streaming/state/StateStore.scala | 11 ++ .../state/StateStoreCoordinator.scala | 166 ++++++++++-------- .../state/StateStoreCoordinatorSuite.scala | 32 ++-- 6 files changed, 141 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 52de6cfb1d51..4dbda130ac04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2262,8 +2262,8 @@ object SQLConf { .doc( "This multiplier determines the minimum version threshold for logging warnings when a " + "state store instance falls behind. The coordinator logs a warning if a state store's " + - "last uploaded snapshot's version lags behind the most recent snapshot version by this " + - "threshold. The threshold is calculated as the configured minimum number of deltas " + + "last uploaded snapshot's version lags behind the query's latest known version by " + + "this threshold. The threshold is calculated as the configured minimum number of deltas " + "needed to create a snapshot, multiplied by this multiplier." ) .version("4.1.0") @@ -2277,9 +2277,8 @@ object SQLConf { .doc( "This multiplier determines the minimum time threshold for logging warnings when a " + "state store instance falls behind. The coordinator logs a warning if a state store's " + - "last snapshot upload time lags behind the most recent snapshot upload by this " + - "threshold. The threshold is calculated as the maintenance interval multiplied by " + - "this multiplier." + "last snapshot upload time lags behind the current time by this threshold. " + + "The threshold is calculated as the maintenance interval multiplied by this multiplier." ) .version("4.1.0") .intConf @@ -2301,10 +2300,10 @@ object SQLConf { buildConf("spark.sql.streaming.stateStore.snapshotLagReportInterval") .internal() .doc( - "The minimum amount of time between the state store coordinator's report on " + + "The minimum amount of time between the state store coordinator's full report on " + "state store instances lagging in snapshot uploads. The reports may be delayed " + - "as the coordinator only checks for lagging instances upon receiving a new " + - "snapshot upload message." + "as the coordinator only checks for lagging instances upon receiving a message " + + "instructing it to do so." ) .version("4.1.0") .timeConf(TimeUnit.MILLISECONDS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 3ac07cf1d730..0f607e60c990 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream} import org.apache.spark.sql.execution.{QueryExecution, StreamSourceAwareSparkPlan} import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress} +import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent} import org.apache.spark.util.{Clock, Utils} @@ -283,6 +284,9 @@ abstract class ProgressContext( progressReporter.lastNoExecutionProgressEventTime = triggerClock.getTimeMillis() progressReporter.updateProgress(newProgress) + // Ask the state store coordinator to look for any lagging instances and report them. + StateStore.constructLaggingInstanceReport(lastExecution.runId, lastEpochId) + // Update the value since this trigger executes a batch successfully. this.execStatsOnLatestExecutedBatch = Some(execStats) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 630fabc5c523..81a3f8825635 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -664,17 +664,16 @@ private[sql] class RocksDBStateStoreProvider * @param version The snapshot version that was just uploaded from RocksDB */ def reportSnapshotUploaded(version: Long): Unit = { - if (!storeConf.stateStoreCoordinatorReportUploadEnabled) { - return - } - // Collect the state store ID and query run ID to report back to the coordinator - StateStore.reportSnapshotUploaded( - StateStoreProviderId( - stateStoreId, - UUID.fromString(getRunId(hadoopConf)) - ), - version - ) + if (storeConf.stateStoreCoordinatorReportUploadEnabled) { + // Collect the state store ID and query run ID to report back to the coordinator + StateStore.reportSnapshotUploaded( + StateStoreProviderId( + stateStoreId, + UUID.fromString(getRunId(hadoopConf)) + ), + version + ) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index f87c98338ce7..b37233e556de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -1139,6 +1139,17 @@ object StateStore extends Logging { coordinatorRef.foreach(_.snapshotUploaded(storeProviderId, snapshotVersion, currentTime)) } + private[streaming] def constructLaggingInstanceReport( + queryRunId: UUID, + latestVersion: Long): Unit = { + // Asks the coordinator to check and report for any lagging state store instances + // Attach the current time, indicating when the current batch just completed + val currentTime = System.currentTimeMillis() + coordinatorRef.foreach( + _.constructLaggingInstanceReport(queryRunId, latestVersion, currentTime) + ) + } + private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 0f34f6bd47c2..6eaa3d0285ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -66,6 +66,16 @@ private case class ReportSnapshotUploaded( timestamp: Long) extends StateStoreCoordinatorMessage +/** + * This message is used for the coordinator to look for all state stores that are lagging behind + * in snapshot uploads. The coordinator will then log a warning message for each lagging instance. + */ +private case class ConstructLaggingInstanceReport( + queryRunId: UUID, + latestVersion: Long, + timestamp: Long) + extends StateStoreCoordinatorMessage + /** * Message used for testing. * This message is used to retrieve the latest snapshot version reported for upload from a @@ -77,9 +87,12 @@ private case class GetLatestSnapshotVersionForTesting(storeId: StateStoreProvide /** * Message used for testing. * This message is used to retrieve the all active state store instance falling behind in - * snapshot uploads, whether it is through version or time criteria. + * snapshot uploads, using version and time criterias. */ -private object GetLaggingStoresForTesting +private case class GetLaggingStoresForTesting( + queryRunId: UUID, + latestVersion: Long, + timestamp: Long) extends StateStoreCoordinatorMessage private object StopCoordinator @@ -154,6 +167,16 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { rpcEndpointRef.askSync[Boolean](ReportSnapshotUploaded(storeProviderId, version, timestamp)) } + /** Ask the coordinator to find all state store instances that are lagging behind in uploads */ + private[sql] def constructLaggingInstanceReport( + queryRunId: UUID, + latestVersion: Long, + timestamp: Long): Unit = { + rpcEndpointRef.askSync[Boolean]( + ConstructLaggingInstanceReport(queryRunId, latestVersion, timestamp) + ) + } + /** * Endpoint used for testing. * Get the latest snapshot version uploaded for a state store. @@ -165,10 +188,16 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { /** * Endpoint used for testing. - * Get the state store instances that are falling behind in snapshot uploads. + * Get the state store instances that are falling behind in snapshot uploads for a particular + * query run. */ - private[state] def getLaggingStoresForTesting(): Seq[StateStoreProviderId] = { - rpcEndpointRef.askSync[Seq[StateStoreProviderId]](GetLaggingStoresForTesting) + private[state] def getLaggingStoresForTesting( + queryRunId: UUID, + latestVersion: Long, + timestamp: Long): Seq[StateStoreProviderId] = { + rpcEndpointRef.askSync[Seq[StateStoreProviderId]]( + GetLaggingStoresForTesting(queryRunId, latestVersion, timestamp) + ) } private[state] def stop(): Unit = { @@ -239,47 +268,54 @@ private class StateStoreCoordinator( // Ignore this upload event if the registered latest version for the provider is more recent, // since it's possible that an older version gets uploaded after a new executor uploads for // the same provider but with a newer snapshot. + logDebug(s"Snapshot version $version was uploaded for provider $providerId") if (!stateStoreLatestUploadedSnapshot.get(providerId).exists(_.version >= version)) { stateStoreLatestUploadedSnapshot.put(providerId, SnapshotUploadEvent(version, timestamp)) - logDebug(s"Snapshot version $version was uploaded for provider $providerId") - - // Report all stores that are behind in snapshot uploads. - // Only report the full list of providers lagging behind if the last reported time - // is not recent. The lag report interval denotes the minimum time between these - // full reports. - val (laggingStores, latestSnapshotPerQuery) = findLaggingStores() - val coordinatorLagReportInterval = - sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) - if (laggingStores.nonEmpty && - System.currentTimeMillis() - lastFullSnapshotLagReport > coordinatorLagReportInterval) { - // Mark timestamp of the full report and log the lagging instances - lastFullSnapshotLagReport = System.currentTimeMillis() - logWarning( - log"StateStoreCoordinator Snapshot Lag Detected - " + - log"Number of state stores falling behind: " + - log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" - ) - laggingStores.foreach { storeProviderId => - val latestSnapshot = latestSnapshotPerQuery(storeProviderId.queryRunId) - val logMessage = stateStoreLatestUploadedSnapshot.get(storeProviderId) match { - case Some(snapshotEvent) => - val versionDelta = latestSnapshot.version - snapshotEvent.version - val timeDelta = latestSnapshot.timestamp - snapshotEvent.timestamp - - log"StateStoreCoordinator Snapshot Lag Detected - State store falling behind " + - log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + - log"(Latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, latestSnapshot)}, " + - log"version delta: ${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + - log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" - case None => - log"StateStoreCoordinator Snapshot Lag Detected - State store falling behind " + - log"${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + - log"(Latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, latestSnapshot)}, " + - log"never uploaded)" - } - logWarning(logMessage) + } + context.reply(true) + + case ConstructLaggingInstanceReport(queryRunId, latestVersion, timestamp) => + val laggingStores = findLaggingStores(queryRunId, latestVersion, timestamp) + logWarning( + log"StateStoreCoordinator Snapshot Lag Report for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Number of state stores falling behind: " + + log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" + ) + // Report all stores that are behind in snapshot uploads. + // Only report the full list of providers lagging behind if the last reported time + // is not recent. The lag report interval denotes the minimum time between these + // full reports. + val coordinatorLagReportInterval = + sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) + if (laggingStores.nonEmpty && + System.currentTimeMillis() - lastFullSnapshotLagReport > coordinatorLagReportInterval) { + // Mark timestamp of the full report and log the lagging instances + lastFullSnapshotLagReport = System.currentTimeMillis() + laggingStores.foreach { providerId => + val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match { + case Some(snapshotEvent) => + val versionDelta = latestVersion - snapshotEvent.version + val timeDelta = timestamp - snapshotEvent.timestamp + + log"StateStoreCoordinator Snapshot Lag Detected for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Provider falling behind: ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + + log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + + log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, snapshotEvent)}, " + + log"version delta: ${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + + log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" + case None => + log"StateStoreCoordinator Snapshot Lag Detected for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Provider falling behind: ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + + log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + + log"latest snapshot: never uploaded)" } + logWarning(logMessage) } + } else if (laggingStores.nonEmpty) { + logInfo(log"StateStoreCoordinator Snapshot Lag Report - last full report was too recent") } context.reply(true) @@ -288,8 +324,8 @@ private class StateStoreCoordinator( logDebug(s"Got latest snapshot version of the state store $providerId: $version") context.reply(version) - case GetLaggingStoresForTesting => - val (laggingStores, _) = findLaggingStores() + case GetLaggingStoresForTesting(queryRunId, latestVersion, timestamp) => + val laggingStores = findLaggingStores(queryRunId, latestVersion, timestamp) logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}") context.reply(laggingStores) @@ -304,9 +340,9 @@ private class StateStoreCoordinator( timestamp: Long ) extends Ordered[SnapshotUploadEvent] { - def isLagging(latest: SnapshotUploadEvent): Boolean = { - val versionDelta = latest.version - version - val timeDelta = latest.timestamp - timestamp + def isLagging(latestVersion: Long, latestTimestamp: Long): Boolean = { + val versionDelta = latestVersion - version + val timeDelta = latestTimestamp - timestamp // Determine alert thresholds from configurations for both time and version differences. val snapshotVersionDeltaMultiplier = sqlConf.getConf( @@ -342,33 +378,17 @@ private class StateStoreCoordinator( } } - private def findLaggingStores(): (Seq[StateStoreProviderId], Map[UUID, SnapshotUploadEvent]) = { - // Skip this check if there are no active instances - if (instances.isEmpty) { - return (Seq.empty, Map.empty) - } - - // Group instances by queryRunId and find the latest snapshot upload for each query - val latestSnapshotsByQuery = instances - .groupBy(_._1.queryRunId) - .view - .mapValues { queryInstances => - // Determine the latest snapshot upload across all instances for this query - queryInstances.map { - case (storeProviderId, _) => - stateStoreLatestUploadedSnapshot.getOrElse(storeProviderId, defaultSnapshotUploadEvent) - }.max - }.toMap - + private def findLaggingStores( + queryRunId: UUID, + referenceVersion: Long, + referenceTimestamp: Long): Seq[StateStoreProviderId] = { // Look for instances that are lagging behind in snapshot uploads - val laggingStores = instances.keys.filter { storeProviderId => - // Compare this instance with the respective query's latest snapshot - val latestSnapshot = latestSnapshotsByQuery(storeProviderId.queryRunId) - stateStoreLatestUploadedSnapshot - .getOrElse(storeProviderId, defaultSnapshotUploadEvent) - .isLagging(latestSnapshot) + instances.keys.filter { storeProviderId => + // Only consider instances that are part of this specific query run + storeProviderId.queryRunId == queryRunId && + stateStoreLatestUploadedSnapshot + .getOrElse(storeProviderId, defaultSnapshotUploadEvent) + .isLagging(referenceVersion, referenceTimestamp) }.toSeq - - (laggingStores, latestSnapshotsByQuery) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 06d30c487ebb..5406560006ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -31,18 +31,6 @@ import org.apache.spark.sql.functions.{count, expr} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils -// SkipMaintenanceOnCertainPartitionsProvider is a test-only provider that skips running -// maintenance for partitions 0 and 1 (these are arbitrary choices). This is used to test -// snapshot upload lag can be observed through StreamingQueryProgress metrics. -class SkipMaintenanceOnCertainPartitionsProvider extends RocksDBStateStoreProvider { - override def doMaintenance(): Unit = { - if (stateStoreId.partitionId == 0 || stateStoreId.partitionId == 1) { - return - } - super.doMaintenance() - } -} - class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { import StateStoreCoordinatorSuite._ @@ -248,6 +236,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + val batchId = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + val timestamp = System.currentTimeMillis() // Verify all stores have uploaded a snapshot and it's logged by the coordinator (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => @@ -256,7 +247,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) } // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStoresForTesting().isEmpty) + assert(coordRef.getLaggingStoresForTesting(query.runId, batchId, timestamp).isEmpty) query.stop() } } @@ -299,6 +290,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + val batchId = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + val timestamp = System.currentTimeMillis() (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => val providerId = @@ -312,7 +306,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } // We should have two state stores (id 0 and 1) that are lagging behind at this point - val laggingStores = coordRef.getLaggingStoresForTesting() + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, batchId, timestamp) assert(laggingStores.size == 2) assert(laggingStores.forall(_.storeId.partitionId <= 1)) query.stop() @@ -362,6 +356,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + val batchId = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + val timestamp = System.currentTimeMillis() // Verify all state stores for join queries are reporting snapshot uploads (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => @@ -375,7 +372,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStoresForTesting().isEmpty) + assert(coordRef.getLaggingStoresForTesting(query.runId, batchId, timestamp).isEmpty) query.stop() } } @@ -421,6 +418,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + val batchId = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + val timestamp = System.currentTimeMillis() // Verify all state stores for join queries are reporting snapshot uploads (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => @@ -441,7 +441,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } // Verify that only stores from partition id 0 and 1 are lagging behind. // Each partition has 4 stores for join queries, so there are 2 * 4 = 8 lagging stores. - val laggingStores = coordRef.getLaggingStoresForTesting() + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, batchId, timestamp) assert(laggingStores.size == 2 * 4) assert(laggingStores.forall(_.storeId.partitionId <= 1)) } From 6ed3366272d21bb0623d5d67030770442d3ad707 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Tue, 18 Mar 2025 15:46:56 -0700 Subject: [PATCH 15/36] SPARK-51358 Verify config turns lag reports off properly --- .../state/StateStoreCoordinator.scala | 100 +++++++++++------- .../state/StateStoreCoordinatorSuite.scala | 10 +- 2 files changed, 70 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 6eaa3d0285ff..4ae102407d0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -275,47 +275,51 @@ private class StateStoreCoordinator( context.reply(true) case ConstructLaggingInstanceReport(queryRunId, latestVersion, timestamp) => - val laggingStores = findLaggingStores(queryRunId, latestVersion, timestamp) - logWarning( - log"StateStoreCoordinator Snapshot Lag Report for " + - log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + - log"Number of state stores falling behind: " + - log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" - ) - // Report all stores that are behind in snapshot uploads. - // Only report the full list of providers lagging behind if the last reported time - // is not recent. The lag report interval denotes the minimum time between these - // full reports. - val coordinatorLagReportInterval = - sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) - if (laggingStores.nonEmpty && - System.currentTimeMillis() - lastFullSnapshotLagReport > coordinatorLagReportInterval) { - // Mark timestamp of the full report and log the lagging instances - lastFullSnapshotLagReport = System.currentTimeMillis() - laggingStores.foreach { providerId => - val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match { - case Some(snapshotEvent) => - val versionDelta = latestVersion - snapshotEvent.version - val timeDelta = timestamp - snapshotEvent.timestamp - - log"StateStoreCoordinator Snapshot Lag Detected for " + - log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + - log"Provider falling behind: ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + - log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + - log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, snapshotEvent)}, " + - log"version delta: ${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + - log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" - case None => - log"StateStoreCoordinator Snapshot Lag Detected for " + - log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + - log"Provider falling behind: ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + - log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + - log"latest snapshot: never uploaded)" + // Only log lagging instances if the snapshot report upload is enabled, + // otherwise all instances will be considered lagging. + if (isSnapshotUploadReportEnabled) { + val laggingStores = findLaggingStores(queryRunId, latestVersion, timestamp) + logWarning( + log"StateStoreCoordinator Snapshot Lag Report for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Number of state stores falling behind: " + + log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" + ) + // Report all stores that are behind in snapshot uploads. + // Only report the full list of providers lagging behind if the last reported time + // is not recent. The lag report interval denotes the minimum time between these + // full reports. + val coordinatorLagReportInterval = + sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) + if (laggingStores.nonEmpty && + System.currentTimeMillis() - lastFullSnapshotLagReport > coordinatorLagReportInterval) { + // Mark timestamp of the full report and log the lagging instances + lastFullSnapshotLagReport = System.currentTimeMillis() + laggingStores.foreach { providerId => + val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match { + case Some(snapshotEvent) => + val versionDelta = latestVersion - snapshotEvent.version + val timeDelta = timestamp - snapshotEvent.timestamp + + log"StateStoreCoordinator Snapshot Lag Detected for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Provider: ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + + log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + + log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, snapshotEvent)}, " + + log"version delta: ${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + + log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" + case None => + log"StateStoreCoordinator Snapshot Lag Detected for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Provider: ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + + log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + + log"latest snapshot: never uploaded)" + } + logWarning(logMessage) } - logWarning(logMessage) + } else if (laggingStores.nonEmpty) { + logInfo(log"StateStoreCoordinator Snapshot Lag Report - last full report was too recent") } - } else if (laggingStores.nonEmpty) { - logInfo(log"StateStoreCoordinator Snapshot Lag Report - last full report was too recent") } context.reply(true) @@ -378,10 +382,28 @@ private class StateStoreCoordinator( } } + private def isSnapshotUploadReportEnabled: Boolean = { + // Only find lagging instances if the snapshot report upload is enabled. + // If RocksDB's changelog checkpointing is disabled, then this should be disabled as well. + val isChangelogCheckpointEnabled = + sqlConf + .getConfString( + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled", + "false" + ).toBoolean + sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED) && + isChangelogCheckpointEnabled + } + private def findLaggingStores( queryRunId: UUID, referenceVersion: Long, referenceTimestamp: Long): Seq[StateStoreProviderId] = { + // Do not report any instance as lagging if the snapshot report upload is disabled, + // since it will treat all active instances as stores that have never uploaded. + if (!isSnapshotUploadReportEnabled) { + return Seq.empty + } // Look for instances that are lagging behind in snapshot uploads instances.keys.filter { storeProviderId => // Only consider instances that are part of this specific query run diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 5406560006ce..d317a27e075a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -168,7 +168,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "false", - SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true" + SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MAINTENANCE_MULTIPLIER_FOR_MIN_TIME_DELTA_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG.key -> + "1" ) { case (coordRef, spark) => import spark.implicits._ @@ -190,6 +193,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { query.processAllAvailable() val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation + val batchId = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + val timestamp = System.currentTimeMillis() // Verify stores do not report snapshot upload events to the coordinator. // As a result, all stores will return nothing as the latest version @@ -198,6 +204,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, partitionId), query.runId) assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) } + // Verify that no instances are marked as lagging + assert(coordRef.getLaggingStoresForTesting(query.runId, batchId, timestamp).isEmpty) query.stop() } } From 8cb4bbf03604de4e4360ec4125167cb5cf0f482b Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Wed, 19 Mar 2025 15:02:15 -0700 Subject: [PATCH 16/36] SPARK-51358 Switch event reporting to an object --- .../streaming/ProgressReporter.scala | 12 +++- .../execution/streaming/state/RocksDB.scala | 11 ++- .../state/RocksDBStateStoreProvider.scala | 72 +++++++++++-------- .../streaming/state/StateStore.scala | 11 --- .../state/StateStoreCoordinator.scala | 28 +++----- .../state/StateStoreCoordinatorSuite.scala | 15 ++-- 6 files changed, 67 insertions(+), 82 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 0f607e60c990..9bbf6c5c0287 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream} import org.apache.spark.sql.execution.{QueryExecution, StreamSourceAwareSparkPlan} import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress} -import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent} import org.apache.spark.util.{Clock, Utils} @@ -62,6 +62,9 @@ class ProgressReporter( val noDataProgressEventInterval: Long = sparkSession.sessionState.conf.streamingNoDataProgressEventInterval + val stateStoreCoordinator: StateStoreCoordinatorRef = + sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + private val timestampFormat = DateTimeFormatter .ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 @@ -285,7 +288,12 @@ abstract class ProgressContext( progressReporter.updateProgress(newProgress) // Ask the state store coordinator to look for any lagging instances and report them. - StateStore.constructLaggingInstanceReport(lastExecution.runId, lastEpochId) + progressReporter.stateStoreCoordinator + .constructLaggingInstanceReport( + lastExecution.runId, + lastEpochId, + triggerClock.getTimeMillis() + ) // Update the value since this trigger executes a batch successfully. this.execStatsOnLatestExecutedBatch = Some(execStats) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 31a7ff1ad4d2..3f189cd065a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -65,7 +65,7 @@ case object StoreTaskCompletionListener extends RocksDBOpType("store_task_comple * @param localRootDir Root directory in local disk that is used to working and checkpointing dirs * @param hadoopConf Hadoop configuration for talking to the remote file system * @param loggingId Id that will be prepended in logs for isolating concurrent RocksDBs - * @param providerListener The parent RocksDBStateStoreProvider object used for event reports + * @param eventListener The RocksDBEventListener object for reporting events to the coordinator */ class RocksDB( dfsRootDir: String, @@ -75,7 +75,7 @@ class RocksDB( loggingId: String = "", useColumnFamilies: Boolean = false, enableStateStoreCheckpointIds: Boolean = false, - providerListener: Option[RocksDBEventListener] = None) + eventListener: Option[RocksDBEventListener] = None) extends Logging { import RocksDB._ @@ -1475,11 +1475,8 @@ class RocksDB( log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}") // Compare and update with the version that was just uploaded. lastUploadedSnapshotVersion.updateAndGet(v => Math.max(snapshot.version, v)) - // Only report to the coordinator that the snapshot has been uploaded when - // changelog checkpointing is enabled, since that is when stores can lag behind. - if (enableChangelogCheckpointing) { - providerListener.foreach(_.reportSnapshotUploaded(snapshot.version)) - } + // Report snapshot upload event to the coordinator. + eventListener.foreach(_.reportSnapshotUploaded(snapshot.version)) } finally { snapshot.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 81a3f8825635..c5f4f80f5274 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -38,19 +38,10 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform import org.apache.spark.util.{NonFateSharingCache, Utils} -/** - * Trait representing events reported from a RocksDB instance. - * - * We pass this into the internal RocksDB instance to report specific events like snapshot uploads. - * This should only be used to report back to the coordinator for metrics and monitoring purposes. - */ -trait RocksDBEventListener { - def reportSnapshotUploaded(version: Long): Unit -} private[sql] class RocksDBStateStoreProvider extends StateStoreProvider with Logging with Closeable - with SupportsFineGrainedReplay with RocksDBEventListener { + with SupportsFineGrainedReplay { import RocksDBStateStoreProvider._ class RocksDBStateStore(lastVersion: Long) extends StateStore { @@ -395,6 +386,7 @@ private[sql] class RocksDBStateStoreProvider this.useColumnFamilies = useColumnFamilies this.stateStoreEncoding = storeConf.stateStoreEncodingFormat this.stateSchemaProvider = stateSchemaProvider + this.rocksDBEventListener = RocksDBEventListener(getRunId(hadoopConf), stateStoreId, storeConf) if (useMultipleValuesPerKey) { require(useColumnFamilies, "Multiple values per key support requires column families to be" + @@ -528,6 +520,7 @@ private[sql] class RocksDBStateStoreProvider @volatile private var useColumnFamilies: Boolean = _ @volatile private var stateStoreEncoding: String = _ @volatile private var stateSchemaProvider: Option[StateSchemaProvider] = _ + @volatile private var rocksDBEventListener: RocksDBEventListener = _ private[sql] lazy val rocksDB = { val dfsRootDir = stateStoreId.storeCheckpointLocation().toString @@ -536,7 +529,7 @@ private[sql] class RocksDBStateStoreProvider val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr) new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr, - useColumnFamilies, storeConf.enableStateStoreCheckpointIds, Some(this)) + useColumnFamilies, storeConf.enableStateStoreCheckpointIds, Some(rocksDBEventListener)) } private val keyValueEncoderMap = new java.util.concurrent.ConcurrentHashMap[String, @@ -655,26 +648,6 @@ private[sql] class RocksDBStateStoreProvider throw StateStoreErrors.cannotCreateColumnFamilyWithReservedChars(colFamilyName) } } - - /** - * Callback function from RocksDB to report events to the coordinator. - * Additional information such as the state store ID and the query run ID are - * attached here to report back to the coordinator. - * - * @param version The snapshot version that was just uploaded from RocksDB - */ - def reportSnapshotUploaded(version: Long): Unit = { - if (storeConf.stateStoreCoordinatorReportUploadEnabled) { - // Collect the state store ID and query run ID to report back to the coordinator - StateStore.reportSnapshotUploaded( - StateStoreProviderId( - stateStoreId, - UUID.fromString(getRunId(hadoopConf)) - ), - version - ) - } - } } @@ -996,3 +969,40 @@ class RocksDBStateStoreChangeDataReader( } } } + +/** + * Object used to relay events reported from a RocksDB instance to the state store coordinator. + * + * We pass this into the RocksDB instance to report specific events like snapshot uploads. + * This should only be used to report back to the coordinator for metrics and monitoring purposes. + */ +private[state] case class RocksDBEventListener( + queryRunId: String, + stateStoreId: StateStoreId, + storeConf: StateStoreConf) { + + /** ID of the state store provider managing the RocksDB instance */ + private val stateStoreProviderId: StateStoreProviderId = + StateStoreProviderId(stateStoreId, UUID.fromString(queryRunId)) + + /** Whether the event listener should relay these messages to the state store coordinator */ + private val coordinatorReportUploadEnabled: Boolean = + storeConf.stateStoreCoordinatorReportUploadEnabled + + /** + * Callback function from RocksDB to report events to the coordinator. + * Additional information such as the state store ID and the query run ID are + * attached here to report back to the coordinator. + * + * @param version The snapshot version that was just uploaded from RocksDB + */ + def reportSnapshotUploaded(version: Long): Unit = { + // Only report to the coordinator if this is enabled, as sometimes we do not need + // to track for lagging instances. + // Also ignore message if we are missing the provider ID from lack of initialization. + if (coordinatorReportUploadEnabled) { + // Report the provider ID and the version to the coordinator + StateStore.reportSnapshotUploaded(stateStoreProviderId, version) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index b37233e556de..f87c98338ce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -1139,17 +1139,6 @@ object StateStore extends Logging { coordinatorRef.foreach(_.snapshotUploaded(storeProviderId, snapshotVersion, currentTime)) } - private[streaming] def constructLaggingInstanceReport( - queryRunId: UUID, - latestVersion: Long): Unit = { - // Asks the coordinator to check and report for any lagging state store instances - // Attach the current time, indicating when the current batch just completed - val currentTime = System.currentTimeMillis() - coordinatorRef.foreach( - _.constructLaggingInstanceReport(queryRunId, latestVersion, currentTime) - ) - } - private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 4ae102407d0e..0d97f02dc644 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -274,11 +274,11 @@ private class StateStoreCoordinator( } context.reply(true) - case ConstructLaggingInstanceReport(queryRunId, latestVersion, timestamp) => + case ConstructLaggingInstanceReport(queryRunId, latestVersion, endOfBatchTimestamp) => // Only log lagging instances if the snapshot report upload is enabled, // otherwise all instances will be considered lagging. - if (isSnapshotUploadReportEnabled) { - val laggingStores = findLaggingStores(queryRunId, latestVersion, timestamp) + if (sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED)) { + val laggingStores = findLaggingStores(queryRunId, latestVersion, endOfBatchTimestamp) logWarning( log"StateStoreCoordinator Snapshot Lag Report for " + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + @@ -291,15 +291,16 @@ private class StateStoreCoordinator( // full reports. val coordinatorLagReportInterval = sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) + val currentTimestamp = System.currentTimeMillis() if (laggingStores.nonEmpty && - System.currentTimeMillis() - lastFullSnapshotLagReport > coordinatorLagReportInterval) { + currentTimestamp - lastFullSnapshotLagReport > coordinatorLagReportInterval) { // Mark timestamp of the full report and log the lagging instances - lastFullSnapshotLagReport = System.currentTimeMillis() + lastFullSnapshotLagReport = currentTimestamp laggingStores.foreach { providerId => val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match { case Some(snapshotEvent) => val versionDelta = latestVersion - snapshotEvent.version - val timeDelta = timestamp - snapshotEvent.timestamp + val timeDelta = endOfBatchTimestamp - snapshotEvent.timestamp log"StateStoreCoordinator Snapshot Lag Detected for " + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + @@ -382,26 +383,13 @@ private class StateStoreCoordinator( } } - private def isSnapshotUploadReportEnabled: Boolean = { - // Only find lagging instances if the snapshot report upload is enabled. - // If RocksDB's changelog checkpointing is disabled, then this should be disabled as well. - val isChangelogCheckpointEnabled = - sqlConf - .getConfString( - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled", - "false" - ).toBoolean - sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED) && - isChangelogCheckpointEnabled - } - private def findLaggingStores( queryRunId: UUID, referenceVersion: Long, referenceTimestamp: Long): Seq[StateStoreProviderId] = { // Do not report any instance as lagging if the snapshot report upload is disabled, // since it will treat all active instances as stores that have never uploaded. - if (!isSnapshotUploadReportEnabled) { + if (!sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED)) { return Seq.empty } // Look for instances that are lagging behind in snapshot uploads diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index d317a27e075a..19b669c6d5eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -171,7 +171,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true", SQLConf.STATE_STORE_COORDINATOR_MAINTENANCE_MULTIPLIER_FOR_MIN_TIME_DELTA_TO_LOG.key -> "1", SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG.key -> - "1" + "2" ) { case (coordRef, spark) => import spark.implicits._ @@ -191,20 +191,13 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { query.processAllAvailable() inputData.addData(1, 2, 3) query.processAllAvailable() - val stateCheckpointDir = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation val batchId = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId val timestamp = System.currentTimeMillis() - // Verify stores do not report snapshot upload events to the coordinator. - // As a result, all stores will return nothing as the latest version - (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => - val providerId = - StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, partitionId), query.runId) - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) - } - // Verify that no instances are marked as lagging + // Verify that no instances are marked as lagging, even when upload messages are sent. + // Since snapshot uploads are tied to commit, the lack of version difference should prevent + // the stores from being marked as lagging. assert(coordRef.getLaggingStoresForTesting(query.runId, batchId, timestamp).isEmpty) query.stop() } From 8b1fd5b718dc0942c5a034fc61fbb882c62f6d7f Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Thu, 20 Mar 2025 10:40:24 -0700 Subject: [PATCH 17/36] SPARK-51358 Fix feedback --- .../apache/spark/sql/internal/SQLConf.scala | 46 ++++++------ .../streaming/ProgressReporter.scala | 15 ++-- .../state/RocksDBStateStoreProvider.scala | 15 ++-- .../streaming/state/StateStoreConf.scala | 7 +- .../state/StateStoreCoordinator.scala | 64 +++++++--------- .../state/StateStoreCoordinatorSuite.scala | 74 +++++++++---------- 6 files changed, 100 insertions(+), 121 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4dbda130ac04..7db3a8d9e105 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2256,54 +2256,52 @@ object SQLConf { .booleanConf .createWithDefault(true) - val STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG = - buildConf("spark.sql.streaming.stateStore.minSnapshotDeltaMultiplierForMinVersionDeltaToLog") + val STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG = + buildConf("spark.sql.streaming.stateStore.multiplierForMinVersionDiffToLog") .internal() .doc( - "This multiplier determines the minimum version threshold for logging warnings when a " + - "state store instance falls behind. The coordinator logs a warning if a state store's " + - "last uploaded snapshot's version lags behind the query's latest known version by " + - "this threshold. The threshold is calculated as the configured minimum number of deltas " + - "needed to create a snapshot, multiplied by this multiplier." + "Determines the version threshold for logging warnings when a state store falls behind. " + + "The coordinator logs a warning when the store's uploaded snapshot version trails the " + + "query's latest version by the configured number of deltas needed to create a snapshot, " + + "times this multiplier." ) .version("4.1.0") .intConf .checkValue(k => k >= 1, "Must be greater than or equal to 1") .createWithDefault(5) - val STATE_STORE_COORDINATOR_MAINTENANCE_MULTIPLIER_FOR_MIN_TIME_DELTA_TO_LOG = - buildConf("spark.sql.streaming.stateStore.maintenanceMultiplierForMinTimeDeltaToLog") + val STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG = + buildConf("spark.sql.streaming.stateStore.multiplierForMinTimeDiffToLog") .internal() .doc( - "This multiplier determines the minimum time threshold for logging warnings when a " + - "state store instance falls behind. The coordinator logs a warning if a state store's " + - "last snapshot upload time lags behind the current time by this threshold. " + - "The threshold is calculated as the maintenance interval multiplied by this multiplier." + "Determines the time threshold for logging warnings when a state store falls behind. " + + "The coordinator logs a warning when the store's uploaded snapshot timestamp trails the " + + "current time by the configured maintenance interval, times this multiplier." ) .version("4.1.0") .intConf .checkValue(k => k >= 1, "Must be greater than or equal to 1") .createWithDefault(10) - val STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED = - buildConf("spark.sql.streaming.stateStore.coordinatorReportUpload.enabled") + val STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG = + buildConf("spark.sql.streaming.stateStore.coordinatorReportSnapshotUploadLag") .internal() .doc( - "If enabled, state store instances will send a message to the state store " + - "coordinator whenever they complete a snapshot upload." + "When enabled, the state store coordinator will report state stores whose snapshot " + + "have not been uploaded for some time. See the conf snapshotLagReportInterval for " + + "the minimum time between reports, and the conf multiplierForMinVersionDiffToLog " + + "and multiplierForMinTimeDiffToLog for the logging thresholds." ) .version("4.1.0") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL = buildConf("spark.sql.streaming.stateStore.snapshotLagReportInterval") .internal() .doc( - "The minimum amount of time between the state store coordinator's full report on " + - "state store instances lagging in snapshot uploads. The reports may be delayed " + - "as the coordinator only checks for lagging instances upon receiving a message " + - "instructing it to do so." + "The minimum amount of time between the state store coordinator's reports on " + + "state store instances trailing behind in snapshot uploads." ) .version("4.1.0") .timeConf(TimeUnit.MILLISECONDS) @@ -5853,8 +5851,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def stateStoreSkipNullsForStreamStreamJoins: Boolean = getConf(STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS) - def stateStoreCoordinatorReportUploadEnabled: Boolean = - getConf(STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED) + def stateStoreCoordinatorReportSnapshotUploadLag: Boolean = + getConf(STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG) def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 9bbf6c5c0287..ed17b55c5e74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -62,6 +62,9 @@ class ProgressReporter( val noDataProgressEventInterval: Long = sparkSession.sessionState.conf.streamingNoDataProgressEventInterval + val coordinatorReportSnapshotUploadLag: Boolean = + sparkSession.sessionState.conf.stateStoreCoordinatorReportSnapshotUploadLag + val stateStoreCoordinator: StateStoreCoordinatorRef = sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator @@ -287,13 +290,11 @@ abstract class ProgressContext( progressReporter.lastNoExecutionProgressEventTime = triggerClock.getTimeMillis() progressReporter.updateProgress(newProgress) - // Ask the state store coordinator to look for any lagging instances and report them. - progressReporter.stateStoreCoordinator - .constructLaggingInstanceReport( - lastExecution.runId, - lastEpochId, - triggerClock.getTimeMillis() - ) + // Ask the state store coordinator to log all lagging state stores + if (progressReporter.coordinatorReportSnapshotUploadLag) { + progressReporter.stateStoreCoordinator + .logLaggingStateStores(lastExecution.runId, lastEpochId + 1) + } // Update the value since this trigger executes a batch successfully. this.execStatsOnLatestExecutedBatch = Some(execStats) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index c5f4f80f5274..bbd2a4264c52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -38,7 +38,6 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform import org.apache.spark.util.{NonFateSharingCache, Utils} - private[sql] class RocksDBStateStoreProvider extends StateStoreProvider with Logging with Closeable with SupportsFineGrainedReplay { @@ -971,7 +970,7 @@ class RocksDBStateStoreChangeDataReader( } /** - * Object used to relay events reported from a RocksDB instance to the state store coordinator. + * Class used to relay events reported from a RocksDB instance to the state store coordinator. * * We pass this into the RocksDB instance to report specific events like snapshot uploads. * This should only be used to report back to the coordinator for metrics and monitoring purposes. @@ -985,9 +984,9 @@ private[state] case class RocksDBEventListener( private val stateStoreProviderId: StateStoreProviderId = StateStoreProviderId(stateStoreId, UUID.fromString(queryRunId)) - /** Whether the event listener should relay these messages to the state store coordinator */ - private val coordinatorReportUploadEnabled: Boolean = - storeConf.stateStoreCoordinatorReportUploadEnabled + /** Whether the coordinator is logging state stores lagging behind */ + private val coordinatorReportSnapshotUploadLagEnabled: Boolean = + storeConf.stateStoreCoordinatorReportSnapshotUploadLag /** * Callback function from RocksDB to report events to the coordinator. @@ -997,10 +996,8 @@ private[state] case class RocksDBEventListener( * @param version The snapshot version that was just uploaded from RocksDB */ def reportSnapshotUploaded(version: Long): Unit = { - // Only report to the coordinator if this is enabled, as sometimes we do not need - // to track for lagging instances. - // Also ignore message if we are missing the provider ID from lack of initialization. - if (coordinatorReportUploadEnabled) { + // Only report to the coordinator if it is reporting lagging stores + if (coordinatorReportSnapshotUploadLagEnabled) { // Report the provider ID and the version to the coordinator StateStore.reportSnapshotUploaded(stateStoreProviderId, version) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 07e4766b000e..26c77fd2ea3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -98,11 +98,10 @@ class StateStoreConf( StatefulOperatorStateInfo.enableStateStoreCheckpointIds(sqlConf) /** - * Whether to report snapshot uploaded messages from the internal RocksDB instance - * to the state store coordinator. + * Whether the coordinator is reporting state stores trailing behind in snapshot uploads. */ - val stateStoreCoordinatorReportUploadEnabled: Boolean = - sqlConf.stateStoreCoordinatorReportUploadEnabled + val stateStoreCoordinatorReportSnapshotUploadLag: Boolean = + sqlConf.stateStoreCoordinatorReportSnapshotUploadLag /** * Additional configurations related to state store. This will capture all configs in diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 0d97f02dc644..88c481bdb57b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -70,10 +70,9 @@ private case class ReportSnapshotUploaded( * This message is used for the coordinator to look for all state stores that are lagging behind * in snapshot uploads. The coordinator will then log a warning message for each lagging instance. */ -private case class ConstructLaggingInstanceReport( +private case class LogLaggingStateStores( queryRunId: UUID, - latestVersion: Long, - timestamp: Long) + latestVersion: Long) extends StateStoreCoordinatorMessage /** @@ -86,13 +85,12 @@ private case class GetLatestSnapshotVersionForTesting(storeId: StateStoreProvide /** * Message used for testing. - * This message is used to retrieve the all active state store instance falling behind in - * snapshot uploads, using version and time criterias. + * This message is used to retrieve all active state store instance falling behind in + * snapshot uploads, using version and time criteria. */ private case class GetLaggingStoresForTesting( queryRunId: UUID, - latestVersion: Long, - timestamp: Long) + latestVersion: Long) extends StateStoreCoordinatorMessage private object StopCoordinator @@ -167,14 +165,9 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { rpcEndpointRef.askSync[Boolean](ReportSnapshotUploaded(storeProviderId, version, timestamp)) } - /** Ask the coordinator to find all state store instances that are lagging behind in uploads */ - private[sql] def constructLaggingInstanceReport( - queryRunId: UUID, - latestVersion: Long, - timestamp: Long): Unit = { - rpcEndpointRef.askSync[Boolean]( - ConstructLaggingInstanceReport(queryRunId, latestVersion, timestamp) - ) + /** Ask the coordinator to log all state store instances that are lagging behind in uploads */ + private[sql] def logLaggingStateStores(queryRunId: UUID, latestVersion: Long): Unit = { + rpcEndpointRef.askSync[Boolean](LogLaggingStateStores(queryRunId, latestVersion)) } /** @@ -193,10 +186,9 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { */ private[state] def getLaggingStoresForTesting( queryRunId: UUID, - latestVersion: Long, - timestamp: Long): Seq[StateStoreProviderId] = { + latestVersion: Long): Seq[StateStoreProviderId] = { rpcEndpointRef.askSync[Seq[StateStoreProviderId]]( - GetLaggingStoresForTesting(queryRunId, latestVersion, timestamp) + GetLaggingStoresForTesting(queryRunId, latestVersion) ) } @@ -227,7 +219,7 @@ private class StateStoreCoordinator( // Stores the last timestamp in milliseconds where the coordinator did a full report on // instances lagging behind on snapshot uploads. The initial timestamp is defaulted to // 0 milliseconds. - private var lastFullSnapshotLagReport = 0L + private var lastFullSnapshotLagReportTimeMs = 0L override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => @@ -274,11 +266,12 @@ private class StateStoreCoordinator( } context.reply(true) - case ConstructLaggingInstanceReport(queryRunId, latestVersion, endOfBatchTimestamp) => + case LogLaggingStateStores(queryRunId, latestVersion) => // Only log lagging instances if the snapshot report upload is enabled, // otherwise all instances will be considered lagging. - if (sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED)) { - val laggingStores = findLaggingStores(queryRunId, latestVersion, endOfBatchTimestamp) + val currentTimestamp = System.currentTimeMillis() + val laggingStores = findLaggingStores(queryRunId, latestVersion, currentTimestamp) + if (laggingStores.nonEmpty) { logWarning( log"StateStoreCoordinator Snapshot Lag Report for " + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + @@ -291,16 +284,15 @@ private class StateStoreCoordinator( // full reports. val coordinatorLagReportInterval = sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) - val currentTimestamp = System.currentTimeMillis() if (laggingStores.nonEmpty && - currentTimestamp - lastFullSnapshotLagReport > coordinatorLagReportInterval) { + currentTimestamp - lastFullSnapshotLagReportTimeMs > coordinatorLagReportInterval) { // Mark timestamp of the full report and log the lagging instances - lastFullSnapshotLagReport = currentTimestamp + lastFullSnapshotLagReportTimeMs = currentTimestamp laggingStores.foreach { providerId => val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match { case Some(snapshotEvent) => - val versionDelta = latestVersion - snapshotEvent.version - val timeDelta = endOfBatchTimestamp - snapshotEvent.timestamp + val versionDelta = latestVersion - Math.max(snapshotEvent.version, 0) + val timeDelta = currentTimestamp - snapshotEvent.timestamp log"StateStoreCoordinator Snapshot Lag Detected for " + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + @@ -314,12 +306,10 @@ private class StateStoreCoordinator( log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + log"Provider: ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + - log"latest snapshot: never uploaded)" + log"latest snapshot: no upload for query run)" } logWarning(logMessage) } - } else if (laggingStores.nonEmpty) { - logInfo(log"StateStoreCoordinator Snapshot Lag Report - last full report was too recent") } } context.reply(true) @@ -329,8 +319,9 @@ private class StateStoreCoordinator( logDebug(s"Got latest snapshot version of the state store $providerId: $version") context.reply(version) - case GetLaggingStoresForTesting(queryRunId, latestVersion, timestamp) => - val laggingStores = findLaggingStores(queryRunId, latestVersion, timestamp) + case GetLaggingStoresForTesting(queryRunId, latestVersion) => + val currentTimestamp = System.currentTimeMillis() + val laggingStores = findLaggingStores(queryRunId, latestVersion, currentTimestamp) logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}") context.reply(laggingStores) @@ -346,14 +337,15 @@ private class StateStoreCoordinator( ) extends Ordered[SnapshotUploadEvent] { def isLagging(latestVersion: Long, latestTimestamp: Long): Boolean = { - val versionDelta = latestVersion - version + // Use version 0 for stores that have not uploaded a snapshot version for this run. + val versionDelta = latestVersion - Math.max(version, 0) val timeDelta = latestTimestamp - timestamp // Determine alert thresholds from configurations for both time and version differences. val snapshotVersionDeltaMultiplier = sqlConf.getConf( - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG) + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG) val maintenanceIntervalMultiplier = sqlConf.getConf( - SQLConf.STATE_STORE_COORDINATOR_MAINTENANCE_MULTIPLIER_FOR_MIN_TIME_DELTA_TO_LOG) + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG) val minDeltasForSnapshot = sqlConf.getConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) val maintenanceInterval = sqlConf.getConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL) @@ -389,7 +381,7 @@ private class StateStoreCoordinator( referenceTimestamp: Long): Seq[StateStoreProviderId] = { // Do not report any instance as lagging if the snapshot report upload is disabled, // since it will treat all active instances as stores that have never uploaded. - if (!sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED)) { + if (!sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG)) { return Seq.empty } // Look for instances that are lagging behind in snapshot uploads diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 19b669c6d5eb..3724cc44206b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -168,10 +168,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "false", - SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MAINTENANCE_MULTIPLIER_FOR_MIN_TIME_DELTA_TO_LOG.key -> "1", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG.key -> - "2" + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2" ) { case (coordRef, spark) => import spark.implicits._ @@ -187,18 +187,18 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .queryName("query") .option("checkpointLocation", checkpointLocation.toString) .start() - inputData.addData(1, 2, 3) - query.processAllAvailable() - inputData.addData(1, 2, 3) - query.processAllAvailable() - val batchId = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId - val timestamp = System.currentTimeMillis() - + // Go through several rounds of input to force snapshot uploads + (0 until 5).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(1000) + } + val latestVersion = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 // Verify that no instances are marked as lagging, even when upload messages are sent. // Since snapshot uploads are tied to commit, the lack of version difference should prevent // the stores from being marked as lagging. - assert(coordRef.getLaggingStoresForTesting(query.runId, batchId, timestamp).isEmpty) + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) query.stop() } } @@ -211,9 +211,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG.key -> - "2" + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2" ) { case (coordRef, spark) => import spark.implicits._ @@ -237,9 +236,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation - val batchId = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId - val timestamp = System.currentTimeMillis() + val latestVersion = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 // Verify all stores have uploaded a snapshot and it's logged by the coordinator (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => @@ -248,7 +246,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) } // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStoresForTesting(query.runId, batchId, timestamp).isEmpty) + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) query.stop() } } @@ -265,9 +263,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG.key -> - "2" + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2" ) { case (coordRef, spark) => import spark.implicits._ @@ -291,9 +288,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation - val batchId = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId - val timestamp = System.currentTimeMillis() + val latestVersion = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => val providerId = @@ -307,7 +303,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } // We should have two state stores (id 0 and 1) that are lagging behind at this point - val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, batchId, timestamp) + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) assert(laggingStores.size == 2) assert(laggingStores.forall(_.storeId.partitionId <= 1)) query.stop() @@ -328,9 +324,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG.key -> - "5" + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5" ) { case (coordRef, spark) => import spark.implicits._ @@ -357,9 +352,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation - val batchId = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId - val timestamp = System.currentTimeMillis() + val latestVersion = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 // Verify all state stores for join queries are reporting snapshot uploads (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => @@ -373,7 +367,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStoresForTesting(query.runId, batchId, timestamp).isEmpty) + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) query.stop() } } @@ -390,9 +384,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_UPLOAD_ENABLED.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_DELTA_MULTIPLIER_FOR_MIN_VERSION_DELTA_TO_LOG.key -> - "5" + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5" ) { case (coordRef, spark) => import spark.implicits._ @@ -419,9 +412,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } val stateCheckpointDir = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation - val batchId = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId - val timestamp = System.currentTimeMillis() + val latestVersion = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 // Verify all state stores for join queries are reporting snapshot uploads (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => @@ -442,7 +434,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } // Verify that only stores from partition id 0 and 1 are lagging behind. // Each partition has 4 stores for join queries, so there are 2 * 4 = 8 lagging stores. - val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, batchId, timestamp) + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) assert(laggingStores.size == 2 * 4) assert(laggingStores.forall(_.storeId.partitionId <= 1)) } From 4ada2855dbca20b6b900049bdb08f52f5fb69f94 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Thu, 20 Mar 2025 17:24:47 -0700 Subject: [PATCH 18/36] SPARK-51358 Add additional edge cases and HDFS support and tests --- .../apache/spark/sql/internal/SQLConf.scala | 13 + .../state/HDFSBackedStateStoreProvider.scala | 4 + .../execution/streaming/state/RocksDB.scala | 16 +- .../state/RocksDBStateStoreProvider.scala | 25 +- .../streaming/state/StateStore.scala | 6 +- .../state/StateStoreCoordinator.scala | 131 ++-- .../state/StateStoreCoordinatorSuite.scala | 592 +++++++++++------- 7 files changed, 465 insertions(+), 322 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7db3a8d9e105..d2968958140c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2307,6 +2307,19 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(TimeUnit.MINUTES.toMillis(5)) + val STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT = + buildConf("spark.sql.streaming.stateStore.maxLaggingStoresToReport") + .internal() + .doc( + "Maximum number of state stores the coordinator will report as trailing in " + + "snapshot uploads. Stores are selected based on the most lagging behind in " + + "snapshot version." + ) + .version("4.1.0") + .intConf + .checkValue(k => k >= 0, "Must be greater than or equal to 0") + .createWithDefault(10) + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 8c1f2eeb41a9..75902e373d26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -677,6 +677,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with logInfo(log"Written snapshot file for version ${MDC(LogKeys.FILE_VERSION, version)} of " + log"${MDC(LogKeys.STATE_STORE_PROVIDER, this)} at ${MDC(LogKeys.FILE_NAME, targetFile)} " + log"for ${MDC(LogKeys.OP_TYPE, opType)}") + // Report snapshot upload event to the coordinator, and include the store ID with the message. + if (storeConf.stateStoreCoordinatorReportSnapshotUploadLag) { + StateStore.reportSnapshotUploaded(stateStoreId, version) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 3f189cd065a3..a89142e94af2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -75,8 +75,7 @@ class RocksDB( loggingId: String = "", useColumnFamilies: Boolean = false, enableStateStoreCheckpointIds: Boolean = false, - eventListener: Option[RocksDBEventListener] = None) - extends Logging { + eventListener: Option[RocksDBEventListener] = None) extends Logging { import RocksDB._ @@ -1476,7 +1475,12 @@ class RocksDB( // Compare and update with the version that was just uploaded. lastUploadedSnapshotVersion.updateAndGet(v => Math.max(snapshot.version, v)) // Report snapshot upload event to the coordinator. - eventListener.foreach(_.reportSnapshotUploaded(snapshot.version)) + if (conf.stateStoreCoordinatorReportSnapshotUploadLag) { + // Note that we still report uploads even when changelog checkpointing is enabled. + // The coordinator needs a way to determine whether upload messages are disabled or not, + // which would be different between RocksDB and HDFS stores due to changelog checkpointing. + eventListener.foreach(_.reportSnapshotUploaded(snapshot.version)) + } } finally { snapshot.close() } @@ -1725,7 +1729,8 @@ case class RocksDBConf( highPriorityPoolRatio: Double, compressionCodec: String, allowFAllocate: Boolean, - compression: String) + compression: String, + stateStoreCoordinatorReportSnapshotUploadLag: Boolean) object RocksDBConf { /** Common prefix of all confs in SQLConf that affects RocksDB */ @@ -1908,7 +1913,8 @@ object RocksDBConf { getRatioConf(HIGH_PRIORITY_POOL_RATIO_CONF), storeConf.compressionCodec, getBooleanConf(ALLOW_FALLOCATE_CONF), - getStringConf(COMPRESSION_CONF)) + getStringConf(COMPRESSION_CONF), + storeConf.stateStoreCoordinatorReportSnapshotUploadLag) } def apply(): RocksDBConf = apply(new StateStoreConf()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index bbd2a4264c52..24206c4f273f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -385,7 +385,7 @@ private[sql] class RocksDBStateStoreProvider this.useColumnFamilies = useColumnFamilies this.stateStoreEncoding = storeConf.stateStoreEncodingFormat this.stateSchemaProvider = stateSchemaProvider - this.rocksDBEventListener = RocksDBEventListener(getRunId(hadoopConf), stateStoreId, storeConf) + this.rocksDBEventListener = RocksDBEventListener(stateStoreId) if (useMultipleValuesPerKey) { require(useColumnFamilies, "Multiple values per key support requires column families to be" + @@ -975,31 +975,16 @@ class RocksDBStateStoreChangeDataReader( * We pass this into the RocksDB instance to report specific events like snapshot uploads. * This should only be used to report back to the coordinator for metrics and monitoring purposes. */ -private[state] case class RocksDBEventListener( - queryRunId: String, - stateStoreId: StateStoreId, - storeConf: StateStoreConf) { - - /** ID of the state store provider managing the RocksDB instance */ - private val stateStoreProviderId: StateStoreProviderId = - StateStoreProviderId(stateStoreId, UUID.fromString(queryRunId)) - - /** Whether the coordinator is logging state stores lagging behind */ - private val coordinatorReportSnapshotUploadLagEnabled: Boolean = - storeConf.stateStoreCoordinatorReportSnapshotUploadLag - +private[state] case class RocksDBEventListener(stateStoreId: StateStoreId) { /** * Callback function from RocksDB to report events to the coordinator. - * Additional information such as the state store ID and the query run ID are + * Information from the store provider such as the state store ID is * attached here to report back to the coordinator. * * @param version The snapshot version that was just uploaded from RocksDB */ def reportSnapshotUploaded(version: Long): Unit = { - // Only report to the coordinator if it is reporting lagging stores - if (coordinatorReportSnapshotUploadLagEnabled) { - // Report the provider ID and the version to the coordinator - StateStore.reportSnapshotUploaded(stateStoreProviderId, version) - } + // Report the state store ID and the version to the coordinator + StateStore.reportSnapshotUploaded(stateStoreId, version) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index f87c98338ce7..59e26b71f80c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -1131,12 +1131,10 @@ object StateStore extends Logging { } } - private[state] def reportSnapshotUploaded( - storeProviderId: StateStoreProviderId, - snapshotVersion: Long): Unit = { + private[state] def reportSnapshotUploaded(storeId: StateStoreId, snapshotVersion: Long): Unit = { // Attach the current timestamp of uploaded snapshot and send the message to the coordinator val currentTime = System.currentTimeMillis() - coordinatorRef.foreach(_.snapshotUploaded(storeProviderId, snapshotVersion, currentTime)) + coordinatorRef.foreach(_.snapshotUploaded(storeId, snapshotVersion, currentTime)) } private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 88c481bdb57b..51024f57546f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -60,27 +60,22 @@ private case class DeactivateInstances(runId: UUID) * This message is used to report a state store instance has just finished uploading a snapshot, * along with the timestamp in milliseconds and the snapshot version. */ -private case class ReportSnapshotUploaded( - storeId: StateStoreProviderId, - version: Long, - timestamp: Long) +private case class ReportSnapshotUploaded(storeId: StateStoreId, version: Long, timestamp: Long) extends StateStoreCoordinatorMessage /** * This message is used for the coordinator to look for all state stores that are lagging behind * in snapshot uploads. The coordinator will then log a warning message for each lagging instance. */ -private case class LogLaggingStateStores( - queryRunId: UUID, - latestVersion: Long) +private case class LogLaggingStateStores(queryRunId: UUID, latestVersion: Long) extends StateStoreCoordinatorMessage /** * Message used for testing. * This message is used to retrieve the latest snapshot version reported for upload from a - * specific state store instance. + * specific state store. */ -private case class GetLatestSnapshotVersionForTesting(storeId: StateStoreProviderId) +private case class GetLatestSnapshotVersionForTesting(storeId: StateStoreId) extends StateStoreCoordinatorMessage /** @@ -88,9 +83,7 @@ private case class GetLatestSnapshotVersionForTesting(storeId: StateStoreProvide * This message is used to retrieve all active state store instance falling behind in * snapshot uploads, using version and time criteria. */ -private case class GetLaggingStoresForTesting( - queryRunId: UUID, - latestVersion: Long) +private case class GetLaggingStoresForTesting(queryRunId: UUID, latestVersion: Long) extends StateStoreCoordinatorMessage private object StopCoordinator @@ -159,10 +152,10 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { /** Inform that an executor has uploaded a snapshot */ private[sql] def snapshotUploaded( - storeProviderId: StateStoreProviderId, + storeId: StateStoreId, version: Long, timestamp: Long): Unit = { - rpcEndpointRef.askSync[Boolean](ReportSnapshotUploaded(storeProviderId, version, timestamp)) + rpcEndpointRef.askSync[Boolean](ReportSnapshotUploaded(storeId, version, timestamp)) } /** Ask the coordinator to log all state store instances that are lagging behind in uploads */ @@ -174,9 +167,8 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { * Endpoint used for testing. * Get the latest snapshot version uploaded for a state store. */ - private[state] def getLatestSnapshotVersionForTesting( - stateStoreProviderId: StateStoreProviderId): Option[Long] = { - rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersionForTesting(stateStoreProviderId)) + private[state] def getLatestSnapshotVersionForTesting(storeId: StateStoreId): Option[Long] = { + rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersionForTesting(storeId)) } /** @@ -186,8 +178,8 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { */ private[state] def getLaggingStoresForTesting( queryRunId: UUID, - latestVersion: Long): Seq[StateStoreProviderId] = { - rpcEndpointRef.askSync[Seq[StateStoreProviderId]]( + latestVersion: Long): Seq[StateStoreId] = { + rpcEndpointRef.askSync[Seq[StateStoreId]]( GetLaggingStoresForTesting(queryRunId, latestVersion) ) } @@ -205,13 +197,12 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { private class StateStoreCoordinator( override val rpcEnv: RpcEnv, val sqlConf: SQLConf) - extends ThreadSafeRpcEndpoint - with Logging { + extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] - // Stores the latest snapshot upload event for a specific state store provider instance + // Stores the latest snapshot upload event for a specific state store private val stateStoreLatestUploadedSnapshot = - new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent] + new mutable.HashMap[StateStoreId, SnapshotUploadEvent] // Default snapshot upload event to use when a provider has never uploaded a snapshot private val defaultSnapshotUploadEvent = SnapshotUploadEvent(-1, 0) @@ -256,13 +247,13 @@ private class StateStoreCoordinator( storeIdsToRemove.mkString(", ")) context.reply(true) - case ReportSnapshotUploaded(providerId, version, timestamp) => - // Ignore this upload event if the registered latest version for the provider is more recent, + case ReportSnapshotUploaded(storeId, version, timestamp) => + // Ignore this upload event if the registered latest version for the store is more recent, // since it's possible that an older version gets uploaded after a new executor uploads for - // the same provider but with a newer snapshot. - logDebug(s"Snapshot version $version was uploaded for provider $providerId") - if (!stateStoreLatestUploadedSnapshot.get(providerId).exists(_.version >= version)) { - stateStoreLatestUploadedSnapshot.put(providerId, SnapshotUploadEvent(version, timestamp)) + // the same state store but with a newer snapshot. + logDebug(s"Snapshot version $version was uploaded for state store $storeId") + if (!stateStoreLatestUploadedSnapshot.get(storeId).exists(_.version >= version)) { + stateStoreLatestUploadedSnapshot.put(storeId, SnapshotUploadEvent(version, timestamp)) } context.reply(true) @@ -288,35 +279,40 @@ private class StateStoreCoordinator( currentTimestamp - lastFullSnapshotLagReportTimeMs > coordinatorLagReportInterval) { // Mark timestamp of the full report and log the lagging instances lastFullSnapshotLagReportTimeMs = currentTimestamp - laggingStores.foreach { providerId => - val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match { - case Some(snapshotEvent) => - val versionDelta = latestVersion - Math.max(snapshotEvent.version, 0) - val timeDelta = currentTimestamp - snapshotEvent.timestamp - - log"StateStoreCoordinator Snapshot Lag Detected for " + - log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + - log"Provider: ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + - log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + - log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, snapshotEvent)}, " + - log"version delta: ${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + - log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" - case None => - log"StateStoreCoordinator Snapshot Lag Detected for " + - log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + - log"Provider: ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, providerId)} " + - log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + - log"latest snapshot: no upload for query run)" + // Only report the stores that are lagging the most behind in snapshot uploads. + laggingStores + .sortBy(stateStoreLatestUploadedSnapshot.getOrElse(_, defaultSnapshotUploadEvent)) + .take(sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT)) + .foreach { storeId => + val logMessage = stateStoreLatestUploadedSnapshot.get(storeId) match { + case Some(snapshotEvent) => + val versionDelta = latestVersion - Math.max(snapshotEvent.version, 0) + val timeDelta = currentTimestamp - snapshotEvent.timestamp + + log"StateStoreCoordinator Snapshot Lag Detected for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, storeId)} " + + log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + + log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, snapshotEvent)}, " + + log"version delta: " + + log"${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + + log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" + case None => + log"StateStoreCoordinator Snapshot Lag Detected for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, storeId)} " + + log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + + log"latest snapshot: no upload for query run)" + } + logWarning(logMessage) } - logWarning(logMessage) - } } } context.reply(true) - case GetLatestSnapshotVersionForTesting(providerId) => - val version = stateStoreLatestUploadedSnapshot.get(providerId).map(_.version) - logDebug(s"Got latest snapshot version of the state store $providerId: $version") + case GetLatestSnapshotVersionForTesting(storeId) => + val version = stateStoreLatestUploadedSnapshot.get(storeId).map(_.version) + logDebug(s"Got latest snapshot version of the state store $storeId: $version") context.reply(version) case GetLaggingStoresForTesting(queryRunId, latestVersion) => @@ -354,10 +350,9 @@ private class StateStoreCoordinator( val minTimeDeltaForLogging = maintenanceIntervalMultiplier * maintenanceInterval // Mark a state store as lagging if it is behind in both version and time. - // In the case that a snapshot was never uploaded, we treat version -1 as the preceding - // version of 0, and only rely on the version delta condition. - // Time requirement will be automatically satisfied as the initial timestamp is 0. - versionDelta >= minVersionDeltaForLogging && timeDelta > minTimeDeltaForLogging + // For stores that have never uploaded a snapshot, the time requirement will + // be automatically satisfied as the initial timestamp is 0. + versionDelta > minVersionDeltaForLogging && timeDelta > minTimeDeltaForLogging } override def compare(otherEvent: SnapshotUploadEvent): Int = { @@ -378,19 +373,23 @@ private class StateStoreCoordinator( private def findLaggingStores( queryRunId: UUID, referenceVersion: Long, - referenceTimestamp: Long): Seq[StateStoreProviderId] = { - // Do not report any instance as lagging if the snapshot report upload is disabled, - // since it will treat all active instances as stores that have never uploaded. + referenceTimestamp: Long): Seq[StateStoreId] = { + // Do not report any instance as lagging if report snapshot upload is disabled. if (!sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG)) { return Seq.empty } - // Look for instances that are lagging behind in snapshot uploads - instances.keys.filter { storeProviderId => - // Only consider instances that are part of this specific query run - storeProviderId.queryRunId == queryRunId && + // Look for state stores that are lagging behind in snapshot uploads + instances.keys + .filter { storeProviderId => + // Only consider active providers that are part of this specific query run, + // but look through all state stores under this store ID, as it's possible that + // the same query re-runs with a new run ID but has already uploaded some snapshots. + storeProviderId.queryRunId == queryRunId && stateStoreLatestUploadedSnapshot - .getOrElse(storeProviderId, defaultSnapshotUploadEvent) + .getOrElse(storeProviderId.storeId, defaultSnapshotUploadEvent) .isLagging(referenceVersion, referenceTimestamp) - }.toSeq + } + .map(_.storeId) + .toSeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 3724cc44206b..e1a9b33557f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -29,8 +29,34 @@ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWra import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.functions.{count, expr} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.Utils +// RocksDBSkipMaintenanceOnCertainPartitionsProvider is a test-only provider that skips running +// maintenance for partitions 0 and 1 (these are arbitrary choices). This is used to test +// snapshot upload lag can be observed through StreamingQueryProgress metrics. +class RocksDBSkipMaintenanceOnCertainPartitionsProvider extends RocksDBStateStoreProvider { + override def doMaintenance(): Unit = { + if (stateStoreId.partitionId == 0 || stateStoreId.partitionId == 1) { + return + } + super.doMaintenance() + } +} + +// HDFSBackedSkipMaintenanceOnCertainPartitionsProvider is a test-only provider that skips running +// maintenance for partitions 0 and 1 (these are arbitrary choices). This is used to test +// snapshot upload lag can be observed through StreamingQueryProgress metrics. +class HDFSBackedSkipMaintenanceOnCertainPartitionsProvider extends HDFSBackedStateStoreProvider { + override def doMaintenance(): Unit = { + if (stateStoreId.partitionId == 0 || stateStoreId.partitionId == 1) { + return + } + super.doMaintenance() + } +} + + class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { import StateStoreCoordinatorSuite._ @@ -157,6 +183,277 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } + Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 2).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(1000) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all stores have uploaded a snapshot and it's logged by the coordinator + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) + assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0) + } + // Verify that we should not have any state stores lagging behind + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + query.stop() + } + } + } + + Seq( + ( + "RocksDBSkipMaintenanceOnCertainPartitionsProvider", + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName + ), + ( + "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName + ) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 3).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(1000) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) + if (partitionId <= 1) { + // Verify state stores in partition 0 and 1 are lagging and didn't upload anything + assert(coordRef.getLatestSnapshotVersionForTesting(storeId).isEmpty) + } else { + // Verify other stores have uploaded a snapshot and it's logged by the coordinator + assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0) + } + } + // We should have two state stores (id 0 and 1) that are lagging behind at this point + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.partitionId <= 1)) + query.stop() + } + } + } + + private val allJoinStateStoreNames: Seq[String] = + SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) + + Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + + s"reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", + SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a join query and run some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = joined.writeStream + .format("memory") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 5).foreach { _ => + input1.addData(1, 5) + input2.addData(1, 5, 10) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all state stores for join queries are reporting snapshot uploads + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + allJoinStateStoreNames.foreach { storeName => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) + assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0) + } + } + // Verify that we should not have any state stores lagging behind + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + query.stop() + } + } + } + + Seq( + ( + "RocksDBSkipMaintenanceOnCertainPartitionsProvider", + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName + ), + ( + "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName + ) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + + s"reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", + SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a join query and run some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = joined.writeStream + .format("memory") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + input1.addData(1, 5) + input2.addData(1, 5, 10) + query.processAllAvailable() + Thread.sleep(500) + } + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + // Verify all state stores for join queries are reporting snapshot uploads + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + allJoinStateStoreNames.foreach { storeName => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) + if (partitionId <= 1) { + // Verify state stores in partition 0 and 1 are lagging and didn't upload + assert(coordRef.getLatestSnapshotVersionForTesting(storeId).isEmpty) + } else { + // Verify other stores have uploaded a snapshot and it's properly logged + assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0) + } + } + } + // Verify that only stores from partition id 0 and 1 are lagging behind. + // Each partition has 4 stores for join queries, so there are 2 * 4 = 8 lagging stores. + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2 * 4) + assert(laggingStores.forall(_.partitionId <= 1)) + } + } + } + test( "SPARK-51358: Snapshot uploads in RocksDB are not reported if changelog " + "checkpointing is disabled" @@ -165,13 +462,14 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { sc, SQLConf.SHUFFLE_PARTITIONS.key -> "5", SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "false", SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "1", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2" + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { case (coordRef, spark) => import spark.implicits._ @@ -202,241 +500,81 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { query.stop() } } +} - test("SPARK-51358: Snapshot uploads in RocksDB are properly reported to the coordinator") { - withCoordinatorAndSQLConf( - sc, - SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2" - ) { - case (coordRef, spark) => - import spark.implicits._ - implicit val sqlContext = spark.sqlContext - - // Start a query and run some data to force snapshot uploads - val inputData = MemoryStream[Int] - val aggregated = inputData.toDF().dropDuplicates() - val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val query = aggregated.writeStream - .format("memory") - .outputMode("update") - .queryName("query") - .option("checkpointLocation", checkpointLocation.toString) - .start() - // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 2).foreach { _ => - inputData.addData(1, 2, 3) - query.processAllAvailable() - Thread.sleep(1000) - } - val stateCheckpointDir = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation - val latestVersion = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 - - // Verify all stores have uploaded a snapshot and it's logged by the coordinator - (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => - val providerId = - StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, partitionId), query.runId) - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) - } - // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) - query.stop() - } - } - - test( - "SPARK-51358: Snapshot uploads in RocksDBSkipMaintenanceOnCertainPartitionsProvider " + - "are properly reported to the coordinator" - ) { - withCoordinatorAndSQLConf( - sc, - SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2" - ) { - case (coordRef, spark) => - import spark.implicits._ - implicit val sqlContext = spark.sqlContext - - // Start a query and run some data to force snapshot uploads - val inputData = MemoryStream[Int] - val aggregated = inputData.toDF().dropDuplicates() - val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val query = aggregated.writeStream - .format("memory") - .outputMode("update") - .queryName("query") - .option("checkpointLocation", checkpointLocation.toString) - .start() - // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 2).foreach { _ => - inputData.addData(1, 2, 3) - query.processAllAvailable() - Thread.sleep(1000) - } - val stateCheckpointDir = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation - val latestVersion = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 - - (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => - val providerId = - StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, partitionId), query.runId) - if (partitionId <= 1) { - // Verify state stores in partition 0 and 1 are lagging and did not upload anything - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) - } else { - // Verify other stores have uploaded a snapshot and it's logged by the coordinator - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) - } - } - // We should have two state stores (id 0 and 1) that are lagging behind at this point - val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) - assert(laggingStores.size == 2) - assert(laggingStores.forall(_.storeId.partitionId <= 1)) - query.stop() - } - } - - private val allJoinStateStoreNames: Seq[String] = - SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) +class StateStoreCoordinatorStreamingSuite extends StreamTest { + import testImplicits._ - test( - "SPARK-51358: Snapshot uploads for join queries with RocksDBStateStoreProvider " + - "are properly reported to the coordinator" - ) { - withCoordinatorAndSQLConf( - sc, + test("SPARK-51358: Restarting queries do not mark state stores as lagging") { + withSQLConf( SQLConf.SHUFFLE_PARTITIONS.key -> "3", SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5" - ) { - case (coordRef, spark) => - import spark.implicits._ - implicit val sqlContext = spark.sqlContext - - // Start a join query and run some data to force snapshot uploads - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] - val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") - val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") - val joined = df1.join(df2, expr("leftKey = rightKey")) - val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val query = joined.writeStream - .format("memory") - .queryName("query") - .option("checkpointLocation", checkpointLocation.toString) - .start() - // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 5).foreach { _ => - input1.addData(1, 5) - input2.addData(1, 5, 10) - query.processAllAvailable() - Thread.sleep(500) - } - val stateCheckpointDir = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation - val latestVersion = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 - - // Verify all state stores for join queries are reporting snapshot uploads - (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => - allJoinStateStoreNames.foreach { storeName => - val providerId = - StateStoreProviderId( - StateStoreId(stateCheckpointDir, 0, partitionId, storeName), - query.runId - ) - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) - } - } - // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) - query.stop() - } - } - - test( - "SPARK-51358: Snapshot uploads for join queries with " + - "RocksDBSkipMaintenanceOnCertainPartitionsProvider are properly reported to the coordinator" - ) { - withCoordinatorAndSQLConf( - sc, - SQLConf.SHUFFLE_PARTITIONS.key -> "3", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[SkipMaintenanceOnCertainPartitionsProvider].getName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5" + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { - case (coordRef, spark) => - import spark.implicits._ - implicit val sqlContext = spark.sqlContext - - // Start a join query and run some data to force snapshot uploads - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] - val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") - val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") - val joined = df1.join(df2, expr("leftKey = rightKey")) - val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val query = joined.writeStream - .format("memory") - .queryName("query") - .option("checkpointLocation", checkpointLocation.toString) - .start() - // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 5).foreach { _ => - input1.addData(1, 5) - input2.addData(1, 5, 10) - query.processAllAvailable() - Thread.sleep(500) - } - val stateCheckpointDir = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation - val latestVersion = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 - - // Verify all state stores for join queries are reporting snapshot uploads - (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => - allJoinStateStoreNames.foreach { storeName => - val providerId = - StateStoreProviderId( - StateStoreId(stateCheckpointDir, 0, partitionId, storeName), - query.runId - ) - if (partitionId <= 1) { - // Verify state stores in partition 0 and 1 are lagging and did not upload anything - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) - } else { - // Verify other stores have uploaded a snapshot and it's logged by the coordinator - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + // Keep track of batch IDs, checkpoint directories, and snapshot versions for the next run + var stateCheckpoint = "" + var latestVersion = 0L + var latestSnapshotVersion = Seq.empty[Long] + + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) + stateCheckpoint = query.lastExecution.checkpointLocation + latestVersion = query.lastProgress.batchId + 1 + + // Verify all stores have uploaded a snapshot and it's logged by the coordinator + latestSnapshotVersion = (0 until numPartitions).map { partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val snapshotVersion = coordRef.getLatestSnapshotVersionForTesting(storeId).get + assert(snapshotVersion >= 0) + snapshotVersion } - } - } - // Verify that only stores from partition id 0 and 1 are lagging behind. - // Each partition has 4 stores for join queries, so there are 2 * 4 = 8 lagging stores. - val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) - assert(laggingStores.size == 2 * 4) - assert(laggingStores.forall(_.storeId.partitionId <= 1)) + // Verify that we should not have any state stores lagging behind + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + }, + StopStream + ) + + // Restart the query, but do not add any data yet so that the associated + // StateStoreProviderId (store id + query run id) in the coordinator does + // not have any uploads linked to it. + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) + + // Verify all stores still have uploaded snapshots from the previous run + (0 until numPartitions).map { partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val snapshotVersion = coordRef.getLatestSnapshotVersionForTesting(storeId).get + assert(snapshotVersion >= latestSnapshotVersion(partitionId)) + } + // Verify that we should not have any state stores lagging behind despite the restart + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + }, + StopStream + ) + } } } } From 3d79a8021b1239db1ca64ca260a2ee121a016054 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Fri, 21 Mar 2025 09:00:58 -0700 Subject: [PATCH 19/36] SPARK-51358 Make report interval's granularity per query --- .../streaming/ProgressReporter.scala | 3 ++- .../execution/streaming/state/RocksDB.scala | 2 +- .../state/StateStoreCoordinator.scala | 23 ++++++++++--------- .../state/StateStoreCoordinatorSuite.scala | 20 ++++++++++++---- .../streaming/state/ValueStateSuite.scala | 1 + 5 files changed, 32 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index ed17b55c5e74..8dfbb6615877 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -292,8 +292,9 @@ abstract class ProgressContext( // Ask the state store coordinator to log all lagging state stores if (progressReporter.coordinatorReportSnapshotUploadLag) { + val latestVersion = lastEpochId + 1 progressReporter.stateStoreCoordinator - .logLaggingStateStores(lastExecution.runId, lastEpochId + 1) + .logLaggingStateStores(lastExecution.runId, latestVersion) } // Update the value since this trigger executes a batch successfully. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index a89142e94af2..23d1ea9e51a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1476,7 +1476,7 @@ class RocksDB( lastUploadedSnapshotVersion.updateAndGet(v => Math.max(snapshot.version, v)) // Report snapshot upload event to the coordinator. if (conf.stateStoreCoordinatorReportSnapshotUploadLag) { - // Note that we still report uploads even when changelog checkpointing is enabled. + // Note that we still report uploads even when changelog checkpointing is disabled. // The coordinator needs a way to determine whether upload messages are disabled or not, // which would be different between RocksDB and HDFS stores due to changelog checkpointing. eventListener.foreach(_.reportSnapshotUploaded(snapshot.version)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 51024f57546f..c5da9a73fd3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -207,10 +207,10 @@ private class StateStoreCoordinator( // Default snapshot upload event to use when a provider has never uploaded a snapshot private val defaultSnapshotUploadEvent = SnapshotUploadEvent(-1, 0) - // Stores the last timestamp in milliseconds where the coordinator did a full report on - // instances lagging behind on snapshot uploads. The initial timestamp is defaulted to - // 0 milliseconds. - private var lastFullSnapshotLagReportTimeMs = 0L + // Stores the last timestamp in milliseconds for each queryRunId indicating when the + // coordinator did a report on instances lagging behind on snapshot uploads. + // The initial timestamp is defaulted to 0 milliseconds. + private val lastFullSnapshotLagReportTimeMs = new mutable.HashMap[UUID, Long] override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => @@ -270,15 +270,16 @@ private class StateStoreCoordinator( log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" ) // Report all stores that are behind in snapshot uploads. - // Only report the full list of providers lagging behind if the last reported time - // is not recent. The lag report interval denotes the minimum time between these - // full reports. + // Only report the list of providers lagging behind if the last reported time + // is not recent for this query run. The lag report interval denotes the minimum + // time between these full reports. val coordinatorLagReportInterval = sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) - if (laggingStores.nonEmpty && - currentTimestamp - lastFullSnapshotLagReportTimeMs > coordinatorLagReportInterval) { - // Mark timestamp of the full report and log the lagging instances - lastFullSnapshotLagReportTimeMs = currentTimestamp + val timeSinceLastReport = + currentTimestamp - lastFullSnapshotLagReportTimeMs.getOrElse(queryRunId, 0L) + if (timeSinceLastReport > coordinatorLagReportInterval) { + // Mark timestamp of the report and log the lagging instances + lastFullSnapshotLagReportTimeMs.put(queryRunId, currentTimestamp) // Only report the stores that are lagging the most behind in snapshot uploads. laggingStores .sortBy(stateStoreLatestUploadedSnapshot.getOrElse(_, defaultSnapshotUploadEvent)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index e1a9b33557f0..d05b7f866498 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -510,11 +510,12 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { SQLConf.SHUFFLE_PARTITIONS.key -> "3", SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "2", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "2", SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { withTempDir { srcDir => @@ -527,6 +528,14 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { testStream(query)( StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Force 3 rounds of snapshot uploads. + // MIN_DELTAS_FOR_SNAPSHOT is 2, so we do this 2*3 times. + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), AddData(inputData, 1, 2, 3), ProcessAllAvailable(), AddData(inputData, 1, 2, 3), @@ -552,12 +561,15 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { }, StopStream ) - - // Restart the query, but do not add any data yet so that the associated + // Restart the query, but do not add too much data so that the associated // StateStoreProviderId (store id + query run id) in the coordinator does // not have any uploads linked to it. testStream(query)( StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Perform one round of data, which is enough to activate instances and force a + // lagging instance report, but not enough to trigger a snapshot upload yet. + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), Execute { query => val coordRef = query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index 8af42d6dec26..093e8b991cc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -427,6 +427,7 @@ abstract class StateVariableSuiteBase extends SharedSparkSession before { StateStore.stop() require(!StateStore.isMaintenanceRunning) + spark.streams.stateStoreCoordinator // initialize the lazy coordinator } after { From a100f49df91fb044fcf13f8419e369cd0ffe2f38 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Fri, 21 Mar 2025 16:49:25 -0700 Subject: [PATCH 20/36] SPARK-51358 Better handling of edge case with wiped out coordinator --- .../state/HDFSBackedStateStoreProvider.scala | 5 +- .../state/RocksDBStateStoreProvider.scala | 29 ++-- .../streaming/state/StateStore.scala | 21 ++- .../state/StateStoreCoordinator.scala | 133 ++++++++++-------- .../state/StateStoreCoordinatorSuite.scala | 78 +++++++--- 5 files changed, 163 insertions(+), 103 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 75902e373d26..600beca5a96f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ import java.util -import java.util.Locale +import java.util.{Locale, UUID} import java.util.concurrent.atomic.LongAdder import scala.collection.mutable @@ -679,7 +679,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with log"for ${MDC(LogKeys.OP_TYPE, opType)}") // Report snapshot upload event to the coordinator, and include the store ID with the message. if (storeConf.stateStoreCoordinatorReportSnapshotUploadLag) { - StateStore.reportSnapshotUploaded(stateStoreId, version) + val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf)) + StateStore.reportSnapshotUploaded(StateStoreProviderId(stateStoreId, runId), version) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 24206c4f273f..b6aa478240fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StreamExecution} +import org.apache.spark.sql.execution.streaming.CheckpointFileManager import org.apache.spark.sql.execution.streaming.state.StateStoreEncoding.Avro import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform @@ -67,7 +67,7 @@ private[sql] class RocksDBStateStoreProvider verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, isInternal) val cfId = rocksDB.createColFamilyIfAbsent(colFamilyName, isInternal) val dataEncoderCacheKey = StateRowEncoderCacheKey( - queryRunId = getRunId(hadoopConf), + queryRunId = StateStoreProvider.getRunId(hadoopConf), operatorId = stateStoreId.operatorId, partitionId = stateStoreId.partitionId, stateStoreName = stateStoreId.storeName, @@ -385,7 +385,8 @@ private[sql] class RocksDBStateStoreProvider this.useColumnFamilies = useColumnFamilies this.stateStoreEncoding = storeConf.stateStoreEncodingFormat this.stateSchemaProvider = stateSchemaProvider - this.rocksDBEventListener = RocksDBEventListener(stateStoreId) + this.rocksDBEventListener = + RocksDBEventListener(StateStoreProvider.getRunId(hadoopConf), stateStoreId) if (useMultipleValuesPerKey) { require(useColumnFamilies, "Multiple values per key support requires column families to be" + @@ -395,7 +396,7 @@ private[sql] class RocksDBStateStoreProvider rocksDB // lazy initialization val dataEncoderCacheKey = StateRowEncoderCacheKey( - queryRunId = getRunId(hadoopConf), + queryRunId = StateStoreProvider.getRunId(hadoopConf), operatorId = stateStoreId.operatorId, partitionId = stateStoreId.partitionId, stateStoreName = stateStoreId.storeName, @@ -799,16 +800,6 @@ object RocksDBStateStoreProvider { ) } - private def getRunId(hadoopConf: Configuration): String = { - val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY) - if (runId != null) { - runId - } else { - assert(Utils.isTesting, "Failed to find query id/batch Id in task context") - UUID.randomUUID().toString - } - } - // Native operation latencies report as latency in microseconds // as SQLMetrics support millis. Convert the value to millis val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric( @@ -975,16 +966,18 @@ class RocksDBStateStoreChangeDataReader( * We pass this into the RocksDB instance to report specific events like snapshot uploads. * This should only be used to report back to the coordinator for metrics and monitoring purposes. */ -private[state] case class RocksDBEventListener(stateStoreId: StateStoreId) { +private[state] case class RocksDBEventListener(queryRunId: String, stateStoreId: StateStoreId) { + // Build the state store provider ID from the query run ID and the state store ID + private val providerId = StateStoreProviderId(stateStoreId, UUID.fromString(queryRunId)) /** * Callback function from RocksDB to report events to the coordinator. - * Information from the store provider such as the state store ID is + * Information from the store provider such as the state store ID and query run ID are * attached here to report back to the coordinator. * * @param version The snapshot version that was just uploaded from RocksDB */ def reportSnapshotUploaded(version: Long): Unit = { - // Report the state store ID and the version to the coordinator - StateStore.reportSnapshotUploaded(stateStoreId, version) + // Report the state store provider ID and the version to the coordinator + StateStore.reportSnapshotUploaded(providerId, version) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 59e26b71f80c..dbf998ab9a60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -652,6 +652,21 @@ object StateStoreProvider { } } } + + /** + * Get the runId from the provided hadoopConf. If it is not found, generate a random UUID. + * + * @param hadoopConf Hadoop configuration used by the StateStore to save state data + */ + private[state] def getRunId(hadoopConf: Configuration): String = { + val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY) + if (runId != null) { + runId + } else { + assert(Utils.isTesting, "Failed to find query id/batch Id in task context") + UUID.randomUUID().toString + } + } } /** @@ -1131,10 +1146,12 @@ object StateStore extends Logging { } } - private[state] def reportSnapshotUploaded(storeId: StateStoreId, snapshotVersion: Long): Unit = { + private[state] def reportSnapshotUploaded( + providerId: StateStoreProviderId, + snapshotVersion: Long): Unit = { // Attach the current timestamp of uploaded snapshot and send the message to the coordinator val currentTime = System.currentTimeMillis() - coordinatorRef.foreach(_.snapshotUploaded(storeId, snapshotVersion, currentTime)) + coordinatorRef.foreach(_.snapshotUploaded(providerId, snapshotVersion, currentTime)) } private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index c5da9a73fd3a..df5d52abac8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -57,10 +57,13 @@ private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage /** - * This message is used to report a state store instance has just finished uploading a snapshot, + * This message is used to report a state store has just finished uploading a snapshot, * along with the timestamp in milliseconds and the snapshot version. */ -private case class ReportSnapshotUploaded(storeId: StateStoreId, version: Long, timestamp: Long) +private case class ReportSnapshotUploaded( + providerId: StateStoreProviderId, + version: Long, + timestamp: Long) extends StateStoreCoordinatorMessage /** @@ -75,12 +78,12 @@ private case class LogLaggingStateStores(queryRunId: UUID, latestVersion: Long) * This message is used to retrieve the latest snapshot version reported for upload from a * specific state store. */ -private case class GetLatestSnapshotVersionForTesting(storeId: StateStoreId) +private case class GetLatestSnapshotVersionForTesting(providerId: StateStoreProviderId) extends StateStoreCoordinatorMessage /** * Message used for testing. - * This message is used to retrieve all active state store instance falling behind in + * This message is used to retrieve all active state store instances falling behind in * snapshot uploads, using version and time criteria. */ private case class GetLaggingStoresForTesting(queryRunId: UUID, latestVersion: Long) @@ -152,10 +155,10 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { /** Inform that an executor has uploaded a snapshot */ private[sql] def snapshotUploaded( - storeId: StateStoreId, + providerId: StateStoreProviderId, version: Long, timestamp: Long): Unit = { - rpcEndpointRef.askSync[Boolean](ReportSnapshotUploaded(storeId, version, timestamp)) + rpcEndpointRef.askSync[Boolean](ReportSnapshotUploaded(providerId, version, timestamp)) } /** Ask the coordinator to log all state store instances that are lagging behind in uploads */ @@ -167,8 +170,9 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { * Endpoint used for testing. * Get the latest snapshot version uploaded for a state store. */ - private[state] def getLatestSnapshotVersionForTesting(storeId: StateStoreId): Option[Long] = { - rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersionForTesting(storeId)) + private[state] def getLatestSnapshotVersionForTesting( + providerId: StateStoreProviderId): Option[Long] = { + rpcEndpointRef.askSync[Option[Long]](GetLatestSnapshotVersionForTesting(providerId)) } /** @@ -178,8 +182,8 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { */ private[state] def getLaggingStoresForTesting( queryRunId: UUID, - latestVersion: Long): Seq[StateStoreId] = { - rpcEndpointRef.askSync[Seq[StateStoreId]]( + latestVersion: Long): Seq[StateStoreProviderId] = { + rpcEndpointRef.askSync[Seq[StateStoreProviderId]]( GetLaggingStoresForTesting(queryRunId, latestVersion) ) } @@ -202,7 +206,7 @@ private class StateStoreCoordinator( // Stores the latest snapshot upload event for a specific state store private val stateStoreLatestUploadedSnapshot = - new mutable.HashMap[StateStoreId, SnapshotUploadEvent] + new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent] // Default snapshot upload event to use when a provider has never uploaded a snapshot private val defaultSnapshotUploadEvent = SnapshotUploadEvent(-1, 0) @@ -212,6 +216,10 @@ private class StateStoreCoordinator( // The initial timestamp is defaulted to 0 milliseconds. private val lastFullSnapshotLagReportTimeMs = new mutable.HashMap[UUID, Long] + // Stores the start time of the query's run. Queries that started recently should not + // have their state stores reported as lagging since we may not have all the information yet. + private val queryRunStartTimeMs = new mutable.HashMap[UUID, Long] + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => logDebug(s"Reported state store $id is active at $executorId") @@ -243,24 +251,30 @@ private class StateStoreCoordinator( val storeIdsToRemove = instances.keys.filter(_.queryRunId == runId).toSeq instances --= storeIdsToRemove + // Also remove these instances from snapshot upload event tracking + stateStoreLatestUploadedSnapshot --= storeIdsToRemove logDebug(s"Deactivating instances related to checkpoint location $runId: " + storeIdsToRemove.mkString(", ")) context.reply(true) - case ReportSnapshotUploaded(storeId, version, timestamp) => + case ReportSnapshotUploaded(providerId, version, timestamp) => // Ignore this upload event if the registered latest version for the store is more recent, // since it's possible that an older version gets uploaded after a new executor uploads for // the same state store but with a newer snapshot. - logDebug(s"Snapshot version $version was uploaded for state store $storeId") - if (!stateStoreLatestUploadedSnapshot.get(storeId).exists(_.version >= version)) { - stateStoreLatestUploadedSnapshot.put(storeId, SnapshotUploadEvent(version, timestamp)) + logDebug(s"Snapshot version $version was uploaded for state store $providerId") + if (!stateStoreLatestUploadedSnapshot.get(providerId).exists(_.version >= version)) { + stateStoreLatestUploadedSnapshot.put(providerId, SnapshotUploadEvent(version, timestamp)) } context.reply(true) case LogLaggingStateStores(queryRunId, latestVersion) => + val currentTimestamp = System.currentTimeMillis() + // Mark the query run's start time if the coordinator has never seen this query run before + if (!queryRunStartTimeMs.contains(queryRunId)) { + queryRunStartTimeMs.put(queryRunId, currentTimestamp) + } // Only log lagging instances if the snapshot report upload is enabled, // otherwise all instances will be considered lagging. - val currentTimestamp = System.currentTimeMillis() val laggingStores = findLaggingStores(queryRunId, latestVersion, currentTimestamp) if (laggingStores.nonEmpty) { logWarning( @@ -284,15 +298,15 @@ private class StateStoreCoordinator( laggingStores .sortBy(stateStoreLatestUploadedSnapshot.getOrElse(_, defaultSnapshotUploadEvent)) .take(sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT)) - .foreach { storeId => - val logMessage = stateStoreLatestUploadedSnapshot.get(storeId) match { + .foreach { providerId => + val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match { case Some(snapshotEvent) => val versionDelta = latestVersion - Math.max(snapshotEvent.version, 0) val timeDelta = currentTimestamp - snapshotEvent.timestamp log"StateStoreCoordinator Snapshot Lag Detected for " + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + - log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, storeId)} " + + log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, providerId.storeId)} " + log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, snapshotEvent)}, " + log"version delta: " + @@ -301,7 +315,7 @@ private class StateStoreCoordinator( case None => log"StateStoreCoordinator Snapshot Lag Detected for " + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + - log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, storeId)} " + + log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, providerId.storeId)} " + log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + log"latest snapshot: no upload for query run)" } @@ -311,9 +325,9 @@ private class StateStoreCoordinator( } context.reply(true) - case GetLatestSnapshotVersionForTesting(storeId) => - val version = stateStoreLatestUploadedSnapshot.get(storeId).map(_.version) - logDebug(s"Got latest snapshot version of the state store $storeId: $version") + case GetLatestSnapshotVersionForTesting(providerId) => + val version = stateStoreLatestUploadedSnapshot.get(providerId).map(_.version) + logDebug(s"Got latest snapshot version of the state store $providerId: $version") context.reply(version) case GetLaggingStoresForTesting(queryRunId, latestVersion) => @@ -333,29 +347,6 @@ private class StateStoreCoordinator( timestamp: Long ) extends Ordered[SnapshotUploadEvent] { - def isLagging(latestVersion: Long, latestTimestamp: Long): Boolean = { - // Use version 0 for stores that have not uploaded a snapshot version for this run. - val versionDelta = latestVersion - Math.max(version, 0) - val timeDelta = latestTimestamp - timestamp - - // Determine alert thresholds from configurations for both time and version differences. - val snapshotVersionDeltaMultiplier = sqlConf.getConf( - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG) - val maintenanceIntervalMultiplier = sqlConf.getConf( - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG) - val minDeltasForSnapshot = sqlConf.getConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) - val maintenanceInterval = sqlConf.getConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL) - - // Use the configured multipliers to determine the proper alert thresholds - val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * minDeltasForSnapshot - val minTimeDeltaForLogging = maintenanceIntervalMultiplier * maintenanceInterval - - // Mark a state store as lagging if it is behind in both version and time. - // For stores that have never uploaded a snapshot, the time requirement will - // be automatically satisfied as the initial timestamp is 0. - versionDelta > minVersionDeltaForLogging && timeDelta > minTimeDeltaForLogging - } - override def compare(otherEvent: SnapshotUploadEvent): Int = { // Compare by version first, then by timestamp as tiebreaker val versionCompare = this.version.compare(otherEvent.version) @@ -374,23 +365,45 @@ private class StateStoreCoordinator( private def findLaggingStores( queryRunId: UUID, referenceVersion: Long, - referenceTimestamp: Long): Seq[StateStoreId] = { + referenceTimestamp: Long): Seq[StateStoreProviderId] = { // Do not report any instance as lagging if report snapshot upload is disabled. if (!sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG)) { return Seq.empty } - // Look for state stores that are lagging behind in snapshot uploads - instances.keys - .filter { storeProviderId => - // Only consider active providers that are part of this specific query run, - // but look through all state stores under this store ID, as it's possible that - // the same query re-runs with a new run ID but has already uploaded some snapshots. - storeProviderId.queryRunId == queryRunId && - stateStoreLatestUploadedSnapshot - .getOrElse(storeProviderId.storeId, defaultSnapshotUploadEvent) - .isLagging(referenceVersion, referenceTimestamp) - } - .map(_.storeId) - .toSeq + + // Determine alert thresholds from configurations for both time and version differences. + val snapshotVersionDeltaMultiplier = + sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG) + val maintenanceIntervalMultiplier = + sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG) + val minDeltasForSnapshot = sqlConf.getConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + val maintenanceInterval = sqlConf.getConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL) + + // Use the configured multipliers to determine the proper alert thresholds + val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * minDeltasForSnapshot + val minTimeDeltaForLogging = maintenanceIntervalMultiplier * maintenanceInterval + + // Do not report any instance as lagging if this query run started recently, since the + // coordinator may be missing some information from the state stores. + // A run is considered recent if the time between now and the start of the run does not pass + // the time requirement for lagging instances (maintenance interval, times a multiplier). + if (referenceTimestamp - queryRunStartTimeMs(queryRunId) <= minTimeDeltaForLogging) { + return Seq.empty + } + // Look for active state store providers that are lagging behind in snapshot uploads + instances.keys.filter { storeProviderId => + // Only consider providers that are part of this specific query run + val latestSnapshot = stateStoreLatestUploadedSnapshot.getOrElse( + storeProviderId, + defaultSnapshotUploadEvent + ) + storeProviderId.queryRunId == queryRunId && ( + // Mark a state store as lagging if it's behind in both version and time. + // Stores that didn't upload a snapshot will be treated as a store with a snapshot of + // version 0. + referenceVersion - Math.max(latestSnapshot.version, 0) > minVersionDeltaForLogging && + referenceTimestamp - latestSnapshot.timestamp > minTimeDeltaForLogging + ) + }.toSeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index d05b7f866498..a204d362fa2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -231,7 +231,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) - assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) } // Verify that we should not have any state stores lagging behind assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) @@ -293,18 +294,19 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { partitionId => val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) if (partitionId <= 1) { // Verify state stores in partition 0 and 1 are lagging and didn't upload anything - assert(coordRef.getLatestSnapshotVersionForTesting(storeId).isEmpty) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) } else { // Verify other stores have uploaded a snapshot and it's logged by the coordinator - assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) } } // We should have two state stores (id 0 and 1) that are lagging behind at this point val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) assert(laggingStores.size == 2) - assert(laggingStores.forall(_.partitionId <= 1)) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) query.stop() } } @@ -367,7 +369,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { partitionId => allJoinStateStoreNames.foreach { storeName => val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) - assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) } } // Verify that we should not have any state stores lagging behind @@ -436,12 +439,13 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { partitionId => allJoinStateStoreNames.foreach { storeName => val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) + val providerId = StateStoreProviderId(storeId, query.runId) if (partitionId <= 1) { // Verify state stores in partition 0 and 1 are lagging and didn't upload - assert(coordRef.getLatestSnapshotVersionForTesting(storeId).isEmpty) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) } else { // Verify other stores have uploaded a snapshot and it's properly logged - assert(coordRef.getLatestSnapshotVersionForTesting(storeId).get >= 0) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) } } } @@ -449,7 +453,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { // Each partition has 4 stores for join queries, so there are 2 * 4 = 8 lagging stores. val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) assert(laggingStores.size == 2 * 4) - assert(laggingStores.forall(_.partitionId <= 1)) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) } } } @@ -515,16 +519,15 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "5", SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { withTempDir { srcDir => val inputData = MemoryStream[Int] val query = inputData.toDF().dropDuplicates() - // Keep track of batch IDs, checkpoint directories, and snapshot versions for the next run + // Keep track of state checkpoint directory and latest version for the second run var stateCheckpoint = "" - var latestVersion = 0L - var latestSnapshotVersion = Seq.empty[Long] + var firstRunLatestVersion = 0L testStream(query)( StartStream(checkpointLocation = srcDir.getCanonicalPath), @@ -547,19 +550,36 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) stateCheckpoint = query.lastExecution.checkpointLocation - latestVersion = query.lastProgress.batchId + 1 + firstRunLatestVersion = query.lastProgress.batchId + 1 // Verify all stores have uploaded a snapshot and it's logged by the coordinator - latestSnapshotVersion = (0 until numPartitions).map { partitionId => + (0 until numPartitions).map { partitionId => val storeId = StateStoreId(stateCheckpoint, 0, partitionId) - val snapshotVersion = coordRef.getLatestSnapshotVersionForTesting(storeId).get + val providerId = StateStoreProviderId(storeId, query.runId) + val snapshotVersion = coordRef.getLatestSnapshotVersionForTesting(providerId).get assert(snapshotVersion >= 0) snapshotVersion } // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + assert(coordRef.getLaggingStoresForTesting(query.runId, firstRunLatestVersion).isEmpty) }, - StopStream + // Stopping the streaming query should deactivate and clear snapshot uploaded events + StopStream, + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) + val latestVersion = query.lastProgress.batchId + 1 + + // Verify we evicted the previous latest uploaded snapshots from the coordinator + (0 until numPartitions).map { partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) + } + // Verify that we are not reporting any lagging stores after eviction + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + } ) // Restart the query, but do not add too much data so that the associated // StateStoreProviderId (store id + query run id) in the coordinator does @@ -573,15 +593,31 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { Execute { query => val coordRef = query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Verify that we are not reporting any lagging stores despite restarting + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + }, + // Force a snapshot upload + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) - // Verify all stores still have uploaded snapshots from the previous run + // Verify that these state stores are properly restored from the checkpoint (0 until numPartitions).map { partitionId => val storeId = StateStoreId(stateCheckpoint, 0, partitionId) - val snapshotVersion = coordRef.getLatestSnapshotVersionForTesting(storeId).get - assert(snapshotVersion >= latestSnapshotVersion(partitionId)) + val providerId = StateStoreProviderId(storeId, query.runId) + val latestSnapshotVersion = coordRef.getLatestSnapshotVersionForTesting(providerId) + assert(latestSnapshotVersion.get >= firstRunLatestVersion) } - // Verify that we should not have any state stores lagging behind despite the restart + // Verify that we are not reporting any lagging stores after the restart assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) }, StopStream From 2c07bf3e2b80bc3eb1fd8925088f9820eb2c0bea Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Fri, 21 Mar 2025 17:55:30 -0700 Subject: [PATCH 21/36] SPARK-51358 Use version requirement as well for lagging stores report --- .../state/StateStoreCoordinator.scala | 65 +++++++++++-------- .../state/StateStoreCoordinatorSuite.scala | 64 +++++++++++------- 2 files changed, 79 insertions(+), 50 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index df5d52abac8b..a2be16751438 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -216,9 +216,10 @@ private class StateStoreCoordinator( // The initial timestamp is defaulted to 0 milliseconds. private val lastFullSnapshotLagReportTimeMs = new mutable.HashMap[UUID, Long] - // Stores the start time of the query's run. Queries that started recently should not - // have their state stores reported as lagging since we may not have all the information yet. - private val queryRunStartTimeMs = new mutable.HashMap[UUID, Long] + // Stores the time and latest version of the query run's start. + // Queries that started recently should not have their state stores reported as lagging + // since we may not have all the information yet. + private val queryRunStartingPoint = new mutable.HashMap[UUID, (Long, Long)] override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => @@ -269,9 +270,10 @@ private class StateStoreCoordinator( case LogLaggingStateStores(queryRunId, latestVersion) => val currentTimestamp = System.currentTimeMillis() - // Mark the query run's start time if the coordinator has never seen this query run before - if (!queryRunStartTimeMs.contains(queryRunId)) { - queryRunStartTimeMs.put(queryRunId, currentTimestamp) + // Mark the query run's starting timestamp and latest version if the coordinator + // has never seen this query run before. + if (!queryRunStartingPoint.contains(queryRunId)) { + queryRunStartingPoint.put(queryRunId, (currentTimestamp, latestVersion)) } // Only log lagging instances if the snapshot report upload is enabled, // otherwise all instances will be considered lagging. @@ -342,26 +344,6 @@ private class StateStoreCoordinator( context.reply(true) } - case class SnapshotUploadEvent( - version: Long, - timestamp: Long - ) extends Ordered[SnapshotUploadEvent] { - - override def compare(otherEvent: SnapshotUploadEvent): Int = { - // Compare by version first, then by timestamp as tiebreaker - val versionCompare = this.version.compare(otherEvent.version) - if (versionCompare == 0) { - this.timestamp.compare(otherEvent.timestamp) - } else { - versionCompare - } - } - - override def toString(): String = { - s"SnapshotUploadEvent(version=$version, timestamp=$timestamp)" - } - } - private def findLaggingStores( queryRunId: UUID, referenceVersion: Long, @@ -386,8 +368,13 @@ private class StateStoreCoordinator( // Do not report any instance as lagging if this query run started recently, since the // coordinator may be missing some information from the state stores. // A run is considered recent if the time between now and the start of the run does not pass - // the time requirement for lagging instances (maintenance interval, times a multiplier). - if (referenceTimestamp - queryRunStartTimeMs(queryRunId) <= minTimeDeltaForLogging) { + // the time requirement for lagging instances. + // Similarly, the run is also considered too recent if not enough versions have passed + // since the start of the run. + val (runStartingTimeMs, runStartingVersion) = queryRunStartingPoint(queryRunId) + + if (referenceTimestamp - runStartingTimeMs <= minTimeDeltaForLogging || + referenceVersion - runStartingVersion <= minVersionDeltaForLogging) { return Seq.empty } // Look for active state store providers that are lagging behind in snapshot uploads @@ -399,6 +386,8 @@ private class StateStoreCoordinator( ) storeProviderId.queryRunId == queryRunId && ( // Mark a state store as lagging if it's behind in both version and time. + // A state store is considered lagging if it's behind in both version and time according + // to the configured thresholds. // Stores that didn't upload a snapshot will be treated as a store with a snapshot of // version 0. referenceVersion - Math.max(latestSnapshot.version, 0) > minVersionDeltaForLogging && @@ -407,3 +396,23 @@ private class StateStoreCoordinator( }.toSeq } } + +case class SnapshotUploadEvent( + version: Long, + timestamp: Long +) extends Ordered[SnapshotUploadEvent] { + + override def compare(otherEvent: SnapshotUploadEvent): Int = { + // Compare by version first, then by timestamp as tiebreaker + val versionCompare = this.version.compare(otherEvent.version) + if (versionCompare == 0) { + this.timestamp.compare(otherEvent.timestamp) + } else { + versionCompare + } + } + + override def toString(): String = { + s"SnapshotUploadEvent(version=$version, timestamp=$timestamp)" + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index a204d362fa2e..45f3f3efe2bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -514,8 +514,9 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { SQLConf.SHUFFLE_PARTITIONS.key -> "3", SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "2", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", @@ -525,18 +526,14 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { withTempDir { srcDir => val inputData = MemoryStream[Int] val query = inputData.toDF().dropDuplicates() + val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) // Keep track of state checkpoint directory and latest version for the second run var stateCheckpoint = "" var firstRunLatestVersion = 0L testStream(query)( StartStream(checkpointLocation = srcDir.getCanonicalPath), - // Force 3 rounds of snapshot uploads. - // MIN_DELTAS_FOR_SNAPSHOT is 2, so we do this 2*3 times. - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), + // Process 4 batches so that the coordinator can start reporting lagging instances AddData(inputData, 1, 2, 3), ProcessAllAvailable(), AddData(inputData, 1, 2, 3), @@ -548,7 +545,6 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { Execute { query => val coordRef = query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator - val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) stateCheckpoint = query.lastExecution.checkpointLocation firstRunLatestVersion = query.lastProgress.batchId + 1 @@ -556,19 +552,29 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { (0 until numPartitions).map { partitionId => val storeId = StateStoreId(stateCheckpoint, 0, partitionId) val providerId = StateStoreProviderId(storeId, query.runId) - val snapshotVersion = coordRef.getLatestSnapshotVersionForTesting(providerId).get - assert(snapshotVersion >= 0) - snapshotVersion + if (partitionId <= 1) { + // Verify state stores in partition 0 and 1 are lagging and didn't upload + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) + } else { + // Verify other stores have uploaded a snapshot and it's properly logged + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } } - // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStoresForTesting(query.runId, firstRunLatestVersion).isEmpty) + // Sleep a bit to ensure that the coordinator can start reporting lagging stores. + // The sleep duration is the maintenance interval times the config's multiplier. + Thread.sleep(5 * 100) + // Verify that the normal state store (partitionId=2) is not lagging behind, + // and the faulty stores are reported as lagging. + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, firstRunLatestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) }, // Stopping the streaming query should deactivate and clear snapshot uploaded events StopStream, Execute { query => val coordRef = query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator - val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) val latestVersion = query.lastProgress.batchId + 1 // Verify we evicted the previous latest uploaded snapshots from the coordinator @@ -577,7 +583,8 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { val providerId = StateStoreProviderId(storeId, query.runId) assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) } - // Verify that we are not reporting any lagging stores after eviction + // Verify that we are not reporting any lagging stores after eviction, + // since none of these state stores are active anymore. assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) } ) @@ -594,10 +601,11 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { val coordRef = query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator val latestVersion = query.lastProgress.batchId + 1 - // Verify that we are not reporting any lagging stores despite restarting + // Verify that we are not reporting any lagging stores despite restarting, + // because the query started too recently. assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) }, - // Force a snapshot upload + // Process 3 more batches, so that we pass the version threshold for lag reports AddData(inputData, 1, 2, 3), ProcessAllAvailable(), AddData(inputData, 1, 2, 3), @@ -608,17 +616,29 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { val coordRef = query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator val latestVersion = query.lastProgress.batchId + 1 - val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) // Verify that these state stores are properly restored from the checkpoint (0 until numPartitions).map { partitionId => val storeId = StateStoreId(stateCheckpoint, 0, partitionId) val providerId = StateStoreProviderId(storeId, query.runId) val latestSnapshotVersion = coordRef.getLatestSnapshotVersionForTesting(providerId) - assert(latestSnapshotVersion.get >= firstRunLatestVersion) + if (partitionId <= 1) { + // Verify state stores in partition 0 and 1 are still lagging and didn't upload + assert(latestSnapshotVersion.isEmpty) + } else { + // Verify other stores have uploaded a snapshot and it's properly logged + assert(latestSnapshotVersion.get >= firstRunLatestVersion) + } } - // Verify that we are not reporting any lagging stores after the restart - assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + // Sleep a bit to ensure that the coordinator has enough time to receive upload events + // The sleep duration is the maintenance interval times the config's multiplier + Thread.sleep(5 * 100) + // Verify that we're back to reporting the faulty state stores (partitionId 0 and 1) + // since enough versions and time has passed since the query's restart. + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) }, StopStream ) From 9b7b75e1393575b061760a3f0f317fcec215d4cc Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Mon, 24 Mar 2025 12:53:11 -0700 Subject: [PATCH 22/36] SPARK-51358 Try separate coordinatorRef --- .../apache/spark/sql/internal/SQLConf.scala | 24 +++- .../state/HDFSBackedStateStoreProvider.scala | 9 +- .../state/RocksDBStateStoreProvider.scala | 9 +- .../streaming/state/StateStore.scala | 31 +++++- .../state/StateStoreCoordinator.scala | 103 +++++++++--------- .../state/StateStoreCoordinatorSuite.scala | 8 +- 6 files changed, 120 insertions(+), 64 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d2968958140c..f24cf8d3bc66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2266,9 +2266,9 @@ object SQLConf { "times this multiplier." ) .version("4.1.0") - .intConf - .checkValue(k => k >= 1, "Must be greater than or equal to 1") - .createWithDefault(5) + .longConf + .checkValue(k => k >= 1L, "Must be greater than or equal to 1") + .createWithDefault(5L) val STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG = buildConf("spark.sql.streaming.stateStore.multiplierForMinTimeDiffToLog") @@ -2279,9 +2279,9 @@ object SQLConf { "current time by the configured maintenance interval, times this multiplier." ) .version("4.1.0") - .intConf - .checkValue(k => k >= 1, "Must be greater than or equal to 1") - .createWithDefault(10) + .longConf + .checkValue(k => k >= 1L, "Must be greater than or equal to 1") + .createWithDefault(10L) val STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG = buildConf("spark.sql.streaming.stateStore.coordinatorReportSnapshotUploadLag") @@ -5864,9 +5864,21 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def stateStoreSkipNullsForStreamStreamJoins: Boolean = getConf(STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS) + def stateStoreCoordinatorMultiplierForMinVersionDiffToLog: Long = + getConf(STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG) + + def stateStoreCoordinatorMultiplierForMinTimeDiffToLog: Long = + getConf(STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG) + def stateStoreCoordinatorReportSnapshotUploadLag: Boolean = getConf(STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG) + def stateStoreCoordinatorSnapshotLagReportInterval: Long = + getConf(STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) + + def stateStoreCoordinatorMaxLaggingStoresToReport: Int = + getConf(STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT) + def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 600beca5a96f..9027d04028dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -680,7 +680,14 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with // Report snapshot upload event to the coordinator, and include the store ID with the message. if (storeConf.stateStoreCoordinatorReportSnapshotUploadLag) { val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf)) - StateStore.reportSnapshotUploaded(StateStoreProviderId(stateStoreId, runId), version) + val currentTimestamp = System.currentTimeMillis() + StateStoreProvider.coordinatorRef.foreach( + _.snapshotUploaded( + StateStoreProviderId(stateStoreId, runId), + version, + currentTimestamp + ) + ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index b6aa478240fe..cc0bb2cdb393 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -978,6 +978,13 @@ private[state] case class RocksDBEventListener(queryRunId: String, stateStoreId: */ def reportSnapshotUploaded(version: Long): Unit = { // Report the state store provider ID and the version to the coordinator - StateStore.reportSnapshotUploaded(providerId, version) + val currentTimestamp = System.currentTimeMillis() + StateStoreProvider.coordinatorRef.foreach( + _.snapshotUploaded( + providerId, + version, + currentTimestamp + ) + ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index dbf998ab9a60..47b62ee4de8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -593,7 +593,10 @@ trait StateStoreProvider { def supportedInstanceMetrics: Seq[StateStoreInstanceMetric] = Seq.empty } -object StateStoreProvider { +object StateStoreProvider extends Logging { + + @GuardedBy("this") + var stateStoreCoordinatorRef: StateStoreCoordinatorRef = null /** * Return a instance of the given provider class name. The instance will not be initialized. @@ -667,6 +670,32 @@ object StateStoreProvider { UUID.randomUUID().toString } } + + /** + * Create the state store coordinator reference which will be reused across state stores within + * the executor JVM process. + */ + private[state] def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized { + val env = SparkEnv.get + if (env != null) { + val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER + // If running locally, then the coordinator reference in stateStoreCoordinatorRef may be have + // become inactive as SparkContext + SparkEnv may have been restarted. Hence, when running in + // driver, always recreate the reference. + if (isDriver || stateStoreCoordinatorRef == null) { + logDebug("Getting StateStoreCoordinatorRef") + if (isDriver || stateStoreCoordinatorRef == null) { + stateStoreCoordinatorRef = StateStoreCoordinatorRef.forExecutor(SparkEnv.get) + logInfo(log"Retrieved reference to StateStoreCoordinator: " + + log"${MDC(LogKeys.STATE_STORE_ID, stateStoreCoordinatorRef)}") + } + } + Some(stateStoreCoordinatorRef) + } else { + stateStoreCoordinatorRef = null + None + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index a2be16751438..8ea3100bad2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -221,6 +221,10 @@ private class StateStoreCoordinator( // since we may not have all the information yet. private val queryRunStartingPoint = new mutable.HashMap[UUID, (Long, Long)] + def coordinatorLagReportInterval: Long = { + sqlConf.stateStoreCoordinatorSnapshotLagReportInterval + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => logDebug(s"Reported state store $id is active at $executorId") @@ -274,55 +278,53 @@ private class StateStoreCoordinator( // has never seen this query run before. if (!queryRunStartingPoint.contains(queryRunId)) { queryRunStartingPoint.put(queryRunId, (currentTimestamp, latestVersion)) - } - // Only log lagging instances if the snapshot report upload is enabled, - // otherwise all instances will be considered lagging. - val laggingStores = findLaggingStores(queryRunId, latestVersion, currentTimestamp) - if (laggingStores.nonEmpty) { - logWarning( - log"StateStoreCoordinator Snapshot Lag Report for " + - log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + - log"Number of state stores falling behind: " + - log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" - ) - // Report all stores that are behind in snapshot uploads. - // Only report the list of providers lagging behind if the last reported time - // is not recent for this query run. The lag report interval denotes the minimum - // time between these full reports. - val coordinatorLagReportInterval = - sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL) - val timeSinceLastReport = - currentTimestamp - lastFullSnapshotLagReportTimeMs.getOrElse(queryRunId, 0L) - if (timeSinceLastReport > coordinatorLagReportInterval) { - // Mark timestamp of the report and log the lagging instances - lastFullSnapshotLagReportTimeMs.put(queryRunId, currentTimestamp) - // Only report the stores that are lagging the most behind in snapshot uploads. - laggingStores - .sortBy(stateStoreLatestUploadedSnapshot.getOrElse(_, defaultSnapshotUploadEvent)) - .take(sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT)) - .foreach { providerId => - val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match { - case Some(snapshotEvent) => - val versionDelta = latestVersion - Math.max(snapshotEvent.version, 0) - val timeDelta = currentTimestamp - snapshotEvent.timestamp - - log"StateStoreCoordinator Snapshot Lag Detected for " + - log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + - log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, providerId.storeId)} " + - log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + - log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, snapshotEvent)}, " + - log"version delta: " + - log"${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + - log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" - case None => + } else { + // Only log lagging instances if the snapshot report upload is enabled, + // otherwise all instances will be considered lagging. + val laggingStores = findLaggingStores(queryRunId, latestVersion, currentTimestamp) + if (laggingStores.nonEmpty) { + logWarning( + log"StateStoreCoordinator Snapshot Lag Report for " + + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + + log"Number of state stores falling behind: " + + log"${MDC(LogKeys.NUM_LAGGING_STORES, laggingStores.size)}" + ) + // Report all stores that are behind in snapshot uploads. + // Only report the list of providers lagging behind if the last reported time + // is not recent for this query run. The lag report interval denotes the minimum + // time between these full reports. + val timeSinceLastReport = + currentTimestamp - lastFullSnapshotLagReportTimeMs.getOrElse(queryRunId, 0L) + if (timeSinceLastReport > coordinatorLagReportInterval) { + // Mark timestamp of the report and log the lagging instances + lastFullSnapshotLagReportTimeMs.put(queryRunId, currentTimestamp) + // Only report the stores that are lagging the most behind in snapshot uploads. + laggingStores + .sortBy(stateStoreLatestUploadedSnapshot.getOrElse(_, defaultSnapshotUploadEvent)) + .take(sqlConf.stateStoreCoordinatorMaxLaggingStoresToReport) + .foreach { providerId => + val baseLogMessage = log"StateStoreCoordinator Snapshot Lag Detected for " + log"queryRunId=${MDC(LogKeys.QUERY_RUN_ID, queryRunId)} - " + log"Store ID: ${MDC(LogKeys.STATE_STORE_ID, providerId.storeId)} " + - log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}, " + - log"latest snapshot: no upload for query run)" + log"(Latest batch ID: ${MDC(LogKeys.BATCH_ID, latestVersion)}" + + val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match { + case Some(snapshotEvent) => + val versionDelta = latestVersion - Math.max(snapshotEvent.version, 0) + val timeDelta = currentTimestamp - snapshotEvent.timestamp + + baseLogMessage + log", " + + log"latest snapshot: ${MDC(LogKeys.SNAPSHOT_EVENT, snapshotEvent)}, " + + log"version delta: " + + log"${MDC(LogKeys.SNAPSHOT_EVENT_VERSION_DELTA, versionDelta)}, " + + log"time delta: ${MDC(LogKeys.SNAPSHOT_EVENT_TIME_DELTA, timeDelta)}ms)" + case None => + baseLogMessage + log", latest snapshot: no upload for query run)" + } + logWarning(logMessage) } - logWarning(logMessage) - } + } } } context.reply(true) @@ -349,17 +351,16 @@ private class StateStoreCoordinator( referenceVersion: Long, referenceTimestamp: Long): Seq[StateStoreProviderId] = { // Do not report any instance as lagging if report snapshot upload is disabled. - if (!sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG)) { + if (!sqlConf.stateStoreCoordinatorReportSnapshotUploadLag) { return Seq.empty } // Determine alert thresholds from configurations for both time and version differences. val snapshotVersionDeltaMultiplier = - sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG) - val maintenanceIntervalMultiplier = - sqlConf.getConf(SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG) - val minDeltasForSnapshot = sqlConf.getConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) - val maintenanceInterval = sqlConf.getConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL) + sqlConf.stateStoreCoordinatorMultiplierForMinVersionDiffToLog + val maintenanceIntervalMultiplier = sqlConf.stateStoreCoordinatorMultiplierForMinTimeDiffToLog + val minDeltasForSnapshot = sqlConf.stateStoreMinDeltasForSnapshot + val maintenanceInterval = sqlConf.streamingMaintenanceInterval // Use the configured multipliers to determine the proper alert thresholds val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * minDeltasForSnapshot diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 45f3f3efe2bd..0699c4f6f32f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -218,7 +218,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .option("checkpointLocation", checkpointLocation.toString) .start() // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 2).foreach { _ => + (0 until 4).foreach { _ => inputData.addData(1, 2, 3) query.processAllAvailable() Thread.sleep(1000) @@ -282,7 +282,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .option("checkpointLocation", checkpointLocation.toString) .start() // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 3).foreach { _ => + (0 until 4).foreach { _ => inputData.addData(1, 2, 3) query.processAllAvailable() Thread.sleep(1000) @@ -354,7 +354,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .option("checkpointLocation", checkpointLocation.toString) .start() // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 5).foreach { _ => + (0 until 7).foreach { _ => input1.addData(1, 5) input2.addData(1, 5, 10) query.processAllAvailable() @@ -425,7 +425,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .option("checkpointLocation", checkpointLocation.toString) .start() // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 6).foreach { _ => + (0 until 7).foreach { _ => input1.addData(1, 5) input2.addData(1, 5, 10) query.processAllAvailable() From 7a3dca47988e00bbb372e5dc30a66a541e30a640 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Mon, 24 Mar 2025 14:51:32 -0700 Subject: [PATCH 23/36] SPARK-51358 Clean up fix for coordinatorRef --- .../spark/sql/execution/streaming/state/StateStore.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 47b62ee4de8d..7234e0c134b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -596,7 +596,7 @@ trait StateStoreProvider { object StateStoreProvider extends Logging { @GuardedBy("this") - var stateStoreCoordinatorRef: StateStoreCoordinatorRef = null + private var stateStoreCoordinatorRef: StateStoreCoordinatorRef = _ /** * Return a instance of the given provider class name. The instance will not be initialized. @@ -672,8 +672,8 @@ object StateStoreProvider extends Logging { } /** - * Create the state store coordinator reference which will be reused across state stores within - * the executor JVM process. + * Create the state store coordinator reference which will be reused across state store providers + * in the executor. */ private[state] def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized { val env = SparkEnv.get From ea73d47220723377e6752b04ccbacca1fadce2ef Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Mon, 24 Mar 2025 15:12:37 -0700 Subject: [PATCH 24/36] SPARK-51358 Switch to case class for query start --- .../execution/streaming/state/StateStore.scala | 18 ++++-------------- .../state/StateStoreCoordinator.scala | 12 +++++++----- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7234e0c134b8..e9f07d969d7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -679,17 +679,15 @@ object StateStoreProvider extends Logging { val env = SparkEnv.get if (env != null) { val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER - // If running locally, then the coordinator reference in stateStoreCoordinatorRef may be have + // If running locally, then the coordinator reference in stateStoreCoordinatorRef may have // become inactive as SparkContext + SparkEnv may have been restarted. Hence, when running in // driver, always recreate the reference. if (isDriver || stateStoreCoordinatorRef == null) { logDebug("Getting StateStoreCoordinatorRef") - if (isDriver || stateStoreCoordinatorRef == null) { - stateStoreCoordinatorRef = StateStoreCoordinatorRef.forExecutor(SparkEnv.get) - logInfo(log"Retrieved reference to StateStoreCoordinator: " + - log"${MDC(LogKeys.STATE_STORE_ID, stateStoreCoordinatorRef)}") - } + stateStoreCoordinatorRef = StateStoreCoordinatorRef.forExecutor(env) } + logInfo(log"Retrieved reference to StateStoreCoordinator: " + + log"${MDC(LogKeys.STATE_STORE_COORDINATOR, stateStoreCoordinatorRef)}") Some(stateStoreCoordinatorRef) } else { stateStoreCoordinatorRef = null @@ -1175,14 +1173,6 @@ object StateStore extends Logging { } } - private[state] def reportSnapshotUploaded( - providerId: StateStoreProviderId, - snapshotVersion: Long): Unit = { - // Attach the current timestamp of uploaded snapshot and send the message to the coordinator - val currentTime = System.currentTimeMillis() - coordinatorRef.foreach(_.snapshotUploaded(providerId, snapshotVersion, currentTime)) - } - private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 8ea3100bad2d..c1fc67438f25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -219,7 +219,7 @@ private class StateStoreCoordinator( // Stores the time and latest version of the query run's start. // Queries that started recently should not have their state stores reported as lagging // since we may not have all the information yet. - private val queryRunStartingPoint = new mutable.HashMap[UUID, (Long, Long)] + private val queryRunStartingPoint = new mutable.HashMap[UUID, QueryStartInfo] def coordinatorLagReportInterval: Long = { sqlConf.stateStoreCoordinatorSnapshotLagReportInterval @@ -277,7 +277,7 @@ private class StateStoreCoordinator( // Mark the query run's starting timestamp and latest version if the coordinator // has never seen this query run before. if (!queryRunStartingPoint.contains(queryRunId)) { - queryRunStartingPoint.put(queryRunId, (currentTimestamp, latestVersion)) + queryRunStartingPoint.put(queryRunId, QueryStartInfo(latestVersion, currentTimestamp)) } else { // Only log lagging instances if the snapshot report upload is enabled, // otherwise all instances will be considered lagging. @@ -372,10 +372,10 @@ private class StateStoreCoordinator( // the time requirement for lagging instances. // Similarly, the run is also considered too recent if not enough versions have passed // since the start of the run. - val (runStartingTimeMs, runStartingVersion) = queryRunStartingPoint(queryRunId) + val queryStartInfo = queryRunStartingPoint(queryRunId) - if (referenceTimestamp - runStartingTimeMs <= minTimeDeltaForLogging || - referenceVersion - runStartingVersion <= minVersionDeltaForLogging) { + if (referenceTimestamp - queryStartInfo.startTimestamp <= minTimeDeltaForLogging || + referenceVersion - queryStartInfo.version <= minVersionDeltaForLogging) { return Seq.empty } // Look for active state store providers that are lagging behind in snapshot uploads @@ -417,3 +417,5 @@ case class SnapshotUploadEvent( s"SnapshotUploadEvent(version=$version, timestamp=$timestamp)" } } + +case class QueryStartInfo(version: Long, startTimestamp: Long) From 3de70084ed1f5089826716504c878d3799c3e7f7 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Tue, 25 Mar 2025 10:31:35 -0700 Subject: [PATCH 25/36] SPARK-51358 Add simultaneous query test --- .../state/StateStoreCoordinator.scala | 72 ++++++++++------ .../state/StateStoreCoordinatorSuite.scala | 86 ++++++++++++++++++- 2 files changed, 130 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index c1fc67438f25..a8e850d3e691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -221,6 +221,39 @@ private class StateStoreCoordinator( // since we may not have all the information yet. private val queryRunStartingPoint = new mutable.HashMap[UUID, QueryStartInfo] + def shouldCoordinatorReportSnapshotLag( + runId: UUID, + referenceVersion: Long, + referenceTimestamp: Long): Boolean = { + // Definitely do not report if it is disabled or the corresponding run id did not start yet. + if (!sqlConf.stateStoreCoordinatorReportSnapshotUploadLag || + !queryRunStartingPoint.contains(runId)) { + false + } else { + // Determine alert thresholds from configurations for both time and version differences. + val snapshotVersionDeltaMultiplier = + sqlConf.stateStoreCoordinatorMultiplierForMinVersionDiffToLog + val maintenanceIntervalMultiplier = sqlConf.stateStoreCoordinatorMultiplierForMinTimeDiffToLog + val minDeltasForSnapshot = sqlConf.stateStoreMinDeltasForSnapshot + val maintenanceInterval = sqlConf.streamingMaintenanceInterval + + // Use the configured multipliers to determine the proper alert thresholds + val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * minDeltasForSnapshot + val minTimeDeltaForLogging = maintenanceIntervalMultiplier * maintenanceInterval + + // Do not report any instance as lagging if this query run started recently, since the + // coordinator may be missing some information from the state stores. + // A run is considered recent if the time between now and the start of the run does not pass + // the time requirement for lagging instances. + // Similarly, the run is also considered too recent if not enough versions have passed + // since the start of the run. + val queryStartInfo = queryRunStartingPoint(runId) + + referenceTimestamp - queryStartInfo.timestamp > minTimeDeltaForLogging && + referenceVersion - queryStartInfo.version > minVersionDeltaForLogging + } + } + def coordinatorLagReportInterval: Long = { sqlConf.stateStoreCoordinatorSnapshotLagReportInterval } @@ -258,6 +291,9 @@ private class StateStoreCoordinator( instances --= storeIdsToRemove // Also remove these instances from snapshot upload event tracking stateStoreLatestUploadedSnapshot --= storeIdsToRemove + // Remove the corresponding run id entries for report time and starting time + lastFullSnapshotLagReportTimeMs -= runId + queryRunStartingPoint -= runId logDebug(s"Deactivating instances related to checkpoint location $runId: " + storeIdsToRemove.mkString(", ")) context.reply(true) @@ -278,7 +314,7 @@ private class StateStoreCoordinator( // has never seen this query run before. if (!queryRunStartingPoint.contains(queryRunId)) { queryRunStartingPoint.put(queryRunId, QueryStartInfo(latestVersion, currentTimestamp)) - } else { + } else if (shouldCoordinatorReportSnapshotLag(queryRunId, latestVersion, currentTimestamp)) { // Only log lagging instances if the snapshot report upload is enabled, // otherwise all instances will be considered lagging. val laggingStores = findLaggingStores(queryRunId, latestVersion, currentTimestamp) @@ -336,9 +372,14 @@ private class StateStoreCoordinator( case GetLaggingStoresForTesting(queryRunId, latestVersion) => val currentTimestamp = System.currentTimeMillis() - val laggingStores = findLaggingStores(queryRunId, latestVersion, currentTimestamp) - logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}") - context.reply(laggingStores) + // Only report if the corresponding run has all the necessary information to start reporting + if (shouldCoordinatorReportSnapshotLag(queryRunId, latestVersion, currentTimestamp)) { + val laggingStores = findLaggingStores(queryRunId, latestVersion, currentTimestamp) + logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}") + context.reply(laggingStores) + } else { + context.reply(Seq.empty) + } case StopCoordinator => stop() // Stop before replying to ensure that endpoint name has been deregistered @@ -350,11 +391,6 @@ private class StateStoreCoordinator( queryRunId: UUID, referenceVersion: Long, referenceTimestamp: Long): Seq[StateStoreProviderId] = { - // Do not report any instance as lagging if report snapshot upload is disabled. - if (!sqlConf.stateStoreCoordinatorReportSnapshotUploadLag) { - return Seq.empty - } - // Determine alert thresholds from configurations for both time and version differences. val snapshotVersionDeltaMultiplier = sqlConf.stateStoreCoordinatorMultiplierForMinVersionDiffToLog @@ -366,18 +402,6 @@ private class StateStoreCoordinator( val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * minDeltasForSnapshot val minTimeDeltaForLogging = maintenanceIntervalMultiplier * maintenanceInterval - // Do not report any instance as lagging if this query run started recently, since the - // coordinator may be missing some information from the state stores. - // A run is considered recent if the time between now and the start of the run does not pass - // the time requirement for lagging instances. - // Similarly, the run is also considered too recent if not enough versions have passed - // since the start of the run. - val queryStartInfo = queryRunStartingPoint(queryRunId) - - if (referenceTimestamp - queryStartInfo.startTimestamp <= minTimeDeltaForLogging || - referenceVersion - queryStartInfo.version <= minVersionDeltaForLogging) { - return Seq.empty - } // Look for active state store providers that are lagging behind in snapshot uploads instances.keys.filter { storeProviderId => // Only consider providers that are part of this specific query run @@ -392,12 +416,14 @@ private class StateStoreCoordinator( // Stores that didn't upload a snapshot will be treated as a store with a snapshot of // version 0. referenceVersion - Math.max(latestSnapshot.version, 0) > minVersionDeltaForLogging && - referenceTimestamp - latestSnapshot.timestamp > minTimeDeltaForLogging + referenceTimestamp - latestSnapshot.timestamp > minTimeDeltaForLogging ) }.toSeq } } +case class QueryStartInfo(version: Long, timestamp: Long) + case class SnapshotUploadEvent( version: Long, timestamp: Long @@ -417,5 +443,3 @@ case class SnapshotUploadEvent( s"SnapshotUploadEvent(version=$version, timestamp=$timestamp)" } } - -case class QueryStartInfo(version: Long, startTimestamp: Long) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 0699c4f6f32f..07687ff28c8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -218,10 +218,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .option("checkpointLocation", checkpointLocation.toString) .start() // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 4).foreach { _ => + (0 until 6).foreach { _ => inputData.addData(1, 2, 3) query.processAllAvailable() - Thread.sleep(1000) + Thread.sleep(500) } val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation @@ -282,10 +282,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .option("checkpointLocation", checkpointLocation.toString) .start() // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 4).foreach { _ => + (0 until 6).foreach { _ => inputData.addData(1, 2, 3) query.processAllAvailable() - Thread.sleep(1000) + Thread.sleep(500) } val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation @@ -458,6 +458,84 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } + test("SPARK-51358: Verify coordinator properly handles simultaneous query runs") { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start and run two queries together with some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val dedupe1 = input1.toDF().dropDuplicates() + val dedupe2 = input2.toDF().dropDuplicates() + val checkpointLocation1 = Utils.createTempDir().getAbsoluteFile + val checkpointLocation2 = Utils.createTempDir().getAbsoluteFile + val query1 = dedupe1.writeStream + .format("memory") + .outputMode("update") + .queryName("query1") + .option("checkpointLocation", checkpointLocation1.toString) + .start() + val query2 = dedupe2.writeStream + .format("memory") + .outputMode("update") + .queryName("query2") + .option("checkpointLocation", checkpointLocation2.toString) + .start() + // Go through several rounds of input to force snapshot uploads for both queries + (0 until 3).foreach { _ => + input1.addData(1, 2, 3) + input2.addData(1, 2, 3) + query1.processAllAvailable() + query2.processAllAvailable() + // Process twice the amount of data for the first query + input1.addData(1, 2, 3) + query1.processAllAvailable() + Thread.sleep(1000) + } + // Verify that the coordinator logged the correct lagging stores for the first query + val streamingQuery1 = query1.asInstanceOf[StreamingQueryWrapper].streamingQuery + val latestVersion1 = streamingQuery1.lastProgress.batchId + 1 + val laggingStores1 = coordRef.getLaggingStoresForTesting(query1.runId, latestVersion1) + + assert(laggingStores1.size == 2) + assert(laggingStores1.forall(_.storeId.partitionId <= 1)) + assert(laggingStores1.forall(_.queryRunId == query1.runId)) + + // Verify that the second query run hasn't reported anything yet due to lack of data + val streamingQuery2 = query2.asInstanceOf[StreamingQueryWrapper].streamingQuery + var latestVersion2 = streamingQuery2.lastProgress.batchId + 1 + var laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, latestVersion2) + assert(laggingStores2.isEmpty) + + // Process some more data for the second query to force lag reports + input2.addData(1, 2, 3) + query2.processAllAvailable() + Thread.sleep(500) + + // Verify that the coordinator logged the correct lagging stores for the second query + latestVersion2 = streamingQuery2.lastProgress.batchId + 1 + laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, latestVersion2) + + assert(laggingStores2.size == 2) + assert(laggingStores2.forall(_.storeId.partitionId <= 1)) + assert(laggingStores2.forall(_.queryRunId == query2.runId)) + } + } + test( "SPARK-51358: Snapshot uploads in RocksDB are not reported if changelog " + "checkpointing is disabled" From 70b7a8afaa0a77fc0c34250acc64a4c267dd83c4 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Tue, 25 Mar 2025 11:25:31 -0700 Subject: [PATCH 26/36] SPARK-51358 Remove additional faulty providers from tests --- .../state/StateStoreCoordinatorSuite.scala | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 07687ff28c8c..4249a60ad071 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -32,31 +32,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.Utils -// RocksDBSkipMaintenanceOnCertainPartitionsProvider is a test-only provider that skips running -// maintenance for partitions 0 and 1 (these are arbitrary choices). This is used to test -// snapshot upload lag can be observed through StreamingQueryProgress metrics. -class RocksDBSkipMaintenanceOnCertainPartitionsProvider extends RocksDBStateStoreProvider { - override def doMaintenance(): Unit = { - if (stateStoreId.partitionId == 0 || stateStoreId.partitionId == 1) { - return - } - super.doMaintenance() - } -} - -// HDFSBackedSkipMaintenanceOnCertainPartitionsProvider is a test-only provider that skips running -// maintenance for partitions 0 and 1 (these are arbitrary choices). This is used to test -// snapshot upload lag can be observed through StreamingQueryProgress metrics. -class HDFSBackedSkipMaintenanceOnCertainPartitionsProvider extends HDFSBackedStateStoreProvider { - override def doMaintenance(): Unit = { - if (stateStoreId.partitionId == 0 || stateStoreId.partitionId == 1) { - return - } - super.doMaintenance() - } -} - - class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { import StateStoreCoordinatorSuite._ From a69d44e1d3fc6d259ba6f239929db77af317d71d Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Wed, 26 Mar 2025 10:19:46 -0700 Subject: [PATCH 27/36] SPARK-51358 Switch default to version 0 and clean up rest of feedback --- .../state/HDFSBackedStateStoreProvider.scala | 2 +- .../execution/streaming/state/RocksDB.scala | 6 ++-- .../streaming/state/StateStoreConf.scala | 2 +- .../state/StateStoreCoordinator.scala | 30 +++++++++---------- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 24ca7ddf54a0..74f386bc715f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -700,7 +700,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with // Compare and update with the version that was just uploaded. lastUploadedSnapshotVersion.updateAndGet(v => Math.max(version, v)) // Report snapshot upload event to the coordinator, and include the store ID with the message. - if (storeConf.stateStoreCoordinatorReportSnapshotUploadLag) { + if (storeConf.reportSnapshotUploadLag) { val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf)) val currentTimestamp = System.currentTimeMillis() StateStoreProvider.coordinatorRef.foreach( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 23d1ea9e51a1..1446b6fd956c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -1475,7 +1475,7 @@ class RocksDB( // Compare and update with the version that was just uploaded. lastUploadedSnapshotVersion.updateAndGet(v => Math.max(snapshot.version, v)) // Report snapshot upload event to the coordinator. - if (conf.stateStoreCoordinatorReportSnapshotUploadLag) { + if (conf.reportSnapshotUploadLag) { // Note that we still report uploads even when changelog checkpointing is disabled. // The coordinator needs a way to determine whether upload messages are disabled or not, // which would be different between RocksDB and HDFS stores due to changelog checkpointing. @@ -1730,7 +1730,7 @@ case class RocksDBConf( compressionCodec: String, allowFAllocate: Boolean, compression: String, - stateStoreCoordinatorReportSnapshotUploadLag: Boolean) + reportSnapshotUploadLag: Boolean) object RocksDBConf { /** Common prefix of all confs in SQLConf that affects RocksDB */ @@ -1914,7 +1914,7 @@ object RocksDBConf { storeConf.compressionCodec, getBooleanConf(ALLOW_FALLOCATE_CONF), getStringConf(COMPRESSION_CONF), - storeConf.stateStoreCoordinatorReportSnapshotUploadLag) + storeConf.reportSnapshotUploadLag) } def apply(): RocksDBConf = apply(new StateStoreConf()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 26c77fd2ea3b..e0450cfc4f69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -100,7 +100,7 @@ class StateStoreConf( /** * Whether the coordinator is reporting state stores trailing behind in snapshot uploads. */ - val stateStoreCoordinatorReportSnapshotUploadLag: Boolean = + val reportSnapshotUploadLag: Boolean = sqlConf.stateStoreCoordinatorReportSnapshotUploadLag /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index a8e850d3e691..cf56085aae0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -209,14 +209,14 @@ private class StateStoreCoordinator( new mutable.HashMap[StateStoreProviderId, SnapshotUploadEvent] // Default snapshot upload event to use when a provider has never uploaded a snapshot - private val defaultSnapshotUploadEvent = SnapshotUploadEvent(-1, 0) + private val defaultSnapshotUploadEvent = SnapshotUploadEvent(0, 0) // Stores the last timestamp in milliseconds for each queryRunId indicating when the // coordinator did a report on instances lagging behind on snapshot uploads. // The initial timestamp is defaulted to 0 milliseconds. private val lastFullSnapshotLagReportTimeMs = new mutable.HashMap[UUID, Long] - // Stores the time and latest version of the query run's start. + // Stores the time and version registered at the query run's start. // Queries that started recently should not have their state stores reported as lagging // since we may not have all the information yet. private val queryRunStartingPoint = new mutable.HashMap[UUID, QueryStartInfo] @@ -347,7 +347,7 @@ private class StateStoreCoordinator( val logMessage = stateStoreLatestUploadedSnapshot.get(providerId) match { case Some(snapshotEvent) => - val versionDelta = latestVersion - Math.max(snapshotEvent.version, 0) + val versionDelta = latestVersion - snapshotEvent.version val timeDelta = currentTimestamp - snapshotEvent.timestamp baseLogMessage + log", " + @@ -402,23 +402,23 @@ private class StateStoreCoordinator( val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * minDeltasForSnapshot val minTimeDeltaForLogging = maintenanceIntervalMultiplier * maintenanceInterval - // Look for active state store providers that are lagging behind in snapshot uploads - instances.keys.filter { storeProviderId => - // Only consider providers that are part of this specific query run - val latestSnapshot = stateStoreLatestUploadedSnapshot.getOrElse( - storeProviderId, - defaultSnapshotUploadEvent - ) - storeProviderId.queryRunId == queryRunId && ( + // Look for active state store providers that are lagging behind in snapshot uploads. + // The coordinator should only consider providers that are part of this specific query run. + instances.view.keys + .filter(_.queryRunId == queryRunId) + .filter { storeProviderId => + val latestSnapshot = stateStoreLatestUploadedSnapshot.getOrElse( + storeProviderId, + defaultSnapshotUploadEvent + ) // Mark a state store as lagging if it's behind in both version and time. // A state store is considered lagging if it's behind in both version and time according // to the configured thresholds. // Stores that didn't upload a snapshot will be treated as a store with a snapshot of - // version 0. - referenceVersion - Math.max(latestSnapshot.version, 0) > minVersionDeltaForLogging && + // version 0 and timestamp 0ms. + referenceVersion - latestSnapshot.version > minVersionDeltaForLogging && referenceTimestamp - latestSnapshot.timestamp > minTimeDeltaForLogging - ) - }.toSeq + }.toSeq } } From 29000ecb6e879871035e49834cb57648c7dbeb13 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Thu, 27 Mar 2025 10:57:48 -0700 Subject: [PATCH 28/36] SPARK-51358 Add additional tests --- .../state/StateStoreCoordinatorSuite.scala | 122 +++++++++++++++++- 1 file changed, 121 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 4249a60ad071..b8ff957c1563 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -546,7 +546,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { (0 until 5).foreach { _ => inputData.addData(1, 2, 3) query.processAllAvailable() - Thread.sleep(1000) + Thread.sleep(500) } val latestVersion = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 @@ -557,6 +557,50 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { query.stop() } } + + test("SPARK-51358: Snapshot lag reports properly detects when all state stores are lagging") { + withCoordinatorAndSQLConf( + sc, + // Only use two partitions with the faulty store provider (both stores will skip uploads) + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + + // Start a query and run some data to force snapshot uploads + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Go through several rounds of input to force snapshot uploads + (0 until 5).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + val latestVersion = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 + // Verify that all instances are marked as lagging, since no upload messages are being sent + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).size == 2) + query.stop() + } + } } class StateStoreCoordinatorStreamingSuite extends StreamTest { @@ -698,6 +742,82 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { } } } + + test("SPARK-51358: Restarting queries with updated SQLConf get propagated to the coordinator") { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "1", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + // Keep track of state checkpoint directory and latest version for the second run + var stateCheckpoint = "" + + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Process multiple batches so that the coordinator can start reporting lagging instances + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + stateCheckpoint = query.lastExecution.checkpointLocation + val latestVersion = query.lastProgress.batchId + 1 + // Sleep a bit to ensure that the coordinator can start reporting lagging stores. + // The sleep duration is the maintenance interval times the config's multiplier. + Thread.sleep(5 * 100) + // Verify that only the faulty stores are reported as lagging + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + // Stopping the streaming query should deactivate and clear snapshot uploaded events + StopStream + ) + // Bump up version multiplier, which would stop the coordinator from reporting + // lagging stores for the next few versions + spark.conf + .set(SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key, "10") + // Restart the query, and verify the conf change reflects in the coordinator + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Process the same amount of data as the first run + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Sleep a bit to ensure that the coordinator can start reporting lagging stores. + Thread.sleep(5 * 100) + // Verify that we are not reporting any lagging stores despite restarting + // because of the higher version multiplier + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + }, + StopStream + ) + } + } + } } object StateStoreCoordinatorSuite { From 9856463bebfa029393437b35f88961ffde9a83da Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Tue, 1 Apr 2025 17:36:39 -0700 Subject: [PATCH 29/36] SPARK-51358 Fix case for AvailableNow and repeated restarts --- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../streaming/IncrementalExecution.scala | 3 +- .../streaming/MicroBatchExecution.scala | 3 +- .../streaming/ProgressReporter.scala | 6 +- .../state/HDFSBackedStateStoreProvider.scala | 8 ++ .../execution/streaming/state/RocksDB.scala | 34 ++++- .../state/RocksDBStateStoreProvider.scala | 29 +++- .../state/StateStoreCoordinator.scala | 107 ++++++--------- .../state/StateStoreCoordinatorSuite.scala | 127 +++++++++++------- 9 files changed, 195 insertions(+), 124 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b30ce998b195..2b048215b429 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2353,7 +2353,7 @@ object SQLConf { .version("4.1.0") .intConf .checkValue(k => k >= 0, "Must be greater than or equal to 0") - .createWithDefault(10) + .createWithDefault(5) val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 246057a5a9d0..8c1e5e901513 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -71,7 +71,8 @@ class IncrementalExecution( MutableMap[Long, Array[Array[String]]] = MutableMap[Long, Array[Array[String]]](), val stateSchemaMetadatas: MutableMap[Long, StateSchemaBroadcast] = MutableMap[Long, StateSchemaBroadcast](), - mode: CommandExecutionMode.Value = CommandExecutionMode.ALL) + mode: CommandExecutionMode.Value = CommandExecutionMode.ALL, + val isTerminatingTrigger: Boolean = false) extends QueryExecution(sparkSession, logicalPlan, mode = mode) with Logging { // Modified planner with stateful operations. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index fe06cbb19c3a..f57da05f6d0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -858,7 +858,8 @@ class MicroBatchExecution( watermarkPropagator, execCtx.previousContext.isEmpty, currentStateStoreCkptId, - stateSchemaMetadatas) + stateSchemaMetadatas, + isTerminatingTrigger = trigger.isInstanceOf[AvailableNowTrigger.type]) execCtx.executionPlan.executedPlan // Force the lazy generation of execution plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 8dfbb6615877..9a2139fa4aa6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -294,7 +294,11 @@ abstract class ProgressContext( if (progressReporter.coordinatorReportSnapshotUploadLag) { val latestVersion = lastEpochId + 1 progressReporter.stateStoreCoordinator - .logLaggingStateStores(lastExecution.runId, latestVersion) + .logLaggingStateStores( + lastExecution.runId, + latestVersion, + lastExecution.isTerminatingTrigger + ) } // Update the value since this trigger executes a batch successfully. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 74f386bc715f..2d25da5f6c1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -1038,6 +1038,14 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with logDebug(s"Loading snapshot at version $snapshotVersion and apply delta files to version " + s"$endVersion takes $elapsedMs ms.") + // Report snapshot version loaded to the coordinator, and include the store ID with the message + if (storeConf.reportSnapshotUploadLag) { + val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf)) + // Since the snapshot was uploaded at a previous time, we set the upload timestamp as 0ms. + StateStoreProvider.coordinatorRef.foreach( + _.snapshotUploaded(StateStoreProviderId(stateStoreId, runId), snapshotVersion, 0L)) + } + result } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 5ff353743a58..0713a22068a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -64,7 +64,6 @@ case object StoreTaskCompletionListener extends RocksDBOpType("store_task_comple * @param stateStoreId StateStoreId for the state store * @param localRootDir Root directory in local disk that is used to working and checkpointing dirs * @param hadoopConf Hadoop configuration for talking to the remote file system - * @param loggingId Id that will be prepended in logs for isolating concurrent RocksDBs * @param eventListener The RocksDBEventListener object for reporting events to the coordinator */ class RocksDB( @@ -390,6 +389,9 @@ class RocksDB( // Initialize maxVersion upon successful load from DFS fileManager.setMaxSeenVersion(version) + // Report this snapshot version to the coordinator + reportSnapshotVersionToCoordinator(latestSnapshotVersion) + openLocalRocksDB(metadata) if (loadedVersion != version) { @@ -467,6 +469,9 @@ class RocksDB( // Initialize maxVersion upon successful load from DFS fileManager.setMaxSeenVersion(version) + // Report this snapshot version to the coordinator + reportSnapshotVersionToCoordinator(latestSnapshotVersion) + openLocalRocksDB(metadata) if (loadedVersion != version) { @@ -604,6 +609,8 @@ class RocksDB( loadedVersion = -1 // invalidate loaded data throw t } + // Report this snapshot version to the coordinator + reportSnapshotVersionToCoordinator(snapshotVersion) this } @@ -1478,12 +1485,7 @@ class RocksDB( // Compare and update with the version that was just uploaded. lastUploadedSnapshotVersion.updateAndGet(v => Math.max(snapshot.version, v)) // Report snapshot upload event to the coordinator. - if (conf.reportSnapshotUploadLag) { - // Note that we still report uploads even when changelog checkpointing is disabled. - // The coordinator needs a way to determine whether upload messages are disabled or not, - // which would be different between RocksDB and HDFS stores due to changelog checkpointing. - eventListener.foreach(_.reportSnapshotUploaded(snapshot.version)) - } + reportSnapshotUploadToCoordinator(snapshot.version) } finally { snapshot.close() } @@ -1491,6 +1493,24 @@ class RocksDB( fileManagerMetrics } + /** Reports to the coordinator with the event listener that a snapshot finished uploading */ + private def reportSnapshotUploadToCoordinator(version: Long): Unit = { + if (conf.reportSnapshotUploadLag) { + // Note that we still report snapshot versions even when changelog checkpointing is disabled. + // The coordinator needs a way to determine whether upload messages are disabled or not, + // which would be different between RocksDB and HDFS stores due to changelog checkpointing. + eventListener.foreach(_.reportSnapshotUploaded(version)) + } + } + + /** Reports to the coordinator the store's latest loaded snapshot version */ + private def reportSnapshotVersionToCoordinator(version: Long): Unit = { + // Skip reporting if the snapshot version is 0, which means there are no snapshots + if (conf.reportSnapshotUploadLag && version > 0L) { + eventListener.foreach(_.reportSnapshotVersion(version)) + } + } + /** Create a native RocksDB logger that forwards native logs to log4j with correct log levels. */ private def createLogger(): Logger = { val dbLogger = new Logger(rocksDbOptions.infoLogLevel()) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 79c83598c4e7..e0be9e2a911d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -533,8 +533,17 @@ private[sql] class RocksDBStateStoreProvider s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})" val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr) - new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr, - useColumnFamilies, storeConf.enableStateStoreCheckpointIds, stateStoreId.partitionId, Some(rocksDBEventListener)) + new RocksDB( + dfsRootDir, + RocksDBConf(storeConf), + localRootDir, + hadoopConf, + storeIdStr, + useColumnFamilies, + storeConf.enableStateStoreCheckpointIds, + stateStoreId.partitionId, + Some(rocksDBEventListener) + ) } private val keyValueEncoderMap = new java.util.concurrent.ConcurrentHashMap[String, @@ -974,6 +983,7 @@ class RocksDBStateStoreChangeDataReader( private[state] case class RocksDBEventListener(queryRunId: String, stateStoreId: StateStoreId) { // Build the state store provider ID from the query run ID and the state store ID private val providerId = StateStoreProviderId(stateStoreId, UUID.fromString(queryRunId)) + /** * Callback function from RocksDB to report events to the coordinator. * Information from the store provider such as the state store ID and query run ID are @@ -992,4 +1002,19 @@ private[state] case class RocksDBEventListener(queryRunId: String, stateStoreId: ) ) } + + /** + * Report the state store's last loaded snapshot version to the coordinator. + * This method is used when the store loads a snapshot version instead of + * uploading a snapshot version, as the context of the snapshot upload timestamp + * is missing. + * + * @param version The snapshot version that was just uploaded from RocksDB + */ + def reportSnapshotVersion(version: Long): Unit = { + // Report the state store provider ID and the version to the coordinator. + // Since this is not the time when the snapshot was uploaded, we'll use 0ms to + // prevent the coordinator to use time lag checks for this store. + StateStoreProvider.coordinatorRef.foreach(_.snapshotUploaded(providerId, version, 0L)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index cf56085aae0a..c7df34c77e5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -70,7 +70,10 @@ private case class ReportSnapshotUploaded( * This message is used for the coordinator to look for all state stores that are lagging behind * in snapshot uploads. The coordinator will then log a warning message for each lagging instance. */ -private case class LogLaggingStateStores(queryRunId: UUID, latestVersion: Long) +private case class LogLaggingStateStores( + queryRunId: UUID, + latestVersion: Long, + isTerminatingTrigger: Boolean) extends StateStoreCoordinatorMessage /** @@ -86,7 +89,10 @@ private case class GetLatestSnapshotVersionForTesting(providerId: StateStoreProv * This message is used to retrieve all active state store instances falling behind in * snapshot uploads, using version and time criteria. */ -private case class GetLaggingStoresForTesting(queryRunId: UUID, latestVersion: Long) +private case class GetLaggingStoresForTesting( + queryRunId: UUID, + latestVersion: Long, + isTerminatingTrigger: Boolean) extends StateStoreCoordinatorMessage private object StopCoordinator @@ -162,8 +168,12 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { } /** Ask the coordinator to log all state store instances that are lagging behind in uploads */ - private[sql] def logLaggingStateStores(queryRunId: UUID, latestVersion: Long): Unit = { - rpcEndpointRef.askSync[Boolean](LogLaggingStateStores(queryRunId, latestVersion)) + private[sql] def logLaggingStateStores( + queryRunId: UUID, + latestVersion: Long, + isTerminatingTrigger: Boolean): Unit = { + rpcEndpointRef.askSync[Boolean]( + LogLaggingStateStores(queryRunId, latestVersion, isTerminatingTrigger)) } /** @@ -182,9 +192,10 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { */ private[state] def getLaggingStoresForTesting( queryRunId: UUID, - latestVersion: Long): Seq[StateStoreProviderId] = { + latestVersion: Long, + isTerminatingTrigger: Boolean = false): Seq[StateStoreProviderId] = { rpcEndpointRef.askSync[Seq[StateStoreProviderId]]( - GetLaggingStoresForTesting(queryRunId, latestVersion) + GetLaggingStoresForTesting(queryRunId, latestVersion, isTerminatingTrigger) ) } @@ -216,47 +227,11 @@ private class StateStoreCoordinator( // The initial timestamp is defaulted to 0 milliseconds. private val lastFullSnapshotLagReportTimeMs = new mutable.HashMap[UUID, Long] - // Stores the time and version registered at the query run's start. - // Queries that started recently should not have their state stores reported as lagging - // since we may not have all the information yet. - private val queryRunStartingPoint = new mutable.HashMap[UUID, QueryStartInfo] - - def shouldCoordinatorReportSnapshotLag( - runId: UUID, - referenceVersion: Long, - referenceTimestamp: Long): Boolean = { - // Definitely do not report if it is disabled or the corresponding run id did not start yet. - if (!sqlConf.stateStoreCoordinatorReportSnapshotUploadLag || - !queryRunStartingPoint.contains(runId)) { - false - } else { - // Determine alert thresholds from configurations for both time and version differences. - val snapshotVersionDeltaMultiplier = - sqlConf.stateStoreCoordinatorMultiplierForMinVersionDiffToLog - val maintenanceIntervalMultiplier = sqlConf.stateStoreCoordinatorMultiplierForMinTimeDiffToLog - val minDeltasForSnapshot = sqlConf.stateStoreMinDeltasForSnapshot - val maintenanceInterval = sqlConf.streamingMaintenanceInterval - - // Use the configured multipliers to determine the proper alert thresholds - val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * minDeltasForSnapshot - val minTimeDeltaForLogging = maintenanceIntervalMultiplier * maintenanceInterval - - // Do not report any instance as lagging if this query run started recently, since the - // coordinator may be missing some information from the state stores. - // A run is considered recent if the time between now and the start of the run does not pass - // the time requirement for lagging instances. - // Similarly, the run is also considered too recent if not enough versions have passed - // since the start of the run. - val queryStartInfo = queryRunStartingPoint(runId) - - referenceTimestamp - queryStartInfo.timestamp > minTimeDeltaForLogging && - referenceVersion - queryStartInfo.version > minVersionDeltaForLogging - } - } + private def shouldCoordinatorReportSnapshotLag: Boolean = + sqlConf.stateStoreCoordinatorReportSnapshotUploadLag - def coordinatorLagReportInterval: Long = { + private def coordinatorLagReportInterval: Long = sqlConf.stateStoreCoordinatorSnapshotLagReportInterval - } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId, providerIdsToCheck) => @@ -293,7 +268,6 @@ private class StateStoreCoordinator( stateStoreLatestUploadedSnapshot --= storeIdsToRemove // Remove the corresponding run id entries for report time and starting time lastFullSnapshotLagReportTimeMs -= runId - queryRunStartingPoint -= runId logDebug(s"Deactivating instances related to checkpoint location $runId: " + storeIdsToRemove.mkString(", ")) context.reply(true) @@ -308,16 +282,13 @@ private class StateStoreCoordinator( } context.reply(true) - case LogLaggingStateStores(queryRunId, latestVersion) => + case LogLaggingStateStores(queryRunId, latestVersion, isTerminatingTrigger) => val currentTimestamp = System.currentTimeMillis() - // Mark the query run's starting timestamp and latest version if the coordinator - // has never seen this query run before. - if (!queryRunStartingPoint.contains(queryRunId)) { - queryRunStartingPoint.put(queryRunId, QueryStartInfo(latestVersion, currentTimestamp)) - } else if (shouldCoordinatorReportSnapshotLag(queryRunId, latestVersion, currentTimestamp)) { - // Only log lagging instances if the snapshot report upload is enabled, - // otherwise all instances will be considered lagging. - val laggingStores = findLaggingStores(queryRunId, latestVersion, currentTimestamp) + // Only log lagging instances if snapshot lag reporting and uploading is enabled, + // otherwise all instances will be considered lagging. + if (shouldCoordinatorReportSnapshotLag) { + val laggingStores = + findLaggingStores(queryRunId, latestVersion, currentTimestamp, isTerminatingTrigger) if (laggingStores.nonEmpty) { logWarning( log"StateStoreCoordinator Snapshot Lag Report for " + @@ -370,11 +341,12 @@ private class StateStoreCoordinator( logDebug(s"Got latest snapshot version of the state store $providerId: $version") context.reply(version) - case GetLaggingStoresForTesting(queryRunId, latestVersion) => + case GetLaggingStoresForTesting(queryRunId, latestVersion, isTerminatingTrigger) => val currentTimestamp = System.currentTimeMillis() - // Only report if the corresponding run has all the necessary information to start reporting - if (shouldCoordinatorReportSnapshotLag(queryRunId, latestVersion, currentTimestamp)) { - val laggingStores = findLaggingStores(queryRunId, latestVersion, currentTimestamp) + // Only report if snapshot lag reporting is enabled + if (shouldCoordinatorReportSnapshotLag) { + val laggingStores = + findLaggingStores(queryRunId, latestVersion, currentTimestamp, isTerminatingTrigger) logDebug(s"Got lagging state stores: ${laggingStores.mkString(", ")}") context.reply(laggingStores) } else { @@ -390,7 +362,8 @@ private class StateStoreCoordinator( private def findLaggingStores( queryRunId: UUID, referenceVersion: Long, - referenceTimestamp: Long): Seq[StateStoreProviderId] = { + referenceTimestamp: Long, + isTerminatingTrigger: Boolean): Seq[StateStoreProviderId] = { // Determine alert thresholds from configurations for both time and version differences. val snapshotVersionDeltaMultiplier = sqlConf.stateStoreCoordinatorMultiplierForMinVersionDiffToLog @@ -407,6 +380,8 @@ private class StateStoreCoordinator( instances.view.keys .filter(_.queryRunId == queryRunId) .filter { storeProviderId => + // Stores that didn't upload a snapshot will be treated as a store with a snapshot of + // version 0 and timestamp 0ms. val latestSnapshot = stateStoreLatestUploadedSnapshot.getOrElse( storeProviderId, defaultSnapshotUploadEvent @@ -414,16 +389,18 @@ private class StateStoreCoordinator( // Mark a state store as lagging if it's behind in both version and time. // A state store is considered lagging if it's behind in both version and time according // to the configured thresholds. - // Stores that didn't upload a snapshot will be treated as a store with a snapshot of - // version 0 and timestamp 0ms. - referenceVersion - latestSnapshot.version > minVersionDeltaForLogging && + val isBehindOnVersions = + referenceVersion - latestSnapshot.version > minVersionDeltaForLogging + val isBehindOnTime = referenceTimestamp - latestSnapshot.timestamp > minTimeDeltaForLogging + // If the query is using a trigger that self-terminates like OneTimeTrigger + // and AvailableNowTrigger, we ignore the time threshold check as the upload frequency + // is not fully dependent on the maintenance interval. + isBehindOnVersions && (isTerminatingTrigger || isBehindOnTime) }.toSeq } } -case class QueryStartInfo(version: Long, timestamp: Long) - case class SnapshotUploadEvent( version: Long, timestamp: Long diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index b8ff957c1563..d3c63edf1845 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWra import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.functions.{count, expr} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.util.Utils class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { @@ -471,7 +471,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .option("checkpointLocation", checkpointLocation2.toString) .start() // Go through several rounds of input to force snapshot uploads for both queries - (0 until 3).foreach { _ => + (0 until 2).foreach { _ => input1.addData(1, 2, 3) input2.addData(1, 2, 3) query1.processAllAvailable() @@ -611,7 +611,7 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { SQLConf.SHUFFLE_PARTITIONS.key -> "3", SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "2", SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", @@ -624,13 +624,13 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { val inputData = MemoryStream[Int] val query = inputData.toDF().dropDuplicates() val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) - // Keep track of state checkpoint directory and latest version for the second run + // Keep track of state checkpoint directory for the second run var stateCheckpoint = "" - var firstRunLatestVersion = 0L testStream(query)( StartStream(checkpointLocation = srcDir.getCanonicalPath), - // Process 4 batches so that the coordinator can start reporting lagging instances + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), AddData(inputData, 1, 2, 3), ProcessAllAvailable(), AddData(inputData, 1, 2, 3), @@ -643,9 +643,9 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { val coordRef = query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator stateCheckpoint = query.lastExecution.checkpointLocation - firstRunLatestVersion = query.lastProgress.batchId + 1 + val latestVersion = query.lastProgress.batchId + 1 - // Verify all stores have uploaded a snapshot and it's logged by the coordinator + // Verify the coordinator logged snapshot uploads (0 until numPartitions).map { partitionId => val storeId = StateStoreId(stateCheckpoint, 0, partitionId) val providerId = StateStoreProviderId(storeId, query.runId) @@ -657,13 +657,10 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) } } - // Sleep a bit to ensure that the coordinator can start reporting lagging stores. - // The sleep duration is the maintenance interval times the config's multiplier. - Thread.sleep(5 * 100) // Verify that the normal state store (partitionId=2) is not lagging behind, // and the faulty stores are reported as lagging. val laggingStores = - coordRef.getLaggingStoresForTesting(query.runId, firstRunLatestVersion) + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) assert(laggingStores.size == 2) assert(laggingStores.forall(_.storeId.partitionId <= 1)) }, @@ -685,9 +682,8 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) } ) - // Restart the query, but do not add too much data so that the associated - // StateStoreProviderId (store id + query run id) in the coordinator does - // not have any uploads linked to it. + // Restart the query, but do not add too much data so that we don't associate + // the current StateStoreProviderId (store id + query run id) with any new uploads. testStream(query)( StartStream(checkpointLocation = srcDir.getCanonicalPath), // Perform one round of data, which is enough to activate instances and force a @@ -698,23 +694,8 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { val coordRef = query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator val latestVersion = query.lastProgress.batchId + 1 - // Verify that we are not reporting any lagging stores despite restarting, - // because the query started too recently. - assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) - }, - // Process 3 more batches, so that we pass the version threshold for lag reports - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - Execute { query => - val coordRef = - query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator - val latestVersion = query.lastProgress.batchId + 1 - - // Verify that these state stores are properly restored from the checkpoint + // Verify that the state stores have restored their snapshot version from the + // checkpoint and reported their current version (0 until numPartitions).map { partitionId => val storeId = StateStoreId(stateCheckpoint, 0, partitionId) val providerId = StateStoreProviderId(storeId, query.runId) @@ -724,14 +705,10 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { assert(latestSnapshotVersion.isEmpty) } else { // Verify other stores have uploaded a snapshot and it's properly logged - assert(latestSnapshotVersion.get >= firstRunLatestVersion) + assert(latestSnapshotVersion.get > 0) } } - // Sleep a bit to ensure that the coordinator has enough time to receive upload events - // The sleep duration is the maintenance interval times the config's multiplier - Thread.sleep(5 * 100) - // Verify that we're back to reporting the faulty state stores (partitionId 0 and 1) - // since enough versions and time has passed since the query's restart. + // Verify that we're reporting the faulty state stores (partitionId 0 and 1) val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) assert(laggingStores.size == 2) @@ -760,8 +737,6 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { withTempDir { srcDir => val inputData = MemoryStream[Int] val query = inputData.toDF().dropDuplicates() - // Keep track of state checkpoint directory and latest version for the second run - var stateCheckpoint = "" testStream(query)( StartStream(checkpointLocation = srcDir.getCanonicalPath), @@ -775,11 +750,7 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { Execute { query => val coordRef = query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator - stateCheckpoint = query.lastExecution.checkpointLocation val latestVersion = query.lastProgress.batchId + 1 - // Sleep a bit to ensure that the coordinator can start reporting lagging stores. - // The sleep duration is the maintenance interval times the config's multiplier. - Thread.sleep(5 * 100) // Verify that only the faulty stores are reported as lagging val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) @@ -807,8 +778,6 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { val coordRef = query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator val latestVersion = query.lastProgress.batchId + 1 - // Sleep a bit to ensure that the coordinator can start reporting lagging stores. - Thread.sleep(5 * 100) // Verify that we are not reporting any lagging stores despite restarting // because of the higher version multiplier assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) @@ -818,6 +787,72 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { } } } + + test("SPARK-51358: Infrequent maintenance using Trigger.AvailableNow should be reported") { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "50", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + + // Populate state stores with an initial snapshot, so that timestamp isn't marked + // as the default 0ms. + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable() + ) + // Increase maintenance interval to a much larger value to stop snapshot uploads + spark.conf.set(SQLConf.STREAMING_MAINTENANCE_INTERVAL.key, "60000") + // Execute a few batches in a short span + testStream(query)( + AddData(inputData, 1, 2, 3), + StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), + Execute { query => + query.awaitTermination() + // Verify the query ran with the AvailableNow trigger + assert(query.lastExecution.isTerminatingTrigger) + }, + AddData(inputData, 1, 2, 3), + StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), + Execute { query => query.awaitTermination() }, + // Start without available now, otherwise the stream closes too quickly for the + // testing RPC call to report lagging state stores + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Process data to activate state stores, but not enough to trigger snapshot uploads + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Verify that all faulty stores are reported as lagging despite the short burst. + // This test scenario mimics cases where snapshots have not been uploaded for a while + // due to the short running duration of AvailableNow. + val laggingStores = coordRef. + getLaggingStoresForTesting(query.runId, latestVersion, isTerminatingTrigger = true) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + StopStream + ) + } + } + } } object StateStoreCoordinatorSuite { From b2a7ccb9e07ecfe92dc949b3ee05522351c1272f Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Thu, 3 Apr 2025 15:21:31 -0700 Subject: [PATCH 30/36] SPARK-51358 Add extra tests for HDFS as well --- .../state/HDFSBackedStateStoreProvider.scala | 29 +- .../state/StateStoreCoordinatorSuite.scala | 368 ++++++++++-------- 2 files changed, 217 insertions(+), 180 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 2d25da5f6c1c..c33d7073e591 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -551,6 +551,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with val snapshotCurrentVersionMap = readSnapshotFile(version) if (snapshotCurrentVersionMap.isDefined) { synchronized { putStateIntoStateCacheMap(version, snapshotCurrentVersionMap.get) } + + // Report the loaded snapshot's version to the coordinator + reportSnapshotVersionToCoordinator(version) + return snapshotCurrentVersionMap.get } @@ -580,6 +584,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with } synchronized { putStateIntoStateCacheMap(version, resultMap) } + + // Report the last available snapshot's version to the coordinator + reportSnapshotVersionToCoordinator(lastAvailableVersion) + resultMap } @@ -1038,14 +1046,6 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with logDebug(s"Loading snapshot at version $snapshotVersion and apply delta files to version " + s"$endVersion takes $elapsedMs ms.") - // Report snapshot version loaded to the coordinator, and include the store ID with the message - if (storeConf.reportSnapshotUploadLag) { - val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf)) - // Since the snapshot was uploaded at a previous time, we set the upload timestamp as 0ms. - StateStoreProvider.coordinatorRef.foreach( - _.snapshotUploaded(StateStoreProviderId(stateStoreId, runId), snapshotVersion, 0L)) - } - result } @@ -1063,6 +1063,19 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), keySchema, valueSchema) } + + /** Reports to the coordinator the store's loaded snapshot version */ + private def reportSnapshotVersionToCoordinator(version: Long): Unit = { + // Skip reporting if the snapshot version is 0, which means there are no snapshots + if (storeConf.reportSnapshotUploadLag && version > 0L) { + val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf)) + // Since this is not the time when the snapshot was uploaded, we'll use 0ms to + // prevent the coordinator to use time lag checks for this store. + StateStoreProvider.coordinatorRef.foreach( + _.snapshotUploaded(StateStoreProviderId(stateStoreId, runId), version, 0L) + ) + } + } } /** [[StateStoreChangeDataReader]] implementation for [[HDFSBackedStateStoreProvider]] */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index d3c63edf1845..8c858fbe60cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -606,118 +606,128 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { class StateStoreCoordinatorStreamingSuite extends StreamTest { import testImplicits._ - test("SPARK-51358: Restarting queries do not mark state stores as lagging") { - withSQLConf( - SQLConf.SHUFFLE_PARTITIONS.key -> "3", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "2", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "5", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" - ) { - withTempDir { srcDir => - val inputData = MemoryStream[Int] - val query = inputData.toDF().dropDuplicates() - val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) - // Keep track of state checkpoint directory for the second run - var stateCheckpoint = "" - - testStream(query)( - StartStream(checkpointLocation = srcDir.getCanonicalPath), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - Execute { query => - val coordRef = - query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator - stateCheckpoint = query.lastExecution.checkpointLocation - val latestVersion = query.lastProgress.batchId + 1 - - // Verify the coordinator logged snapshot uploads - (0 until numPartitions).map { partitionId => - val storeId = StateStoreId(stateCheckpoint, 0, partitionId) - val providerId = StateStoreProviderId(storeId, query.runId) - if (partitionId <= 1) { - // Verify state stores in partition 0 and 1 are lagging and didn't upload - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) - } else { - // Verify other stores have uploaded a snapshot and it's properly logged - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + Seq( + ("RocksDB", classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName), + ("HDFS", classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Restarting queries do not mark state stores as lagging for $providerName" + ) { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "2", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) + // Keep track of state checkpoint directory for the second run + var stateCheckpoint = "" + + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + stateCheckpoint = query.lastExecution.checkpointLocation + val latestVersion = query.lastProgress.batchId + 1 + + // Verify the coordinator logged snapshot uploads + (0 until numPartitions).map { + partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + if (partitionId <= 1) { + // Verify state stores in partition 0 and 1 are lagging and didn't upload + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) + } else { + // Verify other stores have uploaded a snapshot and it's properly logged + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } + } + // Verify that the normal state store (partitionId=2) is not lagging behind, + // and the faulty stores are reported as lagging. + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + // Stopping the streaming query should deactivate and clear snapshot uploaded events + StopStream, + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + + // Verify we evicted the previous latest uploaded snapshots from the coordinator + (0 until numPartitions).map { partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) + } + // Verify that we are not reporting any lagging stores after eviction, + // since none of these state stores are active anymore. + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) } - } - // Verify that the normal state store (partitionId=2) is not lagging behind, - // and the faulty stores are reported as lagging. - val laggingStores = - coordRef.getLaggingStoresForTesting(query.runId, latestVersion) - assert(laggingStores.size == 2) - assert(laggingStores.forall(_.storeId.partitionId <= 1)) - }, - // Stopping the streaming query should deactivate and clear snapshot uploaded events - StopStream, - Execute { query => - val coordRef = - query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator - val latestVersion = query.lastProgress.batchId + 1 - - // Verify we evicted the previous latest uploaded snapshots from the coordinator - (0 until numPartitions).map { partitionId => - val storeId = StateStoreId(stateCheckpoint, 0, partitionId) - val providerId = StateStoreProviderId(storeId, query.runId) - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) - } - // Verify that we are not reporting any lagging stores after eviction, - // since none of these state stores are active anymore. - assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + ) + // Restart the query, but do not add too much data so that we don't associate + // the current StateStoreProviderId (store id + query run id) with any new uploads. + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Perform one round of data, which is enough to activate instances and force a + // lagging instance report, but not enough to trigger a snapshot upload yet. + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Verify that the state stores have restored their snapshot version from the + // checkpoint and reported their current version + (0 until numPartitions).map { + partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + val latestSnapshotVersion = + coordRef.getLatestSnapshotVersionForTesting(providerId) + if (partitionId <= 1) { + // Verify state stores in partition 0/1 are still lagging and didn't upload + assert(latestSnapshotVersion.isEmpty) + } else { + // Verify other stores have uploaded a snapshot and it's properly logged + assert(latestSnapshotVersion.get > 0) + } + } + // Verify that we're reporting the faulty state stores (partitionId 0 and 1) + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + StopStream + ) } - ) - // Restart the query, but do not add too much data so that we don't associate - // the current StateStoreProviderId (store id + query run id) with any new uploads. - testStream(query)( - StartStream(checkpointLocation = srcDir.getCanonicalPath), - // Perform one round of data, which is enough to activate instances and force a - // lagging instance report, but not enough to trigger a snapshot upload yet. - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - Execute { query => - val coordRef = - query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator - val latestVersion = query.lastProgress.batchId + 1 - // Verify that the state stores have restored their snapshot version from the - // checkpoint and reported their current version - (0 until numPartitions).map { partitionId => - val storeId = StateStoreId(stateCheckpoint, 0, partitionId) - val providerId = StateStoreProviderId(storeId, query.runId) - val latestSnapshotVersion = coordRef.getLatestSnapshotVersionForTesting(providerId) - if (partitionId <= 1) { - // Verify state stores in partition 0 and 1 are still lagging and didn't upload - assert(latestSnapshotVersion.isEmpty) - } else { - // Verify other stores have uploaded a snapshot and it's properly logged - assert(latestSnapshotVersion.get > 0) - } - } - // Verify that we're reporting the faulty state stores (partitionId 0 and 1) - val laggingStores = - coordRef.getLaggingStoresForTesting(query.runId, latestVersion) - assert(laggingStores.size == 2) - assert(laggingStores.forall(_.storeId.partitionId <= 1)) - }, - StopStream - ) + } } - } } test("SPARK-51358: Restarting queries with updated SQLConf get propagated to the coordinator") { @@ -788,70 +798,84 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { } } - test("SPARK-51358: Infrequent maintenance using Trigger.AvailableNow should be reported") { - withSQLConf( - SQLConf.SHUFFLE_PARTITIONS.key -> "2", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "50", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" - ) { - withTempDir { srcDir => - val inputData = MemoryStream[Int] - val query = inputData.toDF().dropDuplicates() - - // Populate state stores with an initial snapshot, so that timestamp isn't marked - // as the default 0ms. - testStream(query)( - StartStream(checkpointLocation = srcDir.getCanonicalPath), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable() - ) - // Increase maintenance interval to a much larger value to stop snapshot uploads - spark.conf.set(SQLConf.STREAMING_MAINTENANCE_INTERVAL.key, "60000") - // Execute a few batches in a short span - testStream(query)( - AddData(inputData, 1, 2, 3), - StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), - Execute { query => - query.awaitTermination() - // Verify the query ran with the AvailableNow trigger - assert(query.lastExecution.isTerminatingTrigger) - }, - AddData(inputData, 1, 2, 3), - StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), - Execute { query => query.awaitTermination() }, - // Start without available now, otherwise the stream closes too quickly for the - // testing RPC call to report lagging state stores - StartStream(checkpointLocation = srcDir.getCanonicalPath), - // Process data to activate state stores, but not enough to trigger snapshot uploads - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - Execute { query => - val coordRef = - query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator - val latestVersion = query.lastProgress.batchId + 1 - // Verify that all faulty stores are reported as lagging despite the short burst. - // This test scenario mimics cases where snapshots have not been uploaded for a while - // due to the short running duration of AvailableNow. - val laggingStores = coordRef. - getLaggingStoresForTesting(query.runId, latestVersion, isTerminatingTrigger = true) - assert(laggingStores.size == 2) - assert(laggingStores.forall(_.storeId.partitionId <= 1)) - }, - StopStream - ) + Seq( + ("RocksDB", classOf[RocksDBStateStoreProvider].getName), + ("HDFS", classOf[HDFSBackedStateStoreProvider].getName) + ).foreach { + case (providerName, providerClassName) => + test( + s"SPARK-51358: Infrequent maintenance with $providerName using Trigger.AvailableNow " + + s"should be reported" + ) { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "50", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" + ) { + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + + // Populate state stores with an initial snapshot, so that timestamp isn't marked + // as the default 0ms. + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable() + ) + // Increase maintenance interval to a much larger value to stop snapshot uploads + spark.conf.set(SQLConf.STREAMING_MAINTENANCE_INTERVAL.key, "60000") + // Execute a few batches in a short span + testStream(query)( + AddData(inputData, 1, 2, 3), + StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), + Execute { query => + query.awaitTermination() + // Verify the query ran with the AvailableNow trigger + assert(query.lastExecution.isTerminatingTrigger) + }, + AddData(inputData, 1, 2, 3), + StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), + Execute { query => + query.awaitTermination() + }, + // Start without available now, otherwise the stream closes too quickly for the + // testing RPC call to report lagging state stores + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Process data to activate state stores, but not enough to trigger snapshot uploads + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Verify that all faulty stores are reported as lagging despite the short burst. + // This test scenario mimics cases where snapshots have not been uploaded for + // a while due to the short running duration of AvailableNow. + val laggingStores = coordRef.getLaggingStoresForTesting( + query.runId, + latestVersion, + isTerminatingTrigger = true + ) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + StopStream + ) + } + } } - } } } From ec585a32b4a93fdc91c90c970e38f20ddf52bb0c Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Fri, 4 Apr 2025 07:58:17 -0700 Subject: [PATCH 31/36] SPARK-51358 Report timestamp when loading snapshot --- .../state/HDFSBackedStateStoreProvider.scala | 31 ++++++------------- .../execution/streaming/state/RocksDB.scala | 14 ++------- .../state/RocksDBStateStoreProvider.scala | 15 --------- .../state/StateStoreCoordinatorSuite.scala | 22 +++++++++---- 4 files changed, 29 insertions(+), 53 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index c33d7073e591..98d49596d11b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -553,7 +553,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with synchronized { putStateIntoStateCacheMap(version, snapshotCurrentVersionMap.get) } // Report the loaded snapshot's version to the coordinator - reportSnapshotVersionToCoordinator(version) + reportSnapshotUploadToCoordinator(version) return snapshotCurrentVersionMap.get } @@ -586,7 +586,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with synchronized { putStateIntoStateCacheMap(version, resultMap) } // Report the last available snapshot's version to the coordinator - reportSnapshotVersionToCoordinator(lastAvailableVersion) + reportSnapshotUploadToCoordinator(lastAvailableVersion) resultMap } @@ -707,18 +707,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with log"for ${MDC(LogKeys.OP_TYPE, opType)}") // Compare and update with the version that was just uploaded. lastUploadedSnapshotVersion.updateAndGet(v => Math.max(version, v)) - // Report snapshot upload event to the coordinator, and include the store ID with the message. - if (storeConf.reportSnapshotUploadLag) { - val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf)) - val currentTimestamp = System.currentTimeMillis() - StateStoreProvider.coordinatorRef.foreach( - _.snapshotUploaded( - StateStoreProviderId(stateStoreId, runId), - version, - currentTimestamp - ) - ) - } + // Report the snapshot upload event to the coordinator + reportSnapshotUploadToCoordinator(version) } /** @@ -1064,15 +1054,14 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with keySchema, valueSchema) } - /** Reports to the coordinator the store's loaded snapshot version */ - private def reportSnapshotVersionToCoordinator(version: Long): Unit = { - // Skip reporting if the snapshot version is 0, which means there are no snapshots - if (storeConf.reportSnapshotUploadLag && version > 0L) { + /** Reports to the coordinator the store's latest snapshot version */ + private def reportSnapshotUploadToCoordinator(version: Long): Unit = { + if (storeConf.reportSnapshotUploadLag) { + // Attach the query run ID and current timestamp to the RPC message val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf)) - // Since this is not the time when the snapshot was uploaded, we'll use 0ms to - // prevent the coordinator to use time lag checks for this store. + val currentTimestamp = System.currentTimeMillis() StateStoreProvider.coordinatorRef.foreach( - _.snapshotUploaded(StateStoreProviderId(stateStoreId, runId), version, 0L) + _.snapshotUploaded(StateStoreProviderId(stateStoreId, runId), version, currentTimestamp) ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 0713a22068a8..9344bddf71aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -390,7 +390,7 @@ class RocksDB( fileManager.setMaxSeenVersion(version) // Report this snapshot version to the coordinator - reportSnapshotVersionToCoordinator(latestSnapshotVersion) + reportSnapshotUploadToCoordinator(latestSnapshotVersion) openLocalRocksDB(metadata) @@ -470,7 +470,7 @@ class RocksDB( fileManager.setMaxSeenVersion(version) // Report this snapshot version to the coordinator - reportSnapshotVersionToCoordinator(latestSnapshotVersion) + reportSnapshotUploadToCoordinator(latestSnapshotVersion) openLocalRocksDB(metadata) @@ -610,7 +610,7 @@ class RocksDB( throw t } // Report this snapshot version to the coordinator - reportSnapshotVersionToCoordinator(snapshotVersion) + reportSnapshotUploadToCoordinator(snapshotVersion) this } @@ -1503,14 +1503,6 @@ class RocksDB( } } - /** Reports to the coordinator the store's latest loaded snapshot version */ - private def reportSnapshotVersionToCoordinator(version: Long): Unit = { - // Skip reporting if the snapshot version is 0, which means there are no snapshots - if (conf.reportSnapshotUploadLag && version > 0L) { - eventListener.foreach(_.reportSnapshotVersion(version)) - } - } - /** Create a native RocksDB logger that forwards native logs to log4j with correct log levels. */ private def createLogger(): Logger = { val dbLogger = new Logger(rocksDbOptions.infoLogLevel()) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index e0be9e2a911d..a35f3acaa71a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -1002,19 +1002,4 @@ private[state] case class RocksDBEventListener(queryRunId: String, stateStoreId: ) ) } - - /** - * Report the state store's last loaded snapshot version to the coordinator. - * This method is used when the store loads a snapshot version instead of - * uploading a snapshot version, as the context of the snapshot upload timestamp - * is missing. - * - * @param version The snapshot version that was just uploaded from RocksDB - */ - def reportSnapshotVersion(version: Long): Unit = { - // Report the state store provider ID and the version to the coordinator. - // Since this is not the time when the snapshot was uploaded, we'll use 0ms to - // prevent the coordinator to use time lag checks for this store. - StateStoreProvider.coordinatorRef.foreach(_.snapshotUploaded(providerId, version, 0L)) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 8c858fbe60cd..1ef8251158d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -271,10 +271,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) val providerId = StateStoreProviderId(storeId, query.runId) if (partitionId <= 1) { - // Verify state stores in partition 0 and 1 are lagging and didn't upload anything - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) + // Verify state stores in partition 0/1 are lagging and didn't upload anything + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0) } else { - // Verify other stores have uploaded a snapshot and it's logged by the coordinator + // Verify other stores uploaded a snapshot and it's logged by the coordinator assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) } } @@ -417,7 +417,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { val providerId = StateStoreProviderId(storeId, query.runId) if (partitionId <= 1) { // Verify state stores in partition 0 and 1 are lagging and didn't upload - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) + assert( + coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0 + ) } else { // Verify other stores have uploaded a snapshot and it's properly logged assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) @@ -658,7 +660,9 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { val providerId = StateStoreProviderId(storeId, query.runId) if (partitionId <= 1) { // Verify state stores in partition 0 and 1 are lagging and didn't upload - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) + assert( + coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0 + ) } else { // Verify other stores have uploaded a snapshot and it's properly logged assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) @@ -711,12 +715,14 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { coordRef.getLatestSnapshotVersionForTesting(providerId) if (partitionId <= 1) { // Verify state stores in partition 0/1 are still lagging and didn't upload - assert(latestSnapshotVersion.isEmpty) + assert(latestSnapshotVersion.getOrElse(0) == 0) } else { // Verify other stores have uploaded a snapshot and it's properly logged assert(latestSnapshotVersion.get > 0) } } + // Sleep a bit to allow the coordinator to pass the time threshold and report lag + Thread.sleep(5 * 100) // Verify that we're reporting the faulty state stores (partitionId 0 and 1) val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) @@ -761,6 +767,8 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { val coordRef = query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator val latestVersion = query.lastProgress.batchId + 1 + // Sleep a bit to allow the coordinator to pass the time threshold and report lag + Thread.sleep(5 * 100) // Verify that only the faulty stores are reported as lagging val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) @@ -788,6 +796,8 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { val coordRef = query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator val latestVersion = query.lastProgress.batchId + 1 + // Sleep the same amount to mimic conditions from first run + Thread.sleep(5 * 100) // Verify that we are not reporting any lagging stores despite restarting // because of the higher version multiplier assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) From 1710fe4c56600fd4055f7eefe409a7c91d18c3c8 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Fri, 4 Apr 2025 08:11:40 -0700 Subject: [PATCH 32/36] SPARK-51358 Bump to retrigger From d8b2184a317e208fe5f3a4af73c5e6306c073978 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Fri, 4 Apr 2025 11:20:40 -0700 Subject: [PATCH 33/36] SPARK-51358 Fix test --- .../test/java/test/org/apache/spark/sql/JavaDatasetSuite.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index bffd2c5d9f70..692b5c0ebc3a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -63,6 +63,9 @@ public void setUp() { spark = new TestSparkSession(); jsc = new JavaSparkContext(spark.sparkContext()); spark.loadTestData(); + + // Initialize state store coordinator endpoint + spark.streams().stateStoreCoordinator(); } @AfterEach From 735b356fef8955b48a554264bf9b2ffedcfc1da1 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Wed, 9 Apr 2025 15:58:26 -0700 Subject: [PATCH 34/36] SPARK-51358 Fix merge and nits, and rename the event listener to forwarder --- .../execution/streaming/state/RocksDB.scala | 6 +++--- .../state/RocksDBStateStoreProvider.scala | 14 +++++++------- .../state/StateStoreCoordinator.scala | 7 ++++--- ...ailureInjectionCheckpointFileManager.scala | 12 ++++++++---- ...cksDBCheckpointFailureInjectionSuite.scala | 3 ++- .../state/StateStoreCoordinatorSuite.scala | 19 ++++++++++++------- 6 files changed, 36 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 62938675e423..07553f51c60e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -64,7 +64,7 @@ case object StoreTaskCompletionListener extends RocksDBOpType("store_task_comple * @param stateStoreId StateStoreId for the state store * @param localRootDir Root directory in local disk that is used to working and checkpointing dirs * @param hadoopConf Hadoop configuration for talking to the remote file system - * @param eventListener The RocksDBEventListener object for reporting events to the coordinator + * @param eventForwarder The RocksDBEventForwarder object for reporting events to the coordinator */ class RocksDB( dfsRootDir: String, @@ -75,7 +75,7 @@ class RocksDB( useColumnFamilies: Boolean = false, enableStateStoreCheckpointIds: Boolean = false, partitionId: Int = 0, - eventListener: Option[RocksDBEventListener] = None) extends Logging { + eventForwarder: Option[RocksDBEventForwarder] = None) extends Logging { import RocksDB._ @@ -1520,7 +1520,7 @@ class RocksDB( // Note that we still report snapshot versions even when changelog checkpointing is disabled. // The coordinator needs a way to determine whether upload messages are disabled or not, // which would be different between RocksDB and HDFS stores due to changelog checkpointing. - eventListener.foreach(_.reportSnapshotUploaded(version)) + eventForwarder.foreach(_.reportSnapshotUploaded(version)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index ec2504e7562a..6a36b8c01519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -390,8 +390,8 @@ private[sql] class RocksDBStateStoreProvider this.useColumnFamilies = useColumnFamilies this.stateStoreEncoding = storeConf.stateStoreEncodingFormat this.stateSchemaProvider = stateSchemaProvider - this.rocksDBEventListener = - RocksDBEventListener(StateStoreProvider.getRunId(hadoopConf), stateStoreId) + this.rocksDBEventForwarder = + Some(RocksDBEventForwarder(StateStoreProvider.getRunId(hadoopConf), stateStoreId)) if (useMultipleValuesPerKey) { require(useColumnFamilies, "Multiple values per key support requires column families to be" + @@ -525,7 +525,7 @@ private[sql] class RocksDBStateStoreProvider @volatile private var useColumnFamilies: Boolean = _ @volatile private var stateStoreEncoding: String = _ @volatile private var stateSchemaProvider: Option[StateSchemaProvider] = _ - @volatile private var rocksDBEventListener: RocksDBEventListener = _ + @volatile private var rocksDBEventForwarder: Option[RocksDBEventForwarder] = _ protected def createRocksDB( dfsRootDir: String, @@ -536,7 +536,7 @@ private[sql] class RocksDBStateStoreProvider useColumnFamilies: Boolean, enableStateStoreCheckpointIds: Boolean, partitionId: Int = 0, - eventListener: Option[RocksDBEventListener] = None): RocksDB = { + eventForwarder: Option[RocksDBEventForwarder] = None): RocksDB = { new RocksDB( dfsRootDir, conf, @@ -546,7 +546,7 @@ private[sql] class RocksDBStateStoreProvider useColumnFamilies, enableStateStoreCheckpointIds, partitionId, - eventListener) + eventForwarder) } private[sql] lazy val rocksDB = { @@ -557,7 +557,7 @@ private[sql] class RocksDBStateStoreProvider val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr) createRocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr, useColumnFamilies, storeConf.enableStateStoreCheckpointIds, stateStoreId.partitionId, - Some(rocksDBEventListener)) + rocksDBEventForwarder) } private val keyValueEncoderMap = new java.util.concurrent.ConcurrentHashMap[String, @@ -994,7 +994,7 @@ class RocksDBStateStoreChangeDataReader( * We pass this into the RocksDB instance to report specific events like snapshot uploads. * This should only be used to report back to the coordinator for metrics and monitoring purposes. */ -private[state] case class RocksDBEventListener(queryRunId: String, stateStoreId: StateStoreId) { +private[state] case class RocksDBEventForwarder(queryRunId: String, stateStoreId: StateStoreId) { // Build the state store provider ID from the query run ID and the state store ID private val providerId = StateStoreProviderId(stateStoreId, UUID.fromString(queryRunId)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index c7df34c77e5b..903f27fb2a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -163,7 +163,7 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { private[sql] def snapshotUploaded( providerId: StateStoreProviderId, version: Long, - timestamp: Long): Unit = { + timestamp: Long): Boolean = { rpcEndpointRef.askSync[Boolean](ReportSnapshotUploaded(providerId, version, timestamp)) } @@ -171,7 +171,7 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { private[sql] def logLaggingStateStores( queryRunId: UUID, latestVersion: Long, - isTerminatingTrigger: Boolean): Unit = { + isTerminatingTrigger: Boolean): Boolean = { rpcEndpointRef.askSync[Boolean]( LogLaggingStateStores(queryRunId, latestVersion, isTerminatingTrigger)) } @@ -371,7 +371,8 @@ private class StateStoreCoordinator( val minDeltasForSnapshot = sqlConf.stateStoreMinDeltasForSnapshot val maintenanceInterval = sqlConf.streamingMaintenanceInterval - // Use the configured multipliers to determine the proper alert thresholds + // Use the configured multipliers multiplierForMinVersionDiffToLog and + // multiplierForMinTimeDiffToLog to determine the proper alert thresholds. val minVersionDeltaForLogging = snapshotVersionDeltaMultiplier * minDeltasForSnapshot val minTimeDeltaForLogging = maintenanceIntervalMultiplier * maintenanceInterval diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala index 4711a45804fb..ebbdd1ad63ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala @@ -251,7 +251,8 @@ class FailureInjectionRocksDBStateStoreProvider extends RocksDBStateStoreProvide loggingId: String, useColumnFamilies: Boolean, enableStateStoreCheckpointIds: Boolean, - partitionId: Int): RocksDB = { + partitionId: Int, + eventForwarder: Option[RocksDBEventForwarder] = None): RocksDB = { FailureInjectionRocksDBStateStoreProvider.createRocksDBWithFaultInjection( dfsRootDir, conf, @@ -260,7 +261,8 @@ class FailureInjectionRocksDBStateStoreProvider extends RocksDBStateStoreProvide loggingId, useColumnFamilies, enableStateStoreCheckpointIds, - partitionId) + partitionId, + eventForwarder) } } @@ -277,7 +279,8 @@ object FailureInjectionRocksDBStateStoreProvider { loggingId: String, useColumnFamilies: Boolean, enableStateStoreCheckpointIds: Boolean, - partitionId: Int): RocksDB = { + partitionId: Int, + eventForwarder: Option[RocksDBEventForwarder]): RocksDB = { new RocksDB( dfsRootDir, conf = conf, @@ -286,7 +289,8 @@ object FailureInjectionRocksDBStateStoreProvider { loggingId = loggingId, useColumnFamilies = useColumnFamilies, enableStateStoreCheckpointIds = enableStateStoreCheckpointIds, - partitionId = partitionId + partitionId = partitionId, + eventForwarder = eventForwarder ) { override def createFileManager( dfsRootDir: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala index 31fc51c4d56f..5c24ec209036 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala @@ -523,7 +523,8 @@ class RocksDBCheckpointFailureInjectionSuite extends StreamTest loggingId = s"[Thread-${Thread.currentThread.getId}]", useColumnFamilies = true, enableStateStoreCheckpointIds = enableStateStoreCheckpointIds, - partitionId = 0) + partitionId = 0, + eventForwarder = None) db.load(version, checkpointId) func(db) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 1ef8251158d7..b3ac66742bd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -527,7 +527,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "false", SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "1", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "1", SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { case (coordRef, spark) => @@ -544,12 +544,17 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .queryName("query") .option("checkpointLocation", checkpointLocation.toString) .start() - // Go through several rounds of input to force snapshot uploads - (0 until 5).foreach { _ => - inputData.addData(1, 2, 3) - query.processAllAvailable() - Thread.sleep(500) - } + // Go through two batches to force two snapshot uploads. + // This would be enough to pass the version check for lagging stores. + inputData.addData(1, 2, 3) + query.processAllAvailable() + inputData.addData(1, 2, 3) + query.processAllAvailable() + + // Sleep for the duration of a maintenance interval - which should be enough + // to pass the time check for lagging stores. + Thread.sleep(100) + val latestVersion = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 // Verify that no instances are marked as lagging, even when upload messages are sent. From 2d60ea9f48e04247b6625408ecc6f8b510b2abc8 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Thu, 10 Apr 2025 10:17:00 -0700 Subject: [PATCH 35/36] SPARK-51358 Clean up and dedupe test code --- .../state/StateStoreCoordinatorSuite.scala | 819 +++++++----------- 1 file changed, 335 insertions(+), 484 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index b3ac66742bd9..09118edc4357 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWra import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.functions.{count, expr} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.streaming.{StreamingQuery, StreamTest, Trigger} import org.apache.spark.util.Utils class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { @@ -158,281 +158,162 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { } } - Seq( - ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), - ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) - ).foreach { - case (providerName, providerClassName) => - test( - s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" - ) { - withCoordinatorAndSQLConf( - sc, - SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" - ) { - case (coordRef, spark) => - import spark.implicits._ - implicit val sqlContext = spark.sqlContext - - // Start a query and run some data to force snapshot uploads - val inputData = MemoryStream[Int] - val aggregated = inputData.toDF().dropDuplicates() - val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val query = aggregated.writeStream - .format("memory") - .outputMode("update") - .queryName("query") - .option("checkpointLocation", checkpointLocation.toString) - .start() - // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 6).foreach { _ => - inputData.addData(1, 2, 3) - query.processAllAvailable() - Thread.sleep(500) - } - val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery - val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation - val latestVersion = streamingQuery.lastProgress.batchId + 1 - - // Verify all stores have uploaded a snapshot and it's logged by the coordinator - (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { - partitionId => - val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) - val providerId = StateStoreProviderId(storeId, query.runId) - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) - } - // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) - query.stop() - } - } - } + private val allJoinStateStoreNames: Seq[String] = + SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) - Seq( + /** Lists the state store providers used for a test, and the set of lagging partition IDs */ + private val regularStateStoreProviders = Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName, Set.empty[Int]), + ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName, Set.empty[Int]) + ) + + /** Lists the state store providers used for a test, and the set of lagging partition IDs */ + private val faultyStateStoreProviders = Seq( ( "RocksDBSkipMaintenanceOnCertainPartitionsProvider", - classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName + classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName, + Set(0, 1) ), ( "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", - classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName + classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName, + Set(0, 1) ) - ).foreach { - case (providerName, providerClassName) => - test( - s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" - ) { - withCoordinatorAndSQLConf( - sc, - SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" - ) { - case (coordRef, spark) => - import spark.implicits._ - implicit val sqlContext = spark.sqlContext - - // Start a query and run some data to force snapshot uploads - val inputData = MemoryStream[Int] - val aggregated = inputData.toDF().dropDuplicates() - val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val query = aggregated.writeStream - .format("memory") - .outputMode("update") - .queryName("query") - .option("checkpointLocation", checkpointLocation.toString) - .start() - // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 6).foreach { _ => - inputData.addData(1, 2, 3) - query.processAllAvailable() - Thread.sleep(500) - } - val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery - val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation - val latestVersion = streamingQuery.lastProgress.batchId + 1 - - (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { - partitionId => - val storeId = StateStoreId(stateCheckpointDir, 0, partitionId) - val providerId = StateStoreProviderId(storeId, query.runId) - if (partitionId <= 1) { - // Verify state stores in partition 0/1 are lagging and didn't upload anything - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0) - } else { - // Verify other stores uploaded a snapshot and it's logged by the coordinator - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) - } - } - // We should have two state stores (id 0 and 1) that are lagging behind at this point - val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) - assert(laggingStores.size == 2) - assert(laggingStores.forall(_.storeId.partitionId <= 1)) - query.stop() + ) + + private val allStateStoreProviders = + regularStateStoreProviders ++ faultyStateStoreProviders + + /** + * Verifies snapshot upload RPC messages from state stores are registered and verifies + * the coordinator detected the correct lagging partitions. + */ + private def verifySnapshotUploadEvents( + coordRef: StateStoreCoordinatorRef, + query: StreamingQuery, + badPartitions: Set[Int], + storeNames: Seq[String] = Seq(StateStoreId.DEFAULT_STORE_NAME)): Unit = { + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery + val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation + val latestVersion = streamingQuery.lastProgress.batchId + 1 + + // Verify all stores have uploaded a snapshot and it's logged by the coordinator + (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { + partitionId => + // Verify for every store name listed + storeNames.foreach { storeName => + val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) + val providerId = StateStoreProviderId(storeId, query.runId) + val latestSnapshotVersion = coordRef.getLatestSnapshotVersionForTesting(providerId) + if (badPartitions.contains(partitionId)) { + assert(latestSnapshotVersion.getOrElse(0) == 0) + } else { + assert(latestSnapshotVersion.get >= 0) + } } - } + } + // Verify that only the bad partitions are all marked as lagging. + // Join queries should have all their state stores marked as lagging, + // which would be 4 stores per partition instead of 1. + val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == badPartitions.size * storeNames.size) + assert(laggingStores.map(_.storeId.partitionId).toSet == badPartitions) } - private val allJoinStateStoreNames: Seq[String] = - SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) + /** Sets up a stateful dropDuplicate query for testing */ + private def setUpStatefulQuery( + inputData: MemoryStream[Int], queryName: String): StreamingQuery = { + // Set up a stateful drop duplicate query + val aggregated = inputData.toDF().dropDuplicates() + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = aggregated.writeStream + .format("memory") + .outputMode("update") + .queryName(queryName) + .option("checkpointLocation", checkpointLocation.toString) + .start() + query + } - Seq( - ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName), - ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName) - ).foreach { - case (providerName, providerClassName) => - test( - s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + - s"reported to the coordinator" + allStateStoreProviders.foreach { case (providerName, providerClassName, badPartitions) => + test( + s"SPARK-51358: Snapshot uploads in $providerName are properly reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { - withCoordinatorAndSQLConf( - sc, - SQLConf.SHUFFLE_PARTITIONS.key -> "3", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", - SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" - ) { - case (coordRef, spark) => - import spark.implicits._ - implicit val sqlContext = spark.sqlContext - - // Start a join query and run some data to force snapshot uploads - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] - val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") - val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") - val joined = df1.join(df2, expr("leftKey = rightKey")) - val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val query = joined.writeStream - .format("memory") - .queryName("query") - .option("checkpointLocation", checkpointLocation.toString) - .start() - // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 7).foreach { _ => - input1.addData(1, 5) - input2.addData(1, 5, 10) - query.processAllAvailable() - Thread.sleep(500) - } - val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery - val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation - val latestVersion = streamingQuery.lastProgress.batchId + 1 - - // Verify all state stores for join queries are reporting snapshot uploads - (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { - partitionId => - allJoinStateStoreNames.foreach { storeName => - val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) - val providerId = StateStoreProviderId(storeId, query.runId) - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) - } - } - // Verify that we should not have any state stores lagging behind - assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) - query.stop() - } + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + val inputData = MemoryStream[Int] + val query = setUpStatefulQuery(inputData, "query") + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 6).foreach { _ => + inputData.addData(1, 2, 3) + query.processAllAvailable() + Thread.sleep(500) + } + // Verify only the partitions in badPartitions are marked as lagging + verifySnapshotUploadEvents(coordRef, query, badPartitions) + query.stop() } + } } - Seq( - ( - "RocksDBSkipMaintenanceOnCertainPartitionsProvider", - classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName - ), - ( - "HDFSBackedSkipMaintenanceOnCertainPartitionsProvider", - classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName - ) - ).foreach { - case (providerName, providerClassName) => - test( - s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + - s"reported to the coordinator" + allStateStoreProviders.foreach { case (providerName, providerClassName, badPartitions) => + test( + s"SPARK-51358: Snapshot uploads for join queries with $providerName are properly " + + s"reported to the coordinator" + ) { + withCoordinatorAndSQLConf( + sc, + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", + SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" ) { - withCoordinatorAndSQLConf( - sc, - SQLConf.SHUFFLE_PARTITIONS.key -> "3", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "5", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0", - SQLConf.STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT.key -> "5" - ) { - case (coordRef, spark) => - import spark.implicits._ - implicit val sqlContext = spark.sqlContext - - // Start a join query and run some data to force snapshot uploads - val input1 = MemoryStream[Int] - val input2 = MemoryStream[Int] - val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") - val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") - val joined = df1.join(df2, expr("leftKey = rightKey")) - val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val query = joined.writeStream - .format("memory") - .queryName("query") - .option("checkpointLocation", checkpointLocation.toString) - .start() - // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 7).foreach { _ => - input1.addData(1, 5) - input2.addData(1, 5, 10) - query.processAllAvailable() - Thread.sleep(500) - } - val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery - val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation - val latestVersion = streamingQuery.lastProgress.batchId + 1 - // Verify all state stores for join queries are reporting snapshot uploads - (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach { - partitionId => - allJoinStateStoreNames.foreach { storeName => - val storeId = StateStoreId(stateCheckpointDir, 0, partitionId, storeName) - val providerId = StateStoreProviderId(storeId, query.runId) - if (partitionId <= 1) { - // Verify state stores in partition 0 and 1 are lagging and didn't upload - assert( - coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0 - ) - } else { - // Verify other stores have uploaded a snapshot and it's properly logged - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) - } - } - } - // Verify that only stores from partition id 0 and 1 are lagging behind. - // Each partition has 4 stores for join queries, so there are 2 * 4 = 8 lagging stores. - val laggingStores = coordRef.getLaggingStoresForTesting(query.runId, latestVersion) - assert(laggingStores.size == 2 * 4) - assert(laggingStores.forall(_.storeId.partitionId <= 1)) - } + case (coordRef, spark) => + import spark.implicits._ + implicit val sqlContext = spark.sqlContext + // Start a join query and run some data to force snapshot uploads + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) as "leftValue") + val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 3) as "rightValue") + val joined = df1.join(df2, expr("leftKey = rightKey")) + val checkpointLocation = Utils.createTempDir().getAbsoluteFile + val query = joined.writeStream + .format("memory") + .queryName("query") + .option("checkpointLocation", checkpointLocation.toString) + .start() + // Add, commit, and wait multiple times to force snapshot versions and time difference + (0 until 7).foreach { _ => + input1.addData(1, 5) + input2.addData(1, 5, 10) + query.processAllAvailable() + Thread.sleep(500) + } + // Verify only the partitions in badPartitions are marked as lagging + verifySnapshotUploadEvents(coordRef, query, badPartitions, allJoinStateStoreNames) + query.stop() } + } } test("SPARK-51358: Verify coordinator properly handles simultaneous query runs") { @@ -452,26 +333,12 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { case (coordRef, spark) => import spark.implicits._ implicit val sqlContext = spark.sqlContext - // Start and run two queries together with some data to force snapshot uploads val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] - val dedupe1 = input1.toDF().dropDuplicates() - val dedupe2 = input2.toDF().dropDuplicates() - val checkpointLocation1 = Utils.createTempDir().getAbsoluteFile - val checkpointLocation2 = Utils.createTempDir().getAbsoluteFile - val query1 = dedupe1.writeStream - .format("memory") - .outputMode("update") - .queryName("query1") - .option("checkpointLocation", checkpointLocation1.toString) - .start() - val query2 = dedupe2.writeStream - .format("memory") - .outputMode("update") - .queryName("query2") - .option("checkpointLocation", checkpointLocation2.toString) - .start() + val query1 = setUpStatefulQuery(input1, "query1") + val query2 = setUpStatefulQuery(input2, "query2") + // Go through several rounds of input to force snapshot uploads for both queries (0 until 2).foreach { _ => input1.addData(1, 2, 3) @@ -533,17 +400,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { case (coordRef, spark) => import spark.implicits._ implicit val sqlContext = spark.sqlContext - // Start a query and run some data to force snapshot uploads val inputData = MemoryStream[Int] - val aggregated = inputData.toDF().dropDuplicates() - val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val query = aggregated.writeStream - .format("memory") - .outputMode("update") - .queryName("query") - .option("checkpointLocation", checkpointLocation.toString) - .start() + val query = setUpStatefulQuery(inputData, "query") + // Go through two batches to force two snapshot uploads. // This would be enough to pass the version check for lagging stores. inputData.addData(1, 2, 3) @@ -584,19 +444,12 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { case (coordRef, spark) => import spark.implicits._ implicit val sqlContext = spark.sqlContext - // Start a query and run some data to force snapshot uploads val inputData = MemoryStream[Int] - val aggregated = inputData.toDF().dropDuplicates() - val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val query = aggregated.writeStream - .format("memory") - .outputMode("update") - .queryName("query") - .option("checkpointLocation", checkpointLocation.toString) - .start() + val query = setUpStatefulQuery(inputData, "query") + // Go through several rounds of input to force snapshot uploads - (0 until 5).foreach { _ => + (0 until 3).foreach { _ => inputData.addData(1, 2, 3) query.processAllAvailable() Thread.sleep(500) @@ -616,129 +469,128 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { Seq( ("RocksDB", classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName), ("HDFS", classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName) - ).foreach { - case (providerName, providerClassName) => - test( - s"SPARK-51358: Restarting queries do not mark state stores as lagging for $providerName" + ).foreach { case (providerName, providerClassName) => + test( + s"SPARK-51358: Restarting queries do not mark state stores as lagging for $providerName" + ) { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "2", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "5", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { - withSQLConf( - SQLConf.SHUFFLE_PARTITIONS.key -> "3", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "2", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "5", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" - ) { - withTempDir { srcDir => - val inputData = MemoryStream[Int] - val query = inputData.toDF().dropDuplicates() - val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) - // Keep track of state checkpoint directory for the second run - var stateCheckpoint = "" - - testStream(query)( - StartStream(checkpointLocation = srcDir.getCanonicalPath), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - Execute { query => - val coordRef = - query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator - stateCheckpoint = query.lastExecution.checkpointLocation - val latestVersion = query.lastProgress.batchId + 1 - - // Verify the coordinator logged snapshot uploads - (0 until numPartitions).map { - partitionId => - val storeId = StateStoreId(stateCheckpoint, 0, partitionId) - val providerId = StateStoreProviderId(storeId, query.runId) - if (partitionId <= 1) { - // Verify state stores in partition 0 and 1 are lagging and didn't upload - assert( - coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0 - ) - } else { - // Verify other stores have uploaded a snapshot and it's properly logged - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) - } - } - // Verify that the normal state store (partitionId=2) is not lagging behind, - // and the faulty stores are reported as lagging. - val laggingStores = - coordRef.getLaggingStoresForTesting(query.runId, latestVersion) - assert(laggingStores.size == 2) - assert(laggingStores.forall(_.storeId.partitionId <= 1)) - }, - // Stopping the streaming query should deactivate and clear snapshot uploaded events - StopStream, - Execute { query => - val coordRef = - query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator - val latestVersion = query.lastProgress.batchId + 1 - - // Verify we evicted the previous latest uploaded snapshots from the coordinator - (0 until numPartitions).map { partitionId => + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + val numPartitions = query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS) + // Keep track of state checkpoint directory for the second run + var stateCheckpoint = "" + + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + stateCheckpoint = query.lastExecution.checkpointLocation + val latestVersion = query.lastProgress.batchId + 1 + + // Verify the coordinator logged snapshot uploads + (0 until numPartitions).map { + partitionId => val storeId = StateStoreId(stateCheckpoint, 0, partitionId) val providerId = StateStoreProviderId(storeId, query.runId) - assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) - } - // Verify that we are not reporting any lagging stores after eviction, - // since none of these state stores are active anymore. - assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + if (partitionId <= 1) { + // Verify state stores in partition 0 and 1 are lagging and didn't upload + assert( + coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) == 0 + ) + } else { + // Verify other stores have uploaded a snapshot and it's properly logged + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0) + } } - ) - // Restart the query, but do not add too much data so that we don't associate - // the current StateStoreProviderId (store id + query run id) with any new uploads. - testStream(query)( - StartStream(checkpointLocation = srcDir.getCanonicalPath), - // Perform one round of data, which is enough to activate instances and force a - // lagging instance report, but not enough to trigger a snapshot upload yet. - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - Execute { query => - val coordRef = - query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator - val latestVersion = query.lastProgress.batchId + 1 - // Verify that the state stores have restored their snapshot version from the - // checkpoint and reported their current version - (0 until numPartitions).map { - partitionId => - val storeId = StateStoreId(stateCheckpoint, 0, partitionId) - val providerId = StateStoreProviderId(storeId, query.runId) - val latestSnapshotVersion = - coordRef.getLatestSnapshotVersionForTesting(providerId) - if (partitionId <= 1) { - // Verify state stores in partition 0/1 are still lagging and didn't upload - assert(latestSnapshotVersion.getOrElse(0) == 0) - } else { - // Verify other stores have uploaded a snapshot and it's properly logged - assert(latestSnapshotVersion.get > 0) - } - } - // Sleep a bit to allow the coordinator to pass the time threshold and report lag - Thread.sleep(5 * 100) - // Verify that we're reporting the faulty state stores (partitionId 0 and 1) - val laggingStores = - coordRef.getLaggingStoresForTesting(query.runId, latestVersion) - assert(laggingStores.size == 2) - assert(laggingStores.forall(_.storeId.partitionId <= 1)) - }, - StopStream - ) - } + // Verify that the normal state store (partitionId=2) is not lagging behind, + // and the faulty stores are reported as lagging. + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + // Stopping the streaming query should deactivate and clear snapshot uploaded events + StopStream, + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + + // Verify we evicted the previous latest uploaded snapshots from the coordinator + (0 until numPartitions).map { partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + assert(coordRef.getLatestSnapshotVersionForTesting(providerId).isEmpty) + } + // Verify that we are not reporting any lagging stores after eviction, + // since none of these state stores are active anymore. + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + } + ) + // Restart the query, but do not add too much data so that we don't associate + // the current StateStoreProviderId (store id + query run id) with any new uploads. + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Perform one round of data, which is enough to activate instances and force a + // lagging instance report, but not enough to trigger a snapshot upload yet. + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Verify that the state stores have restored their snapshot version from the + // checkpoint and reported their current version + (0 until numPartitions).map { + partitionId => + val storeId = StateStoreId(stateCheckpoint, 0, partitionId) + val providerId = StateStoreProviderId(storeId, query.runId) + val latestSnapshotVersion = + coordRef.getLatestSnapshotVersionForTesting(providerId) + if (partitionId <= 1) { + // Verify state stores in partition 0/1 are still lagging and didn't upload + assert(latestSnapshotVersion.getOrElse(0) == 0) + } else { + // Verify other stores have uploaded a snapshot and it's properly logged + assert(latestSnapshotVersion.get > 0) + } + } + // Sleep a bit to allow the coordinator to pass the time threshold and report lag + Thread.sleep(5 * 100) + // Verify that we're reporting the faulty state stores (partitionId 0 and 1) + val laggingStores = + coordRef.getLaggingStoresForTesting(query.runId, latestVersion) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + StopStream + ) } } + } } test("SPARK-51358: Restarting queries with updated SQLConf get propagated to the coordinator") { @@ -816,81 +668,80 @@ class StateStoreCoordinatorStreamingSuite extends StreamTest { Seq( ("RocksDB", classOf[RocksDBStateStoreProvider].getName), ("HDFS", classOf[HDFSBackedStateStoreProvider].getName) - ).foreach { - case (providerName, providerClassName) => - test( - s"SPARK-51358: Infrequent maintenance with $providerName using Trigger.AvailableNow " + - s"should be reported" + ).foreach { case (providerName, providerClassName) => + test( + s"SPARK-51358: Infrequent maintenance with $providerName using Trigger.AvailableNow " + + s"should be reported" + ) { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, + RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", + SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", + SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "50", + SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" ) { - withSQLConf( - SQLConf.SHUFFLE_PARTITIONS.key -> "2", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName, - RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled" -> "true", - SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key -> "true", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key -> "2", - SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_TIME_DIFF_TO_LOG.key -> "50", - SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0" - ) { - withTempDir { srcDir => - val inputData = MemoryStream[Int] - val query = inputData.toDF().dropDuplicates() - - // Populate state stores with an initial snapshot, so that timestamp isn't marked - // as the default 0ms. - testStream(query)( - StartStream(checkpointLocation = srcDir.getCanonicalPath), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - AddData(inputData, 1, 2, 3), - ProcessAllAvailable() - ) - // Increase maintenance interval to a much larger value to stop snapshot uploads - spark.conf.set(SQLConf.STREAMING_MAINTENANCE_INTERVAL.key, "60000") - // Execute a few batches in a short span - testStream(query)( - AddData(inputData, 1, 2, 3), - StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), - Execute { query => - query.awaitTermination() - // Verify the query ran with the AvailableNow trigger - assert(query.lastExecution.isTerminatingTrigger) - }, - AddData(inputData, 1, 2, 3), - StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), - Execute { query => - query.awaitTermination() - }, - // Start without available now, otherwise the stream closes too quickly for the - // testing RPC call to report lagging state stores - StartStream(checkpointLocation = srcDir.getCanonicalPath), - // Process data to activate state stores, but not enough to trigger snapshot uploads - AddData(inputData, 1, 2, 3), - ProcessAllAvailable(), - Execute { query => - val coordRef = - query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator - val latestVersion = query.lastProgress.batchId + 1 - // Verify that all faulty stores are reported as lagging despite the short burst. - // This test scenario mimics cases where snapshots have not been uploaded for - // a while due to the short running duration of AvailableNow. - val laggingStores = coordRef.getLaggingStoresForTesting( - query.runId, - latestVersion, - isTerminatingTrigger = true - ) - assert(laggingStores.size == 2) - assert(laggingStores.forall(_.storeId.partitionId <= 1)) - }, - StopStream - ) - } + withTempDir { srcDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF().dropDuplicates() + + // Populate state stores with an initial snapshot, so that timestamp isn't marked + // as the default 0ms. + testStream(query)( + StartStream(checkpointLocation = srcDir.getCanonicalPath), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + AddData(inputData, 1, 2, 3), + ProcessAllAvailable() + ) + // Increase maintenance interval to a much larger value to stop snapshot uploads + spark.conf.set(SQLConf.STREAMING_MAINTENANCE_INTERVAL.key, "60000") + // Execute a few batches in a short span + testStream(query)( + AddData(inputData, 1, 2, 3), + StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), + Execute { query => + query.awaitTermination() + // Verify the query ran with the AvailableNow trigger + assert(query.lastExecution.isTerminatingTrigger) + }, + AddData(inputData, 1, 2, 3), + StartStream(Trigger.AvailableNow, checkpointLocation = srcDir.getCanonicalPath), + Execute { query => + query.awaitTermination() + }, + // Start without available now, otherwise the stream closes too quickly for the + // testing RPC call to report lagging state stores + StartStream(checkpointLocation = srcDir.getCanonicalPath), + // Process data to activate state stores, but not enough to trigger snapshot uploads + AddData(inputData, 1, 2, 3), + ProcessAllAvailable(), + Execute { query => + val coordRef = + query.sparkSession.sessionState.streamingQueryManager.stateStoreCoordinator + val latestVersion = query.lastProgress.batchId + 1 + // Verify that all faulty stores are reported as lagging despite the short burst. + // This test scenario mimics cases where snapshots have not been uploaded for + // a while due to the short running duration of AvailableNow. + val laggingStores = coordRef.getLaggingStoresForTesting( + query.runId, + latestVersion, + isTerminatingTrigger = true + ) + assert(laggingStores.size == 2) + assert(laggingStores.forall(_.storeId.partitionId <= 1)) + }, + StopStream + ) } } + } } } From e4c0cf9854116a03e1078084a5ed94a75a755710 Mon Sep 17 00:00:00 2001 From: Zeyu Chen Date: Thu, 10 Apr 2025 10:27:08 -0700 Subject: [PATCH 36/36] SPARK-51358 Add comment for coordRef entry point --- .../spark/sql/execution/streaming/state/StateStore.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 0e560945f646..63936305c7cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -595,6 +595,11 @@ trait StateStoreProvider { object StateStoreProvider extends Logging { + /** + * The state store coordinator reference used to report events such as snapshot uploads from + * the state store providers. + * For all other messages, refer to the coordinator reference in the [[StateStore]] object. + */ @GuardedBy("this") private var stateStoreCoordinatorRef: StateStoreCoordinatorRef = _ @@ -674,6 +679,8 @@ object StateStoreProvider extends Logging { /** * Create the state store coordinator reference which will be reused across state store providers * in the executor. + * This coordinator reference should only be used to report events from store providers regarding + * snapshot uploads to avoid lock contention with other coordinator RPC messages. */ private[state] def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized { val env = SparkEnv.get