diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/FunctionDefinition.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/FunctionDefinition.java index 16379285f2e287..e43e84121c3254 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/FunctionDefinition.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/FunctionDefinition.java @@ -40,6 +40,13 @@ public interface FunctionDefinition { */ FunctionKind getKind(); + /** + * Returns the language of function this definition describes. + */ + default FunctionLanguage getLanguage() { + return FunctionLanguage.JVM; + } + /** * Returns the set of requirements this definition demands. */ diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/FunctionLanguage.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/FunctionLanguage.java new file mode 100644 index 00000000000000..7b36b132a4719f --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/FunctionLanguage.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions; + +import org.apache.flink.annotation.PublicEvolving; + +/** + * Categorizes the language of a {@link FunctionDefinition}. + */ +@PublicEvolving +public enum FunctionLanguage { + + JVM, + + PYTHON +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/SimplePythonFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/SimplePythonFunction.java new file mode 100644 index 00000000000000..116151ad1d5fce --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/SimplePythonFunction.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.python; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.util.Preconditions; + +/** + * A simple implementation of {@link PythonFunction}. + */ +@Internal +public final class SimplePythonFunction implements PythonFunction { + + private static final long serialVersionUID = 1L; + + /** + * Serialized representation of the user-defined python function. + */ + private final byte[] serializedPythonFunction; + + /** + * Python execution environment. + */ + private final PythonEnv pythonEnv; + + public SimplePythonFunction(byte[] serializedPythonFunction, PythonEnv pythonEnv) { + this.serializedPythonFunction = Preconditions.checkNotNull(serializedPythonFunction); + this.pythonEnv = Preconditions.checkNotNull(pythonEnv); + } + + @Override + public byte[] getSerializedPythonFunction() { + return serializedPythonFunction; + } + + @Override + public PythonEnv getPythonEnv() { + return pythonEnv; + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/CalciteConfig.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/CalciteConfig.scala index 8064fa9b9aa43f..21cfff4d485ddd 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/CalciteConfig.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/CalciteConfig.scala @@ -52,6 +52,12 @@ class CalciteConfigBuilder { private var replaceLogicalOptRules: Boolean = false private var logicalOptRuleSets: List[RuleSet] = Nil + /** + * Defines the logical optimization rule set. + */ + private var replacePythonLogicalOptRules: Boolean = false + private var pythonLogicalOptRuleSets: List[RuleSet] = Nil + /** * Defines the physical optimization rule set. */ @@ -225,6 +231,8 @@ class CalciteConfigBuilder { replaceNormRules, getRuleSet(logicalOptRuleSets), replaceLogicalOptRules, + getRuleSet(pythonLogicalOptRuleSets), + replacePythonLogicalOptRules, getRuleSet(physicalOptRuleSets), replacePhysicalOptRules, getRuleSet(decoRuleSets), @@ -254,6 +262,10 @@ class CalciteConfig( val logicalOptRuleSet: Option[RuleSet], /** Whether this configuration replaces the built-in logical optimization rule set. */ val replacesLogicalOptRuleSet: Boolean, + /** A custom Python logical optimization rule set. */ + val pythonLogicalOptRuleSet: Option[RuleSet], + /** Whether this configuration replaces the built-in Python logical optimization rule set. */ + val replacesPythonLogicalOptRuleSet: Boolean, /** A custom physical optimization rule set. */ val physicalOptRuleSet: Option[RuleSet], /** Whether this configuration replaces the built-in physical optimization rule set. */ diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/BatchOptimizer.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/BatchOptimizer.scala index 1bb8c8256c7ae7..4a713e8d9f3959 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/BatchOptimizer.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/BatchOptimizer.scala @@ -54,7 +54,8 @@ class BatchOptimizer( val decorPlan = RelDecorrelator.decorrelateQuery(expandedPlan) val normalizedPlan = optimizeNormalizeLogicalPlan(decorPlan) val logicalPlan = optimizeLogicalPlan(normalizedPlan) - optimizePhysicalPlan(logicalPlan, FlinkConventions.DATASET) + val pythonizedLogicalPlan = optimizePythonLogicalPlan(logicalPlan) + optimizePhysicalPlan(pythonizedLogicalPlan, FlinkConventions.DATASET) } /** diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/Optimizer.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/Optimizer.scala index 8506749dbe751d..3f4245fe013088 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/Optimizer.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/Optimizer.scala @@ -86,6 +86,25 @@ abstract class Optimizer( } } + /** + * Returns the logical Python optimization rule set for this optimizer + * including a custom RuleSet configuration. + */ + protected def getPythonLogicalOptRuleSet: RuleSet = { + materializedConfig.pythonLogicalOptRuleSet match { + + case None => + getBuiltInPythonLogicalOptRuleSet + + case Some(ruleSet) => + if (materializedConfig.replacesLogicalOptRuleSet) { + ruleSet + } else { + RuleSets.ofList((getBuiltInPythonLogicalOptRuleSet.asScala ++ ruleSet.asScala).asJava) + } + } + } + /** * Returns the physical optimization rule set for this optimizer * including a custom RuleSet configuration. @@ -117,6 +136,13 @@ abstract class Optimizer( FlinkRuleSets.LOGICAL_OPT_RULES } + /** + * Returns the built-in Python logical optimization rules that are defined by the optimizer. + */ + protected def getBuiltInPythonLogicalOptRuleSet: RuleSet = { + FlinkRuleSets.LOGICAL_PYTHON_OPT_RULES + } + /** * Returns the built-in physical optimization rules that are defined by the optimizer. */ @@ -153,6 +179,19 @@ abstract class Optimizer( } } + protected def optimizePythonLogicalPlan(relNode: RelNode): RelNode = { + val logicalOptRuleSet = getPythonLogicalOptRuleSet + if (logicalOptRuleSet.iterator().hasNext) { + runHepPlannerSimultaneously( + HepMatchOrder.TOP_DOWN, + logicalOptRuleSet, + relNode, + relNode.getTraitSet) + } else { + relNode + } + } + protected def optimizeLogicalPlan(relNode: RelNode): RelNode = { val logicalOptRuleSet = getLogicalOptRuleSet val logicalOutputProps = relNode.getTraitSet.replace(FlinkConventions.LOGICAL).simplify() diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/StreamOptimizer.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/StreamOptimizer.scala index 30ca4861ca6955..20da9fe1b72961 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/StreamOptimizer.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/StreamOptimizer.scala @@ -64,8 +64,8 @@ class StreamOptimizer( RelTimeIndicatorConverter.convert(decorPlan, relBuilder.getRexBuilder) val normalizedPlan = optimizeNormalizeLogicalPlan(planWithMaterializedTimeAttributes) val logicalPlan = optimizeLogicalPlan(normalizedPlan) - - val physicalPlan = optimizePhysicalPlan(logicalPlan, FlinkConventions.DATASTREAM) + val pythonizedLogicalPlan = optimizePythonLogicalPlan(logicalPlan) + val physicalPlan = optimizePhysicalPlan(pythonizedLogicalPlan, FlinkConventions.DATASTREAM) optimizeDecoratePlan(physicalPlan, updatesAsRetraction) } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala index 36df67a0721a8c..34f4ba84bda289 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala @@ -34,7 +34,6 @@ trait CommonCalc { private[flink] def generateFunction[T <: Function]( generator: FunctionCodeGenerator, ruleDescription: String, - inputSchema: RowSchema, returnSchema: RowSchema, calcProjection: Seq[RexNode], calcCondition: Option[RexNode], diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonPythonCalc.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonPythonCalc.scala new file mode 100644 index 00000000000000..a495de7b668747 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonPythonCalc.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.nodes + +import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode} +import org.apache.flink.table.functions.FunctionLanguage +import org.apache.flink.table.functions.python.{PythonFunction, PythonFunctionInfo, SimplePythonFunction} +import org.apache.flink.table.functions.utils.ScalarSqlFunction + +import scala.collection.JavaConversions._ +import scala.collection.mutable + +trait CommonPythonCalc { + + private[flink] def extractPythonScalarFunctionInfos( + rexCalls: Array[RexCall]): (Array[Int], Array[PythonFunctionInfo]) = { + // using LinkedHashMap to keep the insert order + val inputNodes = new mutable.LinkedHashMap[RexNode, Integer]() + val pythonFunctionInfos = rexCalls.map(createPythonScalarFunctionInfo(_, inputNodes)) + + val udfInputOffsets = inputNodes.toArray.sortBy(_._2).map(_._1).map { + case inputRef: RexInputRef => inputRef.getIndex + } + (udfInputOffsets, pythonFunctionInfos) + } + + private[flink] def createPythonScalarFunctionInfo( + rexCall: RexCall, + inputNodes: mutable.Map[RexNode, Integer]): PythonFunctionInfo = rexCall.getOperator match { + case sfc: ScalarSqlFunction if sfc.getScalarFunction.getLanguage == FunctionLanguage.PYTHON => + val inputs = new mutable.ArrayBuffer[AnyRef]() + rexCall.getOperands.foreach { + case pythonRexCall: RexCall if pythonRexCall.getOperator.asInstanceOf[ScalarSqlFunction] + .getScalarFunction.getLanguage == FunctionLanguage.PYTHON => + // Continuous Python UDFs can be chained together + val argPythonInfo = createPythonScalarFunctionInfo(pythonRexCall, inputNodes) + inputs.append(argPythonInfo) + + case argNode: RexNode => + // For input arguments of RexInputRef, it's replaced with an offset into the input row; + inputNodes.get(argNode) match { + case Some(existing) => inputs.append(existing) + case None => + val inputOffset = Integer.valueOf(inputNodes.size) + inputs.append(inputOffset) + inputNodes.put(argNode, inputOffset) + } + } + + // Extracts the necessary information for Python function execution, such as + // the serialized Python function, the Python env, etc + val pythonFunction = new SimplePythonFunction( + sfc.getScalarFunction.asInstanceOf[PythonFunction].getSerializedPythonFunction, + sfc.getScalarFunction.asInstanceOf[PythonFunction].getPythonEnv) + new PythonFunctionInfo(pythonFunction, inputs.toArray) + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala index fd60bfe99cfc93..7d2aa58a6353ee 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala @@ -105,7 +105,6 @@ class DataSetCalc( val genFunction = generateFunction( generator, ruleDescription, - new RowSchema(getInput.getRowType), new RowSchema(getRowType), projection, condition, diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala index 07b53eb9cb49d1..f3e6afa7e47913 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala @@ -18,18 +18,15 @@ package org.apache.flink.table.plan.nodes.datastream -import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.core.Calc -import org.apache.calcite.rel.metadata.RelMetadataQuery -import org.apache.calcite.rel.{RelNode, RelWriter} +import org.apache.calcite.rel.RelNode import org.apache.calcite.rex.RexProgram import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.calcite.RelTimeIndicatorConverter import org.apache.flink.table.codegen.FunctionCodeGenerator -import org.apache.flink.table.plan.nodes.CommonCalc import org.apache.flink.table.plan.schema.RowSchema import org.apache.flink.table.planner.StreamPlanner import org.apache.flink.table.runtime.CRowProcessRunner @@ -49,11 +46,14 @@ class DataStreamCalc( schema: RowSchema, calcProgram: RexProgram, ruleDescription: String) - extends Calc(cluster, traitSet, input, calcProgram) - with CommonCalc - with DataStreamRel { - - override def deriveRowType(): RelDataType = schema.relDataType + extends DataStreamCalcBase( + cluster, + traitSet, + input, + inputSchema, + schema, + calcProgram, + ruleDescription) { override def copy(traitSet: RelTraitSet, child: RelNode, program: RexProgram): Calc = { new DataStreamCalc( @@ -66,28 +66,6 @@ class DataStreamCalc( ruleDescription) } - override def toString: String = calcToString(calcProgram, getExpressionString) - - override def explainTerms(pw: RelWriter): RelWriter = { - pw.input("input", getInput) - .item("select", selectionToString(calcProgram, getExpressionString)) - .itemIf("where", - conditionToString(calcProgram, getExpressionString), - calcProgram.getCondition != null) - } - - override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { - val child = this.getInput - val rowCnt = metadata.getRowCount(child) - computeSelfCost(calcProgram, planner, rowCnt) - } - - override def estimateRowCount(metadata: RelMetadataQuery): Double = { - val child = this.getInput - val rowCnt = metadata.getRowCount(child) - estimateRowCount(calcProgram, rowCnt) - } - override def translateToPlan( planner: StreamPlanner, queryConfig: StreamQueryConfig): DataStream[CRow] = { @@ -117,7 +95,6 @@ class DataStreamCalc( val genFunction = generateFunction( generator, ruleDescription, - inputSchema, schema, projection, condition, diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalcBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalcBase.scala new file mode 100644 index 00000000000000..5e62a9f2ddfceb --- /dev/null +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalcBase.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.nodes.datastream + +import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.Calc +import org.apache.calcite.rel.metadata.RelMetadataQuery +import org.apache.calcite.rel.{RelNode, RelWriter} +import org.apache.calcite.rex.RexProgram +import org.apache.flink.table.plan.nodes.CommonCalc +import org.apache.flink.table.plan.schema.RowSchema + +/** + * Base RelNode for data stream calc. + */ +abstract class DataStreamCalcBase( + cluster: RelOptCluster, + traitSet: RelTraitSet, + input: RelNode, + inputSchema: RowSchema, + schema: RowSchema, + calcProgram: RexProgram, + ruleDescription: String) + extends Calc(cluster, traitSet, input, calcProgram) + with CommonCalc + with DataStreamRel { + + override def deriveRowType(): RelDataType = schema.relDataType + + override def toString: String = calcToString(calcProgram, getExpressionString) + + override def explainTerms(pw: RelWriter): RelWriter = { + pw.input("input", getInput) + .item("select", selectionToString(calcProgram, getExpressionString)) + .itemIf("where", + conditionToString(calcProgram, getExpressionString), + calcProgram.getCondition != null) + } + + override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { + val child = this.getInput + val rowCnt = metadata.getRowCount(child) + computeSelfCost(calcProgram, planner, rowCnt) + } + + override def estimateRowCount(metadata: RelMetadataQuery): Double = { + val child = this.getInput + val rowCnt = metadata.getRowCount(child) + estimateRowCount(calcProgram, rowCnt) + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamPythonCalc.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamPythonCalc.scala new file mode 100644 index 00000000000000..50cec46d46aa12 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamPythonCalc.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.nodes.datastream + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.core.Calc +import org.apache.calcite.rex.{RexCall, RexInputRef, RexProgram} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.streaming.api.operators.OneInputStreamOperator +import org.apache.flink.table.api.StreamQueryConfig +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.codegen.FunctionCodeGenerator +import org.apache.flink.table.functions.python.PythonFunctionInfo +import org.apache.flink.table.plan.nodes.CommonPythonCalc +import org.apache.flink.table.plan.nodes.datastream.DataStreamPythonCalc._ +import org.apache.flink.table.plan.schema.RowSchema +import org.apache.flink.table.planner.StreamPlanner +import org.apache.flink.table.runtime.CRowProcessRunner +import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} +import org.apache.flink.table.types.logical.RowType +import org.apache.flink.table.types.utils.TypeConversions + +import scala.collection.JavaConversions._ + +class DataStreamPythonCalc( + cluster: RelOptCluster, + traitSet: RelTraitSet, + input: RelNode, + inputSchema: RowSchema, + schema: RowSchema, + calcProgram: RexProgram, + ruleDescription: String) + extends DataStreamCalcBase( + cluster, + traitSet, + input, + inputSchema, + schema, + calcProgram, + ruleDescription) + with CommonPythonCalc { + + override def copy(traitSet: RelTraitSet, child: RelNode, program: RexProgram): Calc = { + new DataStreamPythonCalc( + cluster, + traitSet, + child, + inputSchema, + schema, + program, + ruleDescription) + } + + private lazy val pythonRexCalls = calcProgram.getProjectList + .map(calcProgram.expandLocalRef) + .filter(_.isInstanceOf[RexCall]) + .map(_.asInstanceOf[RexCall]) + .toArray + + private lazy val (pythonUDFInputOffsets, pythonFunctionInfos) = + extractPythonScalarFunctionInfos(pythonRexCalls) + + private lazy val resultProjectList = { + var idx = 0 + calcProgram.getProjectList + .map(calcProgram.expandLocalRef) + .map { + case pythonCall: RexCall => + val inputRef = new RexInputRef(input.getRowType.getFieldCount + idx, pythonCall.getType) + idx += 1 + inputRef + case node => node + } + } + + override def translateToPlan( + planner: StreamPlanner, + queryConfig: StreamQueryConfig): DataStream[CRow] = { + val config = planner.getConfig + + val inputDataStream = + getInput.asInstanceOf[DataStreamRel].translateToPlan(planner, queryConfig) + + val inputParallelism = inputDataStream.getParallelism + + val pythonOperatorResultTypeInfo = new RowTypeInfo( + inputSchema.fieldTypeInfos ++ + pythonRexCalls.map(node => FlinkTypeFactory.toTypeInfo(node.getType)): _*) + + // Constructs the Python operator + val pythonOperatorInputRowType = TypeConversions.fromLegacyInfoToDataType( + inputSchema.typeInfo).getLogicalType.asInstanceOf[RowType] + val pythonOperatorOutputRowType = TypeConversions.fromLegacyInfoToDataType( + pythonOperatorResultTypeInfo).getLogicalType.asInstanceOf[RowType] + val pythonOperator = getPythonScalarFunctionOperator( + pythonOperatorInputRowType, pythonOperatorOutputRowType, pythonUDFInputOffsets) + + val pythonDataStream = inputDataStream + .transform( + calcOpName(calcProgram, getExpressionString), + CRowTypeInfo(schema.typeInfo), + pythonOperator) + // keep parallelism to ensure order of accumulate and retract messages + .setParallelism(inputParallelism) + + val generator = new FunctionCodeGenerator( + config, false, pythonOperatorResultTypeInfo) + + val genFunction = generateFunction( + generator, + ruleDescription, + schema, + resultProjectList, + None, + config, + classOf[ProcessFunction[CRow, CRow]]) + + val processFunc = new CRowProcessRunner( + genFunction.name, + genFunction.code, + CRowTypeInfo(schema.typeInfo)) + + pythonDataStream + .process(processFunc) + .name(calcOpName(calcProgram, getExpressionString)) + // keep parallelism to ensure order of accumulate and retract messages + .setParallelism(inputParallelism) + } + + private[flink] def getPythonScalarFunctionOperator( + inputRowType: RowType, + outputRowType: RowType, + udfInputOffsets: Array[Int]) = { + val clazz = Class.forName(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME) + val ctor = clazz.getConstructor( + classOf[Array[PythonFunctionInfo]], + classOf[RowType], + classOf[RowType], + classOf[Array[Int]], + classOf[Int]) + ctor.newInstance( + pythonFunctionInfos, + inputRowType, + outputRowType, + udfInputOffsets, + Integer.valueOf(inputSchema.arity)) + .asInstanceOf[OneInputStreamOperator[CRow, CRow]] + } +} + +object DataStreamPythonCalc { + val PYTHON_SCALAR_FUNCTION_OPERATOR_NAME = + "org.apache.flink.table.runtime.operators.python.PythonScalarFunctionOperator" +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index b7701cdde07517..96d5015694865a 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -19,6 +19,7 @@ package org.apache.flink.table.plan.rules import org.apache.calcite.rel.core.RelFactories +import org.apache.calcite.rel.rules import org.apache.calcite.rel.rules._ import org.apache.calcite.tools.{RuleSet, RuleSets} import org.apache.flink.table.plan.nodes.logical._ @@ -143,6 +144,13 @@ object FlinkRuleSets { FlinkLogicalWindowTableAggregate.CONVERTER ) + /** + * RuleSet to optimize plans for Python UDF execution + */ + val LOGICAL_PYTHON_OPT_RULES: RuleSet = RuleSets.ofList( + PythonScalarFunctionSplitRule.INSTANCE + ) + /** * RuleSet to normalize plans for batch / DataSet execution */ @@ -233,8 +241,9 @@ object FlinkRuleSets { StreamTableSourceScanRule.INSTANCE, DataStreamMatchRule.INSTANCE, DataStreamTableAggregateRule.INSTANCE, - DataStreamGroupWindowTableAggregateRule.INSTANCE - ) + DataStreamGroupWindowTableAggregateRule.INSTANCE, + DataStreamPythonCalcRule.INSTANCE + ) /** * RuleSet to decorate plans for stream / DataStream execution diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamCalcRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamCalcRule.scala index 0a1a31a7a5f3e8..a7c16d9b840193 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamCalcRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamCalcRule.scala @@ -18,21 +18,30 @@ package org.apache.flink.table.plan.rules.datastream -import org.apache.calcite.plan.{RelOptRule, RelTraitSet} +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.convert.ConverterRule +import org.apache.flink.table.functions.FunctionLanguage import org.apache.flink.table.plan.nodes.FlinkConventions import org.apache.flink.table.plan.nodes.datastream.DataStreamCalc import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCalc import org.apache.flink.table.plan.schema.RowSchema +import org.apache.flink.table.plan.util.PythonUtil.containsFunctionOf + +import scala.collection.JavaConverters._ class DataStreamCalcRule extends ConverterRule( classOf[FlinkLogicalCalc], FlinkConventions.LOGICAL, FlinkConventions.DATASTREAM, - "DataStreamCalcRule") -{ + "DataStreamCalcRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val calc: FlinkLogicalCalc = call.rel(0).asInstanceOf[FlinkLogicalCalc] + val program = calc.getProgram + !program.getExprList.asScala.exists(containsFunctionOf(_, FunctionLanguage.PYTHON)) + } def convert(rel: RelNode): RelNode = { val calc: FlinkLogicalCalc = rel.asInstanceOf[FlinkLogicalCalc] diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamPythonCalcRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamPythonCalcRule.scala new file mode 100644 index 00000000000000..bf28383a86b2f6 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamPythonCalcRule.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.datastream + +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.flink.table.functions.FunctionLanguage +import org.apache.flink.table.plan.nodes.FlinkConventions +import org.apache.flink.table.plan.nodes.datastream.DataStreamPythonCalc +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCalc +import org.apache.flink.table.plan.schema.RowSchema +import org.apache.flink.table.plan.util.PythonUtil.containsFunctionOf + +import scala.collection.JavaConverters._ + +class DataStreamPythonCalcRule + extends ConverterRule( + classOf[FlinkLogicalCalc], + FlinkConventions.LOGICAL, + FlinkConventions.DATASTREAM, + "DataStreamPythonCalcRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val calc: FlinkLogicalCalc = call.rel(0).asInstanceOf[FlinkLogicalCalc] + val program = calc.getProgram + program.getExprList.asScala.exists(containsFunctionOf(_, FunctionLanguage.PYTHON)) + } + + def convert(rel: RelNode): RelNode = { + val calc: FlinkLogicalCalc = + rel.asInstanceOf[FlinkLogicalCalc] + val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM) + val convInput: RelNode = RelOptRule.convert(calc.getInput, FlinkConventions.DATASTREAM) + + new DataStreamPythonCalc( + rel.getCluster, + traitSet, + convInput, + new RowSchema(convInput.getRowType), + new RowSchema(rel.getRowType), + calc.getProgram, + description) + } +} + +object DataStreamPythonCalcRule { + val INSTANCE: RelOptRule = new DataStreamPythonCalcRule +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/PythonScalarFunctionSplitRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/PythonScalarFunctionSplitRule.scala new file mode 100644 index 00000000000000..facc01f88f5079 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/PythonScalarFunctionSplitRule.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.logical + +import org.apache.calcite.plan.RelOptRule.{any, operand} +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode, RexProgram} +import org.apache.calcite.sql.validate.SqlValidatorUtil +import org.apache.flink.table.functions.FunctionLanguage +import org.apache.flink.table.functions.ScalarFunction +import org.apache.flink.table.functions.utils.ScalarSqlFunction +import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCalc +import org.apache.flink.table.plan.util.PythonUtil.containsFunctionOf +import org.apache.flink.table.plan.util.RexDefaultVisitor + +import scala.collection.JavaConverters._ +import scala.collection.JavaConversions._ +import scala.collection.mutable + +/** + * Rule that split [[FlinkLogicalCalc]] into multiple [[FlinkLogicalCalc]]s. This is to ensure + * that the Python [[ScalarFunction]]s which could be executed in a batch are grouped into + * the same [[FlinkLogicalCalc]] node. + */ +class PythonScalarFunctionSplitRule extends RelOptRule( + operand(classOf[FlinkLogicalCalc], any), + "PythonScalarFunctionSplitRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val calc: FlinkLogicalCalc = call.rel(0).asInstanceOf[FlinkLogicalCalc] + val program = calc.getProgram + program.getExprList.asScala.exists(containsFunctionOf(_, FunctionLanguage.PYTHON)) && + program.getExprList.asScala.exists(containsFunctionOf(_, FunctionLanguage.JVM)) + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val calc: FlinkLogicalCalc = call.rel(0).asInstanceOf[FlinkLogicalCalc] + val input = calc.getInput + val rexBuilder = call.builder().getRexBuilder + val program = calc.getProgram + val extractedRexCalls = new mutable.ArrayBuffer[RexCall]() + + val outerCallContainsJavaFuntion = + program.getProjectList + .map(program.expandLocalRef) + .exists(containsFunctionOf(_, FunctionLanguage.JVM, recursive = false)) || + Option(program.getCondition) + .map(program.expandLocalRef) + .exists(containsFunctionOf(_, FunctionLanguage.JVM, recursive = false)) + + val splitter = new ScalarFunctionSplitter( + input.getRowType.getFieldCount, + extractedRexCalls, + outerCallContainsJavaFuntion) + + val newProjects = program.getProjectList + .map(program.expandLocalRef) + .map(_.accept(splitter)) + + val newCondition = Option(program.getCondition) + .map(program.expandLocalRef) + .map(_.accept(splitter)) + + val bottomCalcProjects = + input.getRowType.getFieldList.indices.map(RexInputRef.of(_, input.getRowType)) ++ + extractedRexCalls + val bottomCalcFieldNames = SqlValidatorUtil.uniquify( + input.getRowType.getFieldNames ++ + extractedRexCalls.indices.map("f" + _), + rexBuilder.getTypeFactory.getTypeSystem.isSchemaCaseSensitive) + + val bottomCalc = new FlinkLogicalCalc( + calc.getCluster, + calc.getTraitSet, + input, + RexProgram.create( + input.getRowType, + bottomCalcProjects, + null, + bottomCalcFieldNames, + rexBuilder)) + + val topCalc = new FlinkLogicalCalc( + calc.getCluster, + calc.getTraitSet, + bottomCalc, + RexProgram.create( + bottomCalc.getRowType, + newProjects, + newCondition.orNull, + calc.getRowType, + rexBuilder)) + + call.transformTo(topCalc) + } +} + +private class ScalarFunctionSplitter( + pythonFunctionOffset: Int, + extractedRexCalls: mutable.ArrayBuffer[RexCall], + convertPythonFunction: Boolean) + extends RexDefaultVisitor[RexNode] { + + override def visitCall(call: RexCall): RexNode = { + call.getOperator match { + case sfc: ScalarSqlFunction if sfc.getScalarFunction.getLanguage == + FunctionLanguage.PYTHON => + visit(convertPythonFunction, call) + + case _ => + visit(!convertPythonFunction, call) + } + } + + override def visitNode(rexNode: RexNode): RexNode = rexNode + + private def visit(needConvert: Boolean, call: RexCall): RexNode = { + if (needConvert) { + val newNode = new RexInputRef( + pythonFunctionOffset + extractedRexCalls.length, call.getType) + extractedRexCalls.append(call) + newNode + } else { + call.clone( + call.getType, + call.getOperands.asScala.map(_.accept( + new ScalarFunctionSplitter( + pythonFunctionOffset, + extractedRexCalls, + convertPythonFunction)))) + } + } +} + +object PythonScalarFunctionSplitRule { + val INSTANCE: RelOptRule = new PythonScalarFunctionSplitRule +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/util/PythonUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/util/PythonUtil.scala new file mode 100644 index 00000000000000..af68dd84638109 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/util/PythonUtil.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.util + +import org.apache.calcite.rex.{RexCall, RexNode} +import org.apache.flink.table.functions.FunctionLanguage +import org.apache.flink.table.functions.utils.ScalarSqlFunction + +import scala.collection.JavaConversions._ + +object PythonUtil { + + /** + * Checks whether it contains the specified kind of function in the specified node. + * + * @param node the RexNode to check + * @param language the expected kind of function to find + * @param recursive whether check the inputs of the specified node + * @return true if it contains the specified kind of function in the specified node. + */ + def containsFunctionOf( + node: RexNode, + language: FunctionLanguage, + recursive: Boolean = true): Boolean = { + node.accept(new FunctionFinder(language, recursive)) + } + + /** + * Checks whether there is a Python function in a RexNode. + * + * @param expectedLanguage the expected kind of function to find + * @param recursive whether check the inputs + */ + class FunctionFinder(expectedLanguage: FunctionLanguage, recursive: Boolean) + extends RexDefaultVisitor[Boolean] { + + override def visitCall(call: RexCall): Boolean = { + call.getOperator match { + case sfc: ScalarSqlFunction if sfc.getScalarFunction.getLanguage == + FunctionLanguage.PYTHON => + findInternal(FunctionLanguage.PYTHON, call) + case _ => + findInternal(FunctionLanguage.JVM, call) + } + } + + override def visitNode(rexNode: RexNode): Boolean = false + + private def findInternal(actualLanguage: FunctionLanguage, call: RexCall): Boolean = { + if (actualLanguage == expectedLanguage) { + true + } else if (recursive) { + call.getOperands.exists(_.accept(this)) + } else { + false + } + } + } +} diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/PythonScalarFunctionSplitRuleTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/PythonScalarFunctionSplitRuleTest.scala new file mode 100644 index 00000000000000..db9538c66d4fb7 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/PythonScalarFunctionSplitRuleTest.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.functions.{FunctionLanguage, ScalarFunction} +import org.apache.flink.table.utils.TableTestUtil._ +import org.apache.flink.table.utils.TableTestBase +import org.junit.Test + +class PythonScalarFunctionSplitRuleTest extends TableTestBase { + + @Test + def testPythonFunctionAsInputOfJavaFunction(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Int, Int, Int)]("MyTable", 'a, 'b, 'c) + util.tableEnv.registerFunction("pyFunc1", new PythonScalarFunction("pyFunc1")) + + val resultTable = table + .select("pyFunc1(a, b) + 1") + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamPythonCalc", + streamTableNode(table), + term("select", "a", "b", "c", "pyFunc1(a, b) AS f0") + ), + term("select", "+(f0, 1) AS _c0") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testPythonFunctionMixWithJavaFunction(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Int, Int, Int)]("MyTable", 'a, 'b, 'c) + util.tableEnv.registerFunction("pyFunc1", new PythonScalarFunction("pyFunc1")) + + val resultTable = table + .select("pyFunc1(a, b), c + 1") + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamPythonCalc", + streamTableNode(table), + term("select", "a", "b", "c", "pyFunc1(a, b) AS f0") + ), + term("select", "f0 AS _c0", "+(c, 1) AS _c1") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testPythonFunctionInWhereClause(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Int, Int, Int)]("MyTable", 'a, 'b, 'c) + util.tableEnv.registerFunction("pyFunc1", new PythonScalarFunction("pyFunc1")) + util.tableEnv.registerFunction("pyFunc2", new PythonScalarFunction("pyFunc2")) + + val resultTable = table + .where("pyFunc2(a, c) > 0") + .select("pyFunc1(a, b), c + 1") + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamPythonCalc", + streamTableNode(table), + term("select", "a", "b", "c", "pyFunc1(a, b) AS f0", "pyFunc2(a, c) AS f1") + ), + term("select", "f0 AS _c0", "+(c, 1) AS _c1"), + term("where", ">(f1, 0)") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testChainingPythonFunction(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Int, Int, Int)]("MyTable", 'a, 'b, 'c) + util.tableEnv.registerFunction("pyFunc1", new PythonScalarFunction("pyFunc1")) + util.tableEnv.registerFunction("pyFunc2", new PythonScalarFunction("pyFunc2")) + util.tableEnv.registerFunction("pyFunc3", new PythonScalarFunction("pyFunc3")) + + val resultTable = table + .select("pyFunc3(pyFunc2(a + pyFunc1(a, c), b), c)") + + val expected = unaryNode( + "DataStreamPythonCalc", + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamPythonCalc", + streamTableNode(table), + term("select", "a", "b", "c", "pyFunc1(a, c) AS f0") + ), + term("select", "a", "b", "c", "+(a, f0) AS f0") + ), + term("select", "pyFunc3(pyFunc2(f0, b), c) AS _c0") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testOnlyOnePythonFunction(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Int, Int, Int)]("MyTable", 'a, 'b, 'c) + util.tableEnv.registerFunction("pyFunc1", new PythonScalarFunction("pyFunc1")) + + val resultTable = table + .select("pyFunc1(a, b)") + + val expected = unaryNode( + "DataStreamPythonCalc", + streamTableNode(table), + term("select", "pyFunc1(a, b) AS _c0") + ) + + util.verifyTable(resultTable, expected) + } +} + +class PythonScalarFunction(name: String) extends ScalarFunction { + def eval(i: Int, j: Int): Int = i + j + + override def getResultType(signature: Array[Class[_]]): TypeInformation[_] = + BasicTypeInfo.INT_TYPE_INFO + + override def getLanguage: FunctionLanguage = FunctionLanguage.PYTHON + + override def toString: String = name +}