Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,17 @@ public Map<CTEId, LogicalPlan> getRewrittenCteConsumer() {
return rewrittenCteConsumer;
}

/** Clear CTE-related rewrite and memo state before rebuilding it from a new plan tree. */
public void clearCteEnvironment() {
cteIdToConsumers.clear();
cteIdToOutputIds.clear();
cteIdToProducer.clear();
consumerIdToFilters.clear();
cteIdToConsumerGroup.clear();
rewrittenCteProducer.clear();
rewrittenCteConsumer.clear();
}

/**
* Snapshot current CTE-related environment for temporary rewrite/optimization.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.jobs.executor;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupJob;
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
Expand All @@ -36,6 +37,9 @@
import org.apache.doris.nereids.rules.rewrite.MergeProjectable;
import org.apache.doris.nereids.rules.rewrite.PushDownExpressionsInHashCondition;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.util.MoreFieldsThread;
import org.apache.doris.qe.ConnectContext;
Expand Down Expand Up @@ -66,6 +70,12 @@ public Optimizer(CascadesContext cascadesContext) {
*/
public void execute() {
MoreFieldsThread.keepFunctionSignature(() -> {
Plan rewritePlan = cascadesContext.getRewritePlan();
if (containsCte(rewritePlan)) {
Comment thread
englefly marked this conversation as resolved.
Plan normalizedPlan = normalizeCtePlan(rewritePlan);
cascadesContext.setRewritePlan(normalizedPlan);
refreshCteContext(normalizedPlan);
}
// generate inlined CTE alternative for CBO comparison
Plan cboInlinedPlan = generateCTEInlineAlternative();
// init memo
Expand Down Expand Up @@ -195,9 +205,9 @@ private Plan generateCTEInlineAlternative() {
private Plan generateFullCTEInline() {
Plan rewritePlan = cascadesContext.getRewritePlan();
CTEInliner cteInliner = new CTEInliner(cascadesContext.getStatementContext());
Plan inlinedPlan = cteInliner.generateInlinedPlan(rewritePlan);
if (inlinedPlan != null) {
return rewriteInlinedPlan(inlinedPlan);
Plan pushedDownInlinedPlan = generateFilterPushedDownInlinedPlan(cteInliner, rewritePlan);
if (pushedDownInlinedPlan != null) {
return normalizeCtePlan(pushedDownInlinedPlan);
}
return null;
}
Expand All @@ -208,22 +218,60 @@ private Plan generateFullCTEInline() {
private Plan generateSelectiveCTEInline() {
Plan rewritePlan = cascadesContext.getRewritePlan();
CTEInliner cteInliner = new CTEInliner(cascadesContext.getStatementContext(), true);
Plan inlinedPlan = cteInliner.generateInlinedPlan(rewritePlan);
if (inlinedPlan != null) {
inlinedPlan = rewriteInlinedPlan(inlinedPlan);
if (inlinedPlan.anyMatch(p -> p instanceof LogicalEmptyRelation)) {
inlinedPlan = eliminateEmptyRelation(inlinedPlan);
cascadesContext.setRewritePlan(inlinedPlan);
Plan pushedDownInlinedPlan = generateFilterPushedDownInlinedPlan(cteInliner, rewritePlan);
if (pushedDownInlinedPlan != null) {
if (pushedDownInlinedPlan.anyMatch(p -> p instanceof LogicalEmptyRelation)) {
pushedDownInlinedPlan = normalizeCtePlan(pushedDownInlinedPlan);
cascadesContext.setRewritePlan(pushedDownInlinedPlan);
refreshCteContext(pushedDownInlinedPlan);
return null;
}
}
return null;
}

private Plan normalizeCtePlan(Plan plan) {
Plan currentPlan = plan;
while (true) {
if (currentPlan.anyMatch(p -> p instanceof LogicalEmptyRelation)) {
currentPlan = eliminateEmptyRelation(currentPlan);
}
CTEInliner cteInliner = new CTEInliner(cascadesContext.getStatementContext());
CTEInliner.InlineResult inlineResult = cteInliner.inlineByCurrentConsumerCount(currentPlan);
Plan normalizedPlan = inlineResult.getPlan();
// Do not use Plan.equals() as a fixpoint check here. Some logical nodes,
// e.g. LogicalCTEAnchor and LogicalSubQueryAlias, intentionally ignore
// children in equals(), so a child CTE rewrite under a kept parent may be
// missed and block cascading consumer-count-based inlining.
if (!inlineResult.isChanged()) {
return normalizedPlan;
}
currentPlan = normalizedPlan;
}
}

private boolean containsCte(Plan plan) {
return plan.anyMatch(p -> p instanceof LogicalCTEAnchor || p instanceof LogicalCTEConsumer);
}

private void refreshCteContext(Plan plan) {
StatementContext statementContext = cascadesContext.getStatementContext();
statementContext.clearCteEnvironment();
plan.foreach(p -> {
if (p instanceof LogicalCTEAnchor) {
LogicalCTEAnchor<?, ?> anchor = (LogicalCTEAnchor<?, ?>) p;
statementContext.setCteProducer(anchor.getCteId(), (LogicalCTEProducer<?>) anchor.left());
} else if (p instanceof LogicalCTEConsumer) {
cascadesContext.putCTEIdToConsumer((LogicalCTEConsumer) p);
}
return false;
});
}

private Plan eliminateEmptyRelation(Plan plan) {
CascadesContext ctx = CascadesContext.initContext(
cascadesContext.getStatementContext(), plan, PhysicalProperties.ANY);
// Use getCteChildrenRewriter for the same reason as rewriteInlinedPlan:
// Use getCteChildrenRewriter for the same reason as pushDownFilterAndPruneInlinedPlan:
// getWholeTreeRewriterWithCustomJobs would invoke RewriteCteChildren which
// reads stale rewrittenCteConsumer cache from the main Rewriter phase,
// reverting the inlined CTE subtrees back to the original structure.
Expand All @@ -234,6 +282,14 @@ private Plan eliminateEmptyRelation(Plan plan) {
return ctx.getRewritePlan();
}

private Plan generateFilterPushedDownInlinedPlan(CTEInliner cteInliner, Plan rewritePlan) {
Plan inlinedPlan = cteInliner.generateInlinedPlan(rewritePlan);
if (inlinedPlan == null) {
return null;
}
return pushDownFilterAndPruneInlinedPlan(inlinedPlan);
}

/**
* Run filter pushdown and column pruning on the inlined plan using a temporary
* CascadesContext.
Expand All @@ -246,7 +302,7 @@ private Plan eliminateEmptyRelation(Plan plan) {
* phase. That cached outer query still contains LogicalCTEConsumer nodes for the inlined CTE,
* preventing the filter from ever reaching the inlined union body.
*/
private Plan rewriteInlinedPlan(Plan inlinedPlan) {
private Plan pushDownFilterAndPruneInlinedPlan(Plan inlinedPlan) {
CascadesContext inlinedContext = CascadesContext.initContext(
cascadesContext.getStatementContext(), inlinedPlan, PhysicalProperties.ANY);
Rewriter.getCteChildrenRewriter(inlinedContext, ImmutableList.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
Expand All @@ -41,8 +42,10 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Generate an inlined alternative plan for CTE optimization.
Expand All @@ -65,6 +68,7 @@ public class CTEInliner extends DefaultPlanRewriter<Void> {
private final StatementContext statementContext;
// Map from CTEId to the CTE producer node (extracted from CTEAnchor.left())
private final Map<CTEId, LogicalCTEProducer<?>> cteProducers = new HashMap<>();
private final Set<CTEId> cteIdsToRemove = new HashSet<>();
private final boolean unionAllOnly;

public CTEInliner(StatementContext statementContext) {
Expand All @@ -81,6 +85,7 @@ public CTEInliner(StatementContext statementContext, boolean unionAllOnly) {
* Returns null if no CTEs can be inlined.
*/
public Plan generateInlinedPlan(Plan plan) {
clearRewriteCandidates();
// First pass: collect all CTE producers that can be inlined
collectCTEProducers(plan);

Expand All @@ -92,6 +97,44 @@ public Plan generateInlinedPlan(Plan plan) {
return plan.accept(this, null);
}

/**
* Recursively remove unused CTE anchors and inline CTEs whose live consumer count
* is small enough after rewrite rules change the plan shape.
*/
public InlineResult inlineByCurrentConsumerCount(Plan plan) {
Plan currentPlan = plan;
boolean changed = false;
while (collectConsumerDrivenCandidates(currentPlan)) {
changed = true;
currentPlan = currentPlan.accept(this, null);
}
return new InlineResult(currentPlan, changed);
}

/** Result of one consumer-count-driven CTE normalization round. */
public static class InlineResult {
private final Plan plan;
private final boolean changed;

public InlineResult(Plan plan, boolean changed) {
this.plan = plan;
this.changed = changed;
}

public Plan getPlan() {
return plan;
}

public boolean isChanged() {
return changed;
}
}

private void clearRewriteCandidates() {
cteProducers.clear();
cteIdsToRemove.clear();
}

private void collectCTEProducers(Plan plan) {
plan.foreach(p -> {
if (p instanceof LogicalCTEAnchor) {
Expand All @@ -113,6 +156,40 @@ private void collectCTEProducers(Plan plan) {
});
}

private boolean collectConsumerDrivenCandidates(Plan plan) {
clearRewriteCandidates();
Map<CTEId, LogicalCTEProducer<?>> allCteProducers = new HashMap<>();
Map<CTEId, Integer> cteConsumerCounts = new HashMap<>();
plan.foreach(p -> {
if (p instanceof LogicalCTEAnchor) {
LogicalCTEAnchor<?, ?> anchor = (LogicalCTEAnchor<?, ?>) p;
allCteProducers.put(anchor.getCteId(), (LogicalCTEProducer<?>) anchor.left());
} else if (p instanceof LogicalCTEConsumer) {
LogicalCTEConsumer consumer = (LogicalCTEConsumer) p;
cteConsumerCounts.merge(consumer.getCteId(), 1, Integer::sum);
}
});

int threshold = statementContext.getConnectContext().getSessionVariable().inlineCTEReferencedThreshold;
for (Map.Entry<CTEId, LogicalCTEProducer<?>> entry : allCteProducers.entrySet()) {
CTEId cteId = entry.getKey();
LogicalCTEProducer<?> producer = entry.getValue();
int consumerCount = cteConsumerCounts.getOrDefault(cteId, 0);
if (consumerCount == 0) {
cteIdsToRemove.add(cteId);
} else if (producer.child() instanceof LogicalEmptyRelation
|| (consumerCount <= threshold && canInline(producer))) {
cteProducers.put(cteId, producer);
}
}
return !cteProducers.isEmpty() || !cteIdsToRemove.isEmpty();
}

private boolean canInline(LogicalCTEProducer<?> producer) {
return !statementContext.isForceMaterializeCTE(producer.getCteId())
&& !containsNondeterministicFunction(producer);
}

private boolean containsNondeterministicFunction(LogicalCTEProducer<?> producer) {
List<Expression> nondeterministicFunctions = new ArrayList<>();
producer.accept(NondeterministicFunctionCollector.INSTANCE, nondeterministicFunctions);
Expand All @@ -127,13 +204,14 @@ private boolean containsUnionAll(LogicalCTEProducer<?> producer) {
@Override
public Plan visitLogicalCTEAnchor(LogicalCTEAnchor<? extends Plan, ? extends Plan> cteAnchor, Void context) {
CTEId cteId = cteAnchor.getCteId();
if (cteProducers.containsKey(cteId)) {
// Inline: skip anchor and producer, process the right (consumer) subtree
if (cteProducers.containsKey(cteId) || cteIdsToRemove.contains(cteId)) {
// Inline or remove: skip anchor and producer, process the right (consumer) subtree
return cteAnchor.right().accept(this, null);
} else {
// Force materialize: keep the structure, only process the right subtree
// Keep the structure and continue trimming nested CTEs in both children.
Plan left = cteAnchor.left().accept(this, null);
Plan right = cteAnchor.right().accept(this, null);
return cteAnchor.withChildren(cteAnchor.left(), right);
return cteAnchor.withChildren(left, right);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ public class ClearContextStatus implements CustomRewriter {

@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
jobContext.getCascadesContext().getStatementContext().getRewrittenCteConsumer().clear();
jobContext.getCascadesContext().getStatementContext().getRewrittenCteProducer().clear();
jobContext.getCascadesContext().getStatementContext().getCteIdToOutputIds().clear();
jobContext.getCascadesContext().getStatementContext().getConsumerIdToFilters().clear();
jobContext.getCascadesContext().getStatementContext().clearCteEnvironment();
return plan;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2576,7 +2576,7 @@ public static boolean isEagerAggregationOnJoin() {
@VarAttrDef.VarAttr(name = ENABLE_ORDERED_SCAN_RANGE_LOCATIONS)
public boolean enableOrderedScanRangeLocations = false;

@VarAttrDef.VarAttr(name = CTE_INLINE_MODE, alias = "cbo_cte_inline_mode", description = {
@VarAttrDef.VarAttr(name = CTE_INLINE_MODE, description = {
Comment thread
englefly marked this conversation as resolved.
"CTE内联模式。<0:禁用; =0:仅当CTE体含UNION ALL且filter可消除部分分支时内联; >=1:CBO比较物化与内联",
"CTE inline mode. <0: disable; =0: only inline when CTE body contains UNION ALL "
+ "and consumer filters can eliminate some union branches; "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,34 @@
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.qe.OriginStatement;
import org.apache.doris.utframe.TestWithFeService;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.Map;
import java.util.Set;

public class CTEInlineTest extends TestWithFeService implements MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
connectContext.setDatabase("test");
createTable("CREATE TABLE cte_inline_tbl (\n"
+ " id int NULL,\n"
+ " val int NULL\n"
+ ") ENGINE=OLAP\n"
+ "DUPLICATE KEY(id)\n"
+ "DISTRIBUTED BY HASH(id) BUCKETS 1\n"
+ "PROPERTIES (\"replication_num\" = \"1\")");
}

@Test
Expand Down Expand Up @@ -81,4 +94,29 @@ public void recCteInline() {
).when(cte -> cte.getCteName().equals("yy"))
);
}

@Test
public void refreshCteConsumersAfterNormalizeEliminatesEmptyBranch() {
int oldCteInlineMode = connectContext.getSessionVariable().cteInlineMode;
int oldInlineCteReferencedThreshold = connectContext.getSessionVariable().inlineCTEReferencedThreshold;
connectContext.getSessionVariable().cteInlineMode = 0;
connectContext.getSessionVariable().inlineCTEReferencedThreshold = 1;
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
String sql = "with cte as (select id, val from cte_inline_tbl) "
+ "select * from cte where id = 1 "
+ "union all select * from cte where id = 2 "
+ "union all select * from cte where 1 = 0";
try {
PlanChecker.from(connectContext).checkPlannerResult(sql, planner -> {
Map<CTEId, Set<LogicalCTEConsumer>> consumers =
planner.getCascadesContext().getStatementContext().getCteIdToConsumers();
Assertions.assertEquals(1, consumers.size());
Assertions.assertEquals(2, consumers.values().iterator().next().size());
});
} finally {
connectContext.getSessionVariable().cteInlineMode = oldCteInlineMode;
connectContext.getSessionVariable().inlineCTEReferencedThreshold = oldInlineCteReferencedThreshold;
connectContext.getSessionVariable().setDisableNereidsRules("");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----PhysicalProject[a + random(), a + random() + 0, abs(a + random()), sum(a + random() + 0), sum(a + random())]
------PhysicalQuickSort[MERGE_SORT]
--------PhysicalQuickSort[LOCAL_SORT]
----------PhysicalProject[(a + random() + 1.0) AS `(a + random() + 1.0)`, a + random(), a + random() + 0, abs(a + random()) AS `abs(a + random())`, sum(a + random() + 0), sum(a + random())]
------------PhysicalUnion
--------------PhysicalEmptyRelation
--------------filter((.a + random() + 0 > 0.01))
----------------PhysicalCteConsumer ( cteId=CTEId#0 )
----------PhysicalProject[(a + random() + 1.0) AS `(a + random() + 1.0)`, a + random() + 0 AS `a + random() + 0`, a + random() AS `a + random()`, abs(a + random()) AS `abs(a + random())`, sum(a + random() + 0) AS `sum(a + random() + 0)`, sum(a + random()) AS `sum(a + random())`]
------------filter((.a + random() + 0 > 0.01))
--------------PhysicalCteConsumer ( cteId=CTEId#0 )

Loading
Loading