diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 5c55034e88df5..eb0e7ce76fc89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -122,6 +122,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } override def put(key: UnsafeRow, value: UnsafeRow): Unit = { + require(value != null, "Cannot put a null value") verify(state == UPDATING, "Cannot put after already committed or aborted") val keyCopy = key.copy() val valueCopy = value.copy() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 7c69e6f710bf7..ee4e2ae962d0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -99,8 +99,8 @@ trait ReadStateStore { trait StateStore extends ReadStateStore { /** - * Put a new value for a non-null key. Implementations must be aware that the UnsafeRows in - * the params can be reused, and must make copies of the data as needed for persistence. + * Put a new non-null value for a non-null key. Implementations must be aware that the UnsafeRows + * in the params can be reused, and must make copies of the data as needed for persistence. */ def put(key: UnsafeRow, value: UnsafeRow): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index dae771c613131..915b0abccf165 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -269,18 +269,14 @@ class SymmetricHashJoinStateManager( // The backing store is arraylike - we as the caller are responsible for filling back in // any hole. So we swap the last element into the hole and decrement numValues to shorten. // clean - if (numValues > 1) { + if (index != numValues - 1) { val valuePairAtMaxIndex = keyWithIndexToValue.get(currentKey, numValues - 1) if (valuePairAtMaxIndex != null) { keyWithIndexToValue.put(currentKey, index, valuePairAtMaxIndex.value, valuePairAtMaxIndex.matched) - } else { - keyWithIndexToValue.put(currentKey, index, null, false) } - keyWithIndexToValue.remove(currentKey, numValues - 1) - } else { - keyWithIndexToValue.remove(currentKey, 0) } + keyWithIndexToValue.remove(currentKey, numValues - 1) numValues -= 1 valueRemoved = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 7e8f955a4594a..b82d32e916797 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1012,6 +1012,19 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] assert(combinedMetrics.customMetrics(customTimingMetric) == 400L) } + test("SPARK-35659: StateStore.put cannot put null value") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(getLatestData(provider).isEmpty) + + val store = provider.getStore(0) + val err = intercept[IllegalArgumentException] { + store.put(stringToRow("key"), null) + } + assert(err.getMessage.contains("Cannot put a null value")) + } + /** Return a new provider with a random id */ def newStoreProvider(): ProviderClass