Skip to content

Commit

Permalink
[FLINK-20887][table-planner] Step2: fix incorrect calc merge via rele…
Browse files Browse the repository at this point in the history
…ated rules
  • Loading branch information
lincoln-lil committed Jun 20, 2023
1 parent b199cc0 commit 590a207
Show file tree
Hide file tree
Showing 15 changed files with 258 additions and 75 deletions.
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.rules.logical;

import org.apache.flink.table.planner.plan.utils.InputRefVisitor;

import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.logical.LogicalCalc;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.rules.ProjectCalcMergeRule;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

/**
* Extends calcite's ProjectCalcMergeRule, modification: does not merge the filter references field
* which generated by non-deterministic function.
*/
public class FlinkProjectCalcMergeRule extends ProjectCalcMergeRule {

public static final RelOptRule INSTANCE = new FlinkProjectCalcMergeRule(Config.DEFAULT);

protected FlinkProjectCalcMergeRule(Config config) {
super(config);
}

@Override
public void onMatch(RelOptRuleCall call) {
LogicalProject project = call.rel(0);
LogicalCalc calc = call.rel(1);

List<RexNode> expandProjects =
calc.getProgram().getProjectList().stream()
.map(p -> calc.getProgram().expandLocalRef(p))
.collect(Collectors.toList());
InputRefVisitor inputRefVisitor = new InputRefVisitor();
project.getProjects().forEach(p -> p.accept(inputRefVisitor));
boolean existNonDeterministicRef =
Arrays.stream(inputRefVisitor.getFields())
.anyMatch(i -> !RexUtil.isDeterministic(expandProjects.get(i)));

if (!existNonDeterministicRef) {
super.onMatch(call);
}
}
}
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.rules.logical;

import org.apache.flink.table.planner.plan.utils.FlinkRexUtil;

import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.rules.ProjectMergeRule;

