diff --git a/flink-python/pyflink/common/serializer.py b/flink-python/pyflink/common/serializer.py index 657b657174d7e..2533b8657ccac 100644 --- a/flink-python/pyflink/common/serializer.py +++ b/flink-python/pyflink/common/serializer.py @@ -61,6 +61,10 @@ def _get_coder(self): deserialize_func = self.deserialize class CoderAdapter(object): + def get_impl(self): + return CoderAdapterIml() + + class CoderAdapterIml(object): def encode_nested(self, element): bytes_io = BytesIO() diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py index 285067c0e0eb1..8b915e7ca86d3 100644 --- a/flink-python/pyflink/datastream/data_stream.py +++ b/flink-python/pyflink/datastream/data_stream.py @@ -932,8 +932,7 @@ def __init__(self, reduce_function: ReduceFunction): def open(self, runtime_context: RuntimeContext): self._reduce_value_state = runtime_context.get_state( - ValueStateDescriptor("_reduce_state" + str(uuid.uuid4()), - Types.PICKLED_BYTE_ARRAY())) + ValueStateDescriptor("_reduce_state" + str(uuid.uuid4()), output_type)) self._reduce_function.open(runtime_context) from pyflink.fn_execution.datastream.runtime_context import StreamingRuntimeContext self._in_batch_execution_mode = \ @@ -1137,7 +1136,7 @@ def get_execution_environment(self): return self._keyed_stream.get_execution_environment() def get_input_type(self): - return self._keyed_stream.get_type() + return _from_java_type(self._keyed_stream._original_data_type_info.get_java_type_info()) def trigger(self, trigger: Trigger): """ @@ -1201,7 +1200,7 @@ def _get_result_data_stream( self.get_execution_environment()) window_serializer = self._window_assigner.get_window_serializer() window_state_descriptor = ListStateDescriptor( - "window-contents", Types.PICKLED_BYTE_ARRAY()) + "window-contents", self.get_input_type()) window_operation_descriptor = WindowOperationDescriptor( self._window_assigner, self._window_trigger, diff --git a/flink-python/pyflink/datastream/state.py b/flink-python/pyflink/datastream/state.py index 76380fff97376..a576d410727f2 100644 --- a/flink-python/pyflink/datastream/state.py +++ b/flink-python/pyflink/datastream/state.py @@ -19,7 +19,7 @@ from typing import TypeVar, Generic, Iterable, List, Iterator, Dict, Tuple -from pyflink.common.typeinfo import TypeInformation, Types, PickledBytesTypeInfo +from pyflink.common.typeinfo import TypeInformation, Types __all__ = [ 'ValueStateDescriptor', @@ -316,10 +316,6 @@ def __init__(self, name: str, value_type_info: TypeInformation): :param name: The name of the state. :param value_type_info: the type information of the state. """ - if not isinstance(value_type_info, PickledBytesTypeInfo): - raise ValueError("The type information of the value could only be PickledBytesTypeInfo " - "(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s." - % type(value_type_info)) super(ValueStateDescriptor, self).__init__(name, value_type_info) @@ -336,10 +332,6 @@ def __init__(self, name: str, elem_type_info: TypeInformation): :param name: The name of the state. :param elem_type_info: the type information of the state element. """ - if not isinstance(elem_type_info, PickledBytesTypeInfo): - raise ValueError("The type information of the element could only be " - "PickledBytesTypeInfo (created via Types.PICKLED_BYTE_ARRAY()) " - "currently, got %s" % type(elem_type_info)) super(ListStateDescriptor, self).__init__(name, Types.LIST(elem_type_info)) @@ -357,14 +349,6 @@ def __init__(self, name: str, key_type_info: TypeInformation, value_type_info: T :param key_type_info: The type information of the key. :param value_type_info: the type information of the value. """ - if not isinstance(key_type_info, PickledBytesTypeInfo): - raise ValueError("The type information of the key could only be PickledBytesTypeInfo " - "(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s" - % type(key_type_info)) - if not isinstance(value_type_info, PickledBytesTypeInfo): - raise ValueError("The type information of the value could only be PickledBytesTypeInfo " - "(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s" - % type(value_type_info)) super(MapStateDescriptor, self).__init__(name, Types.MAP(key_type_info, value_type_info)) @@ -392,10 +376,6 @@ def __init__(self, reduce_function = ReduceFunctionWrapper(reduce_function) # type: ignore else: raise TypeError("The input must be a ReduceFunction or a callable function!") - if not isinstance(type_info, PickledBytesTypeInfo): - raise ValueError("The type information of the state could only be PickledBytesTypeInfo " - "(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s" - % type(type_info)) self._reduce_function = reduce_function def get_reduce_function(self): @@ -418,10 +398,6 @@ def __init__(self, from pyflink.datastream.functions import AggregateFunction if not isinstance(agg_function, AggregateFunction): raise TypeError("The input must be a pyflink.datastream.functions.AggregateFunction!") - if not isinstance(state_type_info, PickledBytesTypeInfo): - raise ValueError("The type information of the state could only be PickledBytesTypeInfo " - "(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s" - % type(state_type_info)) self._agg_function = agg_function def get_agg_function(self): diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py index 9b40bb6b93ade..21eba87859a6a 100644 --- a/flink-python/pyflink/datastream/tests/test_data_stream.py +++ b/flink-python/pyflink/datastream/tests/test_data_stream.py @@ -214,9 +214,9 @@ def __init__(self): def open(self, runtime_context: RuntimeContext): self.pre1 = runtime_context.get_state( - ValueStateDescriptor("pre1", Types.PICKLED_BYTE_ARRAY())) + ValueStateDescriptor("pre1", Types.STRING())) self.pre2 = runtime_context.get_state( - ValueStateDescriptor("pre2", Types.PICKLED_BYTE_ARRAY())) + ValueStateDescriptor("pre2", Types.STRING())) def map1(self, value): if value[0] == 'b': @@ -409,7 +409,7 @@ def __init__(self): def open(self, runtime_context: RuntimeContext): self.state = runtime_context.get_state( - ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY())) + ValueStateDescriptor("test_state", Types.INT())) def map(self, value): state_value = self.state.value() @@ -453,7 +453,7 @@ def __init__(self): def open(self, runtime_context: RuntimeContext): self.state = runtime_context.get_state( - ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY())) + ValueStateDescriptor("test_state", Types.INT())) def flat_map(self, value): state_value = self.state.value() @@ -497,7 +497,7 @@ def __init__(self): def open(self, runtime_context: RuntimeContext): self.state = runtime_context.get_state( - ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY())) + ValueStateDescriptor("test_state", Types.INT())) def filter(self, value): state_value = self.state.value() @@ -694,15 +694,11 @@ def __init__(self): self.map_state = None def open(self, runtime_context: RuntimeContext): - value_state_descriptor = ValueStateDescriptor('value_state', - Types.PICKLED_BYTE_ARRAY()) + value_state_descriptor = ValueStateDescriptor('value_state', Types.INT()) self.value_state = runtime_context.get_state(value_state_descriptor) - list_state_descriptor = ListStateDescriptor('list_state', - Types.PICKLED_BYTE_ARRAY()) + list_state_descriptor = ListStateDescriptor('list_state', Types.INT()) self.list_state = runtime_context.get_list_state(list_state_descriptor) - map_state_descriptor = MapStateDescriptor('map_state', - Types.PICKLED_BYTE_ARRAY(), - Types.PICKLED_BYTE_ARRAY()) + map_state_descriptor = MapStateDescriptor('map_state', Types.INT(), Types.STRING()) self.map_state = runtime_context.get_map_state(map_state_descriptor) def process_element(self, value, ctx): @@ -770,7 +766,7 @@ def __init__(self): def open(self, runtime_context: RuntimeContext): self.reducing_state = runtime_context.get_reducing_state( ReducingStateDescriptor( - 'reducing_state', lambda i, i2: i + i2, Types.PICKLED_BYTE_ARRAY())) + 'reducing_state', lambda i, i2: i + i2, Types.INT())) def process_element(self, value, ctx): self.reducing_state.add(value[0]) @@ -814,7 +810,7 @@ def __init__(self): def open(self, runtime_context: RuntimeContext): self.aggregating_state = runtime_context.get_aggregating_state( AggregatingStateDescriptor( - 'aggregating_state', MyAggregateFunction(), Types.PICKLED_BYTE_ARRAY())) + 'aggregating_state', MyAggregateFunction(), Types.INT())) def process_element(self, value, ctx): self.aggregating_state.add(value[0]) @@ -1356,7 +1352,7 @@ def __init__(self): def open(self, runtime_context: RuntimeContext): self.map_state = runtime_context.get_map_state( - MapStateDescriptor("map", Types.PICKLED_BYTE_ARRAY(), Types.PICKLED_BYTE_ARRAY())) + MapStateDescriptor("map", Types.STRING(), Types.BOOLEAN())) def flat_map1(self, value): yield str(value[0] + 1) @@ -1376,8 +1372,7 @@ def __init__(self): def open(self, runtime_context: RuntimeContext): self.timer_registered = False - self.count_state = runtime_context.get_state(ValueStateDescriptor( - "count", Types.PICKLED_BYTE_ARRAY())) + self.count_state = runtime_context.get_state(ValueStateDescriptor("count", Types.INT())) def process_element1(self, value, ctx: 'KeyedCoProcessFunction.Context'): if not self.timer_registered: @@ -1411,7 +1406,7 @@ def __init__(self): def open(self, runtime_context: RuntimeContext): self.state = runtime_context.get_state( - ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY())) + ValueStateDescriptor("test_state", Types.INT())) def reduce(self, value1, value2): state_value = self.state.value() @@ -1439,7 +1434,7 @@ class SimpleCountWindowTrigger(Trigger[tuple, CountWindow]): def __init__(self): self._window_size = 3 self._count_state_descriptor = ReducingStateDescriptor( - "trigger_counter", lambda a, b: a + b, Types.PICKLED_BYTE_ARRAY()) + "trigger_counter", lambda a, b: a + b, Types.BIG_INT()) def on_element(self, element: tuple, @@ -1479,7 +1474,7 @@ def __init__(self): self._window_id = 0 self._window_size = 3 self._counter_state_descriptor = ReducingStateDescriptor( - "assigner_counter", lambda a, b: a + b, Types.PICKLED_BYTE_ARRAY()) + "assigner_counter", lambda a, b: a + b, Types.BIG_INT()) def assign_windows(self, element: tuple, diff --git a/flink-python/pyflink/datastream/window.py b/flink-python/pyflink/datastream/window.py index 058311bf6e758..53266e0051f8e 100644 --- a/flink-python/pyflink/datastream/window.py +++ b/flink-python/pyflink/datastream/window.py @@ -153,22 +153,19 @@ def __init__(self): def serialize(self, element: TimeWindow, stream: BytesIO) -> None: if self._underlying_coder is None: - self._underlying_coder = self._get_coder() + self._underlying_coder = self._get_coder().get_impl() bytes_data = self._underlying_coder.encode(element) stream.write(bytes_data) def deserialize(self, stream: BytesIO) -> TimeWindow: if self._underlying_coder is None: - self._underlying_coder = self._get_coder() + self._underlying_coder = self._get_coder().get_impl() bytes_data = stream.read(16) return self._underlying_coder.decode(bytes_data) def _get_coder(self): - try: - from pyflink.fn_execution import coder_impl_fast as coder_impl - except: - from pyflink.fn_execution import coder_impl_slow as coder_impl - return coder_impl.TimeWindowCoderImpl() + from pyflink.fn_execution import coders + return coders.TimeWindowCoder() class CountWindowSerializer(TypeSerializer[CountWindow]): @@ -178,22 +175,19 @@ def __init__(self): def serialize(self, element: CountWindow, stream: BytesIO) -> None: if self._underlying_coder is None: - self._underlying_coder = self._get_coder() + self._underlying_coder = self._get_coder().get_impl() bytes_data = self._underlying_coder.encode(element) stream.write(bytes_data) def deserialize(self, stream: BytesIO) -> CountWindow: if self._underlying_coder is None: - self._underlying_coder = self._get_coder() + self._underlying_coder = self._get_coder().get_impl() bytes_data = stream.read(8) return self._underlying_coder.decode(bytes_data) def _get_coder(self): - try: - from pyflink.fn_execution import coder_impl_fast as coder_impl - except: - from pyflink.fn_execution import coder_impl_slow as coder_impl - return coder_impl.CountWindowCoderImpl() + from pyflink.fn_execution import coders + return coders.CountWindowCoder() T = TypeVar('T') diff --git a/flink-python/pyflink/fn_execution/beam/beam_coder_impl_fast.pyx b/flink-python/pyflink/fn_execution/beam/beam_coder_impl_fast.pyx index 551ae0fcab102..da8da4cf9143a 100644 --- a/flink-python/pyflink/fn_execution/beam/beam_coder_impl_fast.pyx +++ b/flink-python/pyflink/fn_execution/beam/beam_coder_impl_fast.pyx @@ -25,7 +25,7 @@ from apache_beam.coders.coder_impl cimport InputStream as BInputStream from apache_beam.coders.coder_impl cimport OutputStream as BOutputStream from apache_beam.coders.coder_impl cimport StreamCoderImpl -from pyflink.fn_execution.beam.beam_stream cimport BeamInputStream +from pyflink.fn_execution.beam.beam_stream_fast cimport BeamInputStream from pyflink.fn_execution.stream_fast cimport InputStream cdef class PassThroughLengthPrefixCoderImpl(StreamCoderImpl): @@ -59,9 +59,11 @@ cdef class PassThroughPrefixCoderImpl(StreamCoderImpl): # create InputStream data_input_stream = InputStream() data_input_stream._input_data = in_stream.allc - in_stream.pos = size + data_input_stream._input_pos = in_stream.pos - return self._value_coder.decode_from_stream(data_input_stream, size) + result = self._value_coder.decode_from_stream(data_input_stream, size) + in_stream.pos = data_input_stream._input_pos + return result cdef void _write_data_output_stream(self, BOutputStream out_stream): cdef OutputStream data_out_stream diff --git a/flink-python/pyflink/fn_execution/beam/beam_coder_impl_slow.py b/flink-python/pyflink/fn_execution/beam/beam_coder_impl_slow.py index 8e5a699fc3cbe..cf67708b90e59 100644 --- a/flink-python/pyflink/fn_execution/beam/beam_coder_impl_slow.py +++ b/flink-python/pyflink/fn_execution/beam/beam_coder_impl_slow.py @@ -19,7 +19,8 @@ from apache_beam.coders.coder_impl import StreamCoderImpl, create_InputStream, create_OutputStream -from pyflink.fn_execution.stream_slow import OutputStream, InputStream +from pyflink.fn_execution.stream_slow import OutputStream +from pyflink.fn_execution.beam.beam_stream_slow import BeamInputStream class PassThroughLengthPrefixCoderImpl(StreamCoderImpl): @@ -50,7 +51,7 @@ def encode_to_stream(self, value, out_stream: create_OutputStream, nested): self._data_output_stream.clear() def decode_from_stream(self, in_stream: create_InputStream, nested): - data_input_stream = InputStream(in_stream.read_all(False)) + data_input_stream = BeamInputStream(in_stream) return self._value_coder.decode_from_stream(data_input_stream) def __repr__(self): diff --git a/flink-python/pyflink/fn_execution/beam/beam_coders.py b/flink-python/pyflink/fn_execution/beam/beam_coders.py index 6d7962ec86f6c..53009623d2f61 100644 --- a/flink-python/pyflink/fn_execution/beam/beam_coders.py +++ b/flink-python/pyflink/fn_execution/beam/beam_coders.py @@ -15,10 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ -import pickle -from typing import Any - -from apache_beam.coders import Coder, coder_impl +from apache_beam.coders import Coder from apache_beam.coders.coders import FastCoder, LengthPrefixCoder from apache_beam.portability import common_urns from apache_beam.typehints import typehints @@ -93,27 +90,3 @@ def __ne__(self, other): def __hash__(self): return hash(self._internal_coder) - - -class DataViewFilterCoder(FastCoder): - - def to_type_hint(self): - return Any - - def __init__(self, udf_data_view_specs): - self._udf_data_view_specs = udf_data_view_specs - - def filter_data_views(self, row): - i = 0 - for specs in self._udf_data_view_specs: - for spec in specs: - row[i][spec.field_index] = None - i += 1 - return row - - def _create_impl(self): - filter_data_views = self.filter_data_views - dumps = pickle.dumps - HIGHEST_PROTOCOL = pickle.HIGHEST_PROTOCOL - return coder_impl.CallbackCoderImpl( - lambda x: dumps(filter_data_views(x), HIGHEST_PROTOCOL), pickle.loads) diff --git a/flink-python/pyflink/fn_execution/beam/beam_operations_fast.pyx b/flink-python/pyflink/fn_execution/beam/beam_operations_fast.pyx index 51b5b8ec0ba3c..4b363b4c0838f 100644 --- a/flink-python/pyflink/fn_execution/beam/beam_operations_fast.pyx +++ b/flink-python/pyflink/fn_execution/beam/beam_operations_fast.pyx @@ -23,7 +23,7 @@ from libc.stdint cimport * from apache_beam.utils.windowed_value cimport WindowedValue from pyflink.fn_execution.coder_impl_fast cimport LengthPrefixBaseCoderImpl -from pyflink.fn_execution.beam.beam_stream cimport BeamInputStream, BeamOutputStream +from pyflink.fn_execution.beam.beam_stream_fast cimport BeamInputStream, BeamOutputStream from pyflink.fn_execution.beam.beam_coder_impl_fast cimport InputStreamWrapper, BeamCoderImpl from pyflink.fn_execution.table.operations import BundleOperation diff --git a/flink-python/pyflink/fn_execution/beam/beam_stream.pxd b/flink-python/pyflink/fn_execution/beam/beam_stream_fast.pxd similarity index 100% rename from flink-python/pyflink/fn_execution/beam/beam_stream.pxd rename to flink-python/pyflink/fn_execution/beam/beam_stream_fast.pxd diff --git a/flink-python/pyflink/fn_execution/beam/beam_stream.pyx b/flink-python/pyflink/fn_execution/beam/beam_stream_fast.pyx similarity index 100% rename from flink-python/pyflink/fn_execution/beam/beam_stream.pyx rename to flink-python/pyflink/fn_execution/beam/beam_stream_fast.pyx diff --git a/flink-python/pyflink/fn_execution/beam/beam_stream_slow.py b/flink-python/pyflink/fn_execution/beam/beam_stream_slow.py new file mode 100644 index 0000000000000..8dff78152a4b6 --- /dev/null +++ b/flink-python/pyflink/fn_execution/beam/beam_stream_slow.py @@ -0,0 +1,35 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +from apache_beam.coders.coder_impl import create_InputStream + +from pyflink.fn_execution.stream_slow import InputStream + + +class BeamInputStream(InputStream): + def __init__(self, input_stream: create_InputStream): + super(BeamInputStream, self).__init__([]) + self._input_stream = input_stream + + def read(self, size): + return self._input_stream.read(size) + + def read_byte(self): + return self._input_stream.read_byte() + + def size(self): + return self._input_stream.size() diff --git a/flink-python/pyflink/fn_execution/coder_impl_fast.pxd b/flink-python/pyflink/fn_execution/coder_impl_fast.pxd index 3cb3cdb6cff03..ed4af67fc4201 100644 --- a/flink-python/pyflink/fn_execution/coder_impl_fast.pxd +++ b/flink-python/pyflink/fn_execution/coder_impl_fast.pxd @@ -129,7 +129,10 @@ cdef class TimestampCoderImpl(FieldCoderImpl): cdef class LocalZonedTimestampCoderImpl(TimestampCoderImpl): cdef object _timezone -cdef class PickledBytesCoderImpl(FieldCoderImpl): +cdef class CloudPickleCoderImpl(FieldCoderImpl): + pass + +cdef class PickleCoderImpl(FieldCoderImpl): pass cdef class GenericArrayCoderImpl(FieldCoderImpl): @@ -151,3 +154,7 @@ cdef class TimeWindowCoderImpl(FieldCoderImpl): cdef class CountWindowCoderImpl(FieldCoderImpl): pass + +cdef class DataViewFilterCoderImpl(FieldCoderImpl): + cdef object _udf_data_view_specs + cdef PickleCoderImpl _pickle_coder diff --git a/flink-python/pyflink/fn_execution/coder_impl_fast.pyx b/flink-python/pyflink/fn_execution/coder_impl_fast.pyx index 6ac4e6c17df67..0dafb13299a82 100644 --- a/flink-python/pyflink/fn_execution/coder_impl_fast.pyx +++ b/flink-python/pyflink/fn_execution/coder_impl_fast.pyx @@ -24,6 +24,7 @@ from libc.stdlib cimport free, malloc import datetime import decimal +import pickle from typing import List, Union from cloudpickle import cloudpickle @@ -631,9 +632,9 @@ cdef class LocalZonedTimestampCoderImpl(TimestampCoderImpl): cpdef decode_from_stream(self, InputStream in_stream, size_t size): return self._timezone.localize(self._decode_timestamp_data_from_stream(in_stream)) -cdef class PickledBytesCoderImpl(FieldCoderImpl): +cdef class CloudPickleCoderImpl(FieldCoderImpl): """ - A coder for all kinds of python object. + A coder used with cloudpickle for all kinds of python object. """ cpdef encode_to_stream(self, value, OutputStream out_stream): @@ -646,6 +647,22 @@ cdef class PickledBytesCoderImpl(FieldCoderImpl): pickled_bytes = in_stream.read_bytes() return cloudpickle.loads(pickled_bytes) + +cdef class PickleCoderImpl(FieldCoderImpl): + """ + A coder used with pickle for all kinds of python object. + """ + + cpdef encode_to_stream(self, value, OutputStream out_stream): + cdef bytes pickled_bytes + pickled_bytes = pickle.dumps(value) + out_stream.write_bytes(pickled_bytes, len(pickled_bytes)) + + cpdef decode_from_stream(self, InputStream in_stream, size_t size): + cdef bytes pickled_bytes + pickled_bytes = in_stream.read_bytes() + return pickle.loads(pickled_bytes) + cdef class GenericArrayCoderImpl(FieldCoderImpl): """ A coder for basic array value (the element of array could be null). @@ -775,3 +792,27 @@ cdef class CountWindowCoderImpl(FieldCoderImpl): cpdef decode_from_stream(self, InputStream in_stream, size_t size): return CountWindow(in_stream.read_int64()) + +cdef class DataViewFilterCoderImpl(FieldCoderImpl): + """ + A coder for CountWindow. + """ + def __init__(self, udf_data_view_specs): + self._udf_data_view_specs = udf_data_view_specs + self._pickle_coder = PickleCoderImpl() + + cpdef encode_to_stream(self, value, OutputStream out_stream): + self._pickle_coder.encode_to_stream(self._filter_data_views(value), out_stream) + + cpdef decode_from_stream(self, InputStream in_stream, size_t size): + return self._pickle_coder.decode_from_stream(in_stream, size) + + def _filter_data_views(self, row): + i = 0 + for specs in self._udf_data_view_specs: + for spec in specs: + row[i][spec.field_index] = None + i += 1 + return row + + diff --git a/flink-python/pyflink/fn_execution/coder_impl_slow.py b/flink-python/pyflink/fn_execution/coder_impl_slow.py index 3f8d4e81682fb..89e871579549a 100644 --- a/flink-python/pyflink/fn_execution/coder_impl_slow.py +++ b/flink-python/pyflink/fn_execution/coder_impl_slow.py @@ -17,6 +17,7 @@ ################################################################################ import datetime import decimal +import pickle from abc import ABC, abstractmethod from typing import List @@ -585,9 +586,9 @@ def internal_to_timestamp(self, milliseconds, nanoseconds): milliseconds, nanoseconds)) -class PickledBytesCoderImpl(FieldCoderImpl): +class CloudPickleCoderImpl(FieldCoderImpl): """ - A coder for all kinds of python object. + A coder used with cloudpickle for all kinds of python object. """ def __init__(self): @@ -606,7 +607,31 @@ def _decode_one_value_from_stream(self, in_stream: InputStream): return value def __repr__(self) -> str: - return 'PickledBytesCoderImpl[%s]' % str(self.field_coder) + return 'CloudPickleCoderImpl[%s]' % str(self.field_coder) + + +class PickleCoderImpl(FieldCoderImpl): + """ + A coder used with pickle for all kinds of python object. + """ + + def __init__(self): + self.field_coder = BinaryCoderImpl() + + def encode_to_stream(self, value, out_stream): + coded_data = pickle.dumps(value) + self.field_coder.encode_to_stream(coded_data, out_stream) + + def decode_from_stream(self, in_stream, length=0): + return self._decode_one_value_from_stream(in_stream) + + def _decode_one_value_from_stream(self, in_stream: InputStream): + real_data = self.field_coder.decode_from_stream(in_stream) + value = pickle.loads(real_data) + return value + + def __repr__(self) -> str: + return 'PickleCoderImpl[%s]' % str(self.field_coder) class TupleCoderImpl(FieldCoderImpl): @@ -743,3 +768,26 @@ def encode_to_stream(self, value, out_stream): def decode_from_stream(self, in_stream, length=0): return CountWindow(in_stream.read_int64()) + + +class DataViewFilterCoderImpl(FieldCoderImpl): + """ + A coder for data view filter. + """ + def __init__(self, udf_data_view_specs): + self._udf_data_view_specs = udf_data_view_specs + self._pickle_coder = PickleCoderImpl() + + def encode_to_stream(self, value, out_stream): + self._pickle_coder.encode_to_stream(self._filter_data_views(value), out_stream) + + def decode_from_stream(self, in_stream, length=0): + return self._pickle_coder.decode_from_stream(in_stream) + + def _filter_data_views(self, row): + i = 0 + for specs in self._udf_data_view_specs: + for spec in specs: + row[i][spec.field_index] = None + i += 1 + return row diff --git a/flink-python/pyflink/fn_execution/coders.py b/flink-python/pyflink/fn_execution/coders.py index cd120caf23580..488741c8982e8 100644 --- a/flink-python/pyflink/fn_execution/coders.py +++ b/flink-python/pyflink/fn_execution/coders.py @@ -22,6 +22,9 @@ import pyarrow as pa import pytz +from pyflink.common.typeinfo import TypeInformation, BasicTypeInfo, BasicType, DateTypeInfo, \ + TimeTypeInfo, TimestampTypeInfo, PrimitiveArrayTypeInfo, BasicArrayTypeInfo, TupleTypeInfo, \ + MapTypeInfo, ListTypeInfo, RowTypeInfo, PickledBytesTypeInfo, ObjectArrayTypeInfo from pyflink.fn_execution import flink_fn_execution_pb2 from pyflink.table.types import TinyIntType, SmallIntType, IntType, BigIntType, BooleanType, \ FloatType, DoubleType, VarCharType, VarBinaryType, DecimalType, DateType, TimeType, \ @@ -36,7 +39,8 @@ 'SmallIntCoder', 'IntCoder', 'FloatCoder', 'DoubleCoder', 'BinaryCoder', 'CharCoder', 'DateCoder', 'TimeCoder', 'TimestampCoder', 'LocalZonedTimestampCoder', 'GenericArrayCoder', 'PrimitiveArrayCoder', 'MapCoder', 'DecimalCoder', - 'BigDecimalCoder', 'TupleCoder', 'TimeWindowCoder', 'CountWindowCoder'] + 'BigDecimalCoder', 'TupleCoder', 'TimeWindowCoder', 'CountWindowCoder', + 'PickleCoder', 'CloudPickleCoder', 'DataViewFilterCoder'] # LengthPrefixBaseCoder will be used in Operations and other coders will be the field coder @@ -152,6 +156,9 @@ class FieldCoder(ABC): def get_impl(self) -> coder_impl.FieldCoderImpl: pass + def __eq__(self, other): + return type(self) == type(other) + class IterableCoder(LengthPrefixBaseCoder): """ @@ -430,6 +437,11 @@ def __init__(self, precision, scale): def get_impl(self): return coder_impl.DecimalCoderImpl(self.precision, self.scale) + def __eq__(self, other: 'DecimalCoder'): + return (self.__class__ == other.__class__ and + self.precision == other.precision and + self.scale == other.scale) + class BigDecimalCoder(FieldCoder): """ @@ -487,6 +499,9 @@ def __init__(self, precision): def get_impl(self): return coder_impl.TimestampCoderImpl(self.precision) + def __eq__(self, other: 'TimestampCoder'): + return self.__class__ == other.__class__ and self.precision == other.precision + class LocalZonedTimestampCoder(FieldCoder): """ @@ -500,11 +515,28 @@ def __init__(self, precision, timezone): def get_impl(self): return coder_impl.LocalZonedTimestampCoderImpl(self.precision, self.timezone) + def __eq__(self, other: 'LocalZonedTimestampCoder'): + return (self.__class__ == other.__class__ and + self.precision == other.precision and + self.timezone == other.timezone) + -class PickledBytesCoder(FieldCoder): +class CloudPickleCoder(FieldCoder): + """ + Coder used with cloudpickle to encode python object. + """ def get_impl(self): - return coder_impl.PickledBytesCoderImpl() + return coder_impl.CloudPickleCoderImpl() + + +class PickleCoder(FieldCoder): + """ + Coder used with pickle to encode python object. + """ + + def get_impl(self): + return coder_impl.PickleCoderImpl() class TupleCoder(FieldCoder): @@ -521,6 +553,11 @@ def get_impl(self): def __repr__(self): return 'TupleCoder[%s]' % ', '.join(str(c) for c in self._field_coders) + def __eq__(self, other: 'TupleCoder'): + return (self.__class__ == other.__class__ and + [self._field_coders[i] == other._field_coders[i] + for i in range(len(self._field_coders))]) + class TimeWindowCoder(FieldCoder): """ @@ -540,6 +577,18 @@ def get_impl(self): return coder_impl.CountWindowCoderImpl() +class DataViewFilterCoder(FieldCoder): + """ + Coder for data view filter. + """ + + def __init__(self, udf_data_view_specs): + self._udf_data_view_specs = udf_data_view_specs + + def get_impl(self): + return coder_impl.DataViewFilterCoderImpl(self._udf_data_view_specs) + + type_name = flink_fn_execution_pb2.Schema _type_name_mappings = { type_name.TINYINT: TinyIntCoder(), @@ -606,7 +655,7 @@ def from_proto(field_type): type_info_name.SQL_DATE: DateCoder(), type_info_name.SQL_TIME: TimeCoder(), type_info_name.SQL_TIMESTAMP: TimestampCoder(3), - type_info_name.PICKLED_BYTES: PickledBytesCoder() + type_info_name.PICKLED_BYTES: CloudPickleCoder() } @@ -635,3 +684,57 @@ def from_type_info_proto(type_info): from_type_info_proto(type_info.map_type_info.value_type)) else: raise ValueError("Unsupported type_info %s." % type_info) + + +_basic_type_info_mappings = { + BasicType.BYTE: TinyIntCoder(), + BasicType.BOOLEAN: BooleanCoder(), + BasicType.SHORT: SmallIntCoder(), + BasicType.INT: IntCoder(), + BasicType.LONG: BigIntCoder(), + BasicType.BIG_INT: BigIntCoder(), + BasicType.FLOAT: FloatCoder(), + BasicType.DOUBLE: DoubleCoder(), + BasicType.STRING: CharCoder(), + BasicType.CHAR: CharCoder(), + BasicType.BIG_DEC: BigDecimalCoder(), +} + + +def from_type_info(type_info: TypeInformation) -> FieldCoder: + """ + Mappings from type_info to Coder + """ + + if isinstance(type_info, PickledBytesTypeInfo): + return PickleCoder() + elif isinstance(type_info, BasicTypeInfo): + return _basic_type_info_mappings[type_info._basic_type] + elif isinstance(type_info, DateTypeInfo): + return DateCoder() + elif isinstance(type_info, TimeTypeInfo): + return TimeCoder() + elif isinstance(type_info, TimestampTypeInfo): + return TimestampCoder(3) + elif isinstance(type_info, PrimitiveArrayTypeInfo): + element_type = type_info._element_type + if isinstance(element_type, BasicTypeInfo) and element_type._basic_type == BasicType.BYTE: + return BinaryCoder() + else: + return PrimitiveArrayCoder(from_type_info(element_type)) + elif isinstance(type_info, (BasicArrayTypeInfo, ObjectArrayTypeInfo)): + return GenericArrayCoder(from_type_info(type_info._element_type)) + elif isinstance(type_info, ListTypeInfo): + return GenericArrayCoder(from_type_info(type_info.elem_type)) + elif isinstance(type_info, MapTypeInfo): + return MapCoder( + from_type_info(type_info._key_type_info), from_type_info(type_info._value_type_info)) + elif isinstance(type_info, TupleTypeInfo): + return TupleCoder([from_type_info(field_type) + for field_type in type_info.get_field_types()]) + elif isinstance(type_info, RowTypeInfo): + return RowCoder( + [from_type_info(f) for f in type_info.get_field_types()], + [f for f in type_info.get_field_names()]) + else: + raise ValueError("Unsupported type_info %s." % type_info) diff --git a/flink-python/pyflink/fn_execution/datastream/operations.py b/flink-python/pyflink/fn_execution/datastream/operations.py index ae79f20296a16..bb28e6d9e8548 100644 --- a/flink-python/pyflink/fn_execution/datastream/operations.py +++ b/flink-python/pyflink/fn_execution/datastream/operations.py @@ -289,7 +289,9 @@ def process_element(normal_data, timestamp: int): window_state_descriptor = window_operation_descriptor.window_state_descriptor internal_window_function = window_operation_descriptor.internal_window_function window_serializer = window_operation_descriptor.window_serializer - keyed_state_backend._namespace_coder_impl = window_serializer._get_coder() + window_coder = window_serializer._get_coder() + keyed_state_backend.namespace_coder = window_coder + keyed_state_backend._namespace_coder_impl = window_coder.get_impl() window_operator = WindowOperator( window_assigner, keyed_state_backend, diff --git a/flink-python/pyflink/fn_execution/datastream/runtime_context.py b/flink-python/pyflink/fn_execution/datastream/runtime_context.py index aaa9d53c84ed8..6ff71181be53b 100644 --- a/flink-python/pyflink/fn_execution/datastream/runtime_context.py +++ b/flink-python/pyflink/fn_execution/datastream/runtime_context.py @@ -17,12 +17,11 @@ ################################################################################ from typing import Dict, Union -from apache_beam.coders import PickleCoder - from pyflink.datastream import RuntimeContext from pyflink.datastream.state import ValueStateDescriptor, ValueState, ListStateDescriptor, \ ListState, MapStateDescriptor, MapState, ReducingStateDescriptor, ReducingState, \ AggregatingStateDescriptor, AggregatingState +from pyflink.fn_execution.coders import from_type_info, MapCoder, GenericArrayCoder from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend from pyflink.metrics import MetricGroup @@ -105,27 +104,35 @@ def get_metrics_group(self) -> MetricGroup: def get_state(self, state_descriptor: ValueStateDescriptor) -> ValueState: if self._keyed_state_backend: - return self._keyed_state_backend.get_value_state(state_descriptor.name, PickleCoder()) + return self._keyed_state_backend.get_value_state( + state_descriptor.name, from_type_info(state_descriptor.type_info)) else: raise Exception("This state is only accessible by functions executed on a KeyedStream.") def get_list_state(self, state_descriptor: ListStateDescriptor) -> ListState: if self._keyed_state_backend: - return self._keyed_state_backend.get_list_state(state_descriptor.name, PickleCoder()) + array_coder = from_type_info(state_descriptor.type_info) # type: GenericArrayCoder + return self._keyed_state_backend.get_list_state( + state_descriptor.name, array_coder._elem_coder) else: raise Exception("This state is only accessible by functions executed on a KeyedStream.") def get_map_state(self, state_descriptor: MapStateDescriptor) -> MapState: if self._keyed_state_backend: - return self._keyed_state_backend.get_map_state(state_descriptor.name, PickleCoder(), - PickleCoder()) + map_coder = from_type_info(state_descriptor.type_info) # type: MapCoder + key_coder = map_coder._key_coder + value_coder = map_coder._value_coder + return self._keyed_state_backend.get_map_state( + state_descriptor.name, key_coder, value_coder) else: raise Exception("This state is only accessible by functions executed on a KeyedStream.") def get_reducing_state(self, state_descriptor: ReducingStateDescriptor) -> ReducingState: if self._keyed_state_backend: return self._keyed_state_backend.get_reducing_state( - state_descriptor.get_name(), PickleCoder(), state_descriptor.get_reduce_function()) + state_descriptor.get_name(), + from_type_info(state_descriptor.type_info), + state_descriptor.get_reduce_function()) else: raise Exception("This state is only accessible by functions executed on a KeyedStream.") @@ -133,7 +140,9 @@ def get_aggregating_state( self, state_descriptor: AggregatingStateDescriptor) -> AggregatingState: if self._keyed_state_backend: return self._keyed_state_backend.get_aggregating_state( - state_descriptor.get_name(), PickleCoder(), state_descriptor.get_agg_function()) + state_descriptor.get_name(), + from_type_info(state_descriptor.type_info), + state_descriptor.get_agg_function()) else: raise Exception("This state is only accessible by functions executed on a KeyedStream.") diff --git a/flink-python/pyflink/fn_execution/datastream/window/merging_window_set.py b/flink-python/pyflink/fn_execution/datastream/window/merging_window_set.py index 0bc33f8a5dc60..639a87e53db59 100644 --- a/flink-python/pyflink/fn_execution/datastream/window/merging_window_set.py +++ b/flink-python/pyflink/fn_execution/datastream/window/merging_window_set.py @@ -16,10 +16,10 @@ # limitations under the License. ################################################################################ from abc import ABC, abstractmethod -from typing import TypeVar, Generic, Tuple, Collection, Iterable +from typing import TypeVar, Generic, Collection, Iterable from pyflink.datastream import MergingWindowAssigner -from pyflink.datastream.state import ListState +from pyflink.datastream.state import MapState W = TypeVar("W") @@ -45,11 +45,11 @@ def merge(self, merged_state_windows: Collection[W]): pass - def __init__(self, assigner: MergingWindowAssigner, state: ListState[Tuple[W, W]]): + def __init__(self, assigner: MergingWindowAssigner, state: MapState[W, W]): self._window_assigner = assigner self._mapping = dict() - for window_for_user, window_in_state in state: + for window_for_user, window_in_state in state.items(): self._mapping[window_for_user] = window_in_state self._state = state @@ -59,7 +59,7 @@ def persist(self) -> None: if self._mapping != self._initial_mapping: self._state.clear() for window_for_user, window_in_state in self._mapping.items(): - self._state.add((window_for_user, window_in_state)) + self._state.put(window_for_user, window_in_state) def get_state_window(self, window: W) -> W: if window in self._mapping: diff --git a/flink-python/pyflink/fn_execution/datastream/window/window_operator.py b/flink-python/pyflink/fn_execution/datastream/window/window_operator.py index 3a9bb616fe62b..d01fc0f2285d5 100644 --- a/flink-python/pyflink/fn_execution/datastream/window/window_operator.py +++ b/flink-python/pyflink/fn_execution/datastream/window/window_operator.py @@ -18,7 +18,6 @@ import typing from typing import TypeVar, Iterable, Collection -from pyflink.common.typeinfo import Types from pyflink.datastream import WindowAssigner, Trigger, MergingWindowAssigner, TriggerResult from pyflink.datastream.functions import KeyedStateStore, RuntimeContext, InternalWindowFunction from pyflink.datastream.state import StateDescriptor, ListStateDescriptor, \ @@ -324,11 +323,12 @@ def open(self, runtime_context: RuntimeContext, internal_timer_service: Internal # TODO: the type info is just a placeholder currently. # it should be the real type serializer after supporting the user-defined state type # serializer - merging_sets_state_descriptor = ListStateDescriptor( - "merging-window-set", Types.PICKLED_BYTE_ARRAY()) + # merging_sets_state_descriptor = ListStateDescriptor( + # "merging-window-set", Types.PICKLED_BYTE_ARRAY()) - self.merging_sets_state = get_or_create_keyed_state( - runtime_context, merging_sets_state_descriptor) + window_coder = self.keyed_state_backend.namespace_coder + self.merging_sets_state = self.keyed_state_backend.get_map_state( + "merging-window-set", window_coder, window_coder) self.merge_function = WindowMergeFunction(self) diff --git a/flink-python/pyflink/fn_execution/state_impl.py b/flink-python/pyflink/fn_execution/state_impl.py index 590c1d7e60006..2d2785114bf75 100644 --- a/flink-python/pyflink/fn_execution/state_impl.py +++ b/flink-python/pyflink/fn_execution/state_impl.py @@ -1001,6 +1001,8 @@ def _get_internal_bag_state(self, name, namespace, element_coder): # at once. The internal state cache is only updated when the current key changes. # The reason is that the state cache size may be smaller that the count of activated # state (i.e. the state with current key). + if isinstance(element_coder, FieldCoder): + element_coder = FlinkCoder(element_coder) state_spec = userstate.BagStateSpec(name, element_coder) internal_state = self._create_bag_state(state_spec, encoded_namespace) return internal_state diff --git a/flink-python/pyflink/fn_execution/stream_slow.py b/flink-python/pyflink/fn_execution/stream_slow.py index 7d6a942e3ad22..143db4a95e38e 100644 --- a/flink-python/pyflink/fn_execution/stream_slow.py +++ b/flink-python/pyflink/fn_execution/stream_slow.py @@ -61,8 +61,7 @@ def read_var_int64(self): shift = 0 result = 0 while True: - self.pos += 1 - byte = self.data[self.pos - 1] + byte = self.read_byte() if byte < 0: raise RuntimeError('VarLong not terminated.') diff --git a/flink-python/pyflink/fn_execution/table/aggregate_fast.pyx b/flink-python/pyflink/fn_execution/table/aggregate_fast.pyx index 13efcda60a9c1..265552b276aea 100644 --- a/flink-python/pyflink/fn_execution/table/aggregate_fast.pyx +++ b/flink-python/pyflink/fn_execution/table/aggregate_fast.pyx @@ -22,9 +22,8 @@ from libc.stdlib cimport free, malloc from typing import List, Dict -from apache_beam.coders import PickleCoder, Coder - from pyflink.common import Row +from pyflink.fn_execution.coders import PickleCoder from pyflink.fn_execution.table.state_data_view import DataViewSpec, ListViewSpec, MapViewSpec, \ PerKeyStateDataViewStore from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend @@ -197,13 +196,13 @@ cdef class SimpleAggsHandleFunctionBase(AggsHandleFunctionBase): data_views[data_view_spec.field_index] = \ state_data_view_store.get_state_list_view( data_view_spec.state_id, - PickleCoder()) + data_view_spec.element_coder) elif isinstance(data_view_spec, MapViewSpec): data_views[data_view_spec.field_index] = \ state_data_view_store.get_state_map_view( data_view_spec.state_id, - PickleCoder(), - PickleCoder()) + data_view_spec.key_coder, + data_view_spec.value_coder) self._udf_data_views.append(data_views) for key in self._distinct_view_descriptors.keys(): self._distinct_data_views[key] = state_data_view_store.get_state_map_view( @@ -437,7 +436,7 @@ cdef class GroupAggFunctionBase: aggs_handle: AggsHandleFunctionBase, key_selector: RowKeySelector, state_backend: RemoteKeyedStateBackend, - state_value_coder: Coder, + state_value_coder, generate_update_before: bool, state_cleaning_enabled: bool, index_of_count_star: int): @@ -481,7 +480,7 @@ cdef class GroupAggFunction(GroupAggFunctionBase): aggs_handle, key_selector: RowKeySelector, state_backend: RemoteKeyedStateBackend, - state_value_coder: Coder, + state_value_coder, generate_update_before: bool, state_cleaning_enabled: bool, index_of_count_star: int): @@ -588,7 +587,7 @@ cdef class GroupTableAggFunction(GroupAggFunctionBase): aggs_handle, key_selector: RowKeySelector, state_backend: RemoteKeyedStateBackend, - state_value_coder: Coder, + state_value_coder, generate_update_before: bool, state_cleaning_enabled: bool, index_of_count_star: int): diff --git a/flink-python/pyflink/fn_execution/table/aggregate_slow.py b/flink-python/pyflink/fn_execution/table/aggregate_slow.py index 4ae7bfe5352bb..d847d53a02be4 100644 --- a/flink-python/pyflink/fn_execution/table/aggregate_slow.py +++ b/flink-python/pyflink/fn_execution/table/aggregate_slow.py @@ -18,9 +18,8 @@ from abc import ABC, abstractmethod from typing import List, Dict, Iterable -from apache_beam.coders import PickleCoder, Coder - from pyflink.common import Row, RowKind +from pyflink.fn_execution.coders import PickleCoder from pyflink.fn_execution.table.state_data_view import DataViewSpec, ListViewSpec, MapViewSpec, \ PerKeyStateDataViewStore from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend @@ -210,13 +209,13 @@ def open(self, state_data_view_store): data_views[data_view_spec.field_index] = \ state_data_view_store.get_state_list_view( data_view_spec.state_id, - PickleCoder()) + data_view_spec.element_coder) elif isinstance(data_view_spec, MapViewSpec): data_views[data_view_spec.field_index] = \ state_data_view_store.get_state_map_view( data_view_spec.state_id, - PickleCoder(), - PickleCoder()) + data_view_spec.key_coder, + data_view_spec.value_coder) self._udf_data_views.append(data_views) for key in self._distinct_view_descriptors.keys(): self._distinct_data_views[key] = state_data_view_store.get_state_map_view( @@ -411,7 +410,7 @@ def __init__(self, aggs_handle: AggsHandleFunctionBase, key_selector: RowKeySelector, state_backend: RemoteKeyedStateBackend, - state_value_coder: Coder, + state_value_coder, generate_update_before: bool, state_cleaning_enabled: bool, index_of_count_star: int): @@ -457,7 +456,7 @@ def __init__(self, aggs_handle: AggsHandleFunction, key_selector: RowKeySelector, state_backend: RemoteKeyedStateBackend, - state_value_coder: Coder, + state_value_coder, generate_update_before: bool, state_cleaning_enabled: bool, index_of_count_star: int): @@ -552,7 +551,7 @@ def __init__(self, aggs_handle: TableAggsHandleFunction, key_selector: RowKeySelector, state_backend: RemoteKeyedStateBackend, - state_value_coder: Coder, + state_value_coder, generate_update_before: bool, state_cleaning_enabled: bool, index_of_count_star: int): diff --git a/flink-python/pyflink/fn_execution/table/operations.py b/flink-python/pyflink/fn_execution/table/operations.py index 74adbe16f9c75..fc2a9a9fafcff 100644 --- a/flink-python/pyflink/fn_execution/table/operations.py +++ b/flink-python/pyflink/fn_execution/table/operations.py @@ -20,14 +20,12 @@ from itertools import chain from typing import Tuple -from apache_beam.coders import PickleCoder - +from pyflink.fn_execution.coders import DataViewFilterCoder, PickleCoder from pyflink.fn_execution.datastream.timerservice import InternalTimer from pyflink.fn_execution.datastream.operations import Operation from pyflink.fn_execution.datastream.timerservice_impl import InternalTimerImpl, TimerOperandType from pyflink.fn_execution import flink_fn_execution_pb2 from pyflink.fn_execution.table.state_data_view import extract_data_view_specs -from pyflink.fn_execution.beam.beam_coders import DataViewFilterCoder from pyflink.fn_execution.table.window_assigner import TumblingWindowAssigner, \ CountTumblingWindowAssigner, SlidingWindowAssigner, CountSlidingWindowAssigner, \ diff --git a/flink-python/pyflink/fn_execution/table/state_data_view.py b/flink-python/pyflink/fn_execution/table/state_data_view.py index 44e7fac57e828..0388e6337c04f 100644 --- a/flink-python/pyflink/fn_execution/table/state_data_view.py +++ b/flink-python/pyflink/fn_execution/table/state_data_view.py @@ -18,10 +18,8 @@ from abc import ABC, abstractmethod from typing import TypeVar, Generic, Union -from apache_beam.coders import PickleCoder - from pyflink.datastream.state import ListState, MapState -from pyflink.fn_execution.coders import from_proto +from pyflink.fn_execution.coders import from_proto, PickleCoder from pyflink.fn_execution.internal_state import InternalListState, InternalMapState from pyflink.fn_execution.utils.operation_utils import is_built_in_function, load_aggregate_function from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend diff --git a/flink-python/pyflink/fn_execution/table/window_aggregate_fast.pyx b/flink-python/pyflink/fn_execution/table/window_aggregate_fast.pyx index ebce8597bf7ea..978fd56f8ce1f 100644 --- a/flink-python/pyflink/fn_execution/table/window_aggregate_fast.pyx +++ b/flink-python/pyflink/fn_execution/table/window_aggregate_fast.pyx @@ -29,9 +29,9 @@ import sys from typing import List, Dict import pytz -from apache_beam.coders import PickleCoder, Coder from pyflink.fn_execution.datastream.timerservice_impl import InternalTimerServiceImpl +from pyflink.fn_execution.coders import PickleCoder from pyflink.fn_execution.table.state_data_view import DataViewSpec, ListViewSpec, MapViewSpec, \ PerWindowStateDataViewStore from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend @@ -169,13 +169,13 @@ cdef class SimpleNamespaceAggsHandleFunction(NamespaceAggsHandleFunction): data_views[data_view_spec.field_index] = \ state_data_view_store.get_state_list_view( data_view_spec.state_id, - PickleCoder()) + data_view_spec.element_coder) elif isinstance(data_view_spec, MapViewSpec): data_views[data_view_spec.field_index] = \ state_data_view_store.get_state_map_view( data_view_spec.state_id, - PickleCoder(), - PickleCoder()) + data_view_spec.key_coder, + data_view_spec.value_coder) self._udf_data_views.append(data_views) for key in self._distinct_view_descriptors.keys(): self._distinct_data_views[key] = state_data_view_store.get_state_map_view( @@ -330,7 +330,7 @@ cdef class GroupWindowAggFunctionBase: allowed_lateness: int, key_selector: RowKeySelector, state_backend: RemoteKeyedStateBackend, - state_value_coder: Coder, + state_value_coder, window_assigner: WindowAssigner[W], window_aggregator: NamespaceAggsHandleFunctionBase, trigger: Trigger[W], @@ -511,7 +511,7 @@ cdef class GroupWindowAggFunction(GroupWindowAggFunctionBase): allowed_lateness: int, key_selector: RowKeySelector, state_backend: RemoteKeyedStateBackend, - state_value_coder: Coder, + state_value_coder, window_assigner: WindowAssigner[W], window_aggregator: NamespaceAggsHandleFunction[W], trigger: Trigger[W], diff --git a/flink-python/pyflink/fn_execution/table/window_aggregate_slow.py b/flink-python/pyflink/fn_execution/table/window_aggregate_slow.py index b33c56ec4139c..1945d1a30dae0 100644 --- a/flink-python/pyflink/fn_execution/table/window_aggregate_slow.py +++ b/flink-python/pyflink/fn_execution/table/window_aggregate_slow.py @@ -21,11 +21,11 @@ from typing import TypeVar, Generic, List, Dict import pytz -from apache_beam.coders import PickleCoder, Coder from pyflink.common import Row, RowKind from pyflink.fn_execution.datastream.timerservice import InternalTimer from pyflink.fn_execution.datastream.timerservice_impl import InternalTimerServiceImpl +from pyflink.fn_execution.coders import PickleCoder from pyflink.fn_execution.table.aggregate_slow import DistinctViewDescriptor, RowKeySelector from pyflink.fn_execution.table.state_data_view import DataViewSpec, ListViewSpec, MapViewSpec, \ PerWindowStateDataViewStore @@ -173,13 +173,13 @@ def open(self, state_data_view_store): data_views[data_view_spec.field_index] = \ state_data_view_store.get_state_list_view( data_view_spec.state_id, - PickleCoder()) + data_view_spec.element_coder) elif isinstance(data_view_spec, MapViewSpec): data_views[data_view_spec.field_index] = \ state_data_view_store.get_state_map_view( data_view_spec.state_id, - PickleCoder(), - PickleCoder()) + data_view_spec.key_coder, + data_view_spec.value_coder) self._udf_data_views.append(data_views) for key in self._distinct_view_descriptors.keys(): self._distinct_data_views[key] = state_data_view_store.get_state_map_view( @@ -290,7 +290,7 @@ def __init__(self, allowed_lateness: int, key_selector: RowKeySelector, state_backend: RemoteKeyedStateBackend, - state_value_coder: Coder, + state_value_coder, window_assigner: WindowAssigner[W], window_aggregator: NamespaceAggsHandleFunctionBase[W], trigger: Trigger[W], @@ -457,7 +457,7 @@ def __init__(self, allowed_lateness: int, key_selector: RowKeySelector, state_backend: RemoteKeyedStateBackend, - state_value_coder: Coder, + state_value_coder, window_assigner: WindowAssigner[W], window_aggregator: NamespaceAggsHandleFunction[W], trigger: Trigger[W], diff --git a/flink-python/pyflink/fn_execution/table/window_assigner.py b/flink-python/pyflink/fn_execution/table/window_assigner.py index a0a56c9310545..aebcc89d08225 100644 --- a/flink-python/pyflink/fn_execution/table/window_assigner.py +++ b/flink-python/pyflink/fn_execution/table/window_assigner.py @@ -144,8 +144,7 @@ def __init__(self, size: int): self._count = None # type: ValueState def open(self, ctx: Context[Any, CountWindow]): - value_state_descriptor = ValueStateDescriptor('tumble-count-assigner', - Types.PICKLED_BYTE_ARRAY()) + value_state_descriptor = ValueStateDescriptor('tumble-count-assigner', Types.LONG()) self._count = ctx.get_partitioned_state(value_state_descriptor) def assign_windows(self, element: List, timestamp: int) -> Iterable[CountWindow]: @@ -218,7 +217,7 @@ def __init__(self, size, slide): self._count = None # type: ValueState def open(self, ctx: Context[Any, CountWindow]): - count_descriptor = ValueStateDescriptor('slide-count-assigner', Types.PICKLED_BYTE_ARRAY()) + count_descriptor = ValueStateDescriptor('slide-count-assigner', Types.LONG()) self._count = ctx.get_partitioned_state(count_descriptor) def assign_windows(self, element: List, timestamp: int) -> Iterable[W]: diff --git a/flink-python/pyflink/fn_execution/table/window_context.py b/flink-python/pyflink/fn_execution/table/window_context.py index f40918a729461..35c9e323ddc5f 100644 --- a/flink-python/pyflink/fn_execution/table/window_context.py +++ b/flink-python/pyflink/fn_execution/table/window_context.py @@ -19,13 +19,14 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar, List, Iterable -from apache_beam.coders import Coder, PickleCoder +from apache_beam.coders import Coder from pyflink.datastream.state import StateDescriptor, State, ValueStateDescriptor, \ ListStateDescriptor, MapStateDescriptor from pyflink.datastream.window import TimeWindow, CountWindow from pyflink.fn_execution.datastream.timerservice import InternalTimerService from pyflink.fn_execution.datastream.timerservice_impl import InternalTimerServiceImpl +from pyflink.fn_execution.coders import from_type_info, MapCoder, GenericArrayCoder from pyflink.fn_execution.internal_state import InternalMergingState from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend @@ -228,12 +229,18 @@ def clear(self): def get_partitioned_state(self, state_descriptor: StateDescriptor) -> State: if isinstance(state_descriptor, ValueStateDescriptor): - state = self._state_backend.get_value_state(state_descriptor.name, PickleCoder()) + state = self._state_backend.get_value_state( + state_descriptor.name, from_type_info(state_descriptor.type_info)) elif isinstance(state_descriptor, ListStateDescriptor): - state = self._state_backend.get_list_state(state_descriptor.name, PickleCoder()) + array_coder = from_type_info(state_descriptor.type_info) # type: GenericArrayCoder + state = self._state_backend.get_list_state( + state_descriptor.name, array_coder._elem_coder) elif isinstance(state_descriptor, MapStateDescriptor): + map_coder = from_type_info(state_descriptor.type_info) # type: MapCoder + key_coder = map_coder._key_coder + value_coder = map_coder._value_coder state = self._state_backend.get_map_state( - state_descriptor.name, PickleCoder(), PickleCoder()) + state_descriptor.name, key_coder, value_coder) else: raise Exception("Unknown supported StateDescriptor %s" % state_descriptor) state.set_current_namespace(self.window) diff --git a/flink-python/pyflink/fn_execution/table/window_trigger.py b/flink-python/pyflink/fn_execution/table/window_trigger.py index 2d9d907d1cf80..71e9202cbbe2e 100644 --- a/flink-python/pyflink/fn_execution/table/window_trigger.py +++ b/flink-python/pyflink/fn_execution/table/window_trigger.py @@ -180,7 +180,7 @@ class CountTrigger(Trigger[CountWindow]): def __init__(self, count_elements: int): self._count_elements = count_elements self._count_state_desc = ValueStateDescriptor( - "trigger-count-%s" % count_elements, Types.PICKLED_BYTE_ARRAY()) + "trigger-count-%s" % count_elements, Types.LONG()) self._ctx = None # type: TriggerContext def open(self, ctx: TriggerContext): diff --git a/flink-python/setup.py b/flink-python/setup.py index 3aa7b1b3184d9..d6483cb9f5658 100644 --- a/flink-python/setup.py +++ b/flink-python/setup.py @@ -120,8 +120,8 @@ def extracted_output_files(base_dir, file_path, output_directory): sources=["pyflink/fn_execution/stream_fast.pyx"], include_dirs=["pyflink/fn_execution/"]), Extension( - name="pyflink.fn_execution.beam.beam_stream", - sources=["pyflink/fn_execution/beam/beam_stream.pyx"], + name="pyflink.fn_execution.beam.beam_stream_fast", + sources=["pyflink/fn_execution/beam/beam_stream_fast.pyx"], include_dirs=["pyflink/fn_execution/beam"]), Extension( name="pyflink.fn_execution.beam.beam_coder_impl_fast", @@ -152,8 +152,8 @@ def extracted_output_files(base_dir, file_path, output_directory): sources=["pyflink/fn_execution/stream_fast.c"], include_dirs=["pyflink/fn_execution/"]), Extension( - name="pyflink.fn_execution.beam.beam_stream", - sources=["pyflink/fn_execution/beam/beam_stream.c"], + name="pyflink.fn_execution.beam.beam_stream_fast", + sources=["pyflink/fn_execution/beam/beam_stream_fast.c"], include_dirs=["pyflink/fn_execution/beam"]), Extension( name="pyflink.fn_execution.beam.beam_coder_impl_fast",