Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
HuangXingBo committed Jul 7, 2021
1 parent f634192 commit 98e5b65
Show file tree
Hide file tree
Showing 17 changed files with 82 additions and 94 deletions.
4 changes: 4 additions & 0 deletions flink-python/pyflink/common/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def _get_coder(self):
deserialize_func = self.deserialize

class CoderAdapter(object):
def get_impl(self):
return CoderAdapterIml()

class CoderAdapterIml(object):

def encode_nested(self, element):
bytes_io = BytesIO()
Expand Down
5 changes: 2 additions & 3 deletions flink-python/pyflink/datastream/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,8 +932,7 @@ def __init__(self, reduce_function: ReduceFunction):

def open(self, runtime_context: RuntimeContext):
self._reduce_value_state = runtime_context.get_state(
ValueStateDescriptor("_reduce_state" + str(uuid.uuid4()),
Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("_reduce_state" + str(uuid.uuid4()), output_type))
self._reduce_function.open(runtime_context)
from pyflink.fn_execution.datastream.runtime_context import StreamingRuntimeContext
self._in_batch_execution_mode = \
Expand Down Expand Up @@ -1201,7 +1200,7 @@ def _get_result_data_stream(
self.get_execution_environment())
window_serializer = self._window_assigner.get_window_serializer()
window_state_descriptor = ListStateDescriptor(
"window-contents", Types.PICKLED_BYTE_ARRAY())
"window-contents", self.get_input_type())
window_operation_descriptor = WindowOperationDescriptor(
self._window_assigner,
self._window_trigger,
Expand Down
26 changes: 1 addition & 25 deletions flink-python/pyflink/datastream/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from typing import TypeVar, Generic, Iterable, List, Iterator, Dict, Tuple

from pyflink.common.typeinfo import TypeInformation, Types, PickledBytesTypeInfo
from pyflink.common.typeinfo import TypeInformation, Types

__all__ = [
'ValueStateDescriptor',
Expand Down Expand Up @@ -316,10 +316,6 @@ def __init__(self, name: str, value_type_info: TypeInformation):
:param name: The name of the state.
:param value_type_info: the type information of the state.
"""
if not isinstance(value_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the value could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s."
% type(value_type_info))
super(ValueStateDescriptor, self).__init__(name, value_type_info)


Expand All @@ -336,10 +332,6 @@ def __init__(self, name: str, elem_type_info: TypeInformation):
:param name: The name of the state.
:param elem_type_info: the type information of the state element.
"""
if not isinstance(elem_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the element could only be "
"PickledBytesTypeInfo (created via Types.PICKLED_BYTE_ARRAY()) "
"currently, got %s" % type(elem_type_info))
super(ListStateDescriptor, self).__init__(name, Types.LIST(elem_type_info))


Expand All @@ -357,14 +349,6 @@ def __init__(self, name: str, key_type_info: TypeInformation, value_type_info: T
:param key_type_info: The type information of the key.
:param value_type_info: the type information of the value.
"""
if not isinstance(key_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the key could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s"
% type(key_type_info))
if not isinstance(value_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the value could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s"
% type(value_type_info))
super(MapStateDescriptor, self).__init__(name, Types.MAP(key_type_info, value_type_info))


Expand Down Expand Up @@ -392,10 +376,6 @@ def __init__(self,
reduce_function = ReduceFunctionWrapper(reduce_function) # type: ignore
else:
raise TypeError("The input must be a ReduceFunction or a callable function!")
if not isinstance(type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the state could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s"
% type(type_info))
self._reduce_function = reduce_function

def get_reduce_function(self):
Expand All @@ -418,10 +398,6 @@ def __init__(self,
from pyflink.datastream.functions import AggregateFunction
if not isinstance(agg_function, AggregateFunction):
raise TypeError("The input must be a pyflink.datastream.functions.AggregateFunction!")
if not isinstance(state_type_info, PickledBytesTypeInfo):
raise ValueError("The type information of the state could only be PickledBytesTypeInfo "
"(created via Types.PICKLED_BYTE_ARRAY()) currently, got %s"
% type(state_type_info))
self._agg_function = agg_function

def get_agg_function(self):
Expand Down
35 changes: 15 additions & 20 deletions flink-python/pyflink/datastream/tests/test_data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.pre1 = runtime_context.get_state(
ValueStateDescriptor("pre1", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("pre1", Types.STRING()))
self.pre2 = runtime_context.get_state(
ValueStateDescriptor("pre2", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("pre2", Types.STRING()))

def map1(self, value):
if value[0] == 'b':
Expand Down Expand Up @@ -409,7 +409,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.state = runtime_context.get_state(
ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("test_state", Types.INT()))

def map(self, value):
state_value = self.state.value()
Expand Down Expand Up @@ -453,7 +453,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.state = runtime_context.get_state(
ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("test_state", Types.INT()))

def flat_map(self, value):
state_value = self.state.value()
Expand Down Expand Up @@ -497,7 +497,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.state = runtime_context.get_state(
ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("test_state", Types.INT()))

def filter(self, value):
state_value = self.state.value()
Expand Down Expand Up @@ -694,15 +694,11 @@ def __init__(self):
self.map_state = None

def open(self, runtime_context: RuntimeContext):
value_state_descriptor = ValueStateDescriptor('value_state',
Types.PICKLED_BYTE_ARRAY())
value_state_descriptor = ValueStateDescriptor('value_state', Types.INT())
self.value_state = runtime_context.get_state(value_state_descriptor)
list_state_descriptor = ListStateDescriptor('list_state',
Types.PICKLED_BYTE_ARRAY())
list_state_descriptor = ListStateDescriptor('list_state', Types.INT())
self.list_state = runtime_context.get_list_state(list_state_descriptor)
map_state_descriptor = MapStateDescriptor('map_state',
Types.PICKLED_BYTE_ARRAY(),
Types.PICKLED_BYTE_ARRAY())
map_state_descriptor = MapStateDescriptor('map_state', Types.INT(), Types.STRING())
self.map_state = runtime_context.get_map_state(map_state_descriptor)

def process_element(self, value, ctx):
Expand Down Expand Up @@ -770,7 +766,7 @@ def __init__(self):
def open(self, runtime_context: RuntimeContext):
self.reducing_state = runtime_context.get_reducing_state(
ReducingStateDescriptor(
'reducing_state', lambda i, i2: i + i2, Types.PICKLED_BYTE_ARRAY()))
'reducing_state', lambda i, i2: i + i2, Types.INT()))

def process_element(self, value, ctx):
self.reducing_state.add(value[0])
Expand Down Expand Up @@ -814,7 +810,7 @@ def __init__(self):
def open(self, runtime_context: RuntimeContext):
self.aggregating_state = runtime_context.get_aggregating_state(
AggregatingStateDescriptor(
'aggregating_state', MyAggregateFunction(), Types.PICKLED_BYTE_ARRAY()))
'aggregating_state', MyAggregateFunction(), Types.INT()))

def process_element(self, value, ctx):
self.aggregating_state.add(value[0])
Expand Down Expand Up @@ -1356,7 +1352,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.map_state = runtime_context.get_map_state(
MapStateDescriptor("map", Types.PICKLED_BYTE_ARRAY(), Types.PICKLED_BYTE_ARRAY()))
MapStateDescriptor("map", Types.STRING(), Types.BOOLEAN()))

def flat_map1(self, value):
yield str(value[0] + 1)
Expand All @@ -1376,8 +1372,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.timer_registered = False
self.count_state = runtime_context.get_state(ValueStateDescriptor(
"count", Types.PICKLED_BYTE_ARRAY()))
self.count_state = runtime_context.get_state(ValueStateDescriptor("count", Types.INT()))

def process_element1(self, value, ctx: 'KeyedCoProcessFunction.Context'):
if not self.timer_registered:
Expand Down Expand Up @@ -1411,7 +1406,7 @@ def __init__(self):

def open(self, runtime_context: RuntimeContext):
self.state = runtime_context.get_state(
ValueStateDescriptor("test_state", Types.PICKLED_BYTE_ARRAY()))
ValueStateDescriptor("test_state", Types.INT()))

def reduce(self, value1, value2):
state_value = self.state.value()
Expand Down Expand Up @@ -1439,7 +1434,7 @@ class SimpleCountWindowTrigger(Trigger[tuple, CountWindow]):
def __init__(self):
self._window_size = 3
self._count_state_descriptor = ReducingStateDescriptor(
"trigger_counter", lambda a, b: a + b, Types.PICKLED_BYTE_ARRAY())
"trigger_counter", lambda a, b: a + b, Types.BIG_INT())

def on_element(self,
element: tuple,
Expand Down Expand Up @@ -1479,7 +1474,7 @@ def __init__(self):
self._window_id = 0
self._window_size = 3
self._counter_state_descriptor = ReducingStateDescriptor(
"assigner_counter", lambda a, b: a + b, Types.PICKLED_BYTE_ARRAY())
"assigner_counter", lambda a, b: a + b, Types.BIG_INT())

def assign_windows(self,
element: tuple,
Expand Down
22 changes: 8 additions & 14 deletions flink-python/pyflink/datastream/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,22 +153,19 @@ def __init__(self):

def serialize(self, element: TimeWindow, stream: BytesIO) -> None:
if self._underlying_coder is None:
self._underlying_coder = self._get_coder()
self._underlying_coder = self._get_coder().get_impl()
bytes_data = self._underlying_coder.encode(element)
stream.write(bytes_data)

def deserialize(self, stream: BytesIO) -> TimeWindow:
if self._underlying_coder is None:
self._underlying_coder = self._get_coder()
self._underlying_coder = self._get_coder().get_impl()
bytes_data = stream.read(16)
return self._underlying_coder.decode(bytes_data)

def _get_coder(self):
try:
from pyflink.fn_execution import coder_impl_fast as coder_impl
except:
from pyflink.fn_execution import coder_impl_slow as coder_impl
return coder_impl.TimeWindowCoderImpl()
from pyflink.fn_execution import coders
return coders.TimeWindowCoder()


class CountWindowSerializer(TypeSerializer[CountWindow]):
Expand All @@ -178,22 +175,19 @@ def __init__(self):

def serialize(self, element: CountWindow, stream: BytesIO) -> None:
if self._underlying_coder is None:
self._underlying_coder = self._get_coder()
self._underlying_coder = self._get_coder().get_impl()
bytes_data = self._underlying_coder.encode(element)
stream.write(bytes_data)

def deserialize(self, stream: BytesIO) -> CountWindow:
if self._underlying_coder is None:
self._underlying_coder = self._get_coder()
self._underlying_coder = self._get_coder().get_impl()
bytes_data = stream.read(8)
return self._underlying_coder.decode(bytes_data)

def _get_coder(self):
try:
from pyflink.fn_execution import coder_impl_fast as coder_impl
except:
from pyflink.fn_execution import coder_impl_slow as coder_impl
return coder_impl.CountWindowCoderImpl()
from pyflink.fn_execution import coders
return coders.CountWindowCoder()


T = TypeVar('T')
Expand Down
21 changes: 21 additions & 0 deletions flink-python/pyflink/fn_execution/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ class FieldCoder(ABC):
def get_impl(self) -> coder_impl.FieldCoderImpl:
pass

def __eq__(self, other):
return type(self) == type(other)


class IterableCoder(LengthPrefixBaseCoder):
"""
Expand Down Expand Up @@ -434,6 +437,11 @@ def __init__(self, precision, scale):
def get_impl(self):
return coder_impl.DecimalCoderImpl(self.precision, self.scale)

def __eq__(self, other: 'DecimalCoder'):
return (self.__class__ == other.__class__ and
self.precision == other.precision and
self.scale == other.scale)


class BigDecimalCoder(FieldCoder):
"""
Expand Down Expand Up @@ -491,6 +499,9 @@ def __init__(self, precision):
def get_impl(self):
return coder_impl.TimestampCoderImpl(self.precision)

def __eq__(self, other: 'TimestampCoder'):
return self.__class__ == other.__class__ and self.precision == other.precision


class LocalZonedTimestampCoder(FieldCoder):
"""
Expand All @@ -504,6 +515,11 @@ def __init__(self, precision, timezone):
def get_impl(self):
return coder_impl.LocalZonedTimestampCoderImpl(self.precision, self.timezone)

def __eq__(self, other: 'LocalZonedTimestampCoder'):
return (self.__class__ == other.__class__ and
self.precision == other.precision and
self.timezone == other.timezone)


class CloudPickleCoder(FieldCoder):
"""
Expand Down Expand Up @@ -537,6 +553,11 @@ def get_impl(self):
def __repr__(self):
return 'TupleCoder[%s]' % ', '.join(str(c) for c in self._field_coders)

def __eq__(self, other: 'TupleCoder'):
return (self.__class__ == other.__class__ and
[self._field_coders[i] == other._field_coders[i]
for i in range(len(self._field_coders))])


class TimeWindowCoder(FieldCoder):
"""
Expand Down
4 changes: 3 additions & 1 deletion flink-python/pyflink/fn_execution/datastream/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,9 @@ def process_element(normal_data, timestamp: int):
window_state_descriptor = window_operation_descriptor.window_state_descriptor
internal_window_function = window_operation_descriptor.internal_window_function
window_serializer = window_operation_descriptor.window_serializer
keyed_state_backend._namespace_coder_impl = window_serializer._get_coder()
window_coder = window_serializer._get_coder()
keyed_state_backend.namespace_coder = window_coder
keyed_state_backend._namespace_coder_impl = window_coder.get_impl()
window_operator = WindowOperator(
window_assigner,
keyed_state_backend,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyflink.datastream.state import ValueStateDescriptor, ValueState, ListStateDescriptor, \
ListState, MapStateDescriptor, MapState, ReducingStateDescriptor, ReducingState, \
AggregatingStateDescriptor, AggregatingState
from pyflink.fn_execution.coders import from_type_info, MapCoder, BasicArrayCoder
from pyflink.fn_execution.coders import from_type_info, MapCoder, GenericArrayCoder
from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend
from pyflink.metrics import MetricGroup

Expand Down Expand Up @@ -111,7 +111,7 @@ def get_state(self, state_descriptor: ValueStateDescriptor) -> ValueState:

def get_list_state(self, state_descriptor: ListStateDescriptor) -> ListState:
if self._keyed_state_backend:
array_coder = from_type_info(state_descriptor.type_info) # type: BasicArrayCoder
array_coder = from_type_info(state_descriptor.type_info) # type: GenericArrayCoder
return self._keyed_state_backend.get_list_state(
state_descriptor.name, array_coder._elem_coder)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
# limitations under the License.
################################################################################
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Tuple, Collection, Iterable
from typing import TypeVar, Generic, Collection, Iterable

from pyflink.datastream import MergingWindowAssigner
from pyflink.datastream.state import ListState
from pyflink.datastream.state import MapState

W = TypeVar("W")

Expand All @@ -45,11 +45,11 @@ def merge(self,
merged_state_windows: Collection[W]):
pass

def __init__(self, assigner: MergingWindowAssigner, state: ListState[Tuple[W, W]]):
def __init__(self, assigner: MergingWindowAssigner, state: MapState[W, W]):
self._window_assigner = assigner
self._mapping = dict()

for window_for_user, window_in_state in state:
for window_for_user, window_in_state in state.items():
self._mapping[window_for_user] = window_in_state

self._state = state
Expand All @@ -59,7 +59,7 @@ def persist(self) -> None:
if self._mapping != self._initial_mapping:
self._state.clear()
for window_for_user, window_in_state in self._mapping.items():
self._state.add((window_for_user, window_in_state))
self._state.put(window_for_user, window_in_state)

def get_state_window(self, window: W) -> W:
if window in self._mapping:
Expand Down
Loading

0 comments on commit 98e5b65

Please sign in to comment.