/**
* Extends calcite's FilterCalcMergeRule for streaming scenario, modification: does not merge the
* filter references field which generated by non-deterministic function.
*/
public class FlinkProjectMergeRule extends ProjectMergeRule {

public static final RelOptRule INSTANCE = new FlinkProjectMergeRule(Config.DEFAULT);

protected FlinkProjectMergeRule(Config config) {
super(config);
}

@Override
public void onMatch(RelOptRuleCall call) {
final Project topProject = call.rel(0);
final Project bottomProject = call.rel(1);
if (FlinkRexUtil.isMergeable(topProject, bottomProject)) {
super.onMatch(call);
}
}
}
Expand Up @@ -196,7 +196,7 @@ object FlinkBatchRuleSets {
// push a projection to the children of a semi/anti Join
ProjectSemiAntiJoinTransposeRule.INSTANCE,
// merge projections
CoreRules.PROJECT_MERGE,
FlinkProjectMergeRule.INSTANCE,
// remove identity project
CoreRules.PROJECT_REMOVE,
// removes constant keys from an Agg
Expand Down Expand Up @@ -288,7 +288,7 @@ object FlinkBatchRuleSets {

// calc rules
FlinkFilterCalcMergeRule.INSTANCE,
CoreRules.PROJECT_CALC_MERGE,
FlinkProjectCalcMergeRule.INSTANCE,
CoreRules.FILTER_TO_CALC,
CoreRules.PROJECT_TO_CALC,
FlinkCalcMergeRule.INSTANCE,
Expand Down
Expand Up @@ -106,7 +106,7 @@ object FlinkStreamRuleSets {
// fix: FLINK-17553 unsupported call error when constant exists in group window key
// this rule will merge the project generated by AggregateProjectPullUpConstantsRule and
// make sure window aggregate can be correctly rewritten by StreamLogicalWindowAggregateRule
CoreRules.PROJECT_MERGE,
FlinkProjectMergeRule.INSTANCE,
StreamLogicalWindowAggregateRule.INSTANCE,
// slices a project into sections which contain window agg functions
// and sections which do not.
Expand Down Expand Up @@ -200,7 +200,7 @@ object FlinkStreamRuleSets {
// push a projection to the children of a semi/anti Join
ProjectSemiAntiJoinTransposeRule.INSTANCE,
// merge projections
CoreRules.PROJECT_MERGE,
FlinkProjectMergeRule.INSTANCE,
// remove identity project
CoreRules.PROJECT_REMOVE,
// removes constant keys from an Agg
Expand Down Expand Up @@ -281,7 +281,7 @@ object FlinkStreamRuleSets {

// calc rules
FlinkFilterCalcMergeRule.INSTANCE,
CoreRules.PROJECT_CALC_MERGE,
FlinkProjectCalcMergeRule.INSTANCE,
CoreRules.FILTER_TO_CALC,
CoreRules.PROJECT_TO_CALC,
FlinkCalcMergeRule.INSTANCE,
Expand Down
Expand Up @@ -22,11 +22,8 @@ import org.apache.flink.table.planner.plan.utils.FlinkRexUtil

import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.RelOptUtil.InputFinder
import org.apache.calcite.rel.core.{Calc, RelFactories}
import org.apache.calcite.rex.{RexNode, RexOver, RexProgramBuilder, RexUtil}

import scala.collection.JavaConversions._
import org.apache.calcite.rex.{RexNode, RexOver, RexProgramBuilder}

/**
* This rule is copied from Calcite's [[org.apache.calcite.rel.rules.CalcMergeRule]].
Expand Down Expand Up @@ -61,49 +58,7 @@ class FlinkCalcMergeRule[C <: Calc](calcClass: Class[C])
return false
}

isMergeable(topCalc, bottomCalc)
}

/**
* Return two neighbouring [[Calc]] can merge into one [[Calc]] or not. If the two [[Calc]] can
* merge into one, each non-deterministic [[RexNode]] of bottom [[Calc]] should appear at most
* once in the project list and filter list of top [[Calc]].
*/
private def isMergeable(topCalc: Calc, bottomCalc: Calc): Boolean = {
val topProgram = topCalc.getProgram
val bottomProgram = bottomCalc.getProgram

val topProjectInputIndices = topProgram.getProjectList
.map(r => topProgram.expandLocalRef(r))
.map(r => InputFinder.bits(r).toArray)

val topFilterInputIndices = if (topProgram.getCondition != null) {
InputFinder.bits(topProgram.expandLocalRef(topProgram.getCondition)).toArray
} else {
new Array[Int](0)
}

val bottomProjectList = bottomProgram.getProjectList
.map(r => bottomProgram.expandLocalRef(r))
.toArray

val topInputIndices = topProjectInputIndices :+ topFilterInputIndices

bottomProjectList.zipWithIndex.forall {
case (project: RexNode, index: Int) => {
var nonDeterministicRexRefCnt = 0
if (!RexUtil.isDeterministic(project)) {
topInputIndices.foreach(
indices =>
indices.foreach(
ref =>
if (ref == index) {
nonDeterministicRexRefCnt += 1
}))
}
nonDeterministicRexRefCnt <= 1
}
}
FlinkRexUtil.isMergeable(topCalc, bottomCalc)
}

override def onMatch(call: RelOptRuleCall): Unit = {
Expand Down
Expand Up @@ -20,6 +20,7 @@ package org.apache.flink.table.planner.plan.utils
import org.apache.flink.annotation.Experimental
import org.apache.flink.configuration.ConfigOption
import org.apache.flink.configuration.ConfigOptions.key
import org.apache.flink.table.planner.JList
import org.apache.flink.table.planner.functions.sql.SqlTryCastFunction
import org.apache.flink.table.planner.plan.utils.ExpressionDetail.ExpressionDetail
import org.apache.flink.table.planner.plan.utils.ExpressionFormat.ExpressionFormat
Expand All @@ -31,6 +32,7 @@ import org.apache.calcite.avatica.util.ByteString
import org.apache.calcite.plan.{RelOptPredicateList, RelOptUtil}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core.{Calc, Project}
import org.apache.calcite.rex._
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.{SqlAsOperator, SqlKind, SqlOperator}
Expand Down Expand Up @@ -312,6 +314,7 @@ object FlinkRexUtil {

/**
* Find all inputRefs.
*
* @return
* InputRef HashSet.
*/
Expand Down Expand Up @@ -612,6 +615,7 @@ object FlinkRexUtil {

/**
* Returns whether a given [[RexProgram]] is deterministic.
*
* @return
* false if any expression of the program is not deterministic
*/
Expand Down Expand Up @@ -650,6 +654,75 @@ object FlinkRexUtil {
rexBuilder,
converter);
}

/**
* Return two neighbouring [[Project]] can merge into one [[Project]] or not. If the two
* [[Project]] can merge into one, each non-deterministic [[RexNode]] of bottom [[Project]] should
* appear at most once in the project list of top [[Project]].
*/
def isMergeable(topProject: Project, bottomProject: Project): Boolean = {
val topInputRefCounter: Array[Int] =
Array.fill(topProject.getInput.getRowType.getFieldCount)(0)

mergeable(topInputRefCounter, topProject.getProjects, bottomProject.getProjects)
}

/**
* An InputRefCounter that count every inputRef's reference count number, every reference will be
* counted, e.g., '$0 + 1' & '$0 + 2' will count 2 instead of 1.
* @param deep
* @param refCounts
*/
private class InputRefCounter(deep: Boolean, val refCounts: Array[Int])
extends RexVisitorImpl[Void](deep: Boolean) {
override def visitInputRef(inputRef: RexInputRef): Void = {
val index = inputRef.getIndex
refCounts(index) += 1
null
}
}

private def mergeable(
topInputRefCounter: Array[Int],
topProjects: JList[RexNode],
bottomProjects: JList[RexNode]): Boolean = {
RexUtil.apply(new InputRefCounter(true, topInputRefCounter), topProjects, null)

bottomProjects.zipWithIndex
.map {
case (p, idx) =>
if (!RexUtil.isDeterministic(p)) {
topInputRefCounter(idx)
} else 0
}
.forall(cnt => cnt <= 1)
}

/**
* Return two neighbouring [[Calc]] can merge into one [[Calc]] or not. If the two [[Calc]] can
* merge into one, each non-deterministic [[RexNode]] of bottom [[Calc]] should appear at most
* once in the project list and filter list of top [[Calc]].
*/
def isMergeable(topCalc: Calc, bottomCalc: Calc): Boolean = {
val topProgram = topCalc.getProgram
val bottomProgram = bottomCalc.getProgram
val topInputRefCounter: Array[Int] =
Array.fill(topCalc.getInput.getRowType.getFieldCount)(0)

val topInputRefs = if (null != topProgram.getCondition) {
topProgram.getProjectList.map(topProgram.expandLocalRef) :+ topProgram.expandLocalRef(
topProgram.getCondition)
} else {
topProgram.getProjectList.map(topProgram.expandLocalRef)
}.toList

mergeable(
topInputRefCounter,
topInputRefs.toList,
bottomProgram.getProjectList
.map(bottomProgram.expandLocalRef)
.toList)
}
}

/**
Expand Down
Expand Up @@ -360,16 +360,16 @@ public void testProjectionIncludingOnlyMetadata() {

private void replaceProgramWithProjectMergeRule() {
FlinkChainedProgram programs = new FlinkChainedProgram<BatchOptimizeContext>();
programs
.addLast(
"rules",
FlinkHepRuleSetProgramBuilder.<BatchOptimizeContext>newBuilder()
.setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE())
.setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
.add(RuleSets.ofList(
programs.addLast(
"rules",
FlinkHepRuleSetProgramBuilder.<BatchOptimizeContext>newBuilder()
.setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE())
.setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
.add(
RuleSets.ofList(
CoreRules.PROJECT_MERGE,
PushProjectIntoTableSourceScanRule.INSTANCE))
.build());
.build());
util().replaceBatchProgram(programs);
}

Expand Down
Expand Up @@ -30,8 +30,9 @@ LogicalProject(a2=[random_udf($0)], a3=[random_udf($0)])
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[random_udf(random_udf(a)) AS a2, random_udf(random_udf(a)) AS a3], where=[(b > 10)])
+- TableSourceScan(table=[[default_catalog, default_database, MyTable, filter=[], project=[a, b], metadata=[]]], fields=[a, b])
Calc(select=[random_udf(a1) AS a2, random_udf(a1) AS a3])
+- Calc(select=[random_udf(a) AS a1, b], where=[(b > 10)])
+- TableSourceScan(table=[[default_catalog, default_database, MyTable, filter=[], project=[a, b], metadata=[]]], fields=[a, b])
]]>
</Resource>
</TestCase>
Expand Down

0 comments on commit 590a207

Please sign in to comment.