Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 63 additions & 14 deletions flink-python/pyflink/datastream/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pyflink.common.typeinfo import TypeInformation
from pyflink.datastream.functions import _get_python_env, FlatMapFunctionWrapper, FlatMapFunction, \
MapFunction, MapFunctionWrapper, Function, FunctionWrapper, SinkFunction, \
KeySelectorFunctionWrapper, KeySelector
KeySelectorFunctionWrapper, KeySelector, ReduceFunction, ReduceFunctionWrapper
from pyflink.java_gateway import get_gateway


Expand Down Expand Up @@ -191,15 +191,15 @@ def map(self, func: Union[Callable, MapFunction], type_info: TypeInformation = N
raise TypeError("The input must be a MapFunction or a callable function")
func_name = str(func)
from pyflink.fn_execution import flink_fn_execution_pb2
j_python_data_stream_scalar_function_operator, output_type_info = \
j_python_data_stream_scalar_function_operator, j_output_type_info = \
self._get_java_python_function_operator(func,
type_info,
func_name,
flink_fn_execution_pb2
.UserDefinedDataStreamFunction.MAP)
return DataStream(self._j_data_stream.transform(
"Map",
output_type_info.get_java_type_info(),
j_output_type_info,
j_python_data_stream_scalar_function_operator
))

Expand All @@ -222,15 +222,15 @@ def flat_map(self, func: Union[Callable, FlatMapFunction], type_info: TypeInform
raise TypeError("The input must be a FlatMapFunction or a callable function")
func_name = str(func)
from pyflink.fn_execution import flink_fn_execution_pb2
j_python_data_stream_scalar_function_operator, output_type_info = \
j_python_data_stream_scalar_function_operator, j_output_type_info = \
self._get_java_python_function_operator(func,
type_info,
func_name,
flink_fn_execution_pb2
.UserDefinedDataStreamFunction.FLAT_MAP)
return DataStream(self._j_data_stream.transform(
"FLAT_MAP",
output_type_info.get_java_type_info(),
j_output_type_info,
j_python_data_stream_scalar_function_operator
))

Expand Down Expand Up @@ -314,15 +314,30 @@ def _get_java_python_function_operator(self, func: Union[Function, FunctionWrapp
PythonConfigUtil = gateway.jvm.org.apache.flink.python.util.PythonConfigUtil
j_conf = PythonConfigUtil.getMergedConfig(j_env)

DataStreamPythonFunctionOperator = gateway.jvm.org.apache.flink.datastream.runtime \
.operators.python.DataStreamPythonStatelessFunctionOperator
# set max bundle size to 1 to force synchronize process for reduce function.
from pyflink.fn_execution.flink_fn_execution_pb2 import UserDefinedDataStreamFunction
if func_type == UserDefinedDataStreamFunction.REDUCE:
j_conf.setInteger(gateway.jvm.org.apache.flink.python.PythonOptions.MAX_BUNDLE_SIZE, 1)
DataStreamPythonReduceFunctionOperator = gateway.jvm.org.apache.flink.datastream \
.runtime.operators.python.DataStreamPythonReduceFunctionOperator

j_output_type_info = j_input_types.getTypeAt(1)
j_python_data_stream_function_operator = DataStreamPythonReduceFunctionOperator(
j_conf,
j_input_types,
j_output_type_info,
j_python_data_stream_function_info)
return j_python_data_stream_function_operator, j_output_type_info
else:
DataStreamPythonFunctionOperator = gateway.jvm.org.apache.flink.datastream.runtime \
.operators.python.DataStreamPythonStatelessFunctionOperator
j_python_data_stream_function_operator = DataStreamPythonFunctionOperator(
j_conf,
j_input_types,
output_type_info.get_java_type_info(),
j_python_data_stream_function_info)

j_python_data_stream_scalar_function_operator = DataStreamPythonFunctionOperator(
j_conf,
j_input_types,
output_type_info.get_java_type_info(),
j_python_data_stream_function_info)
return j_python_data_stream_scalar_function_operator, output_type_info
return j_python_data_stream_function_operator, output_type_info.get_java_type_info()

def add_sink(self, sink_func: SinkFunction) -> 'DataStreamSink':
"""
Expand Down Expand Up @@ -434,7 +449,41 @@ def flat_map(self, func: Union[Callable, FlatMapFunction], type_info: TypeInform
-> 'DataStream':
return self._values().flat_map(func, type_info)

def _values(self) -> 'DataStream':
def reduce(self, func: Union[Callable, ReduceFunction]) -> 'DataStream':
"""
Applies a reduce transformation on the grouped data stream grouped on by the given
key position. The `ReduceFunction` will receive input values based on the key value.
Only input values with the same key will go to the same reducer.

Example:
::
>>> ds = env.from_collection([(1, 'a'), (2, 'a'), (3, 'a'), (4, 'b'])
>>> ds.key_by(lambda x: x[1]).reduce(lambda a, b: a[0] + b[0], b[1])

:param func: The ReduceFunction that is called for each element of the DataStream.
:return: The transformed DataStream.
"""

if not isinstance(func, ReduceFunction):
if callable(func):
func = ReduceFunctionWrapper(func)
else:
raise TypeError("The input must be a ReduceFunction or a callable function!")

from pyflink.fn_execution.flink_fn_execution_pb2 import UserDefinedDataStreamFunction
func_name = "m_reduce_" + str(func)
j_python_data_stream_scalar_function_operator, j_output_type_info = \
self._get_java_python_function_operator(func,
None,
func_name,
UserDefinedDataStreamFunction.REDUCE)
return DataStream(self._j_data_stream.transform(
"Keyed Reduce",
j_output_type_info,
j_python_data_stream_scalar_function_operator
))

def _values(self):
"""
Since python KeyedStream is in the format of Row(key_value, original_data), it is used for
getting the original_data.
Expand Down
52 changes: 52 additions & 0 deletions flink-python/pyflink/datastream/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,33 @@ def flat_map(self, value):
pass


class ReduceFunction(Function):
"""
Base interface for Reduce functions. Reduce functions combine groups of elements to a single
value, by taking always two elements and combining them into one. Reduce functions may be
used on entire data sets, or on grouped data sets. In the latter case, each group is reduced
individually.

The basic syntax for using a ReduceFunction is as follows:
::
>>> ds = ...
>>> new_ds = ds.key_by(lambda x: x[1]).reduce(MyReduceFunction())
"""

@abc.abstractmethod
def reduce(self, value1, value2):
"""
The core method of ReduceFunction, combining two values into one value of the same type.
The reduce function is consecutively applied to all values of a group until only a single
value remains.

:param value1: The first value to combine.
:param value2: The second value to combine.
:return: The combined value of both input values.
"""
pass


class KeySelector(Function):
"""
The KeySelector allows to use deterministic objects for operations such as reduce, reduceGroup,
Expand Down Expand Up @@ -161,6 +188,31 @@ def flat_map(self, value):
return self._func(value)


class ReduceFunctionWrapper(FunctionWrapper):
"""
A wrapper class for ReduceFunction. It's used for wrapping up user defined function in a
ReduceFunction when user does not implement a ReduceFunction but directly pass a function
object or a lambda function to reduce() function.
"""
def __init__(self, func):
"""
The constructor of ReduceFunctionWrapper.

:param func: user defined function object.
"""
super(ReduceFunctionWrapper, self).__init__(func)

def reduce(self, value1, value2):
"""
A delegated reduce function to invoke user defined function.

:param value1: The first value to combine.
:param value2: The second value to combine.
:return: The combined value of both input values.
"""
return self._func(value1, value2)


class KeySelectorFunctionWrapper(FunctionWrapper):
"""
A wrapper class for KeySelector. It's used for wrapping up user defined function in a
Expand Down
11 changes: 11 additions & 0 deletions flink-python/pyflink/datastream/tests/test_data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ def test_force_non_parallel(self):
plan = eval(str(self.env.get_execution_plan()))
self.assertEqual(1, plan['nodes'][0]['parallelism'])

def test_reduce_function_without_data_types(self):
ds = self.env.from_collection([(1, 'a'), (2, 'a'), (3, 'a'), (4, 'b')],
type_info=Types.ROW([Types.INT(), Types.STRING()]))
ds.key_by(lambda a: a[1]).reduce(lambda a, b: (a[0] + b[0], b[1])).add_sink(self.test_sink)
self.env.execute('reduce_function_test')
result = self.test_sink.get_results()
expected = ["1,a", "3,a", "6,a", "4,b"]
expected.sort()
result.sort()
self.assertEqual(expected, result)

def test_map_function_without_data_types(self):
self.env.set_parallelism(1)
ds = self.env.from_collection([('ab', decimal.Decimal(1)),
Expand Down
Loading