Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,14 @@ public Map<RelationId, Set<Expression>> getConsumerIdToFilters() {
return this.statementContext.getConsumerIdToFilters();
}

public void putConsumerIdToLimitRows(RelationId id, long rows) {
this.statementContext.getConsumerIdToLimitRows().merge(id, rows, Math::max);
}

public Map<RelationId, Long> getConsumerIdToLimitRows() {
return this.statementContext.getConsumerIdToLimitRows();
}

public void addCTEConsumerGroup(CTEId cteId, Group g, Multimap<Slot, Slot> producerSlotToConsumerSlot) {
List<Pair<Multimap<Slot, Slot>, Group>> consumerGroups =
this.statementContext.getCteIdToConsumerGroup().computeIfAbsent(cteId, k -> new ArrayList<>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ public enum TableFrom {
private final Map<CTEId, LogicalCTEProducer<? extends Plan>> cteIdToProducer = new HashMap<>();

private final Map<RelationId, Set<Expression>> consumerIdToFilters = new HashMap<>();
private final Map<RelationId, Long> consumerIdToLimitRows = new HashMap<>();
// Used to update consumer's stats
private final Map<CTEId, List<Pair<Multimap<Slot, Slot>, Group>>> cteIdToConsumerGroup = new HashMap<>();
private final Map<CTEId, LogicalPlan> rewrittenCteProducer = new HashMap<>();
Expand Down Expand Up @@ -643,6 +644,10 @@ public Map<RelationId, Set<Expression>> getConsumerIdToFilters() {
return consumerIdToFilters;
}

public Map<RelationId, Long> getConsumerIdToLimitRows() {
return consumerIdToLimitRows;
}

public PlaceholderId getNextPlaceholderId() {
return placeHolderIdGenerator.getNextId();
}
Expand Down Expand Up @@ -673,6 +678,7 @@ public void clearCteEnvironment() {
cteIdToOutputIds.clear();
cteIdToProducer.clear();
consumerIdToFilters.clear();
consumerIdToLimitRows.clear();
cteIdToConsumerGroup.clear();
rewrittenCteProducer.clear();
rewrittenCteConsumer.clear();
Expand All @@ -687,6 +693,7 @@ public CteEnvironmentSnapshot cacheCteEnvironment() {
copyMapOfSets(cteIdToOutputIds),
new HashMap<>(cteIdToProducer),
copyMapOfSets(consumerIdToFilters),
new HashMap<>(consumerIdToLimitRows),
copyMapOfLists(cteIdToConsumerGroup),
new HashMap<>(rewrittenCteProducer),
new HashMap<>(rewrittenCteConsumer));
Expand All @@ -706,6 +713,9 @@ public void restoreCteEnvironment(CteEnvironmentSnapshot snapshot) {
consumerIdToFilters.clear();
consumerIdToFilters.putAll(snapshot.consumerIdToFilters);

consumerIdToLimitRows.clear();
consumerIdToLimitRows.putAll(snapshot.consumerIdToLimitRows);

cteIdToConsumerGroup.clear();
cteIdToConsumerGroup.putAll(snapshot.cteIdToConsumerGroup);

Expand Down Expand Up @@ -738,6 +748,7 @@ public static class CteEnvironmentSnapshot {
private final Map<CTEId, Set<Slot>> cteIdToOutputIds;
private final Map<CTEId, LogicalCTEProducer<? extends Plan>> cteIdToProducer;
private final Map<RelationId, Set<Expression>> consumerIdToFilters;
private final Map<RelationId, Long> consumerIdToLimitRows;
private final Map<CTEId, List<Pair<Multimap<Slot, Slot>, Group>>> cteIdToConsumerGroup;
private final Map<CTEId, LogicalPlan> rewrittenCteProducer;
private final Map<CTEId, LogicalPlan> rewrittenCteConsumer;
Expand All @@ -750,13 +761,15 @@ public CteEnvironmentSnapshot(
Map<CTEId, Set<Slot>> cteIdToOutputIds,
Map<CTEId, LogicalCTEProducer<? extends Plan>> cteIdToProducer,
Map<RelationId, Set<Expression>> consumerIdToFilters,
Map<RelationId, Long> consumerIdToLimitRows,
Map<CTEId, List<Pair<Multimap<Slot, Slot>, Group>>> cteIdToConsumerGroup,
Map<CTEId, LogicalPlan> rewrittenCteProducer,
Map<CTEId, LogicalPlan> rewrittenCteConsumer) {
this.cteIdToConsumers = cteIdToConsumers;
this.cteIdToOutputIds = cteIdToOutputIds;
this.cteIdToProducer = cteIdToProducer;
this.consumerIdToFilters = consumerIdToFilters;
this.consumerIdToLimitRows = consumerIdToLimitRows;
this.cteIdToConsumerGroup = cteIdToConsumerGroup;
this.rewrittenCteProducer = rewrittenCteProducer;
this.rewrittenCteConsumer = rewrittenCteConsumer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.apache.doris.nereids.rules.rewrite.ClearContextStatus;
import org.apache.doris.nereids.rules.rewrite.CollectCteConsumerOutput;
import org.apache.doris.nereids.rules.rewrite.CollectFilterAboveConsumer;
import org.apache.doris.nereids.rules.rewrite.CollectLimitAboveConsumer;
import org.apache.doris.nereids.rules.rewrite.CollectPredicateOnScan;
import org.apache.doris.nereids.rules.rewrite.ColumnPruning;
import org.apache.doris.nereids.rules.rewrite.ConstantPropagation;
Expand Down Expand Up @@ -395,6 +396,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
topic("Push project and filter on cte consumer to cte producer",
topDown(
new CollectFilterAboveConsumer(),
new CollectLimitAboveConsumer(),
new CollectCteConsumerOutput())
),
topic("eliminate join according unique or foreign key",
Expand Down Expand Up @@ -775,6 +777,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
topic("Push project and filter on cte consumer to cte producer",
topDown(
new CollectFilterAboveConsumer(),
new CollectLimitAboveConsumer(),
new CollectCteConsumerOutput()
)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ public enum RuleType {
CTE_INLINE(RuleTypeClass.REWRITE),
REWRITE_CTE_CHILDREN(RuleTypeClass.REWRITE),
COLLECT_FILTER_ABOVE_CTE_CONSUMER(RuleTypeClass.REWRITE),
COLLECT_LIMIT_ABOVE_CTE_CONSUMER(RuleTypeClass.REWRITE),
INLINE_VIEW(RuleTypeClass.REWRITE),
CHECK_PRIVILEGES(RuleTypeClass.REWRITE),

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* Collect limit rows needed by CTE consumers.
*/
public class CollectLimitAboveConsumer implements RewriteRuleFactory {

@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalLimit(logicalCTEConsumer()).thenApply(ctx -> {
LogicalLimit<LogicalCTEConsumer> limit = ctx.root;
collectLimitRows(ctx.cascadesContext, limit, limit.child());
return ctx.root;
}).toRule(RuleType.COLLECT_LIMIT_ABOVE_CTE_CONSUMER),
logicalLimit(logicalProject(logicalCTEConsumer()))
.when(limit -> isRowPreservingProject(limit.child()))
.thenApply(ctx -> {
LogicalLimit<LogicalProject<LogicalCTEConsumer>> limit = ctx.root;
collectLimitRows(ctx.cascadesContext, limit, limit.child().child());
return ctx.root;
}).toRule(RuleType.COLLECT_LIMIT_ABOVE_CTE_CONSUMER)
);
}

private void collectLimitRows(CascadesContext cascadesContext, LogicalLimit<?> limit,
LogicalCTEConsumer cteConsumer) {
cascadesContext.putConsumerIdToLimitRows(
cteConsumer.getRelationId(), limit.getLimit() + limit.getOffset());
}

private boolean isRowPreservingProject(LogicalProject<?> project) {
if (project.isDistinct()) {
return false;
}
for (NamedExpression expression : project.getProjects()) {
if (expression.containsType(TableGeneratingFunction.class)) {
return false;
}
}
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.LimitPhase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
Expand Down Expand Up @@ -124,6 +126,7 @@ public Plan visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> cteProduc
} else {
child = (LogicalPlan) cteProducer.child();
child = tryToConstructFilter(cascadesContext, cteProducer.getCteId(), child);
child = tryToConstructLimit(cascadesContext, cteProducer.getCteId(), child);
Set<Slot> producerOutputs = cascadesContext.getStatementContext()
.getCteIdToOutputIds().get(cteProducer.getCteId());
if (producerOutputs != null && producerOutputs.size() < child.getOutput().size()) {
Expand Down Expand Up @@ -162,6 +165,19 @@ private LogicalPlan pushPlanUnderAnchor(LogicalPlan plan) {
return plan;
}

private LogicalPlan tryToConstructLimit(CascadesContext cascadesContext, CTEId cteId, LogicalPlan child) {
Set<LogicalCTEConsumer> consumers = cascadesContext.getCteIdToConsumers().get(cteId);
long limit = 0;
for (LogicalCTEConsumer consumer : consumers) {
Long rowsNeeded = cascadesContext.getConsumerIdToLimitRows().get(consumer.getRelationId());
if (rowsNeeded == null) {
return child;
}
limit = Math.max(limit, rowsNeeded);
}
return pushPlanUnderAnchor(new LogicalLimit<>(limit, 0, LimitPhase.ORIGIN, child));
}

/*
* An expression can only be pushed down if it has filter expressions on all consumers that reference the slot.
* For example, let's assume a producer has two consumers, consumer1 and consumer2:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.generator.Unnest;
import org.apache.doris.nereids.trees.plans.LimitPhase;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Map;

/**
* Tests for {@link CollectLimitAboveConsumer}.
*/
class CollectLimitAboveConsumerTest {

@Test
void testCollectDirectLimitRowsNeeded() {
LogicalOlapScan producerPlan = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
LogicalCTEConsumer consumer = new LogicalCTEConsumer(
PlanConstructor.getNextRelationId(), new CTEId(1), "cte1", producerPlan);
LogicalLimit<LogicalCTEConsumer> limit = new LogicalLimit<>(10, 5, LimitPhase.ORIGIN, consumer);

CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), limit);
Rule rule = new CollectLimitAboveConsumer().buildRules().get(0);
rule.transform(limit, cascadesContext);

Map<RelationId, Long> collected = cascadesContext.getStatementContext().getConsumerIdToLimitRows();
Assertions.assertEquals(15L, collected.get(consumer.getRelationId()));
}

@Test
void testCollectLocalLimitRowsNeededWithoutAddingOffsetAgain() {
LogicalOlapScan producerPlan = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
LogicalCTEConsumer consumer = new LogicalCTEConsumer(
PlanConstructor.getNextRelationId(), new CTEId(2), "cte2", producerPlan);
LogicalLimit<LogicalCTEConsumer> limit = new LogicalLimit<>(15, 0, LimitPhase.LOCAL, consumer);

CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), limit);
Rule rule = new CollectLimitAboveConsumer().buildRules().get(0);
rule.transform(limit, cascadesContext);

Map<RelationId, Long> collected = cascadesContext.getStatementContext().getConsumerIdToLimitRows();
Assertions.assertEquals(15L, collected.get(consumer.getRelationId()));
}

@Test
void testKeepMaxRowsNeededWhenConsumerIsCollectedMultipleTimes() {
LogicalOlapScan producerPlan = PlanConstructor.newLogicalOlapScan(10, "t_merge", 0);
LogicalCTEConsumer consumer = new LogicalCTEConsumer(
PlanConstructor.getNextRelationId(), new CTEId(10), "cte_merge", producerPlan);
LogicalLimit<LogicalCTEConsumer> highLimit = new LogicalLimit<>(20, 0, LimitPhase.ORIGIN, consumer);
LogicalLimit<LogicalCTEConsumer> lowLimit = new LogicalLimit<>(3, 0, LimitPhase.ORIGIN, consumer);

CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), highLimit);
Rule rule = new CollectLimitAboveConsumer().buildRules().get(0);
rule.transform(highLimit, cascadesContext);
rule.transform(lowLimit, cascadesContext);

Map<RelationId, Long> collected = cascadesContext.getStatementContext().getConsumerIdToLimitRows();
Assertions.assertEquals(20L, collected.get(consumer.getRelationId()));
}

@Test
void testCollectLimitAboveProjectRowsNeeded() {
LogicalOlapScan producerPlan = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
LogicalCTEConsumer consumer = new LogicalCTEConsumer(
PlanConstructor.getNextRelationId(), new CTEId(3), "cte3", producerPlan);
LogicalProject<LogicalCTEConsumer> project = new LogicalProject<>(
ImmutableList.copyOf(consumer.getOutput()), consumer);
LogicalLimit<LogicalProject<LogicalCTEConsumer>> limit = new LogicalLimit<>(
7, 0, LimitPhase.LOCAL, project);

CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), limit);
List<Rule> rules = new CollectLimitAboveConsumer().buildRules();
rules.get(1).transform(limit, cascadesContext);

Map<RelationId, Long> collected = cascadesContext.getStatementContext().getConsumerIdToLimitRows();
Assertions.assertEquals(7L, collected.get(consumer.getRelationId()));
}

@Test
void testSkipLimitAboveProjectWithUnnest() {
LogicalOlapScan producerPlan = PlanConstructor.newLogicalOlapScan(3, "t4", 0);
LogicalCTEConsumer consumer = new LogicalCTEConsumer(
PlanConstructor.getNextRelationId(), new CTEId(4), "cte4", producerPlan);
SlotReference arr = new SlotReference("arr", ArrayType.of(IntegerType.INSTANCE));
LogicalProject<LogicalCTEConsumer> project = new LogicalProject<>(
ImmutableList.of(new Alias(new Unnest(arr), "a")), consumer);
LogicalLimit<LogicalProject<LogicalCTEConsumer>> limit = new LogicalLimit<>(
3, 0, LimitPhase.ORIGIN, project);

CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), limit);
List<Rule> rules = new CollectLimitAboveConsumer().buildRules();
Assertions.assertFalse(rules.get(1).getPattern().matchPlanTree(limit));

Map<RelationId, Long> collected = cascadesContext.getStatementContext().getConsumerIdToLimitRows();
Assertions.assertFalse(collected.containsKey(consumer.getRelationId()));
}

@Test
void testSkipLimitAboveDistinctProject() {
LogicalOlapScan producerPlan = PlanConstructor.newLogicalOlapScan(4, "t5", 0);
LogicalCTEConsumer consumer = new LogicalCTEConsumer(
PlanConstructor.getNextRelationId(), new CTEId(5), "cte5", producerPlan);
LogicalProject<LogicalCTEConsumer> project = new LogicalProject<>(
ImmutableList.copyOf(consumer.getOutput()), true, consumer);
LogicalLimit<LogicalProject<LogicalCTEConsumer>> limit = new LogicalLimit<>(
3, 0, LimitPhase.ORIGIN, project);

CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(new ConnectContext(), limit);
List<Rule> rules = new CollectLimitAboveConsumer().buildRules();
Assertions.assertFalse(rules.get(1).getPattern().matchPlanTree(limit));

Map<RelationId, Long> collected = cascadesContext.getStatementContext().getConsumerIdToLimitRows();
Assertions.assertFalse(collected.containsKey(consumer.getRelationId()));
}
}
Loading
Loading