Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47570][SS] Integrate range scan encoder changes with timer implementation #45709

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,14 @@ class StatefulProcessorHandleImpl(
}

/**
* Function to retrieve all registered timers for all grouping keys
* Function to retrieve all expired registered timers for all grouping keys
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function
* will return all timers that have timestamp less than passed threshold
* @return - iterator of registered timers for all grouping keys
*/
def getExpiredTimers(): Iterator[(Any, Long)] = {
def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
verifyTimerOperations("get_expired_timers")
timerState.getExpiredTimers()
timerState.getExpiredTimers(expiryTimestampMs)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class TimerStateImpl(

val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF
store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex,
schemaForValueRow, NoPrefixKeyStateEncoderSpec(keySchemaForSecIndex),
schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, 1),
useMultipleValuesPerKey = false, isInternal = true)

private def getGroupingKey(cfName: String): Any = {
Expand All @@ -110,7 +110,6 @@ class TimerStateImpl(

// We maintain a secondary index that inverts the ordering of the timestamp
// and grouping key
// TODO: use range scan encoder to encode the secondary index key
private def encodeSecIndexKey(groupingKey: Any, expiryTimestampMs: Long): UnsafeRow = {
val keyByteArr = keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes()
val keyRow = secIndexKeyEncoder(InternalRow(expiryTimestampMs, keyByteArr))
Expand Down Expand Up @@ -187,10 +186,15 @@ class TimerStateImpl(
}

/**
* Function to get all the registered timers for all grouping keys
* Function to get all the expired registered timers for all grouping keys.
* Perform a range scan on timestamp and will stop iterating once the key row timestamp equals or
* exceeds the limit (as timestamp key is increasingly sorted).
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe add a small comment here mentioning that we perform a range scan and stop iterating once the key row timestamp exceeds the threshold

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is not user-facing, right? If it is, I'd suggest avoiding implementation detail. Looks like as it doesn't seem to be an user facing, but just to remind.

Copy link
Contributor Author

@jingz-db jingz-db Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out (I did not realize this before but looks like we did the right thing)!

* will return all timers that have timestamp less than passed threshold.
* @return - iterator of all the registered timers for all grouping keys
*/
def getExpiredTimers(): Iterator[(Any, Long)] = {
def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
// this iter is increasingly sorted on timestamp
val iter = store.iterator(tsToKeyCFName)

new NextIterator[(Any, Long)] {
Expand All @@ -199,7 +203,12 @@ class TimerStateImpl(
val rowPair = iter.next()
val keyRow = rowPair.key
val result = getTimerRowFromSecIndex(keyRow)
result
if (result._2 < expiryTimestampMs) {
result
} else {
finished = true
null.asInstanceOf[(Any, Long)]
}
} else {
finished = true
null.asInstanceOf[(Any, Long)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,26 +160,18 @@ case class TransformWithStateExec(
case ProcessingTime =>
assert(batchTimestampMs.isDefined)
val batchTimestamp = batchTimestampMs.get
val procTimeIter = processorHandle.getExpiredTimers()
procTimeIter.flatMap { case (keyObj, expiryTimestampMs) =>
if (expiryTimestampMs < batchTimestamp) {
processorHandle.getExpiredTimers(batchTimestamp)
.flatMap { case (keyObj, expiryTimestampMs) =>
handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
} else {
Iterator.empty
}
}

case EventTime =>
assert(eventTimeWatermarkForEviction.isDefined)
val watermark = eventTimeWatermarkForEviction.get
val eventTimeIter = processorHandle.getExpiredTimers()
eventTimeIter.flatMap { case (keyObj, expiryTimestampMs) =>
if (expiryTimestampMs < watermark) {
processorHandle.getExpiredTimers(watermark)
.flatMap { case (keyObj, expiryTimestampMs) =>
handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
} else {
Iterator.empty
}
}

case _ => Iterator.empty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class TimerSuite extends StateVariableSuiteBase {
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
timerState.registerTimer(1L * 1000)
assert(timerState.listTimers().toSet === Set(1000L))
assert(timerState.getExpiredTimers().toSet === Set(("test_key", 1000L)))
assert(timerState.getExpiredTimers(Long.MaxValue).toSeq === Seq(("test_key", 1000L)))
assert(timerState.getExpiredTimers(Long.MinValue).toSeq === Seq.empty[Long])

timerState.registerTimer(20L * 1000)
assert(timerState.listTimers().toSet === Set(20000L, 1000L))
Expand All @@ -69,8 +70,10 @@ class TimerSuite extends StateVariableSuiteBase {
timerState1.registerTimer(1L * 1000)
timerState2.registerTimer(15L * 1000)
assert(timerState1.listTimers().toSet === Set(15000L, 1000L))
assert(timerState1.getExpiredTimers().toSet ===
Set(("test_key", 15000L), ("test_key", 1000L)))
assert(timerState1.getExpiredTimers(Long.MaxValue).toSeq ===
Seq(("test_key", 1000L), ("test_key", 15000L)))
// if timestamp equals to expiryTimestampsMs, will not considered expired
assert(timerState1.getExpiredTimers(15000L).toSeq === Seq(("test_key", 1000L)))
assert(timerState1.listTimers().toSet === Set(15000L, 1000L))

timerState1.registerTimer(20L * 1000)
Expand Down Expand Up @@ -99,15 +102,67 @@ class TimerSuite extends StateVariableSuiteBase {
ImplicitGroupingKeyTracker.removeImplicitKey()

ImplicitGroupingKeyTracker.setImplicitKey("test_key1")
assert(timerState1.getExpiredTimers().toSet ===
Set(("test_key2", 15000L), ("test_key1", 2000L), ("test_key1", 1000L)))
assert(timerState1.getExpiredTimers(Long.MaxValue).toSeq ===
Seq(("test_key1", 1000L), ("test_key1", 2000L), ("test_key2", 15000L)))
assert(timerState1.getExpiredTimers(10000L).toSeq ===
Seq(("test_key1", 1000L), ("test_key1", 2000L)))
assert(timerState1.listTimers().toSet === Set(1000L, 2000L))
ImplicitGroupingKeyTracker.removeImplicitKey()

ImplicitGroupingKeyTracker.setImplicitKey("test_key2")
assert(timerState2.listTimers().toSet === Set(15000L))
assert(timerState2.getExpiredTimers().toSet ===
Set(("test_key2", 15000L), ("test_key1", 2000L), ("test_key1", 1000L)))
assert(timerState2.getExpiredTimers(1500L).toSeq === Seq(("test_key1", 1000L)))
}
}

testWithTimeOutMode("Range scan on second index timer key - " +
"verify timestamp is sorted for single instance") { timeoutMode =>
tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider =>
val store = provider.getStore(0)

ImplicitGroupingKeyTracker.setImplicitKey("test_key")
val timerState = new TimerStateImpl(store, timeoutMode,
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
val timerTimerstamps = Seq(931L, 8000L, 452300L, 4200L, 90L, 1L, 2L, 8L, 3L, 35L, 6L, 9L, 5L)
// register/put unordered timestamp into rocksDB
timerTimerstamps.foreach(timerState.registerTimer)
assert(timerState.getExpiredTimers(Long.MaxValue).toSeq.map(_._2) === timerTimerstamps.sorted)
assert(timerState.getExpiredTimers(4200L).toSeq.map(_._2) ===
timerTimerstamps.sorted.takeWhile(_ < 4200L))
assert(timerState.getExpiredTimers(Long.MinValue).toSeq === Seq.empty)
ImplicitGroupingKeyTracker.removeImplicitKey()
}
}

testWithTimeOutMode("test range scan on second index timer key - " +
"verify timestamp is sorted for multiple instances") { timeoutMode =>
tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider =>
val store = provider.getStore(0)

ImplicitGroupingKeyTracker.setImplicitKey("test_key1")
val timerState1 = new TimerStateImpl(store, timeoutMode,
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
val timerTimestamps1 = Seq(64L, 32L, 1024L, 4096L, 0L, 1L)
timerTimestamps1.foreach(timerState1.registerTimer)

val timerState2 = new TimerStateImpl(store, timeoutMode,
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
val timerTimestamps2 = Seq(931L, 8000L, 452300L, 4200L)
timerTimestamps2.foreach(timerState2.registerTimer)
ImplicitGroupingKeyTracker.removeImplicitKey()

ImplicitGroupingKeyTracker.setImplicitKey("test_key3")
val timerState3 = new TimerStateImpl(store, timeoutMode,
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
val timerTimerStamps3 = Seq(1L, 2L, 8L, 3L)
timerTimerStamps3.foreach(timerState3.registerTimer)
ImplicitGroupingKeyTracker.removeImplicitKey()

assert(timerState1.getExpiredTimers(Long.MaxValue).toSeq.map(_._2) ===
(timerTimestamps1 ++ timerTimestamps2 ++ timerTimerStamps3).sorted)
assert(timerState1.getExpiredTimers(Long.MinValue).toSeq === Seq.empty)
assert(timerState1.getExpiredTimers(8000L).toSeq.map(_._2) ===
(timerTimestamps1 ++ timerTimestamps2 ++ timerTimerStamps3).sorted.takeWhile(_ < 8000L))
}
}
}