From 1f306f4c7dcb430f300e99c8e57e8ea43f49159f Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 4 Apr 2026 09:02:19 +0900 Subject: [PATCH 1/2] WIP better error on stream-stream join NPE --- .../org/apache/spark/internal/LogKeys.java | 2 + .../resources/error/error-conditions.json | 15 +++++ .../join/SymmetricHashJoinStateManager.scala | 55 ++++++++++++++++--- .../streaming/state/StateStoreErrors.scala | 28 ++++++++++ .../SymmetricHashJoinStateManagerSuite.scala | 17 ++---- 5 files changed, 98 insertions(+), 19 deletions(-) diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java index 59df0423fad26..a6ce6ae7945de 100644 --- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java +++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java @@ -319,6 +319,7 @@ public enum LogKeys implements LogKey { JOB_IDS, JOIN_CONDITION, JOIN_CONDITION_SUB_EXPR, + JOIN_SIDE, JOIN_TYPE, K8S_CONTEXT, KEY, @@ -536,6 +537,7 @@ public enum LogKeys implements LogKey { NUM_TASKS, NUM_TASK_CPUS, NUM_TRAIN_WORD, + NUM_VALUES, NUM_UNFINISHED_DECOMMISSIONED, NUM_VERSIONS_RETAIN, NUM_WEIGHTED_EXAMPLES, diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 4da32888e7370..073aa6ff423d1 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6757,6 +6757,21 @@ ], "sqlState" : "XXKST" }, + "STREAM_STREAM_JOIN_INCONSISTENT_STATE" : { + "message" : [ + "Detected inconsistency in stream-stream join state." + ], + "subClass" : { + "NULL_VALUE" : { + "message" : [ + "Value at index is null while numValues is .", + "joinSide=, storeVersion=, partitionId=.", + "Enable as a workaround to skip null values." + ] + } + }, + "sqlState" : "XXKST" + }, "STRUCT_ARRAY_LENGTH_MISMATCH" : { "message" : [ "Input row doesn't have expected number of values required by the schema. fields are required while values are provided." diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index fc2a69312fe79..1f758cc7a5bd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX, STATE_STORE_ID} +import org.apache.spark.internal.LogKeys.{END_INDEX, INDEX, JOIN_SIDE, KEY, NUM_VALUES, PARTITION_ID, START_INDEX, STATE_STORE_ID, STATE_STORE_VERSION} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes @@ -1025,7 +1025,9 @@ abstract class SymmetricHashJoinStateManagerBase( override def getNext(): JoinedRow = { while (index < numValues) { val valuePair = keyWithIndexToValue.get(key, index) - if (valuePair == null && storeConf.skipNullsForStreamStreamJoins) { + if (valuePair == null) { + handleNullValuePair(key, index, numValues) + skippedNullValueCount.foreach(_ += 1L) index += 1 } else if (valuePair.matched) { // See the NOTE in the method doc about rationale. @@ -1265,8 +1267,10 @@ abstract class SymmetricHashJoinStateManagerBase( /** * Find the next value satisfying the condition, updating `currentKey` and `numValues` if * needed. Returns null when no value can be found. - * Note that we will skip nulls explicitly if config setting for the same is - * set to true via STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS. + * + * Null values from the state store are handled by [[handleNullValuePair]]: + * skipped if STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS is enabled, + * otherwise an exception is thrown with diagnostic context. */ private def findNextValueForIndex(): ValueAndMatchPair = { // Loop across all values for the current key, and then all other keys, until we find a @@ -1277,7 +1281,9 @@ abstract class SymmetricHashJoinStateManagerBase( if (hasMoreValuesForCurrentKey) { // First search the values for the current key. val valuePair = keyWithIndexToValue.get(currentKey, index) - if (valuePair == null && storeConf.skipNullsForStreamStreamJoins) { + if (valuePair == null) { + handleNullValuePair(currentKey, index, numValues) + skippedNullValueCount.foreach(_ += 1L) index += 1 } else if (removalCondition(valuePair.value)) { return valuePair @@ -1328,6 +1334,40 @@ abstract class SymmetricHashJoinStateManagerBase( /** Projects the key of unsafe row to internal row for printable log message. */ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow = keyProjection(currentKey) + @volatile private var nullValueWarningLogged = false + + /** + * Called when keyWithIndexToValue.get(key, index) returns null even though + * keyToNumValues claims there should be a value at that index. + * + * Logs a warning (once per instance) with diagnostic context, then either + * returns normally (if skipNullsForStreamStreamJoins is enabled, caller skips) + * or throws a [[StreamStreamJoinInconsistentStateNullValue]]. + */ + protected def handleNullValuePair( + key: UnsafeRow, + nullIndex: Long, + numValues: Long): Unit = { + if (!nullValueWarningLogged) { + nullValueWarningLogged = true + logWarning(log"Null value detected in stream-stream join state: " + + log"joinSide=${MDC(JOIN_SIDE, joinSide)}, key=${MDC(KEY, keyProjection(key))}, " + + log"index=${MDC(INDEX, nullIndex)}, numValues=${MDC(NUM_VALUES, numValues)}, " + + log"storeVersion=${MDC(STATE_STORE_VERSION, stateInfo.get.storeVersion)}, " + + log"partitionId=${MDC(PARTITION_ID, partitionId)}") + } + + if (!storeConf.skipNullsForStreamStreamJoins) { + throw StateStoreErrors.streamStreamJoinNullValue( + valueIndex = nullIndex, + numValues = numValues, + joinSide = joinSide.toString, + storeVersion = stateInfo.get.storeVersion, + partitionId = partitionId, + configKey = SQLConf.STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS.key) + } + } + /** Commit all the changes to all the state stores */ def commit(): Unit @@ -1609,8 +1649,6 @@ abstract class SymmetricHashJoinStateManagerBase( /** * Get all values and indices for the provided key. * Should not return null. - * Note that we will skip nulls explicitly if config setting for the same is - * set to true via STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS. */ def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = { new NextIterator[KeyWithIndexAndValue] { @@ -1623,7 +1661,8 @@ abstract class SymmetricHashJoinStateManagerBase( val keyWithIndex = keyWithIndexRow(key, index) val valuePair = valueRowConverter.convertValue(stateStore.get(keyWithIndex, colFamilyName)) - if (valuePair == null && storeConf.skipNullsForStreamStreamJoins) { + if (valuePair == null) { + handleNullValuePair(key, index, numValues) skippedNullValueCount.foreach(_ += 1L) index += 1 } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 422d479fd1f54..a00c6a8bc73c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -306,6 +306,17 @@ object StateStoreErrors { StateStoreUnknownInternalColumnFamily = { new StateStoreUnknownInternalColumnFamily(colFamilyName) } + + def streamStreamJoinNullValue( + valueIndex: Long, + numValues: Long, + joinSide: String, + storeVersion: Long, + partitionId: Int, + configKey: String): StreamStreamJoinInconsistentStateNullValue = { + new StreamStreamJoinInconsistentStateNullValue( + valueIndex, numValues, joinSide, storeVersion, partitionId, configKey) + } } trait ConvertableToCannotLoadStoreError { @@ -674,3 +685,20 @@ class StateStoreBaseCheckpointIdMismatch( "actualBaseId" -> actualBaseId ) ) + +class StreamStreamJoinInconsistentStateNullValue( + valueIndex: Long, + numValues: Long, + joinSide: String, + storeVersion: Long, + partitionId: Int, + configKey: String) + extends SparkRuntimeException( + errorClass = "STREAM_STREAM_JOIN_INCONSISTENT_STATE.NULL_VALUE", + messageParameters = Map( + "valueIndex" -> valueIndex.toString, + "numValues" -> numValues.toString, + "joinSide" -> joinSide, + "storeVersion" -> storeVersion.toString, + "partitionId" -> partitionId.toString, + "configKey" -> configKey)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index 1042f01463b05..4717d2fb40544 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -24,6 +24,7 @@ import java.util.UUID import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfter +import org.apache.spark.{SparkRuntimeException, SparkThrowable} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression, GenericInternalRow, JoinedRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate @@ -329,8 +330,8 @@ class SymmetricHashJoinStateManagerSuite extends SymmetricHashJoinStateManagerBa /* Test removeByValue with nulls in middle simulated by updating numValues on the state manager */ private def testAllOperationsWithNullsInMiddle(stateFormatVersion: Int): Unit = { - // Test with skipNullsForStreamStreamJoins set to false which would throw a - // NullPointerException while iterating and also return null values as part of get + // Test with skipNullsForStreamStreamJoins set to false which would throw + // STREAM_STREAM_JOIN_INCONSISTENT_STATE.NULL_VALUE while iterating withJoinStateManager(inputValueAttributes, joinKeyExpressions, stateFormatVersion) { manager => implicit val mgr = manager @@ -342,15 +343,9 @@ class SymmetricHashJoinStateManagerSuite extends SymmetricHashJoinStateManagerBa updateNumValues(40, 7) // create nulls in between and end removeByValue(50) } - assert(ex.isInstanceOf[NullPointerException]) - assert(getNumValues(40) === 7) // we should get 7 with no nulls skipped - - removeByValue(300) - assert(getNumValues(40) === 1) // only 400 should remain - assert(get(40) === Seq(400)) - removeByValue(400) - assert(get(40) === Seq.empty) - assertNumRows(stateFormatVersion, 0) // ensure all elements removed + assert(ex.isInstanceOf[SparkRuntimeException]) + assert(ex.asInstanceOf[SparkThrowable].getCondition == + "STREAM_STREAM_JOIN_INCONSISTENT_STATE.NULL_VALUE") } // Test with skipNullsForStreamStreamJoins set to true which would skip nulls From 44e57b318dd7f584569a0eed1051e265625d48a3 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 6 Apr 2026 14:09:32 +0900 Subject: [PATCH 2/2] LogKeys orderliness fix --- .../src/main/java/org/apache/spark/internal/LogKeys.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java index a6ce6ae7945de..a13153429541c 100644 --- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java +++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java @@ -537,8 +537,8 @@ public enum LogKeys implements LogKey { NUM_TASKS, NUM_TASK_CPUS, NUM_TRAIN_WORD, - NUM_VALUES, NUM_UNFINISHED_DECOMMISSIONED, + NUM_VALUES, NUM_VERSIONS_RETAIN, NUM_WEIGHTED_EXAMPLES, NUM_WORKERS,