diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java index e1036af92088c5..8d58732eef9c94 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java @@ -483,6 +483,14 @@ public Map> getConsumerIdToFilters() { return this.statementContext.getConsumerIdToFilters(); } + public void putConsumerIdToLimitRows(RelationId id, long rows) { + this.statementContext.getConsumerIdToLimitRows().merge(id, rows, Math::max); + } + + public Map getConsumerIdToLimitRows() { + return this.statementContext.getConsumerIdToLimitRows(); + } + public void addCTEConsumerGroup(CTEId cteId, Group g, Multimap producerSlotToConsumerSlot) { List, Group>> consumerGroups = this.statementContext.getCteIdToConsumerGroup().computeIfAbsent(cteId, k -> new ArrayList<>()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java index 2af64752cc47aa..f7d6d0c44e3e42 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java @@ -168,6 +168,7 @@ public enum TableFrom { private final Map> cteIdToProducer = new HashMap<>(); private final Map> consumerIdToFilters = new HashMap<>(); + private final Map consumerIdToLimitRows = new HashMap<>(); // Used to update consumer's stats private final Map, Group>>> cteIdToConsumerGroup = new HashMap<>(); private final Map rewrittenCteProducer = new HashMap<>(); @@ -643,6 +644,10 @@ public Map> getConsumerIdToFilters() { return consumerIdToFilters; } + public Map getConsumerIdToLimitRows() { + return consumerIdToLimitRows; + } + public PlaceholderId getNextPlaceholderId() { return placeHolderIdGenerator.getNextId(); } @@ -673,6 +678,7 @@ public void clearCteEnvironment() { cteIdToOutputIds.clear(); cteIdToProducer.clear(); consumerIdToFilters.clear(); + consumerIdToLimitRows.clear(); cteIdToConsumerGroup.clear(); rewrittenCteProducer.clear(); rewrittenCteConsumer.clear(); @@ -687,6 +693,7 @@ public CteEnvironmentSnapshot cacheCteEnvironment() { copyMapOfSets(cteIdToOutputIds), new HashMap<>(cteIdToProducer), copyMapOfSets(consumerIdToFilters), + new HashMap<>(consumerIdToLimitRows), copyMapOfLists(cteIdToConsumerGroup), new HashMap<>(rewrittenCteProducer), new HashMap<>(rewrittenCteConsumer)); @@ -706,6 +713,9 @@ public void restoreCteEnvironment(CteEnvironmentSnapshot snapshot) { consumerIdToFilters.clear(); consumerIdToFilters.putAll(snapshot.consumerIdToFilters); + consumerIdToLimitRows.clear(); + consumerIdToLimitRows.putAll(snapshot.consumerIdToLimitRows); + cteIdToConsumerGroup.clear(); cteIdToConsumerGroup.putAll(snapshot.cteIdToConsumerGroup); @@ -738,6 +748,7 @@ public static class CteEnvironmentSnapshot { private final Map> cteIdToOutputIds; private final Map> cteIdToProducer; private final Map> consumerIdToFilters; + private final Map consumerIdToLimitRows; private final Map, Group>>> cteIdToConsumerGroup; private final Map rewrittenCteProducer; private final Map rewrittenCteConsumer; @@ -750,6 +761,7 @@ public CteEnvironmentSnapshot( Map> cteIdToOutputIds, Map> cteIdToProducer, Map> consumerIdToFilters, + Map consumerIdToLimitRows, Map, Group>>> cteIdToConsumerGroup, Map rewrittenCteProducer, Map rewrittenCteConsumer) { @@ -757,6 +769,7 @@ public CteEnvironmentSnapshot( this.cteIdToOutputIds = cteIdToOutputIds; this.cteIdToProducer = cteIdToProducer; this.consumerIdToFilters = consumerIdToFilters; + this.consumerIdToLimitRows = consumerIdToLimitRows; this.cteIdToConsumerGroup = cteIdToConsumerGroup; this.rewrittenCteProducer = rewrittenCteProducer; this.rewrittenCteConsumer = rewrittenCteConsumer; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index e118b21b7e969a..f82ff9eeab491a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -52,6 +52,7 @@ import org.apache.doris.nereids.rules.rewrite.ClearContextStatus; import org.apache.doris.nereids.rules.rewrite.CollectCteConsumerOutput; import org.apache.doris.nereids.rules.rewrite.CollectFilterAboveConsumer; +import org.apache.doris.nereids.rules.rewrite.CollectLimitAboveConsumer; import org.apache.doris.nereids.rules.rewrite.CollectPredicateOnScan; import org.apache.doris.nereids.rules.rewrite.ColumnPruning; import org.apache.doris.nereids.rules.rewrite.ConstantPropagation; @@ -395,6 +396,7 @@ public class Rewriter extends AbstractBatchJobExecutor { topic("Push project and filter on cte consumer to cte producer", topDown( new CollectFilterAboveConsumer(), + new CollectLimitAboveConsumer(), new CollectCteConsumerOutput()) ), topic("eliminate join according unique or foreign key", @@ -775,6 +777,7 @@ public class Rewriter extends AbstractBatchJobExecutor { topic("Push project and filter on cte consumer to cte producer", topDown( new CollectFilterAboveConsumer(), + new CollectLimitAboveConsumer(), new CollectCteConsumerOutput() ) ), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 71e7514025e7db..26d40c59d1a6ae 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -400,6 +400,7 @@ public enum RuleType { CTE_INLINE(RuleTypeClass.REWRITE), REWRITE_CTE_CHILDREN(RuleTypeClass.REWRITE), COLLECT_FILTER_ABOVE_CTE_CONSUMER(RuleTypeClass.REWRITE), + COLLECT_LIMIT_ABOVE_CTE_CONSUMER(RuleTypeClass.REWRITE), INLINE_VIEW(RuleTypeClass.REWRITE), CHECK_PRIVILEGES(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CollectLimitAboveConsumer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CollectLimitAboveConsumer.java new file mode 100644 index 00000000000000..00b342f4a14989 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CollectLimitAboveConsumer.java @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * Collect limit rows needed by CTE consumers. + */ +public class CollectLimitAboveConsumer implements RewriteRuleFactory { + + @Override + public List buildRules() { + return ImmutableList.of( + logicalLimit(logicalCTEConsumer()).thenApply(ctx -> { + LogicalLimit limit = ctx.root; + collectLimitRows(ctx.cascadesContext, limit, limit.child()); + return ctx.root; + }).toRule(RuleType.COLLECT_LIMIT_ABOVE_CTE_CONSUMER), + logicalLimit(logicalProject(logicalCTEConsumer())) + .when(limit -> isRowPreservingProject(limit.child())) + .thenApply(ctx -> { + LogicalLimit> limit = ctx.root; + collectLimitRows(ctx.cascadesContext, limit, limit.child().child()); + return ctx.root; + }).toRule(RuleType.COLLECT_LIMIT_ABOVE_CTE_CONSUMER) + ); + } + + private void collectLimitRows(CascadesContext cascadesContext, LogicalLimit limit, + LogicalCTEConsumer cteConsumer) { + cascadesContext.putConsumerIdToLimitRows( + cteConsumer.getRelationId(), limit.getLimit() + limit.getOffset()); + } + + private boolean isRowPreservingProject(LogicalProject project) { + if (project.isDistinct()) { + return false; + } + for (NamedExpression expression : project.getProjects()) { + if (expression.containsType(TableGeneratingFunction.class)) { + return false; + } + } + return true; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/RewriteCteChildren.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/RewriteCteChildren.java index 29c83e24bf5e07..c957a9e853f68d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/RewriteCteChildren.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/RewriteCteChildren.java @@ -28,12 +28,14 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.LimitPhase; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.RelationId; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; @@ -124,6 +126,7 @@ public Plan visitLogicalCTEProducer(LogicalCTEProducer cteProduc } else { child = (LogicalPlan) cteProducer.child(); child = tryToConstructFilter(cascadesContext, cteProducer.getCteId(), child); + child = tryToConstructLimit(cascadesContext, cteProducer.getCteId(), child); Set producerOutputs = cascadesContext.getStatementContext() .getCteIdToOutputIds().get(cteProducer.getCteId()); if (producerOutputs != null && producerOutputs.size() < child.getOutput().size()) { @@ -162,6 +165,19 @@ private LogicalPlan pushPlanUnderAnchor(LogicalPlan plan) { return plan; } + private LogicalPlan tryToConstructLimit(CascadesContext cascadesContext, CTEId cteId, LogicalPlan child) { + Set consumers = cascadesContext.getCteIdToConsumers().get(cteId); + long limit = 0; + for (LogicalCTEConsumer consumer : consumers) { + Long rowsNeeded = cascadesContext.getConsumerIdToLimitRows().get(consumer.getRelationId()); + if (rowsNeeded == null) { + return child; + } + limit = Math.max(limit, rowsNeeded); + } + return pushPlanUnderAnchor(new LogicalLimit<>(limit, 0, LimitPhase.ORIGIN, child)); + } + /* * An expression can only be pushed down if it has filter expressions on all consumers that reference the slot. * For example, let's assume a producer has two consumers, consumer1 and consumer2: diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CollectLimitAboveConsumerTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CollectLimitAboveConsumerTest.java new file mode 100644 index 00000000000000..85d6f1126da0fe --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CollectLimitAboveConsumerTest.java @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.CTEId; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.generator.Unnest; +import org.apache.doris.nereids.trees.plans.LimitPhase; +import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.ConnectContext; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +/** + * Tests for {@link CollectLimitAboveConsumer}. + */ +class CollectLimitAboveConsumerTest { + + @Test + void testCollectDirectLimitRowsNeeded() { + LogicalOlapScan producerPlan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + LogicalCTEConsumer consumer = new LogicalCTEConsumer( + PlanConstructor.getNextRelationId(), new CTEId(1), "cte1", producerPlan); + LogicalLimit limit = new LogicalLimit<>(10, 5, LimitPhase.ORIGIN, consumer); + + CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), limit); + Rule rule = new CollectLimitAboveConsumer().buildRules().get(0); + rule.transform(limit, cascadesContext); + + Map collected = cascadesContext.getStatementContext().getConsumerIdToLimitRows(); + Assertions.assertEquals(15L, collected.get(consumer.getRelationId())); + } + + @Test + void testCollectLocalLimitRowsNeededWithoutAddingOffsetAgain() { + LogicalOlapScan producerPlan = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + LogicalCTEConsumer consumer = new LogicalCTEConsumer( + PlanConstructor.getNextRelationId(), new CTEId(2), "cte2", producerPlan); + LogicalLimit limit = new LogicalLimit<>(15, 0, LimitPhase.LOCAL, consumer); + + CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), limit); + Rule rule = new CollectLimitAboveConsumer().buildRules().get(0); + rule.transform(limit, cascadesContext); + + Map collected = cascadesContext.getStatementContext().getConsumerIdToLimitRows(); + Assertions.assertEquals(15L, collected.get(consumer.getRelationId())); + } + + @Test + void testKeepMaxRowsNeededWhenConsumerIsCollectedMultipleTimes() { + LogicalOlapScan producerPlan = PlanConstructor.newLogicalOlapScan(10, "t_merge", 0); + LogicalCTEConsumer consumer = new LogicalCTEConsumer( + PlanConstructor.getNextRelationId(), new CTEId(10), "cte_merge", producerPlan); + LogicalLimit highLimit = new LogicalLimit<>(20, 0, LimitPhase.ORIGIN, consumer); + LogicalLimit lowLimit = new LogicalLimit<>(3, 0, LimitPhase.ORIGIN, consumer); + + CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), highLimit); + Rule rule = new CollectLimitAboveConsumer().buildRules().get(0); + rule.transform(highLimit, cascadesContext); + rule.transform(lowLimit, cascadesContext); + + Map collected = cascadesContext.getStatementContext().getConsumerIdToLimitRows(); + Assertions.assertEquals(20L, collected.get(consumer.getRelationId())); + } + + @Test + void testCollectLimitAboveProjectRowsNeeded() { + LogicalOlapScan producerPlan = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + LogicalCTEConsumer consumer = new LogicalCTEConsumer( + PlanConstructor.getNextRelationId(), new CTEId(3), "cte3", producerPlan); + LogicalProject project = new LogicalProject<>( + ImmutableList.copyOf(consumer.getOutput()), consumer); + LogicalLimit> limit = new LogicalLimit<>( + 7, 0, LimitPhase.LOCAL, project); + + CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), limit); + List rules = new CollectLimitAboveConsumer().buildRules(); + rules.get(1).transform(limit, cascadesContext); + + Map collected = cascadesContext.getStatementContext().getConsumerIdToLimitRows(); + Assertions.assertEquals(7L, collected.get(consumer.getRelationId())); + } + + @Test + void testSkipLimitAboveProjectWithUnnest() { + LogicalOlapScan producerPlan = PlanConstructor.newLogicalOlapScan(3, "t4", 0); + LogicalCTEConsumer consumer = new LogicalCTEConsumer( + PlanConstructor.getNextRelationId(), new CTEId(4), "cte4", producerPlan); + SlotReference arr = new SlotReference("arr", ArrayType.of(IntegerType.INSTANCE)); + LogicalProject project = new LogicalProject<>( + ImmutableList.of(new Alias(new Unnest(arr), "a")), consumer); + LogicalLimit> limit = new LogicalLimit<>( + 3, 0, LimitPhase.ORIGIN, project); + + CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), limit); + List rules = new CollectLimitAboveConsumer().buildRules(); + Assertions.assertFalse(rules.get(1).getPattern().matchPlanTree(limit)); + + Map collected = cascadesContext.getStatementContext().getConsumerIdToLimitRows(); + Assertions.assertFalse(collected.containsKey(consumer.getRelationId())); + } + + @Test + void testSkipLimitAboveDistinctProject() { + LogicalOlapScan producerPlan = PlanConstructor.newLogicalOlapScan(4, "t5", 0); + LogicalCTEConsumer consumer = new LogicalCTEConsumer( + PlanConstructor.getNextRelationId(), new CTEId(5), "cte5", producerPlan); + LogicalProject project = new LogicalProject<>( + ImmutableList.copyOf(consumer.getOutput()), true, consumer); + LogicalLimit> limit = new LogicalLimit<>( + 3, 0, LimitPhase.ORIGIN, project); + + CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), limit); + List rules = new CollectLimitAboveConsumer().buildRules(); + Assertions.assertFalse(rules.get(1).getPattern().matchPlanTree(limit)); + + Map collected = cascadesContext.getStatementContext().getConsumerIdToLimitRows(); + Assertions.assertFalse(collected.containsKey(consumer.getRelationId())); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CteLimitPushdownPlanTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CteLimitPushdownPlanTest.java new file mode 100644 index 00000000000000..fae1be1f6d437d --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CteLimitPushdownPlanTest.java @@ -0,0 +1,167 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Test; + +/** + * Planner-level tests for CTE limit pushdown. + */ +class CteLimitPushdownPlanTest extends TestWithFeService implements MemoPatternMatchSupported { + + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + useDatabase("test"); + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + + createTable("CREATE TABLE cte_limit_pushdown_t (\n" + + " k1 int NULL,\n" + + " k2 int NULL\n" + + ") ENGINE=OLAP\n" + + "DISTRIBUTED BY HASH(k1) BUCKETS 1\n" + + "PROPERTIES (\n" + + " \"replication_allocation\" = \"tag.location.default: 1\"\n" + + ");"); + } + + @Override + protected void runBeforeEach() throws Exception { + StatementScopeIdGenerator.clear(); + } + + @Test + void testPushLimitWithOffsetToProducer() { + String sql = "WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_t) " + + "(SELECT * FROM cte LIMIT 10 OFFSET 5) " + + "UNION ALL " + + "(SELECT * FROM cte LIMIT 3)"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalCTEProducer( + logicalLimit().when(limit -> limit.getLimit() == 15 && limit.getOffset() == 0))); + } + + @Test + void testPushLimitBeforeProducerOutputPruning() { + String sql = "WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_t) " + + "(SELECT k1 FROM cte LIMIT 7) " + + "UNION ALL " + + "(SELECT k1 FROM cte LIMIT 3)"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalCTEProducer( + logicalLimit().when(limit -> limit.getLimit() == 7 && limit.getOffset() == 0))); + } + + @Test + void testPushMaxLimitForAllLimitedConsumers() { + String sql = "WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_t) " + + "(SELECT * FROM cte LIMIT 10 OFFSET 5) " + + "UNION ALL " + + "(SELECT * FROM cte LIMIT 20)"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalCTEProducer( + logicalLimit().when(limit -> limit.getLimit() == 20 && limit.getOffset() == 0))); + } + + @Test + void testSkipProducerLimitWhenAnyConsumerNeedsFullRows() { + String sql = "WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_t) " + + "(SELECT * FROM cte LIMIT 10) " + + "UNION ALL " + + "(SELECT * FROM cte)"; + + assertNoProducerLimit(sql); + } + + @Test + void testSkipProducerLimitWhenLimitIsAboveFilter() { + String sql = "WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_t) " + + "(SELECT * FROM cte WHERE k1 > 1 LIMIT 10) " + + "UNION ALL " + + "(SELECT * FROM cte LIMIT 3)"; + + assertNoProducerLimit(sql); + } + + @Test + void testSkipProducerLimitWhenConsumerUsesTopN() { + String sql = "WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_t) " + + "(SELECT * FROM (SELECT * FROM cte ORDER BY k1 LIMIT 10) topn_branch) " + + "UNION ALL " + + "(SELECT * FROM cte LIMIT 3)"; + + assertNoProducerLimit(sql); + } + + @Test + void testSkipProducerLimitWhenLimitIsAboveJoin() { + String sql = "WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_t) " + + "(SELECT c.k1, c.k2 FROM cte c " + + "JOIN cte_limit_pushdown_t t ON c.k1 = t.k1 LIMIT 10) " + + "UNION ALL " + + "(SELECT * FROM cte LIMIT 3)"; + + assertNoProducerLimit(sql); + } + + @Test + void testSkipProducerLimitWhenLimitIsAboveAggregate() { + String sql = "WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_t) " + + "(SELECT k1, COUNT(*) FROM cte GROUP BY k1 LIMIT 10) " + + "UNION ALL " + + "(SELECT * FROM cte LIMIT 3)"; + + assertNoProducerLimit(sql); + } + + @Test + void testSkipProducerLimitWhenLimitIsAboveWindow() { + String sql = "WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_t) " + + "(SELECT k1, rn FROM (" + + "SELECT k1, ROW_NUMBER() OVER (ORDER BY k1) AS rn FROM cte" + + ") window_branch LIMIT 10) " + + "UNION ALL " + + "(SELECT * FROM cte LIMIT 3)"; + + assertNoProducerLimit(sql); + } + + private void assertNoProducerLimit(String sql) { + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalCTEProducer()) + .nonMatch(logicalCTEProducer(logicalLimit())) + .nonMatch(logicalCTEProducer(logicalLimit(logicalProject()))) + .nonMatch(logicalCTEProducer(logicalProject(logicalLimit()))); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/RewriteCteChildrenLimitPushdownTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/RewriteCteChildrenLimitPushdownTest.java new file mode 100644 index 00000000000000..fb7a07e47b6227 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/RewriteCteChildrenLimitPushdownTest.java @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.trees.expressions.CTEId; +import org.apache.doris.nereids.trees.plans.LimitPhase; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.ConnectContext; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** + * Tests producer-side CTE limit construction in {@link RewriteCteChildren}. + */ +class RewriteCteChildrenLimitPushdownTest { + + @Test + void testPushMaxConsumerLimitToProducer() { + LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + CTEId cteId = new CTEId(1); + LogicalCTEProducer producer = new LogicalCTEProducer<>(cteId, scan); + LogicalCTEConsumer consumer1 = new LogicalCTEConsumer( + PlanConstructor.getNextRelationId(), cteId, "cte1", producer); + LogicalCTEConsumer consumer2 = new LogicalCTEConsumer( + PlanConstructor.getNextRelationId(), cteId, "cte1", producer); + CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), producer); + cascadesContext.getCteIdToConsumers().put(cteId, ImmutableSet.of(consumer1, consumer2)); + cascadesContext.putConsumerIdToLimitRows(consumer1.getRelationId(), 10L); + cascadesContext.putConsumerIdToLimitRows(consumer2.getRelationId(), 20L); + + Plan rewritten = new RewriteCteChildren(ImmutableList.of(), false) + .visitLogicalCTEProducer(producer, cascadesContext); + + LogicalCTEProducer rewrittenProducer = (LogicalCTEProducer) rewritten; + Assertions.assertInstanceOf(LogicalLimit.class, rewrittenProducer.child()); + LogicalLimit limit = (LogicalLimit) rewrittenProducer.child(); + Assertions.assertEquals(20L, limit.getLimit()); + Assertions.assertEquals(0L, limit.getOffset()); + Assertions.assertEquals(LimitPhase.ORIGIN, limit.getPhase()); + } + + @Test + void testSkipProducerLimitWhenAnyConsumerHasNoLimit() { + LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + CTEId cteId = new CTEId(2); + LogicalCTEProducer producer = new LogicalCTEProducer<>(cteId, scan); + LogicalCTEConsumer consumer1 = new LogicalCTEConsumer( + PlanConstructor.getNextRelationId(), cteId, "cte2", producer); + LogicalCTEConsumer consumer2 = new LogicalCTEConsumer( + PlanConstructor.getNextRelationId(), cteId, "cte2", producer); + CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), producer); + cascadesContext.getCteIdToConsumers().put(cteId, ImmutableSet.of(consumer1, consumer2)); + cascadesContext.putConsumerIdToLimitRows(consumer1.getRelationId(), 10L); + + Plan rewritten = new RewriteCteChildren(ImmutableList.of(), false) + .visitLogicalCTEProducer(producer, cascadesContext); + + LogicalCTEProducer rewrittenProducer = (LogicalCTEProducer) rewritten; + Assertions.assertSame(scan, rewrittenProducer.child()); + } +} diff --git a/regression-test/suites/nereids_rules_p0/cte_limit_pushdown/test_cte_limit_pushdown.groovy b/regression-test/suites/nereids_rules_p0/cte_limit_pushdown/test_cte_limit_pushdown.groovy new file mode 100644 index 00000000000000..863b1b1a3f1b31 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/cte_limit_pushdown/test_cte_limit_pushdown.groovy @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_cte_limit_pushdown") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql "SET disable_nereids_rules='PRUNE_EMPTY_PARTITION'" + + sql "DROP TABLE IF EXISTS cte_limit_pushdown_regression_t" + sql """ + CREATE TABLE cte_limit_pushdown_regression_t ( + k1 int NULL, + k2 int NULL + ) + DUPLICATE KEY(k1) + DISTRIBUTED BY HASH(k1) BUCKETS 1 + PROPERTIES ( + "replication_num" = "1" + ) + """ + + sql """ + INSERT INTO cte_limit_pushdown_regression_t VALUES + (1, 10), (2, 20), (3, 30), (4, 40), (5, 50), (6, 60) + """ + + def cteProducerFragment = { explainString -> + int multicast = explainString.indexOf("MultiCastDataSinks") + assert multicast >= 0 + int fragmentStart = explainString.lastIndexOf("PLAN FRAGMENT", multicast) + assert fragmentStart >= 0 + int fragmentEnd = explainString.indexOf("PLAN FRAGMENT", multicast + 1) + if (fragmentEnd < 0) { + fragmentEnd = explainString.length() + } + return explainString.substring(fragmentStart, fragmentEnd) + } + + def cteProducerSourceBlock = { explainString -> + int multicast = explainString.indexOf("MultiCastDataSinks") + assert multicast >= 0 + int scanStart = explainString.indexOf(":VOlapScanNode", multicast) + assert scanStart >= 0 + int nextFragment = explainString.indexOf("PLAN FRAGMENT", scanStart + 1) + if (nextFragment < 0) { + nextFragment = explainString.length() + } + return explainString.substring(scanStart, nextFragment) + } + + def hasExactLimit = { planBlock, expectedLimit -> + return planBlock.readLines().any { line -> line.trim() == "limit: ${expectedLimit}" } + } + + def hasAnyLimit = { planBlock -> + return planBlock.readLines().any { line -> line.trim().startsWith("limit: ") } + } + + def assertProducerLimit = { explainString, expectedLimit -> + String producerFragment = cteProducerFragment(explainString) + String producerSource = cteProducerSourceBlock(explainString) + assert producerFragment.contains("MultiCastDataSinks") + assert producerSource.contains("cte_limit_pushdown_regression_t") + assert hasExactLimit(producerFragment, expectedLimit) + assert hasExactLimit(producerSource, expectedLimit) + return true + } + + def assertNoProducerLimit = { explainString -> + String producerFragment = cteProducerFragment(explainString) + String producerSource = cteProducerSourceBlock(explainString) + assert producerFragment.contains("MultiCastDataSinks") + assert producerSource.contains("cte_limit_pushdown_regression_t") + assert !hasAnyLimit(producerFragment) + assert !hasAnyLimit(producerSource) + return true + } + + explain { + sql """ + WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_regression_t) + (SELECT * FROM cte LIMIT 10 OFFSET 5) + UNION ALL + (SELECT * FROM cte LIMIT 3) + """ + check { explainString -> assertProducerLimit(explainString, 15) } + } + + explain { + sql """ + WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_regression_t) + (SELECT k1 FROM cte LIMIT 7) + UNION ALL + (SELECT k1 FROM cte LIMIT 3) + """ + check { explainString -> assertProducerLimit(explainString, 7) } + } + + explain { + sql """ + WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_regression_t) + (SELECT * FROM cte LIMIT 10) + UNION ALL + (SELECT * FROM cte) + """ + check { explainString -> assertNoProducerLimit(explainString) } + } + + explain { + sql """ + WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_regression_t) + (SELECT * FROM cte WHERE k1 > 1 LIMIT 10) + UNION ALL + (SELECT * FROM cte LIMIT 3) + """ + check { explainString -> assertNoProducerLimit(explainString) } + } + + explain { + sql """ + WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_regression_t) + (SELECT * FROM (SELECT * FROM cte ORDER BY k1 LIMIT 10) topn_branch) + UNION ALL + (SELECT * FROM cte LIMIT 3) + """ + check { explainString -> assertNoProducerLimit(explainString) } + } + + explain { + sql """ + WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_regression_t) + (SELECT c.k1, c.k2 FROM cte c + JOIN cte_limit_pushdown_regression_t t ON c.k1 = t.k1 LIMIT 10) + UNION ALL + (SELECT * FROM cte LIMIT 3) + """ + check { explainString -> assertNoProducerLimit(explainString) } + } + + explain { + sql """ + WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_regression_t) + (SELECT k1, COUNT(*) FROM cte GROUP BY k1 LIMIT 10) + UNION ALL + (SELECT * FROM cte LIMIT 3) + """ + check { explainString -> assertNoProducerLimit(explainString) } + } + + explain { + sql """ + WITH cte AS (SELECT k1, k2 FROM cte_limit_pushdown_regression_t) + (SELECT k1, rn FROM ( + SELECT k1, ROW_NUMBER() OVER (ORDER BY k1) AS rn FROM cte + ) window_branch LIMIT 10) + UNION ALL + (SELECT * FROM cte LIMIT 3) + """ + check { explainString -> assertNoProducerLimit(explainString) } + } +}