Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-22865][python] Optimize state serialize/deserialize in PyFlink #16069

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flink-python/pyflink/common/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def _get_coder(self):
deserialize_func = self.deserialize

class CoderAdapter(object):
def get_impl(self):
return CoderAdapterIml()

class CoderAdapterIml(object):

def encode_nested(self, element):
bytes_io = BytesIO()
Expand Down
7 changes: 3 additions & 4 deletions flink-python/pyflink/datastream/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,8 +932,7 @@ def __init__(self, reduce_function: ReduceFunction):

def open(self, runtime_context: RuntimeContext):
self._reduce_value_state = runtime_context.get_state(
ValueStateDescriptor("_reduce_state" + str(uuid.uuid4()),
Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("_reduce_state" + str(uuid.uuid4()), output_type))
self._reduce_function.open(runtime_context)
from pyflink.fn_execution.datastream.runtime_context import StreamingRuntimeContext
self._in_batch_execution_mode = \
Expand Down Expand Up @@ -1137,7 +1136,7 @@ def get_execution_environment(self):
return self._keyed_stream.get_execution_environment()

def get_input_type(self):
return self._keyed_stream.get_type()
return _from_java_type(self._keyed_stream._original_data_type_info.get_java_type_info())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this change?


def trigger(self, trigger: Trigger):
"""
Expand Down Expand Up @@ -1201,7 +1200,7 @@ def _get_result_data_stream(
self.get_execution_environment())
window_serializer = self._window_assigner.get_window_serializer()
window_state_descriptor = ListStateDescriptor(
"window-contents", Types.PICKLED_BYTE_ARRAY())
"window-contents", self.get_input_type())
window_operation_descriptor = WindowOperationDescriptor(
self._window_assigner,
self._window_trigger,
Expand Down
26 changes: 1 addition & 25 deletions flink-python/pyflink/datastream/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from typing import TypeVar, Generic, Iterable, List, Iterator, Dict, Tuple

from pyflink.common.typeinfo import TypeInformation, Types, PickledBytesTypeInfo
from pyflink.common.typeinfo import TypeInformation, Types

__all__ = [
'ValueStateDescriptor',
Expand Down Expand Up @@ -316,10 +316,6 @@ def __init__(self, name: str, value_type_info: TypeInformation):
:param name: The name of the state.
:param value_type_info: the type information of the state.
"""
if not isinstance(value_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the value could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s."
% type(value_type_info))
super(ValueStateDescriptor, self).__init__(name, value_type_info)


Expand All @@ -336,10 +332,6 @@ def __init__(self, name: str, elem_type_info: TypeInformation):
:param name: The name of the state.
:param elem_type_info: the type information of the state element.
"""
if not isinstance(elem_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the element could only be "
"PickledBytesTypeInfo (created via Types.PICKLED_BYTE_ARRAY()) "
"currently, got %s" % type(elem_type_info))
super(ListStateDescriptor, self).__init__(name, Types.LIST(elem_type_info))


Expand All @@ -357,14 +349,6 @@ def __init__(self, name: str, key_type_info: TypeInformation, value_type_info: T
:param key_type_info: The type information of the key.
:param value_type_info: the type information of the value.
"""
if not isinstance(key_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the key could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s"
% type(key_type_info))
if not isinstance(value_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the value could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s"
% type(value_type_info))
super(MapStateDescriptor, self).__init__(name, Types.MAP(key_type_info, value_type_info))


Expand Down Expand Up @@ -392,10 +376,6 @@ def __init__(self,
reduce_function = ReduceFunctionWrapper(reduce_function) # type: ignore
else:
raise TypeError("The input must be a ReduceFunction or a callable function!")
if not isinstance(type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the state could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s"
% type(type_info))
self._reduce_function = reduce_function

def get_reduce_function(self):
Expand All @@ -418,10 +398,6 @@ def __init__(self,
from pyflink.datastream.functions import AggregateFunction
if not isinstance(agg_function, AggregateFunction):
raise TypeError("The input must be a pyflink.datastream.functions.AggregateFunction!")
if not isinstance(state_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the state could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s"
% type(state_type_info))
self._agg_function = agg_function

def get_agg_function(self):
Expand Down
35 changes: 15 additions & 20 deletions flink-python/pyflink/datastream/tests/test_data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.pre1 = runtime_context.get_state(
ValueStateDescriptor("pre1", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("pre1", Types.STRING()))
self.pre2 = runtime_context.get_state(
ValueStateDescriptor("pre2", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("pre2", Types.STRING()))

def map1(self, value):
if value[0] == 'b':
Expand Down Expand Up @@ -409,7 +409,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.state = runtime_context.get_state(
ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("test_state", Types.INT()))

def map(self, value):
state_value = self.state.value()
Expand Down Expand Up @@ -453,7 +453,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.state = runtime_context.get_state(
ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("test_state", Types.INT()))

def flat_map(self, value):
state_value = self.state.value()
Expand Down Expand Up @@ -497,7 +497,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.state = runtime_context.get_state(
ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("test_state", Types.INT()))

def filter(self, value):
state_value = self.state.value()
Expand Down Expand Up @@ -694,15 +694,11 @@ def __init__(self):
self.map_state = None

def open(self, runtime_context: RuntimeContext):
value_state_descriptor = ValueStateDescriptor('value_state',
Types.PICKLED_BYTE_ARRAY())
value_state_descriptor = ValueStateDescriptor('value_state', Types.INT())
self.value_state = runtime_context.get_state(value_state_descriptor)
list_state_descriptor = ListStateDescriptor('list_state',
Types.PICKLED_BYTE_ARRAY())
list_state_descriptor = ListStateDescriptor('list_state', Types.INT())
self.list_state = runtime_context.get_list_state(list_state_descriptor)
map_state_descriptor = MapStateDescriptor('map_state',
Types.PICKLED_BYTE_ARRAY(),
Types.PICKLED_BYTE_ARRAY())
map_state_descriptor = MapStateDescriptor('map_state', Types.INT(), Types.STRING())
self.map_state = runtime_context.get_map_state(map_state_descriptor)

def process_element(self, value, ctx):
Expand Down Expand Up @@ -770,7 +766,7 @@ def __init__(self):
def open(self, runtime_context: RuntimeContext):
self.reducing_state = runtime_context.get_reducing_state(
ReducingStateDescriptor(
'reducing_state', lambda i, i2: i + i2, Types.PICKLED_BYTE_ARRAY()))
'reducing_state', lambda i, i2: i + i2, Types.INT()))

def process_element(self, value, ctx):
self.reducing_state.add(value[0])
Expand Down Expand Up @@ -814,7 +810,7 @@ def __init__(self):
def open(self, runtime_context: RuntimeContext):
self.aggregating_state = runtime_context.get_aggregating_state(
AggregatingStateDescriptor(
'aggregating_state', MyAggregateFunction(), Types.PICKLED_BYTE_ARRAY()))
'aggregating_state', MyAggregateFunction(), Types.INT()))

def process_element(self, value, ctx):
self.aggregating_state.add(value[0])
Expand Down Expand Up @@ -1356,7 +1352,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.map_state = runtime_context.get_map_state(
MapStateDescriptor("map", Types.PICKLED_BYTE_ARRAY(), Types.PICKLED_BYTE_ARRAY()))
MapStateDescriptor("map", Types.STRING(), Types.BOOLEAN()))

