Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ public enum LogKeys implements LogKey {
JOB_IDS,
JOIN_CONDITION,
JOIN_CONDITION_SUB_EXPR,
JOIN_SIDE,
JOIN_TYPE,
K8S_CONTEXT,
KEY,
Expand Down Expand Up @@ -537,6 +538,7 @@ public enum LogKeys implements LogKey {
NUM_TASK_CPUS,
NUM_TRAIN_WORD,
NUM_UNFINISHED_DECOMMISSIONED,
NUM_VALUES,
NUM_VERSIONS_RETAIN,
NUM_WEIGHTED_EXAMPLES,
NUM_WORKERS,
Expand Down
15 changes: 15 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -6757,6 +6757,21 @@
],
"sqlState" : "XXKST"
},
"STREAM_STREAM_JOIN_INCONSISTENT_STATE" : {
"message" : [
"Detected inconsistency in stream-stream join state."
],
"subClass" : {
"NULL_VALUE" : {
"message" : [
"Value at index <valueIndex> is null while numValues is <numValues>.",
"joinSide=<joinSide>, storeVersion=<storeVersion>, partitionId=<partitionId>.",
"Enable <configKey> as a workaround to skip null values."
]
}
},
"sqlState" : "XXKST"
},
"STRUCT_ARRAY_LENGTH_MISMATCH" : {
"message" : [
"Input row doesn't have expected number of values required by the schema. <expected> fields are required while <actual> values are provided."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX, STATE_STORE_ID}
import org.apache.spark.internal.LogKeys.{END_INDEX, INDEX, JOIN_SIDE, KEY, NUM_VALUES, PARTITION_ID, START_INDEX, STATE_STORE_ID, STATE_STORE_VERSION}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
Expand Down Expand Up @@ -1025,7 +1025,9 @@ abstract class SymmetricHashJoinStateManagerBase(
override def getNext(): JoinedRow = {
while (index < numValues) {
val valuePair = keyWithIndexToValue.get(key, index)
if (valuePair == null && storeConf.skipNullsForStreamStreamJoins) {
if (valuePair == null) {
handleNullValuePair(key, index, numValues)
skippedNullValueCount.foreach(_ += 1L)
index += 1
} else if (valuePair.matched) {
// See the NOTE in the method doc about rationale.
Expand Down Expand Up @@ -1265,8 +1267,10 @@ abstract class SymmetricHashJoinStateManagerBase(
/**
* Find the next value satisfying the condition, updating `currentKey` and `numValues` if
* needed. Returns null when no value can be found.
* Note that we will skip nulls explicitly if config setting for the same is
* set to true via STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS.
*
* Null values from the state store are handled by [[handleNullValuePair]]:
* skipped if STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS is enabled,
* otherwise an exception is thrown with diagnostic context.
*/
private def findNextValueForIndex(): ValueAndMatchPair = {
// Loop across all values for the current key, and then all other keys, until we find a
Expand All @@ -1277,7 +1281,9 @@ abstract class SymmetricHashJoinStateManagerBase(
if (hasMoreValuesForCurrentKey) {
// First search the values for the current key.
val valuePair = keyWithIndexToValue.get(currentKey, index)
if (valuePair == null && storeConf.skipNullsForStreamStreamJoins) {
if (valuePair == null) {
handleNullValuePair(currentKey, index, numValues)
skippedNullValueCount.foreach(_ += 1L)
index += 1
} else if (removalCondition(valuePair.value)) {
return valuePair
Expand Down Expand Up @@ -1328,6 +1334,40 @@ abstract class SymmetricHashJoinStateManagerBase(
/** Projects the key of unsafe row to internal row for printable log message. */
def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow = keyProjection(currentKey)

@volatile private var nullValueWarningLogged = false

/**
* Called when keyWithIndexToValue.get(key, index) returns null even though
* keyToNumValues claims there should be a value at that index.
*
* Logs a warning (once per instance) with diagnostic context, then either
* returns normally (if skipNullsForStreamStreamJoins is enabled, caller skips)
* or throws a [[StreamStreamJoinInconsistentStateNullValue]].
*/
protected def handleNullValuePair(
key: UnsafeRow,
nullIndex: Long,
numValues: Long): Unit = {
if (!nullValueWarningLogged) {
nullValueWarningLogged = true
logWarning(log"Null value detected in stream-stream join state: " +
log"joinSide=${MDC(JOIN_SIDE, joinSide)}, key=${MDC(KEY, keyProjection(key))}, " +
log"index=${MDC(INDEX, nullIndex)}, numValues=${MDC(NUM_VALUES, numValues)}, " +
log"storeVersion=${MDC(STATE_STORE_VERSION, stateInfo.get.storeVersion)}, " +
log"partitionId=${MDC(PARTITION_ID, partitionId)}")
}

if (!storeConf.skipNullsForStreamStreamJoins) {
throw StateStoreErrors.streamStreamJoinNullValue(
valueIndex = nullIndex,
numValues = numValues,
joinSide = joinSide.toString,
storeVersion = stateInfo.get.storeVersion,
partitionId = partitionId,
configKey = SQLConf.STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS.key)
}
}

/** Commit all the changes to all the state stores */
def commit(): Unit

Expand Down Expand Up @@ -1609,8 +1649,6 @@ abstract class SymmetricHashJoinStateManagerBase(
/**
* Get all values and indices for the provided key.
* Should not return null.
* Note that we will skip nulls explicitly if config setting for the same is
* set to true via STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS.
*/
def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = {
new NextIterator[KeyWithIndexAndValue] {
Expand All @@ -1623,7 +1661,8 @@ abstract class SymmetricHashJoinStateManagerBase(
val keyWithIndex = keyWithIndexRow(key, index)
val valuePair =
valueRowConverter.convertValue(stateStore.get(keyWithIndex, colFamilyName))
if (valuePair == null && storeConf.skipNullsForStreamStreamJoins) {
if (valuePair == null) {
handleNullValuePair(key, index, numValues)
skippedNullValueCount.foreach(_ += 1L)
index += 1
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,17 @@ object StateStoreErrors {
StateStoreUnknownInternalColumnFamily = {
new StateStoreUnknownInternalColumnFamily(colFamilyName)
}

def streamStreamJoinNullValue(
valueIndex: Long,
numValues: Long,
joinSide: String,
storeVersion: Long,
partitionId: Int,
configKey: String): StreamStreamJoinInconsistentStateNullValue = {
new StreamStreamJoinInconsistentStateNullValue(
valueIndex, numValues, joinSide, storeVersion, partitionId, configKey)
}
}

trait ConvertableToCannotLoadStoreError {
Expand Down Expand Up @@ -674,3 +685,20 @@ class StateStoreBaseCheckpointIdMismatch(
"actualBaseId" -> actualBaseId
)
)

class StreamStreamJoinInconsistentStateNullValue(
valueIndex: Long,
numValues: Long,
joinSide: String,
storeVersion: Long,
partitionId: Int,
configKey: String)
extends SparkRuntimeException(
errorClass = "STREAM_STREAM_JOIN_INCONSISTENT_STATE.NULL_VALUE",
messageParameters = Map(
"valueIndex" -> valueIndex.toString,
"numValues" -> numValues.toString,
"joinSide" -> joinSide,
"storeVersion" -> storeVersion.toString,
"partitionId" -> partitionId.toString,
"configKey" -> configKey))
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.UUID
import org.apache.hadoop.conf.Configuration
import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkRuntimeException, SparkThrowable}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression, GenericInternalRow, JoinedRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
Expand Down Expand Up @@ -329,8 +330,8 @@ class SymmetricHashJoinStateManagerSuite extends SymmetricHashJoinStateManagerBa

/* Test removeByValue with nulls in middle simulated by updating numValues on the state manager */
private def testAllOperationsWithNullsInMiddle(stateFormatVersion: Int): Unit = {
// Test with skipNullsForStreamStreamJoins set to false which would throw a
// NullPointerException while iterating and also return null values as part of get
// Test with skipNullsForStreamStreamJoins set to false which would throw
// STREAM_STREAM_JOIN_INCONSISTENT_STATE.NULL_VALUE while iterating
withJoinStateManager(inputValueAttributes, joinKeyExpressions, stateFormatVersion) { manager =>
implicit val mgr = manager

Expand All @@ -342,15 +343,9 @@ class SymmetricHashJoinStateManagerSuite extends SymmetricHashJoinStateManagerBa
updateNumValues(40, 7) // create nulls in between and end
removeByValue(50)
}
assert(ex.isInstanceOf[NullPointerException])
assert(getNumValues(40) === 7) // we should get 7 with no nulls skipped

removeByValue(300)
assert(getNumValues(40) === 1) // only 400 should remain
assert(get(40) === Seq(400))
removeByValue(400)
assert(get(40) === Seq.empty)
assertNumRows(stateFormatVersion, 0) // ensure all elements removed
assert(ex.isInstanceOf[SparkRuntimeException])
assert(ex.asInstanceOf[SparkThrowable].getCondition ==
"STREAM_STREAM_JOIN_INCONSISTENT_STATE.NULL_VALUE")
}

// Test with skipNullsForStreamStreamJoins set to true which would skip nulls
Expand Down