Skip to content

Commit

Permalink
partially resolve comments, will refactor test case
Browse files Browse the repository at this point in the history
  • Loading branch information
jingz-db committed Mar 13, 2024
1 parent 8fbd501 commit 25d7bab
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable {
}

/**
* Similar usage as StatefulProcessor. Represents the arbitrary stateful logic that needs to
* be provided by the user to perform stateful manipulations on keyed streams.
* Stateful processor with support for specifying initial state.
* Accepts a user-defined type as initial state to be initialized in the first batch.
* This can be used for starting a new streaming query with existing state from a
* previous streaming query.
*/
@Experimental
@Evolving
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ case class FlatMapGroupsWithState(
}

object TransformWithState {
def apply[K: Encoder, V: Encoder, U: Encoder, S: Encoder](
def apply[K: Encoder, V: Encoder, U: Encoder](
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
statefulProcessor: StatefulProcessor[K, V, U],
Expand All @@ -595,7 +595,7 @@ object TransformWithState {
initialStateDataAttrs = dataAttributes,
initialStateDeserializer =
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
initialState = LocalRelation(encoderFor[S].schema) // empty data set
initialState = LocalRelation(encoderFor[K].schema) // empty data set
)
CatalystSerde.serialize[U](mapped)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
outputMode: OutputMode = OutputMode.Append()): Dataset[U] = {
Dataset[U](
sparkSession,
// The last K type is only to silence compiler error
TransformWithState[K, V, U, K](
TransformWithState[K, V, U](
groupingAttributes,
dataAttributes,
statefulProcessor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ case class TransformWithStateExec(
override def requiredChildDistribution: Seq[Distribution] = {
StatefulOperatorPartitioning.getCompatibleDistribution(
groupingAttributes, getStateInfo, conf) ::
StatefulOperatorPartitioning.getCompatibleDistribution(
initialStateGroupingAttrs, getStateInfo, conf) ::
Nil
StatefulOperatorPartitioning.getCompatibleDistribution(
initialStateGroupingAttrs, getStateInfo, conf) ::
Nil
}

/**
Expand Down Expand Up @@ -136,8 +136,9 @@ case class TransformWithStateExec(
mappedIterator
}

private def processInitialStateRows(keyRow: UnsafeRow, initStateIter: Iterator[InternalRow]):
Unit = {
private def processInitialStateRows(
keyRow: UnsafeRow,
initStateIter: Iterator[InternalRow]): Unit = {
val getKeyObj =
ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)

Expand Down

0 comments on commit 25d7bab

Please sign in to comment.