diff --git a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py index 3ca7fa5a46d659..5093057cf8cb4e 100644 --- a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py +++ b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py @@ -36,7 +36,7 @@ name='flink-fn-execution.proto', package='org.apache.flink.fn_execution.v1', syntax='proto3', - serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\xfc\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12K\n\x06inputs\x18\x02 \x03(\x0b\x32;.org.apache.flink.fn_execution.v1.UserDefinedFunction.Input\x1a\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\n\x05input\"[\n\x14UserDefinedFunctions\x12\x43\n\x04udfs\x18\x01 \x03(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunction\"\x80\t\n\x06Schema\x12>\n\x06\x66ields\x18\x01 \x03(\x0b\x32..org.apache.flink.fn_execution.v1.Schema.Field\x1a\x97\x01\n\x07MapType\x12\x44\n\x08key_type\x18\x01 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldType\x12\x46\n\nvalue_type\x18\x02 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldType\x1a!\n\x0c\x44\x61teTimeType\x12\x11\n\tprecision\x18\x01 \x01(\x05\x1a/\n\x0b\x44\x65\x63imalType\x12\x11\n\tprecision\x18\x01 \x01(\x05\x12\r\n\x05scale\x18\x02 \x01(\x05\x1a\xec\x03\n\tFieldType\x12\x44\n\ttype_name\x18\x01 \x01(\x0e\x32\x31.org.apache.flink.fn_execution.v1.Schema.TypeName\x12\x10\n\x08nullable\x18\x02 \x01(\x08\x12U\n\x17\x63ollection_element_type\x18\x03 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldTypeH\x00\x12\x44\n\x08map_type\x18\x04 \x01(\x0b\x32\x30.org.apache.flink.fn_execution.v1.Schema.MapTypeH\x00\x12>\n\nrow_schema\x18\x05 \x01(\x0b\x32(.org.apache.flink.fn_execution.v1.SchemaH\x00\x12O\n\x0e\x64\x61te_time_type\x18\x06 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.Schema.DateTimeTypeH\x00\x12L\n\x0c\x64\x65\x63imal_type\x18\x07 \x01(\x0b\x32\x34.org.apache.flink.fn_execution.v1.Schema.DecimalTypeH\x00\x42\x0b\n\ttype_info\x1al\n\x05\x46ield\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12@\n\x04type\x18\x03 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldType\"\xea\x01\n\x08TypeName\x12\x07\n\x03ROW\x10\x00\x12\x0b\n\x07TINYINT\x10\x01\x12\x0c\n\x08SMALLINT\x10\x02\x12\x07\n\x03INT\x10\x03\x12\n\n\x06\x42IGINT\x10\x04\x12\x0b\n\x07\x44\x45\x43IMAL\x10\x05\x12\t\n\x05\x46LOAT\x10\x06\x12\n\n\x06\x44OUBLE\x10\x07\x12\x08\n\x04\x44\x41TE\x10\x08\x12\x08\n\x04TIME\x10\t\x12\x0c\n\x08\x44\x41TETIME\x10\n\x12\x0b\n\x07\x42OOLEAN\x10\x0b\x12\n\n\x06\x42INARY\x10\x0c\x12\r\n\tVARBINARY\x10\r\x12\x08\n\x04\x43HAR\x10\x0e\x12\x0b\n\x07VARCHAR\x10\x0f\x12\t\n\x05\x41RRAY\x10\x10\x12\x07\n\x03MAP\x10\x11\x12\x0c\n\x08MULTISET\x10\x12\x42-\n\x1forg.apache.flink.fnexecution.v1B\nFlinkFnApib\x06proto3') + serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\xfc\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12K\n\x06inputs\x18\x02 \x03(\x0b\x32;.org.apache.flink.fn_execution.v1.UserDefinedFunction.Input\x1a\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\n\x05input\"[\n\x14UserDefinedFunctions\x12\x43\n\x04udfs\x18\x01 \x03(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunction\"\x8b\t\n\x06Schema\x12>\n\x06\x66ields\x18\x01 \x03(\x0b\x32..org.apache.flink.fn_execution.v1.Schema.Field\x1a\x97\x01\n\x07MapType\x12\x44\n\x08key_type\x18\x01 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldType\x12\x46\n\nvalue_type\x18\x02 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldType\x1a!\n\x0c\x44\x61teTimeType\x12\x11\n\tprecision\x18\x01 \x01(\x05\x1a/\n\x0b\x44\x65\x63imalType\x12\x11\n\tprecision\x18\x01 \x01(\x05\x12\r\n\x05scale\x18\x02 \x01(\x05\x1a\xec\x03\n\tFieldType\x12\x44\n\ttype_name\x18\x01 \x01(\x0e\x32\x31.org.apache.flink.fn_execution.v1.Schema.TypeName\x12\x10\n\x08nullable\x18\x02 \x01(\x08\x12U\n\x17\x63ollection_element_type\x18\x03 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldTypeH\x00\x12\x44\n\x08map_type\x18\x04 \x01(\x0b\x32\x30.org.apache.flink.fn_execution.v1.Schema.MapTypeH\x00\x12>\n\nrow_schema\x18\x05 \x01(\x0b\x32(.org.apache.flink.fn_execution.v1.SchemaH\x00\x12O\n\x0e\x64\x61te_time_type\x18\x06 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.Schema.DateTimeTypeH\x00\x12L\n\x0c\x64\x65\x63imal_type\x18\x07 \x01(\x0b\x32\x34.org.apache.flink.fn_execution.v1.Schema.DecimalTypeH\x00\x42\x0b\n\ttype_info\x1al\n\x05\x46ield\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12@\n\x04type\x18\x03 \x01(\x0b\x32\x32.org.apache.flink.fn_execution.v1.Schema.FieldType\"\xf5\x01\n\x08TypeName\x12\x07\n\x03ROW\x10\x00\x12\x0b\n\x07TINYINT\x10\x01\x12\x0c\n\x08SMALLINT\x10\x02\x12\x07\n\x03INT\x10\x03\x12\n\n\x06\x42IGINT\x10\x04\x12\x0b\n\x07\x44\x45\x43IMAL\x10\x05\x12\t\n\x05\x46LOAT\x10\x06\x12\n\n\x06\x44OUBLE\x10\x07\x12\x08\n\x04\x44\x41TE\x10\x08\x12\x08\n\x04TIME\x10\t\x12\x0c\n\x08\x44\x41TETIME\x10\n\x12\x0b\n\x07\x42OOLEAN\x10\x0b\x12\n\n\x06\x42INARY\x10\x0c\x12\r\n\tVARBINARY\x10\r\x12\x08\n\x04\x43HAR\x10\x0e\x12\x0b\n\x07VARCHAR\x10\x0f\x12\t\n\x05\x41RRAY\x10\x10\x12\x07\n\x03MAP\x10\x11\x12\x0c\n\x08MULTISET\x10\x12\x12\t\n\x05TABLE\x10\x13\x42-\n\x1forg.apache.flink.fnexecution.v1B\nFlinkFnApib\x06proto3') ) @@ -123,11 +123,15 @@ name='MULTISET', index=18, number=18, options=None, type=None), + _descriptor.EnumValueDescriptor( + name='TABLE', index=19, number=19, + options=None, + type=None), ], containing_type=None, options=None, serialized_start=1329, - serialized_end=1563, + serialized_end=1574, ) _sym_db.RegisterEnumDescriptor(_SCHEMA_TYPENAME) @@ -499,7 +503,7 @@ oneofs=[ ], serialized_start=411, - serialized_end=1563, + serialized_end=1574, ) _USERDEFINEDFUNCTION_INPUT.fields_by_name['udf'].message_type = _USERDEFINEDFUNCTION diff --git a/flink-python/pyflink/proto/flink-fn-execution.proto b/flink-python/pyflink/proto/flink-fn-execution.proto index 0cca33ef2784bb..fd91d65ad1f636 100644 --- a/flink-python/pyflink/proto/flink-fn-execution.proto +++ b/flink-python/pyflink/proto/flink-fn-execution.proto @@ -73,6 +73,7 @@ message Schema { ARRAY = 16; MAP = 17; MULTISET = 18; + TABLE = 19; } message MapType { diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractPythonTableFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractPythonTableFunctionOperator.java new file mode 100644 index 00000000000000..c9328239b223ab --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractPythonTableFunctionOperator.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.python; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.python.PythonFunctionRunner; +import org.apache.flink.python.env.PythonEnvironmentManager; +import org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.functions.TableFunction; +import org.apache.flink.table.functions.python.PythonEnv; +import org.apache.flink.table.functions.python.PythonFunctionInfo; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.util.Preconditions; + +import org.apache.beam.sdk.fn.data.FnDataReceiver; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.stream.Collectors; + +/** + * @param Type of the input elements. + * @param Type of the output elements. + * @param Type of the UDF input type. + * @param Type of the UDF input type. + */ +public abstract class AbstractPythonTableFunctionOperator + extends AbstractPythonFunctionOperator { + + private static final long serialVersionUID = 1L; + + /** + * The Python {@link TableFunction} to be executed. + */ + protected final PythonFunctionInfo tableFunction; + + /** + * The input logical type. + */ + protected final RowType inputType; + + /** + * The output logical type. + */ + protected final RowType outputType; + + /** + * The offsets of udtf inputs. + */ + protected final int[] udtfInputOffsets; + + /** + * The udtf input logical type. + */ + protected transient RowType udtfInputType; + + /** + * The udtf output logical type. + */ + protected transient RowType udtfOutputType; + + /** + * The queue holding the input elements for which the execution results have not been received. + */ + protected transient LinkedBlockingQueue forwardedInputQueue; + + /** + * The queue holding the user-defined table function execution results. The execution results + * are in the same order as the input elements. + */ + protected transient LinkedBlockingQueue udtfResultQueue; + + public AbstractPythonTableFunctionOperator( + Configuration config, + PythonFunctionInfo tableFunction, + RowType inputType, + RowType outputType, + int[] udtfInputOffsets) { + super(config); + this.tableFunction = Preconditions.checkNotNull(tableFunction); + this.inputType = Preconditions.checkNotNull(inputType); + this.outputType = Preconditions.checkNotNull(outputType); + this.udtfInputOffsets = Preconditions.checkNotNull(udtfInputOffsets); + } + + @Override + public void open() throws Exception { + forwardedInputQueue = new LinkedBlockingQueue<>(); + udtfResultQueue = new LinkedBlockingQueue<>(); + udtfInputType = new RowType( + Arrays.stream(udtfInputOffsets) + .mapToObj(i -> inputType.getFields().get(i)) + .collect(Collectors.toList())); + List udtfOutputDataFields = new ArrayList<>( + outputType.getFields().subList(inputType.getFieldCount(), outputType.getFieldCount())); + udtfOutputType = new RowType(udtfOutputDataFields); + super.open(); + } + + @Override + public void processElement(StreamRecord element) throws Exception { + bufferInput(element.getValue()); + super.processElement(element); + emitResults(); + } + + @Override + public PythonFunctionRunner createPythonFunctionRunner() throws Exception { + final FnDataReceiver udtfResultReceiver = input -> { + // handover to queue, do not block the result receiver thread + udtfResultQueue.put(input); + }; + + return new ProjectUdfInputPythonTableFunctionRunner( + createPythonFunctionRunner( + udtfResultReceiver, + createPythonEnvironmentManager())); + } + + @Override + public PythonEnv getPythonEnv() { + return tableFunction.getPythonFunction().getPythonEnv(); + } + + /** + * Buffers the specified input, it will be used to construct + * the operator result together with the udtf execution result. + */ + public abstract void bufferInput(IN input); + + public abstract UDTFIN getUdtfInput(IN element); + + public abstract PythonFunctionRunner createPythonFunctionRunner( + FnDataReceiver resultReceiver, + PythonEnvironmentManager pythonEnvironmentManager); + + private class ProjectUdfInputPythonTableFunctionRunner implements PythonFunctionRunner { + + private final PythonFunctionRunner pythonFunctionRunner; + + ProjectUdfInputPythonTableFunctionRunner(PythonFunctionRunner pythonFunctionRunner) { + this.pythonFunctionRunner = pythonFunctionRunner; + } + + @Override + public void open() throws Exception { + pythonFunctionRunner.open(); + } + + @Override + public void close() throws Exception { + pythonFunctionRunner.close(); + } + + @Override + public void startBundle() throws Exception { + pythonFunctionRunner.startBundle(); + } + + @Override + public void finishBundle() throws Exception { + pythonFunctionRunner.finishBundle(); + } + + @Override + public void processElement(IN element) throws Exception { + pythonFunctionRunner.processElement(getUdtfInput(element)); + } + } +} diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperator.java new file mode 100644 index 00000000000000..248b11ba08304f --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperator.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.python; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.python.PythonFunctionRunner; +import org.apache.flink.python.env.PythonEnvironmentManager; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.functions.TableFunction; +import org.apache.flink.table.functions.python.PythonFunctionInfo; +import org.apache.flink.table.runtime.runners.python.PythonTableFunctionRunner; +import org.apache.flink.table.runtime.types.CRow; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; + +import org.apache.beam.sdk.fn.data.FnDataReceiver; + +/** + * The Python {@link TableFunction} operator for the legacy planner. + */ +public class PythonTableFunctionOperator extends AbstractPythonTableFunctionOperator { + + private static final long serialVersionUID = 1L; + + /** + * The collector used to collect records. + */ + private transient StreamRecordCRowWrappingCollector cRowWrapper; + + public PythonTableFunctionOperator( + Configuration config, + PythonFunctionInfo tableFunction, + RowType inputType, + RowType outputType, + int[] udtfInputOffsets) { + super(config, tableFunction, inputType, outputType, udtfInputOffsets); + } + + @Override + public void open() throws Exception { + super.open(); + this.cRowWrapper = new StreamRecordCRowWrappingCollector(output); + } + + private boolean isFinishResult(Row result) { + return result.getArity() == 0; + } + + @Override + public void emitResults() { + Row udtfResult; + CRow input = null; + while ((udtfResult = udtfResultQueue.poll()) != null) { + if (input == null) { + input = forwardedInputQueue.poll(); + } + if (isFinishResult(udtfResult)) { + input = forwardedInputQueue.poll(); + } + if (input != null && !isFinishResult(udtfResult)) { + cRowWrapper.setChange(input.change()); + cRowWrapper.collect(Row.join(input.row(), udtfResult)); + } + } + } + + @Override + public void bufferInput(CRow input) { + forwardedInputQueue.add(input); + } + + @Override + public Row getUdtfInput(CRow element) { + return Row.project(element.row(), udtfInputOffsets); + } + + @Override + public PythonFunctionRunner createPythonFunctionRunner( + FnDataReceiver resultReceiver, + PythonEnvironmentManager pythonEnvironmentManager) { + return new PythonTableFunctionRunner( + getRuntimeContext().getTaskName(), + resultReceiver, + tableFunction, + pythonEnvironmentManager, + udtfInputType, + udtfOutputType); + } + + /** + * The collector is used to convert a {@link Row} to a {@link CRow}. + */ + private static class StreamRecordCRowWrappingCollector implements Collector { + + private final Collector> out; + private final CRow reuseCRow = new CRow(); + + /** + * For Table API & SQL jobs, the timestamp field is not used. + */ + private final StreamRecord reuseStreamRecord = new StreamRecord<>(reuseCRow); + + StreamRecordCRowWrappingCollector(Collector> out) { + this.out = out; + } + + public void setChange(boolean change) { + this.reuseCRow.change_$eq(change); + } + + @Override + public void collect(Row record) { + reuseCRow.row_$eq(record); + out.collect(reuseStreamRecord); + } + + @Override + public void close() { + out.close(); + } + } +} diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonScalarFunctionRunner.java b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonScalarFunctionRunner.java index d85817acee540b..3f290810f40521 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonScalarFunctionRunner.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonScalarFunctionRunner.java @@ -21,7 +21,6 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.fnexecution.v1.FlinkFnApi; -import org.apache.flink.python.AbstractPythonFunctionRunner; import org.apache.flink.python.PythonFunctionRunner; import org.apache.flink.python.env.PythonEnvironmentManager; import org.apache.flink.table.functions.ScalarFunction; @@ -30,21 +29,9 @@ import org.apache.flink.table.types.logical.RowType; import org.apache.flink.util.Preconditions; -import com.google.protobuf.ByteString; import org.apache.beam.model.pipeline.v1.RunnerApi; -import org.apache.beam.runners.core.construction.ModelCoders; -import org.apache.beam.runners.core.construction.graph.ExecutableStage; -import org.apache.beam.runners.core.construction.graph.ImmutableExecutableStage; -import org.apache.beam.runners.core.construction.graph.PipelineNode; -import org.apache.beam.runners.core.construction.graph.SideInputReference; -import org.apache.beam.runners.core.construction.graph.TimerReference; -import org.apache.beam.runners.core.construction.graph.UserStateReference; -import org.apache.beam.runners.fnexecution.state.StateRequestHandler; import org.apache.beam.sdk.fn.data.FnDataReceiver; -import java.util.Collections; -import java.util.List; - /** * Abstract {@link PythonFunctionRunner} used to execute Python {@link ScalarFunction}s. * @@ -52,27 +39,12 @@ * @param Type of the execution results. */ @Internal -public abstract class AbstractPythonScalarFunctionRunner extends AbstractPythonFunctionRunner { +public abstract class AbstractPythonScalarFunctionRunner extends AbstractPythonStatelessFunctionRunner { private static final String SCHEMA_CODER_URN = "flink:coder:schema:v1"; private static final String SCALAR_FUNCTION_URN = "flink:transform:scalar_function:v1"; - private static final String INPUT_ID = "input"; - private static final String OUTPUT_ID = "output"; - private static final String TRANSFORM_ID = "transform"; - - private static final String MAIN_INPUT_NAME = "input"; - private static final String MAIN_OUTPUT_NAME = "output"; - - private static final String INPUT_CODER_ID = "input_coder"; - private static final String OUTPUT_CODER_ID = "output_coder"; - private static final String WINDOW_CODER_ID = "window_coder"; - - private static final String WINDOW_STRATEGY = "windowing_strategy"; - private final PythonFunctionInfo[] scalarFunctions; - private final RowType inputType; - private final RowType outputType; public AbstractPythonScalarFunctionRunner( String taskName, @@ -81,85 +53,8 @@ public AbstractPythonScalarFunctionRunner( PythonEnvironmentManager environmentManager, RowType inputType, RowType outputType) { - super(taskName, resultReceiver, environmentManager, StateRequestHandler.unsupported()); + super(taskName, resultReceiver, environmentManager, inputType, outputType, SCALAR_FUNCTION_URN); this.scalarFunctions = Preconditions.checkNotNull(scalarFunctions); - this.inputType = Preconditions.checkNotNull(inputType); - this.outputType = Preconditions.checkNotNull(outputType); - } - - /** - * Gets the logical type of the input elements of the Python user-defined functions. - */ - public RowType getInputType() { - return inputType; - } - - /** - * Gets the logical type of the execution results of the Python user-defined functions. - */ - public RowType getOutputType() { - return outputType; - } - - @Override - @SuppressWarnings("unchecked") - public ExecutableStage createExecutableStage() throws Exception { - RunnerApi.Components components = - RunnerApi.Components.newBuilder() - .putPcollections( - INPUT_ID, - RunnerApi.PCollection.newBuilder() - .setWindowingStrategyId(WINDOW_STRATEGY) - .setCoderId(INPUT_CODER_ID) - .build()) - .putPcollections( - OUTPUT_ID, - RunnerApi.PCollection.newBuilder() - .setWindowingStrategyId(WINDOW_STRATEGY) - .setCoderId(OUTPUT_CODER_ID) - .build()) - .putTransforms( - TRANSFORM_ID, - RunnerApi.PTransform.newBuilder() - .setUniqueName(TRANSFORM_ID) - .setSpec(RunnerApi.FunctionSpec.newBuilder() - .setUrn(SCALAR_FUNCTION_URN) - .setPayload( - org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString.copyFrom( - getUserDefinedFunctionsProto().toByteArray())) - .build()) - .putInputs(MAIN_INPUT_NAME, INPUT_ID) - .putOutputs(MAIN_OUTPUT_NAME, OUTPUT_ID) - .build()) - .putWindowingStrategies( - WINDOW_STRATEGY, - RunnerApi.WindowingStrategy.newBuilder() - .setWindowCoderId(WINDOW_CODER_ID) - .build()) - .putCoders( - INPUT_CODER_ID, - getInputCoderProto()) - .putCoders( - OUTPUT_CODER_ID, - getOutputCoderProto()) - .putCoders( - WINDOW_CODER_ID, - getWindowCoderProto()) - .build(); - - PipelineNode.PCollectionNode input = - PipelineNode.pCollection(INPUT_ID, components.getPcollectionsOrThrow(INPUT_ID)); - List sideInputs = Collections.EMPTY_LIST; - List userStates = Collections.EMPTY_LIST; - List timers = Collections.EMPTY_LIST; - List transforms = - Collections.singletonList( - PipelineNode.pTransform(TRANSFORM_ID, components.getTransformsOrThrow(TRANSFORM_ID))); - List outputs = - Collections.singletonList( - PipelineNode.pCollection(OUTPUT_ID, components.getPcollectionsOrThrow(OUTPUT_ID))); - return ImmutableExecutableStage.of( - components, createPythonExecutionEnvironment(), input, sideInputs, userStates, timers, transforms, outputs); } /** @@ -174,36 +69,20 @@ public FlinkFnApi.UserDefinedFunctions getUserDefinedFunctionsProto() { return builder.build(); } - private FlinkFnApi.UserDefinedFunction getUserDefinedFunctionProto(PythonFunctionInfo pythonFunctionInfo) { - FlinkFnApi.UserDefinedFunction.Builder builder = FlinkFnApi.UserDefinedFunction.newBuilder(); - builder.setPayload(ByteString.copyFrom(pythonFunctionInfo.getPythonFunction().getSerializedPythonFunction())); - for (Object input : pythonFunctionInfo.getInputs()) { - FlinkFnApi.UserDefinedFunction.Input.Builder inputProto = - FlinkFnApi.UserDefinedFunction.Input.newBuilder(); - if (input instanceof PythonFunctionInfo) { - inputProto.setUdf(getUserDefinedFunctionProto((PythonFunctionInfo) input)); - } else if (input instanceof Integer) { - inputProto.setInputOffset((Integer) input); - } else { - inputProto.setInputConstant(ByteString.copyFrom((byte[]) input)); - } - builder.addInputs(inputProto); - } - return builder.build(); - } - /** * Gets the proto representation of the input coder. */ - private RunnerApi.Coder getInputCoderProto() { - return getRowCoderProto(inputType); + @Override + RunnerApi.Coder getInputCoderProto() { + return getRowCoderProto(getInputType()); } /** * Gets the proto representation of the output coder. */ - private RunnerApi.Coder getOutputCoderProto() { - return getRowCoderProto(outputType); + @Override + RunnerApi.Coder getOutputCoderProto() { + return getRowCoderProto(getOutputType()); } private RunnerApi.Coder getRowCoderProto(RowType rowType) { @@ -216,16 +95,4 @@ private RunnerApi.Coder getRowCoderProto(RowType rowType) { .build()) .build(); } - - /** - * Gets the proto representation of the window coder. - */ - private RunnerApi.Coder getWindowCoderProto() { - return RunnerApi.Coder.newBuilder() - .setSpec( - RunnerApi.FunctionSpec.newBuilder() - .setUrn(ModelCoders.GLOBAL_WINDOW_CODER_URN) - .build()) - .build(); - } } diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonStatelessFunctionRunner.java b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonStatelessFunctionRunner.java new file mode 100644 index 00000000000000..0e8363317c8255 --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonStatelessFunctionRunner.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.runners.python; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.fnexecution.v1.FlinkFnApi; +import org.apache.flink.python.AbstractPythonFunctionRunner; +import org.apache.flink.python.PythonFunctionRunner; +import org.apache.flink.python.env.PythonEnvironmentManager; +import org.apache.flink.table.functions.python.PythonFunctionInfo; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.util.Preconditions; + +import com.google.protobuf.ByteString; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.construction.ModelCoders; +import org.apache.beam.runners.core.construction.graph.ExecutableStage; +import org.apache.beam.runners.core.construction.graph.ImmutableExecutableStage; +import org.apache.beam.runners.core.construction.graph.PipelineNode; +import org.apache.beam.runners.core.construction.graph.SideInputReference; +import org.apache.beam.runners.core.construction.graph.TimerReference; +import org.apache.beam.runners.core.construction.graph.UserStateReference; +import org.apache.beam.runners.fnexecution.state.StateRequestHandler; +import org.apache.beam.sdk.fn.data.FnDataReceiver; + +import java.util.Collections; +import java.util.List; + +/** + * Abstract {@link PythonFunctionRunner} used to execute Python stateless functions.. + * + * @param Type of the input elements. + * @param Type of the execution results. + */ +@Internal +public abstract class AbstractPythonStatelessFunctionRunner extends AbstractPythonFunctionRunner { + + private static final String INPUT_ID = "input"; + private static final String OUTPUT_ID = "output"; + private static final String TRANSFORM_ID = "transform"; + + private static final String MAIN_INPUT_NAME = "input"; + private static final String MAIN_OUTPUT_NAME = "output"; + + private static final String INPUT_CODER_ID = "input_coder"; + private static final String OUTPUT_CODER_ID = "output_coder"; + private static final String WINDOW_CODER_ID = "window_coder"; + + private static final String WINDOW_STRATEGY = "windowing_strategy"; + + private final String functionUrn; + + private final RowType inputType; + private final RowType outputType; + + public AbstractPythonStatelessFunctionRunner( + String taskName, + FnDataReceiver resultReceiver, + PythonEnvironmentManager environmentManager, + RowType inputType, + RowType outputType, + String functionUrn) { + super(taskName, resultReceiver, environmentManager, StateRequestHandler.unsupported()); + this.functionUrn = functionUrn; + this.inputType = Preconditions.checkNotNull(inputType); + this.outputType = Preconditions.checkNotNull(outputType); + } + + @Override + @SuppressWarnings("unchecked") + public ExecutableStage createExecutableStage() throws Exception { + RunnerApi.Components components = + RunnerApi.Components.newBuilder() + .putPcollections( + INPUT_ID, + RunnerApi.PCollection.newBuilder() + .setWindowingStrategyId(WINDOW_STRATEGY) + .setCoderId(INPUT_CODER_ID) + .build()) + .putPcollections( + OUTPUT_ID, + RunnerApi.PCollection.newBuilder() + .setWindowingStrategyId(WINDOW_STRATEGY) + .setCoderId(OUTPUT_CODER_ID) + .build()) + .putTransforms( + TRANSFORM_ID, + RunnerApi.PTransform.newBuilder() + .setUniqueName(TRANSFORM_ID) + .setSpec(RunnerApi.FunctionSpec.newBuilder() + .setUrn(functionUrn) + .setPayload( + org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString.copyFrom( + getUserDefinedFunctionsProto().toByteArray())) + .build()) + .putInputs(MAIN_INPUT_NAME, INPUT_ID) + .putOutputs(MAIN_OUTPUT_NAME, OUTPUT_ID) + .build()) + .putWindowingStrategies( + WINDOW_STRATEGY, + RunnerApi.WindowingStrategy.newBuilder() + .setWindowCoderId(WINDOW_CODER_ID) + .build()) + .putCoders( + INPUT_CODER_ID, + getInputCoderProto()) + .putCoders( + OUTPUT_CODER_ID, + getOutputCoderProto()) + .putCoders( + WINDOW_CODER_ID, + getWindowCoderProto()) + .build(); + + PipelineNode.PCollectionNode input = + PipelineNode.pCollection(INPUT_ID, components.getPcollectionsOrThrow(INPUT_ID)); + List sideInputs = Collections.EMPTY_LIST; + List userStates = Collections.EMPTY_LIST; + List timers = Collections.EMPTY_LIST; + List transforms = + Collections.singletonList( + PipelineNode.pTransform(TRANSFORM_ID, components.getTransformsOrThrow(TRANSFORM_ID))); + List outputs = + Collections.singletonList( + PipelineNode.pCollection(OUTPUT_ID, components.getPcollectionsOrThrow(OUTPUT_ID))); + return ImmutableExecutableStage.of( + components, createPythonExecutionEnvironment(), input, sideInputs, userStates, timers, transforms, outputs); + } + + FlinkFnApi.UserDefinedFunction getUserDefinedFunctionProto(PythonFunctionInfo pythonFunctionInfo) { + FlinkFnApi.UserDefinedFunction.Builder builder = FlinkFnApi.UserDefinedFunction.newBuilder(); + builder.setPayload(ByteString.copyFrom(pythonFunctionInfo.getPythonFunction().getSerializedPythonFunction())); + for (Object input : pythonFunctionInfo.getInputs()) { + FlinkFnApi.UserDefinedFunction.Input.Builder inputProto = + FlinkFnApi.UserDefinedFunction.Input.newBuilder(); + if (input instanceof PythonFunctionInfo) { + inputProto.setUdf(getUserDefinedFunctionProto((PythonFunctionInfo) input)); + } else if (input instanceof Integer) { + inputProto.setInputOffset((Integer) input); + } else { + inputProto.setInputConstant(ByteString.copyFrom((byte[]) input)); + } + builder.addInputs(inputProto); + } + return builder.build(); + } + + /** + * Gets the logical type of the input elements of the Python user-defined functions. + */ + public RowType getInputType() { + return inputType; + } + + /** + * Gets the logical type of the execution results of the Python user-defined functions. + */ + public RowType getOutputType() { + return outputType; + } + + /** + * Gets the proto representation of the input coder. + */ + abstract RunnerApi.Coder getInputCoderProto(); + + /** + * Gets the proto representation of the output coder. + */ + abstract RunnerApi.Coder getOutputCoderProto(); + + /** + * Gets the proto representation of the window coder. + */ + private RunnerApi.Coder getWindowCoderProto() { + return RunnerApi.Coder.newBuilder() + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(ModelCoders.GLOBAL_WINDOW_CODER_URN) + .build()) + .build(); + } + + /** + * Gets the proto representation of the Python user-defined functions to be executed. + */ + @VisibleForTesting + public abstract FlinkFnApi.UserDefinedFunctions getUserDefinedFunctionsProto(); +} diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonTableFunctionRunner.java b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonTableFunctionRunner.java new file mode 100644 index 00000000000000..da2122e5e75e67 --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/AbstractPythonTableFunctionRunner.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.runners.python; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.fnexecution.v1.FlinkFnApi; +import org.apache.flink.python.PythonFunctionRunner; +import org.apache.flink.python.env.PythonEnvironmentManager; +import org.apache.flink.table.functions.TableFunction; +import org.apache.flink.table.functions.python.PythonFunctionInfo; +import org.apache.flink.table.runtime.typeutils.PythonTypeUtils; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.util.Preconditions; + +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.fn.data.FnDataReceiver; + +/** + * Abstract {@link PythonFunctionRunner} used to execute Python {@link TableFunction}. + * + * @param Type of the input elements. + * @param Type of the execution results. + */ +public abstract class AbstractPythonTableFunctionRunner extends AbstractPythonStatelessFunctionRunner { + + private static final String SCHEMA_CODER_URN = "flink:coder:schema:v1"; + private static final String TABLE_FUNCTION_URN = "flink:transform:table_function:v1"; + + private final PythonFunctionInfo tableFunction; + + public AbstractPythonTableFunctionRunner( + String taskName, + FnDataReceiver resultReceiver, + PythonFunctionInfo tableFunction, + PythonEnvironmentManager environmentManager, + RowType inputType, + RowType outputType) { + super(taskName, resultReceiver, environmentManager, inputType, outputType, TABLE_FUNCTION_URN); + this.tableFunction = Preconditions.checkNotNull(tableFunction); + } + + /** + * Gets the proto representation of the Python user-defined functions to be executed. + */ + @VisibleForTesting + public FlinkFnApi.UserDefinedFunctions getUserDefinedFunctionsProto() { + FlinkFnApi.UserDefinedFunctions.Builder builder = FlinkFnApi.UserDefinedFunctions.newBuilder(); + builder.addUdfs(getUserDefinedFunctionProto(tableFunction)); + return builder.build(); + } + + /** + * Gets the proto representation of the input coder. + */ + @Override + RunnerApi.Coder getInputCoderProto() { + return getTableCoderProto(getInputType()); + } + + /** + * Gets the proto representation of the output coder. + */ + @Override + RunnerApi.Coder getOutputCoderProto() { + return getTableCoderProto(getOutputType()); + } + + private RunnerApi.Coder getTableCoderProto(RowType rowType) { + return RunnerApi.Coder.newBuilder() + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(SCHEMA_CODER_URN) + .setPayload(org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString.copyFrom( + PythonTypeUtils.toTableProtoType(rowType).toByteArray())) + .build()) + .build(); + } +} diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/PythonTableFunctionRunner.java b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/PythonTableFunctionRunner.java new file mode 100644 index 00000000000000..f9762c68c2c819 --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/PythonTableFunctionRunner.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.runners.python; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.python.PythonFunctionRunner; +import org.apache.flink.python.env.PythonEnvironmentManager; +import org.apache.flink.table.functions.TableFunction; +import org.apache.flink.table.functions.python.PythonFunctionInfo; +import org.apache.flink.table.runtime.typeutils.PythonTypeUtils; +import org.apache.flink.table.runtime.typeutils.serializers.python.RowTableSerializer; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.types.Row; + +import org.apache.beam.sdk.fn.data.FnDataReceiver; + +/** + * A {@link PythonFunctionRunner} used to execute Python {@link TableFunction}. + * It takes {@link Row} as the input and output type. + */ +@Internal +public class PythonTableFunctionRunner extends AbstractPythonTableFunctionRunner { + + public PythonTableFunctionRunner( + String taskName, + FnDataReceiver resultReceiver, + PythonFunctionInfo tableFunction, + PythonEnvironmentManager environmentManager, + RowType inputType, + RowType outputType) { + super(taskName, resultReceiver, tableFunction, environmentManager, inputType, outputType); + } + + @Override + public TypeSerializer getInputTypeSerializer() { + return (RowTableSerializer) PythonTypeUtils.toFlinkTableTypeSerializer(getInputType()); + } + + @Override + public TypeSerializer getOutputTypeSerializer() { + return (RowTableSerializer) PythonTypeUtils.toFlinkTableTypeSerializer(getOutputType()); + } +} diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/PythonTypeUtils.java b/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/PythonTypeUtils.java index 03f60bb0915692..ac31ba8f4056e2 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/PythonTypeUtils.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/PythonTypeUtils.java @@ -38,6 +38,7 @@ import org.apache.flink.table.runtime.typeutils.serializers.python.BigDecSerializer; import org.apache.flink.table.runtime.typeutils.serializers.python.DateSerializer; import org.apache.flink.table.runtime.typeutils.serializers.python.DecimalSerializer; +import org.apache.flink.table.runtime.typeutils.serializers.python.RowTableSerializer; import org.apache.flink.table.runtime.typeutils.serializers.python.StringSerializer; import org.apache.flink.table.runtime.typeutils.serializers.python.TimeSerializer; import org.apache.flink.table.runtime.typeutils.serializers.python.TimestampSerializer; @@ -90,6 +91,39 @@ public static FlinkFnApi.Schema.FieldType toProtoType(LogicalType logicalType) { return logicalType.accept(new LogicalTypeToProtoTypeConverter()); } + public static TypeSerializer toFlinkTableTypeSerializer(LogicalType logicalType) { + RowType rowType = (RowType) logicalType; + LogicalTypeDefaultVisitor converter = + new LogicalTypeToTypeSerializerConverter(); + final TypeSerializer[] fieldTypeSerializers = rowType.getFields() + .stream() + .map(f -> f.getType().accept(converter)) + .toArray(TypeSerializer[]::new); + return new RowTableSerializer(fieldTypeSerializers); + } + + public static FlinkFnApi.Schema.FieldType toTableProtoType(LogicalType logicalType) { + RowType rowType = (RowType) logicalType; + FlinkFnApi.Schema.FieldType.Builder builder = + FlinkFnApi.Schema.FieldType.newBuilder() + .setTypeName(FlinkFnApi.Schema.TypeName.TABLE) + .setNullable(rowType.isNullable()); + + LogicalTypeDefaultVisitor converter = + new LogicalTypeToProtoTypeConverter(); + FlinkFnApi.Schema.Builder schemaBuilder = FlinkFnApi.Schema.newBuilder(); + for (RowType.RowField field : rowType.getFields()) { + schemaBuilder.addFields( + FlinkFnApi.Schema.Field.newBuilder() + .setName(field.getName()) + .setDescription(field.getDescription().orElse(EMPTY_STRING)) + .setType(field.getType().accept(converter)) + .build()); + } + builder.setRowSchema(schemaBuilder.build()); + return builder.build(); + } + /** * Convert LogicalType to conversion class for flink planner. */ diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/RowTableSerializer.java b/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/RowTableSerializer.java new file mode 100644 index 00000000000000..5f5d2f48a45c9f --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/RowTableSerializer.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.typeutils.serializers.python; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.typeutils.runtime.RowSerializer; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.types.Row; + +import java.io.IOException; + +/** + * The implementation of TableSerializer in legacy planner. + */ +@Internal +public final class RowTableSerializer extends TableSerializer { + + private final RowSerializer rowSerializer; + + public RowTableSerializer(TypeSerializer[] fieldSerializers) { + super(fieldSerializers); + this.rowSerializer = new RowSerializer(fieldSerializers); + } + + @Override + public Row createResult(int len) { + return new Row(len); + } + + @Override + public void setField(Row result, int index, Object value) { + result.setField(index, value); + } + + @Override + public void serialize(Row record, DataOutputView target) throws IOException { + rowSerializer.serialize(record, target); + } + + public RowSerializer getRowSerializer() { + return rowSerializer; + } +} diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/TableSerializer.java b/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/TableSerializer.java new file mode 100644 index 00000000000000..a53a390653814b --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/typeutils/serializers/python/TableSerializer.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.typeutils.serializers.python; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +import java.io.DataInputStream; +import java.io.IOException; +import java.util.Arrays; + +import static org.apache.flink.api.java.typeutils.runtime.NullMaskUtils.readIntoNullMask; + +/** + * Base Table Serializer for Table Function. + */ +@Internal +abstract class TableSerializer extends TypeSerializer { + + private final TypeSerializer[] fieldSerializers; + + private transient boolean[] nullMask; + + TableSerializer(TypeSerializer[] fieldSerializers) { + this.fieldSerializers = fieldSerializers; + this.nullMask = new boolean[Math.max(fieldSerializers.length - 8, 0)]; + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer duplicate() { + throw new RuntimeException("This method duplicate() should not be called"); + } + + @Override + public T createInstance() { + return unwantedMethodCall("createInstance()"); + } + + @Override + public T copy(T from) { + return unwantedMethodCall("copy(T from)"); + } + + @Override + public T copy(T from, T reuse) { + return unwantedMethodCall("copy(T from, T reuse)"); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public T deserialize(T reuse, DataInputView source) throws IOException { + return unwantedMethodCall("deserialize(T reuse, DataInputView source)"); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + unwantedMethodCall("copy(DataInputView source, DataOutputView target)"); + } + + private T unwantedMethodCall(String methodName) { + throw new RuntimeException(String.format("The method %s should not be called", methodName)); + } + + public abstract T createResult(int len); + + public abstract void setField(T result, int index, Object value); + + @Override + public T deserialize(DataInputView source) throws IOException { + int len = fieldSerializers.length; + int b = source.readUnsignedByte() & 0xff; + DataInputStream inputStream = (DataInputStream) source; + if (b == 0x00 && inputStream.available() == 0) { + return createResult(0); + } + T result = createResult(len); + int minLen = Math.min(8, len); + readIntoNullMask(len - 8, source, nullMask); + for (int i = 0; i < minLen; i++) { + if ((b & 0x80) > 0) { + setField(result, i, null); + } else { + setField(result, i, fieldSerializers[i].deserialize(source)); + } + b = b << 1; + } + for (int i = 0, j = minLen; j < len; i++, j++) { + if (nullMask[i]) { + setField(result, j, null); + } else { + setField(result, j, fieldSerializers[j].deserialize(source)); + } + } + + return result; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof TableSerializer) { + TableSerializer other = (TableSerializer) obj; + if (this.fieldSerializers.length == other.fieldSerializers.length) { + for (int i = 0; i < this.fieldSerializers.length; i++) { + if (!this.fieldSerializers[i].equals(other.fieldSerializers[i])) { + return false; + } + } + return true; + } + } + + return false; + } + + @Override + public int hashCode() { + return Arrays.hashCode(fieldSerializers); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + throw new RuntimeException("The method snapshotConfiguration() should not be called"); + } +} diff --git a/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonTableFunctionRunnerTest.java b/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonTableFunctionRunnerTest.java new file mode 100644 index 00000000000000..f0dcb5e145f885 --- /dev/null +++ b/flink-python/src/test/java/org/apache/flink/table/functions/python/AbstractPythonTableFunctionRunnerTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.python; + +import org.apache.flink.table.runtime.runners.python.AbstractPythonTableFunctionRunner; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.RowType; + +import java.util.Collections; + +/** + * Base class for PythonTableFunctionRunner and BaseRowPythonTableFunctionRunner test. + * + * @param Type of the input elements. + * @param Type of the output elements. + */ +public abstract class AbstractPythonTableFunctionRunnerTest { + AbstractPythonTableFunctionRunner createUDTFRunner() throws Exception { + PythonFunctionInfo pythonFunctionInfo = new PythonFunctionInfo( + DummyPythonFunction.INSTANCE, + new Integer[]{0}); + + RowType rowType = new RowType(Collections.singletonList(new RowType.RowField("f1", new BigIntType()))); + return createPythonTableFunctionRunner(pythonFunctionInfo, rowType, rowType); + } + + public abstract AbstractPythonTableFunctionRunner createPythonTableFunctionRunner( + PythonFunctionInfo pythonFunctionInfo, RowType inputType, RowType outputType) throws Exception; + + /** + * Dummy PythonFunction. + */ + public static class DummyPythonFunction implements PythonFunction { + + private static final long serialVersionUID = 1L; + + public static final PythonFunction INSTANCE = new DummyPythonFunction(); + + @Override + public byte[] getSerializedPythonFunction() { + return new byte[0]; + } + + @Override + public PythonEnv getPythonEnv() { + return new PythonEnv(PythonEnv.ExecType.PROCESS); + } + } +} diff --git a/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTableFunctionRunnerTest.java b/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTableFunctionRunnerTest.java new file mode 100644 index 00000000000000..9532fdc90051be --- /dev/null +++ b/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTableFunctionRunnerTest.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.python; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.fnexecution.v1.FlinkFnApi; +import org.apache.flink.python.env.ProcessPythonEnvironmentManager; +import org.apache.flink.python.env.PythonDependencyInfo; +import org.apache.flink.python.env.PythonEnvironmentManager; +import org.apache.flink.table.runtime.runners.python.AbstractPythonTableFunctionRunner; +import org.apache.flink.table.runtime.runners.python.PythonTableFunctionRunner; +import org.apache.flink.table.runtime.typeutils.serializers.python.RowTableSerializer; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.types.Row; + +import org.apache.beam.runners.fnexecution.control.JobBundleFactory; +import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.Struct; +import org.junit.Test; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Tests for {@link PythonTableFunctionRunner}. These test that: + * + *
    T + *
  • The input data type and output data type are properly constructed
  • + *
  • The UDTF proto is properly constructed
  • + *
+ */ +public class PythonTableFunctionRunnerTest extends AbstractPythonTableFunctionRunnerTest { + + @Test + public void testInputOutputDataTypeConstructedProperlyForSingleUDTF() throws Exception { + final AbstractPythonTableFunctionRunner runner = createUDTFRunner(); + + // check input TypeSerializer + TypeSerializer inputTypeSerializer = runner.getInputTypeSerializer(); + assertTrue(inputTypeSerializer instanceof RowTableSerializer); + + assertEquals(1, ((RowTableSerializer) inputTypeSerializer).getRowSerializer().getArity()); + + // check output TypeSerializer + TypeSerializer outputTypeSerializer = runner.getOutputTypeSerializer(); + assertTrue(outputTypeSerializer instanceof RowTableSerializer); + assertEquals(1, ((RowTableSerializer) outputTypeSerializer).getRowSerializer().getArity()); + } + + @Test + public void testUDFnProtoConstructedProperlyForSingleUTDF() throws Exception { + final AbstractPythonTableFunctionRunner runner = createUDTFRunner(); + + FlinkFnApi.UserDefinedFunctions udtfs = runner.getUserDefinedFunctionsProto(); + assertEquals(1, udtfs.getUdfsCount()); + + FlinkFnApi.UserDefinedFunction udtf = udtfs.getUdfs(0); + assertEquals(1, udtf.getInputsCount()); + assertEquals(0, udtf.getInputs(0).getInputOffset()); + } + + @Override + public AbstractPythonTableFunctionRunner createPythonTableFunctionRunner( + PythonFunctionInfo pythonFunctionInfo, + RowType inputType, + RowType outputType) throws Exception { + final FnDataReceiver dummyReceiver = input -> { + // ignore the execution results + }; + + final PythonEnvironmentManager environmentManager = + new ProcessPythonEnvironmentManager( + new PythonDependencyInfo(new HashMap<>(), null, null, new HashMap<>(), null), + new String[]{System.getProperty("java.io.tmpdir")}, + new HashMap<>()); + + return new PythonTableFunctionRunner( + "testPythonRunner", + dummyReceiver, + pythonFunctionInfo, + environmentManager, + inputType, + outputType); + } + + private AbstractPythonTableFunctionRunner createUDTFRunner( + JobBundleFactory jobBundleFactory, FnDataReceiver receiver) throws IOException { + PythonFunctionInfo pythonFunctionInfo = new PythonFunctionInfo( + DummyPythonFunction.INSTANCE, + new Integer[]{0}); + + RowType rowType = new RowType(Collections.singletonList(new RowType.RowField("f1", new BigIntType()))); + + final PythonEnvironmentManager environmentManager = + new ProcessPythonEnvironmentManager( + new PythonDependencyInfo(new HashMap<>(), null, null, new HashMap<>(), null), + new String[]{System.getProperty("java.io.tmpdir")}, + new HashMap<>()); + + return new PythonTableFunctionRunnerTestHarness( + "testPythonRunner", + receiver, + pythonFunctionInfo, + environmentManager, + rowType, + rowType, + jobBundleFactory); + } + + private static class PythonTableFunctionRunnerTestHarness extends PythonTableFunctionRunner { + + private final JobBundleFactory jobBundleFactory; + + PythonTableFunctionRunnerTestHarness( + String taskName, + FnDataReceiver resultReceiver, + PythonFunctionInfo tableFunction, + PythonEnvironmentManager environmentManager, + RowType inputType, + RowType outputType, + JobBundleFactory jobBundleFactory) { + super(taskName, resultReceiver, tableFunction, environmentManager, inputType, outputType); + this.jobBundleFactory = jobBundleFactory; + } + + @Override + public JobBundleFactory createJobBundleFactory(Struct pipelineOptions) throws Exception { + return jobBundleFactory; + } + } +} diff --git a/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTypeUtilsTest.java b/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTypeUtilsTest.java index e697ef4a0c7223..22e3837a521fd9 100644 --- a/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTypeUtilsTest.java +++ b/flink-python/src/test/java/org/apache/flink/table/functions/python/PythonTypeUtilsTest.java @@ -24,6 +24,7 @@ import org.apache.flink.table.catalog.UnresolvedIdentifier; import org.apache.flink.table.runtime.typeutils.BaseRowSerializer; import org.apache.flink.table.runtime.typeutils.PythonTypeUtils; +import org.apache.flink.table.runtime.typeutils.serializers.python.RowTableSerializer; import org.apache.flink.table.types.logical.ArrayType; import org.apache.flink.table.types.logical.BigIntType; import org.apache.flink.table.types.logical.DateType; @@ -57,6 +58,17 @@ public void testLogicalTypeToFlinkTypeSerializer() { assertEquals(1, ((RowSerializer) rowSerializer).getArity()); } + @Test + public void testLogicalTypeToFlinkTableTypeSerializer() { + List rowFields = new ArrayList<>(); + rowFields.add(new RowType.RowField("f1", new BigIntType())); + RowType rowType = new RowType(rowFields); + TypeSerializer rowTableSerializer = PythonTypeUtils.toFlinkTableTypeSerializer(rowType); + assertTrue(rowTableSerializer instanceof RowTableSerializer); + + assertEquals(1, ((RowTableSerializer) rowTableSerializer).getRowSerializer().getArity()); + } + @Test public void testLogicalTypeToBlinkTypeSerializer() { List rowFields = new ArrayList<>(); diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/AbstractPassThroughPythonTableFunctionRunner.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/AbstractPassThroughPythonTableFunctionRunner.java new file mode 100644 index 00000000000000..bbc25da11c420a --- /dev/null +++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/AbstractPassThroughPythonTableFunctionRunner.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.python; + +import org.apache.flink.python.PythonFunctionRunner; +import org.apache.flink.util.Preconditions; + +import org.apache.beam.sdk.fn.data.FnDataReceiver; + +import java.util.ArrayList; +import java.util.List; + +/** + * A {@link PythonFunctionRunner} that just emit each input element for Python UDTF. + * + * @param Type of the input elements. + */ +public abstract class AbstractPassThroughPythonTableFunctionRunner implements PythonFunctionRunner { + + protected boolean bundleStarted; + protected final List bufferedElements; + protected final FnDataReceiver resultReceiver; + + AbstractPassThroughPythonTableFunctionRunner(FnDataReceiver resultReceiver) { + this.resultReceiver = Preconditions.checkNotNull(resultReceiver); + bundleStarted = false; + bufferedElements = new ArrayList<>(); + } + + @Override + public void open() {} + + @Override + public void close() {} + + @Override + public void startBundle() { + Preconditions.checkState(!bundleStarted); + bundleStarted = true; + } + + @Override + public void processElement(IN element) { + bufferedElements.add(copy(element)); + } + + public abstract IN copy(IN element); +} diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PassThroughPythonTableFunctionRunner.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PassThroughPythonTableFunctionRunner.java new file mode 100644 index 00000000000000..2ecccf9fc129c3 --- /dev/null +++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PassThroughPythonTableFunctionRunner.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.python; + +import org.apache.flink.table.runtime.runners.python.PythonTableFunctionRunner; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.beam.sdk.fn.data.FnDataReceiver; + +/** + * A {@link PythonTableFunctionRunner} that just emit each input element. + */ +public class PassThroughPythonTableFunctionRunner extends AbstractPassThroughPythonTableFunctionRunner { + PassThroughPythonTableFunctionRunner(FnDataReceiver resultReceiver) { + super(resultReceiver); + } + + @Override + public Row copy(Row element) { + return Row.copy(element); + } + + @Override + public void finishBundle() throws Exception { + Preconditions.checkState(bundleStarted); + bundleStarted = false; + + for (Row element : bufferedElements) { + resultReceiver.accept(element); + resultReceiver.accept(new Row(0)); + } + bufferedElements.clear(); + } +} diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTest.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTest.java new file mode 100644 index 00000000000000..ffaaeec293dfd5 --- /dev/null +++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.python; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.python.PythonFunctionRunner; +import org.apache.flink.python.env.PythonEnvironmentManager; +import org.apache.flink.streaming.util.TestHarnessUtil; +import org.apache.flink.table.functions.python.PythonFunctionInfo; +import org.apache.flink.table.runtime.types.CRow; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.types.Row; + +import org.apache.beam.sdk.fn.data.FnDataReceiver; + +import java.util.Collection; +import java.util.Queue; + +/** + * Tests for {@link PythonTableFunctionOperator}. + */ +public class PythonTableFunctionOperatorTest extends PythonTableFunctionOperatorTestBase { + @Override + public AbstractPythonTableFunctionOperator getTestOperator( + Configuration config, + PythonFunctionInfo tableFunction, + RowType inputType, + RowType outputType, + int[] udfInputOffsets) { + return new PassThroughPythonTableFunctionOperator( + config, tableFunction, inputType, outputType, udfInputOffsets); + } + + @Override + public CRow newRow(boolean accumulateMsg, Object... fields) { + return new CRow(Row.of(fields), accumulateMsg); + } + + @Override + public void assertOutputEquals(String message, Collection expected, Collection actual) { + TestHarnessUtil.assertOutputEquals(message, (Queue) expected, (Queue) actual); + } + + private static class PassThroughPythonTableFunctionOperator extends PythonTableFunctionOperator { + + PassThroughPythonTableFunctionOperator( + Configuration config, + PythonFunctionInfo tableFunction, + RowType inputType, + RowType outputType, + int[] udfInputOffsets) { + super(config, tableFunction, inputType, outputType, udfInputOffsets); + } + + @Override + public PythonFunctionRunner createPythonFunctionRunner( + FnDataReceiver resultReceiver, + PythonEnvironmentManager pythonEnvironmentManager) { + return new PassThroughPythonTableFunctionRunner(resultReceiver); + } + } +} diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTestBase.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTestBase.java new file mode 100644 index 00000000000000..93ed9da956f4d3 --- /dev/null +++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/PythonTableFunctionOperatorTestBase.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.python; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.python.PythonOptions; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.table.functions.python.AbstractPythonTableFunctionRunnerTest; +import org.apache.flink.table.functions.python.PythonFunctionInfo; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; + +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collection; +import java.util.concurrent.ConcurrentLinkedQueue; + +/** + * Base class for Python table function operator test. These test that: + * + *
    + *
  • Retraction flag is correctly forwarded to the downstream
  • + *
  • FinishBundle is called when checkpoint is encountered
  • + *
  • Watermarks are buffered and only sent to downstream when finishedBundle is triggered
  • + *
+ * + * @param Type of the input elements. + * @param Type of the output elements. + * @param Type of the UDTF input type. + * @param Type of the UDTF input type. + */ +public abstract class PythonTableFunctionOperatorTestBase { + + @Test + public void testRetractionFieldKept() throws Exception { + OneInputStreamOperatorTestHarness testHarness = getTestHarness(new Configuration()); + long initialTime = 0L; + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + testHarness.open(); + + testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1)); + testHarness.processElement(new StreamRecord<>(newRow(false, "c3", "c4", 1L), initialTime + 2)); + testHarness.processElement(new StreamRecord<>(newRow(false, "c5", "c6", 2L), initialTime + 3)); + testHarness.close(); + + expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c2", 0L, 0L))); + expectedOutput.add(new StreamRecord<>(newRow(false, "c3", "c4", 1L, 1L))); + expectedOutput.add(new StreamRecord<>(newRow(false, "c5", "c6", 2L, 2L))); + + assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + } + + @Test + public void testFinishBundleTriggeredOnCheckpoint() throws Exception { + Configuration conf = new Configuration(); + conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 10); + OneInputStreamOperatorTestHarness testHarness = getTestHarness(conf); + + long initialTime = 0L; + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + testHarness.open(); + + testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1)); + + // checkpoint trigger finishBundle + testHarness.prepareSnapshotPreBarrier(0L); + + expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c2", 0L, 0L))); + + assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + + testHarness.close(); + } + + @Test + public void testFinishBundleTriggeredByCount() throws Exception { + Configuration conf = new Configuration(); + conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 2); + OneInputStreamOperatorTestHarness testHarness = getTestHarness(conf); + + long initialTime = 0L; + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + testHarness.open(); + + testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1)); + assertOutputEquals("FinishBundle should not be triggered.", expectedOutput, testHarness.getOutput()); + + testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 1L), initialTime + 2)); + expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c2", 0L, 0L))); + expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c2", 1L, 1L))); + + assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + + testHarness.close(); + } + + @Test + public void testFinishBundleTriggeredByTime() throws Exception { + Configuration conf = new Configuration(); + conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 10); + conf.setLong(PythonOptions.MAX_BUNDLE_TIME_MILLS, 1000L); + OneInputStreamOperatorTestHarness testHarness = getTestHarness(conf); + + long initialTime = 0L; + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + testHarness.open(); + + testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1)); + assertOutputEquals("FinishBundle should not be triggered.", expectedOutput, testHarness.getOutput()); + + testHarness.setProcessingTime(1000L); + expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c2", 0L, 0L))); + assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + + testHarness.close(); + } + + private OneInputStreamOperatorTestHarness getTestHarness(Configuration config) throws Exception { + RowType dataType = new RowType(Arrays.asList( + new RowType.RowField("f1", new VarCharType()), + new RowType.RowField("f2", new VarCharType()), + new RowType.RowField("f3", new BigIntType()))); + AbstractPythonTableFunctionOperator operator = getTestOperator( + config, + new PythonFunctionInfo( + AbstractPythonTableFunctionRunnerTest.DummyPythonFunction.INSTANCE, + new Integer[]{0}), + dataType, + dataType, + new int[]{2} + ); + + OneInputStreamOperatorTestHarness testHarness = + new OneInputStreamOperatorTestHarness<>(operator); + testHarness.getStreamConfig().setManagedMemoryFraction(0.5); + return testHarness; + } + + public abstract IN newRow(boolean accumulateMsg, Object... fields); + + public abstract void assertOutputEquals(String message, Collection expected, Collection actual); + + public abstract AbstractPythonTableFunctionOperator getTestOperator( + Configuration config, + PythonFunctionInfo tableFunction, + RowType inputType, + RowType outputType, + int[] udfInputOffsets); +}