Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47363][SS] Initial State without state reader implementation for State API v2. #45467

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -3542,6 +3542,12 @@
],
"sqlState" : "42802"
},
"STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY" : {
"message" : [
"Cannot re-initialize state on the same grouping key during initial state handling for stateful processor. Invalid grouping key=<groupingKey>."
],
"sqlState" : "42802"
},
"STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : {
"message" : [
"Failed to create column family with unsupported starting character and name=<colFamilyName>."
Expand Down
6 changes: 6 additions & 0 deletions docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2174,6 +2174,12 @@ Failed to perform stateful processor operation=`<operationType>` with invalid ha

Failed to perform stateful processor operation=`<operationType>` with invalid timeoutMode=`<timeoutMode>`

### STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY

[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)

Cannot re-initialize state on the same grouping key during initial state handling for stateful processor. Invalid grouping key=`<groupingKey>`.

### STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS

[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,22 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable {
statefulProcessorHandle
}
}

/**
* Stateful processor with support for specifying initial state.
* Accepts a user-defined type as initial state to be initialized in the first batch.
* This can be used for starting a new streaming query with existing state from a
* previous streaming query.
*/
@Experimental
@Evolving
trait StatefulProcessorWithInitialState[K, I, O, S] extends StatefulProcessor[K, I, O] {

/**
* Function that will be invoked only in the first batch for users to process initial states.
*
* @param key - grouping key
* @param initialState - A row in the initial state to be processed
*/
def handleInitialState(key: K, initialState: S): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,46 @@ object TransformWithState {
outputMode,
keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
CatalystSerde.generateObjAttr[U],
child
child,
hasInitialState = false,
// the following parameters will not be used in physical plan if hasInitialState = false
initialStateGroupingAttrs = groupingAttributes,
initialStateDataAttrs = dataAttributes,
initialStateDeserializer =
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
initialState = LocalRelation(encoderFor[K].schema) // empty data set
)
CatalystSerde.serialize[U](mapped)
}

// This apply() is to invoke TransformWithState object with hasInitialState set to true
def apply[K: Encoder, V: Encoder, U: Encoder, S: Encoder](
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
statefulProcessor: StatefulProcessor[K, V, U],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
child: LogicalPlan,
initialStateGroupingAttrs: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
initialState: LogicalPlan): LogicalPlan = {
val keyEncoder = encoderFor[K]
val mapped = new TransformWithState(
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
groupingAttributes,
dataAttributes,
statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]],
timeoutMode,
outputMode,
keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
CatalystSerde.generateObjAttr[U],
child,
hasInitialState = true,
initialStateGroupingAttrs,
initialStateDataAttrs,
UnresolvedDeserializer(encoderFor[S].deserializer, initialStateDataAttrs),
initialState
)
CatalystSerde.serialize[U](mapped)
}
Expand All @@ -604,10 +643,18 @@ case class TransformWithState(
outputMode: OutputMode,
keyEncoder: ExpressionEncoder[Any],
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectProducer {
child: LogicalPlan,
hasInitialState: Boolean = false,
initialStateGroupingAttrs: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
initialStateDeserializer: Expression,
initialState: LogicalPlan) extends BinaryNode with ObjectProducer {

override protected def withNewChildInternal(newChild: LogicalPlan): TransformWithState =
copy(child = newChild)
override def left: LogicalPlan = child
override def right: LogicalPlan = initialState
override protected def withNewChildrenInternal(
newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithState =
copy(child = newLeft, initialState = newRight)
}

/** Factory for constructing new `FlatMapGroupsInR` nodes. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator
import org.apache.spark.sql.internal.TypedAggUtils
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode}
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode}

/**
* A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
Expand Down Expand Up @@ -676,6 +676,42 @@ class KeyValueGroupedDataset[K, V] private[sql](
)
}

/**
* (Scala-specific)
* Invokes methods defined in the stateful processor used in arbitrary state API v2.
* Functions as the function above, but with additional initial state.
*
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
* @param StatefulProcessorWithInitialState Instance of statefulProcessor whose functions will
* be invoked by the operator.
* @param timeoutMode The timeout mode of the stateful processor.
* @param outputMode The output mode of the stateful processor. Defaults to APPEND mode.
* @param initialState User provided initial state that will be used to initiate state for
* the query in the first batch.
*
*/
private[sql] def transformWithState[U: Encoder, S: Encoder](
statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
Dataset[U](
sparkSession,
TransformWithState[K, V, U, S](
groupingAttributes,
dataAttributes,
statefulProcessor,
timeoutMode,
outputMode,
child = logicalPlan,
initialState.groupingAttributes,
initialState.dataAttributes,
initialState.queryExecution.logical
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we follow the practice we did in flatMapGroupsWithState for safeness sake?

initialState.queryExecution.analyzed

)
)
}

/**
* (Scala-specific)
* Reduces the elements of each group of data using the specified binary function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case TransformWithState(
keyDeserializer, valueDeserializer, groupingAttributes,
dataAttributes, statefulProcessor, timeoutMode, outputMode,
keyEncoder, outputAttr, child) =>
keyEncoder, outputAttr, child, hasInitialState,
initialStateGroupingAttrs, initialStateDataAttrs,
initialStateDeserializer, initialState) =>
val execPlan = TransformWithStateExec(
keyDeserializer,
valueDeserializer,
Expand All @@ -767,7 +769,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
batchTimestampMs = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
planLater(child))
planLater(child),
isStreaming = true,
hasInitialState,
initialStateGroupingAttrs,
initialStateDataAttrs,
initialStateDeserializer,
planLater(initialState))
execPlan :: Nil
case _ =>
Nil
Expand Down Expand Up @@ -918,10 +926,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
) :: Nil
case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes,
dataAttributes, statefulProcessor, timeoutMode, outputMode, keyEncoder,
outputObjAttr, child) =>
outputObjAttr, child, hasInitialState,
initialStateGroupingAttrs, initialStateDataAttrs,
initialStateDeserializer, initialState) =>
TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer,
groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode,
keyEncoder, outputObjAttr, planLater(child)) :: Nil
keyEncoder, outputObjAttr, planLater(child), hasInitialState,
initialStateGroupingAttrs, initialStateDataAttrs,
initialStateDeserializer, planLater(initialState)) :: Nil

case _: FlatMapGroupsInPandasWithState =>
// TODO(SPARK-40443): support applyInPandasWithState in batch query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,13 @@ class IncrementalExecution(
)

case t: TransformWithStateExec =>
val hasInitialState = (isFirstBatch && t.hasInitialState)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to allow adding state in the middle of the query lifecycle. Here isFirstBatch does not mean batch ID = 0 but mean this is the first batch in this query run.

This should follow the above logic we did for FlatMapGroupsWithStateExec, currentBatchId == 0L.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please let me know if this is a different functionality than we had in flatMapGroupsWithState.

t.copy(
stateInfo = Some(nextStatefulOperationStateInfo()),
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None
eventTimeWatermarkForEviction = None,
hasInitialState = hasInitialState
)

case m: FlatMapGroupsInPandasWithStateExec =>
Expand Down