diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index a34ceb9f1145..f7eb1e63d7bd 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -320,6 +320,11 @@ "An error occurred during loading state." ], "subClass" : { + "AUTO_SNAPSHOT_REPAIR_FAILED" : { + "message" : [ + "Failed to load snapshot version for state store . An attempt to auto repair using snapshot versions () out of available snapshots () also failed." + ] + }, "CANNOT_FIND_BASE_SNAPSHOT_CHECKPOINT" : { "message" : [ "Cannot find a base snapshot checkpoint with lineage: ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b8907629ad37..d2d7edc65121 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2546,6 +2546,43 @@ object SQLConf { .intConf .createWithDefault(10) + val STATE_STORE_AUTO_SNAPSHOT_REPAIR_ENABLED = + buildConf("spark.sql.streaming.stateStore.autoSnapshotRepair.enabled") + .internal() + .doc("When true, enables automatic repair of state store snapshot, when a bad snapshot is " + + "detected while loading the state store, to prevent the query from failing. " + + "Typically, queries will fail when they are unable to load a snapshot, " + + "but this helps recover by skipping the bad snapshot and uses the change files." + + "NOTE: For RocksDB state store, changelog checkpointing must be enabled") + .version("4.1.0") + .booleanConf + // Disable in tests, so that tests will fail if they encounter bad snapshot + .createWithDefault(!Utils.isTesting) + + val STATE_STORE_AUTO_SNAPSHOT_REPAIR_NUM_FAILURES_BEFORE_ACTIVATING = + buildConf("spark.sql.streaming.stateStore.autoSnapshotRepair.numFailuresBeforeActivating") + .internal() + .doc( + "When autoSnapshotRepair is enabled, it will wait for the specified number of snapshot " + + "load failures, before it attempts to repair." + ) + .version("4.1.0") + .intConf + .checkValue(k => k > 0, "Must allow at least 1 failure before activating autoSnapshotRepair") + .createWithDefault(1) + + val STATE_STORE_AUTO_SNAPSHOT_REPAIR_MAX_CHANGE_FILE_REPLAY = + buildConf("spark.sql.streaming.stateStore.autoSnapshotRepair.maxChangeFileReplay") + .internal() + .doc( + "When autoSnapshotRepair is enabled, this specifies the maximum number of change " + + "files allowed to be replayed to rebuild state due to bad snapshots." + ) + .version("4.1.0") + .intConf + .checkValue(k => k > 0, "Must allow at least 1 change file replay") + .createWithDefault(50) + val STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT = buildConf("spark.sql.streaming.stateStore.numStateStoreInstanceMetricsToReport") .internal() @@ -6729,6 +6766,15 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + def stateStoreAutoSnapshotRepairEnabled: Boolean = + getConf(STATE_STORE_AUTO_SNAPSHOT_REPAIR_ENABLED) + + def stateStoreAutoSnapshotRepairNumFailuresBeforeActivating: Int = + getConf(STATE_STORE_AUTO_SNAPSHOT_REPAIR_NUM_FAILURES_BEFORE_ACTIVATING) + + def stateStoreAutoSnapshotRepairMaxChangeFileReplay: Int = + getConf(STATE_STORE_AUTO_SNAPSHOT_REPAIR_MAX_CHANGE_FILE_REPLAY) + def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED) def stateStoreSkipNullsForStreamStreamJoins: Boolean = diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala index 580b8e1114f9..ef3bb4711c85 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala @@ -103,6 +103,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L lastProgress.stateOperators.head.customMetrics.keySet().asScala == Set( "loadedMapCacheHitCount", "loadedMapCacheMissCount", + "numSnapshotsAutoRepaired", "stateOnCurrentVersionSizeBytes", "SnapshotLastUploaded.partition_0_default")) assert(lastProgress.sources.nonEmpty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/AutoSnapshotLoader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/AutoSnapshotLoader.scala new file mode 100644 index 000000000000..d94f10d49fbd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/AutoSnapshotLoader.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.collection.immutable.ArraySeq +import scala.util.control.NonFatal + +import org.apache.hadoop.fs.{Path, PathFilter} + +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys.{NUM_RETRIES, NUM_RETRY, VERSION_NUM} +import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager + +/** + * [[AutoSnapshotLoader]] is used to handle loading state store snapshot version from the + * checkpoint directory. It supports Auto snapshot repair, which will automatically handle + * corrupt snapshots and skip them, by using another snapshot version before the corrupt one. + * If no snapshot exists before the corrupt one, then it will use the 0 version snapshot + * (represents initial/empty snapshot). + * + * @param autoSnapshotRepairEnabled If true, it will handle corrupt snapshot + * @param numFailuresBeforeActivating If auto snapshot repair is enabled, + * number of failures before activating it + * @param maxChangeFileReplay If auto snapshot repair is enabled, maximum difference between + * the requested snapshot version and the selected snapshot version + * @param loggingId To append to log messages + * */ +abstract class AutoSnapshotLoader( + autoSnapshotRepairEnabled: Boolean, + numFailuresBeforeActivating: Int, + maxChangeFileReplay: Int, + loggingId: String = "") extends Logging { + + override protected def logName: String = s"${super.logName} $loggingId" + + /** Called before loading a snapshot from the checkpoint directory */ + protected def beforeLoad(): Unit + + /** + * Attempt to load the specified snapshot version from the checkpoint directory. + * Should throw an exception if the snapshot is corrupt. + * @note Must support loading version 0 + * */ + protected def loadSnapshotFromCheckpoint(snapshotVersion: Long): Unit + + /** Called when load fails, to do any necessary cleanup/variable reset */ + protected def onLoadSnapshotFromCheckpointFailure(): Unit + + /** Get a list of eligible snapshot versions in the checkpoint directory that can be loaded */ + protected def getEligibleSnapshots(versionToLoad: Long): Seq[Long] + + /** + * Load the latest snapshot for the specified version from the checkpoint directory. + * If Auto snapshot repair is enabled, the snapshot version loaded may be lower than + * the latest snapshot version, if the latest is corrupt. + * + * @param versionToLoad The version to load latest snapshot for + * @return The actual loaded snapshot version and if it was due to auto repair + * */ + def loadSnapshot(versionToLoad: Long): (Long, Boolean) = { + val eligibleSnapshots = + (getEligibleSnapshots(versionToLoad) :+ 0L) // always include the initial snapshot + .distinct // Ensure no duplicate version numbers + .sorted(Ordering[Long].reverse) + + // Start with the latest snapshot + val firstEligibleSnapshot = eligibleSnapshots.head + + // no retry if auto snapshot repair is not enabled + val maxNumFailures = if (autoSnapshotRepairEnabled) numFailuresBeforeActivating else 1 + var numFailuresForFirstSnapshot = 0 + var lastException: Throwable = null + var loadedSnapshot: Option[Long] = None + while (loadedSnapshot.isEmpty && numFailuresForFirstSnapshot < maxNumFailures) { + beforeLoad() // if this fails, then we should fail + try { + // try to load the first eligible snapshot + loadSnapshotFromCheckpoint(firstEligibleSnapshot) + loadedSnapshot = Some(firstEligibleSnapshot) + } catch { + // Swallow only if auto snapshot repair is enabled + // If auto snapshot repair is not enabled, we should fail immediately + case NonFatal(e) if autoSnapshotRepairEnabled => + onLoadSnapshotFromCheckpointFailure() + numFailuresForFirstSnapshot += 1 + logError(log"Failed to load snapshot version " + + log"${MDC(VERSION_NUM, firstEligibleSnapshot)}, " + + log"attempt ${MDC(NUM_RETRY, numFailuresForFirstSnapshot)} out of " + + log"${MDC(NUM_RETRIES, maxNumFailures)} attempts", e) + lastException = e + case e: Throwable => + onLoadSnapshotFromCheckpointFailure() + throw e + } + } + + var autoRepairCompleted = false + if (loadedSnapshot.isEmpty) { + // we would only get here if auto snapshot repair is enabled + assert(autoSnapshotRepairEnabled) + + val remainingEligibleSnapshots = if (eligibleSnapshots.length > 1) { + // skip the first snapshot, since we already tried it + eligibleSnapshots.tail + } else { + // no more snapshots to try + Seq.empty + } + + // select remaining snapshots that are within the maxChangeFileReplay limit + val selectedRemainingSnapshots = remainingEligibleSnapshots.filter( + s => versionToLoad - s <= maxChangeFileReplay) + + logInfo(log"Attempting to auto repair snapshot by skipping " + + log"snapshot version ${MDC(VERSION_NUM, firstEligibleSnapshot)} " + + log"and trying to load with one of the selected snapshots " + + log"${MDC(VERSION_NUM, selectedRemainingSnapshots)}, out of eligible snapshots " + + log"${MDC(VERSION_NUM, remainingEligibleSnapshots)}. " + + log"maxChangeFileReplay: ${MDC(VERSION_NUM, maxChangeFileReplay)}") + + // Now try to load using any of the selected snapshots, + // remember they are sorted in descending order + for (snapshotVersion <- selectedRemainingSnapshots if loadedSnapshot.isEmpty) { + beforeLoad() // if this fails, then we should fail + try { + loadSnapshotFromCheckpoint(snapshotVersion) + loadedSnapshot = Some(snapshotVersion) + logInfo(log"Successfully loaded snapshot version " + + log"${MDC(VERSION_NUM, snapshotVersion)}. Repair complete.") + } catch { + case NonFatal(e) => + logError(log"Failed to load snapshot version " + + log"${MDC(VERSION_NUM, snapshotVersion)}, will retry repair with " + + log"the next eligible snapshot version", e) + onLoadSnapshotFromCheckpointFailure() + lastException = e + } + } + + if (loadedSnapshot.isEmpty) { + // we tried all eligible snapshots and failed to load any of them + logError(log"Auto snapshot repair failed to load any snapshot:" + + log" latestSnapshotVersion: ${MDC(VERSION_NUM, firstEligibleSnapshot)}, " + + log"attemptedSnapshots: ${MDC(VERSION_NUM, selectedRemainingSnapshots)}, " + + log"eligibleSnapshots: ${MDC(VERSION_NUM, remainingEligibleSnapshots)}, " + + log"maxChangeFileReplay: ${MDC(VERSION_NUM, maxChangeFileReplay)}", lastException) + throw StateStoreErrors.autoSnapshotRepairFailed( + loggingId, firstEligibleSnapshot, selectedRemainingSnapshots, remainingEligibleSnapshots, + lastException) + } else { + autoRepairCompleted = true + } + } + + // we would only get here if we successfully loaded a snapshot + (loadedSnapshot.get, autoRepairCompleted) + } +} + +object SnapshotLoaderHelper { + /** Get all the snapshot versions that can be used to load this version */ + def getEligibleSnapshotsForVersion( + version: Long, + fm: CheckpointFileManager, + dfsPath: Path, + pathFilter: PathFilter, + fileSuffix: String): Seq[Long] = { + if (fm.exists(dfsPath)) { + ArraySeq.unsafeWrapArray( + fm.list(dfsPath, pathFilter) + .map(_.getPath.getName.stripSuffix(fileSuffix)) + .map(_.toLong) + ).filter(_ <= version) + } else { + Seq(0L) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index aa4fa9bfaf62..f1c9c94e7bf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ import java.util import java.util.{Locale, UUID} -import java.util.concurrent.atomic.{AtomicLong, LongAdder} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, LongAdder} import scala.collection.mutable import scala.jdk.CollectionConverters._ @@ -295,7 +295,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with def getMetricsForProvider(): Map[String, Long] = synchronized { Map("memoryUsedBytes" -> SizeEstimator.estimate(loadedMaps), metricLoadedMapCacheHit.name -> loadedMapCacheHitCount.sum(), - metricLoadedMapCacheMiss.name -> loadedMapCacheMissCount.sum()) + metricLoadedMapCacheMiss.name -> loadedMapCacheMissCount.sum(), + metricNumSnapshotsAutoRepaired.name -> (if (performedSnapshotAutoRepair.get()) 1 else 0) + ) } /** Get the state store for making updates to create a new `version` of the store. */ @@ -324,6 +326,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with if (version < 0) { throw QueryExecutionErrors.unexpectedStateStoreVersion(version) } + + performedSnapshotAutoRepair.set(false) val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey) if (version > 0) { newMap.putAll(loadMap(version)) @@ -426,6 +430,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = { metricStateOnCurrentVersionSizeBytes :: metricLoadedMapCacheHit :: metricLoadedMapCacheMiss :: + metricNumSnapshotsAutoRepaired :: Nil } @@ -471,6 +476,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with mgr } } + private val onlySnapshotFiles = new PathFilter { + override def accept(path: Path): Boolean = path.toString.endsWith(".snapshot") + } private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) private val loadedMapCacheHitCount: LongAdder = new LongAdder @@ -479,6 +487,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with // This is updated when the maintenance task writes the snapshot file and read by the task // thread. -1 represents no version has ever been uploaded. private val lastUploadedSnapshotVersion: AtomicLong = new AtomicLong(-1L) + // Was snapshot auto repair performed when loading the current version + private val performedSnapshotAutoRepair: AtomicBoolean = new AtomicBoolean(false) private lazy val metricStateOnCurrentVersionSizeBytes: StateStoreCustomSizeMetric = StateStoreCustomSizeMetric("stateOnCurrentVersionSizeBytes", @@ -492,6 +502,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with StateStoreCustomSumMetric("loadedMapCacheMissCount", "count of cache miss on states cache in provider") + private lazy val metricNumSnapshotsAutoRepaired: StateStoreCustomMetric = + StateStoreCustomSumMetric("numSnapshotsAutoRepaired", + "number of snapshots that were automatically repaired during store load") + private lazy val instanceMetricSnapshotLastUpload: StateStoreInstanceMetric = StateStoreSnapshotLastUploadInstanceMetric() @@ -593,52 +607,78 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with loadedMapCacheMissCount.increment() val (result, elapsedMs) = Utils.timeTakenMs { - val snapshotCurrentVersionMap = readSnapshotFile(version) - if (snapshotCurrentVersionMap.isDefined) { - synchronized { putStateIntoStateCacheMap(version, snapshotCurrentVersionMap.get) } + val (loadedVersion, loadedMap) = loadSnapshot(version) + val finalMap = if (loadedVersion == version) { + loadedMap + } else { + // Load all the deltas from the version after the loadedVersion up to the target version. + // The loadedVersion is the one with a full snapshot, so it doesn't need deltas. + val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey) + resultMap.putAll(loadedMap) + for (deltaVersion <- loadedVersion + 1 to version) { + updateFromDeltaFile(deltaVersion, resultMap) + } + resultMap + } - // Report the loaded snapshot's version to the coordinator - reportSnapshotUploadToCoordinator(version) + // Synchronize and update the state cache map + synchronized { putStateIntoStateCacheMap(version, finalMap) } - return snapshotCurrentVersionMap.get - } + // Report the snapshot found to the coordinator + reportSnapshotUploadToCoordinator(loadedVersion) - // Find the most recent map before this version that we can. - // [SPARK-22305] This must be done iteratively to avoid stack overflow. - var lastAvailableVersion = version - var lastAvailableMap: Option[HDFSBackedStateStoreMap] = None - while (lastAvailableMap.isEmpty) { - lastAvailableVersion -= 1 + finalMap + } - if (lastAvailableVersion <= 0) { + logDebug(s"Loading state for $version takes $elapsedMs ms.") + + result + } + + /** Loads the latest snapshot for the version we want to load and + * returns the snapshot version and map representing the snapshot */ + private def loadSnapshot(versionToLoad: Long): (Long, HDFSBackedStateStoreMap) = { + var loadedMap: Option[HDFSBackedStateStoreMap] = None + val storeIdStr = s"StateStoreId(opId=${stateStoreId_.operatorId}," + + s"partId=${stateStoreId_.partitionId},name=${stateStoreId_.storeName})" + + val snapshotLoader = new AutoSnapshotLoader( + storeConf.autoSnapshotRepairEnabled, + storeConf.autoSnapshotRepairNumFailuresBeforeActivating, + storeConf.autoSnapshotRepairMaxChangeFileReplay, + storeIdStr) { + override protected def beforeLoad(): Unit = {} + + override protected def loadSnapshotFromCheckpoint(snapshotVersion: Long): Unit = { + loadedMap = if (snapshotVersion <= 0) { // Use an empty map for versions 0 or less. - lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)) + Some(HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)) } else { - lastAvailableMap = - synchronized { Option(loadedMaps.get(lastAvailableVersion)) } - .orElse(readSnapshotFile(lastAvailableVersion)) + // first try to get the map from the cache + synchronized { Option(loadedMaps.get(snapshotVersion)) } + .orElse(readSnapshotFile(snapshotVersion)) } } - // Load all the deltas from the version after the last available one up to the target version. - // The last available version is the one with a full snapshot, so it doesn't need deltas. - val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey) - resultMap.putAll(lastAvailableMap.get) - for (deltaVersion <- lastAvailableVersion + 1 to version) { - updateFromDeltaFile(deltaVersion, resultMap) - } + override protected def onLoadSnapshotFromCheckpointFailure(): Unit = {} - synchronized { putStateIntoStateCacheMap(version, resultMap) } + override protected def getEligibleSnapshots(versionToLoad: Long): Seq[Long] = { + val snapshotVersions = SnapshotLoaderHelper.getEligibleSnapshotsForVersion( + versionToLoad, fm, baseDir, onlySnapshotFiles, fileSuffix = ".snapshot") - // Report the last available snapshot's version to the coordinator - reportSnapshotUploadToCoordinator(lastAvailableVersion) + // Get locally cached versions, so we can use the locally cached version if available. + val cachedVersions = synchronized { + loadedMaps.keySet.asScala.toSeq + }.filter(_ <= versionToLoad) - resultMap + // Combine the two sets of versions, so we can check both during load + (snapshotVersions ++ cachedVersions).distinct + } } - logDebug(s"Loading state for $version takes $elapsedMs ms.") - - result + val (loadedVersion, autoRepairCompleted) = snapshotLoader.loadSnapshot(versionToLoad) + performedSnapshotAutoRepair.set(autoRepairCompleted) + (loadedVersion, loadedMap.get) } private def writeUpdateToDeltaFile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index fb3ef606b8f3..f8570c583387 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -219,6 +219,9 @@ class RocksDB( @volatile private var numInternalKeysOnLoadedVersion = 0L @volatile private var numInternalKeysOnWritingVersion = 0L + // Was snapshot auto repair performed when loading the current version + @volatile private var performedSnapshotAutoRepair = false + @volatile private var fileManagerMetrics = RocksDBFileManagerMetrics.EMPTY_METRICS // SPARK-46249 - Keep track of recorded metrics per version which can be used for querying later @@ -541,24 +544,9 @@ class RocksDB( try { if (loadedVersion != version) { closeDB(ignoreException = false) - val latestSnapshotVersion = fileManager.getLatestSnapshotVersion(version) - val metadata = fileManager.loadCheckpointFromDfs( - latestSnapshotVersion, - workingDir, - rocksDBFileMapping) - - loadedVersion = latestSnapshotVersion - - // reset the last snapshot version to the latest available snapshot version - lastSnapshotVersion = latestSnapshotVersion - - // Initialize maxVersion upon successful load from DFS - fileManager.setMaxSeenVersion(version) - - // Report this snapshot version to the coordinator - reportSnapshotUploadToCoordinator(latestSnapshotVersion) - openLocalRocksDB(metadata) + // load the latest snapshot + loadSnapshotWithoutCheckpointId(version) if (loadedVersion != version) { val versionsAndUniqueIds: Array[(Long, Option[String])] = @@ -589,6 +577,54 @@ class RocksDB( this } + private def loadSnapshotWithoutCheckpointId(versionToLoad: Long): Long = { + // Don't allow auto snapshot repair if changelog checkpointing is not enabled + // since it relies on changelog to rebuild state. + val allowAutoSnapshotRepair = if (enableChangelogCheckpointing) { + conf.stateStoreConf.autoSnapshotRepairEnabled + } else { + false + } + val snapshotLoader = new AutoSnapshotLoader( + allowAutoSnapshotRepair, + conf.stateStoreConf.autoSnapshotRepairNumFailuresBeforeActivating, + conf.stateStoreConf.autoSnapshotRepairMaxChangeFileReplay, + loggingId) { + override protected def beforeLoad(): Unit = closeDB(ignoreException = false) + + override protected def loadSnapshotFromCheckpoint(snapshotVersion: Long): Unit = { + val remoteMetaData = fileManager.loadCheckpointFromDfs(snapshotVersion, + workingDir, rocksDBFileMapping) + + loadedVersion = snapshotVersion + // Initialize maxVersion upon successful load from DFS + fileManager.setMaxSeenVersion(snapshotVersion) + + openLocalRocksDB(remoteMetaData) + + // By setting this to the snapshot version we successfully loaded, + // if auto snapshot repair is enabled, and we end up skipping the latest snapshot + // and used an older one, we will create a new snapshot at commit time + // if the loaded one is old enough. + lastSnapshotVersion = snapshotVersion + // Report this snapshot version to the coordinator + reportSnapshotUploadToCoordinator(snapshotVersion) + } + + override protected def onLoadSnapshotFromCheckpointFailure(): Unit = { + loadedVersion = -1 // invalidate loaded data + } + + override protected def getEligibleSnapshots(version: Long): Seq[Long] = { + fileManager.getEligibleSnapshotsForVersion(version) + } + } + + val (version, autoRepairCompleted) = snapshotLoader.loadSnapshot(versionToLoad) + performedSnapshotAutoRepair = autoRepairCompleted + version + } + /** * Function to check if col family is internal or not based on information recorded in * checkpoint metadata. @@ -657,6 +693,7 @@ class RocksDB( assert(version >= 0) recordedMetrics = None + performedSnapshotAutoRepair = false // Reset the load metrics before loading loadMetrics.clear() @@ -1622,7 +1659,8 @@ class RocksDB( filesReused = fileManagerMetrics.filesReused, lastUploadedSnapshotVersion = lastUploadedSnapshotVersion.get(), zipFileBytesUncompressed = fileManagerMetrics.zipFileBytesUncompressed, - nativeOpsMetrics = nativeOpsMetrics) + nativeOpsMetrics = nativeOpsMetrics, + numSnapshotsAutoRepaired = if (performedSnapshotAutoRepair) 1 else 0) } /** @@ -2067,7 +2105,8 @@ case class RocksDBConf( compression: String, reportSnapshotUploadLag: Boolean, fileChecksumEnabled: Boolean, - maxVersionsToDeletePerMaintenance: Int) + maxVersionsToDeletePerMaintenance: Int, + stateStoreConf: StateStoreConf) object RocksDBConf { /** Common prefix of all confs in SQLConf that affects RocksDB */ @@ -2267,7 +2306,8 @@ object RocksDBConf { getStringConf(COMPRESSION_CONF), storeConf.reportSnapshotUploadLag, storeConf.checkpointFileChecksumEnabled, - storeConf.maxVersionsToDeletePerMaintenance) + storeConf.maxVersionsToDeletePerMaintenance, + storeConf) } def apply(): RocksDBConf = apply(new StateStoreConf()) @@ -2289,7 +2329,8 @@ case class RocksDBMetrics( filesReused: Long, zipFileBytesUncompressed: Option[Long], nativeOpsMetrics: Map[String, Long], - lastUploadedSnapshotVersion: Long) { + lastUploadedSnapshotVersion: Long, + numSnapshotsAutoRepaired: Long) { def json: String = Serialization.write(this)(RocksDBMetrics.format) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala index 2e86ff70d58f..92fa5d0350fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala @@ -427,6 +427,12 @@ class RocksDBFileManager( } } + /** Get all the snapshot versions that can be used to load this version */ + def getEligibleSnapshotsForVersion(version: Long): Seq[Long] = { + SnapshotLoaderHelper.getEligibleSnapshotsForVersion( + version, fm, new Path(dfsRootDir), onlyZipFiles, fileSuffix = ".zip") + } + /** * Based on the ground truth lineage loaded from changelog file (lineage), this function * does file listing to find all snapshot (version, uniqueId) pairs, and finds diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 2cc4c8a870ae..e01e1e0f86ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -552,7 +552,8 @@ private[sql] class RocksDBStateStoreProvider CUSTOM_METRIC_PINNED_BLOCKS_MEM_USAGE -> rocksDBMetrics.pinnedBlocksMemUsage, CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES_KEYS -> rocksDBMetrics.numInternalKeys, CUSTOM_METRIC_NUM_EXTERNAL_COL_FAMILIES -> internalColFamilyCnt(), - CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES -> externalColFamilyCnt() + CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES -> externalColFamilyCnt(), + CUSTOM_METRIC_NUM_SNAPSHOTS_AUTO_REPAIRED -> rocksDBMetrics.numSnapshotsAutoRepaired ) ++ rocksDBMetrics.zipFileBytesUncompressed.map(bytes => Map(CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED -> bytes)).getOrElse(Map()) @@ -1261,6 +1262,9 @@ object RocksDBStateStoreProvider { // Total SST file size val CUSTOM_METRIC_SST_FILE_SIZE = StateStoreCustomSizeMetric( "rocksdbSstFileSize", "RocksDB: size of all SST files") + val CUSTOM_METRIC_NUM_SNAPSHOTS_AUTO_REPAIRED = StateStoreCustomSumMetric( + "rocksdbNumSnapshotsAutoRepaired", + "RocksDB: number of snapshots that were automatically repaired during store load") val ALL_CUSTOM_METRICS = Seq( CUSTOM_METRIC_SST_FILE_SIZE, CUSTOM_METRIC_GET_TIME, CUSTOM_METRIC_PUT_TIME, @@ -1276,7 +1280,7 @@ object RocksDBStateStoreProvider { CUSTOM_METRIC_PINNED_BLOCKS_MEM_USAGE, CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES_KEYS, CUSTOM_METRIC_NUM_EXTERNAL_COL_FAMILIES, CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES, CUSTOM_METRIC_LOAD_FROM_SNAPSHOT_TIME, CUSTOM_METRIC_LOAD_TIME, CUSTOM_METRIC_REPLAY_CHANGE_LOG, - CUSTOM_METRIC_NUM_REPLAY_CHANGE_LOG_FILES) + CUSTOM_METRIC_NUM_REPLAY_CHANGE_LOG_FILES, CUSTOM_METRIC_NUM_SNAPSHOTS_AUTO_REPAIRED) val CUSTOM_INSTANCE_METRIC_SNAPSHOT_LAST_UPLOADED = StateStoreSnapshotLastUploadInstanceMetric() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 74904a37f450..a765f52a2272 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -48,6 +48,17 @@ class StateStoreConf( */ val minDeltasForSnapshot: Int = sqlConf.stateStoreMinDeltasForSnapshot + /** Whether we should enable automatic snapshot repair */ + val autoSnapshotRepairEnabled: Boolean = sqlConf.stateStoreAutoSnapshotRepairEnabled + + /** Number of failures before activating auto snapshot repair when enabled */ + val autoSnapshotRepairNumFailuresBeforeActivating: Int = + sqlConf.stateStoreAutoSnapshotRepairNumFailuresBeforeActivating + + /** Maximum number of change files allowed to be replayed when auto snapshot repair is enabled */ + val autoSnapshotRepairMaxChangeFileReplay: Int = + sqlConf.stateStoreAutoSnapshotRepairMaxChangeFileReplay + /** Minimum versions a State Store implementation should retain to allow rollbacks */ val minVersionsToRetain: Int = sqlConf.minBatchesToRetain diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 970499a054b5..23bb54d86348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -254,6 +254,16 @@ object StateStoreErrors { new StateStoreUnexpectedEmptyFileInRocksDBZip(fileName, zipFileName) } + def autoSnapshotRepairFailed( + stateStoreId: String, + latestSnapshot: Long, + selectedSnapshots: Seq[Long], + eligibleSnapshots: Seq[Long], + cause: Throwable): StateStoreAutoSnapshotRepairFailed = { + new StateStoreAutoSnapshotRepairFailed( + stateStoreId, latestSnapshot, selectedSnapshots, eligibleSnapshots, cause) + } + def cannotLoadStore(e: Throwable): Throwable = { e match { case e: SparkException @@ -583,3 +593,18 @@ class StateStoreUnexpectedEmptyFileInRocksDBZip(fileName: String, zipFileName: S "fileName" -> fileName, "zipFileName" -> zipFileName), cause = null) + +class StateStoreAutoSnapshotRepairFailed( + stateStoreId: String, + latestSnapshot: Long, + selectedSnapshots: Seq[Long], + eligibleSnapshots: Seq[Long], + cause: Throwable) + extends SparkRuntimeException( + errorClass = "CANNOT_LOAD_STATE_STORE.AUTO_SNAPSHOT_REPAIR_FAILED", + messageParameters = Map( + "latestSnapshot" -> latestSnapshot.toString, + "stateStoreId" -> stateStoreId, + "selectedSnapshots" -> selectedSnapshots.mkString(","), + "eligibleSnapshots" -> eligibleSnapshots.mkString(",")), + cause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/AutoSnapshotLoaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/AutoSnapshotLoaderSuite.scala new file mode 100644 index 000000000000..186248b43bc8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/AutoSnapshotLoaderSuite.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.collection.mutable.ListBuffer + +import org.apache.spark.SparkFunSuite + +/** + * Suite to test [[AutoSnapshotLoader]]. Tests different behaviors including + * when repair is enabled/disabled, when numFailuresBeforeActivating is set, + * when maxChangeFileReplay is set. + */ +class AutoSnapshotLoaderSuite extends SparkFunSuite { + test("successful snapshot load without auto repair") { + // Test auto repair on or off + Seq(true, false).foreach { enabled => + val loader = new TestAutoSnapshotLoader( + autoSnapshotRepairEnabled = enabled, + eligibleSnapshots = Seq(2, 4), + failSnapshots = Seq.empty) + + val (versionLoaded, autoRepairCompleted) = loader.loadSnapshot(5) + assert(!autoRepairCompleted) + assert(versionLoaded == 4, "Should load the latest snapshot version") + assert(loader.getRequestedSnapshotVersions == Seq(4), + "Should have requested only the latest snapshot version") + } + } + + test("snapshot load failure gets repaired") { + def createLoader(autoRepair: Boolean): TestAutoSnapshotLoader = + new TestAutoSnapshotLoader( + autoSnapshotRepairEnabled = autoRepair, + eligibleSnapshots = Seq(2, 4), + failSnapshots = Seq(4)) + + // load without auto repair enabled + var loader = createLoader(autoRepair = false) + + // This should fail to load v5 due to snapshot 4 failure, even though snapshot 2 exists + val ex = intercept[TestLoadException] { + loader.loadSnapshot(5) + } + assert(ex.snapshotVersion == 4, "Load failure should be due to version 4") + + // Now try to load with auto repair enabled + loader = createLoader(autoRepair = true) + val (versionLoaded, autoRepairCompleted) = loader.loadSnapshot(5) + assert(autoRepairCompleted) + assert(versionLoaded == 2, "Should have loaded the snapshot version before the corrupt one") + assert(loader.getRequestedSnapshotVersions == Seq(4, 2)) + } + + test("repair works even when all snapshots are corrupt") { + val loader = new TestAutoSnapshotLoader( + autoSnapshotRepairEnabled = true, + eligibleSnapshots = Seq(2, 4), + failSnapshots = Seq(2, 4)) + + val (versionLoaded, autoRepairCompleted) = loader.loadSnapshot(5) + assert(autoRepairCompleted) + assert(versionLoaded == 0, "Load 0 since no good snapshots available") + assert(loader.getRequestedSnapshotVersions == Seq(4, 2, 0)) + } + + test("number of failures before activating auto repair") { + def createLoader(numFailures: Int): TestAutoSnapshotLoader = + new TestAutoSnapshotLoader( + autoSnapshotRepairEnabled = true, + numFailuresBeforeActivating = numFailures, + eligibleSnapshots = Seq(2, 4), + failSnapshots = Seq(4)) + + (1 to 5).foreach { numFailures => + val loader = createLoader(numFailures) + val (versionLoaded, autoRepairCompleted) = loader.loadSnapshot(5) + assert(autoRepairCompleted) + assert(versionLoaded == 2, "Should have loaded the snapshot version before the corrupt one") + assert(loader.getRequestedSnapshotVersions == Seq.fill(numFailures)(4) :+ 2, + s"should have tried to load version 4 $numFailures times before falling back to version 2") + } + } + + test("maximum change file replay") { + def createLoader(maxChangeFileReplay: Int, fail: Seq[Long]): TestAutoSnapshotLoader = + new TestAutoSnapshotLoader( + autoSnapshotRepairEnabled = true, + maxChangeFileReplay = maxChangeFileReplay, + eligibleSnapshots = Seq(2, 4, 5), + failSnapshots = fail) + + var loader = createLoader(maxChangeFileReplay = 1, fail = Seq(5)) + // repair with max change file replay = 1, should load snapshot 4 + val (versionLoaded, autoRepairCompleted) = loader.loadSnapshot(5) + assert(autoRepairCompleted) + assert(versionLoaded == 4) + assert(loader.getRequestedSnapshotVersions == Seq(5, 4)) + + // repair with max change file replay = 2, should fail since we can't use the older snapshots + loader = createLoader(maxChangeFileReplay = 2, fail = Seq(5, 4)) + val ex = intercept[StateStoreAutoSnapshotRepairFailed] { + loader.loadSnapshot(5) + } + + checkError( + exception = ex, + condition = "CANNOT_LOAD_STATE_STORE.AUTO_SNAPSHOT_REPAIR_FAILED", + parameters = Map( + "latestSnapshot" -> "5", + "stateStoreId" -> "test", + "selectedSnapshots" -> "4", // only selected 4 due to maxChangeFileReplay = 2 + "eligibleSnapshots" -> "4,2,0") + ) + assert(loader.getRequestedSnapshotVersions == Seq(5, 4)) + assert(ex.getCause.asInstanceOf[TestLoadException].snapshotVersion == 4) + + // repair with max change file replay = 3, should load snapshot 2 + loader = createLoader(maxChangeFileReplay = 3, fail = Seq(5, 4)) + val (versionLoaded_, autoRepairCompleted_) = loader.loadSnapshot(5) + assert(autoRepairCompleted_) + assert(versionLoaded_ == 2) + assert(loader.getRequestedSnapshotVersions == Seq(5, 4, 2)) + } +} + +/** + * A test implementation of [[AutoSnapshotLoader]] for testing purposes. + * Allows tracking of requested snapshot versions and simulating load failures. + * */ +class TestAutoSnapshotLoader( + autoSnapshotRepairEnabled: Boolean, + numFailuresBeforeActivating: Int = 1, + maxChangeFileReplay: Int = 10, + loggingId: String = "test", + eligibleSnapshots: Seq[Long], + failSnapshots: Seq[Long] = Seq.empty) extends AutoSnapshotLoader( + autoSnapshotRepairEnabled, numFailuresBeforeActivating, maxChangeFileReplay, loggingId) { + + // track snapshot versions requested via loadSnapshotFromCheckpoint + private val requestedSnapshotVersions = ListBuffer[Long]() + def getRequestedSnapshotVersions: Seq[Long] = requestedSnapshotVersions.toSeq + + override protected def beforeLoad(): Unit = {} + + override protected def loadSnapshotFromCheckpoint(snapshotVersion: Long): Unit = { + // Track the snapshot version + requestedSnapshotVersions += snapshotVersion + + // throw exception if the snapshot version is in the failSnapshots list + if (failSnapshots.contains(snapshotVersion)) { + throw new TestLoadException(snapshotVersion) + } + } + + override protected def onLoadSnapshotFromCheckpointFailure(): Unit = {} + + override protected def getEligibleSnapshots(versionToLoad: Long): Seq[Long] = eligibleSnapshots +} + +class TestLoadException(val snapshotVersion: Long) + extends IllegalStateException(s"Cannot load snapshot version $snapshotVersion") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala index 38e5b15465b8..0bf95ce92797 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala @@ -110,10 +110,13 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest "rocksdbTotalBytesReadThroughIterator", "rocksdbTotalBytesWrittenByFlush", "rocksdbPinnedBlocksMemoryUsage", "rocksdbNumInternalColFamiliesKeys", "rocksdbNumExternalColumnFamilies", "rocksdbNumInternalColumnFamilies", + "rocksdbNumSnapshotsAutoRepaired", "SnapshotLastUploaded.partition_0_default", "rocksdbChangeLogWriterCommitLatencyMs", "rocksdbSaveZipFilesLatencyMs", "rocksdbLoadFromSnapshotLatencyMs", "rocksdbLoadLatencyMs", "rocksdbReplayChangeLogLatencyMs", "rocksdbNumReplayChangelogFiles")) + assert(stateOperatorMetrics.customMetrics.get("rocksdbNumSnapshotsAutoRepaired") == 0, + "Should be 0 since we didn't repair any snapshot") } } finally { query.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index da6c3e62798e..de16aa38fe5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -3692,6 +3692,83 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession } } + testWithChangelogCheckpointingEnabled("Auto snapshot repair") { + withSQLConf( + SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED.key -> false.toString, + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "2" + ) { + withTempDir { dir => + val remoteDir = dir.getCanonicalPath + withDB(remoteDir) { db => + db.load(0) + db.put("a", "0") + db.commit() + assert(db.metricsOpt.get.numSnapshotsAutoRepaired == 0) + + db.load(1) + db.put("b", "1") + db.commit() // snapshot is created + assert(db.metricsOpt.get.numSnapshotsAutoRepaired == 0) + db.doMaintenance() // upload snapshot 2.zip + + db.load(2) + db.put("c", "2") + db.commit() + assert(db.metricsOpt.get.numSnapshotsAutoRepaired == 0) + + db.load(3) + db.put("d", "3") + db.commit() // snapshot is created + assert(db.metricsOpt.get.numSnapshotsAutoRepaired == 0) + db.doMaintenance() // upload snapshot 4.zip + } + + def corruptFile(file: File): Unit = + // overwrite the file content to become empty + new PrintWriter(file) { close() } + + // corrupt snapshot 4.zip + corruptFile(new File(remoteDir, "4.zip")) + + withDB(remoteDir) { db => + // this should fail when trying to load from remote + val ex = intercept[java.nio.file.NoSuchFileException] { + db.load(4) + } + // would fail while trying to read the metadata file from the empty zip file + assert(ex.getMessage.contains("/metadata")) + } + + // Enable auto snapshot repair + withSQLConf(SQLConf.STATE_STORE_AUTO_SNAPSHOT_REPAIR_ENABLED.key -> true.toString, + SQLConf.STATE_STORE_AUTO_SNAPSHOT_REPAIR_NUM_FAILURES_BEFORE_ACTIVATING.key -> "1", + SQLConf.STATE_STORE_AUTO_SNAPSHOT_REPAIR_MAX_CHANGE_FILE_REPLAY.key -> "5" + ) { + withDB(remoteDir) { db => + // this should now succeed + db.load(4) + assert(toStr(db.get("a")) == "0") + db.put("e", "4") + db.commit() // a new snapshot (5.zip) will be created since previous one is corrupt + assert(db.metricsOpt.get.numSnapshotsAutoRepaired == 1) + db.doMaintenance() // upload snapshot 5.zip + } + + // corrupt all snapshot files + Seq(2, 5).foreach { v => corruptFile(new File(remoteDir, s"$v.zip")) } + + withDB(remoteDir) { db => + // this load should succeed due to auto repair, even though all snapshots are bad + db.load(5) + assert(toStr(db.get("b")) == "1") + db.commit() + assert(db.metricsOpt.get.numSnapshotsAutoRepaired == 1) + } + } + } + } + } + testWithChangelogCheckpointingEnabled("SPARK-51922 - Changelog writer v1 with large key" + " does not cause UTFDataFormatException") { val remoteDir = Utils.createTempDir() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 6bb64315e356..807397d96918 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, IOException, ObjectInputStream, ObjectOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, IOException, ObjectInputStream, ObjectOutputStream, PrintWriter} import java.net.URI import java.util import java.util.UUID @@ -1420,6 +1420,92 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1")) } + test("Auto snapshot repair") { + withSQLConf( + SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED.key -> false.toString, + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1" // for hdfs means every 2 versions + ) { + val storeId = StateStoreId(newDir(), 0L, 1) + val remoteDir = storeId.storeCheckpointLocation().toString + + def numSnapshotsAutoRepaired(store: StateStore): Long = { + store.metrics.customMetrics + .find(m => m._1.name == "numSnapshotsAutoRepaired").get._2 + } + + tryWithProviderResource(newStoreProviderWithClonedConf(storeId)) { provider => + var store = provider.getStore(0) + put(store, "a", 0, 0) + store.commit() + assert(numSnapshotsAutoRepaired(store) == 0) + + store = provider.getStore(1) + put(store, "b", 1, 1) + store.commit() + assert(numSnapshotsAutoRepaired(store) == 0) + provider.doMaintenance() // upload snapshot 2.snapshot + + store = provider.getStore(2) + put(store, "c", 2, 2) + store.commit() + assert(numSnapshotsAutoRepaired(store) == 0) + + store = provider.getStore(3) + put(store, "d", 3, 3) + store.commit() + assert(numSnapshotsAutoRepaired(store) == 0) + provider.doMaintenance() // upload snapshot 4.snapshot + } + + def corruptFile(file: File): Unit = + // overwrite the file content to become empty + new PrintWriter(file) { close() } + + // corrupt 4.snapshot + corruptFile(new File(remoteDir, "4.snapshot")) + + tryWithProviderResource(newStoreProviderWithClonedConf(storeId)) { provider => + // this should fail when trying to load from remote + val ex = intercept[SparkException] { + provider.getStore(4) + } + assert(ex.getCause.isInstanceOf[java.io.EOFException]) + } + + // Enable auto snapshot repair + withSQLConf(SQLConf.STATE_STORE_AUTO_SNAPSHOT_REPAIR_ENABLED.key -> true.toString, + SQLConf.STATE_STORE_AUTO_SNAPSHOT_REPAIR_NUM_FAILURES_BEFORE_ACTIVATING.key -> "1", + SQLConf.STATE_STORE_AUTO_SNAPSHOT_REPAIR_MAX_CHANGE_FILE_REPLAY.key -> "6" + ) { + tryWithProviderResource(newStoreProviderWithClonedConf(storeId)) { provider => + // this should now succeed + var store = provider.getStore(4) + assert(get(store, "a", 0).contains(0)) + put(store, "e", 4, 4) + store.commit() + assert(numSnapshotsAutoRepaired(store) == 1) + + store = provider.getStore(5) + put(store, "f", 5, 5) + store.commit() + assert(numSnapshotsAutoRepaired(store) == 0) + provider.doMaintenance() // upload snapshot 6.snapshot + } + + // corrupt all snapshot files + Seq(2, 6).foreach { v => corruptFile(new File(remoteDir, s"$v.snapshot"))} + + tryWithProviderResource(newStoreProviderWithClonedConf(storeId)) { provider => + // this load should succeed due to auto repair, even though all snapshots are bad + val store = provider.getStore(6) + assert(get(store, "b", 1).contains(1)) + store.commit() + assert(numSnapshotsAutoRepaired(store) == 1) + } + } + } + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) }