From bf3b51a9b08cddcc6f6d600434ef652ae290272e Mon Sep 17 00:00:00 2001 From: Jark Wu Date: Tue, 18 Oct 2016 11:15:07 +0800 Subject: [PATCH 1/4] [FLINK-4469] [table] Add support for user defined table function in Table API & SQL Remove CROSS/OUTER APPLY support in SQL. Make ExpressionParser support parse string to LogicalNode minor changes support POJO types address review comments, it mostly includes: 1. Forbid TableFunction implemented by Scala object, since the `collect` is called on a singleton. It will be error-prone in some concurrent cases. 2. Make `UserDefinedFunction` clean, and move eval relative functions and `createSqlFunction` to utils. 3. Rename `ScalarFunctions` to `SqlFunctions` because contains also TableFunction logic. 4. Restructure tests. Test Java Table API by comparing the RelNode of two tables. And check SQL API's DataSetRel and DataStreamRel via `TableTestBase` utils. 5. Scala Table API implicitly convert `TableFunction` into `TableFunctionCall`. `TableFunctionCall` is not an `Expression` or `LogicalNode`, but is like `GroupWindow`, can be visible to the users (contains API such `as(...)`). 6. Fix the hierarchy type extraction problem. 7. Check correct errors if a invalid table function is called or registered. Remove UserDefinedFunction. Make TableFunction and ScalarFunction clean. Rebase the code and fix conflicts minor fixes --- .../java/table/BatchTableEnvironment.scala | 15 + .../java/table/StreamTableEnvironment.scala | 15 + .../scala/table/BatchTableEnvironment.scala | 12 + .../scala/table/StreamTableEnvironment.scala | 11 + .../flink/api/scala/table/expressionDsl.scala | 4 + .../flink/api/table/FlinkTypeFactory.scala | 14 +- .../flink/api/table/TableEnvironment.scala | 44 +- .../flink/api/table/TableFunctionCall.scala | 110 +++++ .../api/table/codegen/CodeGenerator.scala | 99 +++-- .../codegen/calls/ScalarFunctionCallGen.scala | 4 +- ...Functions.scala => SqlFunctionUtils.scala} | 16 +- .../codegen/calls/TableFunctionCallGen.scala | 78 ++++ .../table/expressions/ExpressionParser.scala | 23 + .../flink/api/table/expressions/call.scala | 14 +- .../api/table/functions/ScalarFunction.scala | 46 +- .../api/table/functions/TableFunction.scala | 121 ++++++ .../table/functions/UserDefinedFunction.scala | 61 --- .../functions/utils/ScalarSqlFunction.scala | 18 +- .../functions/utils/TableSqlFunction.scala | 119 ++++++ .../utils/UserDefinedFunctionUtils.scala | 292 +++++++++---- .../flink/api/table/plan/logical/call.scala | 117 +++++ .../api/table/plan/logical/operators.scala | 23 +- .../api/table/plan/nodes/FlinkCorrelate.scala | 152 +++++++ .../plan/nodes/dataset/DataSetCorrelate.scala | 141 +++++++ .../datastream/DataStreamCorrelate.scala | 133 ++++++ .../api/table/plan/rules/FlinkRuleSets.scala | 2 + .../rules/dataSet/DataSetCorrelateRule.scala | 89 ++++ .../datastream/DataStreamCorrelateRule.scala | 89 ++++ .../plan/schema/FlinkTableFunctionImpl.scala | 84 ++++ .../org/apache/flink/api/table/table.scala | 123 +++++- .../api/table/validate/FunctionCatalog.scala | 43 +- .../UserDefinedTableFunctionITCase.scala | 212 ++++++++++ .../batch/UserDefinedTableFunctionTest.scala | 320 ++++++++++++++ .../UserDefinedTableFunctionITCase.scala | 181 ++++++++ .../stream/UserDefinedTableFunctionTest.scala | 399 ++++++++++++++++++ .../UserDefinedScalarFunctionTest.scala | 4 +- .../utils/ExpressionTestBase.scala | 4 +- .../utils/UserDefinedTableFunctions.scala | 116 +++++ .../flink/api/table/utils/TableTestBase.scala | 32 ++ 39 files changed, 3123 insertions(+), 257 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableFunctionCall.scala rename flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/{ScalarFunctions.scala => SqlFunctionUtils.scala} (96%) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala delete mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/call.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala index a4f40d5619204..9df646f5943e4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala @@ -21,6 +21,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.TypeExtractor import org.apache.flink.api.java.{DataSet, ExecutionEnvironment} import org.apache.flink.api.table.expressions.ExpressionParser +import org.apache.flink.api.table.functions.TableFunction import org.apache.flink.api.table.{Table, TableConfig} /** @@ -162,4 +163,18 @@ class BatchTableEnvironment( translate[T](table)(typeInfo) } + /** + * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in SQL queries. + * + * @param name The name under which the function is registered. + * @param tf The TableFunction to register + */ + def registerFunction[T](name: String, tf: TableFunction[T]): Unit = { + implicit val typeInfo: TypeInformation[T] = TypeExtractor + .createTypeInfo(tf, classOf[TableFunction[_]], tf.getClass, 0) + .asInstanceOf[TypeInformation[T]] + + registerTableFunctionInternal[T](name, tf) + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala index f8dbc37d7850f..c6b5cb9b19805 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala @@ -19,6 +19,7 @@ package org.apache.flink.api.java.table import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.TypeExtractor +import org.apache.flink.api.table.functions.TableFunction import org.apache.flink.api.table.{TableConfig, Table} import org.apache.flink.api.table.expressions.ExpressionParser import org.apache.flink.streaming.api.datastream.DataStream @@ -164,4 +165,18 @@ class StreamTableEnvironment( translate[T](table)(typeInfo) } + /** + * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in SQL queries. + * + * @param name The name under which the function is registered. + * @param tf The TableFunction to register + */ + def registerFunction[T](name: String, tf: TableFunction[T]): Unit = { + implicit val typeInfo: TypeInformation[T] = TypeExtractor + .createTypeInfo(tf, classOf[TableFunction[_]], tf.getClass, 0) + .asInstanceOf[TypeInformation[T]] + + registerTableFunctionInternal[T](name, tf) + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala index adb444bf5ac68..36885d2ccce2c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/BatchTableEnvironment.scala @@ -20,6 +20,7 @@ package org.apache.flink.api.scala.table import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala._ import org.apache.flink.api.table.expressions.Expression +import org.apache.flink.api.table.functions.TableFunction import org.apache.flink.api.table.{TableConfig, Table} import scala.reflect.ClassTag @@ -139,4 +140,15 @@ class BatchTableEnvironment( wrap[T](translate(table))(ClassTag.AnyRef.asInstanceOf[ClassTag[T]]) } + /** + * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in SQL queries. + * + * @param name The name under which the function is registered. + * @param tf The TableFunction to register + */ + def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = { + registerTableFunctionInternal(name, tf) + } + } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala index e1061786e984a..dde69d508934f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/StreamTableEnvironment.scala @@ -18,6 +18,7 @@ package org.apache.flink.api.scala.table import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.table.functions.TableFunction import org.apache.flink.api.table.{TableConfig, Table} import org.apache.flink.api.table.expressions.Expression import org.apache.flink.streaming.api.scala.{StreamExecutionEnvironment, DataStream} @@ -142,4 +143,14 @@ class StreamTableEnvironment( asScalaStream(translate(table)) } + /** + * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in SQL queries. + * + * @param name The name under which the function is registered. + * @param tf The TableFunction to register + */ + def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = { + registerTableFunctionInternal(name, tf) + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala index fee43d8ba48cc..922621079d74a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala @@ -21,9 +21,11 @@ import java.sql.{Date, Time, Timestamp} import org.apache.calcite.avatica.util.DateTimeUtils._ import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation} +import org.apache.flink.api.table.TableFunctionCallBuilder import org.apache.flink.api.table.expressions.ExpressionUtils.{toMilliInterval, toMonthInterval, toRowInterval} import org.apache.flink.api.table.expressions.TimeIntervalUnit.TimeIntervalUnit import org.apache.flink.api.table.expressions._ +import org.apache.flink.api.table.functions.TableFunction import scala.language.implicitConversions @@ -539,6 +541,8 @@ trait ImplicitExpressionConversions { implicit def sqlDate2Literal(sqlDate: Date): Expression = Literal(sqlDate) implicit def sqlTime2Literal(sqlTime: Time): Expression = Literal(sqlTime) implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression = Literal(sqlTimestamp) + implicit def UDTF2TableFunctionCall[T: TypeInformation](udtf: TableFunction[T]): + TableFunctionCallBuilder[T] = TableFunctionCallBuilder(udtf) } // ------------------------------------------------------------------------------------------------ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala index ee71ce9dd7842..dd11fd2e000b7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/FlinkTypeFactory.scala @@ -26,7 +26,7 @@ import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql.parser.SqlParserPos import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ -import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeinfo.{NothingTypeInfo, SqlTimeTypeInfo, TypeInformation} import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.typeutils.ValueTypeInfo._ import org.apache.flink.api.table.FlinkTypeFactory.typeInfoToSqlTypeName @@ -95,9 +95,9 @@ class FlinkTypeFactory(typeSystem: RelDataTypeSystem) extends JavaTypeFactoryImp } override def createTypeWithNullability( - relDataType: RelDataType, - nullable: Boolean) - : RelDataType = relDataType match { + relDataType: RelDataType, + nullable: Boolean) + : RelDataType = relDataType match { case composite: CompositeRelDataType => // at the moment we do not care about nullability composite @@ -152,8 +152,7 @@ object FlinkTypeFactory { case typeName if DAY_INTERVAL_TYPES.contains(typeName) => TimeIntervalTypeInfo.INTERVAL_MILLIS case NULL => - throw TableException("Type NULL is not supported. " + - "Null values must have a supported type.") + throw TableException("Type NULL is not supported. Null values must have a supported type.") // symbol for special flags e.g. TRIM's BOTH, LEADING, TRAILING // are represented as integer @@ -168,6 +167,9 @@ object FlinkTypeFactory { val compositeRelDataType = relDataType.asInstanceOf[CompositeRelDataType] compositeRelDataType.compositeType + // ROW and CURSOR for UDTF case, whose type info will never be used, just a placeholder + case ROW | CURSOR => new NothingTypeInfo + case _@t => throw TableException(s"Type is not supported: $t") } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala index 7b2b73842a492..c7f8f265cb43a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala @@ -40,7 +40,8 @@ import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo import org.apache.flink.api.scala.{ExecutionEnvironment => ScalaBatchExecEnv} import org.apache.flink.api.table.codegen.ExpressionReducer import org.apache.flink.api.table.expressions.{Alias, Expression, UnresolvedFieldReference} -import org.apache.flink.api.table.functions.{ScalarFunction, UserDefinedFunction} +import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{checkForInstantiation, checkNotSingleton, createTableSqlFunctions, createScalarSqlFunction} +import org.apache.flink.api.table.functions.{TableFunction, ScalarFunction} import org.apache.flink.api.table.plan.cost.DataSetCostFactory import org.apache.flink.api.table.plan.schema.RelTable import org.apache.flink.api.table.sinks.TableSink @@ -153,21 +154,42 @@ abstract class TableEnvironment(val config: TableConfig) { protected def getBuiltInRuleSet: RuleSet /** - * Registers a [[UserDefinedFunction]] under a unique name. Replaces already existing + * Registers a [[ScalarFunction]] under a unique name. Replaces already existing * user-defined functions under this name. */ - def registerFunction(name: String, function: UserDefinedFunction): Unit = { - function match { - case sf: ScalarFunction => - // register in Table API - functionCatalog.registerFunction(name, function.getClass) + def registerFunction(name: String, function: ScalarFunction): Unit = { + // check could be instantiated + checkForInstantiation(function.getClass) - // register in SQL API - functionCatalog.registerSqlFunction(sf.getSqlFunction(name, typeFactory)) + // register in Table API + functionCatalog.registerFunction(name, function.getClass) - case _ => - throw new TableException("Unsupported user-defined function type.") + // register in SQL API + functionCatalog.registerSqlFunction(createScalarSqlFunction(name, function, typeFactory)) + } + + /** + * Registers a [[TableFunction]] under a unique name. Replaces already existing + * user-defined functions under this name. + */ + private[flink] def registerTableFunctionInternal[T: TypeInformation]( + name: String, function: TableFunction[T]): Unit = { + // check not Scala object + checkNotSingleton(function.getClass) + // check could be instantiated + checkForInstantiation(function.getClass) + + val typeInfo: TypeInformation[_] = if (function.getResultType != null) { + function.getResultType + } else { + implicitly[TypeInformation[T]] } + + // register in Table API + functionCatalog.registerFunction(name, function.getClass) + // register in SQL API + val sqlFunctions = createTableSqlFunctions(name, function, typeInfo, typeFactory) + functionCatalog.registerSqlFunctions(sqlFunctions) } /** diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableFunctionCall.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableFunctionCall.scala new file mode 100644 index 0000000000000..4843567629358 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableFunctionCall.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.table.expressions.{Expression, UnresolvedFieldReference} +import org.apache.flink.api.table.functions.TableFunction +import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.getFieldInfo +import org.apache.flink.api.table.plan.logical.{LogicalNode, LogicalTableFunctionCall} + + +/** + * A [[TableFunctionCall]] represents a call to a [[TableFunction]] with actual parameters. + * + * For Scala users, Flink will help to parse a [[TableFunction]] to [[TableFunctionCall]] + * implicitly. For Java users, Flink will help to parse a string expression to + * [[TableFunctionCall]]. So users do not need to create a [[TableFunctionCall]] manually. + * + * @param functionName function name + * @param tableFunction user-defined table function + * @param parameters actual parameters of function + * @param resultType type information of returned table + */ +case class TableFunctionCall( + functionName: String, + tableFunction: TableFunction[_], + parameters: Seq[Expression], + resultType: TypeInformation[_]) { + + private var aliases: Option[Seq[Expression]] = None + + /** + * Assigns an alias for this table function returned fields that the following `select()` clause + * can refer to. + * + * @param aliasList alias for this table function returned fields + * @return this table function call + */ + def as(aliasList: Expression*): TableFunctionCall = { + this.aliases = Some(aliasList) + this + } + + /** + * Converts an API class to a logical node for planning. + */ + private[flink] def toLogicalTableFunctionCall(child: LogicalNode): LogicalTableFunctionCall = { + val originNames = getFieldInfo(resultType)._1 + + // determine the final field names + val fieldNames = if (aliases.isDefined) { + val aliasList = aliases.get + if (aliasList.length != originNames.length) { + throw ValidationException( + s"List of column aliases must have same degree as table; " + + s"the returned table of function '$functionName' has ${originNames.length} " + + s"columns (${originNames.mkString(",")}), " + + s"whereas alias list has ${aliasList.length} columns") + } else if (!aliasList.forall(_.isInstanceOf[UnresolvedFieldReference])) { + throw ValidationException("Alias only accept name expressions as arguments") + } else { + aliasList.map(_.asInstanceOf[UnresolvedFieldReference].name).toArray + } + } else { + originNames + } + + LogicalTableFunctionCall( + functionName, + tableFunction, + parameters, + resultType, + fieldNames, + child) + } +} + + +case class TableFunctionCallBuilder[T: TypeInformation](udtf: TableFunction[T]) { + /** + * Creates a call to a [[TableFunction]] in Scala Table API. + * + * @param params actual parameters of function + * @return [[TableFunctionCall]] + */ + def apply(params: Expression*): TableFunctionCall = { + val resultType = if (udtf.getResultType == null) { + implicitly[TypeInformation[T]] + } else { + udtf.getResultType + } + TableFunctionCall(udtf.getClass.getSimpleName, udtf, params, resultType) + } +} + diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala index 2a8ef449ba996..082137d42720c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala @@ -33,9 +33,8 @@ import org.apache.flink.api.java.typeutils.{GenericTypeInfo, PojoTypeInfo, Tuple import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo import org.apache.flink.api.table.codegen.CodeGenUtils._ import org.apache.flink.api.table.codegen.Indenter.toISC -import org.apache.flink.api.table.codegen.calls.ScalarFunctions +import org.apache.flink.api.table.codegen.calls.SqlFunctionUtils import org.apache.flink.api.table.codegen.calls.ScalarOperators._ -import org.apache.flink.api.table.functions.UserDefinedFunction import org.apache.flink.api.table.typeutils.{RowTypeInfo, TypeConverter} import org.apache.flink.api.table.typeutils.TypeCheckUtils._ import org.apache.flink.api.table.{FlinkTypeFactory, TableConfig} @@ -50,16 +49,19 @@ import scala.collection.mutable * @param nullableInput input(s) can be null. * @param input1 type information about the first input of the Function * @param input2 type information about the second input if the Function is binary - * @param inputPojoFieldMapping additional mapping information if input1 is a POJO (POJO types - * have no deterministic field order). We assume that input2 is - * converted before and thus is never a POJO. + * @param input1PojoFieldMapping additional mapping information if input1 is a POJO (POJO types + * have no deterministic field order). + * @param input2PojoFieldMapping additional mapping information if input2 is a POJO (POJO types + * have no deterministic field order). + * */ class CodeGenerator( config: TableConfig, nullableInput: Boolean, input1: TypeInformation[Any], input2: Option[TypeInformation[Any]] = None, - inputPojoFieldMapping: Option[Array[Int]] = None) + input1PojoFieldMapping: Option[Array[Int]] = None, + input2PojoFieldMapping: Option[Array[Int]] = None) extends RexVisitor[GeneratedExpression] { // check if nullCheck is enabled when inputs can be null @@ -67,18 +69,19 @@ class CodeGenerator( throw new CodeGenException("Null check must be enabled if entire rows can be null.") } - // check for POJO input mapping + // check for POJO input1 mapping input1 match { case pt: PojoTypeInfo[_] => - inputPojoFieldMapping.getOrElse( - throw new CodeGenException("No input mapping is specified for input of type POJO.")) + input1PojoFieldMapping.getOrElse( + throw new CodeGenException("No input mapping is specified for input1 of type POJO.")) case _ => // ok } - // check that input2 is never a POJO + // check for POJO input2 mapping input2 match { case Some(pt: PojoTypeInfo[_]) => - throw new CodeGenException("Second input must not be a POJO type.") + input2PojoFieldMapping.getOrElse( + throw new CodeGenException("No input mapping is specified for input2 of type POJO.")) case _ => // ok } @@ -334,17 +337,32 @@ class CodeGenerator( resultFieldNames: Seq[String]) : GeneratedExpression = { val input1AccessExprs = for (i <- 0 until input1.getArity) - yield generateInputAccess(input1, input1Term, i) + yield generateInputAccess(input1, input1Term, i, input1PojoFieldMapping) val input2AccessExprs = input2 match { case Some(ti) => for (i <- 0 until ti.getArity) - yield generateInputAccess(ti, input2Term, i) + yield generateInputAccess(ti, input2Term, i, input2PojoFieldMapping) case None => Seq() // add nothing } generateResultExpression(input1AccessExprs ++ input2AccessExprs, returnType, resultFieldNames) } + /** + * Generates an expression from the left input and the right table function. + */ + def generateCorrelateAccessExprs: (Seq[GeneratedExpression], Seq[GeneratedExpression]) = { + val input1AccessExprs = for (i <- 0 until input1.getArity) + yield generateInputAccess(input1, input1Term, i, input1PojoFieldMapping) + + val input2AccessExprs = input2 match { + case Some(ti) => for (i <- 0 until ti.getArity) + yield generateFieldAccess(ti, input2Term, i, input2PojoFieldMapping) + case None => throw new CodeGenException("type information of input2 must not be null") + } + (input1AccessExprs, input2AccessExprs) + } + /** * Generates an expression from a sequence of RexNode. If objects or variables can be reused, * they will be added to reusable code sections internally. The evaluation result @@ -594,9 +612,11 @@ class CodeGenerator( override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = { // if inputRef index is within size of input1 we work with input1, input2 otherwise val input = if (inputRef.getIndex < input1.getArity) { - (input1, input1Term) + (input1, input1Term, input1PojoFieldMapping) } else { - (input2.getOrElse(throw new CodeGenException("Invalid input access.")), input2Term) + (input2.getOrElse(throw new CodeGenException("Invalid input access.")), + input2Term, + input2PojoFieldMapping) } val index = if (input._2 == input1Term) { @@ -605,13 +625,17 @@ class CodeGenerator( inputRef.getIndex - input1.getArity } - generateInputAccess(input._1, input._2, index) + generateInputAccess(input._1, input._2, index, input._3) } override def visitFieldAccess(rexFieldAccess: RexFieldAccess): GeneratedExpression = { val refExpr = rexFieldAccess.getReferenceExpr.accept(this) val index = rexFieldAccess.getField.getIndex - val fieldAccessExpr = generateFieldAccess(refExpr.resultType, refExpr.resultTerm, index) + val fieldAccessExpr = generateFieldAccess( + refExpr.resultType, + refExpr.resultTerm, + index, + input1PojoFieldMapping) val resultTerm = newName("result") val nullTerm = newName("isNull") @@ -753,8 +777,9 @@ class CodeGenerator( } } - override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression = - throw new CodeGenException("Correlating variables are not supported yet.") + override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression = { + GeneratedExpression(input1Term, "false", "", input1) + } override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression = throw new CodeGenException("Local variables are not supported yet.") @@ -948,7 +973,7 @@ class CodeGenerator( // advanced scalar functions case sqlOperator: SqlOperator => - val callGen = ScalarFunctions.getCallGenerator( + val callGen = SqlFunctionUtils.getCallGenerator( sqlOperator, operands.map(_.resultType), resultType) @@ -977,7 +1002,8 @@ class CodeGenerator( private def generateInputAccess( inputType: TypeInformation[Any], inputTerm: String, - index: Int) + index: Int, + pojoFieldMapping: Option[Array[Int]]) : GeneratedExpression = { // if input has been used before, we can reuse the code that // has already been generated @@ -989,10 +1015,10 @@ class CodeGenerator( // generate input access and unboxing if necessary case None => val expr = if (nullableInput) { - generateNullableInputFieldAccess(inputType, inputTerm, index) + generateNullableInputFieldAccess(inputType, inputTerm, index, pojoFieldMapping) } else { - generateFieldAccess(inputType, inputTerm, index) + generateFieldAccess(inputType, inputTerm, index, pojoFieldMapping) } reusableInputUnboxingExprs((inputTerm, index)) = expr @@ -1005,7 +1031,8 @@ class CodeGenerator( private def generateNullableInputFieldAccess( inputType: TypeInformation[Any], inputTerm: String, - index: Int) + index: Int, + pojoFieldMapping: Option[Array[Int]]) : GeneratedExpression = { val resultTerm = newName("result") val nullTerm = newName("isNull") @@ -1013,7 +1040,7 @@ class CodeGenerator( val fieldType = inputType match { case ct: CompositeType[_] => val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]]) { - inputPojoFieldMapping.get(index) + pojoFieldMapping.get(index) } else { index @@ -1024,7 +1051,7 @@ class CodeGenerator( } val resultTypeTerm = primitiveTypeTermForTypeInfo(fieldType) val defaultValue = primitiveDefaultValue(fieldType) - val fieldAccessExpr = generateFieldAccess(inputType, inputTerm, index) + val fieldAccessExpr = generateFieldAccess(inputType, inputTerm, index, pojoFieldMapping) val inputCheckCode = s""" @@ -1047,12 +1074,13 @@ class CodeGenerator( private def generateFieldAccess( inputType: TypeInformation[_], inputTerm: String, - index: Int) + index: Int, + pojoFieldMapping: Option[Array[Int]]) : GeneratedExpression = { inputType match { case ct: CompositeType[_] => - val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]] && inputPojoFieldMapping.nonEmpty) { - inputPojoFieldMapping.get(index) + val fieldIndex = if (ct.isInstanceOf[PojoTypeInfo[_]] && pojoFieldMapping.nonEmpty) { + pojoFieldMapping.get(index) } else { index @@ -1332,16 +1360,17 @@ class CodeGenerator( } /** - * Adds a reusable [[UserDefinedFunction]] to the member area of the generated [[Function]]. - * The [[UserDefinedFunction]] must have a default constructor, however, it does not have + * Adds a reusable instance (a [[org.apache.flink.api.table.functions.TableFunction]] or + * [[org.apache.flink.api.table.functions.ScalarFunction]]) to the member area of the generated + * [[Function]]. The instance class must have a default constructor, however, it does not have * to be public. * - * @param function [[UserDefinedFunction]] object to be instantiated during runtime + * @param instance object to be instantiated during runtime * @return member variable term */ - def addReusableFunction(function: UserDefinedFunction): String = { - val classQualifier = function.getClass.getCanonicalName - val fieldTerm = s"function_${classQualifier.replace('.', '$')}" + def addReusableInstance(instance: Any): String = { + val classQualifier = instance.getClass.getCanonicalName + val fieldTerm = s"instance_${classQualifier.replace('.', '$')}" val fieldFunction = s""" diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctionCallGen.scala index b6ef8ad863c21..62b6842a3b558 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctionCallGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctionCallGen.scala @@ -42,7 +42,7 @@ class ScalarFunctionCallGen( operands: Seq[GeneratedExpression]) : GeneratedExpression = { // determine function signature and result class - val matchingSignature = getSignature(scalarFunction, signature) + val matchingSignature = getSignature(scalarFunction.getClass, signature) .getOrElse(throw new CodeGenException("No matching signature found.")) val resultClass = getResultTypeClass(scalarFunction, matchingSignature) @@ -65,7 +65,7 @@ class ScalarFunctionCallGen( } // generate function call - val functionReference = codeGenerator.addReusableFunction(scalarFunction) + val functionReference = codeGenerator.addReusableInstance(scalarFunction) val resultTypeTerm = if (resultClass.isPrimitive) { primitiveTypeTermForTypeInfo(returnType) } else { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/SqlFunctionUtils.scala similarity index 96% rename from flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala rename to flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/SqlFunctionUtils.scala index e7c436aeb6bc6..8f11c876e4111 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctions.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/SqlFunctionUtils.scala @@ -28,14 +28,14 @@ import org.apache.calcite.util.BuiltInMethod import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, SqlTimeTypeInfo, TypeInformation} import org.apache.flink.api.java.typeutils.GenericTypeInfo -import org.apache.flink.api.table.functions.utils.ScalarSqlFunction +import org.apache.flink.api.table.functions.utils.{TableSqlFunction, ScalarSqlFunction} import scala.collection.mutable /** - * Global hub for user-defined and built-in advanced SQL scalar functions. + * Global hub for user-defined and built-in advanced SQL functions. */ -object ScalarFunctions { +object SqlFunctionUtils { private val sqlFunctions: mutable.Map[(SqlOperator, Seq[TypeInformation[_]]), CallGenerator] = mutable.Map() @@ -317,6 +317,16 @@ object ScalarFunctions { ) ) + // user-defined table function + case tsf: TableSqlFunction => + Some( + new TableFunctionCallGen( + tsf.getTableFunction, + operandTypes, + resultType + ) + ) + // built-in scalar function case _ => sqlFunctions.get((sqlOperator, operandTypes)) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala new file mode 100644 index 0000000000000..802d8a4c78784 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.codegen.calls + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.table.codegen.CodeGenUtils._ +import org.apache.flink.api.table.codegen.{CodeGenException, CodeGenerator, GeneratedExpression} +import org.apache.flink.api.table.functions.TableFunction +import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils._ + +/** + * Generates a call to user-defined [[TableFunction]]. + * + * @param tableFunction user-defined [[TableFunction]] that might be overloaded + * @param signature actual signature with which the function is called + * @param returnType actual return type required by the surrounding + */ +class TableFunctionCallGen( + tableFunction: TableFunction[_], + signature: Seq[TypeInformation[_]], + returnType: TypeInformation[_]) + extends CallGenerator { + + override def generate( + codeGenerator: CodeGenerator, + operands: Seq[GeneratedExpression]) + : GeneratedExpression = { + // determine function signature + val matchingSignature = getSignature(tableFunction.getClass, signature) + .getOrElse(throw new CodeGenException("No matching signature found.")) + + // convert parameters for function (output boxing) + val parameters = matchingSignature + .zip(operands) + .map { case (paramClass, operandExpr) => + if (paramClass.isPrimitive) { + operandExpr + } else { + val boxedTypeTerm = boxedTypeTermForTypeInfo(operandExpr.resultType) + val boxedExpr = codeGenerator.generateOutputFieldBoxing(operandExpr) + val exprOrNull: String = if (codeGenerator.nullCheck) { + s"${boxedExpr.nullTerm} ? null : ($boxedTypeTerm) ${boxedExpr.resultTerm}" + } else { + boxedExpr.resultTerm + } + boxedExpr.copy(resultTerm = exprOrNull) + } + } + + // generate function call + val functionReference = codeGenerator.addReusableInstance(tableFunction) + val functionCallCode = + s""" + |${parameters.map(_.code).mkString("\n")} + |$functionReference.clear(); + |$functionReference.eval(${parameters.map(_.resultTerm).mkString(", ")}); + |""".stripMargin + + // has no result + GeneratedExpression(functionReference, "false", functionCallCode, returnType) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala index 6b6c1294927ce..c995b2bfc7a08 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala @@ -24,6 +24,7 @@ import org.apache.flink.api.table.expressions.ExpressionUtils.{toMilliInterval, import org.apache.flink.api.table.expressions.TimeIntervalUnit.TimeIntervalUnit import org.apache.flink.api.table.expressions.TimePointUnit.TimePointUnit import org.apache.flink.api.table.expressions.TrimMode.TrimMode +import org.apache.flink.api.table.plan.logical.{AliasNode, LogicalNode, UnresolvedTableFunctionCall} import org.apache.flink.api.table.typeutils.TimeIntervalTypeInfo import scala.language.implicitConversions @@ -472,6 +473,28 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { } } + lazy val tableFunctionCall: PackratParser[LogicalNode] = + functionIdent ~ "(" ~ repsep(expression, ",") ~ ")" ^^ { + case name ~ _ ~ args ~ _ => UnresolvedTableFunctionCall(name.toUpperCase, args) + } + + lazy val aliasNode: PackratParser[LogicalNode] = + tableFunctionCall ~ AS ~ "(" ~ repsep(fieldReference, ",") ~ ")" ^^ { + case e ~ _ ~ _ ~ names ~ _ => AliasNode(names, e) + } | tableFunctionCall + + lazy val logicalNode: PackratParser[LogicalNode] = aliasNode | + failure("Invalid expression.") + + def parseLogicalNode(nodeString: String): LogicalNode = { + parseAll(logicalNode, nodeString) match { + case Success(lst, _) => lst + + case NoSuccess(msg, next) => + throwError(msg, next) + } + } + private def throwError(msg: String, next: Input): Nothing = { val improvedMsg = msg.replace("string matching regex `\\z'", "End of expression") diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala index 39367be28df1d..6df6bfeffff44 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala @@ -20,7 +20,7 @@ package org.apache.flink.api.table.expressions import org.apache.calcite.rex.RexNode import org.apache.calcite.tools.RelBuilder import org.apache.flink.api.table.functions.ScalarFunction -import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, signatureToString, signaturesToString} +import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, signatureToString, signaturesToString, createScalarSqlFunction} import org.apache.flink.api.table.validate.{ValidationResult, ValidationFailure, ValidationSuccess} import org.apache.flink.api.table.{FlinkTypeFactory, UnresolvedException} @@ -63,22 +63,26 @@ case class ScalarFunctionCall( override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] relBuilder.call( - scalarFunction.getSqlFunction(scalarFunction.toString, typeFactory), + createScalarSqlFunction( + scalarFunction.getClass.getCanonicalName, + scalarFunction, + typeFactory), parameters.map(_.toRexNode): _*) } - override def toString = s"$scalarFunction(${parameters.mkString(", ")})" + override def toString = + s"${scalarFunction.getClass.getCanonicalName}(${parameters.mkString(", ")})" override private[flink] def resultType = getResultType(scalarFunction, foundSignature.get) override private[flink] def validateInput(): ValidationResult = { val signature = children.map(_.resultType) // look for a signature that matches the input types - foundSignature = getSignature(scalarFunction, signature) + foundSignature = getSignature(scalarFunction.getClass, signature) if (foundSignature.isEmpty) { ValidationFailure(s"Given parameters do not match any signature. \n" + s"Actual: ${signatureToString(signature)} \n" + - s"Expected: ${signaturesToString(scalarFunction)}") + s"Expected: ${signaturesToString(scalarFunction.getClass)}") } else { ValidationSuccess } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala index 5f9d834d7b0a3..06adfd9be2d5e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala @@ -48,7 +48,7 @@ import org.apache.flink.api.table.{FlinkTypeFactory, ValidationException} * recommended to declare parameters and result types as primitive types instead of their boxed * classes. DATE/TIME is equal to int, TIMESTAMP is equal to long. */ -abstract class ScalarFunction extends UserDefinedFunction { +abstract class ScalarFunction { /** * Creates a call to a [[ScalarFunction]] in Scala Table API. @@ -60,47 +60,6 @@ abstract class ScalarFunction extends UserDefinedFunction { ScalarFunctionCall(this, params) } - // ---------------------------------------------------------------------------------------------- - - private val evalMethods = checkAndExtractEvalMethods() - private lazy val signatures = evalMethods.map(_.getParameterTypes) - - /** - * Extracts evaluation methods and throws a [[ValidationException]] if no implementation - * can be found. - */ - private def checkAndExtractEvalMethods(): Array[Method] = { - val methods = getClass.asSubclass(classOf[ScalarFunction]) - .getDeclaredMethods - .filter { m => - val modifiers = m.getModifiers - m.getName == "eval" && Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers) - } - - if (methods.isEmpty) { - throw new ValidationException(s"Scalar function class '$this' does not implement at least " + - s"one method named 'eval' which is public and not abstract.") - } else { - methods - } - } - - /** - * Returns all found evaluation methods of the possibly overloaded function. - */ - private[flink] final def getEvalMethods: Array[Method] = evalMethods - - /** - * Returns all found signature of the possibly overloaded function. - */ - private[flink] final def getSignatures: Array[Array[Class[_]]] = signatures - - override private[flink] final def createSqlFunction( - name: String, - typeFactory: FlinkTypeFactory) - : SqlFunction = { - new ScalarSqlFunction(name, this, typeFactory) - } // ---------------------------------------------------------------------------------------------- @@ -135,7 +94,8 @@ abstract class ScalarFunction extends UserDefinedFunction { TypeExtractor.getForClass(c) } catch { case ite: InvalidTypesException => - throw new ValidationException(s"Parameter types of scalar function '$this' cannot be " + + throw new ValidationException( + s"Parameter types of scalar function '${this.getClass.getCanonicalName}' cannot be " + s"automatically determined. Please provide type information manually.") } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala new file mode 100644 index 0000000000000..d3548af77c843 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.functions + +import java.util + +import org.apache.flink.api.common.functions.InvalidTypesException +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.TypeExtractor +import org.apache.flink.api.table.ValidationException + +/** + * Base class for a user-defined table function (UDTF). A user-defined table functions works on + * zero, one, or multiple scalar values as input and returns multiple rows as output. + * + * The behavior of a [[TableFunction]] can be defined by implementing a custom evaluation + * method. An evaluation method must be declared publicly and named "eval". Evaluation methods + * can also be overloaded by implementing multiple methods named "eval". + * + * User-defined functions must have a default constructor and must be instantiable during runtime. + * + * By default the result type of an evaluation method is determined by Flink's type extraction + * facilities. This is sufficient for basic types or simple POJOs but might be wrong for more + * complex, custom, or composite types. In these cases [[TypeInformation]] of the result type + * can be manually defined by overriding [[getResultType()]]. + * + * Internally, the Table/SQL API code generation works with primitive values as much as possible. + * If a user-defined table function should not introduce much overhead during runtime, it is + * recommended to declare parameters and result types as primitive types instead of their boxed + * classes. DATE/TIME is equal to int, TIMESTAMP is equal to long. + * + * Example: + * + * {{{ + * + * public class Split extends TableFunction { + * + * // implement an "eval" method with several parameters you want + * public void eval(String str) { + * for (String s : str.split(" ")) { + * collect(s); // use collect(...) to emit an output row + * } + * } + * + * // can overloading eval method here ... + * } + * + * val tEnv: TableEnvironment = ... + * val table: Table = ... // schema: [a: String] + * + * // for Scala users + * val split = new Split() + * table.crossApply(split('c) as ('s)).select('a, 's) + * + * // for Java users + * tEnv.registerFunction("split", new Split()) // register table function first + * table.crossApply("split(a) as (s)").select("a, s") + * + * // for SQL users + * tEnv.registerFunction("split", new Split()) // register table function first + * tEnv.sql("SELECT a, s FROM MyTable, LATERAL TABLE(split(a)) as T(s)") + * + * }}} + * + * @tparam T The type of the output row + */ +abstract class TableFunction[T] { + + private val rows: util.ArrayList[T] = new util.ArrayList[T]() + + /** + * Emit an output row. + * + * @param row the output row + */ + protected def collect(row: T): Unit = { + // cache rows for now, maybe immediately process them further + rows.add(row) + } + + /** + * Internal use. Get an iterator of the buffered rows. + */ + def getRowsIterator = rows.iterator() + + /** + * Internal use. Clear buffered rows. + */ + def clear() = rows.clear() + + // ---------------------------------------------------------------------------------------------- + + /** + * Returns the result type of the evaluation method with a given signature. + * + * This method needs to be overriden in case Flink's type extraction facilities are not + * sufficient to extract the [[TypeInformation]] based on the return type of the evaluation + * method. Flink's type extraction facilities can handle basic types or + * simple POJOs but might be wrong for more complex, custom, or composite types. + * + * @return [[TypeInformation]] of result type or null if Flink should determine the type + */ + def getResultType: TypeInformation[T] = null + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala deleted file mode 100644 index 62afef04d20b8..0000000000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.api.table.functions - -import org.apache.calcite.sql.SqlFunction -import org.apache.flink.api.table.FlinkTypeFactory -import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.checkForInstantiation - -import scala.collection.mutable - -/** - * Base class for all user-defined functions such as scalar functions, table functions, - * or aggregation functions. - * - * User-defined functions must have a default constructor and must be instantiable during runtime. - */ -abstract class UserDefinedFunction { - - // we cache SQL functions to reduce amount of created objects - // (i.e. for type inference, validation, etc.) - private val cachedSqlFunctions = mutable.HashMap[String, SqlFunction]() - - // check if function can be instantiated - checkForInstantiation(this.getClass) - - /** - * Returns the corresponding [[SqlFunction]]. Creates an instance if not already created. - */ - private[flink] final def getSqlFunction( - name: String, - typeFactory: FlinkTypeFactory) - : SqlFunction = { - cachedSqlFunctions.getOrElseUpdate(name, createSqlFunction(name, typeFactory)) - } - - /** - * Creates corresponding [[SqlFunction]]. - */ - private[flink] def createSqlFunction( - name: String, - typeFactory: FlinkTypeFactory) - : SqlFunction - - override def toString = getClass.getCanonicalName -} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala index 531313e36904c..bbc33ed79c7a6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala @@ -26,7 +26,7 @@ import org.apache.calcite.sql.parser.SqlParserPos import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.table.functions.ScalarFunction import org.apache.flink.api.table.functions.utils.ScalarSqlFunction.{createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference} -import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, signatureToString, signaturesToString} +import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, getSignatures, signatureToString, signaturesToString} import org.apache.flink.api.table.{FlinkTypeFactory, ValidationException} import scala.collection.JavaConverters._ @@ -76,12 +76,12 @@ object ScalarSqlFunction { FlinkTypeFactory.toTypeInfo(operandType) } } - val foundSignature = getSignature(scalarFunction, parameters) + val foundSignature = getSignature(scalarFunction.getClass, parameters) if (foundSignature.isEmpty) { throw new ValidationException( s"Given parameters of function '$name' do not match any signature. \n" + s"Actual: ${signatureToString(parameters)} \n" + - s"Expected: ${signaturesToString(scalarFunction)}") + s"Expected: ${signaturesToString(scalarFunction.getClass)}") } val resultType = getResultType(scalarFunction, foundSignature.get) typeFactory.createTypeFromTypeInfo(resultType) @@ -104,7 +104,7 @@ object ScalarSqlFunction { val operandTypeInfo = getOperandTypeInfo(callBinding) - val foundSignature = getSignature(scalarFunction, operandTypeInfo) + val foundSignature = getSignature(scalarFunction.getClass, operandTypeInfo) .getOrElse(throw new ValidationException(s"Operand types of could not be inferred.")) val inferredTypes = scalarFunction @@ -123,16 +123,18 @@ object ScalarSqlFunction { name: String, scalarFunction: ScalarFunction) : SqlOperandTypeChecker = { + + val signatures = getSignatures(scalarFunction.getClass) /** * Operand type checker based on [[ScalarFunction]] given information. */ new SqlOperandTypeChecker { override def getAllowedSignatures(op: SqlOperator, opName: String): String = { - s"$opName[${signaturesToString(scalarFunction)}]" + s"$opName[${signaturesToString(scalarFunction.getClass)}]" } override def getOperandCountRange: SqlOperandCountRange = { - val signatureLengths = scalarFunction.getSignatures.map(_.length) + val signatureLengths = signatures.map(_.length) SqlOperandCountRanges.between(signatureLengths.min, signatureLengths.max) } @@ -142,14 +144,14 @@ object ScalarSqlFunction { : Boolean = { val operandTypeInfo = getOperandTypeInfo(callBinding) - val foundSignature = getSignature(scalarFunction, operandTypeInfo) + val foundSignature = getSignature(scalarFunction.getClass, operandTypeInfo) if (foundSignature.isEmpty) { if (throwOnFailure) { throw new ValidationException( s"Given parameters of function '$name' do not match any signature. \n" + s"Actual: ${signatureToString(operandTypeInfo)} \n" + - s"Expected: ${signaturesToString(scalarFunction)}") + s"Expected: ${signaturesToString(scalarFunction.getClass)}") } else { false } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala new file mode 100644 index 0000000000000..6eadfbc5fa317 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/TableSqlFunction.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.functions.utils + +import com.google.common.base.Predicate +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql._ +import org.apache.calcite.sql.`type`._ +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.sql.validate.SqlUserDefinedTableFunction +import org.apache.calcite.util.Util +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.table.functions.TableFunction +import org.apache.flink.api.table.FlinkTypeFactory +import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl + +import scala.collection.JavaConverters._ +import java.util + + +/** + * Calcite wrapper for user-defined table functions. + */ +class TableSqlFunction( + name: String, + udtf: TableFunction[_], + rowTypeInfo: TypeInformation[_], + returnTypeInference: SqlReturnTypeInference, + operandTypeInference: SqlOperandTypeInference, + operandTypeChecker: SqlOperandTypeChecker, + paramTypes: util.List[RelDataType], + functionImpl: FlinkTableFunctionImpl[_]) + extends SqlUserDefinedTableFunction( + new SqlIdentifier(name, SqlParserPos.ZERO), + returnTypeInference, + operandTypeInference, + operandTypeChecker, + paramTypes, + functionImpl) { + + /** + * Get the user-defined table function + */ + def getTableFunction = udtf + + /** + * Get the returned table type information of the table function + */ + def getRowTypeInfo = rowTypeInfo + + /** + * Get additional mapping information if the returned table type is a POJO + * (POJO types have no deterministic field order) + */ + def getPojoFieldMapping = functionImpl.fieldIndexes + +} + +object TableSqlFunction { + /** + * Util function to create a [[TableSqlFunction]] + * @param name function name (used by SQL parser) + * @param udtf user defined table function to be called + * @param rowTypeInfo the row type information generated by the table function + * @param typeFactory type factory for converting Flink's between Calcite's types + * @param functionImpl calcite table function schema + * @return [[TableSqlFunction]] + */ + def apply( + name: String, + udtf: TableFunction[_], + rowTypeInfo: TypeInformation[_], + typeFactory: FlinkTypeFactory, + functionImpl: FlinkTableFunctionImpl[_]): TableSqlFunction = { + + val argTypes: util.List[RelDataType] = new util.ArrayList[RelDataType] + val typeFamilies: util.List[SqlTypeFamily] = new util.ArrayList[SqlTypeFamily] + // derives operands' data types and type families + functionImpl.getParameters.asScala.foreach{ o => + val relType: RelDataType = o.getType(typeFactory) + argTypes.add(relType) + typeFamilies.add(Util.first(relType.getSqlTypeName.getFamily, SqlTypeFamily.ANY)) + } + // derives whether the 'input'th parameter of a method is optional. + val optional: Predicate[Integer] = new Predicate[Integer]() { + def apply(input: Integer): Boolean = { + functionImpl.getParameters.get(input).isOptional + } + } + // create type check for the operands + val typeChecker: FamilyOperandTypeChecker = OperandTypes.family(typeFamilies, optional) + + new TableSqlFunction( + name, + udtf, + rowTypeInfo, + ReturnTypes.CURSOR, + InferTypes.explicit(argTypes), + typeChecker, + argTypes, + functionImpl) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala index e7416f710bc68..1c1b2cb0696f9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala @@ -19,144 +19,215 @@ package org.apache.flink.api.table.functions.utils +import java.lang.reflect.{Method, Modifier} import java.sql.{Date, Time, Timestamp} import com.google.common.primitives.Primitives +import org.apache.calcite.sql.SqlFunction import org.apache.flink.api.common.functions.InvalidTypesException -import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} +import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.typeutils.TypeExtractor -import org.apache.flink.api.table.ValidationException -import org.apache.flink.api.table.functions.{ScalarFunction, UserDefinedFunction} +import org.apache.flink.api.table.{FlinkTypeFactory, TableException, ValidationException} +import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction} +import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl import org.apache.flink.util.InstantiationUtil object UserDefinedFunctionUtils { /** - * Instantiates a user-defined function. + * Instantiates a class. */ - def instantiate[T <: UserDefinedFunction](clazz: Class[T]): T = { + def instantiate[T](clazz: Class[T]): T = { val constructor = clazz.getDeclaredConstructor() constructor.setAccessible(true) constructor.newInstance() } /** - * Checks if a user-defined function can be easily instantiated. + * Checks if a class can be easily instantiated. */ def checkForInstantiation(clazz: Class[_]): Unit = { if (!InstantiationUtil.isPublic(clazz)) { - throw ValidationException("Function class is not public.") + throw ValidationException(s"Function class ${clazz.getCanonicalName} is not public.") } else if (!InstantiationUtil.isProperClass(clazz)) { - throw ValidationException("Function class is no proper class, it is either abstract," + - " an interface, or a primitive type.") + throw ValidationException(s"Function class ${clazz.getCanonicalName} is no proper class, " + + s"it is either abstract, an interface, or a primitive type.") } else if (InstantiationUtil.isNonStaticInnerClass(clazz)) { - throw ValidationException("The class is an inner class, but not statically accessible.") + throw ValidationException(s"The class ${clazz.getCanonicalName} is an inner class, " + + s"but not statically accessible.") } // check for default constructor (can be private) clazz .getDeclaredConstructors .find(_.getParameterTypes.isEmpty) - .getOrElse(throw ValidationException("Function class needs a default constructor.")) + .getOrElse(throw ValidationException( + s"Function class ${clazz.getCanonicalName} needs a default constructor.")) + } + + /** + * Check whether this is a Scala object. It is forbidden to use [[TableFunction]] implemented + * by a Scala object, since concurrent risks. + */ + def checkNotSingleton(clazz: Class[_]): Unit = { + // TODO it is not a good way to check singleton. Maybe improve it further. + if (clazz.getFields.map(_.getName) contains "MODULE$") { + throw new ValidationException( + s"TableFunction implemented by class ${clazz.getCanonicalName} " + + s"is a Scala object, it is forbidden since concurrent risks.") + } } // ---------------------------------------------------------------------------------------------- - // Utilities for ScalarFunction + // Utilities for eval methods // ---------------------------------------------------------------------------------------------- /** - * Prints one signature consisting of classes. + * Returns signatures matching the given signature of [[TypeInformation]]. + * Elements of the signature can be null (act as a wildcard). */ - def signatureToString(signature: Array[Class[_]]): String = - "(" + signature.map { clazz => - if (clazz == null) { - "null" - } else { - clazz.getCanonicalName - } - }.mkString(", ") + ")" + def getSignature( + function: Class[_], + signature: Seq[TypeInformation[_]]) + : Option[Array[Class[_]]] = { + // We compare the raw Java classes not the TypeInformation. + // TypeInformation does not matter during runtime (e.g. within a MapFunction). + val actualSignature = typeInfoToClass(signature) + val signatures = getSignatures(function) + + signatures + // go over all signatures and find one matching actual signature + .find { curSig => + // match parameters of signature to actual parameters + actualSignature.length == curSig.length && + curSig.zipWithIndex.forall { case (clazz, i) => + parameterTypeEquals(actualSignature(i), clazz) + } + } + } /** - * Prints one signature consisting of TypeInformation. + * Returns eval method matching the given signature of [[TypeInformation]]. */ - def signatureToString(signature: Seq[TypeInformation[_]]): String = { - signatureToString(typeInfoToClass(signature)) + def getEvalMethod( + function: Class[_], + signature: Seq[TypeInformation[_]]) + : Option[Method] = { + // We compare the raw Java classes not the TypeInformation. + // TypeInformation does not matter during runtime (e.g. within a MapFunction). + val actualSignature = typeInfoToClass(signature) + val evalMethods = checkAndExtractEvalMethods(function) + + evalMethods + // go over all eval methods and find one matching + .find { cur => + val signatures = cur.getParameterTypes + // match parameters of signature to actual parameters + actualSignature.length == signatures.length && + signatures.zipWithIndex.forall { case (clazz, i) => + parameterTypeEquals(actualSignature(i), clazz) + } + } } /** - * Extracts type classes of [[TypeInformation]] in a null-aware way. + * Extracts "eval" methods and throws a [[ValidationException]] if no implementation + * can be found. */ - def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] = - typeInfos.map { typeInfo => - if (typeInfo == null) { - null - } else { - typeInfo.getTypeClass + def checkAndExtractEvalMethods(function: Class[_]): Array[Method] = { + val methods = function + .getDeclaredMethods + .filter { m => + val modifiers = m.getModifiers + m.getName == "eval" && Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers) } - }.toArray + if (methods.isEmpty) { + throw new ValidationException( + s"Function class '${function.getCanonicalName}' does not implement at least " + + s"one method named 'eval' which is public and not abstract.") + } else { + methods + } + } + + def getSignatures(function: Class[_]): Array[Array[Class[_]]] = { + checkAndExtractEvalMethods(function).map(_.getParameterTypes) + } + + // ---------------------------------------------------------------------------------------------- + // Utilities for sql functions + // ---------------------------------------------------------------------------------------------- /** - * Compares parameter candidate classes with expected classes. If true, the parameters match. - * Candidate can be null (acts as a wildcard). + * Create [[SqlFunction]] for a [[ScalarFunction]] + * @param name function name + * @param function scalar function + * @param typeFactory type factory + * @return the ScalarSqlFunction */ - def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean = - candidate == null || - candidate == expected || - expected.isPrimitive && Primitives.wrap(expected) == candidate || - candidate == classOf[Date] && expected == classOf[Int] || - candidate == classOf[Time] && expected == classOf[Int] || - candidate == classOf[Timestamp] && expected == classOf[Long] + def createScalarSqlFunction( + name: String, + function: ScalarFunction, + typeFactory: FlinkTypeFactory) + : SqlFunction = { + new ScalarSqlFunction(name, function, typeFactory) + } /** - * Returns signatures matching the given signature of [[TypeInformation]]. - * Elements of the signature can be null (act as a wildcard). + * Create [[SqlFunction]]s for a [[TableFunction]]'s every eval method + * @param name function name + * @param tableFunction table function + * @param resultType the type information of returned table + * @param typeFactory type factory + * @return the TableSqlFunction */ - def getSignature( - scalarFunction: ScalarFunction, - signature: Seq[TypeInformation[_]]) - : Option[Array[Class[_]]] = { - // We compare the raw Java classes not the TypeInformation. - // TypeInformation does not matter during runtime (e.g. within a MapFunction). - val actualSignature = typeInfoToClass(signature) + def createTableSqlFunctions( + name: String, + tableFunction: TableFunction[_], + resultType: TypeInformation[_], + typeFactory: FlinkTypeFactory) + : Seq[SqlFunction] = { + val (fieldNames, fieldIndexes, _) = UserDefinedFunctionUtils.getFieldInfo(resultType) + val evalMethods = checkAndExtractEvalMethods(tableFunction.getClass) - scalarFunction - .getSignatures - // go over all signatures and find one matching actual signature - .find { curSig => - // match parameters of signature to actual parameters - actualSignature.length == curSig.length && - curSig.zipWithIndex.forall { case (clazz, i) => - parameterTypeEquals(actualSignature(i), clazz) - } - } + evalMethods.map { method => + val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, method) + TableSqlFunction(name, tableFunction, resultType, typeFactory, function) + } } + // ---------------------------------------------------------------------------------------------- + // Utilities for scalar functions + // ---------------------------------------------------------------------------------------------- + /** * Internal method of [[ScalarFunction#getResultType()]] that does some pre-checking and uses * [[TypeExtractor]] as default return type inference. */ def getResultType( - scalarFunction: ScalarFunction, + function: ScalarFunction, signature: Array[Class[_]]) : TypeInformation[_] = { // find method for signature - val evalMethod = scalarFunction.getEvalMethods + val evalMethod = checkAndExtractEvalMethods(function.getClass) .find(m => signature.sameElements(m.getParameterTypes)) .getOrElse(throw new ValidationException("Given signature is invalid.")) - val userDefinedTypeInfo = scalarFunction.getResultType(signature) + val userDefinedTypeInfo = function.getResultType(signature) if (userDefinedTypeInfo != null) { - userDefinedTypeInfo + userDefinedTypeInfo } else { try { TypeExtractor.getForClass(evalMethod.getReturnType) } catch { case ite: InvalidTypesException => - throw new ValidationException(s"Return type of scalar function '$this' cannot be " + - s"automatically determined. Please provide type information manually.") + throw new ValidationException( + s"Return type of scalar function '${function.getClass.getCanonicalName}' cannot be " + + s"automatically determined. Please provide type information manually.") } } } @@ -165,21 +236,100 @@ object UserDefinedFunctionUtils { * Returns the return type of the evaluation method matching the given signature. */ def getResultTypeClass( - scalarFunction: ScalarFunction, + function: ScalarFunction, signature: Array[Class[_]]) : Class[_] = { // find method for signature - val evalMethod = scalarFunction.getEvalMethods + val evalMethod = checkAndExtractEvalMethods(function.getClass) .find(m => signature.sameElements(m.getParameterTypes)) .getOrElse(throw new IllegalArgumentException("Given signature is invalid.")) evalMethod.getReturnType } + // ---------------------------------------------------------------------------------------------- + // Miscellaneous + // ---------------------------------------------------------------------------------------------- + /** - * Prints all signatures of a [[ScalarFunction]]. + * Returns field names and field positions for a given [[TypeInformation]]. + * + * Field names are automatically extracted for + * [[org.apache.flink.api.common.typeutils.CompositeType]]. + * + * @param inputType The TypeInformation extract the field names and positions from. + * @return A tuple of two arrays holding the field names and corresponding field positions. */ - def signaturesToString(scalarFunction: ScalarFunction): String = { - scalarFunction.getSignatures.map(signatureToString).mkString(", ") + def getFieldInfo(inputType: TypeInformation[_]) + : (Array[String], Array[Int], Array[TypeInformation[_]]) = { + + val fieldNames: Array[String] = inputType match { + case t: CompositeType[_] => t.getFieldNames + case a: AtomicType[_] => Array("f0") + case tpe => + throw new TableException(s"Currently only support CompositeType and AtomicType. " + + s"Type $tpe lacks explicit field naming") + } + val fieldIndexes = fieldNames.indices.toArray + val fieldTypes: Array[TypeInformation[_]] = fieldNames.map { i => + inputType match { + case t: CompositeType[_] => t.getTypeAt(i).asInstanceOf[TypeInformation[_]] + case a: AtomicType[_] => a.asInstanceOf[TypeInformation[_]] + case tpe => + throw new TableException(s"Currently only support CompositeType and AtomicType.") + } + } + (fieldNames, fieldIndexes, fieldTypes) } + /** + * Prints one signature consisting of classes. + */ + def signatureToString(signature: Array[Class[_]]): String = + signature.map { clazz => + if (clazz == null) { + "null" + } else { + clazz.getCanonicalName + } + }.mkString("(", ", ", ")") + + /** + * Prints one signature consisting of TypeInformation. + */ + def signatureToString(signature: Seq[TypeInformation[_]]): String = { + signatureToString(typeInfoToClass(signature)) + } + + /** + * Prints all eval methods signatures of a class. + */ + def signaturesToString(function: Class[_]): String = { + getSignatures(function).map(signatureToString).mkString(", ") + } + + /** + * Extracts type classes of [[TypeInformation]] in a null-aware way. + */ + private def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] = + typeInfos.map { typeInfo => + if (typeInfo == null) { + null + } else { + typeInfo.getTypeClass + } + }.toArray + + + /** + * Compares parameter candidate classes with expected classes. If true, the parameters match. + * Candidate can be null (acts as a wildcard). + */ + private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean = + candidate == null || + candidate == expected || + expected.isPrimitive && Primitives.wrap(expected) == candidate || + candidate == classOf[Date] && expected == classOf[Int] || + candidate == classOf[Time] && expected == classOf[Int] || + candidate == classOf[Timestamp] && expected == classOf[Long] + } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/call.scala new file mode 100644 index 0000000000000..50f9373e84a21 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/call.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.plan.logical + +import java.lang.reflect.Method +import java.util + +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.logical.LogicalTableFunctionScan +import org.apache.calcite.tools.RelBuilder +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.table._ +import org.apache.flink.api.table.expressions.{Attribute, Expression, ResolvedFieldReference} +import org.apache.flink.api.table.functions.TableFunction +import org.apache.flink.api.table.functions.utils.TableSqlFunction +import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getEvalMethod, signaturesToString, signatureToString, getFieldInfo, checkNotSingleton, checkForInstantiation} +import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl + +import scala.collection.JavaConverters._ + +/** + * General logical node for unresolved user-defined table function calls. + */ +case class UnresolvedTableFunctionCall(functionName: String, args: Seq[Expression]) + extends LogicalNode { + + override def output: Seq[Attribute] = + throw UnresolvedException("Invalid call to output on UnresolvedTableFunctionCall") + + override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = + throw UnresolvedException("Invalid call to construct on UnresolvedTableFunctionCall") + + override private[flink] def children: Seq[LogicalNode] = + throw UnresolvedException("Invalid call to children on UnresolvedTableFunctionCall") +} + +/** + * LogicalNode for calling a user-defined table functions. + * @param functionName function name + * @param tableFunction table function to be called (might be overloaded) + * @param parameters actual parameters + * @param fieldNames output field names + * @param child child logical node + */ +case class LogicalTableFunctionCall( + functionName: String, + tableFunction: TableFunction[_], + parameters: Seq[Expression], + resultType: TypeInformation[_], + fieldNames: Array[String], + child: LogicalNode) + extends UnaryNode { + + val (_, fieldIndexes, fieldTypes) = getFieldInfo(resultType) + var evalMethod: Method = _ + + override def output: Seq[Attribute] = fieldNames.zip(fieldTypes).map { + case (n, t) => ResolvedFieldReference(n, t) + } + + override def validate(tableEnv: TableEnvironment): LogicalNode = { + val node = super.validate(tableEnv).asInstanceOf[LogicalTableFunctionCall] + // check not Scala object + checkNotSingleton(tableFunction.getClass) + // check could be instantiated + checkForInstantiation(tableFunction.getClass) + // look for a signature that matches the input types + val signature = node.parameters.map(_.resultType) + val foundMethod = getEvalMethod(tableFunction.getClass, signature) + if (foundMethod.isEmpty) { + failValidation( + s"Given parameters of function '$functionName' do not match any signature. \n" + + s"Actual: ${signatureToString(signature)} \n" + + s"Expected: ${signaturesToString(tableFunction.getClass)}") + } else { + node.evalMethod = foundMethod.get + } + node + } + + override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { + val fieldIndexes = getFieldInfo(resultType)._2 + val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, evalMethod) + val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] + val sqlFunction = TableSqlFunction( + tableFunction.toString, + tableFunction, + resultType, + typeFactory, + function) + + val scan = LogicalTableFunctionScan.create( + relBuilder.peek().getCluster, + new util.ArrayList[RelNode](), + relBuilder.call(sqlFunction, parameters.map(_.toRexNode(relBuilder)).asJava), + function.getElementType(null), + function.getRowType(relBuilder.getTypeFactory, null), + null) + + relBuilder.push(scan) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala index ecf1996cc921f..1b6a8da504816 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala @@ -17,8 +17,11 @@ */ package org.apache.flink.api.table.plan.logical + +import com.google.common.collect.Sets import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.CorrelationId import org.apache.calcite.rel.logical.LogicalProject import org.apache.calcite.rex.{RexInputRef, RexNode} import org.apache.calcite.tools.RelBuilder @@ -27,6 +30,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.api.table._ import org.apache.flink.api.table.expressions._ + import org.apache.flink.api.table.typeutils.TypeConverter import org.apache.flink.api.table.validate.{ValidationFailure, ValidationSuccess} @@ -361,7 +365,8 @@ case class Join( left: LogicalNode, right: LogicalNode, joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { + condition: Option[Expression], + correlated: Boolean) extends BinaryNode { override def output: Seq[Attribute] = { left.output ++ right.output @@ -411,22 +416,31 @@ case class Join( right) } val resolvedCondition = node.condition.map(_.postOrderTransform(partialFunction)) - Join(node.left, node.right, node.joinType, resolvedCondition) + Join(node.left, node.right, node.joinType, resolvedCondition, correlated) } override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { left.construct(relBuilder) right.construct(relBuilder) + + val corSet = Sets.newHashSet[CorrelationId]() + + if (correlated) { + corSet.add(relBuilder.peek().getCluster.createCorrel()) + } + relBuilder.join( TypeConverter.flinkJoinTypeToRelType(joinType), - condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true))) + condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true)), + corSet) } private def ambiguousName: Set[String] = left.output.map(_.name).toSet.intersect(right.output.map(_.name).toSet) override def validate(tableEnv: TableEnvironment): LogicalNode = { - if (tableEnv.isInstanceOf[StreamTableEnvironment]) { + if (tableEnv.isInstanceOf[StreamTableEnvironment] + && !right.isInstanceOf[LogicalTableFunctionCall]) { failValidation(s"Join on stream tables is currently not supported.") } @@ -605,3 +619,4 @@ case class WindowAggregate( resolvedWindowAggregate } } + diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala new file mode 100644 index 0000000000000..821c55529a1ce --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.plan.nodes + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rex.{RexCall, RexNode} +import org.apache.calcite.sql.SemiJoinType +import org.apache.flink.api.common.functions.FlatMapFunction +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.table.codegen.{CodeGenerator, GeneratedExpression, GeneratedFunction} +import org.apache.flink.api.table.codegen.CodeGenUtils.primitiveDefaultValue +import org.apache.flink.api.table.functions.utils.TableSqlFunction +import org.apache.flink.api.table.runtime.FlatMapRunner +import org.apache.flink.api.table.typeutils.TypeConverter._ +import org.apache.flink.api.table.{TableConfig, TableException} + +import scala.collection.JavaConverters._ + +/** + * cross/outer apply a user-defined table function + */ +trait FlinkCorrelate { + + private[flink] def functionBody( + generator: CodeGenerator, + udtfTypeInfo: TypeInformation[Any], + rowType: RelDataType, + rexCall: RexCall, + condition: Option[RexNode], + config: TableConfig, + joinType: SemiJoinType, + expectedType: Option[TypeInformation[Any]]): String = { + + val returnType = determineReturnType( + rowType, + expectedType, + config.getNullCheck, + config.getEfficientTypeUsage) + + val (input1AccessExprs, input2AccessExprs) = generator.generateCorrelateAccessExprs + val crossResultExpr = generator.generateResultExpression(input1AccessExprs ++ input2AccessExprs, + returnType, rowType.getFieldNames.asScala) + + val call = generator.generateExpression(rexCall) + var body = + s""" + |${call.code} + |java.util.Iterator iter = ${call.resultTerm}.getRowsIterator(); + """.stripMargin + + if (joinType == SemiJoinType.INNER) { + // cross apply + body += + s""" + |if (!iter.hasNext()) { + | return; + |} + """.stripMargin + } else if (joinType == SemiJoinType.LEFT) { + // outer apply + val input2NullExprs = input2AccessExprs.map( + x => GeneratedExpression(primitiveDefaultValue(x.resultType), "true", "", x.resultType)) + val outerResultExpr = generator.generateResultExpression( + input1AccessExprs ++ input2NullExprs, returnType, rowType.getFieldNames.asScala) + body += + s""" + |if (!iter.hasNext()) { + | ${outerResultExpr.code} + | ${generator.collectorTerm}.collect(${outerResultExpr.resultTerm}); + | return; + |} + """.stripMargin + } else { + throw TableException(s"Unsupported SemiJoinType: $joinType for correlate join.") + } + + val projection = if (condition.isEmpty) { + s""" + |${crossResultExpr.code} + |${generator.collectorTerm}.collect(${crossResultExpr.resultTerm}); + """.stripMargin + } else { + val filterGenerator = new CodeGenerator(config, false, udtfTypeInfo) { + override def input1Term: String = input2Term + } + val filterCondition = filterGenerator.generateExpression(condition.get) + s""" + |${filterGenerator.reuseInputUnboxingCode()} + |${filterCondition.code} + |if (${filterCondition.resultTerm}) { + | ${crossResultExpr.code} + | ${generator.collectorTerm}.collect(${crossResultExpr.resultTerm}); + |} + |""".stripMargin + } + + val outputTypeClass = udtfTypeInfo.getTypeClass.getCanonicalName + body += + s""" + |while (iter.hasNext()) { + | $outputTypeClass ${generator.input2Term} = ($outputTypeClass) iter.next(); + | $projection + |} + """.stripMargin + body + } + + private[flink] def correlateMapFunction( + genFunction: GeneratedFunction[FlatMapFunction[Any, Any]]) + : FlatMapRunner[Any, Any] = { + + new FlatMapRunner[Any, Any]( + genFunction.name, + genFunction.code, + genFunction.returnType) + } + + private[flink] def selectToString(rowType: RelDataType): String = { + rowType.getFieldNames.asScala.mkString(",") + } + + private[flink] def correlateOpName( + rexCall: RexCall, + sqlFunction: TableSqlFunction, + rowType: RelDataType) + : String = { + + s"correlate: ${correlateToString(rexCall, sqlFunction)}, select: ${selectToString(rowType)}" + } + + private[flink] def correlateToString(rexCall: RexCall, sqlFunction: TableSqlFunction): String = { + val udtfName = sqlFunction.getName + val operands = rexCall.getOperands.asScala.map(_.toString).mkString(",") + s"table($udtfName($operands))" + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala new file mode 100644 index 0000000000000..d6715ff9a13f8 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.plan.nodes.dataset + +import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.logical.LogicalTableFunctionScan +import org.apache.calcite.rel.metadata.RelMetadataQuery +import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.calcite.rex.{RexNode, RexCall} +import org.apache.calcite.sql.SemiJoinType +import org.apache.flink.api.common.functions.FlatMapFunction +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.DataSet +import org.apache.flink.api.table.BatchTableEnvironment +import org.apache.flink.api.table.codegen.CodeGenerator +import org.apache.flink.api.table.functions.utils.TableSqlFunction +import org.apache.flink.api.table.plan.nodes.FlinkCorrelate +import org.apache.flink.api.table.typeutils.TypeConverter._ + +/** + * Flink RelNode which matches along with cross apply a user defined table function. + */ +class DataSetCorrelate( + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputNode: RelNode, + scan: LogicalTableFunctionScan, + condition: Option[RexNode], + relRowType: RelDataType, + joinRowType: RelDataType, + joinType: SemiJoinType, + ruleDescription: String) + extends SingleRel(cluster, traitSet, inputNode) + with FlinkCorrelate + with DataSetRel { + override def deriveRowType() = relRowType + + + override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { + val rowCnt = metadata.getRowCount(getInput) * 1.5 + planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * 0.5) + } + + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { + new DataSetCorrelate( + cluster, + traitSet, + inputs.get(0), + scan, + condition, + relRowType, + joinRowType, + joinType, + ruleDescription) + } + + override def toString: String = { + val rexCall = scan.getCall.asInstanceOf[RexCall] + val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction] + correlateToString(rexCall, sqlFunction) + } + + override def explainTerms(pw: RelWriter): RelWriter = { + val rexCall = scan.getCall.asInstanceOf[RexCall] + val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction] + super.explainTerms(pw) + .item("invocation", scan.getCall) + .item("function", sqlFunction.getTableFunction.getClass.getCanonicalName) + .item("rowType", relRowType) + .item("joinType", joinType) + .itemIf("condition", condition.orNull, condition.isDefined) + } + + + override def translateToPlan( + tableEnv: BatchTableEnvironment, + expectedType: Option[TypeInformation[Any]]) + : DataSet[Any] = { + + val config = tableEnv.getConfig + val returnType = determineReturnType( + getRowType, + expectedType, + config.getNullCheck, + config.getEfficientTypeUsage) + + // do not need to specify input type + val inputDS = inputNode.asInstanceOf[DataSetRel].translateToPlan(tableEnv) + + val funcRel = scan.asInstanceOf[LogicalTableFunctionScan] + val rexCall = funcRel.getCall.asInstanceOf[RexCall] + val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction] + val pojoFieldMapping = sqlFunction.getPojoFieldMapping + val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] + + val generator = new CodeGenerator( + config, + false, + inputDS.getType, + Some(udtfTypeInfo), + None, + Some(pojoFieldMapping)) + + val body = functionBody( + generator, + udtfTypeInfo, + getRowType, + rexCall, + condition, + config, + joinType, + expectedType) + + val genFunction = generator.generateFunction( + ruleDescription, + classOf[FlatMapFunction[Any, Any]], + body, + returnType) + + val mapFunc = correlateMapFunction(genFunction) + + inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType)) + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala new file mode 100644 index 0000000000000..b0bc48abe73b0 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamCorrelate.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.plan.nodes.datastream + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.logical.LogicalTableFunctionScan +import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.calcite.rex.{RexCall, RexNode} +import org.apache.calcite.sql.SemiJoinType +import org.apache.flink.api.common.functions.FlatMapFunction +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.table.StreamTableEnvironment +import org.apache.flink.api.table.codegen.CodeGenerator +import org.apache.flink.api.table.functions.utils.TableSqlFunction +import org.apache.flink.api.table.plan.nodes.FlinkCorrelate +import org.apache.flink.api.table.typeutils.TypeConverter._ +import org.apache.flink.streaming.api.datastream.DataStream + +/** + * Flink RelNode which matches along with cross apply a user defined table function. + */ +class DataStreamCorrelate( + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputNode: RelNode, + scan: LogicalTableFunctionScan, + condition: Option[RexNode], + relRowType: RelDataType, + joinRowType: RelDataType, + joinType: SemiJoinType, + ruleDescription: String) + extends SingleRel(cluster, traitSet, inputNode) + with FlinkCorrelate + with DataStreamRel { + override def deriveRowType() = relRowType + + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { + new DataStreamCorrelate( + cluster, + traitSet, + inputs.get(0), + scan, + condition, + relRowType, + joinRowType, + joinType, + ruleDescription) + } + + override def toString: String = { + val rexCall = scan.getCall.asInstanceOf[RexCall] + val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction] + correlateToString(rexCall, sqlFunction) + } + + override def explainTerms(pw: RelWriter): RelWriter = { + val rexCall = scan.getCall.asInstanceOf[RexCall] + val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction] + super.explainTerms(pw) + .item("invocation", scan.getCall) + .item("function", sqlFunction.getTableFunction.getClass.getCanonicalName) + .item("rowType", relRowType) + .item("joinType", joinType) + .itemIf("condition", condition.orNull, condition.isDefined) + } + + override def translateToPlan( + tableEnv: StreamTableEnvironment, + expectedType: Option[TypeInformation[Any]]) + : DataStream[Any] = { + + val config = tableEnv.getConfig + val returnType = determineReturnType( + getRowType, + expectedType, + config.getNullCheck, + config.getEfficientTypeUsage) + + // do not need to specify input type + val inputDS = inputNode.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) + + val funcRel = scan.asInstanceOf[LogicalTableFunctionScan] + val rexCall = funcRel.getCall.asInstanceOf[RexCall] + val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction] + val pojoFieldMapping = sqlFunction.getPojoFieldMapping + val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] + + val generator = new CodeGenerator( + config, + false, + inputDS.getType, + Some(udtfTypeInfo), + None, + Some(pojoFieldMapping)) + + val body = functionBody( + generator, + udtfTypeInfo, + getRowType, + rexCall, + condition, + config, + joinType, + expectedType) + + val genFunction = generator.generateFunction( + ruleDescription, + classOf[FlatMapFunction[Any, Any]], + body, + returnType) + + val mapFunc = correlateMapFunction(genFunction) + + inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType)) + } + +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala index 26c025eb1382a..fafc4ba59c917 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala @@ -107,6 +107,7 @@ object FlinkRuleSets { DataSetMinusRule.INSTANCE, DataSetSortRule.INSTANCE, DataSetValuesRule.INSTANCE, + DataSetCorrelateRule.INSTANCE, BatchTableSourceScanRule.INSTANCE ) @@ -150,6 +151,7 @@ object FlinkRuleSets { DataStreamScanRule.INSTANCE, DataStreamUnionRule.INSTANCE, DataStreamValuesRule.INSTANCE, + DataStreamCorrelateRule.INSTANCE, StreamTableSourceScanRule.INSTANCE ) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala new file mode 100644 index 0000000000000..bccb2578a467f --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.plan.rules.dataSet + +import org.apache.calcite.plan.volcano.RelSubset +import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.calcite.rel.logical.{LogicalFilter, LogicalCorrelate, LogicalTableFunctionScan} +import org.apache.calcite.rex.RexNode +import org.apache.flink.api.table.plan.nodes.dataset.{DataSetConvention, DataSetCorrelate} + +/** + * Rule to convert a LogicalCorrelate into a DataSetCorrelate. + */ +class DataSetCorrelateRule + extends ConverterRule( + classOf[LogicalCorrelate], + Convention.NONE, + DataSetConvention.INSTANCE, + "DataSetCorrelateRule") + { + + override def matches(call: RelOptRuleCall): Boolean = { + val join: LogicalCorrelate = call.rel(0).asInstanceOf[LogicalCorrelate] + val right = join.getRight.asInstanceOf[RelSubset].getOriginal + + + right match { + // right node is a table function + case scan: LogicalTableFunctionScan => true + // a filter is pushed above the table function + case filter: LogicalFilter => + filter.getInput.asInstanceOf[RelSubset].getOriginal + .isInstanceOf[LogicalTableFunctionScan] + case _ => false + } + } + + override def convert(rel: RelNode): RelNode = { + val join: LogicalCorrelate = rel.asInstanceOf[LogicalCorrelate] + val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE) + val convInput: RelNode = RelOptRule.convert(join.getInput(0), DataSetConvention.INSTANCE) + val right: RelNode = join.getInput(1) + + def convertToCorrelate(relNode: RelNode, condition: Option[RexNode]): DataSetCorrelate = { + relNode match { + case rel: RelSubset => + convertToCorrelate(rel.getRelList.get(0), condition) + + case filter: LogicalFilter => + convertToCorrelate(filter.getInput.asInstanceOf[RelSubset].getOriginal, + Some(filter.getCondition)) + + case scan: LogicalTableFunctionScan => + new DataSetCorrelate( + rel.getCluster, + traitSet, + convInput, + scan, + condition, + rel.getRowType, + join.getRowType, + join.getJoinType, + description) + } + } + convertToCorrelate(right, None) + } + } + +object DataSetCorrelateRule { + val INSTANCE: RelOptRule = new DataSetCorrelateRule +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala new file mode 100644 index 0000000000000..bb52fd773915c --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamCorrelateRule.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.plan.rules.datastream + +import org.apache.calcite.plan.volcano.RelSubset +import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.calcite.rel.logical.{LogicalFilter, LogicalCorrelate, LogicalTableFunctionScan} +import org.apache.calcite.rex.RexNode +import org.apache.flink.api.table.plan.nodes.datastream.{DataStreamCorrelate, DataStreamConvention} + +/** + * Rule to convert a LogicalCorrelate into a DataStreamCorrelate. + */ +class DataStreamCorrelateRule + extends ConverterRule( + classOf[LogicalCorrelate], + Convention.NONE, + DataStreamConvention.INSTANCE, + "DataStreamCorrelateRule") +{ + + override def matches(call: RelOptRuleCall): Boolean = { + val join: LogicalCorrelate = call.rel(0).asInstanceOf[LogicalCorrelate] + val right = join.getRight.asInstanceOf[RelSubset].getOriginal + + right match { + // right node is a table function + case scan: LogicalTableFunctionScan => true + // a filter is pushed above the table function + case filter: LogicalFilter => + filter.getInput.asInstanceOf[RelSubset].getOriginal + .isInstanceOf[LogicalTableFunctionScan] + case _ => false + } + } + + override def convert(rel: RelNode): RelNode = { + val join: LogicalCorrelate = rel.asInstanceOf[LogicalCorrelate] + val traitSet: RelTraitSet = rel.getTraitSet.replace(DataStreamConvention.INSTANCE) + val convInput: RelNode = RelOptRule.convert(join.getInput(0), DataStreamConvention.INSTANCE) + val right: RelNode = join.getInput(1) + + def convertToCorrelate(relNode: RelNode, condition: Option[RexNode]): DataStreamCorrelate = { + relNode match { + case rel: RelSubset => + convertToCorrelate(rel.getRelList.get(0), condition) + + case filter: LogicalFilter => + convertToCorrelate(filter.getInput.asInstanceOf[RelSubset].getOriginal, + Some(filter.getCondition)) + + case scan: LogicalTableFunctionScan => + new DataStreamCorrelate( + rel.getCluster, + traitSet, + convInput, + scan, + condition, + rel.getRowType, + join.getRowType, + join.getJoinType, + description) + } + } + convertToCorrelate(right, None) + } + +} + +object DataStreamCorrelateRule { + val INSTANCE: RelOptRule = new DataStreamCorrelateRule +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala new file mode 100644 index 0000000000000..540a5c8382706 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/schema/FlinkTableFunctionImpl.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.plan.schema + +import java.lang.reflect.{Method, Type} +import java.util + +import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory} +import org.apache.calcite.schema.TableFunction +import org.apache.calcite.schema.impl.ReflectiveFunctionBase +import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.table.{FlinkTypeFactory, TableException} + +/** + * This is heavily inspired by Calcite's [[org.apache.calcite.schema.impl.TableFunctionImpl]]. + * We need it in order to create a [[org.apache.flink.api.table.functions.utils.TableSqlFunction]]. + * The main difference is that we override the [[getRowType()]] and [[getElementType()]]. + */ +class FlinkTableFunctionImpl[T]( + val typeInfo: TypeInformation[T], + val fieldIndexes: Array[Int], + val fieldNames: Array[String], + val evalMethod: Method) + extends ReflectiveFunctionBase(evalMethod) + with TableFunction { + + if (fieldIndexes.length != fieldNames.length) { + throw new TableException( + "Number of field indexes and field names must be equal.") + } + + // check uniqueness of field names + if (fieldNames.length != fieldNames.toSet.size) { + throw new TableException( + "Table field names must be unique.") + } + + val fieldTypes: Array[TypeInformation[_]] = + typeInfo match { + case cType: CompositeType[T] => + if (fieldNames.length != cType.getArity) { + throw new TableException( + s"Arity of type (" + cType.getFieldNames.deep + ") " + + "not equal to number of field names " + fieldNames.deep + ".") + } + fieldIndexes.map(cType.getTypeAt(_).asInstanceOf[TypeInformation[_]]) + case aType: AtomicType[T] => + if (fieldIndexes.length != 1 || fieldIndexes(0) != 0) { + throw new TableException( + "Non-composite input type may have only a single field and its index must be 0.") + } + Array(aType) + } + + override def getElementType(arguments: util.List[AnyRef]): Type = classOf[Array[Object]] + + override def getRowType(typeFactory: RelDataTypeFactory, + arguments: util.List[AnyRef]): RelDataType = { + val flinkTypeFactory = typeFactory.asInstanceOf[FlinkTypeFactory] + val builder = flinkTypeFactory.builder + fieldNames + .zip(fieldTypes) + .foreach { f => + builder.add(f._1, flinkTypeFactory.createTypeFromTypeInfo(f._2)).nullable(true) + } + builder.build + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala index c45e8719e1f4a..3fe3b988c4f1f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala @@ -20,6 +20,7 @@ package org.apache.flink.api.table import org.apache.calcite.rel.RelNode import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType +import org.apache.flink.api.table.plan.logical.Minus import org.apache.flink.api.table.expressions.{Asc, Expression, ExpressionParser, Ordering} import org.apache.flink.api.table.plan.ProjectionTranslator._ import org.apache.flink.api.table.plan.logical._ @@ -400,7 +401,8 @@ class Table( } new Table( tableEnv, - Join(this.logicalPlan, right.logicalPlan, joinType, joinPredicate).validate(tableEnv)) + Join(this.logicalPlan, right.logicalPlan, joinType, joinPredicate, correlated = false) + .validate(tableEnv)) } /** @@ -608,6 +610,125 @@ class Table( new Table(tableEnv, Limit(offset, fetch, logicalPlan).validate(tableEnv)) } + /** + * The Cross Apply returns rows from the outer table (table on the left of the Apply operator) + * that produces matching values from the table-valued function (which is on the right side of + * the operator). + * + * The Cross Apply is equivalent to Inner Join, but it works with a table-valued function. + * + * Example: + * + * {{{ + * class MySplitUDTF extends TableFunction[String] { + * def eval(str: String): Unit = { + * str.split("#").foreach(collect) + * } + * } + * + * val split = new MySplitUDTF() + * table.crossApply(split('c) as ('s)).select('a,'b,'c,'s) + * }}} + */ + def crossApply(udtf: TableFunctionCall): Table = { + applyInternal(udtf, JoinType.INNER) + } + + /** + * The Cross Apply returns rows from the outer table (table on the left of the Apply operator) + * that produces matching values from the table-valued function (which is on the right side of + * the operator). + * + * The Cross Apply is equivalent to Inner Join, but it works with a table-valued function. + * + * Example: + * + * {{{ + * class MySplitUDTF extends TableFunction[String] { + * def eval(str: String): Unit = { + * str.split("#").foreach(collect) + * } + * } + * + * val split = new MySplitUDTF() + * table.crossApply("split(c) as (s)").select("a, b, c, s") + * }}} + */ + def crossApply(udtf: String): Table = { + applyInternal(udtf, JoinType.INNER) + } + + /** + * The Outer Apply returns all the rows from the outer table (table on the left of the Apply + * operator), and rows that do not matches the condition from the table-valued function (which + * is on the right side of the operator), NULL values are displayed. + * + * The Outer Apply is equivalent to Left Outer Join, but it works with a table-valued function. + * + * Example: + * + * {{{ + * class MySplitUDTF extends TableFunction[String] { + * def eval(str: String): Unit = { + * str.split("#").foreach(collect) + * } + * } + * + * val split = new MySplitUDTF() + * table.outerApply(split('c) as ('s)).select('a,'b,'c,'s) + * }}} + */ + def outerApply(udtf: TableFunctionCall): Table = { + applyInternal(udtf, JoinType.LEFT_OUTER) + } + + /** + * The Outer Apply returns all the rows from the outer table (table on the left of the Apply + * operator), and rows that do not matches the condition from the table-valued function (which + * is on the right side of the operator), NULL values are displayed. + * + * The Outer Apply is equivalent to Left Outer Join, but it works with a table-valued function. + * + * Example: + * + * {{{ + * val split = new MySplitUDTF() + * table.outerApply("split(c) as (s)").select("a, b, c, s") + * }}} + */ + def outerApply(udtf: String): Table = { + applyInternal(udtf, JoinType.LEFT_OUTER) + } + + private def applyInternal(udtfString: String, joinType: JoinType): Table = { + val node = ExpressionParser.parseLogicalNode(udtfString) + var alias: Option[Seq[Expression]] = None + val functionCall = node match { + case AliasNode(aliasList, child) => + alias = Some(aliasList) + child + case _ => node + } + + functionCall match { + case call @ UnresolvedTableFunctionCall(name, args) => + val udtfCall = tableEnv.getFunctionCatalog.lookupTableFunction(name, args) + if (alias.isDefined) { + applyInternal(udtfCall.as(alias.get: _*), joinType) + } else { + applyInternal(udtfCall, joinType) + } + case _ => throw new TableException("Cross/Outer Apply only accept TableFunction") + } + } + + private def applyInternal(node: TableFunctionCall, joinType: JoinType): Table = { + val logicalCall = node.toLogicalTableFunctionCall(this.logicalPlan).validate(tableEnv) + new Table( + tableEnv, + Join(this.logicalPlan, logicalCall, joinType, None, correlated = true).validate(tableEnv)) + } + /** * Writes the [[Table]] to a [[TableSink]]. A [[TableSink]] defines an external storage location. * diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala index 679733c9389ea..4721208d6ed2a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala @@ -21,10 +21,10 @@ package org.apache.flink.api.table.validate import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.sql.util.{ChainedSqlOperatorTable, ListSqlOperatorTable, ReflectiveSqlOperatorTable} import org.apache.calcite.sql.{SqlFunction, SqlOperator, SqlOperatorTable} -import org.apache.flink.api.table.ValidationException +import org.apache.flink.api.table.{TableFunctionCall, ValidationException} import org.apache.flink.api.table.expressions._ -import org.apache.flink.api.table.functions.ScalarFunction -import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils +import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction} +import org.apache.flink.api.table.functions.utils.{TableSqlFunction, UserDefinedFunctionUtils} import scala.collection.JavaConversions._ import scala.collection.mutable @@ -47,12 +47,49 @@ class FunctionCatalog { sqlFunctions += sqlFunction } + /** Register multiple sql functions at one time. The functions has the same name. **/ + def registerSqlFunctions(functions: Seq[SqlFunction]): Unit = { + if (functions.nonEmpty) { + val name = functions.head.getName + // check all name is the same in the functions + if (functions.forall(_.getName == name)) { + sqlFunctions --= sqlFunctions.filter(_.getName == name) + sqlFunctions ++= functions + } else { + throw ValidationException("The sql functions request to register have different name.") + } + } + } + def getSqlOperatorTable: SqlOperatorTable = ChainedSqlOperatorTable.of( new BasicOperatorTable(), new ListSqlOperatorTable(sqlFunctions) ) + /** + * Lookup table function and create an TableFunctionCall if we find a match. + */ + def lookupTableFunction[T](name: String, children: Seq[Expression]): TableFunctionCall = { + val funcClass = functionBuilders + .getOrElse(name.toLowerCase, throw ValidationException(s"Undefined table function: $name")) + funcClass match { + // user-defined table function call + case tf if classOf[TableFunction[T]].isAssignableFrom(tf) => + val tableSqlFunction = sqlFunctions + .find(f => f.getName.equalsIgnoreCase(name) && f.isInstanceOf[TableSqlFunction]) + .getOrElse(throw ValidationException(s"Unregistered table sql function: $name")) + .asInstanceOf[TableSqlFunction] + val typeInfo = tableSqlFunction.getRowTypeInfo + val function = tableSqlFunction.getTableFunction + TableFunctionCall(name, function, children, typeInfo) + + case _ => + throw ValidationException(s"The registered function under name '$name' " + + s"is not a TableFunction") + } + } + /** * Lookup and create an expression if we find a match. */ diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala new file mode 100644 index 0000000000000..7e0d0ffb7478e --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionITCase.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.batch + +import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase +import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.table.expressions.utils._ +import org.apache.flink.api.table.{Row, Table, TableEnvironment} +import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.TestBaseUtils +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +@RunWith(classOf[Parameterized]) +class UserDefinedTableFunctionITCase( + mode: TestExecutionMode, + configMode: TableConfigMode) + extends TableProgramsTestBase(mode, configMode) { + + @Test + def testSQLCrossApply(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env) + val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c) + tableEnv.registerTable("MyTable", in) + tableEnv.registerFunction("split", new TableFunc1) + + val sqlQuery = "SELECT MyTable.c, t.s FROM MyTable, LATERAL TABLE(split(c)) AS t(s)" + + val result = tableEnv.sql(sqlQuery).toDataSet[Row] + val results = result.collect() + val expected: String = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" + + "Anna#44,Anna\n" + "Anna#44,44\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testSQLOuterApply(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env) + val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c) + tableEnv.registerTable("MyTable", in) + tableEnv.registerFunction("split", new TableFunc2) + + val sqlQuery = "SELECT MyTable.c, t.a, t.b FROM MyTable LEFT JOIN LATERAL TABLE(split(c)) " + + "AS t(a,b) ON TRUE" + + val result = tableEnv.sql(sqlQuery).toDataSet[Row] + val results = result.collect() + val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" + + "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testTableAPICrossApply(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env) + val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c) + + val func1 = new TableFunc1 + val result = in.crossApply(func1('c) as ('s)).select('c, 's).toDataSet[Row] + val results = result.collect() + val expected: String = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" + + "Anna#44,Anna\n" + "Anna#44,44\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + + // with overloading + val result2 = in.crossApply(func1('c, "$") as ('s)).select('c, 's).toDataSet[Row] + val results2 = result2.collect() + val expected2: String = "Jack#22,$Jack\n" + "Jack#22,$22\n" + "John#19,$John\n" + + "John#19,$19\n" + "Anna#44,$Anna\n" + "Anna#44,$44\n" + TestBaseUtils.compareResultAsText(results2.asJava, expected2) + } + + + @Test + def testTableAPIOuterApply(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env) + val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c) + val func2 = new TableFunc2 + val result = in.outerApply(func2('c) as ('s, 'l)).select('c, 's, 'l).toDataSet[Row] + val results = result.collect() + val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" + + "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + + @Test + def testCustomReturnType(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env) + val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c) + val func2 = new TableFunc2 + + val result = in + .crossApply(func2('c) as ('name, 'len)) + .select('c, 'name, 'len) + .toDataSet[Row] + + val results = result.collect() + val expected: String = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" + + "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testHierarchyType(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env) + val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c) + + val hierarchy = new HierarchyTableFunction + val result = in + .crossApply(hierarchy('c) as ('name, 'adult, 'len)) + .select('c, 'name, 'adult, 'len) + .toDataSet[Row] + + val results = result.collect() + val expected: String = "Jack#22,Jack,true,22\n" + "John#19,John,false,19\n" + + "Anna#44,Anna,true,44\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testPojoType(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env) + val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c) + + val pojo = new PojoTableFunc() + val result = in + .crossApply(pojo('c)) + .select('c, 'name, 'age) + .toDataSet[Row] + + val results = result.collect() + val expected: String = "Jack#22,Jack,22\n" + "John#19,John,19\n" + "Anna#44,Anna,44\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + + @Test + def testTableAPIWithFilter(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env) + val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c) + val func0 = new TableFunc0 + + val result = in + .crossApply(func0('c) as ('name, 'age)) + .select('c, 'name, 'age) + .filter('age > 20) + .toDataSet[Row] + + val results = result.collect() + val expected: String = "Jack#22,Jack,22\n" + "Anna#44,Anna,44\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + + @Test + def testUDTFWithScalarFunction(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tableEnv: BatchTableEnvironment = TableEnvironment.getTableEnvironment(env) + val in: Table = getSmall3TupleDataSet(env).toTable(tableEnv).as('a, 'b, 'c) + val func1 = new TableFunc1 + + val result = in + .crossApply(func1('c.substring(2)) as 's) + .select('c, 's) + .toDataSet[Row] + + val results = result.collect() + val expected: String = "Jack#22,ack\n" + "Jack#22,22\n" + "John#19,ohn\n" + "John#19,19\n" + + "Anna#44,nna\n" + "Anna#44,44\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + + private def getSmall3TupleDataSet(env: ExecutionEnvironment): DataSet[(Int, Long, String)] = { + val data = new mutable.MutableList[(Int, Long, String)] + data.+=((1, 1L, "Jack#22")) + data.+=((2, 2L, "John#19")) + data.+=((3, 2L, "Anna#44")) + data.+=((4, 3L, "nosharp")) + env.fromCollection(data) + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala new file mode 100644 index 0000000000000..7e236d12725f5 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/UserDefinedTableFunctionTest.scala @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.batch + +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment => ScalaExecutionEnv, _} +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.{DataSet => JDataSet, ExecutionEnvironment => JavaExecutionEnv} +import org.apache.flink.api.table.expressions.utils.{HierarchyTableFunction, PojoTableFunc, TableFunc1, TableFunc2} +import org.apache.flink.api.table.typeutils.RowTypeInfo +import org.apache.flink.api.table.utils.TableTestBase +import org.apache.flink.api.table.utils.TableTestUtil._ +import org.apache.flink.api.table.{Row, TableEnvironment, Types} +import org.junit.Test +import org.mockito.Mockito._ + + +class UserDefinedTableFunctionTest extends TableTestBase { + + @Test + def testTableAPI(): Unit = { + // mock + val ds = mock(classOf[DataSet[Row]]) + val jDs = mock(classOf[JDataSet[Row]]) + val typeInfo: TypeInformation[Row] = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING)) + when(ds.javaSet).thenReturn(jDs) + when(jDs.getType).thenReturn(typeInfo) + + // Scala environment + val env = mock(classOf[ScalaExecutionEnv]) + val tableEnv = TableEnvironment.getTableEnvironment(env) + val in1 = ds.toTable(tableEnv).as('a, 'b, 'c) + + // Java environment + val javaEnv = mock(classOf[JavaExecutionEnv]) + val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv) + val in2 = javaTableEnv.fromDataSet(jDs).as("a, b, c") + javaTableEnv.registerTable("MyTable", in2) + + // test cross apply + val func1 = new TableFunc1 + javaTableEnv.registerFunction("func1", func1) + var scalaTable = in1.crossApply(func1('c) as ('s)).select('c, 's) + var javaTable = in2.crossApply("func1(c) as (s)").select("c, s") + verifyTableEquals(scalaTable, javaTable) + + // test outer apply + scalaTable = in1.outerApply(func1('c) as ('s)).select('c, 's) + javaTable = in2.outerApply("func1(c) as (s)").select("c, s") + verifyTableEquals(scalaTable, javaTable) + + // test overloading + scalaTable = in1.crossApply(func1('c, "$") as ('s)).select('c, 's) + javaTable = in2.crossApply("func1(c, '$') as (s)").select("c, s") + verifyTableEquals(scalaTable, javaTable) + + // test custom result type + val func2 = new TableFunc2 + javaTableEnv.registerFunction("func2", func2) + scalaTable = in1.crossApply(func2('c) as ('name, 'len)).select('c, 'name, 'len) + javaTable = in2.crossApply("func2(c) as (name, len)").select("c, name, len") + verifyTableEquals(scalaTable, javaTable) + + // test hierarchy generic type + val hierarchy = new HierarchyTableFunction + javaTableEnv.registerFunction("hierarchy", hierarchy) + scalaTable = in1.crossApply(hierarchy('c) as ('name, 'adult, 'len)) + .select('c, 'name, 'len, 'adult) + javaTable = in2.crossApply("hierarchy(c) as (name, adult, len)") + .select("c, name, len, adult") + verifyTableEquals(scalaTable, javaTable) + + // test pojo type + val pojo = new PojoTableFunc + javaTableEnv.registerFunction("pojo", pojo) + scalaTable = in1.crossApply(pojo('c)) + .select('c, 'name, 'age) + javaTable = in2.crossApply("pojo(c)") + .select("c, name, age") + verifyTableEquals(scalaTable, javaTable) + + // test with filter + scalaTable = in1.crossApply(func2('c) as ('name, 'len)) + .select('c, 'name, 'len).filter('len > 2) + javaTable = in2.crossApply("func2(c) as (name, len)") + .select("c, name, len").filter("len > 2") + verifyTableEquals(scalaTable, javaTable) + + // test with scalar function + scalaTable = in1.crossApply(func1('c.substring(2)) as ('s)) + .select('a, 'c, 's) + javaTable = in2.crossApply("func1(substring(c, 2)) as (s)") + .select("a, c, s") + verifyTableEquals(scalaTable, javaTable) + } + + @Test + def testSQLWithCrossApply(): Unit = { + val util = batchTestUtil() + val func1 = new TableFunc1 + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addFunction("func1", func1) + + val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)" + + val expected = unaryNode( + "DataSetCalc", + unaryNode( + "DataSetCorrelate", + batchTableNode(0), + term("invocation", "func1($cor0.c)"), + term("function", func1.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"), + term("joinType", "INNER") + ), + term("select", "c", "f0 AS s") + ) + + util.verifySql(sqlQuery, expected) + + // test overloading + + val sqlQuery2 = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c, '$')) AS T(s)" + + val expected2 = unaryNode( + "DataSetCalc", + unaryNode( + "DataSetCorrelate", + batchTableNode(0), + term("invocation", "func1($cor0.c, '$')"), + term("function", func1.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"), + term("joinType", "INNER") + ), + term("select", "c", "f0 AS s") + ) + + util.verifySql(sqlQuery2, expected2) + } + + @Test + def testSQLWithOuterApply(): Unit = { + val util = batchTestUtil() + val func1 = new TableFunc1 + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addFunction("func1", func1) + + val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON TRUE" + + val expected = unaryNode( + "DataSetCalc", + unaryNode( + "DataSetCorrelate", + batchTableNode(0), + term("invocation", "func1($cor0.c)"), + term("function", func1.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"), + term("joinType", "LEFT") + ), + term("select", "c", "f0 AS s") + ) + + util.verifySql(sqlQuery, expected) + } + + @Test + def testSQLWithCustomType(): Unit = { + val util = batchTestUtil() + val func2 = new TableFunc2 + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addFunction("func2", func2) + + val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len)" + + val expected = unaryNode( + "DataSetCalc", + unaryNode( + "DataSetCorrelate", + batchTableNode(0), + term("invocation", "func2($cor0.c)"), + term("function", func2.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " + + "VARCHAR(2147483647) f0, INTEGER f1)"), + term("joinType", "INNER") + ), + term("select", "c", "f0 AS name", "f1 AS len") + ) + + util.verifySql(sqlQuery, expected) + } + + @Test + def testSQLWithHierarchyType(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val function = new HierarchyTableFunction + util.addFunction("hierarchy", function) + + val sqlQuery = "SELECT c, T.* FROM MyTable, LATERAL TABLE(hierarchy(c)) AS T(name, adult, len)" + + val expected = unaryNode( + "DataSetCalc", + unaryNode( + "DataSetCorrelate", + batchTableNode(0), + term("invocation", "hierarchy($cor0.c)"), + term("function", function.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," + + " VARCHAR(2147483647) f0, BOOLEAN f1, INTEGER f2)"), + term("joinType", "INNER") + ), + term("select", "c", "f0 AS name", "f1 AS adult", "f2 AS len") + ) + + util.verifySql(sqlQuery, expected) + } + + @Test + def testSQLWithPojoType(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val function = new PojoTableFunc + util.addFunction("pojo", function) + + val sqlQuery = "SELECT c, name, age FROM MyTable, LATERAL TABLE(pojo(c))" + + val expected = unaryNode( + "DataSetCalc", + unaryNode( + "DataSetCorrelate", + batchTableNode(0), + term("invocation", "pojo($cor0.c)"), + term("function", function.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," + + " INTEGER age, VARCHAR(2147483647) name)"), + term("joinType", "INNER") + ), + term("select", "c", "name", "age") + ) + + util.verifySql(sqlQuery, expected) + } + + @Test + def testSQLWithFilter(): Unit = { + val util = batchTestUtil() + val func2 = new TableFunc2 + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addFunction("func2", func2) + + val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len) " + + "WHERE len > 2" + + val expected = unaryNode( + "DataSetCalc", + unaryNode( + "DataSetCorrelate", + batchTableNode(0), + term("invocation", "func2($cor0.c)"), + term("function", func2.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " + + "VARCHAR(2147483647) f0, INTEGER f1)"), + term("joinType", "INNER"), + term("condition", ">($1, 2)") + ), + term("select", "c", "f0 AS name", "f1 AS len") + ) + + util.verifySql(sqlQuery, expected) + } + + + @Test + def testSQLWithScalarFunction(): Unit = { + val util = batchTestUtil() + val func1 = new TableFunc1 + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addFunction("func1", func1) + + val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)" + + val expected = unaryNode( + "DataSetCalc", + unaryNode( + "DataSetCorrelate", + batchTableNode(0), + term("invocation", "func1(SUBSTRING($cor0.c, 2))"), + term("function", func1.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"), + term("joinType", "INNER") + ), + term("select", "c", "f0 AS s") + ) + + util.verifySql(sqlQuery, expected) + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala new file mode 100644 index 0000000000000..f19f7f91786c1 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.stream + +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.stream.utils.StreamITCase +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.table.expressions.utils.{TableFunc0, TableFunc1} +import org.apache.flink.api.table.{Row, TableEnvironment} +import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase +import org.junit.Assert._ +import org.junit.Test + +import scala.collection.mutable + +class UserDefinedTableFunctionITCase extends StreamingMultipleProgramsTestBase { + + @Test + def testSQLCrossApply(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("MyTable", t) + + tEnv.registerFunction("split", new TableFunc0) + + val sqlQuery = "SELECT MyTable.c, t.n, t.a FROM MyTable, LATERAL TABLE(split(c)) AS t(n,a)" + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testSQLOuterApply(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("MyTable", t) + + tEnv.registerFunction("split", new TableFunc0) + + val sqlQuery = "SELECT MyTable.c, t.n, t.a FROM MyTable " + + "LEFT JOIN LATERAL TABLE(split(c)) AS t(n,a) ON TRUE" + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "nosharp,null,null", "Jack#22,Jack,22", + "John#19,John,19", "Anna#44,Anna,44") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testTableAPICrossApply(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c) + val func0 = new TableFunc0 + + val result = t + .crossApply(func0('c) as('d, 'e)) + .select('c, 'd, 'e) + .toDataStream[Row] + + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testTableAPIOuterApply(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c) + val func0 = new TableFunc0 + + val result = t + .outerApply(func0('c) as('d, 'e)) + .select('c, 'd, 'e) + .toDataStream[Row] + + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "nosharp,null,null", "Jack#22,Jack,22", + "John#19,John,19", "Anna#44,Anna,44") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testTableAPIWithFilter(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c) + val func0 = new TableFunc0 + + val result = t + .crossApply(func0('c) as('name, 'age)) + .select('c, 'name, 'age) + .filter('age > 20) + .toDataStream[Row] + + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("Jack#22,Jack,22", "Anna#44,Anna,44") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testTableAPIWithScalarFunction(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c) + val func1 = new TableFunc1 + + val result = t + .crossApply(func1('c.substring(2)) as 's) + .select('c, 's) + .toDataStream[Row] + + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("Jack#22,ack", "Jack#22,22", "John#19,ohn", + "John#19,19", "Anna#44,nna", "Anna#44,44") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + private def getSmall3TupleDataStream( + env: StreamExecutionEnvironment) + : DataStream[(Int, Long, String)] = { + + val data = new mutable.MutableList[(Int, Long, String)] + data.+=((1, 1L, "Jack#22")) + data.+=((2, 2L, "John#19")) + data.+=((3, 2L, "Anna#44")) + data.+=((4, 3L, "nosharp")) + env.fromCollection(data) + } + +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala new file mode 100644 index 0000000000000..fa3da6b8ca793 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.stream + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.table.expressions.utils._ +import org.apache.flink.api.table.typeutils.RowTypeInfo +import org.apache.flink.api.table.utils.TableTestBase +import org.apache.flink.api.table.utils.TableTestUtil._ +import org.apache.flink.api.table._ +import org.apache.flink.streaming.api.environment.{StreamExecutionEnvironment => JavaExecutionEnv} +import org.apache.flink.streaming.api.datastream.{DataStream => JDataStream} +import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment => ScalaExecutionEnv} +import org.junit.Assert.{assertTrue, fail} +import org.junit.Test +import org.mockito.Mockito._ + +class UserDefinedTableFunctionTest extends TableTestBase { + + @Test + def testTableAPI(): Unit = { + // mock + val ds = mock(classOf[DataStream[Row]]) + val jDs = mock(classOf[JDataStream[Row]]) + val typeInfo: TypeInformation[Row] = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING)) + when(ds.javaStream).thenReturn(jDs) + when(jDs.getType).thenReturn(typeInfo) + + // Scala environment + val env = mock(classOf[ScalaExecutionEnv]) + val tableEnv = TableEnvironment.getTableEnvironment(env) + val in1 = ds.toTable(tableEnv).as('a, 'b, 'c) + + // Java environment + val javaEnv = mock(classOf[JavaExecutionEnv]) + val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv) + val in2 = javaTableEnv.fromDataStream(jDs).as("a, b, c") + + // test cross apply + val func1 = new TableFunc1 + javaTableEnv.registerFunction("func1", func1) + var scalaTable = in1.crossApply(func1('c) as ('s)).select('c, 's) + var javaTable = in2.crossApply("func1(c) as (s)").select("c, s") + verifyTableEquals(scalaTable, javaTable) + + // test outer apply + scalaTable = in1.outerApply(func1('c) as ('s)).select('c, 's) + javaTable = in2.outerApply("func1(c) as (s)").select("c, s") + verifyTableEquals(scalaTable, javaTable) + + // test overloading + scalaTable = in1.crossApply(func1('c, "$") as ('s)).select('c, 's) + javaTable = in2.crossApply("func1(c, '$') as (s)").select("c, s") + verifyTableEquals(scalaTable, javaTable) + + // test custom result type + val func2 = new TableFunc2 + javaTableEnv.registerFunction("func2", func2) + scalaTable = in1.crossApply(func2('c) as ('name, 'len)).select('c, 'name, 'len) + javaTable = in2.crossApply("func2(c) as (name, len)").select("c, name, len") + verifyTableEquals(scalaTable, javaTable) + + // test hierarchy generic type + val hierarchy = new HierarchyTableFunction + javaTableEnv.registerFunction("hierarchy", hierarchy) + scalaTable = in1.crossApply(hierarchy('c) as ('name, 'adult, 'len)) + .select('c, 'name, 'len, 'adult) + javaTable = in2.crossApply("hierarchy(c) as (name, adult, len)") + .select("c, name, len, adult") + verifyTableEquals(scalaTable, javaTable) + + // test pojo type + val pojo = new PojoTableFunc + javaTableEnv.registerFunction("pojo", pojo) + scalaTable = in1.crossApply(pojo('c)) + .select('c, 'name, 'age) + javaTable = in2.crossApply("pojo(c)") + .select("c, name, age") + verifyTableEquals(scalaTable, javaTable) + + // test with filter + scalaTable = in1.crossApply(func2('c) as ('name, 'len)) + .select('c, 'name, 'len).filter('len > 2) + javaTable = in2.crossApply("func2(c) as (name, len)") + .select("c, name, len").filter("len > 2") + verifyTableEquals(scalaTable, javaTable) + + // test with scalar function + scalaTable = in1.crossApply(func1('c.substring(2)) as ('s)) + .select('a, 'c, 's) + javaTable = in2.crossApply("func1(substring(c, 2)) as (s)") + .select("a, c, s") + verifyTableEquals(scalaTable, javaTable) + + // check scala object is forbidden + expectExceptionThrown( + tableEnv.registerFunction("func3", ObjectTableFunction), "Scala object") + expectExceptionThrown( + javaTableEnv.registerFunction("func3", ObjectTableFunction), "Scala object") + expectExceptionThrown( + in1.crossApply(ObjectTableFunction('a, 1)),"Scala object") + + } + + + @Test + def testInvalidTableFunction(): Unit = { + // mock + val util = streamTestUtil() + val t = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val tEnv = TableEnvironment.getTableEnvironment(mock(classOf[JavaExecutionEnv])) + + //=================== check scala object is forbidden ===================== + // Scala table environment register + expectExceptionThrown(util.addFunction("udtf", ObjectTableFunction), "Scala object") + // Java table environment register + expectExceptionThrown(tEnv.registerFunction("udtf", ObjectTableFunction), "Scala object") + // Scala Table API directly call + expectExceptionThrown(t.crossApply(ObjectTableFunction('a, 1)), "Scala object") + + + //============ throw exception when table function is not registered ========= + // Java Table API call + expectExceptionThrown(t.crossApply("nonexist(a)"), "Undefined table function: NONEXIST") + // SQL API call + expectExceptionThrown( + util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(nonexist(a))"), + "No match found for function signature nonexist()") + + + //========= throw exception when the called function is a scalar function ==== + util.addFunction("func0", Func0) + // Java Table API call + expectExceptionThrown(t.crossApply("func0(a)"), "is not a TableFunction") + // SQL API call + // NOTE: it doesn't throw an exception but an AssertionError, maybe a Calcite bug + expectExceptionThrown( + util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(func0(a))"), + null, + classOf[AssertionError]) + + //========== throw exception when the parameters is not correct =============== + // Java Table API call + util.addFunction("func2", new TableFunc2) + expectExceptionThrown( + t.crossApply("func2(c, c)"), + "Given parameters of function 'FUNC2' do not match any signature") + // SQL API call + expectExceptionThrown( + util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(func2(c, c))"), + "No match found for function signature func2(, )") + } + + private def expectExceptionThrown( + function: => Unit, + keywords: String, + clazz: Class[_ <: Throwable] = classOf[ValidationException]) + : Unit = { + try { + function + fail(s"Expected a $clazz, but no exception is thrown.") + } catch { + case e if e.getClass == clazz => + if (keywords != null) { + assertTrue( + s"The exception message '${e.getMessage}' doesn't contain keyword '$keywords'", + e.getMessage.contains(keywords)) + } + case e: Throwable => fail(s"Expected throw ${clazz.getSimpleName}, but is $e.") + } + } + + @Test + def testSQLWithCrossApply(): Unit = { + val util = streamTestUtil() + val func1 = new TableFunc1 + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addFunction("func1", func1) + + val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)" + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamCorrelate", + streamTableNode(0), + term("invocation", "func1($cor0.c)"), + term("function", func1.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"), + term("joinType", "INNER") + ), + term("select", "c", "f0 AS s") + ) + + util.verifySql(sqlQuery, expected) + + // test overloading + + val sqlQuery2 = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c, '$')) AS T(s)" + + val expected2 = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamCorrelate", + streamTableNode(0), + term("invocation", "func1($cor0.c, '$')"), + term("function", func1.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"), + term("joinType", "INNER") + ), + term("select", "c", "f0 AS s") + ) + + util.verifySql(sqlQuery2, expected2) + } + + @Test + def testSQLWithOuterApply(): Unit = { + val util = streamTestUtil() + val func1 = new TableFunc1 + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addFunction("func1", func1) + + val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON TRUE" + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamCorrelate", + streamTableNode(0), + term("invocation", "func1($cor0.c)"), + term("function", func1.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"), + term("joinType", "LEFT") + ), + term("select", "c", "f0 AS s") + ) + + util.verifySql(sqlQuery, expected) + } + + @Test + def testSQLWithCustomType(): Unit = { + val util = streamTestUtil() + val func2 = new TableFunc2 + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addFunction("func2", func2) + + val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len)" + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamCorrelate", + streamTableNode(0), + term("invocation", "func2($cor0.c)"), + term("function", func2.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " + + "VARCHAR(2147483647) f0, INTEGER f1)"), + term("joinType", "INNER") + ), + term("select", "c", "f0 AS name", "f1 AS len") + ) + + util.verifySql(sqlQuery, expected) + } + + @Test + def testSQLWithHierarchyType(): Unit = { + val util = streamTestUtil() + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val function = new HierarchyTableFunction + util.addFunction("hierarchy", function) + + val sqlQuery = "SELECT c, T.* FROM MyTable, LATERAL TABLE(hierarchy(c)) AS T(name, adult, len)" + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamCorrelate", + streamTableNode(0), + term("invocation", "hierarchy($cor0.c)"), + term("function", function.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," + + " VARCHAR(2147483647) f0, BOOLEAN f1, INTEGER f2)"), + term("joinType", "INNER") + ), + term("select", "c", "f0 AS name", "f1 AS adult", "f2 AS len") + ) + + util.verifySql(sqlQuery, expected) + } + + @Test + def testSQLWithPojoType(): Unit = { + val util = streamTestUtil() + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + val function = new PojoTableFunc + util.addFunction("pojo", function) + + val sqlQuery = "SELECT c, name, age FROM MyTable, LATERAL TABLE(pojo(c))" + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamCorrelate", + streamTableNode(0), + term("invocation", "pojo($cor0.c)"), + term("function", function.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," + + " INTEGER age, VARCHAR(2147483647) name)"), + term("joinType", "INNER") + ), + term("select", "c", "name", "age") + ) + + util.verifySql(sqlQuery, expected) + } + + @Test + def testSQLWithFilter(): Unit = { + val util = streamTestUtil() + val func2 = new TableFunc2 + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addFunction("func2", func2) + + val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name, len) " + + "WHERE len > 2" + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamCorrelate", + streamTableNode(0), + term("invocation", "func2($cor0.c)"), + term("function", func2.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " + + "VARCHAR(2147483647) f0, INTEGER f1)"), + term("joinType", "INNER"), + term("condition", ">($1, 2)") + ), + term("select", "c", "f0 AS name", "f1 AS len") + ) + + util.verifySql(sqlQuery, expected) + } + + + @Test + def testSQLWithScalarFunction(): Unit = { + val util = streamTestUtil() + val func1 = new TableFunc1 + util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addFunction("func1", func1) + + val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)" + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamCorrelate", + streamTableNode(0), + term("invocation", "func1(SUBSTRING($cor0.c, 2))"), + term("function", func1.getClass.getCanonicalName), + term("rowType", + "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) f0)"), + term("joinType", "INNER") + ), + term("select", "c", "f0 AS s") + ) + + util.verifySql(sqlQuery, expected) + } + +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala index 95cb331de98fc..ffe3cd30fb4bd 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala @@ -24,7 +24,7 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.scala.table._ import org.apache.flink.api.table.expressions.utils._ -import org.apache.flink.api.table.functions.UserDefinedFunction +import org.apache.flink.api.table.functions.ScalarFunction import org.apache.flink.api.table.typeutils.RowTypeInfo import org.apache.flink.api.table.{Row, Types} import org.junit.Test @@ -208,7 +208,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { )).asInstanceOf[TypeInformation[Any]] } - override def functions: Map[String, UserDefinedFunction] = Map( + override def functions: Map[String, ScalarFunction] = Map( "Func0" -> Func0, "Func1" -> Func1, "Func2" -> Func2, diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala index 84b61da22ce75..958fd259e6a39 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala @@ -30,7 +30,7 @@ import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment} import org.apache.flink.api.table._ import org.apache.flink.api.table.codegen.{CodeGenerator, Compiler, GeneratedFunction} import org.apache.flink.api.table.expressions.{Expression, ExpressionParser} -import org.apache.flink.api.table.functions.UserDefinedFunction +import org.apache.flink.api.table.functions.ScalarFunction import org.apache.flink.api.table.plan.nodes.dataset.{DataSetCalc, DataSetConvention} import org.apache.flink.api.table.plan.rules.FlinkRuleSets import org.apache.flink.api.table.typeutils.RowTypeInfo @@ -79,7 +79,7 @@ abstract class ExpressionTestBase { def typeInfo: TypeInformation[Any] - def functions: Map[String, UserDefinedFunction] = Map() + def functions: Map[String, ScalarFunction] = Map() @Before def resetTestExprs() = { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala new file mode 100644 index 0000000000000..1e6bdb809f16c --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.expressions.utils + +import java.lang.Boolean +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.tuple.Tuple3 +import org.apache.flink.api.table.Row +import org.apache.flink.api.table.functions.TableFunction +import org.apache.flink.api.table.typeutils.RowTypeInfo + + +case class SimpleUser(name: String, age: Int) + +class TableFunc0 extends TableFunction[SimpleUser] { + // make sure input element's format is "#" + def eval(user: String): Unit = { + if (user.contains("#")) { + val splits = user.split("#") + collect(SimpleUser(splits(0), splits(1).toInt)) + } + } +} + +class TableFunc1 extends TableFunction[String] { + def eval(str: String): Unit = { + if (str.contains("#")){ + str.split("#").foreach(collect) + } + } + + def eval(str: String, prefix: String): Unit = { + if (str.contains("#")) { + str.split("#").foreach(s => collect(prefix + s)) + } + } +} + + +class TableFunc2 extends TableFunction[Row] { + def eval(str: String): Unit = { + if (str.contains("#")) { + str.split("#").foreach({ s => + val row = new Row(2) + row.setField(0, s) + row.setField(1, s.length) + collect(row) + }) + } + } + + override def getResultType: TypeInformation[Row] = { + new RowTypeInfo(Seq(BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO)) + } +} + +class HierarchyTableFunction extends SplittableTableFunction[Boolean, Integer] { + def eval(user: String) { + if (user.contains("#")) { + val splits = user.split("#") + val age = splits(1).toInt + collect(new Tuple3[String, Boolean, Integer](splits(0), age >= 20, age)) + } + } +} + +abstract class SplittableTableFunction[A, B] extends TableFunction[Tuple3[String, A, B]] {} + +class PojoTableFunc extends TableFunction[PojoUser] { + def eval(user: String) { + if (user.contains("#")) { + val splits = user.split("#") + collect(new PojoUser(splits(0), splits(1).toInt)) + } + } +} + +class PojoUser() { + var name: String = _ + var age: Int = 0 + + def this(name: String, age: Int) { + this() + this.name = name + this.age = age + } +} + +// ---------------------------------------------------------------------------------------------- +// Invalid Table Functions +// ---------------------------------------------------------------------------------------------- + + +// this is used to check whether scala object is forbidden +object ObjectTableFunction extends TableFunction[Integer] { + def eval(a: Int, b: Int): Unit = { + collect(a) + collect(b) + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala index 539bb61fc2a8c..73f50f54457b7 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala @@ -24,6 +24,7 @@ import org.apache.flink.api.java.{DataSet => JDataSet} import org.apache.flink.api.scala.table._ import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment} import org.apache.flink.api.table.expressions.Expression +import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction} import org.apache.flink.api.table.{Table, TableEnvironment} import org.apache.flink.streaming.api.datastream.{DataStream => JDataStream} import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} @@ -43,6 +44,12 @@ class TableTestBase { StreamTableTestUtil() } + def verifyTableEquals(expected: Table, actual: Table): Unit = { + assertEquals("Logical Plan do not match", + RelOptUtil.toString(expected.getRelNode), + RelOptUtil.toString(actual.getRelNode)) + } + } abstract class TableTestUtil { @@ -54,6 +61,9 @@ abstract class TableTestUtil { } def addTable[T: TypeInformation](name: String, fields: Expression*): Table + def addFunction[T: TypeInformation](name: String, function: TableFunction[T]): Unit + def addFunction(name: String, function: ScalarFunction): Unit + def verifySql(query: String, expected: String): Unit def verifyTable(resultTable: Table, expected: String): Unit @@ -119,6 +129,17 @@ case class BatchTableTestUtil() extends TableTestUtil { t } + def addFunction[T: TypeInformation]( + name: String, + function: TableFunction[T]) + : Unit = { + tEnv.registerFunction(name, function) + } + + def addFunction(name: String, function: ScalarFunction): Unit = { + tEnv.registerFunction(name, function) + } + def verifySql(query: String, expected: String): Unit = { verifyTable(tEnv.sql(query), expected) } @@ -164,6 +185,17 @@ case class StreamTableTestUtil() extends TableTestUtil { t } + def addFunction[T: TypeInformation]( + name: String, + function: TableFunction[T]) + : Unit = { + tEnv.registerFunction(name, function) + } + + def addFunction(name: String, function: ScalarFunction): Unit = { + tEnv.registerFunction(name, function) + } + def verifySql(query: String, expected: String): Unit = { verifyTable(tEnv.sql(query), expected) } From 2cd9cf4df8eb68789f82ce0fa96eec913ff7932b Mon Sep 17 00:00:00 2001 From: Jark Wu Date: Sun, 4 Dec 2016 19:38:41 +0800 Subject: [PATCH 2/4] Address review comments: 1. extend AS to support multiple names 2. make .crossApply accept an Expression parameter, and parse table function call string to Expression 3. rename class SqlFunctionUtils to FunctionGenerator 4. and other minor changes --- .../java/table/BatchTableEnvironment.scala | 2 +- .../java/table/StreamTableEnvironment.scala | 2 +- .../table/TableFunctionCallBuilder.scala | 39 +++++++ .../flink/api/scala/table/expressionDsl.scala | 3 +- .../flink/api/table/TableEnvironment.scala | 6 +- .../flink/api/table/TableFunctionCall.scala | 110 ------------------ .../api/table/codegen/CodeGenerator.scala | 16 +-- ...ionUtils.scala => FunctionGenerator.scala} | 2 +- .../codegen/calls/TableFunctionCallGen.scala | 6 +- .../table/expressions/ExpressionParser.scala | 27 +---- .../flink/api/table/expressions/call.scala | 75 +++++++++++- .../table/expressions/fieldExpression.scala | 6 +- .../api/table/plan/ProjectionTranslator.scala | 4 +- .../api/table/plan/logical/operators.scala | 14 +-- .../api/table/plan/nodes/FlinkCorrelate.scala | 24 ++-- .../org/apache/flink/api/table/table.scala | 49 ++++---- .../api/table/validate/FunctionCatalog.scala | 50 +++----- .../stream/UserDefinedTableFunctionTest.scala | 7 +- 18 files changed, 209 insertions(+), 233 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala delete mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableFunctionCall.scala rename flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/{SqlFunctionUtils.scala => FunctionGenerator.scala} (99%) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala index 9df646f5943e4..b353377c54f28 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/BatchTableEnvironment.scala @@ -165,7 +165,7 @@ class BatchTableEnvironment( /** * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog. - * Registered functions can be referenced in SQL queries. + * Registered functions can be referenced in Table API and SQL queries. * * @param name The name under which the function is registered. * @param tf The TableFunction to register diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala index c6b5cb9b19805..367cb82454114 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/java/table/StreamTableEnvironment.scala @@ -167,7 +167,7 @@ class StreamTableEnvironment( /** * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog. - * Registered functions can be referenced in SQL queries. + * Registered functions can be referenced in Table API and SQL queries. * * @param name The name under which the function is registered. * @param tf The TableFunction to register diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala new file mode 100644 index 0000000000000..2261b702ec81b --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/TableFunctionCallBuilder.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.table + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.table.expressions.{Expression, TableFunctionCall} +import org.apache.flink.api.table.functions.TableFunction + +case class TableFunctionCallBuilder[T: TypeInformation](udtf: TableFunction[T]) { + /** + * Creates a call to a [[TableFunction]] in Scala Table API. + * + * @param params actual parameters of function + * @return [[TableFunctionCall]] + */ + def apply(params: Expression*): Expression = { + val resultType = if (udtf.getResultType == null) { + implicitly[TypeInformation[T]] + } else { + udtf.getResultType + } + TableFunctionCall(udtf.getClass.getSimpleName, udtf, params, resultType) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala index 922621079d74a..cc4c68d897b8c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala @@ -21,7 +21,6 @@ import java.sql.{Date, Time, Timestamp} import org.apache.calcite.avatica.util.DateTimeUtils._ import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation} -import org.apache.flink.api.table.TableFunctionCallBuilder import org.apache.flink.api.table.expressions.ExpressionUtils.{toMilliInterval, toMonthInterval, toRowInterval} import org.apache.flink.api.table.expressions.TimeIntervalUnit.TimeIntervalUnit import org.apache.flink.api.table.expressions._ @@ -99,7 +98,7 @@ trait ImplicitExpressionOperations { def cast(toType: TypeInformation[_]) = Cast(expr, toType) - def as(name: Symbol) = Alias(expr, name.name) + def as(name: Symbol, extraNames: Symbol*) = Alias(expr, name.name, extraNames.map(_.name)) def asc = Asc(expr) def desc = Desc(expr) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala index c7f8f265cb43a..8cabadb1b6743 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableEnvironment.scala @@ -386,7 +386,7 @@ abstract class TableEnvironment(val config: TableConfig) { case t: TupleTypeInfo[A] => exprs.zipWithIndex.map { case (UnresolvedFieldReference(name), idx) => (idx, name) - case (Alias(UnresolvedFieldReference(origName), name), _) => + case (Alias(UnresolvedFieldReference(origName), name, _), _) => val idx = t.getFieldIndex(origName) if (idx < 0) { throw new TableException(s"$origName is not a field of type $t") @@ -398,7 +398,7 @@ abstract class TableEnvironment(val config: TableConfig) { case c: CaseClassTypeInfo[A] => exprs.zipWithIndex.map { case (UnresolvedFieldReference(name), idx) => (idx, name) - case (Alias(UnresolvedFieldReference(origName), name), _) => + case (Alias(UnresolvedFieldReference(origName), name, _), _) => val idx = c.getFieldIndex(origName) if (idx < 0) { throw new TableException(s"$origName is not a field of type $c") @@ -415,7 +415,7 @@ abstract class TableEnvironment(val config: TableConfig) { throw new TableException(s"$name is not a field of type $p") } (idx, name) - case Alias(UnresolvedFieldReference(origName), name) => + case Alias(UnresolvedFieldReference(origName), name, _) => val idx = p.getFieldIndex(origName) if (idx < 0) { throw new TableException(s"$origName is not a field of type $p") diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableFunctionCall.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableFunctionCall.scala deleted file mode 100644 index 4843567629358..0000000000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/TableFunctionCall.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.api.table - -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.table.expressions.{Expression, UnresolvedFieldReference} -import org.apache.flink.api.table.functions.TableFunction -import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.getFieldInfo -import org.apache.flink.api.table.plan.logical.{LogicalNode, LogicalTableFunctionCall} - - -/** - * A [[TableFunctionCall]] represents a call to a [[TableFunction]] with actual parameters. - * - * For Scala users, Flink will help to parse a [[TableFunction]] to [[TableFunctionCall]] - * implicitly. For Java users, Flink will help to parse a string expression to - * [[TableFunctionCall]]. So users do not need to create a [[TableFunctionCall]] manually. - * - * @param functionName function name - * @param tableFunction user-defined table function - * @param parameters actual parameters of function - * @param resultType type information of returned table - */ -case class TableFunctionCall( - functionName: String, - tableFunction: TableFunction[_], - parameters: Seq[Expression], - resultType: TypeInformation[_]) { - - private var aliases: Option[Seq[Expression]] = None - - /** - * Assigns an alias for this table function returned fields that the following `select()` clause - * can refer to. - * - * @param aliasList alias for this table function returned fields - * @return this table function call - */ - def as(aliasList: Expression*): TableFunctionCall = { - this.aliases = Some(aliasList) - this - } - - /** - * Converts an API class to a logical node for planning. - */ - private[flink] def toLogicalTableFunctionCall(child: LogicalNode): LogicalTableFunctionCall = { - val originNames = getFieldInfo(resultType)._1 - - // determine the final field names - val fieldNames = if (aliases.isDefined) { - val aliasList = aliases.get - if (aliasList.length != originNames.length) { - throw ValidationException( - s"List of column aliases must have same degree as table; " + - s"the returned table of function '$functionName' has ${originNames.length} " + - s"columns (${originNames.mkString(",")}), " + - s"whereas alias list has ${aliasList.length} columns") - } else if (!aliasList.forall(_.isInstanceOf[UnresolvedFieldReference])) { - throw ValidationException("Alias only accept name expressions as arguments") - } else { - aliasList.map(_.asInstanceOf[UnresolvedFieldReference].name).toArray - } - } else { - originNames - } - - LogicalTableFunctionCall( - functionName, - tableFunction, - parameters, - resultType, - fieldNames, - child) - } -} - - -case class TableFunctionCallBuilder[T: TypeInformation](udtf: TableFunction[T]) { - /** - * Creates a call to a [[TableFunction]] in Scala Table API. - * - * @param params actual parameters of function - * @return [[TableFunctionCall]] - */ - def apply(params: Expression*): TableFunctionCall = { - val resultType = if (udtf.getResultType == null) { - implicitly[TypeInformation[T]] - } else { - udtf.getResultType - } - TableFunctionCall(udtf.getClass.getSimpleName, udtf, params, resultType) - } -} - diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala index 082137d42720c..8d6eea9774676 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala @@ -33,7 +33,7 @@ import org.apache.flink.api.java.typeutils.{GenericTypeInfo, PojoTypeInfo, Tuple import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo import org.apache.flink.api.table.codegen.CodeGenUtils._ import org.apache.flink.api.table.codegen.Indenter.toISC -import org.apache.flink.api.table.codegen.calls.SqlFunctionUtils +import org.apache.flink.api.table.codegen.calls.FunctionGenerator import org.apache.flink.api.table.codegen.calls.ScalarOperators._ import org.apache.flink.api.table.typeutils.{RowTypeInfo, TypeConverter} import org.apache.flink.api.table.typeutils.TypeCheckUtils._ @@ -159,22 +159,22 @@ class CodeGenerator( /** * @return term of the (casted and possibly boxed) first input */ - def input1Term = "in1" + var input1Term = "in1" /** * @return term of the (casted and possibly boxed) second input */ - def input2Term = "in2" + var input2Term = "in2" /** * @return term of the (casted) output collector */ - def collectorTerm = "c" + var collectorTerm = "c" /** * @return term of the output record (possibly defined in the member area e.g. Row, Tuple) */ - def outRecordTerm = "out" + var outRecordTerm = "out" /** * @return returns if null checking is enabled @@ -357,6 +357,8 @@ class CodeGenerator( val input2AccessExprs = input2 match { case Some(ti) => for (i <- 0 until ti.getArity) + // use generateFieldAccess instead of generateInputAccess to avoid the generated table + // function's field access code is put on the top of function body rather than the while loop yield generateFieldAccess(ti, input2Term, i, input2PojoFieldMapping) case None => throw new CodeGenException("type information of input2 must not be null") } @@ -778,7 +780,7 @@ class CodeGenerator( } override def visitCorrelVariable(correlVariable: RexCorrelVariable): GeneratedExpression = { - GeneratedExpression(input1Term, "false", "", input1) + GeneratedExpression(input1Term, GeneratedExpression.NEVER_NULL, "", input1) } override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression = @@ -973,7 +975,7 @@ class CodeGenerator( // advanced scalar functions case sqlOperator: SqlOperator => - val callGen = SqlFunctionUtils.getCallGenerator( + val callGen = FunctionGenerator.getCallGenerator( sqlOperator, operands.map(_.resultType), resultType) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/SqlFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FunctionGenerator.scala similarity index 99% rename from flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/SqlFunctionUtils.scala rename to flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FunctionGenerator.scala index 8f11c876e4111..9b144ba67c8db 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/SqlFunctionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/FunctionGenerator.scala @@ -35,7 +35,7 @@ import scala.collection.mutable /** * Global hub for user-defined and built-in advanced SQL functions. */ -object SqlFunctionUtils { +object FunctionGenerator { private val sqlFunctions: mutable.Map[(SqlOperator, Seq[TypeInformation[_]]), CallGenerator] = mutable.Map() diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala index 802d8a4c78784..78d7fdb9713a4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala @@ -73,6 +73,10 @@ class TableFunctionCallGen( |""".stripMargin // has no result - GeneratedExpression(functionReference, "false", functionCallCode, returnType) + GeneratedExpression( + functionReference, + GeneratedExpression.NEVER_NULL, + functionCallCode, + returnType) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala index c995b2bfc7a08..6cd63fffdc994 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala @@ -24,7 +24,6 @@ import org.apache.flink.api.table.expressions.ExpressionUtils.{toMilliInterval, import org.apache.flink.api.table.expressions.TimeIntervalUnit.TimeIntervalUnit import org.apache.flink.api.table.expressions.TimePointUnit.TimePointUnit import org.apache.flink.api.table.expressions.TrimMode.TrimMode -import org.apache.flink.api.table.plan.logical.{AliasNode, LogicalNode, UnresolvedTableFunctionCall} import org.apache.flink.api.table.typeutils.TimeIntervalTypeInfo import scala.language.implicitConversions @@ -448,7 +447,9 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { lazy val alias: PackratParser[Expression] = logic ~ AS ~ fieldReference ^^ { case e ~ _ ~ name => Alias(e, name.name) - } | logic + } | logic ~ AS ~ "(" ~ rep1sep(fieldReference, ",") ~ ")" ^^ { + case e ~ _ ~ _ ~ names ~ _ => Alias(e, names.head.name, names.drop(1).map(_.name)) + } | logic lazy val expression: PackratParser[Expression] = alias | failure("Invalid expression.") @@ -473,28 +474,6 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers { } } - lazy val tableFunctionCall: PackratParser[LogicalNode] = - functionIdent ~ "(" ~ repsep(expression, ",") ~ ")" ^^ { - case name ~ _ ~ args ~ _ => UnresolvedTableFunctionCall(name.toUpperCase, args) - } - - lazy val aliasNode: PackratParser[LogicalNode] = - tableFunctionCall ~ AS ~ "(" ~ repsep(fieldReference, ",") ~ ")" ^^ { - case e ~ _ ~ _ ~ names ~ _ => AliasNode(names, e) - } | tableFunctionCall - - lazy val logicalNode: PackratParser[LogicalNode] = aliasNode | - failure("Invalid expression.") - - def parseLogicalNode(nodeString: String): LogicalNode = { - parseAll(logicalNode, nodeString) match { - case Success(lst, _) => lst - - case NoSuccess(msg, next) => - throwError(msg, next) - } - } - private def throwError(msg: String, next: Input): Nothing = { val improvedMsg = msg.replace("string matching regex `\\z'", "End of expression") diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala index 6df6bfeffff44..3b53bef465571 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala @@ -19,10 +19,12 @@ package org.apache.flink.api.table.expressions import org.apache.calcite.rex.RexNode import org.apache.calcite.tools.RelBuilder -import org.apache.flink.api.table.functions.ScalarFunction -import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, signatureToString, signaturesToString, createScalarSqlFunction} -import org.apache.flink.api.table.validate.{ValidationResult, ValidationFailure, ValidationSuccess} -import org.apache.flink.api.table.{FlinkTypeFactory, UnresolvedException} +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction} +import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils._ +import org.apache.flink.api.table.plan.logical.{LogicalNode, LogicalTableFunctionCall} +import org.apache.flink.api.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess} +import org.apache.flink.api.table.{FlinkTypeFactory, UnresolvedException, ValidationException} /** * General expression for unresolved function calls. The function can be a built-in @@ -89,3 +91,68 @@ case class ScalarFunctionCall( } } + + +/** + * + * Expression for calling a user-defined table function with actual parameters. + * + * @param functionName function name + * @param tableFunction user-defined table function + * @param parameters actual parameters of function + * @param resultType type information of returned table + */ +case class TableFunctionCall( + functionName: String, + tableFunction: TableFunction[_], + parameters: Seq[Expression], + resultType: TypeInformation[_]) + extends Expression { + + private var aliases: Option[Seq[String]] = None + + override private[flink] def children: Seq[Expression] = parameters + + /** + * Assigns an alias for this table function returned fields that the following `select()` clause + * can refer to. + * + * @param aliasList alias for this table function returned fields + * @return this table function call + */ + private[flink] def as(aliasList: Option[Seq[String]]): TableFunctionCall = { + this.aliases = aliasList + this + } + + /** + * Converts an API class to a logical node for planning. + */ + private[flink] def toLogicalTableFunctionCall(child: LogicalNode): LogicalTableFunctionCall = { + val originNames = getFieldInfo(resultType)._1 + + // determine the final field names + val fieldNames = if (aliases.isDefined) { + val aliasList = aliases.get + if (aliasList.length != originNames.length) { + throw ValidationException( + s"List of column aliases must have same degree as table; " + + s"the returned table of function '$functionName' has ${originNames.length} " + + s"columns (${originNames.mkString(",")}), " + + s"whereas alias list has ${aliasList.length} columns") + } else { + aliasList.toArray + } + } else { + originNames + } + + LogicalTableFunctionCall( + functionName, + tableFunction, + parameters, + resultType, + fieldNames, + child) + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala index c7817bf1bac34..e651bb3f2c0a8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/fieldExpression.scala @@ -67,7 +67,7 @@ case class ResolvedFieldReference( } } -case class Alias(child: Expression, name: String) +case class Alias(child: Expression, name: String, extraNames: Seq[String] = Seq()) extends UnaryExpression with NamedExpression { override def toString = s"$child as '$name" @@ -80,7 +80,7 @@ case class Alias(child: Expression, name: String) override private[flink] def makeCopy(anyRefs: Array[AnyRef]): this.type = { val child: Expression = anyRefs.head.asInstanceOf[Expression] - copy(child, name).asInstanceOf[this.type] + copy(child, name, extraNames).asInstanceOf[this.type] } override private[flink] def toAttribute: Attribute = { @@ -94,6 +94,8 @@ case class Alias(child: Expression, name: String) override private[flink] def validateInput(): ValidationResult = { if (name == "*") { ValidationFailure("Alias can not accept '*' as name.") + } else if (extraNames.nonEmpty) { + ValidationFailure("Invalid call to Alias with multiple names.") } else { ValidationSuccess } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala index cd22f6a61cd71..f6ddeeffc6914 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala @@ -122,10 +122,10 @@ object ProjectionTranslator { case prop: WindowProperty => val name = propNames(prop) Alias(UnresolvedFieldReference(name), tableEnv.createUniqueAttributeName()) - case n @ Alias(agg: Aggregation, name) => + case n @ Alias(agg: Aggregation, name, _) => val aName = aggNames(agg) Alias(UnresolvedFieldReference(aName), name) - case n @ Alias(prop: WindowProperty, name) => + case n @ Alias(prop: WindowProperty, name, _) => val pName = propNames(prop) Alias(UnresolvedFieldReference(pName), name) case l: LeafExpression => l diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala index 1b6a8da504816..66ccc7ed6a16d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala @@ -17,8 +17,6 @@ */ package org.apache.flink.api.table.plan.logical - -import com.google.common.collect.Sets import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.CorrelationId @@ -220,7 +218,7 @@ case class Aggregate( relBuilder.aggregate( relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava), aggregateExpressions.map { - case Alias(agg: Aggregation, name) => agg.toAggCall(name)(relBuilder) + case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder) case _ => throw new RuntimeException("This should never happen.") }.asJava) } @@ -423,16 +421,16 @@ case class Join( left.construct(relBuilder) right.construct(relBuilder) - val corSet = Sets.newHashSet[CorrelationId]() + val corSet = mutable.Set[CorrelationId]() if (correlated) { - corSet.add(relBuilder.peek().getCluster.createCorrel()) + corSet += relBuilder.peek().getCluster.createCorrel() } relBuilder.join( TypeConverter.flinkJoinTypeToRelType(joinType), condition.map(_.toRexNode(relBuilder)).getOrElse(relBuilder.literal(true)), - corSet) + corSet.asJava) } private def ambiguousName: Set[String] = @@ -565,11 +563,11 @@ case class WindowAggregate( window, relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava), propertyExpressions.map { - case Alias(prop: WindowProperty, name) => prop.toNamedWindowProperty(name)(relBuilder) + case Alias(prop: WindowProperty, name, _) => prop.toNamedWindowProperty(name)(relBuilder) case _ => throw new RuntimeException("This should never happen.") }, aggregateExpressions.map { - case Alias(agg: Aggregation, name) => agg.toAggCall(name)(relBuilder) + case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder) case _ => throw new RuntimeException("This should never happen.") }.asJava) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala index 821c55529a1ce..9745be1c93a1c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkCorrelate.scala @@ -53,8 +53,6 @@ trait FlinkCorrelate { config.getEfficientTypeUsage) val (input1AccessExprs, input2AccessExprs) = generator.generateCorrelateAccessExprs - val crossResultExpr = generator.generateResultExpression(input1AccessExprs ++ input2AccessExprs, - returnType, rowType.getFieldNames.asScala) val call = generator.generateExpression(rexCall) var body = @@ -73,8 +71,16 @@ trait FlinkCorrelate { """.stripMargin } else if (joinType == SemiJoinType.LEFT) { // outer apply - val input2NullExprs = input2AccessExprs.map( - x => GeneratedExpression(primitiveDefaultValue(x.resultType), "true", "", x.resultType)) + + // in case of outer apply and the returned row of table function is empty, + // fill null to all fields of the row + val input2NullExprs = input2AccessExprs.map { x => + GeneratedExpression( + primitiveDefaultValue(x.resultType), + GeneratedExpression.ALWAYS_NULL, + "", + x.resultType) + } val outerResultExpr = generator.generateResultExpression( input1AccessExprs ++ input2NullExprs, returnType, rowType.getFieldNames.asScala) body += @@ -89,15 +95,19 @@ trait FlinkCorrelate { throw TableException(s"Unsupported SemiJoinType: $joinType for correlate join.") } + val crossResultExpr = generator.generateResultExpression( + input1AccessExprs ++ input2AccessExprs, + returnType, + rowType.getFieldNames.asScala) + val projection = if (condition.isEmpty) { s""" |${crossResultExpr.code} |${generator.collectorTerm}.collect(${crossResultExpr.resultTerm}); """.stripMargin } else { - val filterGenerator = new CodeGenerator(config, false, udtfTypeInfo) { - override def input1Term: String = input2Term - } + val filterGenerator = new CodeGenerator(config, false, udtfTypeInfo) + filterGenerator.input1Term = filterGenerator.input2Term val filterCondition = filterGenerator.generateExpression(condition.get) s""" |${filterGenerator.reuseInputUnboxingCode()} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala index 3fe3b988c4f1f..a75f2fc70f76d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala @@ -21,7 +21,7 @@ import org.apache.calcite.rel.RelNode import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.api.table.plan.logical.Minus -import org.apache.flink.api.table.expressions.{Asc, Expression, ExpressionParser, Ordering} +import org.apache.flink.api.table.expressions.{Alias, Asc, Call, Expression, ExpressionParser, Ordering, TableFunctionCall} import org.apache.flink.api.table.plan.ProjectionTranslator._ import org.apache.flink.api.table.plan.logical._ import org.apache.flink.api.table.sinks.TableSink @@ -630,7 +630,7 @@ class Table( * table.crossApply(split('c) as ('s)).select('a,'b,'c,'s) * }}} */ - def crossApply(udtf: TableFunctionCall): Table = { + def crossApply(udtf: Expression): Table = { applyInternal(udtf, JoinType.INNER) } @@ -678,7 +678,7 @@ class Table( * table.outerApply(split('c) as ('s)).select('a,'b,'c,'s) * }}} */ - def outerApply(udtf: TableFunctionCall): Table = { + def outerApply(udtf: Expression): Table = { applyInternal(udtf, JoinType.LEFT_OUTER) } @@ -701,32 +701,33 @@ class Table( } private def applyInternal(udtfString: String, joinType: JoinType): Table = { - val node = ExpressionParser.parseLogicalNode(udtfString) - var alias: Option[Seq[Expression]] = None - val functionCall = node match { - case AliasNode(aliasList, child) => - alias = Some(aliasList) - child - case _ => node - } - - functionCall match { - case call @ UnresolvedTableFunctionCall(name, args) => - val udtfCall = tableEnv.getFunctionCatalog.lookupTableFunction(name, args) - if (alias.isDefined) { - applyInternal(udtfCall.as(alias.get: _*), joinType) - } else { - applyInternal(udtfCall, joinType) - } + val udtf = ExpressionParser.parseExpression(udtfString) + applyInternal(udtf, joinType) + } + + private def applyInternal(udtf: Expression, joinType: JoinType): Table = { + var alias: Option[Seq[String]] = None + + // unwrap an Expression until get a TableFunctionCall + def unwrap(expr: Expression): TableFunctionCall = expr match { + case Alias(child, name, extraNames) => + alias = Some(Seq(name) ++ extraNames) + unwrap(child) + case Call(name, args) => + val function = tableEnv.getFunctionCatalog.lookupFunction(name, args) + unwrap(function) + case c: TableFunctionCall => c case _ => throw new TableException("Cross/Outer Apply only accept TableFunction") } - } - private def applyInternal(node: TableFunctionCall, joinType: JoinType): Table = { - val logicalCall = node.toLogicalTableFunctionCall(this.logicalPlan).validate(tableEnv) + val call = unwrap(udtf) + .as(alias) + .toLogicalTableFunctionCall(this.logicalPlan) + .validate(tableEnv) + new Table( tableEnv, - Join(this.logicalPlan, logicalCall, joinType, None, correlated = true).validate(tableEnv)) + Join(this.logicalPlan, call, joinType, None, correlated = true).validate(tableEnv)) } /** diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala index 4721208d6ed2a..4029a7d532717 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/validate/FunctionCatalog.scala @@ -21,7 +21,7 @@ package org.apache.flink.api.table.validate import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.sql.util.{ChainedSqlOperatorTable, ListSqlOperatorTable, ReflectiveSqlOperatorTable} import org.apache.calcite.sql.{SqlFunction, SqlOperator, SqlOperatorTable} -import org.apache.flink.api.table.{TableFunctionCall, ValidationException} +import org.apache.flink.api.table.ValidationException import org.apache.flink.api.table.expressions._ import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction} import org.apache.flink.api.table.functions.utils.{TableSqlFunction, UserDefinedFunctionUtils} @@ -67,43 +67,15 @@ class FunctionCatalog { new ListSqlOperatorTable(sqlFunctions) ) - /** - * Lookup table function and create an TableFunctionCall if we find a match. - */ - def lookupTableFunction[T](name: String, children: Seq[Expression]): TableFunctionCall = { - val funcClass = functionBuilders - .getOrElse(name.toLowerCase, throw ValidationException(s"Undefined table function: $name")) - funcClass match { - // user-defined table function call - case tf if classOf[TableFunction[T]].isAssignableFrom(tf) => - val tableSqlFunction = sqlFunctions - .find(f => f.getName.equalsIgnoreCase(name) && f.isInstanceOf[TableSqlFunction]) - .getOrElse(throw ValidationException(s"Unregistered table sql function: $name")) - .asInstanceOf[TableSqlFunction] - val typeInfo = tableSqlFunction.getRowTypeInfo - val function = tableSqlFunction.getTableFunction - TableFunctionCall(name, function, children, typeInfo) - - case _ => - throw ValidationException(s"The registered function under name '$name' " + - s"is not a TableFunction") - } - } - /** * Lookup and create an expression if we find a match. */ def lookupFunction(name: String, children: Seq[Expression]): Expression = { val funcClass = functionBuilders .getOrElse(name.toLowerCase, throw ValidationException(s"Undefined function: $name")) - withChildren(funcClass, children) - } - /** - * Instantiate a function using the provided `children`. - */ - private def withChildren(func: Class[_], children: Seq[Expression]): Expression = { - func match { + // Instantiate a function using the provided `children` + funcClass match { // user-defined scalar function call case sf if classOf[ScalarFunction].isAssignableFrom(sf) => @@ -112,10 +84,20 @@ class FunctionCatalog { case Failure(e) => throw ValidationException(e.getMessage) } + // user-defined table function call + case tf if classOf[TableFunction[_]].isAssignableFrom(tf) => + val tableSqlFunction = sqlFunctions + .find(f => f.getName.equalsIgnoreCase(name) && f.isInstanceOf[TableSqlFunction]) + .getOrElse(throw ValidationException(s"Unregistered table sql function: $name")) + .asInstanceOf[TableSqlFunction] + val typeInfo = tableSqlFunction.getRowTypeInfo + val function = tableSqlFunction.getTableFunction + TableFunctionCall(name, function, children, typeInfo) + // general expression call case expression if classOf[Expression].isAssignableFrom(expression) => // try to find a constructor accepts `Seq[Expression]` - Try(func.getDeclaredConstructor(classOf[Seq[_]])) match { + Try(funcClass.getDeclaredConstructor(classOf[Seq[_]])) match { case Success(seqCtor) => Try(seqCtor.newInstance(children).asInstanceOf[Expression]) match { case Success(expr) => expr @@ -124,14 +106,14 @@ class FunctionCatalog { case Failure(e) => val childrenClass = Seq.fill(children.length)(classOf[Expression]) // try to find a constructor matching the exact number of children - Try(func.getDeclaredConstructor(childrenClass: _*)) match { + Try(funcClass.getDeclaredConstructor(childrenClass: _*)) match { case Success(ctor) => Try(ctor.newInstance(children: _*).asInstanceOf[Expression]) match { case Success(expr) => expr case Failure(exception) => throw ValidationException(exception.getMessage) } case Failure(exception) => - throw ValidationException(s"Invalid number of arguments for function $func") + throw ValidationException(s"Invalid number of arguments for function $funcClass") } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala index fa3da6b8ca793..bc01819c0489e 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala @@ -138,7 +138,7 @@ class UserDefinedTableFunctionTest extends TableTestBase { //============ throw exception when table function is not registered ========= // Java Table API call - expectExceptionThrown(t.crossApply("nonexist(a)"), "Undefined table function: NONEXIST") + expectExceptionThrown(t.crossApply("nonexist(a)"), "Undefined function: NONEXIST") // SQL API call expectExceptionThrown( util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(nonexist(a))"), @@ -148,7 +148,10 @@ class UserDefinedTableFunctionTest extends TableTestBase { //========= throw exception when the called function is a scalar function ==== util.addFunction("func0", Func0) // Java Table API call - expectExceptionThrown(t.crossApply("func0(a)"), "is not a TableFunction") + expectExceptionThrown( + t.crossApply("func0(a)"), + "only accept TableFunction", + classOf[TableException]) // SQL API call // NOTE: it doesn't throw an exception but an AssertionError, maybe a Calcite bug expectExceptionThrown( From ca0abcb29be11896eed6be26f6bff3e5ba5655ab Mon Sep 17 00:00:00 2001 From: Jark Wu Date: Mon, 5 Dec 2016 22:08:16 +0800 Subject: [PATCH 3/4] revert removing UserDefinedFunction --- .../api/table/codegen/CodeGenerator.scala | 14 +++---- .../codegen/calls/ScalarFunctionCallGen.scala | 4 +- .../codegen/calls/TableFunctionCallGen.scala | 4 +- .../flink/api/table/expressions/call.scala | 4 +- .../api/table/functions/ScalarFunction.scala | 2 +- .../api/table/functions/TableFunction.scala | 2 +- .../table/functions/UserDefinedFunction.scala | 27 +++++++++++++ .../functions/utils/ScalarSqlFunction.scala | 14 +++---- .../utils/UserDefinedFunctionUtils.scala | 39 +++++++++---------- .../flink/api/table/plan/logical/call.scala | 4 +- .../api/table/plan/logical/operators.scala | 1 - .../plan/nodes/dataset/DataSetCorrelate.scala | 4 +- .../rules/dataSet/DataSetCorrelateRule.scala | 5 ++- 13 files changed, 74 insertions(+), 50 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala index 8d6eea9774676..9e4f5691ca7b2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala @@ -35,6 +35,7 @@ import org.apache.flink.api.table.codegen.CodeGenUtils._ import org.apache.flink.api.table.codegen.Indenter.toISC import org.apache.flink.api.table.codegen.calls.FunctionGenerator import org.apache.flink.api.table.codegen.calls.ScalarOperators._ +import org.apache.flink.api.table.functions.UserDefinedFunction import org.apache.flink.api.table.typeutils.{RowTypeInfo, TypeConverter} import org.apache.flink.api.table.typeutils.TypeCheckUtils._ import org.apache.flink.api.table.{FlinkTypeFactory, TableConfig} @@ -1362,17 +1363,16 @@ class CodeGenerator( } /** - * Adds a reusable instance (a [[org.apache.flink.api.table.functions.TableFunction]] or - * [[org.apache.flink.api.table.functions.ScalarFunction]]) to the member area of the generated - * [[Function]]. The instance class must have a default constructor, however, it does not have + * Adds a reusable [[UserDefinedFunction]] to the member area of the generated [[Function]]. + * The [[UserDefinedFunction]] must have a default constructor, however, it does not have * to be public. * - * @param instance object to be instantiated during runtime + * @param function [[UserDefinedFunction]] object to be instantiated during runtime * @return member variable term */ - def addReusableInstance(instance: Any): String = { - val classQualifier = instance.getClass.getCanonicalName - val fieldTerm = s"instance_${classQualifier.replace('.', '$')}" + def addReusableFunction(function: UserDefinedFunction): String = { + val classQualifier = function.getClass.getCanonicalName + val fieldTerm = s"function_${classQualifier.replace('.', '$')}" val fieldFunction = s""" diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctionCallGen.scala index 62b6842a3b558..b6ef8ad863c21 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctionCallGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarFunctionCallGen.scala @@ -42,7 +42,7 @@ class ScalarFunctionCallGen( operands: Seq[GeneratedExpression]) : GeneratedExpression = { // determine function signature and result class - val matchingSignature = getSignature(scalarFunction.getClass, signature) + val matchingSignature = getSignature(scalarFunction, signature) .getOrElse(throw new CodeGenException("No matching signature found.")) val resultClass = getResultTypeClass(scalarFunction, matchingSignature) @@ -65,7 +65,7 @@ class ScalarFunctionCallGen( } // generate function call - val functionReference = codeGenerator.addReusableInstance(scalarFunction) + val functionReference = codeGenerator.addReusableFunction(scalarFunction) val resultTypeTerm = if (resultClass.isPrimitive) { primitiveTypeTermForTypeInfo(returnType) } else { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala index 78d7fdb9713a4..27cb43fb40190 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/TableFunctionCallGen.scala @@ -42,7 +42,7 @@ class TableFunctionCallGen( operands: Seq[GeneratedExpression]) : GeneratedExpression = { // determine function signature - val matchingSignature = getSignature(tableFunction.getClass, signature) + val matchingSignature = getSignature(tableFunction, signature) .getOrElse(throw new CodeGenException("No matching signature found.")) // convert parameters for function (output boxing) @@ -64,7 +64,7 @@ class TableFunctionCallGen( } // generate function call - val functionReference = codeGenerator.addReusableInstance(tableFunction) + val functionReference = codeGenerator.addReusableFunction(tableFunction) val functionCallCode = s""" |${parameters.map(_.code).mkString("\n")} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala index 3b53bef465571..3e8d8b10d3db2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala @@ -80,11 +80,11 @@ case class ScalarFunctionCall( override private[flink] def validateInput(): ValidationResult = { val signature = children.map(_.resultType) // look for a signature that matches the input types - foundSignature = getSignature(scalarFunction.getClass, signature) + foundSignature = getSignature(scalarFunction, signature) if (foundSignature.isEmpty) { ValidationFailure(s"Given parameters do not match any signature. \n" + s"Actual: ${signatureToString(signature)} \n" + - s"Expected: ${signaturesToString(scalarFunction.getClass)}") + s"Expected: ${signaturesToString(scalarFunction)}") } else { ValidationSuccess } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala index 06adfd9be2d5e..86d9d66128677 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/ScalarFunction.scala @@ -48,7 +48,7 @@ import org.apache.flink.api.table.{FlinkTypeFactory, ValidationException} * recommended to declare parameters and result types as primitive types instead of their boxed * classes. DATE/TIME is equal to int, TIMESTAMP is equal to long. */ -abstract class ScalarFunction { +abstract class ScalarFunction extends UserDefinedFunction { /** * Creates a call to a [[ScalarFunction]] in Scala Table API. diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala index d3548af77c843..98a29210690ba 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/TableFunction.scala @@ -80,7 +80,7 @@ import org.apache.flink.api.table.ValidationException * * @tparam T The type of the output row */ -abstract class TableFunction[T] { +abstract class TableFunction[T] extends UserDefinedFunction { private val rows: util.ArrayList[T] = new util.ArrayList[T]() diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala new file mode 100644 index 0000000000000..cdf6b070495bd --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/UserDefinedFunction.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.functions + +/** + * Base class for all user-defined functions such as scalar functions, table functions, + * or aggregation functions. + * + * User-defined functions must have a default constructor and must be instantiable during runtime. + */ +trait UserDefinedFunction { +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala index bbc33ed79c7a6..0a987aaa3b3cf 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/ScalarSqlFunction.scala @@ -76,12 +76,12 @@ object ScalarSqlFunction { FlinkTypeFactory.toTypeInfo(operandType) } } - val foundSignature = getSignature(scalarFunction.getClass, parameters) + val foundSignature = getSignature(scalarFunction, parameters) if (foundSignature.isEmpty) { throw new ValidationException( s"Given parameters of function '$name' do not match any signature. \n" + s"Actual: ${signatureToString(parameters)} \n" + - s"Expected: ${signaturesToString(scalarFunction.getClass)}") + s"Expected: ${signaturesToString(scalarFunction)}") } val resultType = getResultType(scalarFunction, foundSignature.get) typeFactory.createTypeFromTypeInfo(resultType) @@ -104,7 +104,7 @@ object ScalarSqlFunction { val operandTypeInfo = getOperandTypeInfo(callBinding) - val foundSignature = getSignature(scalarFunction.getClass, operandTypeInfo) + val foundSignature = getSignature(scalarFunction, operandTypeInfo) .getOrElse(throw new ValidationException(s"Operand types of could not be inferred.")) val inferredTypes = scalarFunction @@ -124,13 +124,13 @@ object ScalarSqlFunction { scalarFunction: ScalarFunction) : SqlOperandTypeChecker = { - val signatures = getSignatures(scalarFunction.getClass) + val signatures = getSignatures(scalarFunction) /** * Operand type checker based on [[ScalarFunction]] given information. */ new SqlOperandTypeChecker { override def getAllowedSignatures(op: SqlOperator, opName: String): String = { - s"$opName[${signaturesToString(scalarFunction.getClass)}]" + s"$opName[${signaturesToString(scalarFunction)}]" } override def getOperandCountRange: SqlOperandCountRange = { @@ -144,14 +144,14 @@ object ScalarSqlFunction { : Boolean = { val operandTypeInfo = getOperandTypeInfo(callBinding) - val foundSignature = getSignature(scalarFunction.getClass, operandTypeInfo) + val foundSignature = getSignature(scalarFunction, operandTypeInfo) if (foundSignature.isEmpty) { if (throwOnFailure) { throw new ValidationException( s"Given parameters of function '$name' do not match any signature. \n" + s"Actual: ${signatureToString(operandTypeInfo)} \n" + - s"Expected: ${signaturesToString(scalarFunction.getClass)}") + s"Expected: ${signaturesToString(scalarFunction)}") } else { false } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala index 1c1b2cb0696f9..932baebcf1655 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/functions/utils/UserDefinedFunctionUtils.scala @@ -29,43 +29,41 @@ import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.typeutils.TypeExtractor import org.apache.flink.api.table.{FlinkTypeFactory, TableException, ValidationException} -import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction} +import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction} import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl import org.apache.flink.util.InstantiationUtil object UserDefinedFunctionUtils { /** - * Instantiates a class. + * Instantiates a user-defined function. */ - def instantiate[T](clazz: Class[T]): T = { + def instantiate[T <: UserDefinedFunction](clazz: Class[T]): T = { val constructor = clazz.getDeclaredConstructor() constructor.setAccessible(true) constructor.newInstance() } /** - * Checks if a class can be easily instantiated. + * Checks if a user-defined function can be easily instantiated. */ def checkForInstantiation(clazz: Class[_]): Unit = { if (!InstantiationUtil.isPublic(clazz)) { - throw ValidationException(s"Function class ${clazz.getCanonicalName} is not public.") + throw ValidationException("Function class is not public.") } else if (!InstantiationUtil.isProperClass(clazz)) { - throw ValidationException(s"Function class ${clazz.getCanonicalName} is no proper class, " + - s"it is either abstract, an interface, or a primitive type.") + throw ValidationException("Function class is no proper class, it is either abstract," + + " an interface, or a primitive type.") } else if (InstantiationUtil.isNonStaticInnerClass(clazz)) { - throw ValidationException(s"The class ${clazz.getCanonicalName} is an inner class, " + - s"but not statically accessible.") + throw ValidationException("The class is an inner class, but not statically accessible.") } // check for default constructor (can be private) clazz .getDeclaredConstructors .find(_.getParameterTypes.isEmpty) - .getOrElse(throw ValidationException( - s"Function class ${clazz.getCanonicalName} needs a default constructor.")) + .getOrElse(throw ValidationException("Function class needs a default constructor.")) } /** @@ -90,7 +88,7 @@ object UserDefinedFunctionUtils { * Elements of the signature can be null (act as a wildcard). */ def getSignature( - function: Class[_], + function: UserDefinedFunction, signature: Seq[TypeInformation[_]]) : Option[Array[Class[_]]] = { // We compare the raw Java classes not the TypeInformation. @@ -113,7 +111,7 @@ object UserDefinedFunctionUtils { * Returns eval method matching the given signature of [[TypeInformation]]. */ def getEvalMethod( - function: Class[_], + function: UserDefinedFunction, signature: Seq[TypeInformation[_]]) : Option[Method] = { // We compare the raw Java classes not the TypeInformation. @@ -137,8 +135,9 @@ object UserDefinedFunctionUtils { * Extracts "eval" methods and throws a [[ValidationException]] if no implementation * can be found. */ - def checkAndExtractEvalMethods(function: Class[_]): Array[Method] = { + def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] = { val methods = function + .getClass .getDeclaredMethods .filter { m => val modifiers = m.getModifiers @@ -147,14 +146,14 @@ object UserDefinedFunctionUtils { if (methods.isEmpty) { throw new ValidationException( - s"Function class '${function.getCanonicalName}' does not implement at least " + + s"Function class '${function.getClass.getCanonicalName}' does not implement at least " + s"one method named 'eval' which is public and not abstract.") } else { methods } } - def getSignatures(function: Class[_]): Array[Array[Class[_]]] = { + def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = { checkAndExtractEvalMethods(function).map(_.getParameterTypes) } @@ -192,7 +191,7 @@ object UserDefinedFunctionUtils { typeFactory: FlinkTypeFactory) : Seq[SqlFunction] = { val (fieldNames, fieldIndexes, _) = UserDefinedFunctionUtils.getFieldInfo(resultType) - val evalMethods = checkAndExtractEvalMethods(tableFunction.getClass) + val evalMethods = checkAndExtractEvalMethods(tableFunction) evalMethods.map { method => val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, method) @@ -213,7 +212,7 @@ object UserDefinedFunctionUtils { signature: Array[Class[_]]) : TypeInformation[_] = { // find method for signature - val evalMethod = checkAndExtractEvalMethods(function.getClass) + val evalMethod = checkAndExtractEvalMethods(function) .find(m => signature.sameElements(m.getParameterTypes)) .getOrElse(throw new ValidationException("Given signature is invalid.")) @@ -240,7 +239,7 @@ object UserDefinedFunctionUtils { signature: Array[Class[_]]) : Class[_] = { // find method for signature - val evalMethod = checkAndExtractEvalMethods(function.getClass) + val evalMethod = checkAndExtractEvalMethods(function) .find(m => signature.sameElements(m.getParameterTypes)) .getOrElse(throw new IllegalArgumentException("Given signature is invalid.")) evalMethod.getReturnType @@ -303,7 +302,7 @@ object UserDefinedFunctionUtils { /** * Prints all eval methods signatures of a class. */ - def signaturesToString(function: Class[_]): String = { + def signaturesToString(function: UserDefinedFunction): String = { getSignatures(function).map(signatureToString).mkString(", ") } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/call.scala index 50f9373e84a21..edb2e2af77ac6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/call.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/call.scala @@ -81,12 +81,12 @@ case class LogicalTableFunctionCall( checkForInstantiation(tableFunction.getClass) // look for a signature that matches the input types val signature = node.parameters.map(_.resultType) - val foundMethod = getEvalMethod(tableFunction.getClass, signature) + val foundMethod = getEvalMethod(tableFunction, signature) if (foundMethod.isEmpty) { failValidation( s"Given parameters of function '$functionName' do not match any signature. \n" + s"Actual: ${signatureToString(signature)} \n" + - s"Expected: ${signaturesToString(tableFunction.getClass)}") + s"Expected: ${signaturesToString(tableFunction)}") } else { node.evalMethod = foundMethod.get } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala index 66ccc7ed6a16d..d3696ec225409 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala @@ -28,7 +28,6 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.api.table._ import org.apache.flink.api.table.expressions._ - import org.apache.flink.api.table.typeutils.TypeConverter import org.apache.flink.api.table.validate.{ValidationFailure, ValidationSuccess} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala index d6715ff9a13f8..4aa7fea650f75 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetCorrelate.scala @@ -49,8 +49,8 @@ class DataSetCorrelate( extends SingleRel(cluster, traitSet, inputNode) with FlinkCorrelate with DataSetRel { - override def deriveRowType() = relRowType + override def deriveRowType() = relRowType override def computeSelfCost(planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { val rowCnt = metadata.getRowCount(getInput) * 1.5 @@ -87,7 +87,6 @@ class DataSetCorrelate( .itemIf("condition", condition.orNull, condition.isDefined) } - override def translateToPlan( tableEnv: BatchTableEnvironment, expectedType: Option[TypeInformation[Any]]) @@ -137,5 +136,4 @@ class DataSetCorrelate( inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType)) } - } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala index bccb2578a467f..e6cf0cfcae5d0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetCorrelateRule.scala @@ -64,8 +64,9 @@ class DataSetCorrelateRule convertToCorrelate(rel.getRelList.get(0), condition) case filter: LogicalFilter => - convertToCorrelate(filter.getInput.asInstanceOf[RelSubset].getOriginal, - Some(filter.getCondition)) + convertToCorrelate( + filter.getInput.asInstanceOf[RelSubset].getOriginal, + Some(filter.getCondition)) case scan: LogicalTableFunctionScan => new DataSetCorrelate( From 7b837aab12920f8d7d831f51ec8534fa95f959e6 Mon Sep 17 00:00:00 2001 From: Jark Wu Date: Tue, 6 Dec 2016 11:33:20 +0800 Subject: [PATCH 4/4] remove UnresolvedTableFunctionCall --- .../flink/api/table/plan/logical/call.scala | 117 ------------------ .../api/table/plan/logical/operators.scala | 76 +++++++++++- 2 files changed, 75 insertions(+), 118 deletions(-) delete mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/call.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/call.scala deleted file mode 100644 index edb2e2af77ac6..0000000000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/call.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.api.table.plan.logical - -import java.lang.reflect.Method -import java.util - -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.logical.LogicalTableFunctionScan -import org.apache.calcite.tools.RelBuilder -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.table._ -import org.apache.flink.api.table.expressions.{Attribute, Expression, ResolvedFieldReference} -import org.apache.flink.api.table.functions.TableFunction -import org.apache.flink.api.table.functions.utils.TableSqlFunction -import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils.{getEvalMethod, signaturesToString, signatureToString, getFieldInfo, checkNotSingleton, checkForInstantiation} -import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl - -import scala.collection.JavaConverters._ - -/** - * General logical node for unresolved user-defined table function calls. - */ -case class UnresolvedTableFunctionCall(functionName: String, args: Seq[Expression]) - extends LogicalNode { - - override def output: Seq[Attribute] = - throw UnresolvedException("Invalid call to output on UnresolvedTableFunctionCall") - - override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = - throw UnresolvedException("Invalid call to construct on UnresolvedTableFunctionCall") - - override private[flink] def children: Seq[LogicalNode] = - throw UnresolvedException("Invalid call to children on UnresolvedTableFunctionCall") -} - -/** - * LogicalNode for calling a user-defined table functions. - * @param functionName function name - * @param tableFunction table function to be called (might be overloaded) - * @param parameters actual parameters - * @param fieldNames output field names - * @param child child logical node - */ -case class LogicalTableFunctionCall( - functionName: String, - tableFunction: TableFunction[_], - parameters: Seq[Expression], - resultType: TypeInformation[_], - fieldNames: Array[String], - child: LogicalNode) - extends UnaryNode { - - val (_, fieldIndexes, fieldTypes) = getFieldInfo(resultType) - var evalMethod: Method = _ - - override def output: Seq[Attribute] = fieldNames.zip(fieldTypes).map { - case (n, t) => ResolvedFieldReference(n, t) - } - - override def validate(tableEnv: TableEnvironment): LogicalNode = { - val node = super.validate(tableEnv).asInstanceOf[LogicalTableFunctionCall] - // check not Scala object - checkNotSingleton(tableFunction.getClass) - // check could be instantiated - checkForInstantiation(tableFunction.getClass) - // look for a signature that matches the input types - val signature = node.parameters.map(_.resultType) - val foundMethod = getEvalMethod(tableFunction, signature) - if (foundMethod.isEmpty) { - failValidation( - s"Given parameters of function '$functionName' do not match any signature. \n" + - s"Actual: ${signatureToString(signature)} \n" + - s"Expected: ${signaturesToString(tableFunction)}") - } else { - node.evalMethod = foundMethod.get - } - node - } - - override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { - val fieldIndexes = getFieldInfo(resultType)._2 - val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, evalMethod) - val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] - val sqlFunction = TableSqlFunction( - tableFunction.toString, - tableFunction, - resultType, - typeFactory, - function) - - val scan = LogicalTableFunctionScan.create( - relBuilder.peek().getCluster, - new util.ArrayList[RelNode](), - relBuilder.call(sqlFunction, parameters.map(_.toRexNode(relBuilder)).asJava), - function.getElementType(null), - function.getRowType(relBuilder.getTypeFactory, null), - null) - - relBuilder.push(scan) - } -} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala index d3696ec225409..4dc2ab7f511cf 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala @@ -17,10 +17,13 @@ */ package org.apache.flink.api.table.plan.logical +import java.lang.reflect.Method +import java.util + import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.CorrelationId -import org.apache.calcite.rel.logical.LogicalProject +import org.apache.calcite.rel.logical.{LogicalProject, LogicalTableFunctionScan} import org.apache.calcite.rex.{RexInputRef, RexNode} import org.apache.calcite.tools.RelBuilder import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ @@ -28,6 +31,10 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.api.table._ import org.apache.flink.api.table.expressions._ +import org.apache.flink.api.table.functions.TableFunction +import org.apache.flink.api.table.functions.utils.TableSqlFunction +import org.apache.flink.api.table.functions.utils.UserDefinedFunctionUtils._ +import org.apache.flink.api.table.plan.schema.FlinkTableFunctionImpl import org.apache.flink.api.table.typeutils.TypeConverter import org.apache.flink.api.table.validate.{ValidationFailure, ValidationSuccess} @@ -617,3 +624,70 @@ case class WindowAggregate( } } + +/** + * LogicalNode for calling a user-defined table functions. + * @param functionName function name + * @param tableFunction table function to be called (might be overloaded) + * @param parameters actual parameters + * @param fieldNames output field names + * @param child child logical node + */ +case class LogicalTableFunctionCall( + functionName: String, + tableFunction: TableFunction[_], + parameters: Seq[Expression], + resultType: TypeInformation[_], + fieldNames: Array[String], + child: LogicalNode) + extends UnaryNode { + + val (_, fieldIndexes, fieldTypes) = getFieldInfo(resultType) + var evalMethod: Method = _ + + override def output: Seq[Attribute] = fieldNames.zip(fieldTypes).map { + case (n, t) => ResolvedFieldReference(n, t) + } + + override def validate(tableEnv: TableEnvironment): LogicalNode = { + val node = super.validate(tableEnv).asInstanceOf[LogicalTableFunctionCall] + // check not Scala object + checkNotSingleton(tableFunction.getClass) + // check could be instantiated + checkForInstantiation(tableFunction.getClass) + // look for a signature that matches the input types + val signature = node.parameters.map(_.resultType) + val foundMethod = getEvalMethod(tableFunction, signature) + if (foundMethod.isEmpty) { + failValidation( + s"Given parameters of function '$functionName' do not match any signature. \n" + + s"Actual: ${signatureToString(signature)} \n" + + s"Expected: ${signaturesToString(tableFunction)}") + } else { + node.evalMethod = foundMethod.get + } + node + } + + override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { + val fieldIndexes = getFieldInfo(resultType)._2 + val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, evalMethod) + val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] + val sqlFunction = TableSqlFunction( + tableFunction.toString, + tableFunction, + resultType, + typeFactory, + function) + + val scan = LogicalTableFunctionScan.create( + relBuilder.peek().getCluster, + new util.ArrayList[RelNode](), + relBuilder.call(sqlFunction, parameters.map(_.toRexNode(relBuilder)).asJava), + function.getElementType(null), + function.getRowType(relBuilder.getTypeFactory, null), + null) + + relBuilder.push(scan) + } +}