diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala index 0634a2f05b41..808ac8e6226b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala @@ -1317,8 +1317,8 @@ abstract class BaseStreamingDeduplicateExec val result = baseIterator.filter { r => val row = r.asInstanceOf[UnsafeRow] val key = getKey(row) - val value = store.get(key) - if (value == null) { + val keyExists = store.keyExists(key) + if (!keyExists) { putDupInfoIntoState(store, row, key, reusedDupInfoRow) numUpdatedStateRows += 1 numOutputRows += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index b1c9dee5a459..77b3aa821c6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -974,6 +974,26 @@ class RocksDB( } } + /** + * This method should gives a 100% guarantee of a correct result, whether the key exists or + * not. + * + * @param key The key to check + * @param cfName The column family name + * @return true if the key exists, false otherwise + */ + def keyExists( + key: Array[Byte], + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Boolean = { + updateMemoryUsageIfNeeded() + val keyWithPrefix = if (useColumnFamilies) { + encodeStateRowWithPrefix(key, cfName) + } else { + key + } + db.keyExists(keyWithPrefix) + } + /** * Get the values for a given key if present, that were merged (via merge). * This returns the values as an iterator of index range, to allow inline access diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 88dad93b5d15..65f42d0f1684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -251,6 +251,15 @@ private[sql] class RocksDBStateStoreProvider value } + override def keyExists(key: UnsafeRow, colFamilyName: String): Boolean = { + validateAndTransitionState(UPDATE) + verify(key != null, "Key cannot be null") + verifyColFamilyOperations("keyExists", colFamilyName) + + val kvEncoder = keyValueEncoderMap.get(colFamilyName) + rocksDB.keyExists(kvEncoder._1.encodeKey(key), colFamilyName) + } + /** * Provides an iterator containing all values of a non-null key. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ee2048c8d95f..15a4d517661d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -115,6 +115,23 @@ trait ReadStateStore { key: UnsafeRow, colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow + /** + * Check if a key exists in the store, with 100% guarantee of a correct result. + * + * Default implementation calls get() and checks if the result is null. + * Implementations backed by RocksDB should override this to use the native + * keyExists() method for better performance. + * + * @param key The key to check + * @param colFamilyName The column family name + * @return true if the key exists, false if it doesn't exist + */ + def keyExists( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Boolean = { + get(key, colFamilyName) != null + } + /** * Provides an iterator containing all values of a non-null key. If key does not exist, * an empty iterator is returned. Implementations should make sure to return an empty @@ -305,6 +322,12 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore { colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow = store.get(key, colFamilyName) + override def keyExists( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Boolean = { + store.keyExists(key, colFamilyName) + } + override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) : StateStoreIterator[UnsafeRowPair] = store.iterator(colFamilyName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index f7d7d0bc921f..b4cc6f315eb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -70,6 +70,12 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta innerStore.get(key, colFamilyName) } + override def keyExists( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Boolean = { + innerStore.keyExists(key, colFamilyName) + } + override def valuesIterator( key: UnsafeRow, colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 61e551b851c2..ec820bcb583b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -1393,6 +1393,79 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures with SharedSparkSession encodeMethod.invoke(db, key.getBytes, cfName).asInstanceOf[Array[Byte]] } + testWithStateStoreCheckpointIdsAndColumnFamilies( + "RocksDB: keyExists over 1000 random keys across CFs", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { + case (enableStateStoreCheckpointIds, colFamiliesEnabled) => + val remoteDir = Utils.createTempDir().toString + new File(remoteDir).delete() + + val conf = dbConf.copy(compactOnCommit = false) + withDB( + remoteDir, + conf = conf, + useColumnFamilies = colFamiliesEnabled, + enableStateStoreCheckpointIds = enableStateStoreCheckpointIds) { db => + val totalPresent = 500 + val totalAbsent = 500 + + // Generate present and absent keys using simple disjoint prefixes + val presentKeysAll = (0 until totalPresent).map(i => s"present_$i") + + // Insert present keys + db.load(0) + // If column families are enabled, create a CF and use it uniformly (after load) + val cfNameOpt = + if (colFamiliesEnabled) { + val cf = "test_cf_random" + db.createColFamilyIfAbsent(cf, isInternal = false) + Some(cf) + } else { + None + } + cfNameOpt match { + case Some(cf) => + presentKeysAll.foreach { k => db.put(k, s"v_$k", cf) } + case None => + presentKeysAll.foreach { k => db.put(k, s"v_$k") } + } + + // Generate absent keys using a different prefix to avoid overlap + val absentKeysAll = (0 until totalAbsent).map(i => s"absent_$i") + + // Validation helper to avoid duplication + def validate(label: String): Unit = { + cfNameOpt match { + case Some(cf) => + presentKeysAll.foreach { k => + assert(db.keyExists(k, cf), + s"$label Expected keyExists(true) for present CF key $k") + } + absentKeysAll.foreach { k => + assert(!db.keyExists(k, cf), + s"$label Expected keyExists(false) for absent CF key $k") + } + case None => + presentKeysAll.foreach { k => + assert(db.keyExists(k), + s"$label Expected keyExists(true) for present default key $k") + } + absentKeysAll.foreach { k => + assert(!db.keyExists(k), + s"$label Expected keyExists(false) for absent default key $k") + } + } + } + + // First check before commit + validate("(pre-commit)") + + // Commit and re-check + db.commit() + validate("(post-commit)") + } + } + testWithStateStoreCheckpointIdsAndColumnFamilies(s"RocksDB: get, put, iterator, commit, load", TestWithBothChangelogCheckpointingEnabledAndDisabled) { case (enableStateStoreCheckpointIds, colFamiliesEnabled) =>