From e2ea6049aa8c495dda6b78395435170fa0c56b8c Mon Sep 17 00:00:00 2001 From: huangxingbo Date: Fri, 7 Feb 2020 16:07:29 +0800 Subject: [PATCH] [FLINK-15913][python] Add Python Table Function Runner And Operator In Legacy Planner-fix-1 --- .../fn_execution/flink_fn_execution_pb2.py | 8 +- .../pyflink/proto/flink-fn-execution.proto | 2 +- .../AbstractPythonScalarFunctionOperator.java | 139 +--------- .../AbstractPythonTableFunctionOperator.java | 136 +--------- .../AbstractStatelessFunctionOperator.java | 251 ++++++++++++++++++ .../BaseRowPythonScalarFunctionOperator.java | 29 -- .../python/PythonScalarFunctionOperator.java | 35 --- .../python/PythonTableFunctionOperator.java | 92 +++---- .../AbstractPythonScalarFunctionRunner.java | 7 +- .../AbstractPythonTableFunctionRunner.java | 33 ++- .../python/PythonTableFunctionRunner.java | 13 +- .../runtime/typeutils/PythonTypeUtils.java | 43 +-- .../python/RowTableSerializer.java | 60 ----- .../serializers/python/TableSerializer.java | 152 ----------- ...AbstractPythonTableFunctionRunnerTest.java | 29 +- .../python/PythonTableFunctionRunnerTest.java | 29 +- .../functions/python/PythonTypeUtilsTest.java | 15 +- ...tPassThroughPythonTableFunctionRunner.java | 64 ----- .../PassThroughPythonTableFunctionRunner.java | 61 ++++- .../PythonTableFunctionOperatorTest.java | 21 +- .../PythonTableFunctionOperatorTestBase.java | 22 +- 21 files changed, 454 insertions(+), 787 deletions(-) create mode 100644 flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractStatelessFunctionOperator.java delete mode 100644 flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/RowTableSerializer.java delete mode 100644 flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/TableSerializer.java delete mode 100644 flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/AbstractPassThroughPythonTableFunctionRunner.java diff --git a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py index 5093057cf8cb4..4ace2334aaa47 100644 --- a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py +++ b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py @@ -36,7 +36,7 @@ name='flink-fn-execution.proto', package='org.apache.flink.fn_execution.v1', syntax='proto3', - serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\xfc\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12K\n\x06inputs\x18\x02 \x03(\x0b\x32;.org.apache.flink.fn_execution.v1.UserDefinedFunction.Input\x1a\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\n\x05input\"[\n\x14UserDefinedFunctions\x12\x43\n\x04udfs\x18\x01 \x03(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunction\"\x8b\t\n\x06Schema\x12>\n\x06\x66ields\x18\x01 \x03(\x0b\x32..org.apache.flink.fn_execution.v1.Schema.Field\x1a\x97\x01\n\x07MapType\x12\x44\n\x08key_type\x18\x01 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldType\x12\x46\n\nvalue_type\x18\x02 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldType\x1a!\n\x0c\x44\x61teTimeType\x12\x11\n\tprecision\x18\x01 \x01(\x05\x1a/\n\x0b\x44\x65\x63imalType\x12\x11\n\tprecision\x18\x01 \x01(\x05\x12\r\n\x05scale\x18\x02 \x01(\x05\x1a\xec\x03\n\tFieldType\x12\x44\n\ttype_name\x18\x01 \x01(\x0e\x32\x31.org.apache.flink.fn_execution.v1.Schema.TypeName\x12\x10\n\x08nullable\x18\x02 \x01(\x08\x12U\n\x17\x63ollection_element_type\x18\x03 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldTypeH\x00\x12\x44\n\x08map_type\x18\x04 \x01(\x0b\x32\x30.org.apache.flink.fn_execution.v1.Schema.MapTypeH\x00\x12>\n\nrow_schema\x18\x05 \x01(\x0b\x32(.org.apache.flink.fn_execution.v1.SchemaH\x00\x12O\n\x0e\x64\x61te_time_type\x18\x06 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.Schema.DateTimeTypeH\x00\x12L\n\x0c\x64\x65\x63imal_type\x18\x07 \x01(\x0b\x32\x34.org.apache.flink.fn_execution.v1.Schema.DecimalTypeH\x00\x42\x0b\n\ttype_info\x1al\n\x05\x46ield\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12@\n\x04type\x18\x03 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldType\"\xf5\x01\n\x08TypeName\x12\x07\n\x03ROW\x10\x00\x12\x0b\n\x07TINYINT\x10\x01\x12\x0c\n\x08SMALLINT\x10\x02\x12\x07\n\x03INT\x10\x03\x12\n\n\x06\x42IGINT\x10\x04\x12\x0b\n\x07\x44\x45\x43IMAL\x10\x05\x12\t\n\x05\x46LOAT\x10\x06\x12\n\n\x06\x44OUBLE\x10\x07\x12\x08\n\x04\x44\x41TE\x10\x08\x12\x08\n\x04TIME\x10\t\x12\x0c\n\x08\x44\x41TETIME\x10\n\x12\x0b\n\x07\x42OOLEAN\x10\x0b\x12\n\n\x06\x42INARY\x10\x0c\x12\r\n\tVARBINARY\x10\r\x12\x08\n\x04\x43HAR\x10\x0e\x12\x0b\n\x07VARCHAR\x10\x0f\x12\t\n\x05\x41RRAY\x10\x10\x12\x07\n\x03MAP\x10\x11\x12\x0c\n\x08MULTISET\x10\x12\x12\t\n\x05TABLE\x10\x13\x42-\n\x1forg.apache.flink.fnexecution.v1B\nFlinkFnApib\x06proto3') + serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\xfc\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12K\n\x06inputs\x18\x02 \x03(\x0b\x32;.org.apache.flink.fn_execution.v1.UserDefinedFunction.Input\x1a\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\n\x05input\"[\n\x14UserDefinedFunctions\x12\x43\n\x04udfs\x18\x01 \x03(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunction\"\x96\t\n\x06Schema\x12>\n\x06\x66ields\x18\x01 \x03(\x0b\x32..org.apache.flink.fn_execution.v1.Schema.Field\x1a\x97\x01\n\x07MapType\x12\x44\n\x08key_type\x18\x01 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldType\x12\x46\n\nvalue_type\x18\x02 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldType\x1a!\n\x0c\x44\x61teTimeType\x12\x11\n\tprecision\x18\x01 \x01(\x05\x1a/\n\x0b\x44\x65\x63imalType\x12\x11\n\tprecision\x18\x01 \x01(\x05\x12\r\n\x05scale\x18\x02 \x01(\x05\x1a\xec\x03\n\tFieldType\x12\x44\n\ttype_name\x18\x01 \x01(\x0e\x32\x31.org.apache.flink.fn_execution.v1.Schema.TypeName\x12\x10\n\x08nullable\x18\x02 \x01(\x08\x12U\n\x17\x63ollection_element_type\x18\x03 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldTypeH\x00\x12\x44\n\x08map_type\x18\x04 \x01(\x0b\x32\x30.org.apache.flink.fn_execution.v1.Schema.MapTypeH\x00\x12>\n\nrow_schema\x18\x05 \x01(\x0b\x32(.org.apache.flink.fn_execution.v1.SchemaH\x00\x12O\n\x0e\x64\x61te_time_type\x18\x06 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.Schema.DateTimeTypeH\x00\x12L\n\x0c\x64\x65\x63imal_type\x18\x07 \x01(\x0b\x32\x34.org.apache.flink.fn_execution.v1.Schema.DecimalTypeH\x00\x42\x0b\n\ttype_info\x1al\n\x05\x46ield\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12@\n\x04type\x18\x03 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldType\"\x80\x02\n\x08TypeName\x12\x07\n\x03ROW\x10\x00\x12\x0b\n\x07TINYINT\x10\x01\x12\x0c\n\x08SMALLINT\x10\x02\x12\x07\n\x03INT\x10\x03\x12\n\n\x06\x42IGINT\x10\x04\x12\x0b\n\x07\x44\x45\x43IMAL\x10\x05\x12\t\n\x05\x46LOAT\x10\x06\x12\n\n\x06\x44OUBLE\x10\x07\x12\x08\n\x04\x44\x41TE\x10\x08\x12\x08\n\x04TIME\x10\t\x12\x0c\n\x08\x44\x41TETIME\x10\n\x12\x0b\n\x07\x42OOLEAN\x10\x0b\x12\n\n\x06\x42INARY\x10\x0c\x12\r\n\tVARBINARY\x10\r\x12\x08\n\x04\x43HAR\x10\x0e\x12\x0b\n\x07VARCHAR\x10\x0f\x12\t\n\x05\x41RRAY\x10\x10\x12\x07\n\x03MAP\x10\x11\x12\x0c\n\x08MULTISET\x10\x12\x12\x14\n\x10TABLEFUNCTIONROW\x10\x13\x42-\n\x1forg.apache.flink.fnexecution.v1B\nFlinkFnApib\x06proto3') ) @@ -124,14 +124,14 @@ options=None, type=None), _descriptor.EnumValueDescriptor( - name='TABLE', index=19, number=19, + name='TABLEFUNCTIONROW', index=19, number=19, options=None, type=None), ], containing_type=None, options=None, serialized_start=1329, - serialized_end=1574, + serialized_end=1585, ) _sym_db.RegisterEnumDescriptor(_SCHEMA_TYPENAME) @@ -503,7 +503,7 @@ oneofs=[ ], serialized_start=411, - serialized_end=1574, + serialized_end=1585, ) _USERDEFINEDFUNCTION_INPUT.fields_by_name['udf'].message_type = _USERDEFINEDFUNCTION diff --git a/flink-python/pyflink/proto/flink-fn-execution.proto b/flink-python/pyflink/proto/flink-fn-execution.proto index fd91d65ad1f63..b8bef7d02b015 100644 --- a/flink-python/pyflink/proto/flink-fn-execution.proto +++ b/flink-python/pyflink/proto/flink-fn-execution.proto @@ -73,7 +73,7 @@ message Schema { ARRAY = 16; MAP = 17; MULTISET = 18; - TABLE = 19; + TABLEFUNCTIONROW = 19; } message MapType { diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractPythonScalarFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractPythonScalarFunctionOperator.java index 0861f6b06bf88..9ec24b6360c1e 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractPythonScalarFunctionOperator.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractPythonScalarFunctionOperator.java @@ -20,25 +20,12 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.configuration.Configuration; -import org.apache.flink.core.memory.ByteArrayInputStreamWithPos; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; -import org.apache.flink.python.PythonFunctionRunner; -import org.apache.flink.python.env.PythonEnvironmentManager; -import org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.functions.ScalarFunction; import org.apache.flink.table.functions.python.PythonEnv; import org.apache.flink.table.functions.python.PythonFunctionInfo; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.util.Preconditions; -import org.apache.beam.sdk.fn.data.FnDataReceiver; - -import java.io.IOException; -import java.util.Arrays; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.stream.Collectors; - /** * Base class for all stream operators to execute Python {@link ScalarFunction}s. It executes the Python * {@link ScalarFunction}s in separate Python execution environment. @@ -67,7 +54,7 @@ */ @Internal public abstract class AbstractPythonScalarFunctionOperator - extends AbstractPythonFunctionOperator { + extends AbstractStatelessFunctionOperator { private static final long serialVersionUID = 1L; @@ -76,57 +63,11 @@ public abstract class AbstractPythonScalarFunctionOperator */ protected final PythonFunctionInfo[] scalarFunctions; - /** - * The input logical type. - */ - protected final RowType inputType; - - /** - * The output logical type. - */ - protected final RowType outputType; - - /** - * The offsets of udf inputs. - */ - protected final int[] udfInputOffsets; - /** * The offset of the fields which should be forwarded. */ protected final int[] forwardedFields; - /** - * The udf input logical type. - */ - protected transient RowType udfInputType; - - /** - * The udf output logical type. - */ - protected transient RowType udfOutputType; - - /** - * The queue holding the input elements for which the execution results have not been received. - */ - protected transient LinkedBlockingQueue forwardedInputQueue; - - /** - * The queue holding the user-defined function execution results. The execution results are in - * the same order as the input elements. - */ - protected transient LinkedBlockingQueue udfResultQueue; - - /** - * Reusable InputStream used to holding the execution results to be deserialized. - */ - protected transient ByteArrayInputStreamWithPos bais; - - /** - * InputStream Wrapper. - */ - protected transient DataInputViewStreamWrapper baisWrapper; - AbstractPythonScalarFunctionOperator( Configuration config, PythonFunctionInfo[] scalarFunctions, @@ -134,96 +75,20 @@ public abstract class AbstractPythonScalarFunctionOperator RowType outputType, int[] udfInputOffsets, int[] forwardedFields) { - super(config); + super(config, inputType, outputType, udfInputOffsets); this.scalarFunctions = Preconditions.checkNotNull(scalarFunctions); - this.inputType = Preconditions.checkNotNull(inputType); - this.outputType = Preconditions.checkNotNull(outputType); - this.udfInputOffsets = Preconditions.checkNotNull(udfInputOffsets); this.forwardedFields = Preconditions.checkNotNull(forwardedFields); } @Override public void open() throws Exception { - forwardedInputQueue = new LinkedBlockingQueue<>(); - udfResultQueue = new LinkedBlockingQueue<>(); - udfInputType = new RowType( - Arrays.stream(udfInputOffsets) - .mapToObj(i -> inputType.getFields().get(i)) - .collect(Collectors.toList())); udfOutputType = new RowType(outputType.getFields().subList(forwardedFields.length, outputType.getFieldCount())); - bais = new ByteArrayInputStreamWithPos(); - baisWrapper = new DataInputViewStreamWrapper(bais); super.open(); } - @Override - public void processElement(StreamRecord element) throws Exception { - bufferInput(element.getValue()); - super.processElement(element); - emitResults(); - } - @Override public PythonEnv getPythonEnv() { return scalarFunctions[0].getPythonFunction().getPythonEnv(); } - @Override - public PythonFunctionRunner createPythonFunctionRunner() throws IOException { - final FnDataReceiver udfResultReceiver = input -> { - // handover to queue, do not block the result receiver thread - udfResultQueue.put(input); - }; - - return new ProjectUdfInputPythonScalarFunctionRunner( - createPythonFunctionRunner( - udfResultReceiver, - createPythonEnvironmentManager())); - } - - /** - * Buffers the specified input, it will be used to construct - * the operator result together with the udf execution result. - */ - public abstract void bufferInput(IN input); - - public abstract UDFIN getUdfInput(IN element); - - public abstract PythonFunctionRunner createPythonFunctionRunner( - FnDataReceiver resultReceiver, - PythonEnvironmentManager pythonEnvironmentManager); - - private class ProjectUdfInputPythonScalarFunctionRunner implements PythonFunctionRunner { - - private final PythonFunctionRunner pythonFunctionRunner; - - ProjectUdfInputPythonScalarFunctionRunner(PythonFunctionRunner pythonFunctionRunner) { - this.pythonFunctionRunner = pythonFunctionRunner; - } - - @Override - public void open() throws Exception { - pythonFunctionRunner.open(); - } - - @Override - public void close() throws Exception { - pythonFunctionRunner.close(); - } - - @Override - public void startBundle() throws Exception { - pythonFunctionRunner.startBundle(); - } - - @Override - public void finishBundle() throws Exception { - pythonFunctionRunner.finishBundle(); - } - - @Override - public void processElement(IN element) throws Exception { - pythonFunctionRunner.processElement(getUdfInput(element)); - } - } } diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractPythonTableFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractPythonTableFunctionOperator.java index c9328239b223a..3d5830141b5b2 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractPythonTableFunctionOperator.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractPythonTableFunctionOperator.java @@ -18,33 +18,25 @@ package org.apache.flink.table.runtime.operators.python; +import org.apache.flink.annotation.Internal; import org.apache.flink.configuration.Configuration; -import org.apache.flink.python.PythonFunctionRunner; -import org.apache.flink.python.env.PythonEnvironmentManager; -import org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.functions.TableFunction; import org.apache.flink.table.functions.python.PythonEnv; import org.apache.flink.table.functions.python.PythonFunctionInfo; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.util.Preconditions; -import org.apache.beam.sdk.fn.data.FnDataReceiver; - import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.stream.Collectors; /** - * @param Type of the input elements. - * @param Type of the output elements. - * @param Type of the UDF input type. - * @param Type of the UDF input type. + * @param Type of the input elements. + * @param Type of the output elements. + * @param Type of the UDTF input type. */ -public abstract class AbstractPythonTableFunctionOperator - extends AbstractPythonFunctionOperator { +@Internal +public abstract class AbstractPythonTableFunctionOperator + extends AbstractStatelessFunctionOperator { private static final long serialVersionUID = 1L; @@ -53,137 +45,33 @@ public abstract class AbstractPythonTableFunctionOperator forwardedInputQueue; - - /** - * The queue holding the user-defined table function execution results. The execution results - * are in the same order as the input elements. - */ - protected transient LinkedBlockingQueue udtfResultQueue; - public AbstractPythonTableFunctionOperator( Configuration config, PythonFunctionInfo tableFunction, RowType inputType, RowType outputType, int[] udtfInputOffsets) { - super(config); + super(config, inputType, outputType, udtfInputOffsets); this.tableFunction = Preconditions.checkNotNull(tableFunction); - this.inputType = Preconditions.checkNotNull(inputType); - this.outputType = Preconditions.checkNotNull(outputType); - this.udtfInputOffsets = Preconditions.checkNotNull(udtfInputOffsets); } @Override public void open() throws Exception { - forwardedInputQueue = new LinkedBlockingQueue<>(); - udtfResultQueue = new LinkedBlockingQueue<>(); - udtfInputType = new RowType( - Arrays.stream(udtfInputOffsets) - .mapToObj(i -> inputType.getFields().get(i)) - .collect(Collectors.toList())); List udtfOutputDataFields = new ArrayList<>( outputType.getFields().subList(inputType.getFieldCount(), outputType.getFieldCount())); - udtfOutputType = new RowType(udtfOutputDataFields); + udfOutputType = new RowType(udtfOutputDataFields); super.open(); } - @Override - public void processElement(StreamRecord element) throws Exception { - bufferInput(element.getValue()); - super.processElement(element); - emitResults(); - } - - @Override - public PythonFunctionRunner createPythonFunctionRunner() throws Exception { - final FnDataReceiver udtfResultReceiver = input -> { - // handover to queue, do not block the result receiver thread - udtfResultQueue.put(input); - }; - - return new ProjectUdfInputPythonTableFunctionRunner( - createPythonFunctionRunner( - udtfResultReceiver, - createPythonEnvironmentManager())); - } - @Override public PythonEnv getPythonEnv() { return tableFunction.getPythonFunction().getPythonEnv(); } /** - * Buffers the specified input, it will be used to construct - * the operator result together with the udtf execution result. + * The received udtf execution result is a finish message when it is a byte 0x00. */ - public abstract void bufferInput(IN input); - - public abstract UDTFIN getUdtfInput(IN element); - - public abstract PythonFunctionRunner createPythonFunctionRunner( - FnDataReceiver resultReceiver, - PythonEnvironmentManager pythonEnvironmentManager); - - private class ProjectUdfInputPythonTableFunctionRunner implements PythonFunctionRunner { - - private final PythonFunctionRunner pythonFunctionRunner; - - ProjectUdfInputPythonTableFunctionRunner(PythonFunctionRunner pythonFunctionRunner) { - this.pythonFunctionRunner = pythonFunctionRunner; - } - - @Override - public void open() throws Exception { - pythonFunctionRunner.open(); - } - - @Override - public void close() throws Exception { - pythonFunctionRunner.close(); - } - - @Override - public void startBundle() throws Exception { - pythonFunctionRunner.startBundle(); - } - - @Override - public void finishBundle() throws Exception { - pythonFunctionRunner.finishBundle(); - } - - @Override - public void processElement(IN element) throws Exception { - pythonFunctionRunner.processElement(getUdtfInput(element)); - } + protected boolean isFinishResult(byte[] rawUdtfResult) { + return rawUdtfResult.length == 1 && rawUdtfResult[0] == 0x00; } } diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractStatelessFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractStatelessFunctionOperator.java new file mode 100644 index 0000000000000..eb695d286ee9f --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractStatelessFunctionOperator.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.python; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.memory.ByteArrayInputStreamWithPos; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.python.PythonFunctionRunner; +import org.apache.flink.python.env.PythonEnvironmentManager; +import org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.runtime.types.CRow; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.beam.sdk.fn.data.FnDataReceiver; + +import java.io.IOException; +import java.util.Arrays; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.stream.Collectors; + +/** + * Base class for all stream operators to execute Python Stateless Functions. + * + * @param Type of the input elements. + * @param Type of the output elements. + * @param Type of the UDF input type. + */ +@Internal +public abstract class AbstractStatelessFunctionOperator + extends AbstractPythonFunctionOperator { + + private static final long serialVersionUID = 1L; + + /** + * The input logical type. + */ + protected final RowType inputType; + + /** + * The output logical type. + */ + protected final RowType outputType; + + /** + * The offsets of udf inputs. + */ + protected final int[] udfInputOffsets; + + /** + * The udf input logical type. + */ + protected transient RowType udfInputType; + + /** + * The udf output logical type. + */ + protected transient RowType udfOutputType; + + /** + * The queue holding the input elements for which the execution results have not been received. + */ + protected transient LinkedBlockingQueue forwardedInputQueue; + + /** + * The queue holding the user-defined table function execution results. The execution results + * are in the same order as the input elements. + */ + protected transient LinkedBlockingQueue udfResultQueue; + + /** + * Reusable InputStream used to holding the execution results to be deserialized. + */ + protected transient ByteArrayInputStreamWithPos bais; + + /** + * InputStream Wrapper. + */ + protected transient DataInputViewStreamWrapper baisWrapper; + + public AbstractStatelessFunctionOperator( + Configuration config, + RowType inputType, + RowType outputType, + int[] udfInputOffsets) { + super(config); + this.inputType = Preconditions.checkNotNull(inputType); + this.outputType = Preconditions.checkNotNull(outputType); + this.udfInputOffsets = Preconditions.checkNotNull(udfInputOffsets); + } + + @Override + public void open() throws Exception { + forwardedInputQueue = new LinkedBlockingQueue<>(); + udfResultQueue = new LinkedBlockingQueue<>(); + udfInputType = new RowType( + Arrays.stream(udfInputOffsets) + .mapToObj(i -> inputType.getFields().get(i)) + .collect(Collectors.toList())); + bais = new ByteArrayInputStreamWithPos(); + baisWrapper = new DataInputViewStreamWrapper(bais); + super.open(); + } + + @Override + public void processElement(StreamRecord element) throws Exception { + bufferInput(element.getValue()); + super.processElement(element); + emitResults(); + } + + @Override + public PythonFunctionRunner createPythonFunctionRunner() throws IOException { + final FnDataReceiver udfResultReceiver = input -> { + // handover to queue, do not block the result receiver thread + udfResultQueue.put(input); + }; + + return new ProjectUdfInputPythonScalarFunctionRunner( + createPythonFunctionRunner( + udfResultReceiver, + createPythonEnvironmentManager())); + } + + /** + * Buffers the specified input, it will be used to construct + * the operator result together with the udf execution result. + */ + public abstract void bufferInput(IN input); + + public abstract UDFIN getUdfInput(IN element); + + public abstract PythonFunctionRunner createPythonFunctionRunner( + FnDataReceiver resultReceiver, + PythonEnvironmentManager pythonEnvironmentManager); + + private class ProjectUdfInputPythonScalarFunctionRunner implements PythonFunctionRunner { + + private final PythonFunctionRunner pythonFunctionRunner; + + ProjectUdfInputPythonScalarFunctionRunner(PythonFunctionRunner pythonFunctionRunner) { + this.pythonFunctionRunner = pythonFunctionRunner; + } + + @Override + public void open() throws Exception { + pythonFunctionRunner.open(); + } + + @Override + public void close() throws Exception { + pythonFunctionRunner.close(); + } + + @Override + public void startBundle() throws Exception { + pythonFunctionRunner.startBundle(); + } + + @Override + public void finishBundle() throws Exception { + pythonFunctionRunner.finishBundle(); + } + + @Override + public void processElement(IN element) throws Exception { + pythonFunctionRunner.processElement(getUdfInput(element)); + } + } + + /** + * The collector is used to convert a {@link Row} to a {@link CRow}. + */ + protected static class StreamRecordCRowWrappingCollector implements Collector { + + private final Collector> out; + private final CRow reuseCRow = new CRow(); + + /** + * For Table API & SQL jobs, the timestamp field is not used. + */ + private final StreamRecord reuseStreamRecord = new StreamRecord<>(reuseCRow); + + StreamRecordCRowWrappingCollector(Collector> out) { + this.out = out; + } + + public void setChange(boolean change) { + this.reuseCRow.change_$eq(change); + } + + @Override + public void collect(Row record) { + reuseCRow.row_$eq(record); + out.collect(reuseStreamRecord); + } + + @Override + public void close() { + out.close(); + } + } + + /** + * The collector is used to convert a {@link BaseRow} to a {@link StreamRecord}. + */ + protected static class StreamRecordBaseRowWrappingCollector implements Collector { + + private final Collector> out; + + /** + * For Table API & SQL jobs, the timestamp field is not used. + */ + private final StreamRecord reuseStreamRecord = new StreamRecord<>(null); + + StreamRecordBaseRowWrappingCollector(Collector> out) { + this.out = out; + } + + @Override + public void collect(BaseRow record) { + out.collect(reuseStreamRecord.replace(record)); + } + + @Override + public void close() { + out.close(); + } + } +} diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/BaseRowPythonScalarFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/BaseRowPythonScalarFunctionOperator.java index 7bb03bedfbf21..79284ea8861ac 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/BaseRowPythonScalarFunctionOperator.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/BaseRowPythonScalarFunctionOperator.java @@ -23,7 +23,6 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.python.PythonFunctionRunner; import org.apache.flink.python.env.PythonEnvironmentManager; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.api.TableConfig; import org.apache.flink.table.dataformat.BaseRow; import org.apache.flink.table.dataformat.BinaryRow; @@ -37,7 +36,6 @@ import org.apache.flink.table.runtime.runners.python.BaseRowPythonScalarFunctionRunner; import org.apache.flink.table.runtime.typeutils.PythonTypeUtils; import org.apache.flink.table.types.logical.RowType; -import org.apache.flink.util.Collector; import org.apache.beam.sdk.fn.data.FnDataReceiver; @@ -165,31 +163,4 @@ private Projection createForwardedFieldProjection() { // noinspection unchecked return generatedProjection.newInstance(Thread.currentThread().getContextClassLoader()); } - - /** - * The collector is used to convert a {@link BaseRow} to a {@link StreamRecord}. - */ - private static class StreamRecordBaseRowWrappingCollector implements Collector { - - private final Collector> out; - - /** - * For Table API & SQL jobs, the timestamp field is not used. - */ - private final StreamRecord reuseStreamRecord = new StreamRecord<>(null); - - StreamRecordBaseRowWrappingCollector(Collector> out) { - this.out = out; - } - - @Override - public void collect(BaseRow record) { - out.collect(reuseStreamRecord.replace(record)); - } - - @Override - public void close() { - out.close(); - } - } } diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/PythonScalarFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/PythonScalarFunctionOperator.java index f30cd9c8a6626..e191d8eafedc2 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/PythonScalarFunctionOperator.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/PythonScalarFunctionOperator.java @@ -25,7 +25,6 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.python.PythonFunctionRunner; import org.apache.flink.python.env.PythonEnvironmentManager; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.functions.ScalarFunction; import org.apache.flink.table.functions.python.PythonFunctionInfo; import org.apache.flink.table.runtime.runners.python.PythonScalarFunctionRunner; @@ -35,7 +34,6 @@ import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.utils.TypeConversions; import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; import org.apache.beam.sdk.fn.data.FnDataReceiver; @@ -131,37 +129,4 @@ public PythonFunctionRunner createPythonFunctionRunner( udfInputType, udfOutputType); } - - /** - * The collector is used to convert a {@link Row} to a {@link CRow}. - */ - private static class StreamRecordCRowWrappingCollector implements Collector { - - private final Collector> out; - private final CRow reuseCRow = new CRow(); - - /** - * For Table API & SQL jobs, the timestamp field is not used. - */ - private final StreamRecord reuseStreamRecord = new StreamRecord<>(reuseCRow); - - StreamRecordCRowWrappingCollector(Collector> out) { - this.out = out; - } - - public void setChange(boolean change) { - this.reuseCRow.change_$eq(change); - } - - @Override - public void collect(Row record) { - reuseCRow.row_$eq(record); - out.collect(reuseStreamRecord); - } - - @Override - public void close() { - out.close(); - } - } } diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperator.java index 248b11ba08304..87f878373ae0d 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperator.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperator.java @@ -18,24 +18,31 @@ package org.apache.flink.table.runtime.operators.python; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.configuration.Configuration; import org.apache.flink.python.PythonFunctionRunner; import org.apache.flink.python.env.PythonEnvironmentManager; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.functions.TableFunction; import org.apache.flink.table.functions.python.PythonFunctionInfo; import org.apache.flink.table.runtime.runners.python.PythonTableFunctionRunner; import org.apache.flink.table.runtime.types.CRow; +import org.apache.flink.table.runtime.types.CRowTypeInfo; +import org.apache.flink.table.runtime.typeutils.PythonTypeUtils; import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.utils.TypeConversions; import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; import org.apache.beam.sdk.fn.data.FnDataReceiver; +import java.io.IOException; + /** * The Python {@link TableFunction} operator for the legacy planner. */ -public class PythonTableFunctionOperator extends AbstractPythonTableFunctionOperator { +@Internal +public class PythonTableFunctionOperator extends AbstractPythonTableFunctionOperator { private static final long serialVersionUID = 1L; @@ -44,6 +51,16 @@ public class PythonTableFunctionOperator extends AbstractPythonTableFunctionOper */ private transient StreamRecordCRowWrappingCollector cRowWrapper; + /** + * The type serializer for the forwarded fields. + */ + private transient TypeSerializer forwardedInputSerializer; + + /** + * The TypeSerializer for udf execution results. + */ + private transient TypeSerializer udtfOutputTypeSerializer; + public PythonTableFunctionOperator( Configuration config, PythonFunctionInfo tableFunction, @@ -54,27 +71,32 @@ public PythonTableFunctionOperator( } @Override + @SuppressWarnings("unchecked") public void open() throws Exception { super.open(); this.cRowWrapper = new StreamRecordCRowWrappingCollector(output); - } - - private boolean isFinishResult(Row result) { - return result.getArity() == 0; + CRowTypeInfo forwardedInputTypeInfo = new CRowTypeInfo( + new RowTypeInfo(TypeConversions.fromDataTypeToLegacyInfo( + TypeConversions.fromLogicalToDataType(inputType)))); + forwardedInputSerializer = forwardedInputTypeInfo.createSerializer(getExecutionConfig()); + udtfOutputTypeSerializer = PythonTypeUtils.toFlinkTypeSerializer(udfOutputType); } @Override - public void emitResults() { - Row udtfResult; + public void emitResults() throws IOException { CRow input = null; - while ((udtfResult = udtfResultQueue.poll()) != null) { + byte[] rawUdtfResult; + while ((rawUdtfResult = udfResultQueue.poll()) != null) { if (input == null) { input = forwardedInputQueue.poll(); } - if (isFinishResult(udtfResult)) { + boolean isFinishResult = isFinishResult(rawUdtfResult); + if (isFinishResult) { input = forwardedInputQueue.poll(); } - if (input != null && !isFinishResult(udtfResult)) { + if (input != null && !isFinishResult) { + bais.setBuffer(rawUdtfResult, 0, rawUdtfResult.length); + Row udtfResult = udtfOutputTypeSerializer.deserialize(baisWrapper); cRowWrapper.setChange(input.change()); cRowWrapper.collect(Row.join(input.row(), udtfResult)); } @@ -83,57 +105,27 @@ public void emitResults() { @Override public void bufferInput(CRow input) { + if (getExecutionConfig().isObjectReuseEnabled()) { + input = forwardedInputSerializer.copy(input); + } forwardedInputQueue.add(input); } @Override - public Row getUdtfInput(CRow element) { - return Row.project(element.row(), udtfInputOffsets); + public Row getUdfInput(CRow element) { + return Row.project(element.row(), udfInputOffsets); } @Override public PythonFunctionRunner createPythonFunctionRunner( - FnDataReceiver resultReceiver, + FnDataReceiver resultReceiver, PythonEnvironmentManager pythonEnvironmentManager) { return new PythonTableFunctionRunner( getRuntimeContext().getTaskName(), resultReceiver, tableFunction, pythonEnvironmentManager, - udtfInputType, - udtfOutputType); - } - - /** - * The collector is used to convert a {@link Row} to a {@link CRow}. - */ - private static class StreamRecordCRowWrappingCollector implements Collector { - - private final Collector> out; - private final CRow reuseCRow = new CRow(); - - /** - * For Table API & SQL jobs, the timestamp field is not used. - */ - private final StreamRecord reuseStreamRecord = new StreamRecord<>(reuseCRow); - - StreamRecordCRowWrappingCollector(Collector> out) { - this.out = out; - } - - public void setChange(boolean change) { - this.reuseCRow.change_$eq(change); - } - - @Override - public void collect(Row record) { - reuseCRow.row_$eq(record); - out.collect(reuseStreamRecord); - } - - @Override - public void close() { - out.close(); - } + udfInputType, + udfOutputType); } } diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonScalarFunctionRunner.java b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonScalarFunctionRunner.java index b413ff0370152..9a56a4c95fa54 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonScalarFunctionRunner.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonScalarFunctionRunner.java @@ -26,6 +26,7 @@ import org.apache.flink.table.functions.ScalarFunction; import org.apache.flink.table.functions.python.PythonFunctionInfo; import org.apache.flink.table.runtime.typeutils.PythonTypeUtils; +import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.util.Preconditions; @@ -90,8 +91,12 @@ private RunnerApi.Coder getRowCoderProto(RowType rowType) { RunnerApi.FunctionSpec.newBuilder() .setUrn(SCHEMA_CODER_URN) .setPayload(org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString.copyFrom( - PythonTypeUtils.toProtoType(rowType).getRowSchema().toByteArray())) + toProtoType(rowType).getRowSchema().toByteArray())) .build()) .build(); } + + private FlinkFnApi.Schema.FieldType toProtoType(LogicalType logicalType) { + return logicalType.accept(new PythonTypeUtils.LogicalTypeToProtoTypeConverter()); + } } diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonTableFunctionRunner.java b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonTableFunctionRunner.java index da2122e5e75e6..eeee0ce7bdce9 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonTableFunctionRunner.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonTableFunctionRunner.java @@ -18,6 +18,7 @@ package org.apache.flink.table.runtime.runners.python; +import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.fnexecution.v1.FlinkFnApi; import org.apache.flink.python.PythonFunctionRunner; @@ -26,6 +27,7 @@ import org.apache.flink.table.functions.python.PythonFunctionInfo; import org.apache.flink.table.runtime.typeutils.PythonTypeUtils; import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.utils.LogicalTypeDefaultVisitor; import org.apache.flink.util.Preconditions; import org.apache.beam.model.pipeline.v1.RunnerApi; @@ -34,10 +36,10 @@ /** * Abstract {@link PythonFunctionRunner} used to execute Python {@link TableFunction}. * - * @param Type of the input elements. - * @param Type of the execution results. + * @param Type of the input elements. */ -public abstract class AbstractPythonTableFunctionRunner extends AbstractPythonStatelessFunctionRunner { +@Internal +public abstract class AbstractPythonTableFunctionRunner extends AbstractPythonStatelessFunctionRunner { private static final String SCHEMA_CODER_URN = "flink:coder:schema:v1"; private static final String TABLE_FUNCTION_URN = "flink:transform:table_function:v1"; @@ -46,7 +48,7 @@ public abstract class AbstractPythonTableFunctionRunner extends Abstrac public AbstractPythonTableFunctionRunner( String taskName, - FnDataReceiver resultReceiver, + FnDataReceiver resultReceiver, PythonFunctionInfo tableFunction, PythonEnvironmentManager environmentManager, RowType inputType, @@ -87,8 +89,29 @@ private RunnerApi.Coder getTableCoderProto(RowType rowType) { RunnerApi.FunctionSpec.newBuilder() .setUrn(SCHEMA_CODER_URN) .setPayload(org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString.copyFrom( - PythonTypeUtils.toTableProtoType(rowType).toByteArray())) + toTableFunctionProtoType(rowType).toByteArray())) .build()) .build(); } + + private FlinkFnApi.Schema.FieldType toTableFunctionProtoType(RowType rowType) { + FlinkFnApi.Schema.FieldType.Builder builder = + FlinkFnApi.Schema.FieldType.newBuilder() + .setTypeName(FlinkFnApi.Schema.TypeName.TABLEFUNCTIONROW) + .setNullable(rowType.isNullable()); + + LogicalTypeDefaultVisitor converter = + new PythonTypeUtils.LogicalTypeToProtoTypeConverter(); + FlinkFnApi.Schema.Builder schemaBuilder = FlinkFnApi.Schema.newBuilder(); + for (RowType.RowField field : rowType.getFields()) { + schemaBuilder.addFields( + FlinkFnApi.Schema.Field.newBuilder() + .setName(field.getName()) + .setDescription(field.getDescription().orElse("")) + .setType(field.getType().accept(converter)) + .build()); + } + builder.setRowSchema(schemaBuilder.build()); + return builder.build(); + } } diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/PythonTableFunctionRunner.java b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/PythonTableFunctionRunner.java index f9762c68c2c81..93fbd97d67ff7 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/PythonTableFunctionRunner.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/PythonTableFunctionRunner.java @@ -20,12 +20,12 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.typeutils.runtime.RowSerializer; import org.apache.flink.python.PythonFunctionRunner; import org.apache.flink.python.env.PythonEnvironmentManager; import org.apache.flink.table.functions.TableFunction; import org.apache.flink.table.functions.python.PythonFunctionInfo; import org.apache.flink.table.runtime.typeutils.PythonTypeUtils; -import org.apache.flink.table.runtime.typeutils.serializers.python.RowTableSerializer; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.types.Row; @@ -36,11 +36,11 @@ * It takes {@link Row} as the input and output type. */ @Internal -public class PythonTableFunctionRunner extends AbstractPythonTableFunctionRunner { +public class PythonTableFunctionRunner extends AbstractPythonTableFunctionRunner { public PythonTableFunctionRunner( String taskName, - FnDataReceiver resultReceiver, + FnDataReceiver resultReceiver, PythonFunctionInfo tableFunction, PythonEnvironmentManager environmentManager, RowType inputType, @@ -50,11 +50,6 @@ public PythonTableFunctionRunner( @Override public TypeSerializer getInputTypeSerializer() { - return (RowTableSerializer) PythonTypeUtils.toFlinkTableTypeSerializer(getInputType()); - } - - @Override - public TypeSerializer getOutputTypeSerializer() { - return (RowTableSerializer) PythonTypeUtils.toFlinkTableTypeSerializer(getOutputType()); + return (RowSerializer) PythonTypeUtils.toFlinkTypeSerializer(getInputType()); } } diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/PythonTypeUtils.java b/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/PythonTypeUtils.java index ac31ba8f4056e..383f75f481ab7 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/PythonTypeUtils.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/PythonTypeUtils.java @@ -38,7 +38,6 @@ import org.apache.flink.table.runtime.typeutils.serializers.python.BigDecSerializer; import org.apache.flink.table.runtime.typeutils.serializers.python.DateSerializer; import org.apache.flink.table.runtime.typeutils.serializers.python.DecimalSerializer; -import org.apache.flink.table.runtime.typeutils.serializers.python.RowTableSerializer; import org.apache.flink.table.runtime.typeutils.serializers.python.StringSerializer; import org.apache.flink.table.runtime.typeutils.serializers.python.TimeSerializer; import org.apache.flink.table.runtime.typeutils.serializers.python.TimestampSerializer; @@ -87,43 +86,6 @@ public static TypeSerializer toBlinkTypeSerializer(LogicalType logicalType) { return logicalType.accept(new LogicalTypeToBlinkTypeSerializerConverter()); } - public static FlinkFnApi.Schema.FieldType toProtoType(LogicalType logicalType) { - return logicalType.accept(new LogicalTypeToProtoTypeConverter()); - } - - public static TypeSerializer toFlinkTableTypeSerializer(LogicalType logicalType) { - RowType rowType = (RowType) logicalType; - LogicalTypeDefaultVisitor converter = - new LogicalTypeToTypeSerializerConverter(); - final TypeSerializer[] fieldTypeSerializers = rowType.getFields() - .stream() - .map(f -> f.getType().accept(converter)) - .toArray(TypeSerializer[]::new); - return new RowTableSerializer(fieldTypeSerializers); - } - - public static FlinkFnApi.Schema.FieldType toTableProtoType(LogicalType logicalType) { - RowType rowType = (RowType) logicalType; - FlinkFnApi.Schema.FieldType.Builder builder = - FlinkFnApi.Schema.FieldType.newBuilder() - .setTypeName(FlinkFnApi.Schema.TypeName.TABLE) - .setNullable(rowType.isNullable()); - - LogicalTypeDefaultVisitor converter = - new LogicalTypeToProtoTypeConverter(); - FlinkFnApi.Schema.Builder schemaBuilder = FlinkFnApi.Schema.newBuilder(); - for (RowType.RowField field : rowType.getFields()) { - schemaBuilder.addFields( - FlinkFnApi.Schema.Field.newBuilder() - .setName(field.getName()) - .setDescription(field.getDescription().orElse(EMPTY_STRING)) - .setType(field.getType().accept(converter)) - .build()); - } - builder.setRowSchema(schemaBuilder.build()); - return builder.build(); - } - /** * Convert LogicalType to conversion class for flink planner. */ @@ -327,7 +289,10 @@ public TypeSerializer visit(DecimalType decimalType) { } } - private static class LogicalTypeToProtoTypeConverter extends LogicalTypeDefaultVisitor { + /** + * Converter That convert the logicalType to the related Prototype. + */ + public static class LogicalTypeToProtoTypeConverter extends LogicalTypeDefaultVisitor { @Override public FlinkFnApi.Schema.FieldType visit(BooleanType booleanType) { return FlinkFnApi.Schema.FieldType.newBuilder() diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/RowTableSerializer.java b/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/RowTableSerializer.java deleted file mode 100644 index 5f5d2f48a45c9..0000000000000 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/RowTableSerializer.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.runtime.typeutils.serializers.python; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.java.typeutils.runtime.RowSerializer; -import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.types.Row; - -import java.io.IOException; - -/** - * The implementation of TableSerializer in legacy planner. - */ -@Internal -public final class RowTableSerializer extends TableSerializer { - - private final RowSerializer rowSerializer; - - public RowTableSerializer(TypeSerializer[] fieldSerializers) { - super(fieldSerializers); - this.rowSerializer = new RowSerializer(fieldSerializers); - } - - @Override - public Row createResult(int len) { - return new Row(len); - } - - @Override - public void setField(Row result, int index, Object value) { - result.setField(index, value); - } - - @Override - public void serialize(Row record, DataOutputView target) throws IOException { - rowSerializer.serialize(record, target); - } - - public RowSerializer getRowSerializer() { - return rowSerializer; - } -} diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/TableSerializer.java b/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/TableSerializer.java deleted file mode 100644 index a53a390653814..0000000000000 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/TableSerializer.java +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.runtime.typeutils.serializers.python; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; - -import java.io.DataInputStream; -import java.io.IOException; -import java.util.Arrays; - -import static org.apache.flink.api.java.typeutils.runtime.NullMaskUtils.readIntoNullMask; - -/** - * Base Table Serializer for Table Function. - */ -@Internal -abstract class TableSerializer extends TypeSerializer { - - private final TypeSerializer[] fieldSerializers; - - private transient boolean[] nullMask; - - TableSerializer(TypeSerializer[] fieldSerializers) { - this.fieldSerializers = fieldSerializers; - this.nullMask = new boolean[Math.max(fieldSerializers.length - 8, 0)]; - } - - @Override - public boolean isImmutableType() { - return false; - } - - @Override - public TypeSerializer duplicate() { - throw new RuntimeException("This method duplicate() should not be called"); - } - - @Override - public T createInstance() { - return unwantedMethodCall("createInstance()"); - } - - @Override - public T copy(T from) { - return unwantedMethodCall("copy(T from)"); - } - - @Override - public T copy(T from, T reuse) { - return unwantedMethodCall("copy(T from, T reuse)"); - } - - @Override - public int getLength() { - return -1; - } - - @Override - public T deserialize(T reuse, DataInputView source) throws IOException { - return unwantedMethodCall("deserialize(T reuse, DataInputView source)"); - } - - @Override - public void copy(DataInputView source, DataOutputView target) throws IOException { - unwantedMethodCall("copy(DataInputView source, DataOutputView target)"); - } - - private T unwantedMethodCall(String methodName) { - throw new RuntimeException(String.format("The method %s should not be called", methodName)); - } - - public abstract T createResult(int len); - - public abstract void setField(T result, int index, Object value); - - @Override - public T deserialize(DataInputView source) throws IOException { - int len = fieldSerializers.length; - int b = source.readUnsignedByte() & 0xff; - DataInputStream inputStream = (DataInputStream) source; - if (b == 0x00 && inputStream.available() == 0) { - return createResult(0); - } - T result = createResult(len); - int minLen = Math.min(8, len); - readIntoNullMask(len - 8, source, nullMask); - for (int i = 0; i < minLen; i++) { - if ((b & 0x80) > 0) { - setField(result, i, null); - } else { - setField(result, i, fieldSerializers[i].deserialize(source)); - } - b = b << 1; - } - for (int i = 0, j = minLen; j < len; i++, j++) { - if (nullMask[i]) { - setField(result, j, null); - } else { - setField(result, j, fieldSerializers[j].deserialize(source)); - } - } - - return result; - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof TableSerializer) { - TableSerializer other = (TableSerializer) obj; - if (this.fieldSerializers.length == other.fieldSerializers.length) { - for (int i = 0; i < this.fieldSerializers.length; i++) { - if (!this.fieldSerializers[i].equals(other.fieldSerializers[i])) { - return false; - } - } - return true; - } - } - - return false; - } - - @Override - public int hashCode() { - return Arrays.hashCode(fieldSerializers); - } - - @Override - public TypeSerializerSnapshot snapshotConfiguration() { - throw new RuntimeException("The method snapshotConfiguration() should not be called"); - } -} diff --git a/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonTableFunctionRunnerTest.java b/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonTableFunctionRunnerTest.java index f0dcb5e145f88..f22d4108ede55 100644 --- a/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonTableFunctionRunnerTest.java +++ b/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonTableFunctionRunnerTest.java @@ -28,38 +28,17 @@ * Base class for PythonTableFunctionRunner and BaseRowPythonTableFunctionRunner test. * * @param Type of the input elements. - * @param Type of the output elements. */ -public abstract class AbstractPythonTableFunctionRunnerTest { - AbstractPythonTableFunctionRunner createUDTFRunner() throws Exception { +public abstract class AbstractPythonTableFunctionRunnerTest { + AbstractPythonTableFunctionRunner createUDTFRunner() throws Exception { PythonFunctionInfo pythonFunctionInfo = new PythonFunctionInfo( - DummyPythonFunction.INSTANCE, + AbstractPythonScalarFunctionRunnerTest.DummyPythonFunction.INSTANCE, new Integer[]{0}); RowType rowType = new RowType(Collections.singletonList(new RowType.RowField("f1", new BigIntType()))); return createPythonTableFunctionRunner(pythonFunctionInfo, rowType, rowType); } - public abstract AbstractPythonTableFunctionRunner createPythonTableFunctionRunner( + public abstract AbstractPythonTableFunctionRunner createPythonTableFunctionRunner( PythonFunctionInfo pythonFunctionInfo, RowType inputType, RowType outputType) throws Exception; - - /** - * Dummy PythonFunction. - */ - public static class DummyPythonFunction implements PythonFunction { - - private static final long serialVersionUID = 1L; - - public static final PythonFunction INSTANCE = new DummyPythonFunction(); - - @Override - public byte[] getSerializedPythonFunction() { - return new byte[0]; - } - - @Override - public PythonEnv getPythonEnv() { - return new PythonEnv(PythonEnv.ExecType.PROCESS); - } - } } diff --git a/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTableFunctionRunnerTest.java b/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTableFunctionRunnerTest.java index 9532fdc90051b..4a76c3af44d9e 100644 --- a/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTableFunctionRunnerTest.java +++ b/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTableFunctionRunnerTest.java @@ -19,13 +19,13 @@ package org.apache.flink.table.functions.python; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.typeutils.runtime.RowSerializer; import org.apache.flink.fnexecution.v1.FlinkFnApi; import org.apache.flink.python.env.ProcessPythonEnvironmentManager; import org.apache.flink.python.env.PythonDependencyInfo; import org.apache.flink.python.env.PythonEnvironmentManager; import org.apache.flink.table.runtime.runners.python.AbstractPythonTableFunctionRunner; import org.apache.flink.table.runtime.runners.python.PythonTableFunctionRunner; -import org.apache.flink.table.runtime.typeutils.serializers.python.RowTableSerializer; import org.apache.flink.table.types.logical.BigIntType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.types.Row; @@ -50,27 +50,22 @@ *
  • The UDTF proto is properly constructed
  • * */ -public class PythonTableFunctionRunnerTest extends AbstractPythonTableFunctionRunnerTest { +public class PythonTableFunctionRunnerTest extends AbstractPythonTableFunctionRunnerTest { @Test public void testInputOutputDataTypeConstructedProperlyForSingleUDTF() throws Exception { - final AbstractPythonTableFunctionRunner runner = createUDTFRunner(); + final AbstractPythonTableFunctionRunner runner = createUDTFRunner(); // check input TypeSerializer TypeSerializer inputTypeSerializer = runner.getInputTypeSerializer(); - assertTrue(inputTypeSerializer instanceof RowTableSerializer); + assertTrue(inputTypeSerializer instanceof RowSerializer); - assertEquals(1, ((RowTableSerializer) inputTypeSerializer).getRowSerializer().getArity()); - - // check output TypeSerializer - TypeSerializer outputTypeSerializer = runner.getOutputTypeSerializer(); - assertTrue(outputTypeSerializer instanceof RowTableSerializer); - assertEquals(1, ((RowTableSerializer) outputTypeSerializer).getRowSerializer().getArity()); + assertEquals(1, ((RowSerializer) inputTypeSerializer).getArity()); } @Test public void testUDFnProtoConstructedProperlyForSingleUTDF() throws Exception { - final AbstractPythonTableFunctionRunner runner = createUDTFRunner(); + final AbstractPythonTableFunctionRunner runner = createUDTFRunner(); FlinkFnApi.UserDefinedFunctions udtfs = runner.getUserDefinedFunctionsProto(); assertEquals(1, udtfs.getUdfsCount()); @@ -81,11 +76,11 @@ public void testUDFnProtoConstructedProperlyForSingleUTDF() throws Exception { } @Override - public AbstractPythonTableFunctionRunner createPythonTableFunctionRunner( + public AbstractPythonTableFunctionRunner createPythonTableFunctionRunner( PythonFunctionInfo pythonFunctionInfo, RowType inputType, RowType outputType) throws Exception { - final FnDataReceiver dummyReceiver = input -> { + final FnDataReceiver dummyReceiver = input -> { // ignore the execution results }; @@ -104,10 +99,10 @@ public AbstractPythonTableFunctionRunner createPythonTableFunctionRunn outputType); } - private AbstractPythonTableFunctionRunner createUDTFRunner( - JobBundleFactory jobBundleFactory, FnDataReceiver receiver) throws IOException { + private AbstractPythonTableFunctionRunner createUDTFRunner( + JobBundleFactory jobBundleFactory, FnDataReceiver receiver) throws IOException { PythonFunctionInfo pythonFunctionInfo = new PythonFunctionInfo( - DummyPythonFunction.INSTANCE, + AbstractPythonScalarFunctionRunnerTest.DummyPythonFunction.INSTANCE, new Integer[]{0}); RowType rowType = new RowType(Collections.singletonList(new RowType.RowField("f1", new BigIntType()))); @@ -134,7 +129,7 @@ private static class PythonTableFunctionRunnerTestHarness extends PythonTableFun PythonTableFunctionRunnerTestHarness( String taskName, - FnDataReceiver resultReceiver, + FnDataReceiver resultReceiver, PythonFunctionInfo tableFunction, PythonEnvironmentManager environmentManager, RowType inputType, diff --git a/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTypeUtilsTest.java b/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTypeUtilsTest.java index 22e3837a521fd..53ae2881d4e04 100644 --- a/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTypeUtilsTest.java +++ b/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTypeUtilsTest.java @@ -24,7 +24,6 @@ import org.apache.flink.table.catalog.UnresolvedIdentifier; import org.apache.flink.table.runtime.typeutils.BaseRowSerializer; import org.apache.flink.table.runtime.typeutils.PythonTypeUtils; -import org.apache.flink.table.runtime.typeutils.serializers.python.RowTableSerializer; import org.apache.flink.table.types.logical.ArrayType; import org.apache.flink.table.types.logical.BigIntType; import org.apache.flink.table.types.logical.DateType; @@ -58,17 +57,6 @@ public void testLogicalTypeToFlinkTypeSerializer() { assertEquals(1, ((RowSerializer) rowSerializer).getArity()); } - @Test - public void testLogicalTypeToFlinkTableTypeSerializer() { - List rowFields = new ArrayList<>(); - rowFields.add(new RowType.RowField("f1", new BigIntType())); - RowType rowType = new RowType(rowFields); - TypeSerializer rowTableSerializer = PythonTypeUtils.toFlinkTableTypeSerializer(rowType); - assertTrue(rowTableSerializer instanceof RowTableSerializer); - - assertEquals(1, ((RowTableSerializer) rowTableSerializer).getRowSerializer().getArity()); - } - @Test public void testLogicalTypeToBlinkTypeSerializer() { List rowFields = new ArrayList<>(); @@ -85,7 +73,8 @@ public void testLogicalTypeToProto() { List rowFields = new ArrayList<>(); rowFields.add(new RowType.RowField("f1", new BigIntType())); RowType rowType = new RowType(rowFields); - FlinkFnApi.Schema.FieldType protoType = PythonTypeUtils.toProtoType(rowType); + FlinkFnApi.Schema.FieldType protoType = + rowType.accept(new PythonTypeUtils.LogicalTypeToProtoTypeConverter()); FlinkFnApi.Schema schema = protoType.getRowSchema(); assertEquals(1, schema.getFieldsCount()); assertEquals("f1", schema.getFields(0).getName()); diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/AbstractPassThroughPythonTableFunctionRunner.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/AbstractPassThroughPythonTableFunctionRunner.java deleted file mode 100644 index bbc25da11c420..0000000000000 --- a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/AbstractPassThroughPythonTableFunctionRunner.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.runtime.operators.python; - -import org.apache.flink.python.PythonFunctionRunner; -import org.apache.flink.util.Preconditions; - -import org.apache.beam.sdk.fn.data.FnDataReceiver; - -import java.util.ArrayList; -import java.util.List; - -/** - * A {@link PythonFunctionRunner} that just emit each input element for Python UDTF. - * - * @param Type of the input elements. - */ -public abstract class AbstractPassThroughPythonTableFunctionRunner implements PythonFunctionRunner { - - protected boolean bundleStarted; - protected final List bufferedElements; - protected final FnDataReceiver resultReceiver; - - AbstractPassThroughPythonTableFunctionRunner(FnDataReceiver resultReceiver) { - this.resultReceiver = Preconditions.checkNotNull(resultReceiver); - bundleStarted = false; - bufferedElements = new ArrayList<>(); - } - - @Override - public void open() {} - - @Override - public void close() {} - - @Override - public void startBundle() { - Preconditions.checkState(!bundleStarted); - bundleStarted = true; - } - - @Override - public void processElement(IN element) { - bufferedElements.add(copy(element)); - } - - public abstract IN copy(IN element); -} diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PassThroughPythonTableFunctionRunner.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PassThroughPythonTableFunctionRunner.java index 2ecccf9fc129c..057760069b54b 100644 --- a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PassThroughPythonTableFunctionRunner.java +++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PassThroughPythonTableFunctionRunner.java @@ -18,23 +18,55 @@ package org.apache.flink.table.runtime.operators.python; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.python.PythonFunctionRunner; import org.apache.flink.table.runtime.runners.python.PythonTableFunctionRunner; -import org.apache.flink.types.Row; import org.apache.flink.util.Preconditions; import org.apache.beam.sdk.fn.data.FnDataReceiver; +import java.util.ArrayList; +import java.util.List; + /** * A {@link PythonTableFunctionRunner} that just emit each input element. */ -public class PassThroughPythonTableFunctionRunner extends AbstractPassThroughPythonTableFunctionRunner { - PassThroughPythonTableFunctionRunner(FnDataReceiver resultReceiver) { - super(resultReceiver); +public abstract class PassThroughPythonTableFunctionRunner implements PythonFunctionRunner { + private boolean bundleStarted; + private final List bufferedElements; + private final FnDataReceiver resultReceiver; + + /** + * Reusable OutputStream used to holding the serialized input elements. + */ + private transient ByteArrayOutputStreamWithPos baos; + + /** + * OutputStream Wrapper. + */ + private transient DataOutputViewStreamWrapper baosWrapper; + + PassThroughPythonTableFunctionRunner(FnDataReceiver resultReceiver) { + this.resultReceiver = Preconditions.checkNotNull(resultReceiver); + bundleStarted = false; + bufferedElements = new ArrayList<>(); } @Override - public Row copy(Row element) { - return Row.copy(element); + public void open() { + baos = new ByteArrayOutputStreamWithPos(); + baosWrapper = new DataOutputViewStreamWrapper(baos); + } + + @Override + public void close() {} + + @Override + public void startBundle() { + Preconditions.checkState(!bundleStarted); + bundleStarted = true; } @Override @@ -42,10 +74,21 @@ public void finishBundle() throws Exception { Preconditions.checkState(bundleStarted); bundleStarted = false; - for (Row element : bufferedElements) { - resultReceiver.accept(element); - resultReceiver.accept(new Row(0)); + for (IN element : bufferedElements) { + baos.reset(); + getInputTypeSerializer().serialize(element, baosWrapper); + resultReceiver.accept(baos.toByteArray()); + resultReceiver.accept(new byte[]{0}); } bufferedElements.clear(); } + + @Override + public void processElement(IN element) { + bufferedElements.add(copy(element)); + } + + public abstract IN copy(IN element); + + public abstract TypeSerializer getInputTypeSerializer(); } diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTest.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTest.java index ffaaeec293dfd..1d9dfa1a6d186 100644 --- a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTest.java +++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTest.java @@ -18,12 +18,14 @@ package org.apache.flink.table.runtime.operators.python; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.Configuration; import org.apache.flink.python.PythonFunctionRunner; import org.apache.flink.python.env.PythonEnvironmentManager; import org.apache.flink.streaming.util.TestHarnessUtil; import org.apache.flink.table.functions.python.PythonFunctionInfo; import org.apache.flink.table.runtime.types.CRow; +import org.apache.flink.table.runtime.typeutils.PythonTypeUtils; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.types.Row; @@ -35,9 +37,9 @@ /** * Tests for {@link PythonTableFunctionOperator}. */ -public class PythonTableFunctionOperatorTest extends PythonTableFunctionOperatorTestBase { +public class PythonTableFunctionOperatorTest extends PythonTableFunctionOperatorTestBase { @Override - public AbstractPythonTableFunctionOperator getTestOperator( + public AbstractPythonTableFunctionOperator getTestOperator( Configuration config, PythonFunctionInfo tableFunction, RowType inputType, @@ -70,9 +72,20 @@ private static class PassThroughPythonTableFunctionOperator extends PythonTableF @Override public PythonFunctionRunner createPythonFunctionRunner( - FnDataReceiver resultReceiver, + FnDataReceiver resultReceiver, PythonEnvironmentManager pythonEnvironmentManager) { - return new PassThroughPythonTableFunctionRunner(resultReceiver); + return new PassThroughPythonTableFunctionRunner(resultReceiver) { + @Override + public Row copy(Row element) { + return Row.copy(element); + } + + @Override + @SuppressWarnings("unchecked") + public TypeSerializer getInputTypeSerializer() { + return PythonTypeUtils.toFlinkTypeSerializer(udfInputType); + } + }; } } } diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTestBase.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTestBase.java index 93ed9da956f4d..4752a7b975523 100644 --- a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTestBase.java +++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTestBase.java @@ -22,7 +22,7 @@ import org.apache.flink.python.PythonOptions; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; -import org.apache.flink.table.functions.python.AbstractPythonTableFunctionRunnerTest; +import org.apache.flink.table.functions.python.AbstractPythonScalarFunctionRunnerTest; import org.apache.flink.table.functions.python.PythonFunctionInfo; import org.apache.flink.table.types.logical.BigIntType; import org.apache.flink.table.types.logical.RowType; @@ -46,9 +46,8 @@ * @param Type of the input elements. * @param Type of the output elements. * @param Type of the UDTF input type. - * @param Type of the UDTF input type. */ -public abstract class PythonTableFunctionOperatorTestBase { +public abstract class PythonTableFunctionOperatorTestBase { @Test public void testRetractionFieldKept() throws Exception { @@ -139,17 +138,22 @@ public void testFinishBundleTriggeredByTime() throws Exception { } private OneInputStreamOperatorTestHarness getTestHarness(Configuration config) throws Exception { - RowType dataType = new RowType(Arrays.asList( + RowType inputType = new RowType(Arrays.asList( new RowType.RowField("f1", new VarCharType()), new RowType.RowField("f2", new VarCharType()), new RowType.RowField("f3", new BigIntType()))); - AbstractPythonTableFunctionOperator operator = getTestOperator( + RowType outputType = new RowType(Arrays.asList( + new RowType.RowField("f1", new VarCharType()), + new RowType.RowField("f2", new VarCharType()), + new RowType.RowField("f3", new BigIntType()), + new RowType.RowField("f4", new BigIntType()))); + AbstractPythonTableFunctionOperator operator = getTestOperator( config, new PythonFunctionInfo( - AbstractPythonTableFunctionRunnerTest.DummyPythonFunction.INSTANCE, + AbstractPythonScalarFunctionRunnerTest.DummyPythonFunction.INSTANCE, new Integer[]{0}), - dataType, - dataType, + inputType, + outputType, new int[]{2} ); @@ -163,7 +167,7 @@ private OneInputStreamOperatorTestHarness getTestHarness(Configuration public abstract void assertOutputEquals(String message, Collection expected, Collection actual); - public abstract AbstractPythonTableFunctionOperator getTestOperator( + public abstract AbstractPythonTableFunctionOperator getTestOperator( Configuration config, PythonFunctionInfo tableFunction, RowType inputType,