[SPARK-50194][SS][PYTHON] Integration of New Timer API and Initial State API with Timer#48838
[SPARK-50194][SS][PYTHON] Integration of New Timer API and Initial State API with Timer#48838jingz-db wants to merge 25 commits intoapache:masterfrom
Conversation
bogao007
left a comment
There was a problem hiding this comment.
LGTM overall, just several minor comments.
| if timeMode != "none": | ||
| batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() | ||
| watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() | ||
| else: | ||
| batch_timestamp = -1 | ||
| watermark_timestamp = -1 |
There was a problem hiding this comment.
Can we abstract this as a separate method and share in both UDFs to reduce redundant code?
| Timer value for the current batch that process the input rows. | ||
| Users can get the processing or event time timestamp from TimerValues. | ||
| """ | ||
| return iter([]) |
There was a problem hiding this comment.
Why do we change the ... placeholder here?
There was a problem hiding this comment.
Sorry I confused this with the handleExpiredTimer. Changed back to ...
| """ | ||
| return iter([]) | ||
|
|
||
| def handleExpiredTimer( |
There was a problem hiding this comment.
Just double check that this method is not required for users to implement, correct?
There was a problem hiding this comment.
Correct. Add a comment line in the docstring to explicitly saying this is optional to implement.
| result_iter_list = [data_iter] | ||
| # process with valid expiry time info and with empty input rows, | ||
| # only timer related rows will be emitted | ||
| # process with expiry timers, only timer related rows will be emitted |
There was a problem hiding this comment.
I have confused about this every time. Is this relying on the behavior that expired timer will be removed so we won't list up the same timer as expired multiple times? This is very easy to be forgotten.
If there is any way we can just move this out and do this after we process all input? Can this be done in transformWithStateUDF/transformWithStateWithInitStateUDF with key = null?
There was a problem hiding this comment.
Thanks so much for catching this! I made a terrible correctness bug in my prior timer implementation. I now moved all timer handling codes into serializer.py where the expired timers are processed per partition.
There was a problem hiding this comment.
Left an explanation of what is causing the correctness issue in my prior implementation here just in case you are curious: #48838 (comment)
| return iter([]) | ||
|
|
||
| result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows) | ||
| batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient) |
There was a problem hiding this comment.
Ideally this shouldn't be called at every key. If we split out the handling of timer expiration from the handling of input rows, we would only need to call this at once.
There was a problem hiding this comment.
Moved this get_timestamps() into stateful_processor_api_client.py inside __init__() so we will only make an API call only once for each batch.
|
I'll revisit the PR once my comments are addressed (or @jingz-db has reasonable point of not doing this), as my proposal would change the code non-trivially. |
| """ | ||
| Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow | ||
| RecordBatches, and write batches to stream. | ||
| Read through chained return results from a single partition of handleInputRows. |
There was a problem hiding this comment.
@bogao007 Could you revisit this change? This was changed since last time you reviewed because I found a correctness bug in my prior timer change. Thanks!
There was a problem hiding this comment.
In my prior implementation, correctness issue happens if there are multiple keys expired on a single partition. E.g. test case test_transform_with_state_init_state_with_timers will fail if we set the partition to "1".
Previously we call get_expiry_timers_iterator() and handleExpiredTimer() in the group_ops.py inside the UDF which is called per key. So when we register timer for key "0" inside handleInitialState() and then we will enter get_expiry_timers_iterator(). Because at that time UDF of key "3" is not called yet, timer for key "3" is not registered. We will only see key "0" expires and will only get Row(id="0-expired") in the output of first batch. When we enter the UDF for key "3", as in TransformWithStateInPandasStateServer here we enforce expiryTimestampIter will only be consumed once per partition, JVM will return none for key "3" as this iterator is already consumed for key "0". This way we have a correctness issue.
There was a problem hiding this comment.
I now moved the handleExpiredTimer inside serializer.py, so get_expiry_timers_iterator() will be called after all handleInitialState() are executed for all keys on the partition, and it is also chained after all handleInputRows() are called on all keys on the same partition.
There was a problem hiding this comment.
Thanks @jingz-db for the detailed explaination! Do you think if we should add a test case where multiple keys are expired in the same partition? Like we either set partition num to 1 or increase the input to have more keys
There was a problem hiding this comment.
+1 to verify this explicitly from test.
There was a problem hiding this comment.
Added a test_transform_with_state_with_timers_single_partition to test with all timer suites with single partition.
bogao007
left a comment
There was a problem hiding this comment.
Did first pass after the change, LGTM overall, mostly minor comments
| else: | ||
| expiry_list_iter = iter([[]]) | ||
|
|
||
| def timer_iter_wrapper(func, *args, **kwargs): |
There was a problem hiding this comment.
Nit: can we move this method definition to the top of dump_stream to follow the same pattern in this file? This would also make the code easier to read.
There was a problem hiding this comment.
Moved just below statefulProcessorApiClient is initialized. We will need to access this object from timer_iter_wrapper.
| """ | ||
| Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow | ||
| RecordBatches, and write batches to stream. | ||
| Read through chained return results from a single partition of handleInputRows. |
There was a problem hiding this comment.
Maybe better to include what the structure looks like for input iterator given we have added a bunch of new objects as the UDF output. Either add it here or down below where args are being defined.
| """ | ||
| Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow | ||
| RecordBatches, and write batches to stream. | ||
| Read through chained return results from a single partition of handleInputRows. |
There was a problem hiding this comment.
Thanks @jingz-db for the detailed explaination! Do you think if we should add a test case where multiple keys are expired in the same partition? Like we either set partition num to 1 or increase the input to have more keys
| """ | ||
| Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow | ||
| RecordBatches, and write batches to stream. | ||
| Read through chained return results from a single partition of handleInputRows. |
There was a problem hiding this comment.
Any reason we can't do this in if key is None: in transformWithStateUDF and transformWithStateWithInitStateUDF?
This was my suggestion and I believe you can just do retrieve expired timers and timestamps, and call handleExpiredTimer() with these information, and done. I don't think this complication is necessary - if we can't do this in key is None in some reason, I suspect fixing that would be much easier.
There was a problem hiding this comment.
I'll wait for the next update about whether my suggestion works or not. I think the complexity would be very different, hence I would like to defer the further review after that.
There was a problem hiding this comment.
TLDR; if we put the timer handling code inside if key is None, we will add higher code complexity.
We need to make a tradeoff whether adding the complication in either serializer.py or TransformWithStateInPandasPythonRunner if we put the above timer handling codes in if key is None.
If we put the timer handling logic inside if key is None, we will need to call dump_stream() again here in finally code block: https://github.com/apache/spark/blob/master/python/pyspark/worker.py#L1966. Calling dump_stream() twice means we will need to properly handle how JVM receives batches. Currently we are reusing the read() function inside PythonArrowOutput, and the reader will end the reading when Python dump_stream signals the end here: https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala#L118. Since we are now calling dump_stream() twice, We will need to overwrite this function in TransformWithStateInPandasPythonRunner and continues reading one more time after receiving end. The extra complexity is that we will also need to properly handle the case where some partitions may not have timer iterator and won't start the additional dump stream writer at all and how we are going to handle exceptions if one of the dump_stream failed. Additionally, we need to set the statefulHandlerState to TIMER_PROCESSED after all timer rows are processed so we will need to do some code changes inside worker.py to set this properly. So this means we will need to get the StatefulProcessorHandlerApiClient object inside worker.py to set the state correctly. This means we will need to have similar code complexity of what we have now in serializer.py (return one extra StatefulProcessorHandlerApiClient from transformWithStateWithInitStateUDFand deserialize it from out_iter). We cannot set the TIMER_PROCESSED state in group_ops.py because the output rows iterator are not fully consumed there. It is fully consumed after dump_stream is called inside worker.py.
So either way we will need to deal with extra complexity. I personally think putting timer handling code into serializer.py is slightly better because this is more similar to how we are dealing with timer on Scala side - we are chaining the timer output rows after the data handling rows into a single iterator.
Let me know if you have suggestions on which way is better.
There was a problem hiding this comment.
This is the implementation of my suggestion based on 0c5ab3f. I've confirmed that pyspark.sql.tests.pandas.test_pandas_transform_with_state passed with this change - I haven't added new tests you've added later though.
I think this is lot much simpler - we just add two markers into input iterator which carries over the mode, and the flow does not change at all. No trick on teeing and chaining iterators, minimum changes on the data structure, etc.
How this works? This is just the same with how we use iterator in Spark in Scala codebase; with iterator in Scala, we pull one entry, process it and produce output, and pull another entry. The generator would have each entry for every grouping key, and then the marker for timer, and then the marker for completion. Each entry will call the function which eventually calls the user function, and the user function is expected to return the iterator, but the logic to produce the iterator should be synchronous (no async and no laziness, otherwise I guess it can even fail without my change).
So when the marker for timer has been evaluated, function calls for all grouping keys must have been already done. Same for the marker for completion. This is same with Scala implementation.
As a side effect, updating the phase is corrected in this commit.
There was a problem hiding this comment.
If you agree with this, please pick the commit in above. You've already gone through some commits and I can't revert partially by myself.
My fork is public, so you can add my repo and fetch and pull the branch, and cherrypick the commit into this PR branch with merge conflict. I'd recommend you to take whole different way - perform "hard reset" to my commit in this PR branch (git reset --hard f8952b213ba7f2cbfbc78ef145552317812e9f9b), and add more commits which are used to address other review comments.
There was a problem hiding this comment.
Thanks for putting out the commit! I cherry-picked your change and this is now looking much cleaner!
| self.max_state.update((max_event_time,)) | ||
| self.handle.registerTimer(timer_values.get_current_watermark_in_ms()) | ||
| self.max_state.update((max_event_time,)) | ||
| self.handle.registerTimer(timer_values.get_current_watermark_in_ms() + 1) |
There was a problem hiding this comment.
I modified this to current batch timestamp + 1 for testing with more common use cases as registering with current batch timestamp is not a very common use case.
There was a problem hiding this comment.
We should also check if timer can expire in the same batch. So I am keeping event time suite as timer expiring in same batch and register a future timestamp for the processing time suite.
HeartSaVioR
left a comment
There was a problem hiding this comment.
Only nit and linter failure. Thanks for the patience.
|
|
||
| for k, g in groupby(data_batches, key=lambda x: x[0]): | ||
| yield (k, g) | ||
| yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g) |
There was a problem hiding this comment.
nit: looks like not consistent? Here we use tuple with explicit () and below class we don't use (). Not a huge deal if linter does not complain, but while we are here (linter is failing)...
| if batch_id == 0: | ||
| assert set(batch_df.sort("id").collect()) == {Row(id="a", timestamp="20")} | ||
| elif batch_id == 1: | ||
| # check timer registered in the same batch is expired |
There was a problem hiding this comment.
nit: let's comment on watermark for late event and watermark for eviction per batch, to help verify the output. e.g. in batch_id == 1, watermark for eviction is 10, but the watermark for late event is 0, hence 4 is accepted. The value of timestamp in expired row will follow the value of watermark for eviction, hence also helpful.
|
I just pushed a commit addressing my own review comments as well as linter failure. These are nits so I think it wouldn't matter. |
|
CI has passed: https://github.com/jingz-db/spark/runs/33636466911 Thanks! Merging to master. |
What changes were proposed in this pull request?
As Scala side, we modify the timer API with a separate
handleExpiredTimerfunction insideStatefulProcessor, this PR make a change to the timer API to couple with API on Scala side. Also adds a timer parameter to pass intohandleInitialStatefunction to support use cases for registering timers in the first batch for initial state rows.Why are the changes needed?
This change is to couple with Scala side of APIs: #48553
Does this PR introduce any user-facing change?
Yes.
We add a new user defined function to explicitly handle expired timeres:
We also add a new timer parameter to enable users to register timers for keys exist in the initial state:
How was this patch tested?
Add a new test in
test_pandas_transform_with_stateWas this patch authored or co-authored using generative AI tooling?
No