Skip to content

Commit

Permalink
[enhancement](nereids) only push having as agg's parent if having jus…
Browse files Browse the repository at this point in the history
…t use slots from agg's output (apache#32414)

1. only push having as agg's parent if having just use slots from agg's output
2. show user friendly error message when item in select list but not in aggregate node's output
  • Loading branch information
starocean999 authored Mar 26, 2024
1 parent fdc62fa commit 0dd27dc
Show file tree
Hide file tree
Showing 11 changed files with 402 additions and 336 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
Expand All @@ -43,6 +44,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -244,10 +246,33 @@ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<Logi

// create a parent project node
LogicalProject<Plan> project = new LogicalProject<>(upperProjects, newAggregate);
// verify project used slots are all coming from agg's output
List<Slot> slots = collectAllUsedSlots(upperProjects);
if (!slots.isEmpty()) {
Set<ExprId> aggOutputExprIds = new HashSet<>(slots.size());
for (NamedExpression expression : normalizedAggOutput) {
aggOutputExprIds.add(expression.getExprId());
}
List<Slot> errorSlots = new ArrayList<>(slots.size());
for (Slot slot : slots) {
if (!aggOutputExprIds.contains(slot.getExprId())) {
errorSlots.add(slot);
}
}
if (!errorSlots.isEmpty()) {
throw new AnalysisException(String.format("%s not in aggregate's output", errorSlots
.stream().map(NamedExpression::getName).collect(Collectors.joining(", "))));
}
}
if (having.isPresent()) {
if (upperProjects.stream().anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance))) {
// when project contains window functions, in order to get the correct result
// push having through project to make it the parent node of logicalAgg
Set<Slot> havingUsedSlots = ExpressionUtils.getInputSlotSet(having.get().getExpressions());
Set<ExprId> havingUsedExprIds = new HashSet<>(havingUsedSlots.size());
for (Slot slot : havingUsedSlots) {
havingUsedExprIds.add(slot.getExprId());
}
Set<ExprId> aggOutputExprIds = newAggregate.getOutputExprIdSet();
if (aggOutputExprIds.containsAll(havingUsedExprIds)) {
// when having just use output slots from agg, we push down having as parent of agg
return project.withChildren(ImmutableList.of(
new LogicalHaving<>(
ExpressionUtils.replace(having.get().getConjuncts(), project.getAliasToProducer()),
Expand Down Expand Up @@ -287,4 +312,15 @@ private List<NamedExpression> normalizeOutput(List<NamedExpression> aggregateOut
}
return builder.build();
}

private List<Slot> collectAllUsedSlots(List<NamedExpression> expressions) {
Set<Slot> inputSlots = ExpressionUtils.getInputSlotSet(expressions);
List<SubqueryExpr> subqueries = ExpressionUtils.collectAll(expressions, SubqueryExpr.class::isInstance);
List<Slot> slots = new ArrayList<>(inputSlots.size() + subqueries.size());
for (SubqueryExpr subqueryExpr : subqueries) {
slots.addAll(subqueryExpr.getCorrelateSlots());
}
slots.addAll(ExpressionUtils.getInputSlotSet(expressions));
return slots;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public void testCTEInHavingAndSubquery() {
logicalFilter(
logicalProject(
logicalJoin(
logicalProject(logicalAggregate()),
logicalAggregate(),
logicalProject()
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,9 @@ public void testHavingGroupBySlot() {
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1))))));
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1)))));

sql = "SELECT a1 as value FROM t1 GROUP BY a1 HAVING a1 > 0";
SlotReference value = new SlotReference(new ExprId(3), "value", TinyIntType.INSTANCE, true,
Expand Down Expand Up @@ -134,10 +133,9 @@ public void testHavingGroupBySlot() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(a1, new TinyIntLiteral((byte) 0)))))
).when(FieldChecker.check("projects", Lists.newArrayList(sumA2.toSlot()))));
}
Expand All @@ -158,10 +156,9 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));

Expand All @@ -171,13 +168,12 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(
logicalOlapScan()
)
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))));
logicalAggregate(
logicalProject(
logicalOlapScan()
)
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2))))
.when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA2.toSlot(), Literal.of(0L)))))));

sql = "SELECT a1, sum(a2) as value FROM t1 GROUP BY a1 HAVING sum(a2) > 0";
a1 = new SlotReference(
Expand All @@ -193,22 +189,20 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(
logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
logicalAggregate(
logicalProject(
logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L)))))));

sql = "SELECT a1, sum(a2) as value FROM t1 GROUP BY a1 HAVING value > 0";
PlanChecker.from(connectContext).analyze(sql)
.matches(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(
logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value))))
logicalAggregate(
logicalProject(
logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, value)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(value.toSlot(), Literal.of(0L))))));

sql = "SELECT a1, sum(a2) FROM t1 GROUP BY a1 HAVING MIN(pk) > 0";
Expand All @@ -230,10 +224,9 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK))))
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, minPK)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(minPK.toSlot(), Literal.of((byte) 0)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));

Expand All @@ -243,10 +236,9 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2))))
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A2.toSlot(), Literal.of(0L)))))));

sql = "SELECT a1, sum(a1 + a2) FROM t1 GROUP BY a1 HAVING sum(a1 + a2 + 3) > 0";
Expand All @@ -256,10 +248,9 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23))))
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA1A2, sumA1A23)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(sumA1A23.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA1A2.toSlot()))));

Expand All @@ -269,10 +260,9 @@ public void testHavingAggregateFunction() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))))
logicalAggregate(
logicalProject(logicalOlapScan())
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(countStar.toSlot(), Literal.of(0L)))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
}
Expand All @@ -298,17 +288,16 @@ void testJoinWithHaving() {
.matches(
logicalProject(
logicalFilter(
logicalProject(
logicalAggregate(
logicalProject(
logicalFilter(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
))
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1)))
)).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE),
logicalAggregate(
logicalProject(
logicalFilter(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
))
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, sumA2, sumB1)))
).when(FieldChecker.check("conjuncts", ImmutableSet.of(new GreaterThan(new Cast(a1, BigIntType.INSTANCE),
sumB1.toSlot()))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot(), sumA2.toSlot()))));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@
2 1 23.0000000000
2 2 23.0000000000

-- !select5 --
1 1 3.0000000000
1 2 3.0000000000

Loading

0 comments on commit 0dd27dc

Please sign in to comment.