def flat_map1(self, value):
yield str(value[0] + 1)
Expand All @@ -1376,8 +1372,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.timer_registered = False
self.count_state = runtime_context.get_state(ValueStateDescriptor(
"count", Types.PICKLED_BYTE_ARRAY()))
self.count_state = runtime_context.get_state(ValueStateDescriptor("count", Types.INT()))

def process_element1(self, value, ctx: 'KeyedCoProcessFunction.Context'):
if not self.timer_registered:
Expand Down Expand Up @@ -1411,7 +1406,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.state = runtime_context.get_state(
ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("test_state", Types.INT()))

def reduce(self, value1, value2):
state_value = self.state.value()
Expand Down Expand Up @@ -1439,7 +1434,7 @@ class SimpleCountWindowTrigger(Trigger[tuple, CountWindow]):
def __init__(self):
self._window_size = 3
self._count_state_descriptor = ReducingStateDescriptor(
"trigger_counter", lambda a, b: a + b, Types.PICKLED_BYTE_ARRAY())
"trigger_counter", lambda a, b: a + b, Types.BIG_INT())

def on_element(self,
element: tuple,
Expand Down Expand Up @@ -1479,7 +1474,7 @@ def __init__(self):
self._window_id = 0
self._window_size = 3
self._counter_state_descriptor = ReducingStateDescriptor(
"assigner_counter", lambda a, b: a + b, Types.PICKLED_BYTE_ARRAY())
"assigner_counter", lambda a, b: a + b, Types.BIG_INT())

