From 5e126a97bb26654048efa735c21e0daff136a6a2 Mon Sep 17 00:00:00 2001 From: Jark Wu Date: Fri, 13 Jan 2017 21:53:49 +0800 Subject: [PATCH 1/3] [FLINK-5224] [table] Improve UDTF: emit rows directly instead of buffering them --- .../flink/table/codegen/CodeGenerator.scala | 93 +++++++++++++++++++ .../codegen/calls/TableFunctionCallGen.scala | 1 - .../flink/table/functions/TableFunction.scala | 22 ++--- .../table/plan/nodes/FlinkCorrelate.scala | 64 ++++++++----- .../plan/nodes/dataset/DataSetCorrelate.scala | 19 +--- .../datastream/DataStreamCorrelate.scala | 19 +--- .../runtime/CorrelateFlatMapRunner.scala | 65 +++++++++++++ .../runtime/TableFunctionCollector.scala | 81 ++++++++++++++++ 8 files changed, 296 insertions(+), 68 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 13fe4c374cf3b..0368d10b9f603 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -39,6 +39,7 @@ import org.apache.flink.table.codegen.Indenter.toISC import org.apache.flink.table.codegen.calls.FunctionGenerator import org.apache.flink.table.codegen.calls.ScalarOperators._ import org.apache.flink.table.functions.UserDefinedFunction +import org.apache.flink.table.runtime.TableFunctionCollector import org.apache.flink.table.typeutils.TypeConverter import org.apache.flink.table.typeutils.TypeCheckUtils._ @@ -129,6 +130,10 @@ class CodeGenerator( // (inputTerm, index) -> expr private val reusableInputUnboxingExprs = mutable.Map[(String, Int), GeneratedExpression]() + // set of constructor statements that will be added only once + // we use a LinkedHashSet to keep the insertion order + private val reusableConstructorStatements = mutable.LinkedHashSet[(String, String)]() + /** * @return code block of statements that need to be placed in the member area of the Function * (e.g. member variables and their initialization) @@ -159,6 +164,19 @@ class CodeGenerator( reusableInputUnboxingExprs.values.map(_.code).mkString("", "\n", "\n") } + /** + * @return code block of constructor statements of the function + */ + def reuseConstructorCode(funcName: String): String = { + reusableConstructorStatements.map { case (params, body) => + j""" + public $funcName($params) throws Exception { + $body + } + """.stripMargin + }.mkString("", "\n", "\n") + } + /** * @return term of the (casted and possibly boxed) first input */ @@ -257,6 +275,8 @@ class CodeGenerator( ${reuseInitCode()} } + ${reuseConstructorCode(funcName)} + @Override public ${samHeader._1} throws Exception { ${samHeader._2.mkString("\n")} @@ -325,6 +345,52 @@ class CodeGenerator( GeneratedFunction[GenericInputFormat[T]](funcName, returnType, funcCode) } + /** + * Generates a [[TableFunctionCollector]] that can be passed to Java compiler. + * + * @param name Class name of the table function collector. Must not be unique but has to be a + * valid Java class identifier. + * @param bodyCode body code for the collector method + * @param returnType The type information of the element collected by the collector + * @return instance of GeneratedFunction + */ + def generateTableFunctionCollector( + name: String, + bodyCode: String, + returnType: TypeInformation[Any]) + : GeneratedFunction[TableFunctionCollector[Any]] = { + + val funcName = newName(name) + val input1TypeClass = boxedTypeTermForTypeInfo(input1) + val input2TypeClass = boxedTypeTermForTypeInfo(returnType) + + val funcCode = j""" + public class $funcName extends ${classOf[TableFunctionCollector[_]].getCanonicalName} { + + ${reuseMemberCode()} + + public $funcName() throws Exception { + ${reuseInitCode()} + } + + @Override + public void collect(Object record) throws Exception { + super.collect(record); + $input1TypeClass $input1Term = ($input1TypeClass) getInput(); + $input2TypeClass $input2Term = ($input2TypeClass) record; + ${reuseInputUnboxingCode()} + $bodyCode + } + + @Override + public void close() { + } + } + """.stripMargin + + GeneratedFunction[TableFunctionCollector[Any]](funcName, returnType, funcCode) + } + /** * Generates an expression that converts the first input (and second input) into the given type. * If two inputs are converted, the second input is appended. If objects or variables can @@ -1415,6 +1481,33 @@ class CodeGenerator( fieldTerm } + + /** + * Adds a reusable constructor statement with the given parameter types. + * + * @param parameterTypes The parameter types to construct the function + * @return member variable terms + */ + def addReusableConstructor(parameterTypes: Class[_]*): Array[String] = { + val parameters = mutable.ListBuffer[String]() + val fieldTerms = mutable.ListBuffer[String]() + var body = "this();\n" + + parameterTypes.zipWithIndex.foreach { case (t, index) => + val classQualifier = t.getCanonicalName + val fieldTerm = newName(s"instance_${classQualifier.replace('.', '$')}") + val field = s"transient $classQualifier $fieldTerm = null;" + reusableMemberStatements.add(field) + fieldTerms += fieldTerm + parameters += s"$classQualifier arg$index" + body += s"$fieldTerm = arg$index;\n" + } + + reusableConstructorStatements.add((parameters.mkString(","), body)) + + fieldTerms.toArray + } + /** * Adds a reusable array to the member area of the generated [[Function]]. */ diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala index 50c569f0009c6..6e44f559f792b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala @@ -69,7 +69,6 @@ class TableFunctionCallGen( val functionCallCode = s""" |${parameters.map(_.code).mkString("\n")} - |$functionReference.clear(); |$functionReference.eval(${parameters.map(_.resultTerm).mkString(", ")}); |""".stripMargin diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala index 653793e26c27c..dd611084faadd 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala @@ -18,10 +18,9 @@ package org.apache.flink.table.functions -import java.util - import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.expressions.{Expression, TableFunctionCall} +import org.apache.flink.util.Collector /** * Base class for a user-defined table function (UDTF). A user-defined table functions works on @@ -99,27 +98,26 @@ abstract class TableFunction[T] extends UserDefinedFunction { // ---------------------------------------------------------------------------------------------- - private val rows: util.ArrayList[T] = new util.ArrayList[T]() - /** * Emit an output row. * * @param row the output row */ protected def collect(row: T): Unit = { - // cache rows for now, maybe immediately process them further - rows.add(row) + collector.collect(row) } - /** - * Internal use. Get an iterator of the buffered rows. - */ - def getRowsIterator = rows.iterator() + // ---------------------------------------------------------------------------------------------- + + /** The code generated collector used to emit row. */ + private var collector: Collector[T] = _ /** - * Internal use. Clear buffered rows. + * Internal use. Sets the current collector. */ - def clear() = rows.clear() + def setCollector(collector: Collector[T]): Unit = { + this.collector = collector + } // ---------------------------------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala index fc6949312a30a..a0f18417e5485 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala @@ -26,7 +26,7 @@ import org.apache.flink.table.codegen.{CodeGenerator, GeneratedExpression, Gener import org.apache.flink.table.codegen.CodeGenUtils.primitiveDefaultValue import org.apache.flink.table.codegen.GeneratedExpression.{ALWAYS_NULL, NO_CODE} import org.apache.flink.table.functions.utils.TableSqlFunction -import org.apache.flink.table.runtime.FlatMapRunner +import org.apache.flink.table.runtime.{CorrelateFlatMapRunner, TableFunctionCollector} import org.apache.flink.table.typeutils.TypeConverter._ import org.apache.flink.table.api.{TableConfig, TableException} @@ -37,7 +37,7 @@ import scala.collection.JavaConverters._ */ trait FlinkCorrelate { - private[flink] def functionBody( + private[flink] def generateFunction( generator: CodeGenerator, udtfTypeInfo: TypeInformation[Any], rowType: RelDataType, @@ -45,7 +45,10 @@ trait FlinkCorrelate { condition: Option[RexNode], config: TableConfig, joinType: SemiJoinType, - expectedType: Option[TypeInformation[Any]]): String = { + expectedType: Option[TypeInformation[Any]], + ruleDescription: String) + : (GeneratedFunction[FlatMapFunction[Any, Any]], + GeneratedFunction[TableFunctionCollector[Any]]) = { val returnType = determineReturnType( rowType, @@ -55,18 +58,21 @@ trait FlinkCorrelate { val (input1AccessExprs, input2AccessExprs) = generator.generateCorrelateAccessExprs + val collectorTerm = generator.addReusableConstructor(classOf[TableFunctionCollector[_]]).head + val call = generator.generateExpression(rexCall) var body = s""" + |${call.resultTerm}.setCollector($collectorTerm); |${call.code} - |java.util.Iterator iter = ${call.resultTerm}.getRowsIterator(); + |boolean hasOutput = $collectorTerm.isCollected(); """.stripMargin if (joinType == SemiJoinType.INNER) { // cross join body += s""" - |if (!iter.hasNext()) { + |if (!hasOutput) { | return; |} """.stripMargin @@ -86,7 +92,7 @@ trait FlinkCorrelate { input1AccessExprs ++ input2NullExprs, returnType, rowType.getFieldNames.asScala) body += s""" - |if (!iter.hasNext()) { + |if (!hasOutput) { | ${outerResultExpr.code} | ${generator.collectorTerm}.collect(${outerResultExpr.resultTerm}); | return; @@ -96,15 +102,23 @@ trait FlinkCorrelate { throw TableException(s"Unsupported SemiJoinType: $joinType for correlate join.") } + val flatMapFunction = generator.generateFunction( + ruleDescription, + classOf[FlatMapFunction[Any, Any]], + body, + returnType) + + // -------------------------- generate table function collector ----------------------- + val crossResultExpr = generator.generateResultExpression( input1AccessExprs ++ input2AccessExprs, returnType, rowType.getFieldNames.asScala) - val projection = if (condition.isEmpty) { + val collectorCode = if (condition.isEmpty) { s""" |${crossResultExpr.code} - |${generator.collectorTerm}.collect(${crossResultExpr.resultTerm}); + |getCollector().collect(${crossResultExpr.resultTerm}); """.stripMargin } else { val filterGenerator = new CodeGenerator(config, false, udtfTypeInfo) @@ -115,30 +129,30 @@ trait FlinkCorrelate { |${filterCondition.code} |if (${filterCondition.resultTerm}) { | ${crossResultExpr.code} - | ${generator.collectorTerm}.collect(${crossResultExpr.resultTerm}); + | getCollector().collect(${crossResultExpr.resultTerm}); |} |""".stripMargin } - val outputTypeClass = udtfTypeInfo.getTypeClass.getCanonicalName - body += - s""" - |while (iter.hasNext()) { - | $outputTypeClass ${generator.input2Term} = ($outputTypeClass) iter.next(); - | $projection - |} - """.stripMargin - body + val collectorFunction = generator.generateTableFunctionCollector( + "TableFunctionCollector", + collectorCode, + udtfTypeInfo) + + (flatMapFunction, collectorFunction) } private[flink] def correlateMapFunction( - genFunction: GeneratedFunction[FlatMapFunction[Any, Any]]) - : FlatMapRunner[Any, Any] = { - - new FlatMapRunner[Any, Any]( - genFunction.name, - genFunction.code, - genFunction.returnType) + flatMap: GeneratedFunction[FlatMapFunction[Any, Any]], + collector: GeneratedFunction[TableFunctionCollector[Any]]) + : CorrelateFlatMapRunner[Any, Any] = { + + new CorrelateFlatMapRunner[Any, Any]( + flatMap.name, + flatMap.code, + collector.name, + collector.code, + flatMap.returnType) } private[flink] def selectToString(rowType: RelDataType): String = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala index fa1afc3d500b0..c26c4affd134d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala @@ -24,7 +24,6 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.calcite.rex.{RexCall, RexNode} import org.apache.calcite.sql.SemiJoinType -import org.apache.flink.api.common.functions.FlatMapFunction import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.BatchTableEnvironment @@ -93,11 +92,6 @@ class DataSetCorrelate( : DataSet[Any] = { val config = tableEnv.getConfig - val returnType = determineReturnType( - getRowType, - expectedType, - config.getNullCheck, - config.getEfficientTypeUsage) // we do not need to specify input type val inputDS = inputNode.asInstanceOf[DataSetRel].translateToPlan(tableEnv) @@ -116,7 +110,7 @@ class DataSetCorrelate( None, Some(pojoFieldMapping)) - val body = functionBody( + val (flatMap, collector) = generateFunction( generator, udtfTypeInfo, getRowType, @@ -124,15 +118,10 @@ class DataSetCorrelate( condition, config, joinType, - expectedType) - - val genFunction = generator.generateFunction( - ruleDescription, - classOf[FlatMapFunction[Any, Any]], - body, - returnType) + expectedType, + ruleDescription) - val mapFunc = correlateMapFunction(genFunction) + val mapFunc = correlateMapFunction(flatMap, collector) inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType)) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala index a2d167bb63251..8f7fef0c72609 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala @@ -23,7 +23,6 @@ import org.apache.calcite.rel.logical.LogicalTableFunctionScan import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.calcite.rex.{RexCall, RexNode} import org.apache.calcite.sql.SemiJoinType -import org.apache.flink.api.common.functions.FlatMapFunction import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.functions.utils.TableSqlFunction @@ -87,11 +86,6 @@ class DataStreamCorrelate( : DataStream[Any] = { val config = tableEnv.getConfig - val returnType = determineReturnType( - getRowType, - expectedType, - config.getNullCheck, - config.getEfficientTypeUsage) // we do not need to specify input type val inputDS = inputNode.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) @@ -110,7 +104,7 @@ class DataStreamCorrelate( None, Some(pojoFieldMapping)) - val body = functionBody( + val (flatMap, collector) = generateFunction( generator, udtfTypeInfo, getRowType, @@ -118,15 +112,10 @@ class DataStreamCorrelate( condition, config, joinType, - expectedType) - - val genFunction = generator.generateFunction( - ruleDescription, - classOf[FlatMapFunction[Any, Any]], - body, - returnType) + expectedType, + ruleDescription) - val mapFunc = correlateMapFunction(genFunction) + val mapFunc = correlateMapFunction(flatMap, collector) inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType)) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala new file mode 100644 index 0000000000000..4e803da0cae14 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime + +import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction} +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.ResultTypeQueryable +import org.apache.flink.configuration.Configuration +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.util.Collector +import org.slf4j.{Logger, LoggerFactory} + +class CorrelateFlatMapRunner[IN, OUT]( + flatMapName: String, + flatMapCode: String, + collectorName: String, + collectorCode: String, + @transient returnType: TypeInformation[OUT]) + extends RichFlatMapFunction[IN, OUT] + with ResultTypeQueryable[OUT] + with Compiler[Any] { + + val LOG: Logger = LoggerFactory.getLogger(this.getClass) + + private var function: FlatMapFunction[IN, OUT] = _ + private var collector: TableFunctionCollector[_] = _ + + override def open(parameters: Configuration): Unit = { + LOG.debug(s"Compiling TableFunctionCollector: $collectorName \n\n Code:\n$collectorCode") + val clazz = compile(getRuntimeContext.getUserCodeClassLoader, collectorName, collectorCode) + LOG.debug("Instantiating TableFunctionCollector.") + collector = clazz.newInstance().asInstanceOf[TableFunctionCollector[_]] + + LOG.debug(s"Compiling FlatMapFunction: $flatMapName \n\n Code:\n$flatMapCode") + val flatMapClazz = compile(getRuntimeContext.getUserCodeClassLoader, flatMapName, flatMapCode) + val constructor = flatMapClazz.getConstructor(classOf[TableFunctionCollector[_]]) + LOG.debug("Instantiating FlatMapFunction.") + function = constructor.newInstance(collector).asInstanceOf[FlatMapFunction[IN, OUT]] + } + + override def flatMap(in: IN, out: Collector[OUT]): Unit = { + collector.setCollector(out) + collector.setInput(in) + collector.reset() + function.flatMap(in, out) + } + + override def getProducedType: TypeInformation[OUT] = returnType +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala new file mode 100644 index 0000000000000..07516c26b75d2 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime + +import org.apache.flink.util.Collector + +/** + * The basic implementation of collector for [[org.apache.calcite.schema.TableFunction]]. + */ +abstract class TableFunctionCollector[T] extends Collector[T] { + + var input: Any = _ + var collector: Collector[_] = _ + var collected: Boolean = _ + + /** + * Gets the input row from left table, + * which will be used to cross join with the result of table function. + */ + def setInput(input: Any): Unit = { + this.input = input + } + + /** + * Gets the input value from left table, + * which will be used to cross join with the result of table function. + */ + def getInput: Any = { + input + } + + /** + * Sets the current collector, which used to emit the final row. + */ + def setCollector(collector: Collector[_]): Unit = { + this.collector = collector + } + + /** + * Gets the internal collector which used to emit the final row. + */ + def getCollector: Collector[_] = { + this.collector + } + + /** + * Resets the flag to indicate whether [[collect(T)]] has been called. + */ + def reset(): Unit = { + collected = false + } + + /** + * Whether [[collect(T)]] has been called. + * @return True if [[collect(T)]] has been called. + */ + def isCollected: Boolean = collected + + + override def collect(record: T): Unit = { + collected = true + } + +} + + From 3b822703a847831fbd919d033ef9d4025381becb Mon Sep 17 00:00:00 2001 From: Jark Wu Date: Tue, 24 Jan 2017 11:08:55 +0800 Subject: [PATCH 2/3] addressed review comment --- .../flink/table/codegen/CodeGenerator.scala | 26 ++-- .../flink/table/codegen/generated.scala | 7 + .../flink/table/functions/TableFunction.scala | 2 +- .../table/plan/nodes/FlinkCorrelate.scala | 120 +++++++++++++----- .../plan/nodes/dataset/DataSetCorrelate.scala | 16 +-- .../datastream/DataStreamCorrelate.scala | 16 +-- .../runtime/TableFunctionCollector.scala | 10 +- 7 files changed, 122 insertions(+), 75 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 0368d10b9f603..948e2e253c81d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -167,12 +167,12 @@ class CodeGenerator( /** * @return code block of constructor statements of the function */ - def reuseConstructorCode(funcName: String): String = { + def reuseConstructorCode(className: String): String = { reusableConstructorStatements.map { case (params, body) => - j""" - public $funcName($params) throws Exception { - $body - } + s""" + |public $className($params) throws Exception { + | $body + |} """.stripMargin }.mkString("", "\n", "\n") } @@ -355,21 +355,21 @@ class CodeGenerator( * @return instance of GeneratedFunction */ def generateTableFunctionCollector( - name: String, - bodyCode: String, - returnType: TypeInformation[Any]) - : GeneratedFunction[TableFunctionCollector[Any]] = { + name: String, + bodyCode: String, + returnType: TypeInformation[Any]) + : GeneratedCollector = { - val funcName = newName(name) + val className = newName(name) val input1TypeClass = boxedTypeTermForTypeInfo(input1) val input2TypeClass = boxedTypeTermForTypeInfo(returnType) val funcCode = j""" - public class $funcName extends ${classOf[TableFunctionCollector[_]].getCanonicalName} { + public class $className extends ${classOf[TableFunctionCollector[_]].getCanonicalName} { ${reuseMemberCode()} - public $funcName() throws Exception { + public $className() throws Exception { ${reuseInitCode()} } @@ -388,7 +388,7 @@ class CodeGenerator( } """.stripMargin - GeneratedFunction[TableFunctionCollector[Any]](funcName, returnType, funcCode) + GeneratedCollector(className, funcCode) } /** diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala index 0d60dc186e29c..73b3cd9686daa 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala @@ -41,3 +41,10 @@ object GeneratedExpression { } case class GeneratedFunction[T](name: String, returnType: TypeInformation[Any], code: String) + +/** + * Describes a generated [[org.apache.flink.util.Collector]]. + * @param name The class name of the generated Collector. + * @param code The code of the generated Collector. + */ +case class GeneratedCollector(name: String, code: String) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala index dd611084faadd..dbfa51836b1f1 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala @@ -115,7 +115,7 @@ abstract class TableFunction[T] extends UserDefinedFunction { /** * Internal use. Sets the current collector. */ - def setCollector(collector: Collector[T]): Unit = { + private[flink] final def setCollector(collector: Collector[T]): Unit = { this.collector = collector } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala index a0f18417e5485..8f78f7a3d14ac 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala @@ -22,7 +22,8 @@ import org.apache.calcite.rex.{RexCall, RexNode} import org.apache.calcite.sql.SemiJoinType import org.apache.flink.api.common.functions.FlatMapFunction import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.table.codegen.{CodeGenerator, GeneratedExpression, GeneratedFunction} +import org.apache.flink.table.codegen.{CodeGenerator, GeneratedCollector, GeneratedExpression, +GeneratedFunction} import org.apache.flink.table.codegen.CodeGenUtils.primitiveDefaultValue import org.apache.flink.table.codegen.GeneratedExpression.{ALWAYS_NULL, NO_CODE} import org.apache.flink.table.functions.utils.TableSqlFunction @@ -37,18 +38,20 @@ import scala.collection.JavaConverters._ */ trait FlinkCorrelate { - private[flink] def generateFunction( - generator: CodeGenerator, + /** Creates the [[CorrelateFlatMapRunner]] to execute the join of input table + * and user-defined table function */ + private[flink] def correlateMapFunction( + config: TableConfig, + inputTypeInfo: TypeInformation[Any], udtfTypeInfo: TypeInformation[Any], rowType: RelDataType, + joinType: SemiJoinType, rexCall: RexCall, condition: Option[RexNode], - config: TableConfig, - joinType: SemiJoinType, expectedType: Option[TypeInformation[Any]], + pojoFieldMapping: Option[Array[Int]], // udtf return type pojo field mapping ruleDescription: String) - : (GeneratedFunction[FlatMapFunction[Any, Any]], - GeneratedFunction[TableFunctionCollector[Any]]) = { + : CorrelateFlatMapRunner[Any, Any] = { val returnType = determineReturnType( rowType, @@ -56,11 +59,63 @@ trait FlinkCorrelate { config.getNullCheck, config.getEfficientTypeUsage) - val (input1AccessExprs, input2AccessExprs) = generator.generateCorrelateAccessExprs + val flatMap = generateFunction( + config, + inputTypeInfo, + udtfTypeInfo, + returnType, + rowType, + joinType, + rexCall, + pojoFieldMapping, + ruleDescription) + + val collector = generateCollector( + config, + inputTypeInfo, + udtfTypeInfo, + returnType, + rowType, + condition, + pojoFieldMapping) - val collectorTerm = generator.addReusableConstructor(classOf[TableFunctionCollector[_]]).head + new CorrelateFlatMapRunner[Any, Any]( + flatMap.name, + flatMap.code, + collector.name, + collector.code, + flatMap.returnType) + + } - val call = generator.generateExpression(rexCall) + /** Generates the flat map function to run user defined table function */ + private def generateFunction( + config: TableConfig, + inputTypeInfo: TypeInformation[Any], + udtfTypeInfo: TypeInformation[Any], + returnType: TypeInformation[Any], + rowType: RelDataType, + joinType: SemiJoinType, + rexCall: RexCall, + pojoFieldMapping: Option[Array[Int]], + ruleDescription: String) + : GeneratedFunction[FlatMapFunction[Any, Any]] = { + + val functionGenerator = new CodeGenerator( + config, + false, + inputTypeInfo, + Some(udtfTypeInfo), + None, + pojoFieldMapping) + + val (input1AccessExprs, input2AccessExprs) = functionGenerator.generateCorrelateAccessExprs + + val collectorTerm = functionGenerator + .addReusableConstructor(classOf[TableFunctionCollector[_]]) + .head + + val call = functionGenerator.generateExpression(rexCall) var body = s""" |${call.resultTerm}.setCollector($collectorTerm); @@ -88,13 +143,13 @@ trait FlinkCorrelate { NO_CODE, x.resultType) } - val outerResultExpr = generator.generateResultExpression( + val outerResultExpr = functionGenerator.generateResultExpression( input1AccessExprs ++ input2NullExprs, returnType, rowType.getFieldNames.asScala) body += s""" |if (!hasOutput) { | ${outerResultExpr.code} - | ${generator.collectorTerm}.collect(${outerResultExpr.resultTerm}); + | ${functionGenerator.collectorTerm}.collect(${outerResultExpr.resultTerm}); | return; |} """.stripMargin @@ -102,13 +157,33 @@ trait FlinkCorrelate { throw TableException(s"Unsupported SemiJoinType: $joinType for correlate join.") } - val flatMapFunction = generator.generateFunction( + functionGenerator.generateFunction( ruleDescription, classOf[FlatMapFunction[Any, Any]], body, returnType) + } + + /** Generates table function collector */ + private[flink] def generateCollector( + config: TableConfig, + inputTypeInfo: TypeInformation[Any], + udtfTypeInfo: TypeInformation[Any], + returnType: TypeInformation[Any], + rowType: RelDataType, + condition: Option[RexNode], + pojoFieldMapping: Option[Array[Int]]) + : GeneratedCollector = { + + val generator = new CodeGenerator( + config, + false, + inputTypeInfo, + Some(udtfTypeInfo), + None, + pojoFieldMapping) - // -------------------------- generate table function collector ----------------------- + val (input1AccessExprs, input2AccessExprs) = generator.generateCorrelateAccessExprs val crossResultExpr = generator.generateResultExpression( input1AccessExprs ++ input2AccessExprs, @@ -134,25 +209,10 @@ trait FlinkCorrelate { |""".stripMargin } - val collectorFunction = generator.generateTableFunctionCollector( + generator.generateTableFunctionCollector( "TableFunctionCollector", collectorCode, udtfTypeInfo) - - (flatMapFunction, collectorFunction) - } - - private[flink] def correlateMapFunction( - flatMap: GeneratedFunction[FlatMapFunction[Any, Any]], - collector: GeneratedFunction[TableFunctionCollector[Any]]) - : CorrelateFlatMapRunner[Any, Any] = { - - new CorrelateFlatMapRunner[Any, Any]( - flatMap.name, - flatMap.code, - collector.name, - collector.code, - flatMap.returnType) } private[flink] def selectToString(rowType: RelDataType): String = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala index c26c4affd134d..5a75e5ded59b0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala @@ -27,7 +27,6 @@ import org.apache.calcite.sql.SemiJoinType import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.BatchTableEnvironment -import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.functions.utils.TableSqlFunction import org.apache.flink.table.plan.nodes.FlinkCorrelate import org.apache.flink.table.typeutils.TypeConverter._ @@ -102,27 +101,18 @@ class DataSetCorrelate( val pojoFieldMapping = sqlFunction.getPojoFieldMapping val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] - val generator = new CodeGenerator( + val mapFunc = correlateMapFunction( config, - false, inputDS.getType, - Some(udtfTypeInfo), - None, - Some(pojoFieldMapping)) - - val (flatMap, collector) = generateFunction( - generator, udtfTypeInfo, getRowType, + joinType, rexCall, condition, - config, - joinType, expectedType, + Some(pojoFieldMapping), ruleDescription) - val mapFunc = correlateMapFunction(flatMap, collector) - inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType)) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala index 8f7fef0c72609..bd65954210236 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala @@ -24,7 +24,6 @@ import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.calcite.rex.{RexCall, RexNode} import org.apache.calcite.sql.SemiJoinType import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.functions.utils.TableSqlFunction import org.apache.flink.table.plan.nodes.FlinkCorrelate import org.apache.flink.table.typeutils.TypeConverter._ @@ -96,27 +95,18 @@ class DataStreamCorrelate( val pojoFieldMapping = sqlFunction.getPojoFieldMapping val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] - val generator = new CodeGenerator( + val mapFunc = correlateMapFunction( config, - false, inputDS.getType, - Some(udtfTypeInfo), - None, - Some(pojoFieldMapping)) - - val (flatMap, collector) = generateFunction( - generator, udtfTypeInfo, getRowType, + joinType, rexCall, condition, - config, - joinType, expectedType, + Some(pojoFieldMapping), ruleDescription) - val mapFunc = correlateMapFunction(flatMap, collector) - inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType)) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala index 07516c26b75d2..71cb3f1be72b1 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala @@ -20,16 +20,16 @@ package org.apache.flink.table.runtime import org.apache.flink.util.Collector /** - * The basic implementation of collector for [[org.apache.calcite.schema.TableFunction]]. + * The basic implementation of collector for [[org.apache.flink.table.functions.TableFunction]]. */ abstract class TableFunctionCollector[T] extends Collector[T] { - var input: Any = _ - var collector: Collector[_] = _ - var collected: Boolean = _ + private var input: Any = _ + private var collector: Collector[_] = _ + private var collected: Boolean = _ /** - * Gets the input row from left table, + * Sets the input row from left table, * which will be used to cross join with the result of table function. */ def setInput(input: Any): Unit = { From 6c23bb7e8ea721414017115b7ac2a6828f2df2b8 Mon Sep 17 00:00:00 2001 From: Jark Wu Date: Tue, 24 Jan 2017 11:20:22 +0800 Subject: [PATCH 3/3] minor fix --- .../org/apache/flink/table/codegen/CodeGenerator.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 948e2e253c81d..74376f30c9a38 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -171,6 +171,7 @@ class CodeGenerator( reusableConstructorStatements.map { case (params, body) => s""" |public $className($params) throws Exception { + | this(); | $body |} """.stripMargin @@ -1491,7 +1492,7 @@ class CodeGenerator( def addReusableConstructor(parameterTypes: Class[_]*): Array[String] = { val parameters = mutable.ListBuffer[String]() val fieldTerms = mutable.ListBuffer[String]() - var body = "this();\n" + val body = mutable.ListBuffer[String]() parameterTypes.zipWithIndex.foreach { case (t, index) => val classQualifier = t.getCanonicalName @@ -1500,10 +1501,11 @@ class CodeGenerator( reusableMemberStatements.add(field) fieldTerms += fieldTerm parameters += s"$classQualifier arg$index" - body += s"$fieldTerm = arg$index;\n" + body += s"$fieldTerm = arg$index;" } - reusableConstructorStatements.add((parameters.mkString(","), body)) + reusableConstructorStatements.add( + (parameters.mkString(","), body.mkString("", "\n", "\n"))) fieldTerms.toArray }