Skip to content

Commit

Permalink
[FLINK-34702][planner] When converting Rank to Deduplicate, it is nec…
Browse files Browse the repository at this point in the history
…essary to consider whether the upstream produces changelog
  • Loading branch information
liuyongvs committed Mar 20, 2024
1 parent 841f23c commit 77f13a9
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 109 deletions.
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.rules.physical.stream;

import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalDeduplicate;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalRank;
import org.apache.flink.table.planner.plan.utils.ChangelogPlanUtils;
import org.apache.flink.table.planner.plan.utils.RankUtil;

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.type.RelDataType;
import org.immutables.value.Value;

/**
* Rule that matches {@link StreamPhysicalRank} which is sorted by time attribute and limits 1 and
* its rank type is ROW_NUMBER and input doesn't produce changelog, and converts it to {@link
* StreamPhysicalDeduplicate}.
*
* <p>NOTES: Queries that can be converted to {@link StreamPhysicalDeduplicate} could be converted
* to {@link StreamPhysicalRank} too. {@link StreamPhysicalDeduplicate} is more efficient than
* {@link StreamPhysicalRank} due to mini-batch and less state access.
*
* <p>e.g. 1. {@code SELECT a, b, c FROM ( SELECT a, b, c, proctime, ROW_NUMBER() OVER (PARTITION BY
* a ORDER BY proctime ASC) as row_num FROM MyTable ) WHERE row_num <= 1 } will be converted to
* StreamExecDeduplicate which keeps first row in proctime.
*
* <p>2. {@code SELECT a, b, c FROM ( SELECT a, b, c, rowtime, ROW_NUMBER() OVER (PARTITION BY a
* ORDER BY rowtime DESC) as row_num FROM MyTable ) WHERE row_num <= 1 } will be converted to
* StreamExecDeduplicate which keeps last row in rowtime.
*/
@Value.Enclosing
public class StreamPhysicalDeduplicateRule
extends RelRule<StreamPhysicalDeduplicateRule.StreamPhysicalDeduplicateRuleConfig> {

public static final StreamPhysicalDeduplicateRule INSTANCE =
StreamPhysicalDeduplicateRule.StreamPhysicalDeduplicateRuleConfig.DEFAULT.toRule();

private StreamPhysicalDeduplicateRule(StreamPhysicalDeduplicateRuleConfig config) {
super(config);
}

@Override
public boolean matches(RelOptRuleCall call) {
StreamPhysicalRank rank = call.rel(0);
return ChangelogPlanUtils.inputInsertOnly(rank) && RankUtil.canConvertToDeduplicate(rank);
}

@Override
public void onMatch(RelOptRuleCall call) {
StreamPhysicalRank rank = call.rel(0);

// order by timeIndicator desc ==> lastRow, otherwise is firstRow
RelFieldCollation fieldCollation = rank.orderKey().getFieldCollations().get(0);
boolean isLastRow = fieldCollation.direction.isDescending();

RelDataType fieldType =
rank.getInput()
.getRowType()
.getFieldList()
.get(fieldCollation.getFieldIndex())
.getType();
boolean isRowtime = FlinkTypeFactory.isRowtimeIndicatorType(fieldType);

StreamPhysicalDeduplicate deduplicate =
new StreamPhysicalDeduplicate(
rank.getCluster(),
rank.getTraitSet(),
rank.getInput(),
rank.partitionKey().toArray(),
isRowtime,
isLastRow);
call.transformTo(deduplicate);
}

/** Rule configuration. */
@Value.Immutable(singleton = false)
public interface StreamPhysicalDeduplicateRuleConfig extends RelRule.Config {
StreamPhysicalDeduplicateRule.StreamPhysicalDeduplicateRuleConfig DEFAULT =
ImmutableStreamPhysicalDeduplicateRule.StreamPhysicalDeduplicateRuleConfig.builder()
.build()
.withOperandSupplier(b0 -> b0.operand(StreamPhysicalRank.class).anyInputs())
.withDescription("StreamPhysicalDeduplicateRule");

@Override
default StreamPhysicalDeduplicateRule toRule() {
return new StreamPhysicalDeduplicateRule(this);
}
}
}
Expand Up @@ -296,6 +296,14 @@ object FlinkStreamProgram {
"watermark transpose"
)
.addProgram(new FlinkChangelogModeInferenceProgram, "Changelog mode inference")
.addProgram(
FlinkHepRuleSetProgramBuilder.newBuilder
.setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_COLLECTION)
.setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
.add(FlinkStreamRuleSets.PHYSICAL_DEDUPLICATE_REWRITE)
.build(),
"physical rewrite rank with deduplicate"
)
.addProgram(
new FlinkMiniBatchIntervalTraitInitProgram,
"Initialization for mini-batch interval inference")
Expand Down
Expand Up @@ -441,7 +441,6 @@ object FlinkStreamRuleSets {
StreamPhysicalTemporalSortRule.INSTANCE,
// rank
StreamPhysicalRankRule.INSTANCE,
StreamPhysicalDeduplicateRule.INSTANCE,
// expand
StreamPhysicalExpandRule.INSTANCE,
// group agg
Expand Down Expand Up @@ -487,6 +486,15 @@ object FlinkStreamRuleSets {
WatermarkAssignerChangelogNormalizeTransposeRule.WITHOUT_CALC
)

/**
* RuleSet to rewrite rank with deduplicate. And it should be before the rule of mini-batch
* interval and after changelog mode inference.
*/
val PHYSICAL_DEDUPLICATE_REWRITE: RuleSet = RuleSets.ofList(
// optimize rank rule
StreamPhysicalDeduplicateRule.INSTANCE
)

/** RuleSet related to mini-batch. */
val MINI_BATCH_RULES: RuleSet = RuleSets.ofList(
// mini-batch interval infer rule
Expand Down

This file was deleted.

Expand Up @@ -34,11 +34,6 @@ import org.apache.calcite.rel.convert.ConverterRule.Config
*/
class StreamPhysicalRankRule(config: Config) extends ConverterRule(config) {

override def matches(call: RelOptRuleCall): Boolean = {
val rank: FlinkLogicalRank = call.rel(0)
!RankUtil.canConvertToDeduplicate(rank)
}

override def convert(rel: RelNode): RelNode = {
val rank = rel.asInstanceOf[FlinkLogicalRank]
val input = rank.getInput
Expand Down
Expand Up @@ -23,6 +23,7 @@ import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalR
import org.apache.flink.table.planner.plan.optimize.program.FlinkChangelogModeInferenceProgram
import org.apache.flink.types.RowKind

import org.apache.calcite.plan.hep.HepRelVertex
import org.apache.calcite.rel.RelNode

import scala.collection.JavaConversions._
Expand All @@ -46,7 +47,11 @@ object ChangelogPlanUtils {
* <p>Note: this method must be called after [[FlinkChangelogModeInferenceProgram]] is applied.
*/
def inputInsertOnly(node: StreamPhysicalRel): Boolean = {
node.getInputs.forall { case input: StreamPhysicalRel => isInsertOnly(input) }
node.getInputs.forall {
case input: StreamPhysicalRel => isInsertOnly(input)
case hepRelVertex: HepRelVertex =>
isInsertOnly(hepRelVertex.getCurrentRel.asInstanceOf[StreamPhysicalRel])
}
}

/**
Expand Down
Expand Up @@ -22,17 +22,15 @@ import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.ExpressionReducer
import org.apache.flink.table.planner.plan.nodes.calcite.Rank
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalRank
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalDeduplicate
import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamPhysicalDeduplicate, StreamPhysicalRank}
import org.apache.flink.table.runtime.operators.rank.{ConstantRankRange, ConstantRankRangeWithoutEnd, RankRange, RankType, VariableRankRange}

import org.apache.calcite.plan.RelOptUtil
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.RelCollation
import org.apache.calcite.rex.{RexBuilder, RexCall, RexInputRef, RexLiteral, RexNode, RexUtil}
import org.apache.calcite.sql.SqlKind

import java.util

import scala.collection.JavaConversions._

/** Util for [[Rank]]s. */
Expand Down Expand Up @@ -357,6 +355,35 @@ object RankUtil {
!rank.outputRankNumber && isLimit1 && isSortOnTimeAttribute && isRowNumberType
}

/**
* Whether the given rank could be converted to [[StreamPhysicalDeduplicate]].
*
* Returns true if the given rank is sorted by time attribute and limits 1 and its RankFunction is
* ROW_NUMBER, else false.
*
* @param rank
* The [[StreamPhysicalRank]] node
* @return
* True if the input rank could be converted to [[StreamPhysicalDeduplicate]]
*/
def canConvertToDeduplicate(rank: StreamPhysicalRank): Boolean = {
val sortCollation = rank.orderKey
val rankRange = rank.rankRange

val isRowNumberType = rank.rankType == RankType.ROW_NUMBER

val isLimit1 = rankRange match {
case rankRange: ConstantRankRange =>
rankRange.getRankStart == 1 && rankRange.getRankEnd == 1
case _ => false
}

val inputRowType = rank.getInput.getRowType
val isSortOnTimeAttribute = sortOnTimeAttribute(sortCollation, inputRowType)

!rank.outputRankNumber && isLimit1 && isSortOnTimeAttribute && isRowNumberType
}

private def sortOnTimeAttribute(
sortCollation: RelCollation,
inputRowType: RelDataType): Boolean = {
Expand Down
Expand Up @@ -360,6 +360,41 @@ Calc(select=[a, b, c, PROCTIME_MATERIALIZE(proctime) AS proctime, rowtime, 1 AS
+- Deduplicate(keep=[LastRow], key=[a], order=[PROCTIME])
+- Exchange(distribution=[hash[a]])
+- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
]]>
</Resource>
</TestCase>
<TestCase name="testRankConsumeChangelog">
<Resource name="sql">
<![CDATA[
SELECT *
FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY a ORDER BY PROCTIME() ASC) as rowNum
FROM (SELECT a, count(b) as b FROM MyTable GROUP BY a)
)
WHERE rowNum = 1
]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(a=[$0], b=[$1], rowNum=[$2])
+- LogicalFilter(condition=[=($2, 1)])
+- LogicalProject(a=[$0], b=[$1], rowNum=[ROW_NUMBER() OVER (PARTITION BY $0 ORDER BY PROCTIME() NULLS FIRST)])
+- LogicalAggregate(group=[{0}], b=[COUNT($1)])
+- LogicalProject(a=[$0], b=[$1])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[a, b, 1 AS $2])
+- Rank(strategy=[RetractStrategy], rankType=[ROW_NUMBER], rankRange=[rankStart=1, rankEnd=1], partitionBy=[a], orderBy=[$2 ASC], select=[a, b, $2])
+- Exchange(distribution=[hash[a]])
+- Calc(select=[a, b, PROCTIME() AS $2])
+- GroupAggregate(groupBy=[a], select=[a, COUNT(b) AS b])
+- Exchange(distribution=[hash[a]])
+- Calc(select=[a, b])
+- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
]]>
</Resource>
</TestCase>
Expand Down
Expand Up @@ -272,4 +272,18 @@ class DeduplicateTest extends TableTestBase {
util.verifyExecPlan(sqlQuery)
}

@Test
def testRankConsumeChangelog(): Unit = {
val sqlQuery =
"""
|SELECT *
|FROM (
| SELECT *,
| ROW_NUMBER() OVER (PARTITION BY a ORDER BY PROCTIME() ASC) as rowNum
| FROM (SELECT a, count(b) as b FROM MyTable GROUP BY a)
|)
|WHERE rowNum = 1
""".stripMargin
util.verifyExecPlan(sqlQuery)
}
}

0 comments on commit 77f13a9

Please sign in to comment.