def assign_windows(self,
element: tuple,
Expand Down
22 changes: 8 additions & 14 deletions flink-python/pyflink/datastream/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,22 +153,19 @@ def __init__(self):

def serialize(self, element: TimeWindow, stream: BytesIO) -> None:
if self._underlying_coder is None:
self._underlying_coder = self._get_coder()
self._underlying_coder = self._get_coder().get_impl()
bytes_data = self._underlying_coder.encode(element)
stream.write(bytes_data)

def deserialize(self, stream: BytesIO) -> TimeWindow:
if self._underlying_coder is None:
self._underlying_coder = self._get_coder()
self._underlying_coder = self._get_coder().get_impl()
bytes_data = stream.read(16)
return self._underlying_coder.decode(bytes_data)

def _get_coder(self):
try:
from pyflink.fn_execution import coder_impl_fast as coder_impl
except:
from pyflink.fn_execution import coder_impl_slow as coder_impl
return coder_impl.TimeWindowCoderImpl()
from pyflink.fn_execution import coders
return coders.TimeWindowCoder()


class CountWindowSerializer(TypeSerializer[CountWindow]):
Expand All @@ -178,22 +175,19 @@ def __init__(self):

def serialize(self, element: CountWindow, stream: BytesIO) -> None:
if self._underlying_coder is None:
self._underlying_coder = self._get_coder()
self._underlying_coder = self._get_coder().get_impl()
bytes_data = self._underlying_coder.encode(element)
stream.write(bytes_data)

def deserialize(self, stream: BytesIO) -> CountWindow:
if self._underlying_coder is None:
self._underlying_coder = self._get_coder()
self._underlying_coder = self._get_coder().get_impl()
bytes_data = stream.read(8)
return self._underlying_coder.decode(bytes_data)

def _get_coder(self):
try:
from pyflink.fn_execution import coder_impl_fast as coder_impl
except:
from pyflink.fn_execution import coder_impl_slow as coder_impl
return coder_impl.CountWindowCoderImpl()
from pyflink.fn_execution import coders
return coders.CountWindowCoder()


T = TypeVar('T')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ from apache_beam.coders.coder_impl cimport InputStream as BInputStream
from apache_beam.coders.coder_impl cimport OutputStream as BOutputStream
from apache_beam.coders.coder_impl cimport StreamCoderImpl

from pyflink.fn_execution.beam.beam_stream cimport BeamInputStream
from pyflink.fn_execution.beam.beam_stream_fast cimport BeamInputStream
from pyflink.fn_execution.stream_fast cimport InputStream

cdef class PassThroughLengthPrefixCoderImpl(StreamCoderImpl):
Expand Down Expand Up @@ -59,9 +59,11 @@ cdef class PassThroughPrefixCoderImpl(StreamCoderImpl):
# create InputStream
data_input_stream = InputStream()
data_input_stream._input_data = <char*?>in_stream.allc
in_stream.pos = size
data_input_stream._input_pos = in_stream.pos

return self._value_coder.decode_from_stream(data_input_stream, size)
result = self._value_coder.decode_from_stream(data_input_stream, size)
in_stream.pos = data_input_stream._input_pos
return result

cdef void _write_data_output_stream(self, BOutputStream out_stream):
cdef OutputStream data_out_stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

from apache_beam.coders.coder_impl import StreamCoderImpl, create_InputStream, create_OutputStream

from pyflink.fn_execution.stream_slow import OutputStream, InputStream
from pyflink.fn_execution.stream_slow import OutputStream
from pyflink.fn_execution.beam.beam_stream_slow import BeamInputStream


class PassThroughLengthPrefixCoderImpl(StreamCoderImpl):
Expand Down Expand Up @@ -50,7 +51,7 @@ def encode_to_stream(self, value, out_stream: create_OutputStream, nested):
self._data_output_stream.clear()

def decode_from_stream(self, in_stream: create_InputStream, nested):
data_input_stream = InputStream(in_stream.read_all(False))
data_input_stream = BeamInputStream(in_stream)
return self._value_coder.decode_from_stream(data_input_stream)

def __repr__(self):
Expand Down
29 changes: 1 addition & 28 deletions flink-python/pyflink/fn_execution/beam/beam_coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import pickle
from typing import Any

from apache_beam.coders import Coder, coder_impl
from apache_beam.coders import Coder
from apache_beam.coders.coders import FastCoder, LengthPrefixCoder
from apache_beam.portability import common_urns
from apache_beam.typehints import typehints
Expand Down Expand Up @@ -93,27 +90,3 @@ def __ne__(self, other):

def __hash__(self):
return hash(self._internal_coder)


class DataViewFilterCoder(FastCoder):

def to_type_hint(self):
return Any

def __init__(self, udf_data_view_specs):
self._udf_data_view_specs = udf_data_view_specs

def filter_data_views(self, row):
i = 0
for specs in self._udf_data_view_specs:
for spec in specs:
row[i][spec.field_index] = None
i += 1
return row

def _create_impl(self):
filter_data_views = self.filter_data_views
dumps = pickle.dumps
HIGHEST_PROTOCOL = pickle.HIGHEST_PROTOCOL
return coder_impl.CallbackCoderImpl(
lambda x: dumps(filter_data_views(x), HIGHEST_PROTOCOL), pickle.loads)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ from libc.stdint cimport *
from apache_beam.utils.windowed_value cimport WindowedValue

from pyflink.fn_execution.coder_impl_fast cimport LengthPrefixBaseCoderImpl
from pyflink.fn_execution.beam.beam_stream cimport BeamInputStream, BeamOutputStream
from pyflink.fn_execution.beam.beam_stream_fast cimport BeamInputStream, BeamOutputStream
from pyflink.fn_execution.beam.beam_coder_impl_fast cimport InputStreamWrapper, BeamCoderImpl
from pyflink.fn_execution.table.operations import BundleOperation

Expand Down
Loading