diff --git a/.gitignore b/.gitignore index 88c448814000a9..4e3a5ae2f2e50c 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,8 @@ flink-python/dev/download flink-python/dev/.conda/ flink-python/dev/log/ flink-python/dev/.stage.txt +flink-python/.eggs/ +flink-python/pyflink/fn_execution/*_pb2.py atlassian-ide-plugin.xml out/ /docs/api diff --git a/flink-python/pom.xml b/flink-python/pom.xml index ed16daa09e3a89..ffeb4864c5e5aa 100644 --- a/flink-python/pom.xml +++ b/flink-python/pom.xml @@ -327,6 +327,26 @@ under the License. + + exec-maven-plugin + org.codehaus.mojo + 1.5.0 + + + Protos Generation + generate-sources + + exec + + + python + + ${basedir}/pyflink/gen_protos.py + + + + + diff --git a/flink-python/pyflink/fn_execution/boot.py b/flink-python/pyflink/fn_execution/boot.py index e42ad2eea58857..bf913846fb2bfc 100644 --- a/flink-python/pyflink/fn_execution/boot.py +++ b/flink-python/pyflink/fn_execution/boot.py @@ -145,5 +145,5 @@ def check_not_empty(check_str, error_message): if "FLINK_BOOT_TESTING" in os.environ and os.environ["FLINK_BOOT_TESTING"] == "1": exit(0) -call([sys.executable, "-m", "apache_beam.runners.worker.sdk_worker_main"], +call([sys.executable, "-m", "pyflink.fn_execution.sdk_worker_main"], stdout=sys.stdout, stderr=sys.stderr, env=env) diff --git a/flink-python/pyflink/fn_execution/coder_impl.py b/flink-python/pyflink/fn_execution/coder_impl.py new file mode 100644 index 00000000000000..2f7c5d3ab6536f --- /dev/null +++ b/flink-python/pyflink/fn_execution/coder_impl.py @@ -0,0 +1,80 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import sys + +from apache_beam.coders.coder_impl import StreamCoderImpl + +if sys.version > '3': + xrange = range + + +class RowCoderImpl(StreamCoderImpl): + + def __init__(self, field_coders): + self._field_coders = field_coders + + def encode_to_stream(self, value, out_stream, nested): + self.write_null_mask(value, out_stream) + for i in xrange(len(self._field_coders)): + self._field_coders[i].encode_to_stream(value[i], out_stream, nested) + + def decode_from_stream(self, in_stream, nested): + from pyflink.table import Row + null_mask = self.read_null_mask(len(self._field_coders), in_stream) + assert len(null_mask) == len(self._field_coders) + return Row(*[None if null_mask[idx] else self._field_coders[idx].decode_from_stream( + in_stream, nested) for idx in xrange(0, len(null_mask))]) + + @staticmethod + def write_null_mask(value, out_stream): + field_pos = 0 + field_count = len(value) + while field_pos < field_count: + b = 0x00 + # set bits in byte + num_pos = min(8, field_count - field_pos) + byte_pos = 0 + while byte_pos < num_pos: + b = b << 1 + # set bit if field is null + if value[field_pos + byte_pos] is None: + b |= 0x01 + byte_pos += 1 + field_pos += num_pos + # shift bits if last byte is not completely filled + b <<= (8 - byte_pos) + # write byte + out_stream.write_byte(b) + + @staticmethod + def read_null_mask(field_count, in_stream): + null_mask = [] + field_pos = 0 + while field_pos < field_count: + b = in_stream.read_byte() + num_pos = min(8, field_count - field_pos) + byte_pos = 0 + while byte_pos < num_pos: + null_mask.append((b & 0x80) > 0) + b = b << 1 + byte_pos += 1 + field_pos += num_pos + return null_mask + + def __repr__(self): + return 'RowCoderImpl[%s]' % ', '.join(str(c) for c in self._field_coders) diff --git a/flink-python/pyflink/fn_execution/coders.py b/flink-python/pyflink/fn_execution/coders.py new file mode 100644 index 00000000000000..c6add86477aba0 --- /dev/null +++ b/flink-python/pyflink/fn_execution/coders.py @@ -0,0 +1,86 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import sys + +from apache_beam.coders import Coder, VarIntCoder +from apache_beam.coders.coders import FastCoder + +from pyflink.fn_execution import coder_impl +from pyflink.fn_execution import flink_fn_execution_pb2 + +FLINK_SCHEMA_CODER_URN = "flink:coder:schema:v1" + +if sys.version > '3': + xrange = range + + +__all__ = ['RowCoder'] + + +class RowCoder(FastCoder): + """ + Coder for Row. + """ + + def __init__(self, field_coders): + self._field_coders = field_coders + + def _create_impl(self): + return coder_impl.RowCoderImpl([c.get_impl() for c in self._field_coders]) + + def is_deterministic(self): + return all(c.is_deterministic() for c in self._field_coders) + + def to_type_hint(self): + from pyflink.table import Row + return Row + + def __repr__(self): + return 'RowCoder[%s]' % ', '.join(str(c) for c in self._field_coders) + + def __eq__(self, other): + return (self.__class__ == other.__class__ + and len(self._field_coders) == len(other._field_coders) + and [self._field_coders[i] == other._field_coders[i] for i in + xrange(len(self._field_coders))]) + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash(self._field_coders) + + +@Coder.register_urn(FLINK_SCHEMA_CODER_URN, flink_fn_execution_pb2.Schema) +def _pickle_from_runner_api_parameter(schema_proto, unused_components, unused_context): + return RowCoder([from_proto(f.type) for f in schema_proto.fields]) + + +def from_proto(field_type): + """ + Creates the corresponding :class:`Coder` given the protocol representation of the field type. + + :param field_type: the protocol representation of the field type + :return: :class:`Coder` + """ + if field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.BIGINT: + return VarIntCoder() + elif field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.ROW: + return RowCoder([from_proto(f.type) for f in field_type.row_schema.fields]) + else: + raise ValueError("field_type %s is not supported." % field_type) diff --git a/flink-python/pyflink/fn_execution/operations.py b/flink-python/pyflink/fn_execution/operations.py new file mode 100644 index 00000000000000..13e179bb53c3fd --- /dev/null +++ b/flink-python/pyflink/fn_execution/operations.py @@ -0,0 +1,261 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from abc import abstractmethod, ABCMeta + +from apache_beam.runners.worker import operation_specs +from apache_beam.runners.worker import bundle_processor +from apache_beam.runners.worker.operations import Operation + +from pyflink.fn_execution import flink_fn_execution_pb2 + +SCALAR_FUNCTION_URN = "flink:transform:scalar_function:v1" + + +class InputGetter(object): + """ + Base class for get an input argument for a :class:`UserDefinedFunction`. + """ + __metaclass__ = ABCMeta + + def open(self): + pass + + def close(self): + pass + + @abstractmethod + def get(self, value): + pass + + +class OffsetInputGetter(InputGetter): + """ + InputGetter for the input argument which is a column of the input row. + + :param input_offset: the offset of the column in the input row + """ + + def __init__(self, input_offset): + self.input_offset = input_offset + + def get(self, value): + return value[self.input_offset] + + +class ScalarFunctionInputGetter(InputGetter): + """ + InputGetter for the input argument which is a Python :class:`ScalarFunction`. This is used for + chaining Python functions. + + :param scalar_function_proto: the proto representation of the Python :class:`ScalarFunction` + """ + + def __init__(self, scalar_function_proto): + self.scalar_function_invoker = create_scalar_function_invoker(scalar_function_proto) + + def open(self): + self.scalar_function_invoker.invoke_open() + + def close(self): + self.scalar_function_invoker.invoke_close() + + def get(self, value): + return self.scalar_function_invoker.invoke_eval(value) + + +class ScalarFunctionInvoker(object): + """ + An abstraction that can be used to execute :class:`ScalarFunction` methods. + + A ScalarFunctionInvoker describes a particular way for invoking methods of a + :class:`ScalarFunction`. + + :param scalar_function: the :class:`ScalarFunction` to execute + :param inputs: the input arguments for the :class:`ScalarFunction` + """ + + def __init__(self, scalar_function, inputs): + self.scalar_function = scalar_function + self.input_getters = [] + for input in inputs: + if input.HasField("udf"): + # for chaining Python UDF input: the input argument is a Python ScalarFunction + self.input_getters.append(ScalarFunctionInputGetter(input.udf)) + else: + # the input argument is a column of the input row + self.input_getters.append(OffsetInputGetter(input.inputOffset)) + + def invoke_open(self): + """ + Invokes the ScalarFunction.open() function. + """ + for input_getter in self.input_getters: + input_getter.open() + # set the FunctionContext to None for now + self.scalar_function.open(None) + + def invoke_close(self): + """ + Invokes the ScalarFunction.close() function. + """ + for input_getter in self.input_getters: + input_getter.close() + self.scalar_function.close() + + def invoke_eval(self, value): + """ + Invokes the ScalarFunction.eval() function. + + :param value: the input element for which eval() method should be invoked + """ + args = [input_getter.get(value) for input_getter in self.input_getters] + return self.scalar_function.eval(*args) + + +def create_scalar_function_invoker(scalar_function_proto): + """ + Creates :class:`ScalarFunctionInvoker` from the proto representation of a + :class:`ScalarFunction`. + + :param scalar_function_proto: the proto representation of the Python :class:`ScalarFunction` + :return: :class:`ScalarFunctionInvoker`. + """ + import cloudpickle + scalar_function = cloudpickle.loads(scalar_function_proto.payload) + return ScalarFunctionInvoker(scalar_function, scalar_function_proto.inputs) + + +class ScalarFunctionRunner(object): + """ + The runner which is responsible for executing the scalar functions and send the + execution results back to the remote Java operator. + + :param udfs_proto: protocol representation for the scalar functions to execute + """ + + def __init__(self, udfs_proto): + self.scalar_function_invokers = [create_scalar_function_invoker(f) for f in + udfs_proto] + + def setup(self, main_receivers): + """ + Set up the ScalarFunctionRunner. + + :param main_receivers: Receiver objects which is responsible for sending the execution + results back the the remote Java operator + """ + from apache_beam.runners.common import _OutputProcessor + self.output_processor = _OutputProcessor( + window_fn=None, + main_receivers=main_receivers, + tagged_receivers=None, + per_element_output_counter=None) + + def open(self): + for invoker in self.scalar_function_invokers: + invoker.invoke_open() + + def close(self): + for invoker in self.scalar_function_invokers: + invoker.invoke_close() + + def process(self, windowed_value): + results = [invoker.invoke_eval(windowed_value.value) for invoker in + self.scalar_function_invokers] + from pyflink.table import Row + result = Row(*results) + # send the execution results back + self.output_processor.process_outputs(windowed_value, [result]) + + +class ScalarFunctionOperation(Operation): + """ + An operation that will execute ScalarFunctions for each input element. + """ + + def __init__(self, name, spec, counter_factory, sampler, consumers): + super(ScalarFunctionOperation, self).__init__(name, spec, counter_factory, sampler) + for tag, op_consumers in consumers.items(): + for consumer in op_consumers: + self.add_receiver(consumer, 0) + + self.scalar_function_runner = ScalarFunctionRunner(self.spec.serialized_fn) + self.scalar_function_runner.open() + + def setup(self): + with self.scoped_start_state: + super(ScalarFunctionOperation, self).setup() + self.scalar_function_runner.setup(self.receivers[0]) + + def start(self): + with self.scoped_start_state: + super(ScalarFunctionOperation, self).start() + + def process(self, o): + with self.scoped_process_state: + self.scalar_function_runner.process(o) + + def finish(self): + with self.scoped_finish_state: + super(ScalarFunctionOperation, self).finish() + + def needs_finalization(self): + return False + + def reset(self): + super(ScalarFunctionOperation, self).reset() + + def teardown(self): + with self.scoped_finish_state: + self.scalar_function_runner.close() + + def progress_metrics(self): + metrics = super(ScalarFunctionOperation, self).progress_metrics() + metrics.processed_elements.measured.output_element_counts.clear() + tag = None + receiver = self.receivers[0] + metrics.processed_elements.measured.output_element_counts[ + str(tag)] = receiver.opcounter.element_counter.value() + return metrics + + +@bundle_processor.BeamTransformFactory.register_urn( + SCALAR_FUNCTION_URN, flink_fn_execution_pb2.UserDefinedFunctions) +def create(factory, transform_id, transform_proto, parameter, consumers): + return _create_user_defined_function_operation( + factory, transform_proto, consumers, parameter.udfs) + + +def _create_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto, + operation_cls=ScalarFunctionOperation): + output_tags = list(transform_proto.outputs.keys()) + output_coders = factory.get_output_coders(transform_proto) + spec = operation_specs.WorkerDoFn( + serialized_fn=udfs_proto, + output_tags=output_tags, + input=None, + side_inputs=None, + output_coders=[output_coders[tag] for tag in output_tags]) + + return operation_cls( + transform_proto.unique_name, + spec, + factory.counter_factory, + factory.state_sampler, + consumers) diff --git a/flink-python/pyflink/fn_execution/sdk_worker_main.py b/flink-python/pyflink/fn_execution/sdk_worker_main.py new file mode 100644 index 00000000000000..82d091c52953b7 --- /dev/null +++ b/flink-python/pyflink/fn_execution/sdk_worker_main.py @@ -0,0 +1,30 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import sys + +# force to register the operations to SDK Harness +import pyflink.fn_execution.operations # noqa # pylint: disable=unused-import + +# force to register the coders to SDK Harness +import pyflink.fn_execution.coders # noqa # pylint: disable=unused-import + +import apache_beam.runners.worker.sdk_worker_main + +if __name__ == '__main__': + apache_beam.runners.worker.sdk_worker_main.main(sys.argv) diff --git a/flink-python/pyflink/gen_protos.py b/flink-python/pyflink/gen_protos.py new file mode 100644 index 00000000000000..ce4ddd4010a405 --- /dev/null +++ b/flink-python/pyflink/gen_protos.py @@ -0,0 +1,146 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from __future__ import absolute_import +from __future__ import print_function + +import glob +import logging +import multiprocessing +import os +import platform +import shutil +import subprocess +import sys +import time +import warnings + +import pkg_resources + +# latest grpcio-tools incompatible with latest protobuf 3.6.1. +GRPC_TOOLS = 'grpcio-tools>=1.3.5,<=1.14.2' + +PROTO_PATHS = [ + os.path.join('proto'), +] + +PYTHON_OUTPUT_PATH = os.path.join('fn_execution') + + +def generate_proto_files(force=False): + try: + import grpc_tools # noqa # pylint: disable=unused-import + except ImportError: + warnings.warn('Installing grpcio-tools is recommended for development.') + + py_sdk_root = os.path.dirname(os.path.abspath(__file__)) + proto_dirs = [os.path.join(py_sdk_root, path) for path in PROTO_PATHS] + proto_files = sum( + [glob.glob(os.path.join(d, '*.proto')) for d in proto_dirs], []) + out_dir = os.path.join(py_sdk_root, PYTHON_OUTPUT_PATH) + out_files = [path for path in glob.glob(os.path.join(out_dir, '*_pb2.py'))] + + if out_files and not proto_files and not force: + # We have out_files but no protos; assume they're up to date. + # This is actually the common case (e.g. installation from an sdist). + logging.info('No proto files; using existing generated files.') + return + + elif not out_files and not proto_files: + raise RuntimeError( + 'No proto files found in %s.' % proto_dirs) + + # Regenerate iff the proto files or this file are newer. + elif force or not out_files or len(out_files) < len(proto_files) or ( + min(os.path.getmtime(path) for path in out_files) + <= max(os.path.getmtime(path) + for path in proto_files + [os.path.realpath(__file__)])): + try: + from grpc_tools import protoc + except ImportError: + if platform.system() == 'Windows': + # For Windows, grpcio-tools has to be installed manually. + raise RuntimeError( + 'Cannot generate protos for Windows since grpcio-tools package is ' + 'not installed. Please install this package manually ' + 'using \'pip install grpcio-tools\'.') + + # Use a subprocess to avoid messing with this process' path and imports. + # Note that this requires a separate module from setup.py for Windows: + # https://docs.python.org/2/library/multiprocessing.html#windows + p = multiprocessing.Process( + target=_install_grpcio_tools_and_generate_proto_files) + p.start() + p.join() + if p.exitcode: + raise ValueError("Proto generation failed (see log for details).") + else: + logging.info('Regenerating out-of-date Python proto definitions.') + builtin_protos = pkg_resources.resource_filename('grpc_tools', '_proto') + args = ( + [sys.executable] + # expecting to be called from command line + ['--proto_path=%s' % builtin_protos] + + ['--proto_path=%s' % d for d in proto_dirs] + + ['--python_out=%s' % out_dir] + + proto_files) + ret_code = protoc.main(args) + if ret_code: + raise RuntimeError( + 'Protoc returned non-zero status (see logs for details): ' + '%s' % ret_code) + + +# Though wheels are available for grpcio-tools, setup_requires uses +# easy_install which doesn't understand them. This means that it is +# compiled from scratch (which is expensive as it compiles the full +# protoc compiler). Instead, we attempt to install a wheel in a temporary +# directory and add it to the path as needed. +# See https://github.com/pypa/setuptools/issues/377 +def _install_grpcio_tools_and_generate_proto_files(): + install_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), '..', '.eggs', 'grpcio-wheels') + build_path = install_path + '-build' + if os.path.exists(build_path): + shutil.rmtree(build_path) + logging.warning('Installing grpcio-tools into %s', install_path) + try: + start = time.time() + subprocess.check_call( + [sys.executable, '-m', 'pip', 'install', + '--prefix', install_path, '--build', build_path, + '--upgrade', GRPC_TOOLS, "-I"]) + from distutils.dist import Distribution + install_obj = Distribution().get_command_obj('install', create=True) + install_obj.prefix = install_path + install_obj.finalize_options() + logging.warning( + 'Installing grpcio-tools took %0.2f seconds.', time.time() - start) + finally: + sys.stderr.flush() + shutil.rmtree(build_path, ignore_errors=True) + sys.path.append(install_obj.install_purelib) + if install_obj.install_purelib != install_obj.install_platlib: + sys.path.append(install_obj.install_platlib) + try: + generate_proto_files() + finally: + sys.stderr.flush() + + +if __name__ == '__main__': + generate_proto_files(force=True) diff --git a/flink-python/pyflink/table/__init__.py b/flink-python/pyflink/table/__init__.py index e69a9b7ecb0cd8..82ce28a34063c6 100644 --- a/flink-python/pyflink/table/__init__.py +++ b/flink-python/pyflink/table/__init__.py @@ -52,6 +52,11 @@ from a registered :class:`pyflink.table.catalog.Catalog`. - :class:`pyflink.table.TableSchema` Represents a table's structure with field names and data types. + - :class:`pyflink.table.FunctionContext` + Used to obtain global runtime information about the context in which the + user-defined function is executed, such as the metric group, and global job parameters, etc. + - :class:`pyflink.table.ScalarFunction` + Base interface for user-defined scalar function. """ from __future__ import absolute_import @@ -65,6 +70,7 @@ from pyflink.table.sources import TableSource, CsvTableSource from pyflink.table.types import DataTypes, UserDefinedType, Row from pyflink.table.table_schema import TableSchema +from pyflink.table.udf import FunctionContext, ScalarFunction __all__ = [ 'TableEnvironment', @@ -85,5 +91,7 @@ 'DataTypes', 'UserDefinedType', 'Row', - 'TableSchema' + 'TableSchema', + 'FunctionContext', + 'ScalarFunction' ] diff --git a/flink-python/pyflink/table/table_environment.py b/flink-python/pyflink/table/table_environment.py index 33160e0cdbef5d..940aafcf62b12d 100644 --- a/flink-python/pyflink/table/table_environment.py +++ b/flink-python/pyflink/table/table_environment.py @@ -542,6 +542,36 @@ def register_java_function(self, name, function_class_name): .loadClass(function_class_name).newInstance() self._j_tenv.registerFunction(name, java_function) + def register_function(self, name, function): + """ + Registers a python user-defined function under a unique name. Replaces already existing + user-defined function under this name. + + Example: + :: + + >>> table_env.register_function( + ... "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())) + + >>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], + ... result_type=DataTypes.BIGINT()) + ... def add(i, j): + ... return i + j + >>> table_env.register_function("add", add) + + >>> class SubtractOne(ScalarFunction): + ... def eval(self, i): + ... return i - 1 + >>> table_env.register_function( + ... "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) + + :param name: The name under which the function is registered. + :type name: str + :param function: The python user-defined function to register. + :type function: UserDefinedFunctionWrapper + """ + self._j_tenv.registerFunction(name, function._judf) + def execute(self, job_name): """ Triggers the program execution. The environment will execute all parts of diff --git a/flink-python/pyflink/table/tests/test_environment_completeness.py b/flink-python/pyflink/table/tests/test_environment_completeness.py index 5f73c3c88b5670..89459891f8f3e5 100644 --- a/flink-python/pyflink/table/tests/test_environment_completeness.py +++ b/flink-python/pyflink/table/tests/test_environment_completeness.py @@ -41,8 +41,7 @@ def excluded_methods(cls): # registerCatalog, getCatalog and listTables should be supported when catalog supported in # python. getCompletionHints has been deprecated. It will be removed in the next release. # TODO add TableEnvironment#create method with EnvironmentSettings as a parameter - return {'registerCatalog', 'getCatalog', 'registerFunction', 'listTables', - 'getCompletionHints', 'create'} + return {'registerCatalog', 'getCatalog', 'listTables', 'getCompletionHints', 'create'} if __name__ == '__main__': diff --git a/flink-python/pyflink/table/tests/test_udf.py b/flink-python/pyflink/table/tests/test_udf.py new file mode 100644 index 00000000000000..bf46ea9a5d3fd8 --- /dev/null +++ b/flink-python/pyflink/table/tests/test_udf.py @@ -0,0 +1,245 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +from pyflink.table import DataTypes +from pyflink.table.udf import ScalarFunction, udf +from pyflink.testing import source_sink_utils +from pyflink.testing.test_case_utils import PyFlinkStreamTableTestCase + + +class UserDefinedFunctionTests(PyFlinkStreamTableTestCase): + + def test_scalar_function(self): + # test lambda function + self.t_env.register_function( + "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())) + + # test Python ScalarFunction + self.t_env.register_function( + "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) + + # test Python function + self.t_env.register_function("add", add) + + # test callable function + self.t_env.register_function( + "add_one_callable", udf(CallablePlus(), DataTypes.BIGINT(), DataTypes.BIGINT())) + + def partial_func(col, param): + return col + param + + # test partial function + import functools + self.t_env.register_function( + "add_one_partial", + udf(functools.partial(partial_func, param=1), DataTypes.BIGINT(), DataTypes.BIGINT())) + + table_sink = source_sink_utils.TestAppendSink( + ['a', 'b', 'c', 'd', 'e'], + [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT(), + DataTypes.BIGINT()]) + self.t_env.register_table_sink("Results", table_sink) + + t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c']) + t.where("add_one(b) <= 3") \ + .select("add_one(a), subtract_one(b), add(a, c), add_one_callable(a), " + "add_one_partial(a)") \ + .insert_into("Results") + self.t_env.execute("test") + actual = source_sink_utils.results() + self.assert_equals(actual, ["2,1,4,2,2", "4,0,12,4,4"]) + + def test_chaining_scalar_function(self): + self.t_env.register_function( + "add_one", udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())) + self.t_env.register_function( + "subtract_one", udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())) + self.t_env.register_function("add", add) + + table_sink = source_sink_utils.TestAppendSink(['a'], [DataTypes.BIGINT()]) + self.t_env.register_table_sink("Results", table_sink) + + t = self.t_env.from_elements([(1, 2), (2, 5), (3, 1)], ['a', 'b']) + t.select("add(add_one(a), subtract_one(b))") \ + .insert_into("Results") + self.t_env.execute("test") + actual = source_sink_utils.results() + self.assert_equals(actual, ["3", "7", "4"]) + + def test_udf_in_join_condition(self): + t1 = self.t_env.from_elements([(2, "Hi")], ['a', 'b']) + t2 = self.t_env.from_elements([(2, "Flink")], ['c', 'd']) + + self.t_env.register_function("f", udf(lambda i: i, DataTypes.BIGINT(), DataTypes.BIGINT())) + + table_sink = source_sink_utils.TestAppendSink( + ['a', 'b', 'c', 'd'], + [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.BIGINT(), DataTypes.STRING()]) + self.t_env.register_table_sink("Results", table_sink) + + t1.join(t2).where("f(a) = c").insert_into("Results") + self.t_env.execute("test") + actual = source_sink_utils.results() + self.assert_equals(actual, ["2,Hi,2,Flink"]) + + def test_udf_in_join_condition_2(self): + t1 = self.t_env.from_elements([(1, "Hi"), (2, "Hi")], ['a', 'b']) + t2 = self.t_env.from_elements([(2, "Flink")], ['c', 'd']) + + self.t_env.register_function("f", udf(lambda i: i, DataTypes.BIGINT(), DataTypes.BIGINT())) + + table_sink = source_sink_utils.TestAppendSink( + ['a', 'b', 'c', 'd'], + [DataTypes.BIGINT(), DataTypes.STRING(), DataTypes.BIGINT(), DataTypes.STRING()]) + self.t_env.register_table_sink("Results", table_sink) + + t1.join(t2).where("f(a) = f(c)").insert_into("Results") + self.t_env.execute("test") + actual = source_sink_utils.results() + self.assert_equals(actual, ["2,Hi,2,Flink"]) + + def test_overwrite_builtin_function(self): + self.t_env.register_function( + "plus", udf(lambda i, j: i + j - 1, + [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())) + + table_sink = source_sink_utils.TestAppendSink(['a'], [DataTypes.BIGINT()]) + self.t_env.register_table_sink("Results", table_sink) + + t = self.t_env.from_elements([(1, 2, 3), (2, 5, 6), (3, 1, 9)], ['a', 'b', 'c']) + t.select("plus(a, b)").insert_into("Results") + self.t_env.execute("test") + actual = source_sink_utils.results() + self.assert_equals(actual, ["2", "6", "3"]) + + def test_open(self): + self.t_env.register_function( + "subtract", udf(Subtract(), DataTypes.BIGINT(), DataTypes.BIGINT())) + table_sink = source_sink_utils.TestAppendSink( + ['a', 'b'], [DataTypes.BIGINT(), DataTypes.BIGINT()]) + self.t_env.register_table_sink("Results", table_sink) + + t = self.t_env.from_elements([(1, 2), (2, 5), (3, 4)], ['a', 'b']) + t.select("a, subtract(b)").insert_into("Results") + self.t_env.execute("test") + actual = source_sink_utils.results() + self.assert_equals(actual, ["1,1", "2,4", "3,3"]) + + def test_deterministic(self): + add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()) + self.assertTrue(add_one._deterministic) + + add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), deterministic=False) + self.assertFalse(add_one._deterministic) + + subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()) + self.assertTrue(subtract_one._deterministic) + + with self.assertRaises(ValueError, msg="Inconsistent deterministic: False and True"): + udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT(), deterministic=False) + + self.assertTrue(add._deterministic) + + @udf(input_types=DataTypes.BIGINT(), result_type=DataTypes.BIGINT(), deterministic=False) + def non_deterministic_udf(i): + return i + + self.assertFalse(non_deterministic_udf._deterministic) + + def test_name(self): + add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()) + self.assertEqual("", add_one._name) + + add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT(), name="add_one") + self.assertEqual("add_one", add_one._name) + + subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()) + self.assertEqual("SubtractOne", subtract_one._name) + + subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT(), + name="subtract_one") + self.assertEqual("subtract_one", subtract_one._name) + + self.assertEqual("add", add._name) + + @udf(input_types=DataTypes.BIGINT(), result_type=DataTypes.BIGINT(), name="named") + def named_udf(i): + return i + + self.assertEqual("named", named_udf._name) + + def test_abc(self): + class UdfWithoutEval(ScalarFunction): + def open(self, function_context): + pass + + with self.assertRaises( + TypeError, + msg="Can't instantiate abstract class UdfWithoutEval with abstract methods eval"): + UdfWithoutEval() + + def test_invalid_udf(self): + class Plus(object): + def eval(self, col): + return col + 1 + + with self.assertRaises( + TypeError, + msg="Invalid function: not a function or callable (__call__ is not defined)"): + # test non-callable function + self.t_env.register_function( + "non-callable-udf", udf(Plus(), DataTypes.BIGINT(), DataTypes.BIGINT())) + + +@udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT()) +def add(i, j): + return i + j + + +class SubtractOne(ScalarFunction): + + def eval(self, i): + return i - 1 + + +class Subtract(ScalarFunction): + + def __init__(self): + self.subtracted_value = 0 + + def open(self, function_context): + self.subtracted_value = 1 + + def eval(self, i): + return i - self.subtracted_value + + +class CallablePlus(object): + + def __call__(self, col): + return col + 1 + + +if __name__ == '__main__': + import unittest + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/flink-python/pyflink/table/types.py b/flink-python/pyflink/table/types.py index dbf92f88bb8c25..09186240a62904 100644 --- a/flink-python/pyflink/table/types.py +++ b/flink-python/pyflink/table/types.py @@ -2005,7 +2005,7 @@ def __repr__(self): return "Row(%s)" % ", ".join("%s=%r" % (k, v) for k, v in zip(self._fields, tuple(self))) else: - return "" % ", ".join(self) + return "" % ", ".join("%r" % field for field in self) _acceptable_types = { diff --git a/flink-python/pyflink/table/udf.py b/flink-python/pyflink/table/udf.py new file mode 100644 index 00000000000000..97ac3910708e68 --- /dev/null +++ b/flink-python/pyflink/table/udf.py @@ -0,0 +1,229 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import abc +import collections +import functools +import inspect +import sys + +from pyflink.java_gateway import get_gateway +from pyflink.table.types import DataType, _to_java_type +from pyflink.util import utils + +__all__ = ['FunctionContext', 'ScalarFunction', 'udf'] + + +if sys.version_info >= (3, 4): + ABC = abc.ABC +else: + ABC = abc.ABCMeta('ABC', (), {}) + + +class FunctionContext(object): + """ + Used to obtain global runtime information about the context in which the + user-defined function is executed. The information includes the metric group, + and global job parameters, etc. + """ + pass + + +class UserDefinedFunction(ABC): + """ + Base interface for user-defined function. + """ + + def open(self, function_context): + """ + Initialization method for the function. It is called before the actual working methods + and thus suitable for one time setup work. + + :param function_context: the context of the function + :type function_context: FunctionContext + """ + pass + + def close(self): + """ + Tear-down method for the user code. It is called after the last call to the main + working methods. + """ + pass + + def is_deterministic(self): + """ + Returns information about the determinism of the function's results. + It returns true if and only if a call to this function is guaranteed to + always return the same result given the same parameters. true is assumed by default. + If the function is not pure functional like random(), date(), now(), + this method must return false. + + :return: the determinism of the function's results. + :rtype: bool + """ + return True + + +class ScalarFunction(UserDefinedFunction): + """ + Base interface for user-defined scalar function. A user-defined scalar functions maps zero, one, + or multiple scalar values to a new scalar value. + """ + + @abc.abstractmethod + def eval(self, *args): + """ + Method which defines the logic of the scalar function. + """ + pass + + +class DelegatingScalarFunction(ScalarFunction): + """ + Helper scalar function implementation for lambda expression and python function. It's for + internal use only. + """ + + def __init__(self, func): + self.func = func + + def eval(self, *args): + return self.func(*args) + + +class UserDefinedFunctionWrapper(object): + """ + Wrapper for Python user-defined function. It handles things like converting lambda + functions to user-defined functions, creating the Java user-defined function representation, + etc. It's for internal use only. + """ + + def __init__(self, func, input_types, result_type, deterministic=None, name=None): + if inspect.isclass(func) or ( + not isinstance(func, UserDefinedFunction) and not callable(func)): + raise TypeError( + "Invalid function: not a function or callable (__call__ is not defined): {0}" + .format(type(func))) + + if not isinstance(input_types, collections.Iterable): + input_types = [input_types] + + for input_type in input_types: + if not isinstance(input_type, DataType): + raise TypeError( + "Invalid input_type: input_type should be DataType but contains {}".format( + input_type)) + + if not isinstance(result_type, DataType): + raise TypeError( + "Invalid returnType: returnType should be DataType but is {}".format(result_type)) + + self._func = func + self._input_types = input_types + self._result_type = result_type + self._judf_placeholder = None + self._name = name or ( + func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) + + if deterministic is not None and isinstance(func, UserDefinedFunction) and deterministic \ + != func.is_deterministic(): + raise ValueError("Inconsistent deterministic: {} and {}".format( + deterministic, func.is_deterministic())) + + # default deterministic is True + self._deterministic = deterministic if deterministic is not None else ( + func.is_deterministic() if isinstance(func, UserDefinedFunction) else True) + + @property + def _judf(self): + if self._judf_placeholder is None: + self._judf_placeholder = self._create_judf() + return self._judf_placeholder + + def _create_judf(self): + func = self._func + if not isinstance(self._func, UserDefinedFunction): + func = DelegatingScalarFunction(self._func) + + import cloudpickle + serialized_func = cloudpickle.dumps(func) + + gateway = get_gateway() + j_input_types = utils.to_jarray(gateway.jvm.TypeInformation, + [_to_java_type(i) for i in self._input_types]) + j_result_type = _to_java_type(self._result_type) + return gateway.jvm.org.apache.flink.table.util.python.PythonTableUtils \ + .createPythonScalarFunction(self._name, + bytearray(serialized_func), + j_input_types, + j_result_type, + self._deterministic, + _get_python_env()) + + +# TODO: support to configure the python execution environment +def _get_python_env(): + gateway = get_gateway() + exec_type = gateway.jvm.org.apache.flink.table.functions.python.PythonEnv.ExecType.PROCESS + return gateway.jvm.org.apache.flink.table.functions.python.PythonEnv(exec_type) + + +def _create_udf(f, input_types, result_type, deterministic, name): + return UserDefinedFunctionWrapper(f, input_types, result_type, deterministic, name) + + +def udf(f=None, input_types=None, result_type=None, deterministic=None, name=None): + """ + Helper method for creating a user-defined function. + + Example: + :: + + >>> add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()) + + >>> @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], + ... result_type=DataTypes.BIGINT()) + ... def add(i, j): + ... return i + j + + >>> class SubtractOne(ScalarFunction): + ... def eval(self, i): + ... return i - 1 + >>> subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()) + + :param f: lambda function or user-defined function. + :type f: function or UserDefinedFunction or type + :param input_types: the input data types. + :type input_types: list[DataType] or DataType + :param result_type: the result data type. + :type result_type: DataType + :param name: the function name. + :type name: str + :param deterministic: the determinism of the function's results. True if and only if a call to + this function is guaranteed to always return the same result given the + same parameters. (default True) + :type deterministic: bool + :return: UserDefinedFunctionWrapper or function. + :rtype: UserDefinedFunctionWrapper or function + """ + # decorator + if f is None: + return functools.partial(_create_udf, input_types=input_types, result_type=result_type, + deterministic=deterministic, name=name) + else: + return _create_udf(f, input_types, result_type, deterministic, name) diff --git a/flink-python/pyflink/testing/test_case_utils.py b/flink-python/pyflink/testing/test_case_utils.py index c1d484ed01c041..21d3f09e629aba 100644 --- a/flink-python/pyflink/testing/test_case_utils.py +++ b/flink-python/pyflink/testing/test_case_utils.py @@ -27,11 +27,10 @@ from py4j.java_gateway import JavaObject from py4j.protocol import Py4JJavaError +from pyflink import gen_protos from pyflink.table.sources import CsvTableSource - from pyflink.dataset import ExecutionEnvironment from pyflink.datastream import StreamExecutionEnvironment - from pyflink.find_flink_home import _find_flink_home from pyflink.table import BatchTableEnvironment, StreamTableEnvironment from pyflink.java_gateway import get_gateway @@ -76,6 +75,8 @@ class PyFlinkTestCase(unittest.TestCase): def setUpClass(cls): cls.tempdir = tempfile.mkdtemp() + gen_protos.generate_proto_files() + os.environ["FLINK_TESTING"] = "1" _find_flink_home() diff --git a/flink-python/setup.py b/flink-python/setup.py index 78a6edd6beb320..f5b26da98e7ee6 100644 --- a/flink-python/setup.py +++ b/flink-python/setup.py @@ -23,6 +23,12 @@ from shutil import copytree, copy, rmtree from setuptools import setup +from setuptools.command.install import install +from setuptools.command.build_py import build_py +from setuptools.command.develop import develop +from setuptools.command.egg_info import egg_info +from setuptools.command.sdist import sdist +from setuptools.command.test import test if sys.version_info < (2, 7): print("Python versions prior to 2.7 are not supported for PyFlink.", @@ -62,6 +68,24 @@ in_flink_source = os.path.isfile("../flink-java/src/main/java/org/apache/flink/api/java/" "ExecutionEnvironment.java") + +# We must generate protos after setup_requires are installed. +def generate_protos_first(original_cmd): + try: + # pylint: disable=wrong-import-position + from pyflink import gen_protos + + class cmd(original_cmd, object): + def run(self): + gen_protos.generate_proto_files() + super(cmd, self).run() + return cmd + except ImportError: + import warnings + warnings.warn("Could not import gen_protos, skipping proto generation.") + return original_cmd + + try: if in_flink_source: @@ -184,7 +208,8 @@ license='https://www.apache.org/licenses/LICENSE-2.0', author='Flink Developers', author_email='dev@flink.apache.org', - install_requires=['py4j==0.10.8.1', 'python-dateutil', 'apache-beam==2.15.0'], + install_requires=['py4j==0.10.8.1', 'python-dateutil', 'apache-beam==2.15.0', + 'cloudpickle==1.2.2'], tests_require=['pytest==4.4.1'], description='Apache Flink Python API', long_description=long_description, @@ -195,7 +220,15 @@ 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7'] + 'Programming Language :: Python :: 3.7'], + cmdclass={ + 'build_py': generate_protos_first(build_py), + 'develop': generate_protos_first(develop), + 'egg_info': generate_protos_first(egg_info), + 'sdist': generate_protos_first(sdist), + 'test': generate_protos_first(test), + 'install': generate_protos_first(install), + }, ) finally: if in_flink_source: diff --git a/flink-python/src/main/java/org/apache/flink/python/AbstractPythonFunctionRunner.java b/flink-python/src/main/java/org/apache/flink/python/AbstractPythonFunctionRunner.java index 73c753544a0cee..53e6ef94fb1674 100644 --- a/flink-python/src/main/java/org/apache/flink/python/AbstractPythonFunctionRunner.java +++ b/flink-python/src/main/java/org/apache/flink/python/AbstractPythonFunctionRunner.java @@ -20,6 +20,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.core.memory.ByteArrayInputStreamWithPos; import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos; import org.apache.flink.table.functions.python.PythonEnv; @@ -50,8 +51,6 @@ import java.io.FileOutputStream; import java.io.IOException; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import java.util.Random; /** @@ -280,14 +279,15 @@ private static String randomString(Random random) { */ protected RunnerApi.Environment createPythonExecutionEnvironment() { if (pythonEnv.getExecType() == PythonEnv.ExecType.PROCESS) { - final Map env = new HashMap<>(2); - env.put("python", pythonEnv.getPythonExec()); + String flinkHomePath = System.getenv(ConfigConstants.ENV_FLINK_HOME_DIR); + String pythonWorkerCommand = + flinkHomePath + File.separator + "bin" + File.separator + "pyflink-udf-runner.sh"; return Environments.createProcessEnvironment( "", "", - pythonEnv.getPythonWorkerCmd(), - env); + pythonWorkerCommand, + null); } else { throw new UnsupportedOperationException(String.format( "Execution type '%s' is not supported.", pythonEnv.getExecType())); diff --git a/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonScalarFunctionRunnerTest.java b/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonScalarFunctionRunnerTest.java index c53f69a06b37f4..4c28d6a8515d53 100644 --- a/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonScalarFunctionRunnerTest.java +++ b/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonScalarFunctionRunnerTest.java @@ -121,7 +121,7 @@ public byte[] getSerializedPythonFunction() { @Override public PythonEnv getPythonEnv() { - return new PythonEnv("", "", PythonEnv.ExecType.PROCESS); + return new PythonEnv(PythonEnv.ExecType.PROCESS); } } } diff --git a/flink-python/src/test/java/org/apache/flink/table/functions/python/BaseRowPythonScalarFunctionRunnerTest.java b/flink-python/src/test/java/org/apache/flink/table/functions/python/BaseRowPythonScalarFunctionRunnerTest.java index 02f5b7462ca8c6..dacd5092b54f47 100644 --- a/flink-python/src/test/java/org/apache/flink/table/functions/python/BaseRowPythonScalarFunctionRunnerTest.java +++ b/flink-python/src/test/java/org/apache/flink/table/functions/python/BaseRowPythonScalarFunctionRunnerTest.java @@ -116,7 +116,7 @@ public AbstractPythonScalarFunctionRunner createPythonScalarFu // ignore the execution results }; - final PythonEnv pythonEnv = new PythonEnv("", "", PythonEnv.ExecType.PROCESS); + final PythonEnv pythonEnv = new PythonEnv(PythonEnv.ExecType.PROCESS); return new BaseRowPythonScalarFunctionRunner( "testPythonRunner", diff --git a/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonScalarFunctionRunnerTest.java b/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonScalarFunctionRunnerTest.java index 6a9398a202e758..212bf2a6e8e964 100644 --- a/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonScalarFunctionRunnerTest.java +++ b/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonScalarFunctionRunnerTest.java @@ -244,7 +244,7 @@ public AbstractPythonScalarFunctionRunner createPythonScalarFunctionRu // ignore the execution results }; - final PythonEnv pythonEnv = new PythonEnv("", "", PythonEnv.ExecType.PROCESS); + final PythonEnv pythonEnv = new PythonEnv(PythonEnv.ExecType.PROCESS); return new PythonScalarFunctionRunner( "testPythonRunner", @@ -266,7 +266,7 @@ private AbstractPythonScalarFunctionRunner createUDFRunner( RowType rowType = new RowType(Collections.singletonList(new RowType.RowField("f1", new BigIntType()))); - final PythonEnv pythonEnv = new PythonEnv("", "", PythonEnv.ExecType.PROCESS); + final PythonEnv pythonEnv = new PythonEnv(PythonEnv.ExecType.PROCESS); return new PythonScalarFunctionRunnerTestHarness( "testPythonRunner", diff --git a/flink-python/tox.ini b/flink-python/tox.ini index 4efde543a02054..873af202864577 100644 --- a/flink-python/tox.ini +++ b/flink-python/tox.ini @@ -38,4 +38,4 @@ commands = # up to 100 characters in length, not 79. ignore=E226,E241,E305,E402,E722,E731,E741,W503,W504 max-line-length=100 -exclude=.tox/*,dev/*,lib/*,target/*,build/*,dist/*,pyflink/shell.py,.eggs/*,pyflink/fn_execution/tests/process_mode_test_data.py +exclude=.tox/*,dev/*,lib/*,target/*,build/*,dist/*,pyflink/shell.py,.eggs/*,pyflink/fn_execution/tests/process_mode_test_data.py,pyflink/fn_execution/*_pb2.py diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonEnv.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonEnv.java index 5d2be94cb9ca82..d01a08d95ccbf9 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonEnv.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonEnv.java @@ -31,38 +31,15 @@ public final class PythonEnv implements Serializable { private static final long serialVersionUID = 1L; - /** - * The path of the Python executable file used. - */ - private final String pythonExec; - - /** - * The command to start Python worker process. - */ - private final String pythonWorkerCmd; - /** * The execution type of the Python worker, it defines how to execute the Python functions. */ private final ExecType execType; - public PythonEnv( - String pythonExec, - String pythonWorkerCmd, - ExecType execType) { - this.pythonExec = Preconditions.checkNotNull(pythonExec); - this.pythonWorkerCmd = Preconditions.checkNotNull(pythonWorkerCmd); + public PythonEnv(ExecType execType) { this.execType = Preconditions.checkNotNull(execType); } - public String getPythonExec() { - return pythonExec; - } - - public String getPythonWorkerCmd() { - return pythonWorkerCmd; - } - public ExecType getExecType() { return execType; } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/PythonFunctionCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/PythonFunctionCodeGenerator.scala new file mode 100644 index 00000000000000..e97e7b79eeb797 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/PythonFunctionCodeGenerator.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.codegen + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.codegen.CodeGenUtils.{primitiveDefaultValue, primitiveTypeTermForTypeInfo, newName} +import org.apache.flink.table.codegen.Indenter.toISC +import org.apache.flink.table.functions.{UserDefinedFunction, FunctionLanguage, ScalarFunction} +import org.apache.flink.table.functions.python.{PythonEnv, PythonFunction} +import org.apache.flink.table.utils.EncodingUtils + +/** + * A code generator for generating Python [[UserDefinedFunction]]s. + */ +object PythonFunctionCodeGenerator extends Compiler[UserDefinedFunction] { + + private val PYTHON_SCALAR_FUNCTION_NAME = "PythonScalarFunction" + + /** + * Generates a [[ScalarFunction]] for the specified Python user-defined function. + * + * @param name class name of the user-defined function. Must be a valid Java class identifier + * @param serializedScalarFunction serialized Python scalar function + * @param inputTypes input data types + * @param resultType expected result type + * @param deterministic the determinism of the function's results + * @param pythonEnv the Python execution environment + * @return instance of generated ScalarFunction + */ + def generateScalarFunction( + name: String, + serializedScalarFunction: Array[Byte], + inputTypes: Array[TypeInformation[_]], + resultType: TypeInformation[_], + deterministic: Boolean, + pythonEnv: PythonEnv): ScalarFunction = { + val funcName = newName(PYTHON_SCALAR_FUNCTION_NAME) + val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType) + val defaultResultValue = primitiveDefaultValue(resultType) + val inputParamCode = inputTypes.zipWithIndex.map { case (inputType, index) => + s"${primitiveTypeTermForTypeInfo(inputType)} in$index" + }.mkString(", ") + + val encodingUtilsTypeTerm = classOf[EncodingUtils].getCanonicalName + val typeInfoTypeTerm = classOf[TypeInformation[_]].getCanonicalName + val inputTypesCode = inputTypes.map(EncodingUtils.encodeObjectToString).map { inputType => + s""" + |($typeInfoTypeTerm) $encodingUtilsTypeTerm.decodeStringToObject( + | "$inputType", $typeInfoTypeTerm.class) + |""".stripMargin + }.mkString(", ") + + val encodedResultType = EncodingUtils.encodeObjectToString(resultType) + val encodedScalarFunction = EncodingUtils.encodeBytesToBase64(serializedScalarFunction) + val encodedPythonEnv = EncodingUtils.encodeObjectToString(pythonEnv) + val pythonEnvTypeTerm = classOf[PythonEnv].getCanonicalName + + val funcCode = j""" + |public class $funcName extends ${classOf[ScalarFunction].getCanonicalName} + | implements ${classOf[PythonFunction].getCanonicalName} { + | + | private static final long serialVersionUID = 1L; + | + | public $resultTypeTerm eval($inputParamCode) { + | return $defaultResultValue; + | } + | + | @Override + | public $typeInfoTypeTerm[] getParameterTypes(Class[] signature) { + | return new $typeInfoTypeTerm[]{$inputTypesCode}; + | } + | + | @Override + | public $typeInfoTypeTerm getResultType(Class[] signature) { + | return ($typeInfoTypeTerm) $encodingUtilsTypeTerm.decodeStringToObject( + | "$encodedResultType", $typeInfoTypeTerm.class); + | } + | + | @Override + | public ${classOf[FunctionLanguage].getCanonicalName} getLanguage() { + | return ${classOf[FunctionLanguage].getCanonicalName}.PYTHON; + | } + | + | @Override + | public byte[] getSerializedPythonFunction() { + | return $encodingUtilsTypeTerm.decodeBase64ToBytes("$encodedScalarFunction"); + | } + | + | @Override + | public $pythonEnvTypeTerm getPythonEnv() { + | return ($pythonEnvTypeTerm) $encodingUtilsTypeTerm.decodeStringToObject( + | "$encodedPythonEnv", $pythonEnvTypeTerm.class); + | } + | + | @Override + | public boolean isDeterministic() { + | return $deterministic; + | } + | + | @Override + | public String toString() { + | return "$name"; + | } + |} + |""".stripMargin + + val clazz = compile( + Thread.currentThread().getContextClassLoader, + funcName, + funcCode) + clazz.newInstance().asInstanceOf[ScalarFunction] + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala index 094945c344bbfc..a84dad81599bce 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/util/python/PythonTableUtils.scala @@ -31,6 +31,9 @@ import org.apache.flink.api.java.io.CollectionInputFormat import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo} import org.apache.flink.core.io.InputSplit import org.apache.flink.table.api.{TableSchema, Types} +import org.apache.flink.table.codegen.PythonFunctionCodeGenerator +import org.apache.flink.table.functions.ScalarFunction +import org.apache.flink.table.functions.python.PythonEnv import org.apache.flink.table.sources.InputFormatTableSource import org.apache.flink.types.Row @@ -38,6 +41,32 @@ import scala.collection.JavaConversions._ object PythonTableUtils { + /** + * Creates a [[ScalarFunction]] for the specified Python ScalarFunction. + * + * @param funcName class name of the user-defined function. Must be a valid Java class identifier + * @param serializedScalarFunction serialized Python scalar function + * @param inputTypes input data types + * @param resultType expected result type + * @param deterministic the determinism of the function's results + * @param pythonEnv the Python execution environment + * @return A generated Java ScalarFunction representation for the specified Python ScalarFunction + */ + def createPythonScalarFunction( + funcName: String, + serializedScalarFunction: Array[Byte], + inputTypes: Array[TypeInformation[_]], + resultType: TypeInformation[_], + deterministic: Boolean, + pythonEnv: PythonEnv): ScalarFunction = + PythonFunctionCodeGenerator.generateScalarFunction( + funcName, + serializedScalarFunction, + inputTypes, + resultType, + deterministic, + pythonEnv) + /** * Wrap the unpickled python data with an InputFormat. It will be passed to * PythonInputFormatTableSource later. diff --git a/pom.xml b/pom.xml index 5a8b5983917b0b..832e298d850f5e 100644 --- a/pom.xml +++ b/pom.xml @@ -1422,6 +1422,7 @@ under the License. flink-python/lib/** flink-python/dev/download/** flink-python/docs/_build/** + flink-python/pyflink/fn_execution/*_pb2.py