Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@
import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderLimit;
import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderTopN;
import org.apache.doris.nereids.rules.rewrite.PushCountIntoUnionAll;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOnPkFk;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
import org.apache.doris.nereids.rules.rewrite.PushDownAggWithDistinctThroughJoinOneSide;
Expand Down Expand Up @@ -174,6 +173,7 @@
import org.apache.doris.nereids.rules.rewrite.batch.ApplyToJoin;
import org.apache.doris.nereids.rules.rewrite.batch.CorrelateApplyToUnCorrelateApply;
import org.apache.doris.nereids.rules.rewrite.batch.EliminateUselessPlanUnderApply;
import org.apache.doris.nereids.rules.rewrite.eageraggregation.PushDownAggregation;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
Expand Down Expand Up @@ -658,19 +658,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
new MergeAggregate()
)
),
topic("Eager aggregation",
cascadesContext -> cascadesContext.rewritePlanContainsTypes(
LogicalAggregate.class, LogicalJoin.class
),
costBased(topDown(
new PushDownAggWithDistinctThroughJoinOneSide(),
new PushDownAggThroughJoinOneSide(),
new PushDownAggThroughJoin()
)),
costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, PushDownDistinctThroughJoin::new)),
topDown(new PushCountIntoUnionAll())
),

// this rule should invoke after infer predicate and push down distinct, and before push down limit
topic("eliminate join according unique or foreign key",
cascadesContext -> cascadesContext.rewritePlanContainsTypes(LogicalJoin.class),
Expand All @@ -687,7 +674,19 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(new PushDownAggThroughJoinOnPkFk()),
topDown(new PullUpJoinFromUnionAll())
),
topic("Eager aggregation",
cascadesContext -> cascadesContext.rewritePlanContainsTypes(
LogicalAggregate.class, LogicalJoin.class
),
costBased(topDown(
new PushDownAggWithDistinctThroughJoinOneSide(),
new PushDownAggThroughJoinOneSide()
)),

costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, PushDownDistinctThroughJoin::new)),
custom(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN, PushDownAggregation::new),
topDown(new PushCountIntoUnionAll())
),
topic("Limit optimization",
cascadesContext -> cascadesContext.rewritePlanContainsTypes(LogicalLimit.class)
|| cascadesContext.rewritePlanContainsTypes(LogicalTopN.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,35 +130,34 @@ public List<Rule> buildRules() {
.toRule(RuleType.NORMALIZE_AGGREGATE));
}

/**
* The LogicalAggregate node may contain window agg functions and usual agg functions
* we call window agg functions as window-agg and usual agg functions as trivial-agg for short
* This rule simplify LogicalAggregate node by:
* 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node,
* 2. create a new LogicalAggregate with normalized group by exprs and trivial-aggs
* 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node
* Push down exprs:
* 1. all group by exprs
* 2. child contains subquery expr in trivial-agg
* 3. child contains window expr in trivial-agg
* 4. all input slots of trivial-agg
* 5. expr(including subquery) in distinct trivial-agg
* Normalize LogicalAggregate's output.
* 1. normalize group by exprs by outputs of bottom LogicalProject
* 2. normalize trivial-aggs by outputs of bottom LogicalProject
* 3. build normalized agg outputs
* Pull up exprs:
* normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes:
* 1. simple slots
* 2. aliases
* a. alias with no aggs child
* b. alias with trivial-agg child
* c. alias with window-agg
*/
@SuppressWarnings("checkstyle:UnusedLocalVariable")
private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having,
public LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having,
CascadesContext ctx) {
// The LogicalAggregate node may contain window agg functions and usual agg functions
// we call window agg functions as window-agg and usual agg functions as trivial-agg for short
// This rule simplify LogicalAggregate node by:
// 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node,
// 2. create a new LogicalAggregate with normalized group by exprs and trivial-aggs
// 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node
// Push down exprs:
// 1. all group by exprs
// 2. child contains subquery expr in trivial-agg
// 3. child contains window expr in trivial-agg
// 4. all input slots of trivial-agg
// 5. expr(including subquery) in distinct trivial-agg
// Normalize LogicalAggregate's output.
// 1. normalize group by exprs by outputs of bottom LogicalProject
// 2. normalize trivial-aggs by outputs of bottom LogicalProject
// 3. build normalized agg outputs
// Pull up exprs:
// normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes:
// 1. simple slots
// 2. aliases
// a. alias with no aggs child
// b. alias with trivial-agg child
// c. alias with window-agg

// Push down exprs:
// collect group by exprs
Set<Expression> groupingByExprs = Utils.fastToImmutableSet(aggregate.getGroupByExpressions());

// collect all trivial-agg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Alias;
Expand Down Expand Up @@ -52,6 +50,7 @@
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -485,14 +484,9 @@ private static Expression doUpdateExpression(AtomicBoolean changed, Expression i
// repeat may check fail.
if (!slotReference.nullable() && newSlotReference.nullable()
&& check && ConnectContext.get() != null) {
if (ConnectContext.get().getSessionVariable().feDebug) {
throw new AnalysisException("AdjustNullable convert slot " + slotReference
+ " from not-nullable to nullable. You can disable check by set fe_debug = false.");
} else {
LOG.warn("adjust nullable convert slot '" + slotReference
+ "' from not-nullable to nullable for query "
+ DebugUtil.printId(ConnectContext.get().queryId()));
}
SessionVariable.throwAnalysisExceptionWhenFeDebug("AdjustNullable convert slot "
+ slotReference
+ " from not-nullable to nullable. You can disable check by set fe_debug = false.");
}
return newSlotReference;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.roaringbitmap.RoaringBitmap;
Expand All @@ -69,7 +68,6 @@
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.IntStream;

/**
* ColumnPruning.
Expand Down Expand Up @@ -221,28 +219,15 @@ public Plan visitLogicalUnion(LogicalUnion union, PruneContext context) {
}
LogicalUnion prunedOutputUnion = pruneUnionOutput(union, context);
// start prune children of union
List<Slot> originOutput = union.getOutput();
Set<Slot> prunedOutput = prunedOutputUnion.getOutputSet();
List<Integer> prunedOutputIndexes = IntStream.range(0, originOutput.size())
.filter(index -> prunedOutput.contains(originOutput.get(index)))
.boxed()
.collect(ImmutableList.toImmutableList());

ImmutableList.Builder<Plan> prunedChildren = ImmutableList.builder();
ImmutableList.Builder<List<SlotReference>> prunedChildrenOutputs = ImmutableList.builder();
for (int i = 0; i < prunedOutputUnion.arity(); i++) {
List<SlotReference> regularChildOutputs = prunedOutputUnion.getRegularChildOutput(i);

RoaringBitmap prunedChildOutputExprIds = new RoaringBitmap();
Builder<SlotReference> prunedChildOutputBuilder
= ImmutableList.builderWithExpectedSize(regularChildOutputs.size());
for (Integer index : prunedOutputIndexes) {
SlotReference slot = regularChildOutputs.get(index);
prunedChildOutputBuilder.add(slot);
prunedChildOutputExprIds.add(slot.getExprId().asInt());
}

List<SlotReference> prunedChildOutput = prunedChildOutputBuilder.build();
regularChildOutputs.forEach(col -> prunedChildOutputExprIds.add(col.getExprId().asInt()));
List<SlotReference> prunedChildOutput = regularChildOutputs;
Plan prunedChild = doPruneChild(
prunedOutputUnion, prunedOutputUnion.child(i), prunedChildOutputExprIds,
prunedChildOutput, true
Expand Down Expand Up @@ -423,12 +408,13 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context)

ImmutableList.Builder<List<NamedExpression>> prunedConstantExprsList
= ImmutableList.builderWithExpectedSize(constantExprsList.size());
List<List<SlotReference>> prunedRegularChildrenOutputs =
Lists.newArrayListWithCapacity(regularChildrenOutputs.size());
if (prunedOutputs.isEmpty()) {
// process prune all columns
NamedExpression originSlot = originOutput.get(0);
prunedOutputs = ImmutableList.of(new SlotReference(originSlot.getExprId(), originSlot.getName(),
TinyIntType.INSTANCE, false, originSlot.getQualifier()));
regularChildrenOutputs = Lists.newArrayListWithCapacity(regularChildrenOutputs.size());
children = Lists.newArrayListWithCapacity(children.size());
for (int i = 0; i < union.getArity(); i++) {
Plan child = union.child(i);
Expand All @@ -442,28 +428,35 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context)
} else {
project = new LogicalProject<>(newProjectOutput, child);
}
regularChildrenOutputs.add((List) project.getOutput());
prunedRegularChildrenOutputs.add((List) project.getOutput());
children.add(project);
}
for (int i = 0; i < constantExprsList.size(); i++) {
prunedConstantExprsList.add(ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1))));
}
} else {
int len = extractColumnIndex.size();
int prunedOutputSize = extractColumnIndex.size();
for (List<NamedExpression> row : constantExprsList) {
ImmutableList.Builder<NamedExpression> newRow = ImmutableList.builderWithExpectedSize(len);
ImmutableList.Builder<NamedExpression> newRow = ImmutableList.builderWithExpectedSize(prunedOutputSize);
for (int idx : extractColumnIndex) {
newRow.add(row.get(idx));
}
prunedConstantExprsList.add(newRow.build());
}
for (int childIdx = 0; childIdx < union.getRegularChildrenOutputs().size(); childIdx++) {
List<SlotReference> regular = Lists.newArrayListWithExpectedSize(prunedOutputSize);
for (int colIdx : extractColumnIndex) {
regular.add(regularChildrenOutputs.get(childIdx).get(colIdx));
}
prunedRegularChildrenOutputs.add(regular);
}
}

if (prunedOutputs.equals(originOutput) && !context.requiredSlotsIds.isEmpty()) {
return union;
} else {
return union.withNewOutputsChildrenAndConstExprsList(prunedOutputs, children,
regularChildrenOutputs, prunedConstantExprsList.build());
prunedRegularChildrenOutputs, prunedConstantExprsList.build());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,26 @@ private List<Plan> expandInnerJoin(CascadesContext ctx, Pair<List<Expression>,

LogicalCTEConsumer left = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
leftProducer.getCteId(), "", leftProducer);
List<NamedExpression> leftOutput = new ArrayList<>();
for (Slot producerOutputSlot : leftProducer.getOutput()) {
for (Slot consumerSlot : left.getProducerToConsumerOutputMap().get(producerOutputSlot)) {
if (!leftOutput.contains(consumerSlot)) {
leftOutput.add(consumerSlot);
break;
}
}
}
LogicalCTEConsumer right = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
rightProducer.getCteId(), "", rightProducer);
List<NamedExpression> rightOutput = new ArrayList<>();
for (Slot producerOutputSlot : rightProducer.getOutput()) {
for (Slot consumerSlot : right.getProducerToConsumerOutputMap().get(producerOutputSlot)) {
if (!rightOutput.contains(consumerSlot)) {
rightOutput.add(consumerSlot);
break;
}
}
}
ctx.putCTEIdToConsumer(left);
ctx.putCTEIdToConsumer(right);

Expand All @@ -329,7 +347,10 @@ private List<Plan> expandInnerJoin(CascadesContext ctx, Pair<List<Expression>,

LogicalJoin<? extends Plan, ? extends Plan> newJoin = new LogicalJoin<>(
JoinType.INNER_JOIN, hashCond, otherCond, join.getDistributeHint(),
join.getMarkJoinSlotReference(), left, right, null);
join.getMarkJoinSlotReference(),
new LogicalProject<>(leftOutput, left),
new LogicalProject<>(rightOutput, right),
null);
if (newJoin.getHashJoinConjuncts().stream()
.anyMatch(equalTo -> equalTo.children().stream().anyMatch(e -> !(e instanceof Slot)))) {
Plan plan = PushDownExpressionsInHashCondition.pushDownHashExpression(newJoin);
Expand Down
Loading
Loading