diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkCalcMergeRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkCalcMergeRule.java index 76bc4f3444ec4..5ba2598f5690f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkCalcMergeRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkCalcMergeRule.java @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.plan.rules.logical; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCalc; import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalCalc; import org.apache.flink.table.planner.plan.utils.FlinkRelUtil; @@ -52,6 +53,8 @@ public class FlinkCalcMergeRule extends RelRule + b0.operand(BatchPhysicalCalc.class) + .inputs( + b1 -> + b1.operand(BatchPhysicalCalc.class) + .anyInputs())) + .withRelBuilderFactory(RelFactories.LOGICAL_BUILDER) + .withDescription("FlinkCalcMergeRule"); + @Override default FlinkCalcMergeRule toRule() { return new FlinkCalcMergeRule(this); diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala index 7ff8a16a70e4a..ab563a9a737e2 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala @@ -417,6 +417,7 @@ object FlinkBatchRuleSets { /** RuleSet to do physical optimize for batch */ val PHYSICAL_OPT_RULES: RuleSet = RuleSets.ofList( + FlinkCalcMergeRule.BATCH_PHYSICAL_INSTANCE, FlinkExpandConversionRule.BATCH_INSTANCE, // source BatchPhysicalBoundedStreamScanRule.INSTANCE, diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalCorrelateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalCorrelateRule.scala index 471f40cce18e1..ec92b278c35e2 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalCorrelateRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalCorrelateRule.scala @@ -17,17 +17,24 @@ */ package org.apache.flink.table.planner.plan.rules.physical.batch +import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate, FlinkLogicalTableFunctionScan} -import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCorrelate +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalCalc, BatchPhysicalCorrelate} import org.apache.flink.table.planner.plan.utils.PythonUtil -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall, RelTraitSet} import org.apache.calcite.plan.volcano.RelSubset +import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField} import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.convert.ConverterRule import org.apache.calcite.rel.convert.ConverterRule.Config -import org.apache.calcite.rex.RexNode +import org.apache.calcite.rex.{RexNode, RexProgram, RexUtil} +import org.apache.calcite.sql.validate.SqlValidatorUtil + +import java.util.Collections + +import scala.collection.JavaConverters._ class BatchPhysicalCorrelateRule(config: Config) extends ConverterRule(config) { @@ -51,35 +58,72 @@ class BatchPhysicalCorrelateRule(config: Config) extends ConverterRule(config) { } override def convert(rel: RelNode): RelNode = { - val join = rel.asInstanceOf[FlinkLogicalCorrelate] + val correlate = rel.asInstanceOf[FlinkLogicalCorrelate] + val cluster = correlate.getCluster val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL) - val convInput: RelNode = RelOptRule.convert(join.getInput(0), FlinkConventions.BATCH_PHYSICAL) - val right: RelNode = join.getInput(1) + val convInput: RelNode = + RelOptRule.convert(correlate.getInput(0), FlinkConventions.BATCH_PHYSICAL) - def convertToCorrelate(relNode: RelNode, condition: Option[RexNode]): BatchPhysicalCorrelate = { + // matches() guarantees the right side is either a TableFunctionScan, or a single Calc + // whose immediate input is a TableFunctionScan. + @scala.annotation.tailrec + def unwrap( + relNode: RelNode): (FlinkLogicalTableFunctionScan, Option[Seq[RexNode]], Option[RexNode]) = relNode match { - case rel: RelSubset => - convertToCorrelate(rel.getRelList.get(0), condition) - + case rel: RelSubset => unwrap(rel.getRelList.get(0)) + case scan: FlinkLogicalTableFunctionScan => (scan, None, None) case calc: FlinkLogicalCalc => - convertToCorrelate( - calc.getInput.asInstanceOf[RelSubset].getOriginal, - if (calc.getProgram.getCondition == null) None - else Some(calc.getProgram.expandLocalRef(calc.getProgram.getCondition)) - ) - - case scan: FlinkLogicalTableFunctionScan => - new BatchPhysicalCorrelate( - rel.getCluster, - traitSet, - convInput, - scan, - condition, - rel.getRowType, - join.getJoinType) + val scan = calc.getInput + .asInstanceOf[RelSubset] + .getOriginal + .asInstanceOf[FlinkLogicalTableFunctionScan] + val program = calc.getProgram + val condition = + if (program.getCondition == null) None + else Some(program.expandLocalRef(program.getCondition)) + val projects = + if (program.projectsOnlyIdentity()) None + else Some(program.getProjectList.asScala.map(program.expandLocalRef).toSeq) + (scan, projects, condition) } + + val (scan, projectsOpt, condition) = unwrap(correlate.getInput(1)) + + projectsOpt match { + case None => + new BatchPhysicalCorrelate( + cluster, + traitSet, + convInput, + scan, + condition, + correlate.getRowType, + correlate.getJoinType) + case Some(projects) => + val innerRowType = SqlValidatorUtil.deriveJoinRowType( + correlate.getLeft.getRowType, + scan.getRowType, + correlate.getJoinType, + cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory], + null, + Collections.emptyList[RelDataTypeField]() + ) + val innerCorrelate = new BatchPhysicalCorrelate( + cluster, + traitSet, + convInput, + scan, + condition, + innerRowType, + correlate.getJoinType) + val outerProgram = BatchPhysicalCorrelateRule.buildOuterProgram( + cluster, + correlate.getLeft.getRowType.getFieldCount, + innerRowType, + correlate.getRowType, + projects) + new BatchPhysicalCalc(cluster, traitSet, innerCorrelate, outerProgram, correlate.getRowType) } - convertToCorrelate(right, None) } } @@ -90,4 +134,26 @@ object BatchPhysicalCorrelateRule { FlinkConventions.LOGICAL, FlinkConventions.BATCH_PHYSICAL, "BatchPhysicalCorrelateRule")) + + /** + * Builds the outer Calc program that sits on top of the inner correlate: passes the left input + * through unchanged, then appends the right-side projections shifted by the left field count. + */ + def buildOuterProgram( + cluster: RelOptCluster, + leftFieldCount: Int, + innerRowType: RelDataType, + outputRowType: RelDataType, + rightProjects: Seq[RexNode]): RexProgram = { + val rexBuilder = cluster.getRexBuilder + val outerProjects = new java.util.ArrayList[RexNode]() + val innerFields = innerRowType.getFieldList + var i = 0 + while (i < leftFieldCount) { + outerProjects.add(rexBuilder.makeInputRef(innerFields.get(i).getType, i)) + i += 1 + } + rightProjects.foreach(p => outerProjects.add(RexUtil.shift(p, leftFieldCount))) + RexProgram.create(innerRowType, outerProjects, null, outputRowType, rexBuilder) + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalCorrelateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalCorrelateRule.scala index f08696e5e865b..97ed1a6c10dc1 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalCorrelateRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalCorrelateRule.scala @@ -18,19 +18,26 @@ package org.apache.flink.table.planner.plan.rules.physical.stream import org.apache.flink.table.api.TableException +import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate, FlinkLogicalTableFunctionScan} -import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalCorrelate +import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamPhysicalCalc, StreamPhysicalCorrelate} import org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule.{getMergedCalc, getTableScan} -import org.apache.flink.table.planner.plan.utils.{AsyncUtil, PythonUtil} +import org.apache.flink.table.planner.plan.utils.{AsyncUtil, FlinkRelUtil, PythonUtil} -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall, RelTraitSet} import org.apache.calcite.plan.hep.HepRelVertex import org.apache.calcite.plan.volcano.RelSubset +import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField} import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.convert.ConverterRule import org.apache.calcite.rel.convert.ConverterRule.Config -import org.apache.calcite.rex.{RexNode, RexProgram, RexProgramBuilder} +import org.apache.calcite.rex.{RexNode, RexProgram, RexUtil} +import org.apache.calcite.sql.validate.SqlValidatorUtil + +import java.util.Collections + +import scala.collection.JavaConverters._ /** Rule that converts [[FlinkLogicalCorrelate]] to [[StreamPhysicalCorrelate]]. */ class StreamPhysicalCorrelateRule(config: Config) extends ConverterRule(config) { @@ -63,40 +70,75 @@ class StreamPhysicalCorrelateRule(config: Config) extends ConverterRule(config) override def convert(rel: RelNode): RelNode = { val correlate = rel.asInstanceOf[FlinkLogicalCorrelate] + val cluster = correlate.getCluster val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL) val convInput: RelNode = RelOptRule.convert(correlate.getInput(0), FlinkConventions.STREAM_PHYSICAL) val right: RelNode = correlate.getInput(1) @scala.annotation.tailrec - def convertToCorrelate( - relNode: RelNode, - condition: Option[RexNode]): StreamPhysicalCorrelate = { + def unwrap(relNode: RelNode) + : (FlinkLogicalTableFunctionScan, Option[Seq[RexNode]], Option[RexNode]) = { relNode match { - case rel: RelSubset => - convertToCorrelate(rel.getRelList.get(0), condition) - + case rel: RelSubset => unwrap(rel.getRelList.get(0)) case calc: FlinkLogicalCalc => val tableScan = getTableScan(calc) val newCalc = getMergedCalc(calc) - convertToCorrelate( - tableScan, - if (newCalc.getProgram.getCondition == null) None - else Some(newCalc.getProgram.expandLocalRef(newCalc.getProgram.getCondition)) - ) - + val program = newCalc.getProgram + val condition = + if (program.getCondition == null) None + else Some(program.expandLocalRef(program.getCondition)) + val projects = + if (program.projectsOnlyIdentity()) None + else Some(program.getProjectList.asScala.map(program.expandLocalRef).toSeq) + (tableScan, projects, condition) case scan: FlinkLogicalTableFunctionScan => - new StreamPhysicalCorrelate( - rel.getCluster, - traitSet, - convInput, - scan, - condition, - rel.getRowType, - correlate.getJoinType) + (scan, None, None) } } - convertToCorrelate(right, None) + + val (scan, projectsOpt, condition) = unwrap(right) + + projectsOpt match { + case None => + new StreamPhysicalCorrelate( + cluster, + traitSet, + convInput, + scan, + condition, + correlate.getRowType, + correlate.getJoinType) + case Some(projects) => + val innerRowType = SqlValidatorUtil.deriveJoinRowType( + correlate.getLeft.getRowType, + scan.getRowType, + correlate.getJoinType, + cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory], + null, + Collections.emptyList[RelDataTypeField]() + ) + val innerCorrelate = new StreamPhysicalCorrelate( + cluster, + traitSet, + convInput, + scan, + condition, + innerRowType, + correlate.getJoinType) + val outerProgram = StreamPhysicalCorrelateRule.buildOuterProgram( + cluster, + correlate.getLeft.getRowType.getFieldCount, + innerRowType, + correlate.getRowType, + projects) + new StreamPhysicalCalc( + cluster, + traitSet, + innerCorrelate, + outerProgram, + correlate.getRowType) + } } } @@ -117,17 +159,7 @@ object StreamPhysicalCorrelateRule { child match { case calc1: FlinkLogicalCalc => val bottomCalc = getMergedCalc(calc1) - val topCalc = calc - val topProgram: RexProgram = topCalc.getProgram - val mergedProgram: RexProgram = RexProgramBuilder - .mergePrograms( - topCalc.getProgram, - bottomCalc.getProgram, - topCalc.getCluster.getRexBuilder) - assert(mergedProgram.getOutputRowType eq topProgram.getOutputRowType) - topCalc - .copy(topCalc.getTraitSet, bottomCalc.getInput, mergedProgram) - .asInstanceOf[FlinkLogicalCalc] + FlinkRelUtil.merge(calc, bottomCalc).asInstanceOf[FlinkLogicalCalc] case _ => calc } @@ -145,4 +177,22 @@ object StreamPhysicalCorrelateRule { case _ => throw new TableException("This must be a bug, could not find table scan") } } + + def buildOuterProgram( + cluster: RelOptCluster, + leftFieldCount: Int, + innerRowType: RelDataType, + outputRowType: RelDataType, + rightProjects: Seq[RexNode]): RexProgram = { + val rexBuilder = cluster.getRexBuilder + val builder = new java.util.ArrayList[RexNode]() + val leftFields = innerRowType.getFieldList + var i = 0 + while (i < leftFieldCount) { + builder.add(rexBuilder.makeInputRef(leftFields.get(i).getType, i)) + i += 1 + } + rightProjects.foreach(p => builder.add(RexUtil.shift(p, leftFieldCount))) + RexProgram.create(innerRowType, builder, null, outputRowType, rexBuilder) + } } diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/UnnestTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/UnnestTest.xml index 5e0da95bb0719..b1fbca4397a00 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/UnnestTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/UnnestTest.xml @@ -91,6 +91,30 @@ Calc(select=[a, b, _1, _2], where=[(_1 > a)]) +- Correlate(invocation=[$UNNEST_ROWS$1($cor0.b)], correlate=[table($UNNEST_ROWS$1($cor0.b))], select=[a,b,_1,_2], rowType=[RecordType(INTEGER a, RecordType:peek_no_expand(INTEGER _1, VARCHAR(2147483647) _2) ARRAY b, INTEGER _1, VARCHAR(2147483647) _2)], joinType=[INNER]) +- Calc(select=[a, b], where=[(a < 3)]) +- TableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b]) +]]> + + + + + + + + + + + @@ -688,8 +712,8 @@ LogicalProject(a=[$0], number=[$1], ordinality=[$2]) diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/UnnestTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/UnnestTest.xml index 138d56d8c12a7..f5b25d0a85e65 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/UnnestTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/UnnestTest.xml @@ -91,6 +91,30 @@ Calc(select=[a, b, _1, _2], where=[(_1 > a)]) +- Correlate(invocation=[$UNNEST_ROWS$1($cor0.b)], correlate=[table($UNNEST_ROWS$1($cor0.b))], select=[a,b,_1,_2], rowType=[RecordType(INTEGER a, RecordType:peek_no_expand(INTEGER _1, VARCHAR(2147483647) _2) ARRAY b, INTEGER _1, VARCHAR(2147483647) _2)], joinType=[INNER]) +- Calc(select=[a, b], where=[(a < 3)]) +- TableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b]) +]]> + + + + + + + + + + + @@ -700,8 +724,8 @@ LogicalProject(a=[$0], number=[$1], ordinality=[$2]) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/common/UnnestTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/common/UnnestTestBase.scala index 2634b0dc30164..884c912702bb5 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/common/UnnestTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/common/UnnestTestBase.scala @@ -302,6 +302,14 @@ abstract class UnnestTestBase(withExecPlan: Boolean) extends TableTestBase { "ON v.bd_name <> 'debug'") } + @Test + def testLateralProjectionFromUnnest(): Unit = { + util.addTableSource[(Int, Array[Int])]("MyTable", 'a, 'b) + util.verifyRelPlan( + "SELECT a, doubled FROM MyTable, " + + "LATERAL (SELECT s * 2 FROM UNNEST(MyTable.b) AS T(s)) AS R(doubled)") + } + def verifyPlan(sql: String): Unit = { if (withExecPlan) { util.verifyExecPlan(sql) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CorrelateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CorrelateITCase.scala index 99a11bfe4d576..213eb3c797ddd 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CorrelateITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CorrelateITCase.scala @@ -434,6 +434,30 @@ class CorrelateITCase extends StreamingTestBase { .sorted).isEqualTo(expected.sorted) } + @Test + def testCorrelateWithRightSideProjection(): Unit = { + val data = List((1, 2, "a|b"), (3, 4, "c|d")) + + val t1 = StreamingEnvUtil + .fromCollection(env, data) + .toTable(tEnv, 'a, 'b, 'c) + tEnv.createTemporaryView("T1", t1) + + val sql = + "SELECT a, s FROM T1, " + + "LATERAL (SELECT CONCAT(v, '_x') FROM TABLE(STRING_SPLIT(c, '|')) AS T(v)) AS R(s)" + + val result = tEnv.sqlQuery(sql) + TestSinkUtil.addValuesSink(tEnv, "MySink", result, ChangelogMode.insertOnly()) + result.executeInsert("MySink").await() + + val expected = List("+I[1, a_x]", "+I[1, b_x]", "+I[3, c_x]", "+I[3, d_x]") + assertThat( + TestValuesTableFactory + .getResultsAsStrings("MySink") + .sorted).isEqualTo(expected.sorted) + } + @Test def testLateralCrossJoin(): Unit = { val data = List((1, 2, "x|y"))