Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-46961][SS] Using ProcessorContext to store and retrieve handle #45359

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -3320,6 +3320,13 @@
],
"sqlState" : "42802"
},
"STATE_STORE_HANDLE_NOT_INITIALIZED" : {
"message" : [
"The handle has not been initialized for this StatefulProcessor.",
"Please only use the StatefulProcessor within the transformWithState operator."
],
"sqlState" : "42802"
},
"STATE_STORE_MULTIPLE_VALUES_PER_KEY" : {
"message" : [
"Store does not support multiple values per key"
Expand Down
7 changes: 7 additions & 0 deletions docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2079,6 +2079,13 @@ Star (*) is not allowed in a select list when GROUP BY an ordinal position is us

Failed to remove default column family with reserved name=`<colFamilyName>`.

### STATE_STORE_HANDLE_NOT_INITIALIZED

[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)

The handle has not been initialized for this StatefulProcessor.
Please only use the StatefulProcessor within the transformWithState operator.

### STATE_STORE_MULTIPLE_VALUES_PER_KEY

[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
e)
}

def stateStoreHandleNotInitialized(): SparkRuntimeException = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to the StateStoreErrors.scala to keep error classes in this area in a common location ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericm-db Could you move it, please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MaxGekk we can't move it to StateStoreErrors.scala because we can't use it in sql/api/ where we throw the error

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. ok.

new SparkRuntimeException(
errorClass = "STATE_STORE_HANDLE_NOT_INITIALIZED",
messageParameters = Map.empty)
}

