-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BEAM -7741] Implement SetState for Python SDK #9090
[BEAM -7741] Implement SetState for Python SDK #9090
Conversation
Run Portable_Python PreCommit |
R: @angoenka |
c990453
to
754f4ea
Compare
Run Python_PVR_Flink PreCommit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking at this.
def __init__(self, state_spec, state_tag, current_value_accessor): | ||
super(SetRuntimeState, self).__init__( | ||
state_spec, state_tag, current_value_accessor) | ||
# TODO: What is current_value_accessor? where does cached value is stored? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it gets the underlying value from the runner. It looks like it's only used in the direct runner.
def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE), | ||
timer1=DoFn.TimerParam(EXPIRY_TIMER)): | ||
unused_key, value = element | ||
buffer.add(str(value).encode('latin1')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just use StringCoder and avoid the explicit encoding? Or if the input is integers, just store the int itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I will take care of this.
|
||
@on_timer(EXPIRY_TIMER) | ||
def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE)): | ||
yield b''.join(sorted(buffer.read())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just yield sorted(buffer.read()).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure thing
def emit_values(self, bag_state=beam.DoFn.StateParam(SET_STATE)): | ||
for value in bag_state.read(): | ||
yield value | ||
yield 'extra' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| test_stream | ||
| beam.Map(lambda x: ('mykey', x)) | ||
| beam.ParDo(SimpleTestSetStatefulDoFn()) | ||
| beam.ParDo(self.record_dofn())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason assert_that doesn't work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tried the assert_that
here. I took reference from already existing test cases. I will update the test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert_that
doesn't work in case of state and timer. Here is the jira ticket filed long back: https://issues.apache.org/jira/browse/BEAM-5295 I can take care of this in a separate ticket.
|
||
@on_timer(CLEAR_TIMER) | ||
def clear_values(self, bag_state=beam.DoFn.StateParam(SET_STATE)): | ||
bag_state.clear() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe do an add, clear, add to make sure it really works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
self.assertEqual( | ||
['extra'], | ||
StatefulDoFnOnDirectRunnerTest.all_records) | ||
|
||
def test_stateful_dofn_nonkeyed_input(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add tests without using TestStream to ensure that the FnApiRunner paths also work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robertwb do you have an example of the test case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be pretty much the same as the existing test, but using Create rather than TestStream. Create just emits all elements at MIN_TIMESTAMP, and then you can set timers to perform future actions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some cases to make sure FnApiRunner path is also executed
Run Python PreCommit |
yield 'extra' | ||
|
||
@on_timer(CLEAR_TIMER) | ||
def clear_values(self, bag_state=beam.DoFn.StateParam(SET_STATE)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/bag_state/set_state/g
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
R: @robertwb Can you take another look? I have taken care of your comments. |
Run Portable_Python PreCommit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Just a couple more comments.
|
||
def read(self): | ||
return _ConcatIterable( | ||
{} if self._cleared else _StateBackedIterable( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{} is a dict, not a set. Also, should we be converting this to a frozen set here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still need to take care of this comment
def _commit(self): | ||
if self._cleared: | ||
self._state_handler.blocking_clear(self._state_key) | ||
if self._added_elements: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One downside of this way of doing things is that if we add the same element many times from different bundles, the underlying bag state keeps growing. Perhaps we should (periodically?) "compact" by reading everything in, eeduplicating via set(), and then writing it all back out. (A TODO is fine here if you don't want to implement it now.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer to have the compaction logic but I will do some research on my end before making the final decision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robertwb Check the latest implementation of read()
method. I think this is what you recommended. Let me know if you want me to implement it in different way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks.
@@ -424,18 +426,18 @@ def clear_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)): | |||
|
|||
def test_simple_set_stateful_dofn(self): | |||
class SimpleTestSetStatefulDoFn(DoFn): | |||
BUFFER_STATE = SetStateSpec('buffer', BytesCoder()) | |||
BUFFER_STATE = SetStateSpec('buffer', FastPrimitivesCoder()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the more specific VarIntCoder here instead.
bag_state.clear() | ||
def clear_values(self, set_state=beam.DoFn.StateParam(SET_STATE)): | ||
set_state.clear() | ||
set_state.add('value2') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add an int, something "different" like 100, for consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't SetState give us error because we are using StrUtf8Coder
and would be passing an int?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can pass some different string value instead.
StatefulDoFnOnDirectRunnerTest.all_records) | ||
self.assertEqual(['value2'], StatefulDoFnOnDirectRunnerTest.all_records) | ||
|
||
def test_stateful_set_state_fn_runner(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_stateful_set_state_portably
yield aggregated_value | ||
|
||
p = TestPipeline() | ||
values = p | beam.Create([('key', 1), ('key', 2), ('key', 3), ('key', 4)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicates?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I will add it.
set_state.add(value) | ||
for saved_value in set_state.read(): | ||
aggregated_value += saved_value | ||
yield aggregated_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe set a timer and test clear as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to add this test case:
def test_stateful_set_state_clean_portably(self):
class SetStateClearingStatefulDoFn(beam.DoFn):
SET_STATE = SetStateSpec('buffer', VarIntCoder())
EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
CLEAR_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK)
def process(self,
element,
set_state=beam.DoFn.StateParam(SET_STATE),
emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
clear_timer=beam.DoFn.TimerParam(CLEAR_TIMER)):
_, value = element
set_state.add(value)
if value == 2:
clear_timer.set(3)
elif value == 3:
emit_timer.set(4)
@on_timer(EMIT_TIMER)
def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
yield sorted(set_state.read())
@on_timer(CLEAR_TIMER)
def clear_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
set_state.clear()
set_state.add(100)
p = TestPipeline()
values = p | beam.Create(range(1, 5))
actual_values = (values
| beam.Map(lambda t: window.TimestampedValue(t, t))
| beam.WindowInto(window.FixedWindows(1))
| beam.ParDo(SetStateClearingStatefulDoFn()))
assert_that(actual_values, equal_to([100]))
result = p.run()
result.wait_until_finish()
but it didn't work it complained that timer is not supported.
(beam) ➜ python git:(BEAM-7741-implement-setstate) ✗ python -m unittest apache_beam.transforms.userstate_test.StatefulDoFnOnDirectRunnerTest.test_stateful_set_state_clean_portably
WARNING:root:Key coder FastPrimitivesCoder for transform <ParDo(PTransform) label=[ParDo(SetStateClearingStatefulDoFn)]> with stateful DoFn may not be deterministic. This may cause incorrect behavior for complex key types. Consider adding an input type hint for this transform.
E
======================================================================
ERROR: test_stateful_set_state_clean_portably (apache_beam.transforms.userstate_test.StatefulDoFnOnDirectRunnerTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "apache_beam/transforms/userstate_test.py", line 567, in test_stateful_set_state_clean_portably
result = p.run()
File "apache_beam/testing/test_pipeline.py", line 107, in run
else test_runner_api))
File "apache_beam/pipeline.py", line 406, in run
self._options).run(False)
File "apache_beam/pipeline.py", line 419, in run
return self.runner.run_pipeline(self, self._options)
File "apache_beam/runners/direct/direct_runner.py", line 128, in run_pipeline
return runner.run_pipeline(pipeline, options)
File "apache_beam/runners/portability/fn_api_runner.py", line 319, in run_pipeline
default_environment=self._default_environment))
File "apache_beam/runners/portability/fn_api_runner.py", line 323, in run_via_runner_api
stage_context, stages = self.create_stages(pipeline_proto)
File "apache_beam/runners/portability/fn_api_runner.py", line 385, in create_stages
use_state_iterables=self._use_state_iterables)
File "apache_beam/runners/portability/fn_api_runner_transforms.py", line 489, in create_and_optimize_stages
stages = list(phase(stages, pipeline_context))
File "apache_beam/runners/portability/fn_api_runner_transforms.py", line 1168, in inject_timer_pcollections
raise NotImplementedError('Timers and side inputs.')
NotImplementedError: Timers and side inputs.
----------------------------------------------------------------------
Ran 1 test in 0.207s
FAILED (errors=1)
This is a little bit confusing for me because I have used Timer on Flink runner and it works fine. This could be a bug on the portable runner side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still have to debug this, I might need some help to debug this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... for some reason it looks like it's assuming you're using a timer and a side input in the same DoFn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robertwb I spent some time debugging this. It seems like FnApiRunner is running into issues when there are two timers uses.
class TwoTimersDoFn(beam.DoFn):
EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
CLEAR_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK)
def process(self,
element,
emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
clear_timer=beam.DoFn.TimerParam(CLEAR_TIMER)
):
_, value = element
if value == 2:
clear_timer.set(3)
pass
elif value == 3:
emit_timer.set(4)
pass
@on_timer(EMIT_TIMER)
def emit_values(self):
yield 'emit1'
@on_timer(CLEAR_TIMER)
def clear_values(self):
yield 'clear1'
p = TestPipeline()
values = p | beam.Create([('key', 1),
('key', 2),
('key', 3),
('key', 4),
('key', 5),
('key', 6)])
actual_values = (values
| beam.Map(lambda t: window.TimestampedValue(t, t[1]))
| beam.WindowInto(window.FixedWindows(1))
| beam.ParDo(TwoTimersDoFn()))
assert_that(actual_values, equal_to([100]))
result = p.run()
result.wait_until_finish()
I tested it on the master branch and it failed consistently and I get the same error message. I couldn't figure out the root cause but it is caused by two timers for sure. It works fine when there is only one timer used.
I feel I should file a separate jira ticket to address this issue, thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the analysis. Filing a JIRA for the multiple-timer issue sounds like the right thing to do here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked, there is already a Jira ticket filed by @tweise (https://issues.apache.org/jira/browse/BEAM-7074)
@robertwb can you take a another look of this PR? also check this comment: https://github.com/apache/beam/pull/9090/files#r309342482 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looks good. Just a couple of small comments.
def _commit(self): | ||
if self._cleared: | ||
self._state_handler.blocking_clear(self._state_key) | ||
if self._added_elements: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks.
"""Specification for a user DoFn Set State cell""" | ||
|
||
def __init__(self, name, coder): | ||
assert isinstance(name, str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TypeError would be a better thing to raise here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can change this. But other StateSpec are doing the same way. Probably change their assert in a separate PR.
set_state.add(value) | ||
for saved_value in set_state.read(): | ||
aggregated_value += saved_value | ||
yield aggregated_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the analysis. Filing a JIRA for the multiple-timer issue sounds like the right thing to do here.
_, value = element | ||
set_state.add(value) | ||
|
||
if value == 5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't, in general, depend on elements being received in order. Perhaps trigger based on the size of the set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's good idea. I will change the test case.
78c8a84
to
ef1dc33
Compare
@robertwb I have taken care of your comments. Can you approve this now? |
R: @robertwb PTAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
SetState was implemented in Java sdk, to ensure the feature parity SetState is also implemented in Python SDK.
Thank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:
R: @username
).[BEAM-XXX] Fixes bug in ApproximateQuantiles
, where you replaceBEAM-XXX
with the appropriate JIRA issue, if applicable. This will automatically link the pull request to the issue.Post-Commit Tests Status (on master branch)
Pre-Commit Tests Status (on master branch)
See .test-infra/jenkins/README for trigger phrase, status and link of all Jenkins jobs.