diff --git a/.gitignore b/.gitignore
index 88c448814000a9..4e3a5ae2f2e50c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -35,6 +35,8 @@ flink-python/dev/download
flink-python/dev/.conda/
flink-python/dev/log/
flink-python/dev/.stage.txt
+flink-python/.eggs/
+flink-python/pyflink/fn_execution/*_pb2.py
atlassian-ide-plugin.xml
out/
/docs/api
diff --git a/flink-python/pom.xml b/flink-python/pom.xml
index ed16daa09e3a89..ffeb4864c5e5aa 100644
--- a/flink-python/pom.xml
+++ b/flink-python/pom.xml
@@ -327,6 +327,26 @@ under the License.
+
+ exec-maven-plugin
+ org.codehaus.mojo
+ 1.5.0
+
+
+ Protos Generation
+ generate-sources
+
+ exec
+
+
+ python
+
+ ${basedir}/pyflink/gen_protos.py
+
+
+
+
+
diff --git a/flink-python/pyflink/fn_execution/boot.py b/flink-python/pyflink/fn_execution/boot.py
index e42ad2eea58857..bf913846fb2bfc 100644
--- a/flink-python/pyflink/fn_execution/boot.py
+++ b/flink-python/pyflink/fn_execution/boot.py
@@ -145,5 +145,5 @@ def check_not_empty(check_str, error_message):
if "FLINK_BOOT_TESTING" in os.environ and os.environ["FLINK_BOOT_TESTING"] == "1":
exit(0)
-call([sys.executable, "-m", "apache_beam.runners.worker.sdk_worker_main"],
+call([sys.executable, "-m", "pyflink.fn_execution.sdk_worker_main"],
stdout=sys.stdout, stderr=sys.stderr, env=env)
diff --git a/flink-python/pyflink/fn_execution/coder_impl.py b/flink-python/pyflink/fn_execution/coder_impl.py
new file mode 100644
index 00000000000000..2f7c5d3ab6536f
--- /dev/null
+++ b/flink-python/pyflink/fn_execution/coder_impl.py
@@ -0,0 +1,80 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import sys
+
+from apache_beam.coders.coder_impl import StreamCoderImpl
+
+if sys.version > '3':
+ xrange = range
+
+
+class RowCoderImpl(StreamCoderImpl):
+
+ def __init__(self, field_coders):
+ self._field_coders = field_coders
+
+ def encode_to_stream(self, value, out_stream, nested):
+ self.write_null_mask(value, out_stream)
+ for i in xrange(len(self._field_coders)):
+ self._field_coders[i].encode_to_stream(value[i], out_stream, nested)
+
+ def decode_from_stream(self, in_stream, nested):
+ from pyflink.table import Row
+ null_mask = self.read_null_mask(len(self._field_coders), in_stream)
+ assert len(null_mask) == len(self._field_coders)
+ return Row(*[None if null_mask[idx] else self._field_coders[idx].decode_from_stream(
+ in_stream, nested) for idx in xrange(0, len(null_mask))])
+
+ @staticmethod
+ def write_null_mask(value, out_stream):
+ field_pos = 0
+ field_count = len(value)
+ while field_pos < field_count:
+ b = 0x00
+ # set bits in byte
+ num_pos = min(8, field_count - field_pos)
+ byte_pos = 0
+ while byte_pos < num_pos:
+ b = b << 1
+ # set bit if field is null
+ if value[field_pos + byte_pos] is None:
+ b |= 0x01
+ byte_pos += 1
+ field_pos += num_pos
+ # shift bits if last byte is not completely filled
+ b <<= (8 - byte_pos)
+ # write byte
+ out_stream.write_byte(b)
+
+ @staticmethod
+ def read_null_mask(field_count, in_stream):
+ null_mask = []
+ field_pos = 0
+ while field_pos < field_count:
+ b = in_stream.read_byte()
+ num_pos = min(8, field_count - field_pos)
+ byte_pos = 0
+ while byte_pos < num_pos:
+ null_mask.append((b & 0x80) > 0)
+ b = b << 1
+ byte_pos += 1
+ field_pos += num_pos
+ return null_mask
+
+ def __repr__(self):
+ return 'RowCoderImpl[%s]' % ', '.join(str(c) for c in self._field_coders)
diff --git a/flink-python/pyflink/fn_execution/coders.py b/flink-python/pyflink/fn_execution/coders.py
new file mode 100644
index 00000000000000..c6add86477aba0
--- /dev/null
+++ b/flink-python/pyflink/fn_execution/coders.py
@@ -0,0 +1,86 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import sys
+
+from apache_beam.coders import Coder, VarIntCoder
+from apache_beam.coders.coders import FastCoder
+
+from pyflink.fn_execution import coder_impl
+from pyflink.fn_execution import flink_fn_execution_pb2
+
+FLINK_SCHEMA_CODER_URN = "flink:coder:schema:v1"
+
+if sys.version > '3':
+ xrange = range
+
+
+__all__ = ['RowCoder']
+
+
+class RowCoder(FastCoder):
+ """
+ Coder for Row.
+ """
+
+ def __init__(self, field_coders):
+ self._field_coders = field_coders
+
+ def _create_impl(self):
+ return coder_impl.RowCoderImpl([c.get_impl() for c in self._field_coders])
+
+ def is_deterministic(self):
+ return all(c.is_deterministic() for c in self._field_coders)
+
+ def to_type_hint(self):
+ from pyflink.table import Row
+ return Row
+
+ def __repr__(self):
+ return 'RowCoder[%s]' % ', '.join(str(c) for c in self._field_coders)
+
+ def __eq__(self, other):
+ return (self.__class__ == other.__class__
+ and len(self._field_coders) == len(other._field_coders)
+ and [self._field_coders[i] == other._field_coders[i] for i in
+ xrange(len(self._field_coders))])
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ return hash(self._field_coders)
+
+
+@Coder.register_urn(FLINK_SCHEMA_CODER_URN, flink_fn_execution_pb2.Schema)
+def _pickle_from_runner_api_parameter(schema_proto, unused_components, unused_context):
+ return RowCoder([from_proto(f.type) for f in schema_proto.fields])
+
+
+def from_proto(field_type):
+ """
+ Creates the corresponding :class:`Coder` given the protocol representation of the field type.
+
+ :param field_type: the protocol representation of the field type
+ :return: :class:`Coder`
+ """
+ if field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.BIGINT:
+ return VarIntCoder()
+ elif field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.ROW:
+ return RowCoder([from_proto(f.type) for f in field_type.row_schema.fields])
+ else:
+ raise ValueError("field_type %s is not supported." % field_type)
diff --git a/flink-python/pyflink/fn_execution/operations.py b/flink-python/pyflink/fn_execution/operations.py
new file mode 100644
index 00000000000000..13e179bb53c3fd
--- /dev/null
+++ b/flink-python/pyflink/fn_execution/operations.py
@@ -0,0 +1,261 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from abc import abstractmethod, ABCMeta
+
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import bundle_processor
+from apache_beam.runners.worker.operations import Operation
+
+from pyflink.fn_execution import flink_fn_execution_pb2
+
+SCALAR_FUNCTION_URN = "flink:transform:scalar_function:v1"
+
+
+class InputGetter(object):
+ """
+ Base class for get an input argument for a :class:`UserDefinedFunction`.
+ """
+ __metaclass__ = ABCMeta
+
+ def open(self):
+ pass
+
+ def close(self):
+ pass
+
+ @abstractmethod
+ def get(self, value):
+ pass
+
+
+class OffsetInputGetter(InputGetter):
+ """
+ InputGetter for the input argument which is a column of the input row.
+
+ :param input_offset: the offset of the column in the input row
+ """
+
+ def __init__(self, input_offset):
+ self.input_offset = input_offset
+
+ def get(self, value):
+ return value[self.input_offset]
+
+
+class ScalarFunctionInputGetter(InputGetter):
+ """
+ InputGetter for the input argument which is a Python :class:`ScalarFunction`. This is used for
+ chaining Python functions.
+
+ :param scalar_function_proto: the proto representation of the Python :class:`ScalarFunction`
+ """
+
+ def __init__(self, scalar_function_proto):
+ self.scalar_function_invoker = create_scalar_function_invoker(scalar_function_proto)
+
+ def open(self):
+ self.scalar_function_invoker.invoke_open()
+
+ def close(self):
+ self.scalar_function_invoker.invoke_close()
+
+ def get(self, value):
+ return self.scalar_function_invoker.invoke_eval(value)
+
+
+class ScalarFunctionInvoker(object):
+ """
+ An abstraction that can be used to execute :class:`ScalarFunction` methods.
+
+ A ScalarFunctionInvoker describes a particular way for invoking methods of a
+ :class:`ScalarFunction`.
+
+ :param scalar_function: the :class:`ScalarFunction` to execute
+ :param inputs: the input arguments for the :class:`ScalarFunction`
+ """
+
+ def __init__(self, scalar_function, inputs):
+ self.scalar_function = scalar_function
+ self.input_getters = []
+ for input in inputs:
+ if input.HasField("udf"):
+ # for chaining Python UDF input: the input argument is a Python ScalarFunction
+ self.input_getters.append(ScalarFunctionInputGetter(input.udf))
+ else:
+ # the input argument is a column of the input row
+ self.input_getters.append(OffsetInputGetter(input.inputOffset))
+
+ def invoke_open(self):
+ """
+ Invokes the ScalarFunction.open() function.
+ """
+ for input_getter in self.input_getters:
+ input_getter.open()
+ # set the FunctionContext to None for now
+ self.scalar_function.open(None)
+
+ def invoke_close(self):
+ """
+ Invokes the ScalarFunction.close() function.
+ """
+ for input_getter in self.input_getters:
+ input_getter.close()
+ self.scalar_function.close()
+
+ def invoke_eval(self, value):
+ """
+ Invokes the ScalarFunction.eval() function.
+
+ :param value: the input element for which eval() method should be invoked
+ """
+ args = [input_getter.get(value) for input_getter in self.input_getters]
+ return self.scalar_function.eval(*args)
+
+
+def create_scalar_function_invoker(scalar_function_proto):
+ """
+ Creates :class:`ScalarFunctionInvoker` from the proto representation of a
+ :class:`ScalarFunction`.
+
+ :param scalar_function_proto: the proto representation of the Python :class:`ScalarFunction`
+ :return: :class:`ScalarFunctionInvoker`.
+ """
+ import cloudpickle
+ scalar_function = cloudpickle.loads(scalar_function_proto.payload)
+ return ScalarFunctionInvoker(scalar_function, scalar_function_proto.inputs)
+
+
+class ScalarFunctionRunner(object):
+ """
+ The runner which is responsible for executing the scalar functions and send the
+ execution results back to the remote Java operator.
+
+ :param udfs_proto: protocol representation for the scalar functions to execute
+ """
+
+ def __init__(self, udfs_proto):
+ self.scalar_function_invokers = [create_scalar_function_invoker(f) for f in
+ udfs_proto]
+
+ def setup(self, main_receivers):
+ """
+ Set up the ScalarFunctionRunner.
+
+ :param main_receivers: Receiver objects which is responsible for sending the execution
+ results back the the remote Java operator
+ """
+ from apache_beam.runners.common import _OutputProcessor
+ self.output_processor = _OutputProcessor(
+ window_fn=None,
+ main_receivers=main_receivers,
+ tagged_receivers=None,
+ per_element_output_counter=None)
+
+ def open(self):
+ for invoker in self.scalar_function_invokers:
+ invoker.invoke_open()
+
+ def close(self):
+ for invoker in self.scalar_function_invokers:
+ invoker.invoke_close()
+
+ def process(self, windowed_value):
+ results = [invoker.invoke_eval(windowed_value.value) for invoker in
+ self.scalar_function_invokers]
+ from pyflink.table import Row
+ result = Row(*results)
+ # send the execution results back
+ self.output_processor.process_outputs(windowed_value, [result])
+
+
+class ScalarFunctionOperation(Operation):
+ """
+ An operation that will execute ScalarFunctions for each input element.
+ """
+
+ def __init__(self, name, spec, counter_factory, sampler, consumers):
+ super(ScalarFunctionOperation, self).__init__(name, spec, counter_factory, sampler)
+ for tag, op_consumers in consumers.items():
+ for consumer in op_consumers:
+ self.add_receiver(consumer, 0)
+
+ self.scalar_function_runner = ScalarFunctionRunner(self.spec.serialized_fn)
+ self.scalar_function_runner.open()
+
+ def setup(self):
+ with self.scoped_start_state:
+ super(ScalarFunctionOperation, self).setup()
+ self.scalar_function_runner.setup(self.receivers[0])
+
+ def start(self):
+ with self.scoped_start_state:
+ super(ScalarFunctionOperation, self).start()
+
+ def process(self, o):
+ with self.scoped_process_state:
+ self.scalar_function_runner.process(o)
+
+ def finish(self):
+ with self.scoped_finish_state:
+ super(ScalarFunctionOperation, self).finish()
+
+ def needs_finalization(self):
+ return False
+
+ def reset(self):
+ super(ScalarFunctionOperation, self).reset()
+
+ def teardown(self):
+ with self.scoped_finish_state:
+ self.scalar_function_runner.close()
+
+ def progress_metrics(self):
+ metrics = super(ScalarFunctionOperation, self).progress_metrics()
+ metrics.processed_elements.measured.output_element_counts.clear()
+ tag = None
+ receiver = self.receivers[0]
+ metrics.processed_elements.measured.output_element_counts[
+ str(tag)] = receiver.opcounter.element_counter.value()
+ return metrics
+
+
+@bundle_processor.BeamTransformFactory.register_urn(
+ SCALAR_FUNCTION_URN, flink_fn_execution_pb2.UserDefinedFunctions)
+def create(factory, transform_id, transform_proto, parameter, consumers):
+ return _create_user_defined_function_operation(
+ factory, transform_proto, consumers, parameter.udfs)
+
+
+def _create_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto,
+ operation_cls=ScalarFunctionOperation):
+ output_tags = list(transform_proto.outputs.keys())
+ output_coders = factory.get_output_coders(transform_proto)
+ spec = operation_specs.WorkerDoFn(
+ serialized_fn=udfs_proto,
+ output_tags=output_tags,
+ input=None,
+ side_inputs=None,
+ output_coders=[output_coders[tag] for tag in output_tags])
+
+ return operation_cls(
+ transform_proto.unique_name,
+ spec,
+ factory.counter_factory,
+ factory.state_sampler,
+ consumers)
diff --git a/flink-python/pyflink/fn_execution/sdk_worker_main.py b/flink-python/pyflink/fn_execution/sdk_worker_main.py
new file mode 100644
index 00000000000000..82d091c52953b7
--- /dev/null
+++ b/flink-python/pyflink/fn_execution/sdk_worker_main.py
@@ -0,0 +1,30 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import sys
+
+# force to register the operations to SDK Harness
+import pyflink.fn_execution.operations # noqa # pylint: disable=unused-import
+
+# force to register the coders to SDK Harness
+import pyflink.fn_execution.coders # noqa # pylint: disable=unused-import
+
+import apache_beam.runners.worker.sdk_worker_main
+
+if __name__ == '__main__':
+ apache_beam.runners.worker.sdk_worker_main.main(sys.argv)
diff --git a/flink-python/pyflink/gen_protos.py b/flink-python/pyflink/gen_protos.py
new file mode 100644
index 00000000000000..ce4ddd4010a405
--- /dev/null
+++ b/flink-python/pyflink/gen_protos.py
@@ -0,0 +1,146 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from __future__ import absolute_import
+from __future__ import print_function
+
+import glob
+import logging
+import multiprocessing
+import os
+import platform
+import shutil
+import subprocess
+import sys
+import time
+import warnings
+
+import pkg_resources
+
+# latest grpcio-tools incompatible with latest protobuf 3.6.1.
+GRPC_TOOLS = 'grpcio-tools>=1.3.5,<=1.14.2'
+
+PROTO_PATHS = [
+ os.path.join('proto'),
+]
+
+PYTHON_OUTPUT_PATH = os.path.join('fn_execution')
+
+
+def generate_proto_files(force=False):
+ try:
+ import grpc_tools # noqa # pylint: disable=unused-import
+ except ImportError:
+ warnings.warn('Installing grpcio-tools is recommended for development.')
+
+ py_sdk_root = os.path.dirname(os.path.abspath(__file__))
+ proto_dirs = [os.path.join(py_sdk_root, path) for path in PROTO_PATHS]
+ proto_files = sum(
+ [glob.glob(os.path.join(d, '*.proto')) for d in proto_dirs], [])
+ out_dir = os.path.join(py_sdk_root, PYTHON_OUTPUT_PATH)
+ out_files = [path for path in glob.glob(os.path.join(out_dir, '*_pb2.py'))]
+
+ if out_files and not proto_files and not force:
+ # We have out_files but no protos; assume they're up to date.
+ # This is actually the common case (e.g. installation from an sdist).
+ logging.info('No proto files; using existing generated files.')
+ return
+
+ elif not out_files and not proto_files:
+ raise RuntimeError(
+ 'No proto files found in %s.' % proto_dirs)
+
+ # Regenerate iff the proto files or this file are newer.
+ elif force or not out_files or len(out_files) < len(proto_files) or (
+ min(os.path.getmtime(path) for path in out_files)
+ <= max(os.path.getmtime(path)
+ for path in proto_files + [os.path.realpath(__file__)])):
+ try:
+ from grpc_tools import protoc
+ except ImportError:
+ if platform.system() == 'Windows':
+ # For Windows, grpcio-tools has to be installed manually.
+ raise RuntimeError(
+ 'Cannot generate protos for Windows since grpcio-tools package is '
+ 'not installed. Please install this package manually '
+ 'using \'pip install grpcio-tools\'.')
+
+ # Use a subprocess to avoid messing with this process' path and imports.
+ # Note that this requires a separate module from setup.py for Windows:
+ # https://docs.python.org/2/library/multiprocessing.html#windows
+ p = multiprocessing.Process(
+ target=_install_grpcio_tools_and_generate_proto_files)
+ p.start()
+ p.join()
+ if p.exitcode:
+ raise ValueError("Proto generation failed (see log for details).")
+ else:
+ logging.info('Regenerating out-of-date Python proto definitions.')
+ builtin_protos = pkg_resources.resource_filename('grpc_tools', '_proto')
+ args = (
+ [sys.executable] + # expecting to be called from command line
+ ['--proto_path=%s' % builtin_protos] +
+ ['--proto_path=%s' % d for d in proto_dirs] +
+ ['--python_out=%s' % out_dir] +
+ proto_files)
+ ret_code = protoc.main(args)
+ if ret_code:
+ raise RuntimeError(
+ 'Protoc returned non-zero status (see logs for details): '
+ '%s' % ret_code)
+
+
+# Though wheels are available for grpcio-tools, setup_requires uses
+# easy_install which doesn't understand them. This means that it is
+# compiled from scratch (which is expensive as it compiles the full
+# protoc compiler). Instead, we attempt to install a wheel in a temporary
+# directory and add it to the path as needed.
+# See https://github.com/pypa/setuptools/issues/377
+def _install_grpcio_tools_and_generate_proto_files():
+ install_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), '..', '.eggs', 'grpcio-wheels')
+ build_path = install_path + '-build'
+ if os.path.exists(build_path):
+ shutil.rmtree(build_path)
+ logging.warning('Installing grpcio-tools into %s', install_path)
+ try:
+ start = time.time()
+ subprocess.check_call(
+ [sys.executable, '-m', 'pip', 'install',
+ '--prefix', install_path, '--build', build_path,
+ '--upgrade', GRPC_TOOLS, "-I"])
+ from distutils.dist import Distribution
+ install_obj = Distribution().get_command_obj('install', create=True)
+ install_obj.prefix = install_path
+ install_obj.finalize_options()
+ logging.warning(
+ 'Installing grpcio-tools took %0.2f seconds.', time.time() - start)
+ finally:
+ sys.stderr.flush()
+ shutil.rmtree(build_path, ignore_errors=True)
+ sys.path.append(install_obj.install_purelib)
+ if install_obj.install_purelib != install_obj.install_platlib:
+ sys.path.append(install_obj.install_platlib)
+ try:
+ generate_proto_files()
+ finally:
+ sys.stderr.flush()
+
+
+if __name__ == '__main__':
+ generate_proto_files(force=True)
diff --git a/flink-python/pyflink/table/__init__.py b/flink-python/pyflink/table/__init__.py
index e69a9b7ecb0cd8..82ce28a34063c6 100644
--- a/flink-python/pyflink/table/__init__.py
+++ b/flink-python/pyflink/table/__init__.py
@@ -52,6 +52,11 @@
from a registered :class:`pyflink.table.catalog.Catalog`.
- :class:`pyflink.table.TableSchema`
Represents a table's structure with field names and data types.
+ - :class:`pyflink.table.FunctionContext`
+ Used to obtain global runtime information about the context in which the
+ user-defined function is executed, such as the metric group, and global job parameters, etc.
+ - :class:`pyflink.table.ScalarFunction`
+ Base interface for user-defined scalar function.
"""
from __future__ import absolute_import
@@ -65,6 +70,7 @@
from pyflink.table.sources import TableSource, CsvTableSource
from pyflink.table.types import DataTypes, UserDefinedType, Row
from pyflink.table.table_schema import TableSchema
+from pyflink.table.udf import FunctionContext, ScalarFunction
__all__ = [
'TableEnvironment',
@@ -85,5 +91,7 @@
'DataTypes',
'UserDefinedType',
'Row',
- 'TableSchema'
+ 'TableSchema',
+ 'FunctionContext',
+ 'ScalarFunction'
]
diff --git a/flink-python/pyflink/table/table_environment.py b/flink-python/pyflink/table/table_environment.py
index 33160e0cdbef5d..940aafcf62b12d 100644
--- a/flink-python/pyflink/table/table_environment.py
+++ b/flink-python/pyflink/table/table_environment.py
@@ -542,6 +542,36 @@ def register_java_function(self, name, function_class_name):
.loadClass(function_class_name).newInstance()
self._j_tenv.registerFunction(name, java_function)
+ def register_function(self, name, function):
+ """
+ Registers a python user-defined function under a unique name. Replaces already existing
+ user-defined function under this name.
+
+ Example:
+ ::
+
+ >>> table_env.register_function(
+ ... "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
+
+ >>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
+ ... result_type=DataTypes.BIGINT())
+ ... def add(i, j):
+ ... return i + j
+ >>> table_env.register_function("add", add)
+
+ >>> class SubtractOne(ScalarFunction):
+ ... def eval(self, i):
+ ... return i - 1
+ >>> table_env.register_function(
+ ... "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+
+ :param name: The name under which the function is registered.
+ :type name: str
+ :param function: The python user-defined function to register.
+ :type function: UserDefinedFunctionWrapper
+ """
+ self._j_tenv.registerFunction(name, function._judf)
+
def execute(self, job_name):
"""
Triggers the program execution. The environment will execute all parts of
diff --git a/flink-python/pyflink/table/tests/test_environment_completeness.py b/flink-python/pyflink/table/tests/test_environment_completeness.py
index 5f73c3c88b5670..89459891f8f3e5 100644
--- a/flink-python/pyflink/table/tests/test_environment_completeness.py
+++ b/flink-python/pyflink/table/tests/test_environment_completeness.py
@@ -41,8 +41,7 @@ def excluded_methods(cls):
# registerCatalog, getCatalog and listTables should be supported when catalog supported in
# python. getCompletionHints has been deprecated. It will be removed in the next release.
# TODO add TableEnvironment#create method with EnvironmentSettings as a parameter
- return {'registerCatalog', 'getCatalog', 'registerFunction', 'listTables',
- 'getCompletionHints', 'create'}
+ return {'registerCatalog', 'getCatalog', 'listTables', 'getCompletionHints', 'create'}
if __name__ == '__main__':
diff --git a/flink-python/pyflink/table/tests/test_udf.py b/flink-python/pyflink/table/tests/test_udf.py
new file mode 100644
index 00000000000000..bf46ea9a5d3fd8
--- /dev/null
+++ b/flink-python/pyflink/table/tests/test_udf.py
@@ -0,0 +1,245 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pyflink.table import DataTypes
+from pyflink.table.udf import ScalarFunction, udf
+from pyflink.testing import source_sink_utils
+from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase
+
+
+class UserDefinedFunctionTests(PyFlinkStreamTableTestCase):
+
+ def test_scalar_function(self):
+ # test lambda function
+ self.t_env.register_function(
+ "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
+
+ # test Python ScalarFunction
+ self.t_env.register_function(
+ "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+
+ # test Python function
+ self.t_env.register_function("add", add)
+
+ # test callable function
+ self.t_env.register_function(
+ "add_one_callable", udf(CallablePlus(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+
+ def partial_func(col, param):
+ return col + param
+
+ # test partial function
+ import functools
+ self.t_env.register_function(
+ "add_one_partial",
+ udf(functools.partial(partial_func, param=1), DataTypes.BIGINT(), DataTypes.BIGINT()))
+
+ table_sink = source_sink_utils.TestAppendSink(
+ ['a', 'b', 'c', 'd', 'e'],
+ [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT(),
+ DataTypes.BIGINT()])
+ self.t_env.register_table_sink("Results", table_sink)
+
+ t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c'])
+ t.where("add_one(b) <= 3") \
+ .select("add_one(a), subtract_one(b), add(a, c), add_one_callable(a), "
+ "add_one_partial(a)") \
+ .insert_into("Results")
+ self.t_env.execute("test")
+ actual = source_sink_utils.results()
+ self.assert_equals(actual, ["2,1,4,2,2", "4,0,12,4,4"])
+
+ def test_chaining_scalar_function(self):
+ self.t_env.register_function(
+ "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()))
+ self.t_env.register_function(
+ "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ self.t_env.register_function("add", add)
+
+ table_sink = source_sink_utils.TestAppendSink(['a'], [DataTypes.BIGINT()])
+ self.t_env.register_table_sink("Results", table_sink)
+
+ t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b'])
+ t.select("add(add_one(a), subtract_one(b))") \
+ .insert_into("Results")
+ self.t_env.execute("test")
+ actual = source_sink_utils.results()
+ self.assert_equals(actual, ["3", "7", "4"])
+
+ def test_udf_in_join_condition(self):
+ t1 = self.t_env.from_elements([(2, "Hi")], ['a', 'b'])
+ t2 = self.t_env.from_elements([(2, "Flink")], ['c', 'd'])
+
+ self.t_env.register_function("f", udf(lambda i: i, DataTypes.BIGINT(), DataTypes.BIGINT()))
+
+ table_sink = source_sink_utils.TestAppendSink(
+ ['a', 'b', 'c', 'd'],
+ [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.BIGINT(), DataTypes.STRING()])
+ self.t_env.register_table_sink("Results", table_sink)
+
+ t1.join(t2).where("f(a) = c").insert_into("Results")
+ self.t_env.execute("test")
+ actual = source_sink_utils.results()
+ self.assert_equals(actual, ["2,Hi,2,Flink"])
+
+ def test_udf_in_join_condition_2(self):
+ t1 = self.t_env.from_elements([(1, "Hi"), (2, "Hi")], ['a', 'b'])
+ t2 = self.t_env.from_elements([(2, "Flink")], ['c', 'd'])
+
+ self.t_env.register_function("f", udf(lambda i: i, DataTypes.BIGINT(), DataTypes.BIGINT()))
+
+ table_sink = source_sink_utils.TestAppendSink(
+ ['a', 'b', 'c', 'd'],
+ [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.BIGINT(), DataTypes.STRING()])
+ self.t_env.register_table_sink("Results", table_sink)
+
+ t1.join(t2).where("f(a) = f(c)").insert_into("Results")
+ self.t_env.execute("test")
+ actual = source_sink_utils.results()
+ self.assert_equals(actual, ["2,Hi,2,Flink"])
+
+ def test_overwrite_builtin_function(self):
+ self.t_env.register_function(
+ "plus", udf(lambda i, j: i + j - 1,
+ [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT()))
+
+ table_sink = source_sink_utils.TestAppendSink(['a'], [DataTypes.BIGINT()])
+ self.t_env.register_table_sink("Results", table_sink)
+
+ t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c'])
+ t.select("plus(a, b)").insert_into("Results")
+ self.t_env.execute("test")
+ actual = source_sink_utils.results()
+ self.assert_equals(actual, ["2", "6", "3"])
+
+ def test_open(self):
+ self.t_env.register_function(
+ "subtract", udf(Subtract(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+ table_sink = source_sink_utils.TestAppendSink(
+ ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()])
+ self.t_env.register_table_sink("Results", table_sink)
+
+ t = self.t_env.from_elements([(1, 2), (2, 5), (3, 4)], ['a', 'b'])
+ t.select("a, subtract(b)").insert_into("Results")
+ self.t_env.execute("test")
+ actual = source_sink_utils.results()
+ self.assert_equals(actual, ["1,1", "2,4", "3,3"])
+
+ def test_deterministic(self):
+ add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())
+ self.assertTrue(add_one._deterministic)
+
+ add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), deterministic=False)
+ self.assertFalse(add_one._deterministic)
+
+ subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())
+ self.assertTrue(subtract_one._deterministic)
+
+ with self.assertRaises(ValueError, msg="Inconsistent deterministic: False and True"):
+ udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT(), deterministic=False)
+
+ self.assertTrue(add._deterministic)
+
+ @udf(input_types=DataTypes.BIGINT(), result_type=DataTypes.BIGINT(), deterministic=False)
+ def non_deterministic_udf(i):
+ return i
+
+ self.assertFalse(non_deterministic_udf._deterministic)
+
+ def test_name(self):
+ add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())
+ self.assertEqual("", add_one._name)
+
+ add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), name="add_one")
+ self.assertEqual("add_one", add_one._name)
+
+ subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())
+ self.assertEqual("SubtractOne", subtract_one._name)
+
+ subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT(),
+ name="subtract_one")
+ self.assertEqual("subtract_one", subtract_one._name)
+
+ self.assertEqual("add", add._name)
+
+ @udf(input_types=DataTypes.BIGINT(), result_type=DataTypes.BIGINT(), name="named")
+ def named_udf(i):
+ return i
+
+ self.assertEqual("named", named_udf._name)
+
+ def test_abc(self):
+ class UdfWithoutEval(ScalarFunction):
+ def open(self, function_context):
+ pass
+
+ with self.assertRaises(
+ TypeError,
+ msg="Can't instantiate abstract class UdfWithoutEval with abstract methods eval"):
+ UdfWithoutEval()
+
+ def test_invalid_udf(self):
+ class Plus(object):
+ def eval(self, col):
+ return col + 1
+
+ with self.assertRaises(
+ TypeError,
+ msg="Invalid function: not a function or callable (__call__ is not defined)"):
+ # test non-callable function
+ self.t_env.register_function(
+ "non-callable-udf", udf(Plus(), DataTypes.BIGINT(), DataTypes.BIGINT()))
+
+
+@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT())
+def add(i, j):
+ return i + j
+
+
+class SubtractOne(ScalarFunction):
+
+ def eval(self, i):
+ return i - 1
+
+
+class Subtract(ScalarFunction):
+
+ def __init__(self):
+ self.subtracted_value = 0
+
+ def open(self, function_context):
+ self.subtracted_value = 1
+
+ def eval(self, i):
+ return i - self.subtracted_value
+
+
+class CallablePlus(object):
+
+ def __call__(self, col):
+ return col + 1
+
+
+if __name__ == '__main__':
+ import unittest
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/flink-python/pyflink/table/types.py b/flink-python/pyflink/table/types.py
index dbf92f88bb8c25..09186240a62904 100644
--- a/flink-python/pyflink/table/types.py
+++ b/flink-python/pyflink/table/types.py
@@ -2005,7 +2005,7 @@ def __repr__(self):
return "Row(%s)" % ", ".join("%s=%r" % (k, v)
for k, v in zip(self._fields, tuple(self)))
else:
- return "" % ", ".join(self)
+ return "" % ", ".join("%r" % field for field in self)
_acceptable_types = {
diff --git a/flink-python/pyflink/table/udf.py b/flink-python/pyflink/table/udf.py
new file mode 100644
index 00000000000000..97ac3910708e68
--- /dev/null
+++ b/flink-python/pyflink/table/udf.py
@@ -0,0 +1,229 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import abc
+import collections
+import functools
+import inspect
+import sys
+
+from pyflink.java_gateway import get_gateway
+from pyflink.table.types import DataType, _to_java_type
+from pyflink.util import utils
+
+__all__ = ['FunctionContext', 'ScalarFunction', 'udf']
+
+
+if sys.version_info >= (3, 4):
+ ABC = abc.ABC
+else:
+ ABC = abc.ABCMeta('ABC', (), {})
+
+
+class FunctionContext(object):
+ """
+ Used to obtain global runtime information about the context in which the
+ user-defined function is executed. The information includes the metric group,
+ and global job parameters, etc.
+ """
+ pass
+
+
+class UserDefinedFunction(ABC):
+ """
+ Base interface for user-defined function.
+ """
+
+ def open(self, function_context):
+ """
+ Initialization method for the function. It is called before the actual working methods
+ and thus suitable for one time setup work.
+
+ :param function_context: the context of the function
+ :type function_context: FunctionContext
+ """
+ pass
+
+ def close(self):
+ """
+ Tear-down method for the user code. It is called after the last call to the main
+ working methods.
+ """
+ pass
+
+ def is_deterministic(self):
+ """
+ Returns information about the determinism of the function's results.
+ It returns true if and only if a call to this function is guaranteed to
+ always return the same result given the same parameters. true is assumed by default.
+ If the function is not pure functional like random(), date(), now(),
+ this method must return false.
+
+ :return: the determinism of the function's results.
+ :rtype: bool
+ """
+ return True
+
+
+class ScalarFunction(UserDefinedFunction):
+ """
+ Base interface for user-defined scalar function. A user-defined scalar functions maps zero, one,
+ or multiple scalar values to a new scalar value.
+ """
+
+ @abc.abstractmethod
+ def eval(self, *args):
+ """
+ Method which defines the logic of the scalar function.
+ """
+ pass
+
+
+class DelegatingScalarFunction(ScalarFunction):
+ """
+ Helper scalar function implementation for lambda expression and python function. It's for
+ internal use only.
+ """
+
+ def __init__(self, func):
+ self.func = func
+
+ def eval(self, *args):
+ return self.func(*args)
+
+
+class UserDefinedFunctionWrapper(object):
+ """
+ Wrapper for Python user-defined function. It handles things like converting lambda
+ functions to user-defined functions, creating the Java user-defined function representation,
+ etc. It's for internal use only.
+ """
+
+ def __init__(self, func, input_types, result_type, deterministic=None, name=None):
+ if inspect.isclass(func) or (
+ not isinstance(func, UserDefinedFunction) and not callable(func)):
+ raise TypeError(
+ "Invalid function: not a function or callable (__call__ is not defined): {0}"
+ .format(type(func)))
+
+ if not isinstance(input_types, collections.Iterable):
+ input_types = [input_types]
+
+ for input_type in input_types:
+ if not isinstance(input_type, DataType):
+ raise TypeError(
+ "Invalid input_type: input_type should be DataType but contains {}".format(
+ input_type))
+
+ if not isinstance(result_type, DataType):
+ raise TypeError(
+ "Invalid returnType: returnType should be DataType but is {}".format(result_type))
+
+ self._func = func
+ self._input_types = input_types
+ self._result_type = result_type
+ self._judf_placeholder = None
+ self._name = name or (
+ func.__name__ if hasattr(func, '__name__') else func.__class__.__name__)
+
+ if deterministic is not None and isinstance(func, UserDefinedFunction) and deterministic \
+ != func.is_deterministic():
+ raise ValueError("Inconsistent deterministic: {} and {}".format(
+ deterministic, func.is_deterministic()))
+
+ # default deterministic is True
+ self._deterministic = deterministic if deterministic is not None else (
+ func.is_deterministic() if isinstance(func, UserDefinedFunction) else True)
+
+ @property
+ def _judf(self):
+ if self._judf_placeholder is None:
+ self._judf_placeholder = self._create_judf()
+ return self._judf_placeholder
+
+ def _create_judf(self):
+ func = self._func
+ if not isinstance(self._func, UserDefinedFunction):
+ func = DelegatingScalarFunction(self._func)
+
+ import cloudpickle
+ serialized_func = cloudpickle.dumps(func)
+
+ gateway = get_gateway()
+ j_input_types = utils.to_jarray(gateway.jvm.TypeInformation,
+ [_to_java_type(i) for i in self._input_types])
+ j_result_type = _to_java_type(self._result_type)
+ return gateway.jvm.org.apache.flink.table.util.python.PythonTableUtils \
+ .createPythonScalarFunction(self._name,
+ bytearray(serialized_func),
+ j_input_types,
+ j_result_type,
+ self._deterministic,
+ _get_python_env())
+
+
+# TODO: support to configure the python execution environment
+def _get_python_env():
+ gateway = get_gateway()
+ exec_type = gateway.jvm.org.apache.flink.table.functions.python.PythonEnv.ExecType.PROCESS
+ return gateway.jvm.org.apache.flink.table.functions.python.PythonEnv(exec_type)
+
+
+def _create_udf(f, input_types, result_type, deterministic, name):
+ return UserDefinedFunctionWrapper(f, input_types, result_type, deterministic, name)
+
+
+def udf(f=None, input_types=None, result_type=None, deterministic=None, name=None):
+ """
+ Helper method for creating a user-defined function.
+
+ Example:
+ ::
+
+ >>> add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())
+
+ >>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
+ ... result_type=DataTypes.BIGINT())
+ ... def add(i, j):
+ ... return i + j
+
+ >>> class SubtractOne(ScalarFunction):
+ ... def eval(self, i):
+ ... return i - 1
+ >>> subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())
+
+ :param f: lambda function or user-defined function.
+ :type f: function or UserDefinedFunction or type
+ :param input_types: the input data types.
+ :type input_types: list[DataType] or DataType
+ :param result_type: the result data type.
+ :type result_type: DataType
+ :param name: the function name.
+ :type name: str
+ :param deterministic: the determinism of the function's results. True if and only if a call to
+ this function is guaranteed to always return the same result given the
+ same parameters. (default True)
+ :type deterministic: bool
+ :return: UserDefinedFunctionWrapper or function.
+ :rtype: UserDefinedFunctionWrapper or function
+ """
+ # decorator
+ if f is None:
+ return functools.partial(_create_udf, input_types=input_types, result_type=result_type,
+ deterministic=deterministic, name=name)
+ else:
+ return _create_udf(f, input_types, result_type, deterministic, name)
diff --git a/flink-python/pyflink/testing/test_case_utils.py b/flink-python/pyflink/testing/test_case_utils.py
index c1d484ed01c041..21d3f09e629aba 100644
--- a/flink-python/pyflink/testing/test_case_utils.py
+++ b/flink-python/pyflink/testing/test_case_utils.py
@@ -27,11 +27,10 @@
from py4j.java_gateway import JavaObject
from py4j.protocol import Py4JJavaError
+from pyflink import gen_protos
from pyflink.table.sources import CsvTableSource
-
from pyflink.dataset import ExecutionEnvironment
from pyflink.datastream import StreamExecutionEnvironment
-
from pyflink.find_flink_home import _find_flink_home
from pyflink.table import BatchTableEnvironment, StreamTableEnvironment
from pyflink.java_gateway import get_gateway
@@ -76,6 +75,8 @@ class PyFlinkTestCase(unittest.TestCase):
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
+ gen_protos.generate_proto_files()
+
os.environ["FLINK_TESTING"] = "1"
_find_flink_home()
diff --git a/flink-python/setup.py b/flink-python/setup.py
index 78a6edd6beb320..f5b26da98e7ee6 100644
--- a/flink-python/setup.py
+++ b/flink-python/setup.py
@@ -23,6 +23,12 @@
from shutil import copytree, copy, rmtree
from setuptools import setup
+from setuptools.command.install import install
+from setuptools.command.build_py import build_py
+from setuptools.command.develop import develop
+from setuptools.command.egg_info import egg_info
+from setuptools.command.sdist import sdist
+from setuptools.command.test import test
if sys.version_info < (2, 7):
print("Python versions prior to 2.7 are not supported for PyFlink.",
@@ -62,6 +68,24 @@
in_flink_source = os.path.isfile("../flink-java/src/main/java/org/apache/flink/api/java/"
"ExecutionEnvironment.java")
+
+# We must generate protos after setup_requires are installed.
+def generate_protos_first(original_cmd):
+ try:
+ # pylint: disable=wrong-import-position
+ from pyflink import gen_protos
+
+ class cmd(original_cmd, object):
+ def run(self):
+ gen_protos.generate_proto_files()
+ super(cmd, self).run()
+ return cmd
+ except ImportError:
+ import warnings
+ warnings.warn("Could not import gen_protos, skipping proto generation.")
+ return original_cmd
+
+
try:
if in_flink_source:
@@ -184,7 +208,8 @@
license='https://www.apache.org/licenses/LICENSE-2.0',
author='Flink Developers',
author_email='dev@flink.apache.org',
- install_requires=['py4j==0.10.8.1', 'python-dateutil', 'apache-beam==2.15.0'],
+ install_requires=['py4j==0.10.8.1', 'python-dateutil', 'apache-beam==2.15.0',
+ 'cloudpickle==1.2.2'],
tests_require=['pytest==4.4.1'],
description='Apache Flink Python API',
long_description=long_description,
@@ -195,7 +220,15 @@
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
- 'Programming Language :: Python :: 3.7']
+ 'Programming Language :: Python :: 3.7'],
+ cmdclass={
+ 'build_py': generate_protos_first(build_py),
+ 'develop': generate_protos_first(develop),
+ 'egg_info': generate_protos_first(egg_info),
+ 'sdist': generate_protos_first(sdist),
+ 'test': generate_protos_first(test),
+ 'install': generate_protos_first(install),
+ },
)
finally:
if in_flink_source:
diff --git a/flink-python/src/main/java/org/apache/flink/python/AbstractPythonFunctionRunner.java b/flink-python/src/main/java/org/apache/flink/python/AbstractPythonFunctionRunner.java
index 73c753544a0cee..53e6ef94fb1674 100644
--- a/flink-python/src/main/java/org/apache/flink/python/AbstractPythonFunctionRunner.java
+++ b/flink-python/src/main/java/org/apache/flink/python/AbstractPythonFunctionRunner.java
@@ -20,6 +20,7 @@
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.table.functions.python.PythonEnv;
@@ -50,8 +51,6 @@
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
import java.util.Random;
/**
@@ -280,14 +279,15 @@ private static String randomString(Random random) {
*/
protected RunnerApi.Environment createPythonExecutionEnvironment() {
if (pythonEnv.getExecType() == PythonEnv.ExecType.PROCESS) {
- final Map env = new HashMap<>(2);
- env.put("python", pythonEnv.getPythonExec());
+ String flinkHomePath = System.getenv(ConfigConstants.ENV_FLINK_HOME_DIR);
+ String pythonWorkerCommand =
+ flinkHomePath + File.separator + "bin" + File.separator + "pyflink-udf-runner.sh";
return Environments.createProcessEnvironment(
"",
"",
- pythonEnv.getPythonWorkerCmd(),
- env);
+ pythonWorkerCommand,
+ null);
} else {
throw new UnsupportedOperationException(String.format(
"Execution type '%s' is not supported.", pythonEnv.getExecType()));
diff --git a/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonScalarFunctionRunnerTest.java b/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonScalarFunctionRunnerTest.java
index c53f69a06b37f4..4c28d6a8515d53 100644
--- a/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonScalarFunctionRunnerTest.java
+++ b/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonScalarFunctionRunnerTest.java
@@ -121,7 +121,7 @@ public byte[] getSerializedPythonFunction() {
@Override
public PythonEnv getPythonEnv() {
- return new PythonEnv("", "", PythonEnv.ExecType.PROCESS);
+ return new PythonEnv(PythonEnv.ExecType.PROCESS);
}
}
}
diff --git a/flink-python/src/test/java/org/apache/flink/table/functions/python/BaseRowPythonScalarFunctionRunnerTest.java b/flink-python/src/test/java/org/apache/flink/table/functions/python/BaseRowPythonScalarFunctionRunnerTest.java
index 02f5b7462ca8c6..dacd5092b54f47 100644
--- a/flink-python/src/test/java/org/apache/flink/table/functions/python/BaseRowPythonScalarFunctionRunnerTest.java
+++ b/flink-python/src/test/java/org/apache/flink/table/functions/python/BaseRowPythonScalarFunctionRunnerTest.java
@@ -116,7 +116,7 @@ public AbstractPythonScalarFunctionRunner createPythonScalarFu
// ignore the execution results
};
- final PythonEnv pythonEnv = new PythonEnv("", "", PythonEnv.ExecType.PROCESS);
+ final PythonEnv pythonEnv = new PythonEnv(PythonEnv.ExecType.PROCESS);
return new BaseRowPythonScalarFunctionRunner(
"testPythonRunner",
diff --git a/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonScalarFunctionRunnerTest.java b/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonScalarFunctionRunnerTest.java
index 6a9398a202e758..212bf2a6e8e964 100644
--- a/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonScalarFunctionRunnerTest.java
+++ b/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonScalarFunctionRunnerTest.java
@@ -244,7 +244,7 @@ public AbstractPythonScalarFunctionRunner createPythonScalarFunctionRu
// ignore the execution results
};
- final PythonEnv pythonEnv = new PythonEnv("", "", PythonEnv.ExecType.PROCESS);
+ final PythonEnv pythonEnv = new PythonEnv(PythonEnv.ExecType.PROCESS);
return new PythonScalarFunctionRunner(
"testPythonRunner",
@@ -266,7 +266,7 @@ private AbstractPythonScalarFunctionRunner createUDFRunner(
RowType rowType = new RowType(Collections.singletonList(new RowType.RowField("f1", new BigIntType())));
- final PythonEnv pythonEnv = new PythonEnv("", "", PythonEnv.ExecType.PROCESS);
+ final PythonEnv pythonEnv = new PythonEnv(PythonEnv.ExecType.PROCESS);
return new PythonScalarFunctionRunnerTestHarness(
"testPythonRunner",
diff --git a/flink-python/tox.ini b/flink-python/tox.ini
index 4efde543a02054..873af202864577 100644
--- a/flink-python/tox.ini
+++ b/flink-python/tox.ini
@@ -38,4 +38,4 @@ commands =
# up to 100 characters in length, not 79.
ignore=E226,E241,E305,E402,E722,E731,E741,W503,W504
max-line-length=100
-exclude=.tox/*,dev/*,lib/*,target/*,build/*,dist/*,pyflink/shell.py,.eggs/*,pyflink/fn_execution/tests/process_mode_test_data.py
+exclude=.tox/*,dev/*,lib/*,target/*,build/*,dist/*,pyflink/shell.py,.eggs/*,pyflink/fn_execution/tests/process_mode_test_data.py,pyflink/fn_execution/*_pb2.py
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonEnv.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonEnv.java
index 5d2be94cb9ca82..d01a08d95ccbf9 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonEnv.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonEnv.java
@@ -31,38 +31,15 @@ public final class PythonEnv implements Serializable {
private static final long serialVersionUID = 1L;
- /**
- * The path of the Python executable file used.
- */
- private final String pythonExec;
-
- /**
- * The command to start Python worker process.
- */
- private final String pythonWorkerCmd;
-
/**
* The execution type of the Python worker, it defines how to execute the Python functions.
*/
private final ExecType execType;
- public PythonEnv(
- String pythonExec,
- String pythonWorkerCmd,
- ExecType execType) {
- this.pythonExec = Preconditions.checkNotNull(pythonExec);
- this.pythonWorkerCmd = Preconditions.checkNotNull(pythonWorkerCmd);
+ public PythonEnv(ExecType execType) {
this.execType = Preconditions.checkNotNull(execType);
}
- public String getPythonExec() {
- return pythonExec;
- }
-
- public String getPythonWorkerCmd() {
- return pythonWorkerCmd;
- }
-
public ExecType getExecType() {
return execType;
}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/PythonFunctionCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/PythonFunctionCodeGenerator.scala
new file mode 100644
index 00000000000000..e97e7b79eeb797
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/PythonFunctionCodeGenerator.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.codegen
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.table.codegen.CodeGenUtils.{primitiveDefaultValue, primitiveTypeTermForTypeInfo, newName}
+import org.apache.flink.table.codegen.Indenter.toISC
+import org.apache.flink.table.functions.{UserDefinedFunction, FunctionLanguage, ScalarFunction}
+import org.apache.flink.table.functions.python.{PythonEnv, PythonFunction}
+import org.apache.flink.table.utils.EncodingUtils
+
+/**
+ * A code generator for generating Python [[UserDefinedFunction]]s.
+ */
+object PythonFunctionCodeGenerator extends Compiler[UserDefinedFunction] {
+
+ private val PYTHON_SCALAR_FUNCTION_NAME = "PythonScalarFunction"
+
+ /**
+ * Generates a [[ScalarFunction]] for the specified Python user-defined function.
+ *
+ * @param name class name of the user-defined function. Must be a valid Java class identifier
+ * @param serializedScalarFunction serialized Python scalar function
+ * @param inputTypes input data types
+ * @param resultType expected result type
+ * @param deterministic the determinism of the function's results
+ * @param pythonEnv the Python execution environment
+ * @return instance of generated ScalarFunction
+ */
+ def generateScalarFunction(
+ name: String,
+ serializedScalarFunction: Array[Byte],
+ inputTypes: Array[TypeInformation[_]],
+ resultType: TypeInformation[_],
+ deterministic: Boolean,
+ pythonEnv: PythonEnv): ScalarFunction = {
+ val funcName = newName(PYTHON_SCALAR_FUNCTION_NAME)
+ val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)
+ val defaultResultValue = primitiveDefaultValue(resultType)
+ val inputParamCode = inputTypes.zipWithIndex.map { case (inputType, index) =>
+ s"${primitiveTypeTermForTypeInfo(inputType)} in$index"
+ }.mkString(", ")
+
+ val encodingUtilsTypeTerm = classOf[EncodingUtils].getCanonicalName
+ val typeInfoTypeTerm = classOf[TypeInformation[_]].getCanonicalName
+ val inputTypesCode = inputTypes.map(EncodingUtils.encodeObjectToString).map { inputType =>
+ s"""
+ |($typeInfoTypeTerm) $encodingUtilsTypeTerm.decodeStringToObject(
+ | "$inputType", $typeInfoTypeTerm.class)
+ |""".stripMargin
+ }.mkString(", ")
+
+ val encodedResultType = EncodingUtils.encodeObjectToString(resultType)
+ val encodedScalarFunction = EncodingUtils.encodeBytesToBase64(serializedScalarFunction)
+ val encodedPythonEnv = EncodingUtils.encodeObjectToString(pythonEnv)
+ val pythonEnvTypeTerm = classOf[PythonEnv].getCanonicalName
+
+ val funcCode = j"""
+ |public class $funcName extends ${classOf[ScalarFunction].getCanonicalName}
+ | implements ${classOf[PythonFunction].getCanonicalName} {
+ |
+ | private static final long serialVersionUID = 1L;
+ |
+ | public $resultTypeTerm eval($inputParamCode) {
+ | return $defaultResultValue;
+ | }
+ |
+ | @Override
+ | public $typeInfoTypeTerm[] getParameterTypes(Class>[] signature) {
+ | return new $typeInfoTypeTerm[]{$inputTypesCode};
+ | }
+ |
+ | @Override
+ | public $typeInfoTypeTerm getResultType(Class>[] signature) {
+ | return ($typeInfoTypeTerm) $encodingUtilsTypeTerm.decodeStringToObject(
+ | "$encodedResultType", $typeInfoTypeTerm.class);
+ | }
+ |
+ | @Override
+ | public ${classOf[FunctionLanguage].getCanonicalName} getLanguage() {
+ | return ${classOf[FunctionLanguage].getCanonicalName}.PYTHON;
+ | }
+ |
+ | @Override
+ | public byte[] getSerializedPythonFunction() {
+ | return $encodingUtilsTypeTerm.decodeBase64ToBytes("$encodedScalarFunction");
+ | }
+ |
+ | @Override
+ | public $pythonEnvTypeTerm getPythonEnv() {
+ | return ($pythonEnvTypeTerm) $encodingUtilsTypeTerm.decodeStringToObject(
+ | "$encodedPythonEnv", $pythonEnvTypeTerm.class);
+ | }
+ |
+ | @Override
+ | public boolean isDeterministic() {
+ | return $deterministic;
+ | }
+ |
+ | @Override
+ | public String toString() {
+ | return "$name";
+ | }
+ |}
+ |""".stripMargin
+
+ val clazz = compile(
+ Thread.currentThread().getContextClassLoader,
+ funcName,
+ funcCode)
+ clazz.newInstance().asInstanceOf[ScalarFunction]
+ }
+}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala
index 094945c344bbfc..a84dad81599bce 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala
@@ -31,6 +31,9 @@ import org.apache.flink.api.java.io.CollectionInputFormat
import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo}
import org.apache.flink.core.io.InputSplit
import org.apache.flink.table.api.{TableSchema, Types}
+import org.apache.flink.table.codegen.PythonFunctionCodeGenerator
+import org.apache.flink.table.functions.ScalarFunction
+import org.apache.flink.table.functions.python.PythonEnv
import org.apache.flink.table.sources.InputFormatTableSource
import org.apache.flink.types.Row
@@ -38,6 +41,32 @@ import scala.collection.JavaConversions._
object PythonTableUtils {
+ /**
+ * Creates a [[ScalarFunction]] for the specified Python ScalarFunction.
+ *
+ * @param funcName class name of the user-defined function. Must be a valid Java class identifier
+ * @param serializedScalarFunction serialized Python scalar function
+ * @param inputTypes input data types
+ * @param resultType expected result type
+ * @param deterministic the determinism of the function's results
+ * @param pythonEnv the Python execution environment
+ * @return A generated Java ScalarFunction representation for the specified Python ScalarFunction
+ */
+ def createPythonScalarFunction(
+ funcName: String,
+ serializedScalarFunction: Array[Byte],
+ inputTypes: Array[TypeInformation[_]],
+ resultType: TypeInformation[_],
+ deterministic: Boolean,
+ pythonEnv: PythonEnv): ScalarFunction =
+ PythonFunctionCodeGenerator.generateScalarFunction(
+ funcName,
+ serializedScalarFunction,
+ inputTypes,
+ resultType,
+ deterministic,
+ pythonEnv)
+
/**
* Wrap the unpickled python data with an InputFormat. It will be passed to
* PythonInputFormatTableSource later.
diff --git a/pom.xml b/pom.xml
index 5a8b5983917b0b..832e298d850f5e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1422,6 +1422,7 @@ under the License.
flink-python/lib/**
flink-python/dev/download/**
flink-python/docs/_build/**
+ flink-python/pyflink/fn_execution/*_pb2.py