def failToRecognizePatternAfterUpgradeError(
pattern: String, e: Throwable): SparkUpgradeException = {
new SparkUpgradeException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming
import java.io.Serializable

import org.apache.spark.annotation.{Evolving, Experimental}
import org.apache.spark.sql.errors.ExecutionErrors

/**
* Represents the arbitrary stateful logic that needs to be provided by the user to perform
Expand All @@ -29,17 +30,18 @@ import org.apache.spark.annotation.{Evolving, Experimental}
@Evolving
private[sql] trait StatefulProcessor[K, I, O] extends Serializable {

/**
* Handle to the stateful processor that provides access to the state store and other
* stateful processing related APIs.
*/
private var statefulProcessorHandle: StatefulProcessorHandle = null

/**
* Function that will be invoked as the first method that allows for users to
* initialize all their state variables and perform other init actions before handling data.
* @param handle - reference to the statefulProcessorHandle that the user can use to perform
* actions like creating state variables, accessing queryInfo etc. Please refer to
* [[StatefulProcessorHandle]] for more details.
* @param outputMode - output mode for the stateful processor
*/
def init(
handle: StatefulProcessorHandle,
outputMode: OutputMode): Unit
def init(outputMode: OutputMode): Unit

/**
* Function that will allow users to interact with input data rows along with the grouping key
Expand All @@ -59,5 +61,27 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable {
* Function called as the last method that allows for users to perform
* any cleanup or teardown operations.
*/
def close (): Unit
def close (): Unit = {}

/**
* Function to set the stateful processor handle that will be used to interact with the state
* store and other stateful processor related operations.
*
* @param handle - instance of StatefulProcessorHandle
*/
final def setHandle(handle: StatefulProcessorHandle): Unit = {
statefulProcessorHandle = handle
}

/**
* Function to get the stateful processor handle that will be used to interact with the state
*
* @return handle - instance of StatefulProcessorHandle
*/
final def getHandle: StatefulProcessorHandle = {
if (statefulProcessorHandle == null) {
throw ExecutionErrors.stateStoreHandleNotInitialized()
}
statefulProcessorHandle
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ case class TransformWithStateExec(
setStoreMetrics(store)
setOperatorMetrics()
statefulProcessor.close()
statefulProcessor.setHandle(null)
processorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
})
}
Expand Down Expand Up @@ -228,7 +229,8 @@ case class TransformWithStateExec(
val processorHandle = new StatefulProcessorHandleImpl(
store, getStateInfo.queryRunId, keyEncoder, isStreaming)
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
statefulProcessor.init(processorHandle, outputMode)
statefulProcessor.setHandle(processorHandle)
statefulProcessor.init(outputMode)
processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
processDataWithPartition(singleIterator, store, processorHandle)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ case class InputRow(key: String, action: String, value: String)
class TestListStateProcessor
extends StatefulProcessor[String, InputRow, (String, String)] {

@transient var _processorHandle: StatefulProcessorHandle = _
@transient var _listState: ListState[String] = _

override def init(handle: StatefulProcessorHandle, outputMode: OutputMode): Unit = {
_processorHandle = handle
_listState = handle.getListState("testListState")
override def init(outputMode: OutputMode): Unit = {
_listState = getHandle.getListState("testListState")
}

override def handleInputRows(
Expand Down Expand Up @@ -84,14 +82,12 @@ class TestListStateProcessor
class ToggleSaveAndEmitProcessor
extends StatefulProcessor[String, String, String] {

@transient var _processorHandle: StatefulProcessorHandle = _
@transient var _listState: ListState[String] = _
@transient var _valueState: ValueState[Boolean] = _

override def init(handle: StatefulProcessorHandle, outputMode: OutputMode): Unit = {
_processorHandle = handle
_listState = handle.getListState("testListState")
_valueState = handle.getValueState("testValueState")
override def init(outputMode: OutputMode): Unit = {
_listState = getHandle.getListState("testListState")
_valueState = getHandle.getValueState("testValueState")
}

override def handleInputRows(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.streaming

import org.apache.spark.SparkException
import org.apache.spark.{SparkException, SparkRuntimeException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, StateStoreMultipleColumnFamiliesNotSupportedException}
Expand All @@ -30,14 +30,9 @@ object TransformWithStateSuiteUtils {
class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)]
with Logging {
@transient private var _countState: ValueState[Long] = _
@transient var _processorHandle: StatefulProcessorHandle = _

override def init(
handle: StatefulProcessorHandle,
outputMode: OutputMode) : Unit = {
_processorHandle = handle
assert(handle.getQueryInfo().getBatchId >= 0)
_countState = _processorHandle.getValueState[Long]("countState")

override def init(outputMode: OutputMode): Unit = {
_countState = getHandle.getValueState[Long]("countState")
}

override def handleInputRows(
Expand All @@ -62,17 +57,11 @@ class RunningCountMostRecentStatefulProcessor
with Logging {
@transient private var _countState: ValueState[Long] = _
@transient private var _mostRecent: ValueState[String] = _
@transient var _processorHandle: StatefulProcessorHandle = _

override def init(
handle: StatefulProcessorHandle,
outputMode: OutputMode) : Unit = {
_processorHandle = handle
assert(handle.getQueryInfo().getBatchId >= 0)
_countState = _processorHandle.getValueState[Long]("countState")
_mostRecent = _processorHandle.getValueState[String]("mostRecent")
}

override def init(outputMode: OutputMode): Unit = {
_countState = getHandle.getValueState[Long]("countState")
_mostRecent = getHandle.getValueState[String]("mostRecent")
}
override def handleInputRows(
key: String,
inputRows: Iterator[(String, String)],
Expand All @@ -96,15 +85,10 @@ class MostRecentStatefulProcessorWithDeletion
extends StatefulProcessor[String, (String, String), (String, String)]
with Logging {
@transient private var _mostRecent: ValueState[String] = _
@transient var _processorHandle: StatefulProcessorHandle = _

override def init(
handle: StatefulProcessorHandle,
outputMode: OutputMode) : Unit = {
_processorHandle = handle
assert(handle.getQueryInfo().getBatchId >= 0)
_processorHandle.deleteIfExists("countState")
_mostRecent = _processorHandle.getValueState[String]("mostRecent")

override def init(outputMode: OutputMode): Unit = {
getHandle.deleteIfExists("countState")
_mostRecent = getHandle.getValueState[String]("mostRecent")
}

override def handleInputRows(
Expand Down Expand Up @@ -132,7 +116,7 @@ class RunningCountStatefulProcessorWithError extends RunningCountStatefulProcess
inputRows: Iterator[String],
timerValues: TimerValues): Iterator[(String, String)] = {
// Trying to create value state here should fail
_tempState = _processorHandle.getValueState[Long]("tempState")
_tempState = getHandle.getValueState[Long]("tempState")
Iterator.empty
}
}
Expand Down Expand Up @@ -195,6 +179,20 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}
}

test("Use statefulProcessor without transformWithState - handle should be absent") {
val processor = new RunningCountStatefulProcessor()
val ex = intercept[Exception] {
processor.getHandle
}
checkError(
ex.asInstanceOf[SparkRuntimeException],
errorClass = "STATE_STORE_HANDLE_NOT_INITIALIZED",
parameters = Map.empty
)
assert(ex.getMessage.contains("The handle has not been initialized" +
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
" for this StatefulProcessor."))
}

test("transformWithState - batch should succeed") {
val inputData = Seq("a", "b")
val result = inputData.toDS()
Expand Down