From 9b8f5a54c54a0bdc05b11df573fbc93a2caf6540 Mon Sep 17 00:00:00 2001 From: beyond1920 Date: Fri, 2 Dec 2016 11:33:12 +0800 Subject: [PATCH 1/3] push project down into BatchTableSourceScan --- .../table/plan/nodes/dataset/BatchScan.scala | 3 +- .../api/table/plan/rules/FlinkRuleSets.scala | 4 +- ...hProjectIntoBatchTableSourceScanRule.scala | 95 +++++ .../rules/util/DataSetCalcConverter.scala | 111 ++++++ .../sources/ProjectableTableSource.scala | 38 ++ ...rojectIntoBatchTableSourceScanITCase.scala | 370 ++++++++++++++++++ 6 files changed, 619 insertions(+), 2 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/DataSetCalcConverter.scala create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/sources/ProjectableTableSource.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanITCase.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchScan.scala index a6de2378548ff..1ebc443daf239 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchScan.scala @@ -45,7 +45,8 @@ abstract class BatchScan( override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { val rowCnt = metadata.getRowCount(this) - planner.getCostFactory.makeCost(rowCnt, rowCnt, 0) + val columnCnt = getRowType.getFieldCount + planner.getCostFactory.makeCost(rowCnt * columnCnt, rowCnt, 0) } protected def convertToExpectedType( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala index 684742567811e..183065c974cf4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala @@ -109,7 +109,9 @@ object FlinkRuleSets { DataSetSortRule.INSTANCE, DataSetValuesRule.INSTANCE, DataSetCorrelateRule.INSTANCE, - BatchTableSourceScanRule.INSTANCE + BatchTableSourceScanRule.INSTANCE, + // project pushdown optimization + PushProjectIntoBatchTableSourceScanRule.INSTANCE ) /** diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala new file mode 100644 index 0000000000000..9f1636e48b487 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.rules.dataSet + +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.plan.RelOptRule.{none, operand} +import org.apache.calcite.rex.{RexProgram, RexUtil} +import org.apache.flink.api.table.plan.nodes.dataset.{BatchTableSourceScan, DataSetCalc} +import org.apache.flink.api.table.plan.rules.util.DataSetCalcConverter._ +import org.apache.flink.api.table.sources.{BatchTableSource, ProjectableTableSource} +import scala.collection.JavaConverters._ + +/** + * This rule is responsible for push project into BatchTableSourceScan node + */ +class PushProjectIntoBatchTableSourceScanRule extends RelOptRule( + operand(classOf[DataSetCalc], + operand(classOf[BatchTableSourceScan], none)), + "PushProjectIntoBatchTableSourceScanRule") { + + override def matches(call: RelOptRuleCall) = { + val scan: BatchTableSourceScan = call.rel(1).asInstanceOf[BatchTableSourceScan] + scan.tableSource match { + case _: ProjectableTableSource[_] => true + case _ => false + } + } + + override def onMatch(call: RelOptRuleCall) { + val calc: DataSetCalc = call.rel(0).asInstanceOf[DataSetCalc] + val scan: BatchTableSourceScan = call.rel(1).asInstanceOf[BatchTableSourceScan] + + val usedFields: Array[Int] = extractRefInputFields(calc) + + // if no fields can be projected, there is no need to transform subtree + if (scan.tableSource.getNumberOfFields == usedFields.length) { + return + } + + val originTableSource = scan.tableSource.asInstanceOf[ProjectableTableSource[_]] + + val newTableSource = originTableSource.projectFields(usedFields) + + val newScan = new BatchTableSourceScan( + scan.getCluster, + scan.getTraitSet, + scan.getTable, + newTableSource.asInstanceOf[BatchTableSource[_]]) + + val (newProjectExprs, newConditionExpr) = rewriteCalcExprs(calc, usedFields) + + // if project merely returns its input and doesn't exist filter, remove datasetCalc nodes + val newProjectExprsList = newProjectExprs.asJava + if (RexUtil.isIdentity(newProjectExprsList, newScan.getRowType) + && !newConditionExpr.isDefined) { + call.transformTo(newScan) + } else { + val newCalcProgram = RexProgram.create( + newScan.getRowType, + newProjectExprsList, + newConditionExpr.getOrElse(null), + calc.calcProgram.getOutputRowType, + calc.getCluster.getRexBuilder) + + val newCal = new DataSetCalc(calc.getCluster, + calc.getTraitSet, + newScan, + calc.getRowType, + newCalcProgram, + description) + + call.transformTo(newCal) + } + } +} + +object PushProjectIntoBatchTableSourceScanRule { + val INSTANCE: RelOptRule = new PushProjectIntoBatchTableSourceScanRule +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/DataSetCalcConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/DataSetCalcConverter.scala new file mode 100644 index 0000000000000..d01bbd969ce74 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/DataSetCalcConverter.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.rules.util + +import org.apache.calcite.rex.{RexCall, RexInputRef, RexLocalRef, RexNode, RexShuttle, RexSlot, RexVisitorImpl} +import org.apache.flink.api.table.plan.nodes.dataset.DataSetCalc + +import scala.collection.JavaConversions._ +import scala.collection.mutable + +object DataSetCalcConverter { + + /** + * extract used input fields index of DataSetCalc RelNode + * + * @param calc the DataSetCalc which to analyze + * @return used input fields indices + */ + def extractRefInputFields(calc: DataSetCalc): Array[Int] = { + val visitor = new RefFieldsVisitor + val calcProgram = calc.calcProgram + // extract input fields from project expressions + calcProgram.getProjectList.foreach(exp => calcProgram.expandLocalRef(exp).accept(visitor)) + val condition = calcProgram.getCondition + // extract input fields from condition expression + if (condition != null) { + calcProgram.expandLocalRef(condition).accept(visitor) + } + visitor.getFields + } + + /** + * rewrite DataSetCal project expressions and condition expression based on new input fields + * + * @param calc the DataSetCalc which to rewrite + * @param usedInputFields input fields index of DataSetCalc RelNode + * @return a tuple which contain 2 elements, the first one is rewritten project expressions; + * the second one is rewritten condition expression, + * Note: if origin condition expression is null, the second value is None + */ + def rewriteCalcExprs( + calc: DataSetCalc, + usedInputFields: Array[Int]): (List[RexNode], Option[RexNode]) = { + val inputRewriter = new InputRewriter(usedInputFields) + val calcProgram = calc.calcProgram + val newProjectExpressions = calcProgram.getProjectList.map( + exp => calcProgram.expandLocalRef(exp).accept(inputRewriter) + ).toList + + val oldCondition = calcProgram.getCondition + val newConditionExpression = { + oldCondition match { + case ref: RexLocalRef => Some(calcProgram.expandLocalRef(ref).accept(inputRewriter)) + case _ => None // null does not match any type + } + } + (newProjectExpressions, newConditionExpression) + } +} + +/** + * A RexVisitor to extract used input fields + */ +class RefFieldsVisitor extends RexVisitorImpl[Unit](true) { + private var fields = mutable.LinkedHashSet[Int]() + + def getFields: Array[Int] = fields.toArray + + override def visitInputRef(inputRef: RexInputRef): Unit = fields += inputRef.getIndex + + override def visitCall(call: RexCall): Unit = + call.operands.foreach(operand => operand.accept(this)) +} + +/** + * This class is responsible for rewrite input + * + * @param fields fields mapping + */ +class InputRewriter(fields: Array[Int]) extends RexShuttle { + + /** old input fields ref index -> new input fields ref index mappings */ + private val fieldMap: Map[Int, Int] = + fields.zipWithIndex.toMap + + override def visitInputRef(inputRef: RexInputRef): RexNode = + new RexInputRef(relNodeIndex(inputRef), inputRef.getType) + + override def visitLocalRef(localRef: RexLocalRef): RexNode = + new RexInputRef(relNodeIndex(localRef), localRef.getType) + + private def relNodeIndex(ref: RexSlot): Int = + fieldMap.getOrElse(ref.getIndex, + throw new IllegalArgumentException("input field contains invalid index")) +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/sources/ProjectableTableSource.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/sources/ProjectableTableSource.scala new file mode 100644 index 0000000000000..35a95c7ebc587 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/sources/ProjectableTableSource.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.sources + +/** + * Defines TableSource which supports project pushdown. + * E.g A definition of TestBatchTableSource which supports project + * class TestBatchTableSource extends BatchTableSource[Row] with ProjectableTableSource[Row] + * + * @tparam T The return type of the [[ProjectableTableSource]]. + */ +trait ProjectableTableSource[T] { + + /** + * create a clone of current projectable instance based on input project fields + * + * @param fields project fields + * @return a clone of current projectable instance based on input project fields + */ + def projectFields(fields: Array[Int]): ProjectableTableSource[T] + +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanITCase.scala new file mode 100644 index 0000000000000..01a1e5744b0af --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanITCase.scala @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.rules.dataSet + +import collection.JavaConversions._ +import org.apache.calcite.rel.{RelNode, RelVisitor} +import org.apache.calcite.rel.core.TableScan +import org.apache.flink.api.common.io.GenericInputFormat +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.table.BatchTableEnvironment +import org.apache.flink.api.java.{DataSet, ExecutionEnvironment} +import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase +import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.api.table.{Row, TableEnvironment} +import org.apache.flink.api.table.sources.{BatchTableSource, ProjectableTableSource} +import org.apache.flink.api.table.typeutils.RowTypeInfo +import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.junit.{Assert, Before, Ignore, Test} +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.mutable + +/** + * Test push project down to batchTableSourceScan optimization + * + * @param mode + * @param configMode + */ +@RunWith(classOf[Parameterized]) +class PushProjectIntoBatchTableSourceScanITCase(mode: TestExecutionMode, + configMode: TableConfigMode) + extends TableProgramsTestBase(mode, configMode) { + + private val tableName = "MyTable" + private var tableEnv: BatchTableEnvironment = null + + @Before + def initTableEnv(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + tableEnv = TableEnvironment.getTableEnvironment(env, config) + tableEnv.registerTableSource(tableName, new TestProjectableTableSource) + } + + @Test + def testProjectOnFilterTableAPI(): Unit = { + val table = tableEnv.scan(tableName).where("amount < 4").select("id, name") + val expectedSelectedFields = Array[String]("name", "id", "amount") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Test + def testProjectOnFilterSql(): Unit = { + val table = tableEnv.sql(s"select id, name from $tableName where amount < 4 ") + val expectedSelectedFields = Array[String]("name", "id", "amount") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Ignore + def testProjectWithWindowTableAPI(): Unit = { + val table = tableEnv.scan(tableName).select("id, amount.avg over (partition by name)") + val expectedSelectedFields = Array[String]("name", "id", "amount") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Test + def testProjectWithWindowSql(): Unit = { + val table = tableEnv.sql(s"select id, avg(amount) over (partition by name) from $tableName") + val expectedSelectedFields = Array[String]("name", "id", "amount") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Test + def testMultipleProjectTableAPI(): Unit = { + val table = tableEnv.scan(tableName) + .where("amount < 4") + .select("amount, id, name") + .select("name") + val expectedSelectedFields = Array[String]("amount", "name") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Test + def testMultipleProjectSql(): Unit = { + val table = tableEnv.sql( + s"select name from (select amount, id, name from $tableName where amount < 4 ) t1") + val expectedSelectedFields = Array[String]("name", "amount") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Test + def testExpressionTableAPI(): Unit = { + val table = tableEnv.scan(tableName).select("id - 1, amount * 2") + val expectedSelectedFields = Array[String]("id", "amount") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Test + def testExpressionSql(): Unit = { + val table = tableEnv.sql("select id - 1, amount * 2 from MyTable") + val expectedSelectedFields = Array[String]("id", "amount") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Test + def testDuplicateFieldTableAPI(): Unit = { + val table = tableEnv.scan(tableName).select("amount, id - 1, amount as amount1, amount * 2") + val expectedSelectedFields = Array[String]("amount", "id") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Test + def testDuplicateFieldSql(): Unit = { + val table = tableEnv.sql( + s"select amount, id - 1, amount as amount1, amount * 2 from $tableName") + val expectedSelectedFields = Array[String]("amount", "id") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Test + def testSelectStarTableAPI(): Unit = { + val table = tableEnv.scan(tableName).select("amount as amount1, *") + val expectedSelectedFields = Array[String]("amount", "name", "id", "price") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Test def testSelectStarSql(): Unit = { + val table = tableEnv.sql(s"select * from $tableName") + val expectedSelectedFields = Array[String]("name", "id", "amount", "price") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Ignore + def testAggregateOnScanTableAPI(): Unit = { + val table = tableEnv.scan(tableName).select("price.avg, amount.max") + val expectedSelectedFields = Array[String]("amount", "price") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Test + def testAggregateOnScanSql(): Unit = { + val table = tableEnv.sql(s"select avg(price), max(amount) from $tableName") + val expectedSelectedFields = Array[String]("amount", "price") + assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) + } + + @Test + def testJoinOnScanTableAPI(): Unit = { + val tableName1 = "MyTable1" + tableEnv.registerTableSource(tableName1, new TestProjectableTableSource) + val in = tableEnv.scan(tableName) + val in1 = tableEnv + .scan(tableName1) + .select("name as name1,id as id1, amount as amount1, price as price1") + val result = in.join(in1).where("id === id1 && amount < 2").select("name, amount1") + val expectedSelectedFields = Array[String]("id", "amount", "name") + assertUsedFieldsEquals(result.getRelNode, tableName, expectedSelectedFields) + val expectedSelectedFields1 = Array[String]("id", "amount") + assertUsedFieldsEquals(result.getRelNode, tableName1, expectedSelectedFields1) + } + + @Test + def testJoinOnScanSql(): Unit = { + val tableName1 = "MyTable1" + tableEnv.registerTableSource(tableName1, new TestProjectableTableSource) + val result = tableEnv.sql( + s"select $tableName.name, $tableName1.amount from $tableName, $tableName1 " + + s"where $tableName.id = $tableName1.id and $tableName.amount < 2") + val expectedSelectedFields = Array[String]("name", "id", "amount") + assertUsedFieldsEquals(result.getRelNode, tableName, expectedSelectedFields) + val expectedSelectedFields1 = Array[String]("id", "amount") + assertUsedFieldsEquals(result.getRelNode, tableName1, expectedSelectedFields1) + } + + @Ignore + def testAggregateOnJoinTableAPI(): Unit = { + val tableName1 = "MyTable1" + tableEnv.registerTableSource(tableName1, new TestProjectableTableSource) + val in = tableEnv.scan(tableName) + val in1 = tableEnv + .scan(tableName1) + .select("name as name1,id as id1, amount as amount1, price as price1") + val result = in + .join(in1) + .where("id === id1 && amount < 2") + .select("id.count, price1.max") + val expectedSelectedFields = Array[String]("id", "amount") + assertUsedFieldsEquals(result.getRelNode, tableName, expectedSelectedFields) + val expectedSelectedFields1 = Array[String]("id", "price") + assertUsedFieldsEquals(result.getRelNode, tableName1, expectedSelectedFields1) + } + + @Test + def testAggregateOnJoinSql(): Unit = { + val tableName1 = "MyTable1" + tableEnv.registerTableSource(tableName1, new TestProjectableTableSource) + val result = tableEnv.sql( + s"select count($tableName.id), max($tableName1.price) from $tableName, $tableName1 " + + s"where $tableName.id = $tableName1.id and $tableName.amount < 2") + val expectedSelectedFields = Array[String]("id", "amount") + assertUsedFieldsEquals(result.getRelNode, tableName, expectedSelectedFields) + val expectedSelectedFields1 = Array[String]("id", "price") + assertUsedFieldsEquals(result.getRelNode, tableName1, expectedSelectedFields1) + } + + /** + * visit RelNode tree to find the leaf TableScan node which contain specify table, + * then get its projection fieldsName + * + * @param tableName table name of which table to get + * @param logicalRoot logical RelNode tree which is not optimized yet + * @param tEnv table environment + * @return projection fieldNames of the specify table + */ + private def extractProjectedFieldNames( + tableName: String, + logicalRoot: RelNode, + tEnv: BatchTableEnvironment) + : Option[Array[String]] = { + val optimizedRelNode = tEnv.optimize(logicalRoot) + val tableVisitor = new TableVisitor + tableVisitor.run(optimizedRelNode) + tableVisitor.fieldsName(tableName) + } + + private def assertUsedFieldsEquals( + root: RelNode, + tableName: String, + expectUsedFields: Array[String]) + : Unit = { + val optionalUsedFields = extractProjectedFieldNames(tableName, root, tableEnv) + optionalUsedFields match { + case Some(usedFields) => + Assert.assertTrue( + usedFields.size == expectUsedFields.size && usedFields.toSet == expectUsedFields.toSet) + case None => Assert.fail(s"cannot find table $tableName in optimized RelNode tree") + } + } +} + +/** + * This class is responsible for collect table fields names of every underlying TableScan node + * Note: if RelNode tree contains same table for more than one time, this visitor would complain by + * throwing an IllegalArgumentException. For example, e.g, + * 'select t.name, t.amount from t, t as t1 where t.id = t1.id and t1.amount < 2' + */ +class TableVisitor extends RelVisitor { + private val tableToUsedFieldsMapping = mutable.HashMap[String, Array[String]]() + + /** + * get fields name of table + * + * @param tableName + * @return + */ + def fieldsName(tableName: String): Option[Array[String]] = { + tableToUsedFieldsMapping.get(tableName) + } + + def run(input: RelNode) { + go(input) + } + + override def visit(node: RelNode, ordinal: Int, parent: RelNode) = { + node match { + case ts: TableScan => + val usedFieldsName = ts.getRowType.getFieldNames.toArray(Array[String]()) + ts.getTable.getQualifiedName.foreach( + tableName => { + tableToUsedFieldsMapping.get(tableName) match { + case Some(_) => + throw new IllegalArgumentException( + s"there already exists table $tableName in RelNode tree") + case None => tableToUsedFieldsMapping += (tableName -> usedFieldsName) + } + + }) + case _ => + } + super.visit(node, ordinal, parent) + } +} + +class TestProjectableTableSource( + fieldTypes: Array[TypeInformation[_]], + fieldNames: Array[String]) + extends BatchTableSource[Row] with ProjectableTableSource[Row] { + + def this() = this( + fieldTypes = Array( + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.DOUBLE_TYPE_INFO), + fieldNames = Array[String]("name", "id", "amount", "price") + ) + + /** Returns the data of the table as a [[org.apache.flink.api.java.DataSet]]. */ + override def getDataSet(execEnv: ExecutionEnvironment): DataSet[Row] = { + execEnv.createInput(new ProjectableInputFormat(33, fieldNames), getReturnType).setParallelism(1) + } + + /** Returns the types of the table fields. */ + override def getFieldTypes: Array[TypeInformation[_]] = fieldTypes + + /** Returns the names of the table fields. */ + override def getFieldsNames: Array[String] = fieldNames + + /** Returns the [[TypeInformation]] for the return type. */ + override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes) + + /** Returns the number of fields of the table. */ + override def getNumberOfFields: Int = fieldNames.length + + override def projectFields(fields: Array[Int]): TestProjectableTableSource = { + val projectedFieldTypes = new Array[TypeInformation[_]](fields.length) + val projectedFieldNames = new Array[String](fields.length) + + fields.zipWithIndex.foreach(f => { + projectedFieldTypes(f._2) = fieldTypes(f._1) + projectedFieldNames(f._2) = fieldNames(f._1)}) + new TestProjectableTableSource(projectedFieldTypes, projectedFieldNames) + } +} + +class ProjectableInputFormat( + num: Int, fieldNames: Array[String]) extends GenericInputFormat[Row] { + + val possibleFieldsName = Set("name", "id", "amount", "price") + var cnt = 0L + require(num > 0, "the num must be positive") + require(fieldNames.toSet.subsetOf(possibleFieldsName), "input field names contain illegal name") + + override def reachedEnd(): Boolean = cnt >= num + + override def nextRecord(reuse: Row): Row = { + fieldNames.zipWithIndex.foreach(f => + f._1 match { + case "name" => + reuse.setField(f._2, "Record_" + cnt) + case "id" => + reuse.setField(f._2, cnt) + case "amount" => + reuse.setField(f._2, cnt.toInt % 16) + case "price" => + reuse.setField(f._2, cnt.toDouble / 3) + case _ => + throw new IllegalArgumentException("unknown field name") + } + ) + cnt += 1 + reuse + } +} From bb4db2b666336eb73c51d28c859eec6a2e0a0d50 Mon Sep 17 00:00:00 2001 From: beyond1920 Date: Thu, 8 Dec 2016 09:35:59 +0800 Subject: [PATCH 2/3] 1. modify DataSetCalcConverter to RexProgramProjectExtractor 2. add testcases 3. modify cost mode --- .../table/plan/nodes/dataset/BatchScan.scala | 3 +- .../nodes/dataset/BatchTableSourceScan.scala | 13 +- ...hProjectIntoBatchTableSourceScanRule.scala | 70 ++-- ...scala => RexProgramProjectExtractor.scala} | 62 +-- .../batch/ProjectableTableSourceITCase.scala | 154 ++++++++ ...rojectIntoBatchTableSourceScanITCase.scala | 370 ------------------ .../util/RexProgramProjectExtractorTest.scala | 121 ++++++ 7 files changed, 353 insertions(+), 440 deletions(-) rename flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/{DataSetCalcConverter.scala => RexProgramProjectExtractor.scala} (60%) create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/ProjectableTableSourceITCase.scala delete mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanITCase.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractorTest.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchScan.scala index 1ebc443daf239..a6de2378548ff 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchScan.scala @@ -45,8 +45,7 @@ abstract class BatchScan( override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { val rowCnt = metadata.getRowCount(this) - val columnCnt = getRowType.getFieldCount - planner.getCostFactory.makeCost(rowCnt * columnCnt, rowCnt, 0) + planner.getCostFactory.makeCost(rowCnt, rowCnt, 0) } protected def convertToExpectedType( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchTableSourceScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchTableSourceScan.scala index 14da86296e816..e368219b552b3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchTableSourceScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/BatchTableSourceScan.scala @@ -19,7 +19,8 @@ package org.apache.flink.api.table.plan.nodes.dataset import org.apache.calcite.plan._ -import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.metadata.RelMetadataQuery +import org.apache.calcite.rel.{RelNode, RelWriter} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.DataSet import org.apache.flink.api.table.{BatchTableEnvironment, FlinkTypeFactory} @@ -39,6 +40,11 @@ class BatchTableSourceScan( flinkTypeFactory.buildRowDataType(tableSource.getFieldsNames, tableSource.getFieldTypes) } + override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { + val rowCnt = metadata.getRowCount(this) + planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * estimateRowSize(getRowType)) + } + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { new BatchTableSourceScan( cluster, @@ -48,6 +54,11 @@ class BatchTableSourceScan( ) } + override def explainTerms(pw: RelWriter): RelWriter = { + super.explainTerms(pw) + .item("fields", tableSource.getFieldsNames.mkString(", ")) + } + override def translateToPlan( tableEnv: BatchTableEnvironment, expectedType: Option[TypeInformation[Any]]): DataSet[Any] = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala index 9f1636e48b487..63c03ab8d50b6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala @@ -20,11 +20,9 @@ package org.apache.flink.api.table.plan.rules.dataSet import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.plan.RelOptRule.{none, operand} -import org.apache.calcite.rex.{RexProgram, RexUtil} import org.apache.flink.api.table.plan.nodes.dataset.{BatchTableSourceScan, DataSetCalc} -import org.apache.flink.api.table.plan.rules.util.DataSetCalcConverter._ +import org.apache.flink.api.table.plan.rules.util.RexProgramProjectExtractor._ import org.apache.flink.api.table.sources.{BatchTableSource, ProjectableTableSource} -import scala.collection.JavaConverters._ /** * This rule is responsible for push project into BatchTableSourceScan node @@ -46,48 +44,40 @@ class PushProjectIntoBatchTableSourceScanRule extends RelOptRule( val calc: DataSetCalc = call.rel(0).asInstanceOf[DataSetCalc] val scan: BatchTableSourceScan = call.rel(1).asInstanceOf[BatchTableSourceScan] - val usedFields: Array[Int] = extractRefInputFields(calc) + val usedFields: Array[Int] = extractRefInputFields(calc.calcProgram) // if no fields can be projected, there is no need to transform subtree - if (scan.tableSource.getNumberOfFields == usedFields.length) { - return - } - - val originTableSource = scan.tableSource.asInstanceOf[ProjectableTableSource[_]] - - val newTableSource = originTableSource.projectFields(usedFields) - - val newScan = new BatchTableSourceScan( - scan.getCluster, - scan.getTraitSet, - scan.getTable, - newTableSource.asInstanceOf[BatchTableSource[_]]) + scan.tableSource.getNumberOfFields match { + case fieldNums if fieldNums == usedFields.length => + case _ => + val originTableSource = scan.tableSource.asInstanceOf[ProjectableTableSource[_]] + val newTableSource = originTableSource.projectFields(usedFields) + val newScan = new BatchTableSourceScan( + scan.getCluster, + scan.getTraitSet, + scan.getTable, + newTableSource.asInstanceOf[BatchTableSource[_]]) - val (newProjectExprs, newConditionExpr) = rewriteCalcExprs(calc, usedFields) + val newCalcProgram = rewriteRexProgram( + calc.calcProgram, + newScan.getRowType, + usedFields, + calc.getCluster.getRexBuilder) - // if project merely returns its input and doesn't exist filter, remove datasetCalc nodes - val newProjectExprsList = newProjectExprs.asJava - if (RexUtil.isIdentity(newProjectExprsList, newScan.getRowType) - && !newConditionExpr.isDefined) { - call.transformTo(newScan) - } else { - val newCalcProgram = RexProgram.create( - newScan.getRowType, - newProjectExprsList, - newConditionExpr.getOrElse(null), - calc.calcProgram.getOutputRowType, - calc.getCluster.getRexBuilder) - - val newCal = new DataSetCalc(calc.getCluster, - calc.getTraitSet, - newScan, - calc.getRowType, - newCalcProgram, - description) - - call.transformTo(newCal) + // if project merely returns its input and doesn't exist filter, remove datasetCalc nodes + if (newCalcProgram.isTrivial) { + call.transformTo(newScan) + } else { + val newCal = new DataSetCalc(calc.getCluster, + calc.getTraitSet, + newScan, + calc.getRowType, + newCalcProgram, + description) + call.transformTo(newCal) + } + } } - } } object PushProjectIntoBatchTableSourceScanRule { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/DataSetCalcConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractor.scala similarity index 60% rename from flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/DataSetCalcConverter.scala rename to flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractor.scala index d01bbd969ce74..9f92e9085abad 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/DataSetCalcConverter.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractor.scala @@ -18,59 +18,67 @@ package org.apache.flink.api.table.plan.rules.util -import org.apache.calcite.rex.{RexCall, RexInputRef, RexLocalRef, RexNode, RexShuttle, RexSlot, RexVisitorImpl} -import org.apache.flink.api.table.plan.nodes.dataset.DataSetCalc +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rex._ import scala.collection.JavaConversions._ import scala.collection.mutable +import scala.collection.JavaConverters._ -object DataSetCalcConverter { +object RexProgramProjectExtractor { /** - * extract used input fields index of DataSetCalc RelNode + * extract used input fields index of RexProgram * - * @param calc the DataSetCalc which to analyze + * @param rexProgram the RexProgram which to analyze * @return used input fields indices */ - def extractRefInputFields(calc: DataSetCalc): Array[Int] = { + def extractRefInputFields(rexProgram: RexProgram): Array[Int] = { val visitor = new RefFieldsVisitor - val calcProgram = calc.calcProgram // extract input fields from project expressions - calcProgram.getProjectList.foreach(exp => calcProgram.expandLocalRef(exp).accept(visitor)) - val condition = calcProgram.getCondition + rexProgram.getProjectList.foreach(exp => rexProgram.expandLocalRef(exp).accept(visitor)) + val condition = rexProgram.getCondition // extract input fields from condition expression if (condition != null) { - calcProgram.expandLocalRef(condition).accept(visitor) + rexProgram.expandLocalRef(condition).accept(visitor) } visitor.getFields } /** - * rewrite DataSetCal project expressions and condition expression based on new input fields + * generate new RexProgram based on new input fields * - * @param calc the DataSetCalc which to rewrite - * @param usedInputFields input fields index of DataSetCalc RelNode - * @return a tuple which contain 2 elements, the first one is rewritten project expressions; - * the second one is rewritten condition expression, - * Note: if origin condition expression is null, the second value is None + * @param oldRexProgram the old RexProgram + * @param inputRowType input row type + * @param usedInputFields input fields index + * @param rexBuilder builder of rex expressions + * @return new RexProgram which contains rewritten project expressions and + * rewritten condition expression */ - def rewriteCalcExprs( - calc: DataSetCalc, - usedInputFields: Array[Int]): (List[RexNode], Option[RexNode]) = { + def rewriteRexProgram( + oldRexProgram: RexProgram, + inputRowType: RelDataType, + usedInputFields: Array[Int], + rexBuilder: RexBuilder): RexProgram = { val inputRewriter = new InputRewriter(usedInputFields) - val calcProgram = calc.calcProgram - val newProjectExpressions = calcProgram.getProjectList.map( - exp => calcProgram.expandLocalRef(exp).accept(inputRewriter) - ).toList + val newProjectExpressions = oldRexProgram.getProjectList.map( + exp => oldRexProgram.expandLocalRef(exp).accept(inputRewriter) + ).toList.asJava - val oldCondition = calcProgram.getCondition + val oldCondition = oldRexProgram.getCondition val newConditionExpression = { oldCondition match { - case ref: RexLocalRef => Some(calcProgram.expandLocalRef(ref).accept(inputRewriter)) - case _ => None // null does not match any type + case ref: RexLocalRef => oldRexProgram.expandLocalRef(ref).accept(inputRewriter) + case _ => null // null does not match any type } } - (newProjectExpressions, newConditionExpression) + RexProgram.create( + inputRowType, + newProjectExpressions, + newConditionExpression, + oldRexProgram.getOutputRowType, + rexBuilder + ) } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/ProjectableTableSourceITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/ProjectableTableSourceITCase.scala new file mode 100644 index 0000000000000..b1792c6a4d7d9 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/ProjectableTableSourceITCase.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.scala.batch + +import org.apache.flink.api.common.io.GenericInputFormat +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.{DataSet => JavaSet, ExecutionEnvironment => JavaExecEnv} +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase +import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.table.sources.{BatchTableSource, ProjectableTableSource} +import org.apache.flink.api.table.typeutils.RowTypeInfo +import org.apache.flink.api.table.{Row, TableEnvironment} +import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.TestBaseUtils +import org.junit.{Before, Test} +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ + +@RunWith(classOf[Parameterized]) +class ProjectableTableSourceITCase(mode: TestExecutionMode, + configMode: TableConfigMode) + extends TableProgramsTestBase(mode, configMode) { + + private val tableName = "MyTable" + private var tableEnv: BatchTableEnvironment = null + + @Before + def initTableEnv(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + tableEnv = TableEnvironment.getTableEnvironment(env, config) + tableEnv.registerTableSource(tableName, new TestProjectableTableSource) + } + + @Test + def testTableAPI(): Unit = { + val results = tableEnv + .scan(tableName) + .where("amount < 4") + .select("id, name") + .collect() + + val expected = Seq( + "0,Record_0", "1,Record_1", "2,Record_2", "3,Record_3", "16,Record_16", + "17,Record_17", "18,Record_18", "19,Record_19", "32,Record_32").mkString("\n") + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + + @Test + def testSQL(): Unit = { + val results = tableEnv + .sql(s"select id, name from $tableName where amount < 4 ") + .collect() + + val expected = Seq( + "0,Record_0", "1,Record_1", "2,Record_2", "3,Record_3", "16,Record_16", + "17,Record_17", "18,Record_18", "19,Record_19", "32,Record_32").mkString("\n") + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + +} + +class TestProjectableTableSource( + fieldTypes: Array[TypeInformation[_]], + fieldNames: Array[String]) + extends BatchTableSource[Row] with ProjectableTableSource[Row] { + + def this() = this( + fieldTypes = Array( + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.DOUBLE_TYPE_INFO), + fieldNames = Array[String]("name", "id", "amount", "price") + ) + + /** Returns the data of the table as a [[org.apache.flink.api.java.DataSet]]. */ + override def getDataSet(execEnv: JavaExecEnv): JavaSet[Row] = { + execEnv.createInput(new ProjectableInputFormat(33, fieldNames), getReturnType).setParallelism(1) + } + + /** Returns the types of the table fields. */ + override def getFieldTypes: Array[TypeInformation[_]] = fieldTypes + + /** Returns the names of the table fields. */ + override def getFieldsNames: Array[String] = fieldNames + + /** Returns the [[TypeInformation]] for the return type. */ + override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes) + + /** Returns the number of fields of the table. */ + override def getNumberOfFields: Int = fieldNames.length + + override def projectFields(fields: Array[Int]): TestProjectableTableSource = { + val projectedFieldTypes = new Array[TypeInformation[_]](fields.length) + val projectedFieldNames = new Array[String](fields.length) + + fields.zipWithIndex.foreach(f => { + projectedFieldTypes(f._2) = fieldTypes(f._1) + projectedFieldNames(f._2) = fieldNames(f._1) + }) + new TestProjectableTableSource(projectedFieldTypes, projectedFieldNames) + } +} + +class ProjectableInputFormat( + num: Int, fieldNames: Array[String]) extends GenericInputFormat[Row] { + + val possibleFieldsName = Set("name", "id", "amount", "price") + var cnt = 0L + require(num > 0, "the num must be positive") + require(fieldNames.toSet.subsetOf(possibleFieldsName), "input field names contain illegal name") + + override def reachedEnd(): Boolean = cnt >= num + + override def nextRecord(reuse: Row): Row = { + fieldNames.zipWithIndex.foreach(f => + f._1 match { + case "name" => + reuse.setField(f._2, "Record_" + cnt) + case "id" => + reuse.setField(f._2, cnt) + case "amount" => + reuse.setField(f._2, cnt.toInt % 16) + case "price" => + reuse.setField(f._2, cnt.toDouble / 3) + case _ => + throw new IllegalArgumentException("unknown field name") + } + ) + cnt += 1 + reuse + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanITCase.scala deleted file mode 100644 index 01a1e5744b0af..0000000000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanITCase.scala +++ /dev/null @@ -1,370 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.api.table.plan.rules.dataSet - -import collection.JavaConversions._ -import org.apache.calcite.rel.{RelNode, RelVisitor} -import org.apache.calcite.rel.core.TableScan -import org.apache.flink.api.common.io.GenericInputFormat -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.java.table.BatchTableEnvironment -import org.apache.flink.api.java.{DataSet, ExecutionEnvironment} -import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase -import org.apache.flink.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode -import org.apache.flink.api.table.{Row, TableEnvironment} -import org.apache.flink.api.table.sources.{BatchTableSource, ProjectableTableSource} -import org.apache.flink.api.table.typeutils.RowTypeInfo -import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode -import org.junit.{Assert, Before, Ignore, Test} -import org.junit.runner.RunWith -import org.junit.runners.Parameterized - -import scala.collection.mutable - -/** - * Test push project down to batchTableSourceScan optimization - * - * @param mode - * @param configMode - */ -@RunWith(classOf[Parameterized]) -class PushProjectIntoBatchTableSourceScanITCase(mode: TestExecutionMode, - configMode: TableConfigMode) - extends TableProgramsTestBase(mode, configMode) { - - private val tableName = "MyTable" - private var tableEnv: BatchTableEnvironment = null - - @Before - def initTableEnv(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - tableEnv = TableEnvironment.getTableEnvironment(env, config) - tableEnv.registerTableSource(tableName, new TestProjectableTableSource) - } - - @Test - def testProjectOnFilterTableAPI(): Unit = { - val table = tableEnv.scan(tableName).where("amount < 4").select("id, name") - val expectedSelectedFields = Array[String]("name", "id", "amount") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Test - def testProjectOnFilterSql(): Unit = { - val table = tableEnv.sql(s"select id, name from $tableName where amount < 4 ") - val expectedSelectedFields = Array[String]("name", "id", "amount") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Ignore - def testProjectWithWindowTableAPI(): Unit = { - val table = tableEnv.scan(tableName).select("id, amount.avg over (partition by name)") - val expectedSelectedFields = Array[String]("name", "id", "amount") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Test - def testProjectWithWindowSql(): Unit = { - val table = tableEnv.sql(s"select id, avg(amount) over (partition by name) from $tableName") - val expectedSelectedFields = Array[String]("name", "id", "amount") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Test - def testMultipleProjectTableAPI(): Unit = { - val table = tableEnv.scan(tableName) - .where("amount < 4") - .select("amount, id, name") - .select("name") - val expectedSelectedFields = Array[String]("amount", "name") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Test - def testMultipleProjectSql(): Unit = { - val table = tableEnv.sql( - s"select name from (select amount, id, name from $tableName where amount < 4 ) t1") - val expectedSelectedFields = Array[String]("name", "amount") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Test - def testExpressionTableAPI(): Unit = { - val table = tableEnv.scan(tableName).select("id - 1, amount * 2") - val expectedSelectedFields = Array[String]("id", "amount") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Test - def testExpressionSql(): Unit = { - val table = tableEnv.sql("select id - 1, amount * 2 from MyTable") - val expectedSelectedFields = Array[String]("id", "amount") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Test - def testDuplicateFieldTableAPI(): Unit = { - val table = tableEnv.scan(tableName).select("amount, id - 1, amount as amount1, amount * 2") - val expectedSelectedFields = Array[String]("amount", "id") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Test - def testDuplicateFieldSql(): Unit = { - val table = tableEnv.sql( - s"select amount, id - 1, amount as amount1, amount * 2 from $tableName") - val expectedSelectedFields = Array[String]("amount", "id") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Test - def testSelectStarTableAPI(): Unit = { - val table = tableEnv.scan(tableName).select("amount as amount1, *") - val expectedSelectedFields = Array[String]("amount", "name", "id", "price") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Test def testSelectStarSql(): Unit = { - val table = tableEnv.sql(s"select * from $tableName") - val expectedSelectedFields = Array[String]("name", "id", "amount", "price") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Ignore - def testAggregateOnScanTableAPI(): Unit = { - val table = tableEnv.scan(tableName).select("price.avg, amount.max") - val expectedSelectedFields = Array[String]("amount", "price") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Test - def testAggregateOnScanSql(): Unit = { - val table = tableEnv.sql(s"select avg(price), max(amount) from $tableName") - val expectedSelectedFields = Array[String]("amount", "price") - assertUsedFieldsEquals(table.getRelNode, tableName, expectedSelectedFields) - } - - @Test - def testJoinOnScanTableAPI(): Unit = { - val tableName1 = "MyTable1" - tableEnv.registerTableSource(tableName1, new TestProjectableTableSource) - val in = tableEnv.scan(tableName) - val in1 = tableEnv - .scan(tableName1) - .select("name as name1,id as id1, amount as amount1, price as price1") - val result = in.join(in1).where("id === id1 && amount < 2").select("name, amount1") - val expectedSelectedFields = Array[String]("id", "amount", "name") - assertUsedFieldsEquals(result.getRelNode, tableName, expectedSelectedFields) - val expectedSelectedFields1 = Array[String]("id", "amount") - assertUsedFieldsEquals(result.getRelNode, tableName1, expectedSelectedFields1) - } - - @Test - def testJoinOnScanSql(): Unit = { - val tableName1 = "MyTable1" - tableEnv.registerTableSource(tableName1, new TestProjectableTableSource) - val result = tableEnv.sql( - s"select $tableName.name, $tableName1.amount from $tableName, $tableName1 " + - s"where $tableName.id = $tableName1.id and $tableName.amount < 2") - val expectedSelectedFields = Array[String]("name", "id", "amount") - assertUsedFieldsEquals(result.getRelNode, tableName, expectedSelectedFields) - val expectedSelectedFields1 = Array[String]("id", "amount") - assertUsedFieldsEquals(result.getRelNode, tableName1, expectedSelectedFields1) - } - - @Ignore - def testAggregateOnJoinTableAPI(): Unit = { - val tableName1 = "MyTable1" - tableEnv.registerTableSource(tableName1, new TestProjectableTableSource) - val in = tableEnv.scan(tableName) - val in1 = tableEnv - .scan(tableName1) - .select("name as name1,id as id1, amount as amount1, price as price1") - val result = in - .join(in1) - .where("id === id1 && amount < 2") - .select("id.count, price1.max") - val expectedSelectedFields = Array[String]("id", "amount") - assertUsedFieldsEquals(result.getRelNode, tableName, expectedSelectedFields) - val expectedSelectedFields1 = Array[String]("id", "price") - assertUsedFieldsEquals(result.getRelNode, tableName1, expectedSelectedFields1) - } - - @Test - def testAggregateOnJoinSql(): Unit = { - val tableName1 = "MyTable1" - tableEnv.registerTableSource(tableName1, new TestProjectableTableSource) - val result = tableEnv.sql( - s"select count($tableName.id), max($tableName1.price) from $tableName, $tableName1 " + - s"where $tableName.id = $tableName1.id and $tableName.amount < 2") - val expectedSelectedFields = Array[String]("id", "amount") - assertUsedFieldsEquals(result.getRelNode, tableName, expectedSelectedFields) - val expectedSelectedFields1 = Array[String]("id", "price") - assertUsedFieldsEquals(result.getRelNode, tableName1, expectedSelectedFields1) - } - - /** - * visit RelNode tree to find the leaf TableScan node which contain specify table, - * then get its projection fieldsName - * - * @param tableName table name of which table to get - * @param logicalRoot logical RelNode tree which is not optimized yet - * @param tEnv table environment - * @return projection fieldNames of the specify table - */ - private def extractProjectedFieldNames( - tableName: String, - logicalRoot: RelNode, - tEnv: BatchTableEnvironment) - : Option[Array[String]] = { - val optimizedRelNode = tEnv.optimize(logicalRoot) - val tableVisitor = new TableVisitor - tableVisitor.run(optimizedRelNode) - tableVisitor.fieldsName(tableName) - } - - private def assertUsedFieldsEquals( - root: RelNode, - tableName: String, - expectUsedFields: Array[String]) - : Unit = { - val optionalUsedFields = extractProjectedFieldNames(tableName, root, tableEnv) - optionalUsedFields match { - case Some(usedFields) => - Assert.assertTrue( - usedFields.size == expectUsedFields.size && usedFields.toSet == expectUsedFields.toSet) - case None => Assert.fail(s"cannot find table $tableName in optimized RelNode tree") - } - } -} - -/** - * This class is responsible for collect table fields names of every underlying TableScan node - * Note: if RelNode tree contains same table for more than one time, this visitor would complain by - * throwing an IllegalArgumentException. For example, e.g, - * 'select t.name, t.amount from t, t as t1 where t.id = t1.id and t1.amount < 2' - */ -class TableVisitor extends RelVisitor { - private val tableToUsedFieldsMapping = mutable.HashMap[String, Array[String]]() - - /** - * get fields name of table - * - * @param tableName - * @return - */ - def fieldsName(tableName: String): Option[Array[String]] = { - tableToUsedFieldsMapping.get(tableName) - } - - def run(input: RelNode) { - go(input) - } - - override def visit(node: RelNode, ordinal: Int, parent: RelNode) = { - node match { - case ts: TableScan => - val usedFieldsName = ts.getRowType.getFieldNames.toArray(Array[String]()) - ts.getTable.getQualifiedName.foreach( - tableName => { - tableToUsedFieldsMapping.get(tableName) match { - case Some(_) => - throw new IllegalArgumentException( - s"there already exists table $tableName in RelNode tree") - case None => tableToUsedFieldsMapping += (tableName -> usedFieldsName) - } - - }) - case _ => - } - super.visit(node, ordinal, parent) - } -} - -class TestProjectableTableSource( - fieldTypes: Array[TypeInformation[_]], - fieldNames: Array[String]) - extends BatchTableSource[Row] with ProjectableTableSource[Row] { - - def this() = this( - fieldTypes = Array( - BasicTypeInfo.STRING_TYPE_INFO, - BasicTypeInfo.LONG_TYPE_INFO, - BasicTypeInfo.INT_TYPE_INFO, - BasicTypeInfo.DOUBLE_TYPE_INFO), - fieldNames = Array[String]("name", "id", "amount", "price") - ) - - /** Returns the data of the table as a [[org.apache.flink.api.java.DataSet]]. */ - override def getDataSet(execEnv: ExecutionEnvironment): DataSet[Row] = { - execEnv.createInput(new ProjectableInputFormat(33, fieldNames), getReturnType).setParallelism(1) - } - - /** Returns the types of the table fields. */ - override def getFieldTypes: Array[TypeInformation[_]] = fieldTypes - - /** Returns the names of the table fields. */ - override def getFieldsNames: Array[String] = fieldNames - - /** Returns the [[TypeInformation]] for the return type. */ - override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes) - - /** Returns the number of fields of the table. */ - override def getNumberOfFields: Int = fieldNames.length - - override def projectFields(fields: Array[Int]): TestProjectableTableSource = { - val projectedFieldTypes = new Array[TypeInformation[_]](fields.length) - val projectedFieldNames = new Array[String](fields.length) - - fields.zipWithIndex.foreach(f => { - projectedFieldTypes(f._2) = fieldTypes(f._1) - projectedFieldNames(f._2) = fieldNames(f._1)}) - new TestProjectableTableSource(projectedFieldTypes, projectedFieldNames) - } -} - -class ProjectableInputFormat( - num: Int, fieldNames: Array[String]) extends GenericInputFormat[Row] { - - val possibleFieldsName = Set("name", "id", "amount", "price") - var cnt = 0L - require(num > 0, "the num must be positive") - require(fieldNames.toSet.subsetOf(possibleFieldsName), "input field names contain illegal name") - - override def reachedEnd(): Boolean = cnt >= num - - override def nextRecord(reuse: Row): Row = { - fieldNames.zipWithIndex.foreach(f => - f._1 match { - case "name" => - reuse.setField(f._2, "Record_" + cnt) - case "id" => - reuse.setField(f._2, cnt) - case "amount" => - reuse.setField(f._2, cnt.toInt % 16) - case "price" => - reuse.setField(f._2, cnt.toDouble / 3) - case _ => - throw new IllegalArgumentException("unknown field name") - } - ) - cnt += 1 - reuse - } -} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractorTest.scala new file mode 100644 index 0000000000000..87dcf256d529a --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractorTest.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.rules.util + +import java.math.BigDecimal + +import org.apache.calcite.adapter.java.JavaTypeFactory +import org.apache.calcite.jdbc.JavaTypeFactoryImpl +import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem} +import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR} +import org.apache.calcite.rex.{RexBuilder, RexInputRef, RexProgram, RexProgramBuilder} +import org.apache.calcite.sql.fun.SqlStdOperatorTable + +import scala.collection.JavaConverters._ +import org.apache.flink.api.table.plan.rules.util.RexProgramProjectExtractor._ +import org.junit.{Assert, Before, Test} + +/** + * This class is responsible for testing RexProgramProjectExtractor + */ +class RexProgramProjectExtractorTest { + private var typeFactory: JavaTypeFactory = null + private var rexBuilder: RexBuilder = null + private var allFieldTypes: Seq[RelDataType] = null + private val allFieldNames = List("name", "id", "amount", "price") + + @Before + def setUp: Unit = { + typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT) + rexBuilder = new RexBuilder(typeFactory) + allFieldTypes = List(VARCHAR, BIGINT, INTEGER, DOUBLE).map(typeFactory.createSqlType(_)) + } + + @Test + def testExtractRefInputFields: Unit = { + val usedFields = extractRefInputFields(buildRexProgram) + Assert.assertArrayEquals(usedFields, Array(2, 1, 3)) + } + + @Test + def testRewriteRexProgram: Unit = { + val originRexProgram = buildRexProgram + Assert.assertTrue(extractExprStrList(originRexProgram).sameElements(Array( + "$0", + "$1", + "$2", + "$3", + "*($t2, $t3)", + "100", + "<($t4, $t5)", + "6", + ">($t2, $t7)", + "AND($t6, $t8)"))) + // use amount, id, price fields to create a new RexProgram + val usedFields = Array(2, 1, 3) + val types = usedFields.map(allFieldTypes(_)).toList.asJava + val names = usedFields.map(allFieldNames(_)).toList.asJava + val inputRowType = typeFactory.createStructType(types, names) + val newRexProgram = rewriteRexProgram(originRexProgram, inputRowType, usedFields, rexBuilder) + Assert.assertTrue(extractExprStrList(newRexProgram).sameElements(Array( + "$0", + "$1", + "$2", + "*($t0, $t2)", + "100", + "<($t3, $t4)", + "6", + ">($t0, $t6)", + "AND($t5, $t7)"))) + } + + private def buildRexProgram: RexProgram = { + val types = allFieldTypes.asJava + val names = allFieldNames.asJava + val inputRowType = typeFactory.createStructType(types, names) + val builder = new RexProgramBuilder(inputRowType, rexBuilder) + val t0 = rexBuilder.makeInputRef(types.get(2), 2) + val t1 = rexBuilder.makeInputRef(types.get(1), 1) + val t2 = rexBuilder.makeInputRef(types.get(3), 3) + val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2)) + val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) + val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L)) + // project: amount, id, amount * price + builder.addProject(t0, "amount") + builder.addProject(t1, "id") + builder.addProject(t3, "total") + // condition: amount * price < 100 and amount > 6 + val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4)) + val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t0, t5)) + val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava)) + builder.addCondition(t8) + builder.getProgram + } + + /** + * extract all expression string list from input RexProgram expression lists + * + * @param rexProgram input RexProgram instance to analyze + * @return all expression string list of input RexProgram expression lists + */ + private def extractExprStrList(rexProgram: RexProgram) = { + rexProgram.getExprList.asScala.map(_.toString) + } + +} From 8c10c97642941534a5324e0b5aca448ff34c3f3a Mon Sep 17 00:00:00 2001 From: beyond1920 Date: Sun, 11 Dec 2016 13:18:04 +0800 Subject: [PATCH 3/3] modify code based on review comments --- ...hProjectIntoBatchTableSourceScanRule.scala | 55 ++++++++-------- .../sources/ProjectableTableSource.scala | 4 +- .../batch/ProjectableTableSourceITCase.scala | 63 ++++++++----------- .../util/RexProgramProjectExtractorTest.scala | 19 +++--- 4 files changed, 65 insertions(+), 76 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala index 63c03ab8d50b6..f6a3ddfeed094 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala @@ -29,7 +29,7 @@ import org.apache.flink.api.table.sources.{BatchTableSource, ProjectableTableSou */ class PushProjectIntoBatchTableSourceScanRule extends RelOptRule( operand(classOf[DataSetCalc], - operand(classOf[BatchTableSourceScan], none)), + operand(classOf[BatchTableSourceScan], none)), "PushProjectIntoBatchTableSourceScanRule") { override def matches(call: RelOptRuleCall) = { @@ -47,37 +47,36 @@ class PushProjectIntoBatchTableSourceScanRule extends RelOptRule( val usedFields: Array[Int] = extractRefInputFields(calc.calcProgram) // if no fields can be projected, there is no need to transform subtree - scan.tableSource.getNumberOfFields match { - case fieldNums if fieldNums == usedFields.length => - case _ => - val originTableSource = scan.tableSource.asInstanceOf[ProjectableTableSource[_]] - val newTableSource = originTableSource.projectFields(usedFields) - val newScan = new BatchTableSourceScan( - scan.getCluster, - scan.getTraitSet, - scan.getTable, - newTableSource.asInstanceOf[BatchTableSource[_]]) + if (scan.tableSource.getNumberOfFields != usedFields.length) { + val originTableSource = scan.tableSource.asInstanceOf[ProjectableTableSource[_]] + val newTableSource = originTableSource.projectFields(usedFields) + val newScan = new BatchTableSourceScan( + scan.getCluster, + scan.getTraitSet, + scan.getTable, + newTableSource.asInstanceOf[BatchTableSource[_]]) - val newCalcProgram = rewriteRexProgram( - calc.calcProgram, - newScan.getRowType, - usedFields, - calc.getCluster.getRexBuilder) + val newCalcProgram = rewriteRexProgram( + calc.calcProgram, + newScan.getRowType, + usedFields, + calc.getCluster.getRexBuilder) - // if project merely returns its input and doesn't exist filter, remove datasetCalc nodes - if (newCalcProgram.isTrivial) { - call.transformTo(newScan) - } else { - val newCal = new DataSetCalc(calc.getCluster, - calc.getTraitSet, - newScan, - calc.getRowType, - newCalcProgram, - description) - call.transformTo(newCal) - } + // if project merely returns its input and doesn't exist filter, remove datasetCalc nodes + if (newCalcProgram.isTrivial) { + call.transformTo(newScan) + } else { + val newCalc = new DataSetCalc( + calc.getCluster, + calc.getTraitSet, + newScan, + calc.getRowType, + newCalcProgram, + description) + call.transformTo(newCalc) } } + } } object PushProjectIntoBatchTableSourceScanRule { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/sources/ProjectableTableSource.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/sources/ProjectableTableSource.scala index 35a95c7ebc587..dd8f68595b932 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/sources/ProjectableTableSource.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/sources/ProjectableTableSource.scala @@ -20,8 +20,8 @@ package org.apache.flink.api.table.sources /** * Defines TableSource which supports project pushdown. - * E.g A definition of TestBatchTableSource which supports project - * class TestBatchTableSource extends BatchTableSource[Row] with ProjectableTableSource[Row] + * E.g A definition of ParquetTableSource which supports project + * class ParquetTableSource extends BatchTableSource[Row] with ProjectableTableSource[Row] * * @tparam T The return type of the [[ProjectableTableSource]]. */ diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/ProjectableTableSourceITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/ProjectableTableSourceITCase.scala index b1792c6a4d7d9..42b9de0ff3294 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/ProjectableTableSourceITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/ProjectableTableSourceITCase.scala @@ -18,7 +18,6 @@ package org.apache.flink.api.scala.batch -import org.apache.flink.api.common.io.GenericInputFormat import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.{DataSet => JavaSet, ExecutionEnvironment => JavaExecEnv} import org.apache.flink.api.scala.ExecutionEnvironment @@ -54,10 +53,10 @@ class ProjectableTableSourceITCase(mode: TestExecutionMode, @Test def testTableAPI(): Unit = { val results = tableEnv - .scan(tableName) - .where("amount < 4") - .select("id, name") - .collect() + .scan(tableName) + .where("amount < 4") + .select("id, name") + .collect() val expected = Seq( "0,Record_0", "1,Record_1", "2,Record_2", "3,Record_3", "16,Record_16", @@ -69,15 +68,14 @@ class ProjectableTableSourceITCase(mode: TestExecutionMode, @Test def testSQL(): Unit = { val results = tableEnv - .sql(s"select id, name from $tableName where amount < 4 ") - .collect() + .sql(s"select id, name from $tableName where amount < 4 ") + .collect() val expected = Seq( "0,Record_0", "1,Record_1", "2,Record_2", "3,Record_3", "16,Record_16", "17,Record_17", "18,Record_18", "19,Record_19", "32,Record_32").mkString("\n") TestBaseUtils.compareResultAsText(results.asJava, expected) } - } class TestProjectableTableSource( @@ -96,7 +94,7 @@ class TestProjectableTableSource( /** Returns the data of the table as a [[org.apache.flink.api.java.DataSet]]. */ override def getDataSet(execEnv: JavaExecEnv): JavaSet[Row] = { - execEnv.createInput(new ProjectableInputFormat(33, fieldNames), getReturnType).setParallelism(1) + execEnv.fromCollection(generateDynamicCollection(33, fieldNames).asJava, getReturnType) } /** Returns the types of the table fields. */ @@ -121,34 +119,27 @@ class TestProjectableTableSource( }) new TestProjectableTableSource(projectedFieldTypes, projectedFieldNames) } -} -class ProjectableInputFormat( - num: Int, fieldNames: Array[String]) extends GenericInputFormat[Row] { - - val possibleFieldsName = Set("name", "id", "amount", "price") - var cnt = 0L - require(num > 0, "the num must be positive") - require(fieldNames.toSet.subsetOf(possibleFieldsName), "input field names contain illegal name") - - override def reachedEnd(): Boolean = cnt >= num - - override def nextRecord(reuse: Row): Row = { - fieldNames.zipWithIndex.foreach(f => - f._1 match { - case "name" => - reuse.setField(f._2, "Record_" + cnt) - case "id" => - reuse.setField(f._2, cnt) - case "amount" => - reuse.setField(f._2, cnt.toInt % 16) - case "price" => - reuse.setField(f._2, cnt.toDouble / 3) - case _ => - throw new IllegalArgumentException("unknown field name") + private def generateDynamicCollection(num: Int, fieldNames: Array[String]): Seq[Row] = { + for {cnt <- 0 until num} + yield { + val row = new Row(fieldNames.length) + fieldNames.zipWithIndex.foreach( + f => + f._1 match { + case "name" => + row.setField(f._2, "Record_" + cnt) + case "id" => + row.setField(f._2, cnt.toLong) + case "amount" => + row.setField(f._2, cnt.toInt % 16) + case "price" => + row.setField(f._2, cnt.toDouble / 3) + case _ => + throw new IllegalArgumentException(s"unknown field name $f._1") + } + ) + row } - ) - cnt += 1 - reuse } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractorTest.scala index 87dcf256d529a..156f281812933 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractorTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/rules/util/RexProgramProjectExtractorTest.scala @@ -24,7 +24,7 @@ import org.apache.calcite.adapter.java.JavaTypeFactory import org.apache.calcite.jdbc.JavaTypeFactoryImpl import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem} import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR} -import org.apache.calcite.rex.{RexBuilder, RexInputRef, RexProgram, RexProgramBuilder} +import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder} import org.apache.calcite.sql.fun.SqlStdOperatorTable import scala.collection.JavaConverters._ @@ -50,7 +50,7 @@ class RexProgramProjectExtractorTest { @Test def testExtractRefInputFields: Unit = { val usedFields = extractRefInputFields(buildRexProgram) - Assert.assertArrayEquals(usedFields, Array(2, 1, 3)) + Assert.assertArrayEquals(usedFields, Array(2, 3, 1)) } @Test @@ -65,10 +65,10 @@ class RexProgramProjectExtractorTest { "100", "<($t4, $t5)", "6", - ">($t2, $t7)", + ">($t1, $t7)", "AND($t6, $t8)"))) // use amount, id, price fields to create a new RexProgram - val usedFields = Array(2, 1, 3) + val usedFields = Array(2, 3, 1) val types = usedFields.map(allFieldTypes(_)).toList.asJava val names = usedFields.map(allFieldNames(_)).toList.asJava val inputRowType = typeFactory.createStructType(types, names) @@ -77,11 +77,11 @@ class RexProgramProjectExtractorTest { "$0", "$1", "$2", - "*($t0, $t2)", + "*($t0, $t1)", "100", "<($t3, $t4)", "6", - ">($t0, $t6)", + ">($t2, $t6)", "AND($t5, $t7)"))) } @@ -96,13 +96,12 @@ class RexProgramProjectExtractorTest { val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2)) val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L)) - // project: amount, id, amount * price + // project: amount, amount * price builder.addProject(t0, "amount") - builder.addProject(t1, "id") builder.addProject(t3, "total") - // condition: amount * price < 100 and amount > 6 + // condition: amount * price < 100 and id > 6 val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4)) - val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t0, t5)) + val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5)) val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava)) builder.addCondition(t8) builder.getProgram