Skip to content

Commit

Permalink
[BEAM-7739] Implement ReadModifyWriteState Py sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
rakeshcusat committed Aug 12, 2019
1 parent 2bdf953 commit 21c9e61
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 12 deletions.
5 changes: 3 additions & 2 deletions model/pipeline/src/main/proto/beam_runner_api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -411,15 +411,16 @@ message Parameter {

message StateSpec {
oneof spec {
ValueStateSpec value_spec = 1;
ReadModifyWriteStateSpec value_spec = 1; // This has been deprecated, and should not be used.
BagStateSpec bag_spec = 2;
CombiningStateSpec combining_spec = 3;
MapStateSpec map_spec = 4;
SetStateSpec set_spec = 5;
ReadModifyWriteStateSpec read_modify_write_spec = 6;
}
}

message ValueStateSpec {
message ReadModifyWriteStateSpec {
string coder_id = 1;
}

Expand Down
46 changes: 46 additions & 0 deletions sdks/python/apache_beam/runners/direct/direct_userstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from apache_beam.transforms import userstate
from apache_beam.transforms.trigger import _ListStateTag
from apache_beam.transforms.trigger import _SetStateTag
from apache_beam.transforms.trigger import _ReadModifyWriteStateTag


class DirectRuntimeState(userstate.RuntimeState):
Expand All @@ -40,6 +41,11 @@ def for_spec(state_spec, state_tag, current_value_accessor):
current_value_accessor)
elif isinstance(state_spec, userstate.SetStateSpec):
return SetRuntimeState(state_spec, state_tag, current_value_accessor)

elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
return ReadModifyWriteRuntimeState(state_spec,
state_tag,
current_value_accessor)
else:
raise ValueError('Invalid state spec: %s' % state_spec)

Expand Down Expand Up @@ -110,6 +116,36 @@ def is_modified(self):
return self._modified


class ReadModifyWriteRuntimeState(DirectRuntimeState,
userstate.ReadModifyWriteRuntimeState):
"""Read modify write state interface object passed to user code."""

def __init__(self, state_spec, state_tag, current_value_accessor):
super(ReadModifyWriteRuntimeState, self).__init__(
state_spec, state_tag, current_value_accessor)
self._value = UNREAD_VALUE
self._modified = False
self._cleared = False

def read(self):
if self._value is UNREAD_VALUE:
self._value = self._decode(self._current_value_accessor())

return self._value

def add(self, value):
self._modified = True
self._value = value

def clear(self):
self._cleared = True
self._value = UNREAD_VALUE
self._modified = False

def is_modified(self):
return self._modified and self._value is not UNREAD_VALUE


class CombiningValueRuntimeState(
DirectRuntimeState, userstate.CombiningValueRuntimeState):
"""Combining value state interface object passed to user code."""
Expand Down Expand Up @@ -169,6 +205,8 @@ def __init__(self, step_context, dofn, key_coder):
state_tag = _ListStateTag(state_key)
elif isinstance(state_spec, userstate.SetStateSpec):
state_tag = _SetStateTag(state_key)
elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
state_tag = _ReadModifyWriteStateTag(state_key)
else:
raise ValueError('Invalid state spec: %s' % state_spec)
self.state_tags[state_spec] = state_tag
Expand Down Expand Up @@ -225,6 +263,14 @@ def commit(self):
for new_value in runtime_state._current_accumulator:
state.add_state(
window, state_tag, state_spec.coder.encode(new_value))
elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
if runtime_state._cleared:
state.clear_state(window, state_tag)

if runtime_state.is_modified():
state.add_state(window,
state_tag,
state_spec.coder.encode(runtime_state._value))
else:
raise ValueError('Invalid state spec: %s' % state_spec)

Expand Down
62 changes: 60 additions & 2 deletions sdks/python/apache_beam/runners/worker/bundle_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def clear(self):
def _commit(self):
if self._cleared:
self._state_handler.blocking_clear(self._state_key)

if self._added_elements:
value_coder_impl = self._value_coder.get_impl()
out = coder_impl.create_OutputStream()
Expand All @@ -439,6 +440,53 @@ def _commit(self):
self._state_handler.blocking_append(self._state_key, out.get())


class SynchronousReadModifyWriteRuntimeState(
userstate.ReadModifyWriteRuntimeState):

def __init__(self, state_handler, state_key, value_coder):
self._state_handler = state_handler
self._state_key = state_key
self._value_coder = value_coder
self._cleared = False
self._added_element = None

def read(self):
# TODO: not sure whther return an iterable object or single value. If we
# are choosing to be consistent with other states (SetState, BagState)
# then we should choose an iteratable object but this state is suppose
# to return only one value. Currently, I am choosing to just return only
# one value.

if self._cleared:
return None
elif self._added_element:
return self._added_element
else:
elements = [element for element in _StateBackedIterable(
self._state_handler, self._state_key, self._value_coder)]
return elements[0] if elements else None

def add(self, value):
if self._cleared:
self._state_handler.blocking_clear(self._state_key)
self._cleared = False
self._added_element = value

def clear(self):
self._cleared = True
self._added_element = None

def _commit(self):
if self._cleared:
self._state_handler.blocking_clear(self._state_key)

if self._added_element:
value_coder_impl = self._value_coder.get_impl()
out = coder_impl.create_OutputStream()
value_coder_impl.encode_to_stream(self._added_element, out, True)
self._state_handler.blocking_append(self._state_key, out.get())


class OutputTimer(object):
def __init__(self, key, window, receiver):
self._key = key
Expand Down Expand Up @@ -502,7 +550,7 @@ def _create_state(self, state_spec, key, window):
if isinstance(state_spec,
(userstate.BagStateSpec, userstate.CombiningValueStateSpec)):
bag_state = SynchronousBagRuntimeState(
self._state_handler,
state_handler=self._state_handler,
state_key=beam_fn_api_pb2.StateKey(
bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
ptransform_id=self._transform_id,
Expand All @@ -516,7 +564,17 @@ def _create_state(self, state_spec, key, window):
return CombiningValueRuntimeState(bag_state, state_spec.combine_fn)
elif isinstance(state_spec, userstate.SetStateSpec):
return SynchronousSetRuntimeState(
self._state_handler,
state_handler=self._state_handler,
state_key=beam_fn_api_pb2.StateKey(
bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
ptransform_id=self._transform_id,
user_state_id=state_spec.name,
window=self._window_coder.encode(window),
key=self._key_coder.encode(key))),
value_coder=state_spec.coder)
elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
return SynchronousReadModifyWriteRuntimeState(
state_handler=self._state_handler,
state_key=beam_fn_api_pb2.StateKey(
bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
ptransform_id=self._transform_id,
Expand Down
31 changes: 26 additions & 5 deletions sdks/python/apache_beam/transforms/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def with_prefix(self, prefix):
return _SetStateTag(prefix + self.tag)


class _ReadModifyWriteStateTag(_StateTag):
"""StateTag pointing to an element."""

def __repr__(self):
return 'ReadModifyWriteState(%s)' % (self.tag)

def with_prefix(self, prefix):
return _ReadModifyWriteStateTag(prefix + self.tag)


class _CombiningValueStateTag(_StateTag):
"""StateTag pointing to an element, accumulated with a combiner.
Expand Down Expand Up @@ -865,11 +875,19 @@ def add_state(self, window, tag, value):
def get_state(self, window, tag):
if isinstance(tag, _CombiningValueStateTag):
original_tag, tag = tag, tag.without_extraction()

values = [self.raw_state.get_state(window_id, tag)
for window_id in self._get_ids(window)]
if isinstance(tag, _ValueStateTag):
raise ValueError(
'Merging requested for non-mergeable state tag: %r.' % tag)

if isinstance(tag, _ReadModifyWriteStateTag):
# TODO: Need better logic here. I think we should just get the latest one
# based on the window.
for vs in values:
for v in vs:
if v:
return v
return None

elif isinstance(tag, _CombiningValueStateTag):
return original_tag.combine_fn.extract_output(
original_tag.combine_fn.merge_accumulators(values))
Expand Down Expand Up @@ -1231,7 +1249,8 @@ def get_window(self, window_id):
def add_state(self, window, tag, value):
if self.defensive_copy:
value = copy.deepcopy(value)
if isinstance(tag, _ValueStateTag):
if isinstance(tag, _ReadModifyWriteStateTag):
# TODO: need to add some thing here.
self.state[window][tag.tag] = value
elif isinstance(tag, _CombiningValueStateTag):
# TODO(robertwb): Store merged accumulators.
Expand All @@ -1247,7 +1266,9 @@ def add_state(self, window, tag, value):

def get_state(self, window, tag):
values = self.state[window][tag.tag]
if isinstance(tag, _ValueStateTag):
if isinstance(tag, _ReadModifyWriteStateTag):
# since we have stored only one item, values will
# have only one item.
return values
elif isinstance(tag, _CombiningValueStateTag):
return tag.combine_fn.apply(values)
Expand Down
30 changes: 30 additions & 0 deletions sdks/python/apache_beam/transforms/userstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,31 @@ def to_runner_api(self, context):
element_coder_id=context.coders.get_id(self.coder)))


class ReadModifyWriteStateSpec(StateSpec):
"""
Specification of a user DoFn read modify write State Cell.
"""
def __init__(self, name, coder):
"""
Initialize the specification for Read modify write state.
Args:
name (str): The name by which the state is identified.
coder (Coder): Coder specifying how to encode the value.
"""
if not isinstance(name, str):
raise TypeError("ReadModifyWriteState name is not a string")
if not isinstance(coder, Coder):
raise TypeError("ReadModifyWriteState coder is not of type Coder")
self.name = name
self.coder = coder

def to_runner_api(self, context):
return beam_runner_api_pb2.StateSpec(
read_modify_write_spec=beam_runner_api_pb2.ReadModifyWriteStateSpec(
coder_id=context.coders.get_id(self.coder)))


class CombiningValueStateSpec(StateSpec):
"""Specification for a user DoFn combining value state cell."""

Expand Down Expand Up @@ -267,6 +292,7 @@ def set(self, timestamp):

class RuntimeState(object):
"""State interface object passed to user code."""

def prefetch(self):
# The default implementation here does nothing.
pass
Expand All @@ -291,6 +317,10 @@ class SetRuntimeState(AccumulatingRuntimeState):
"""Set state interface object passed to user code."""


class ReadModifyWriteRuntimeState(AccumulatingRuntimeState):
"""ReadModifyWrite state information object passed to user code."""


class CombiningValueRuntimeState(AccumulatingRuntimeState):
"""Combining value state interface object passed to user code."""

Expand Down

0 comments on commit 21c9e61

Please sign in to comment.