Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-38809][SS] Implement option to skip null values in symmetric hash implementation of stream-stream joins #36090

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1897,6 +1897,19 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

/**
* SPARK-38809 - Config option to allow skipping null values for hash based stream-stream joins.
* Its possible for us to see nulls if state was written with an older version of Spark,
* the state was corrupted on disk or if we had an issue with the state iterators.
*/
val STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS =
buildConf("spark.sql.streaming.stateStore.skipNullsForStreamStreamJoins.enabled")
.internal()
.doc("When true, this config will skip null values in hash based stream-stream joins.")
.version("3.3.0")
.booleanConf
anishshri-db marked this conversation as resolved.
Show resolved Hide resolved
.createWithDefault(false)

val VARIABLE_SUBSTITUTE_ENABLED =
buildConf("spark.sql.variable.substitute")
.doc("This enables substitution using syntax like `${var}`, `${system:var}`, " +
Expand Down Expand Up @@ -3877,6 +3890,9 @@ class SQLConf extends Serializable with Logging {

def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED)

def stateStoreSkipNullsForStreamStreamJoins: Boolean =
getConf(STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS)

def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION)

def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class StateStoreConf(
val formatValidationCheckValue: Boolean =
extraOptions.getOrElse(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG, "true") == "true"

/** Whether to skip null values for hash based stream-stream joins. */
val skipNullsForStreamStreamJoins: Boolean = sqlConf.stateStoreSkipNullsForStreamStreamJoins

/** The compression codec used to compress delta and snapshot files. */
val compressionCodec: String = sqlConf.stateStoreCompressionCodec

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,12 @@ class SymmetricHashJoinStateManager(
valueRemoved = false
}

// Find the next value satisfying the condition, updating `currentKey` and `numValues` if
// needed. Returns null when no value can be found.
/**
* Find the next value satisfying the condition, updating `currentKey` and `numValues` if
* needed. Returns null when no value can be found.
* Note that we will skip nulls explicitly if config setting for the same is
* set to true via STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS.
*/
private def findNextValueForIndex(): ValueAndMatchPair = {
// Loop across all values for the current key, and then all other keys, until we find a
// value satisfying the removal condition.
Expand All @@ -233,7 +237,9 @@ class SymmetricHashJoinStateManager(
if (hasMoreValuesForCurrentKey) {
// First search the values for the current key.
val valuePair = keyWithIndexToValue.get(currentKey, index)
if (removalCondition(valuePair.value)) {
if (valuePair == null && storeConf.skipNullsForStreamStreamJoins) {
index += 1
} else if (removalCondition(valuePair.value)) {
return valuePair
} else {
index += 1
Expand Down Expand Up @@ -597,22 +603,30 @@ class SymmetricHashJoinStateManager(
/**
* Get all values and indices for the provided key.
* Should not return null.
* Note that we will skip nulls explicitly if config setting for the same is
* set to true via STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS.
*/
def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = {
val keyWithIndexAndValue = new KeyWithIndexAndValue()
var index = 0
new NextIterator[KeyWithIndexAndValue] {
private val keyWithIndexAndValue = new KeyWithIndexAndValue()
private var index: Long = 0L

private def hasMoreValues = index < numValues
override protected def getNext(): KeyWithIndexAndValue = {
if (index >= numValues) {
finished = true
null
} else {
while (hasMoreValues) {
anishshri-db marked this conversation as resolved.
Show resolved Hide resolved
val keyWithIndex = keyWithIndexRow(key, index)
val valuePair = valueRowConverter.convertValue(stateStore.get(keyWithIndex))
keyWithIndexAndValue.withNew(key, index, valuePair)
index += 1
keyWithIndexAndValue
if (valuePair == null && storeConf.skipNullsForStreamStreamJoins) {
index += 1
} else {
keyWithIndexAndValue.withNew(key, index, valuePair)
index += 1
return keyWithIndexAndValue
}
}

finished = true
return null
}

override protected def close(): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.LeftSide
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -52,6 +53,12 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter
}
}

SymmetricHashJoinStateManager.supportedVersions.foreach { version =>
test(s"StreamingJoinStateManager V${version} - all operations with nulls in middle") {
testAllOperationsWithNullsInMiddle(version)
}
}

SymmetricHashJoinStateManager.supportedVersions.foreach { version =>
test(s"SPARK-35689: StreamingJoinStateManager V${version} - " +
"printable key of keyWithIndexToValue") {
Expand Down Expand Up @@ -167,6 +174,55 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter
}
}

/* Test removeByValue with nulls in middle simulated by updating numValues on the state manager */
private def testAllOperationsWithNullsInMiddle(stateFormatVersion: Int): Unit = {
// Test with skipNullsForStreamStreamJoins set to false which would throw a
// NullPointerException while iterating and also return null values as part of get
withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion) { manager =>
implicit val mgr = manager

val ex = intercept[Exception] {
appendAndTest(40, 50, 200, 300)
assert(numRows === 3)
updateNumValues(40, 4) // create a null at the end
append(40, 400)
updateNumValues(40, 7) // create nulls in between and end
removeByValue(50)
}
assert(ex.isInstanceOf[NullPointerException])
assert(getNumValues(40) === 7) // we should get 7 with no nulls skipped

removeByValue(300)
assert(getNumValues(40) === 1) // only 400 should remain
assert(get(40) === Seq(400))
removeByValue(400)
assert(get(40) === Seq.empty)
assert(numRows === 0) // ensure all elements removed
}

// Test with skipNullsForStreamStreamJoins set to true which would skip nulls
// and continue iterating as part of removeByValue as well as get
withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion, true) { manager =>
implicit val mgr = manager

appendAndTest(40, 50, 200, 300)
assert(numRows === 3)
updateNumValues(40, 4) // create a null at the end
append(40, 400)
updateNumValues(40, 7) // create nulls in between and end

removeByValue(50)
assert(getNumValues(40) === 3) // we should now get (400, 200, 300) with nulls skipped

removeByValue(300)
assert(getNumValues(40) === 1) // only 400 should remain
assert(get(40) === Seq(400))
removeByValue(400)
assert(get(40) === Seq.empty)
assert(numRows === 0) // ensure all elements removed
}
}

val watermarkMetadata = new MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build()
val inputValueSchema = new StructType()
.add(StructField("time", IntegerType, metadata = watermarkMetadata))
Expand Down Expand Up @@ -205,6 +261,11 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter
manager.updateNumValuesTestOnly(toJoinKeyRow(key), numValues)
}

def getNumValues(key: Int)
(implicit manager: SymmetricHashJoinStateManager): Int = {
manager.get(toJoinKeyRow(key)).size
}

def get(key: Int)(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = {
manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted
}
Expand Down Expand Up @@ -232,22 +293,26 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter
manager.metrics.numKeys
}


def withJoinStateManager(
inputValueAttribs: Seq[Attribute],
joinKeyExprs: Seq[Expression],
stateFormatVersion: Int)(f: SymmetricHashJoinStateManager => Unit): Unit = {
inputValueAttribs: Seq[Attribute],
joinKeyExprs: Seq[Expression],
stateFormatVersion: Int,
skipNullsForStreamStreamJoins: Boolean = false)
(f: SymmetricHashJoinStateManager => Unit): Unit = {

withTempDir { file =>
val storeConf = new StateStoreConf()
val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5)
val manager = new SymmetricHashJoinStateManager(
LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration,
partitionId = 0, stateFormatVersion)
try {
f(manager)
} finally {
manager.abortIfNeeded()
withSQLConf(SQLConf.STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS.key ->
skipNullsForStreamStreamJoins.toString) {
val storeConf = new StateStoreConf(spark.sqlContext.conf)
val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5)
val manager = new SymmetricHashJoinStateManager(
LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration,
partitionId = 0, stateFormatVersion)
try {
f(manager)
} finally {
manager.abortIfNeeded()
}
}
}
StateStore.stop()
Expand Down