Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
Expand All @@ -54,6 +53,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* eager aggregation
Expand Down Expand Up @@ -133,34 +133,32 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, P
}
}

// Do not push count(*)/count(literal)/count(preserved_side_col) to the nullable side of outer joins.
// count(*) counts all physical rows, including null-extended rows from the outer join.
// After pushdown to the nullable side, unmatched rows produce NULL for the pre-aggregated count,
// and ifnull(sum(NULL), 0) = 0, which loses the count of unmatched rows.
// However, count(nullable_side_col) is safe to push down because for unmatched rows,
// nullable_side_col IS NULL, so the original count is 0, matching ifnull(sum(NULL), 0) = 0.
// Do not push agg(literal) or agg(preserved_side_col) to the nullable side of outer joins.
// Aggregates like count(*), sum(2), min(1) etc. aggregate over all physical rows,
// including null-extended rows from the outer join.
// After pushdown to the nullable side, unmatched rows produce NULL for the pre-aggregated value,
// losing the contribution of those rows (e.g. sum(2) should add 2 per unmatched row,
// but sum(NULL) skips them).
// However, agg(nullable_side_col) is safe to push down because for unmatched rows,
// nullable_side_col IS NULL, and the aggregate naturally handles NULL values correctly.
if (!join.getJoinType().isInnerJoin() && !join.getJoinType().isCrossJoin()) {
JoinType joinType = join.getJoinType();
boolean leftIsNullable = joinType.isRightOuterJoin() || joinType.isFullOuterJoin();
boolean rightIsNullable = joinType.isLeftOuterJoin() || joinType.isFullOuterJoin();
for (AggregateFunction aggFunc : context.getAggFunctions()) {
if (aggFunc instanceof Count) {
Set<Slot> countInputSlots = aggFunc.getInputSlots();
// Determine which side is nullable
boolean leftIsNullable = joinType.isRightOuterJoin() || joinType.isFullOuterJoin();
boolean rightIsNullable = joinType.isLeftOuterJoin() || joinType.isFullOuterJoin();
// Check if we're pushing to a nullable side without referencing its columns
if (toLeft && leftIsNullable) {
boolean hasLeftInput = countInputSlots.stream()
.anyMatch(slot -> join.left().getOutputSet().contains(slot));
if (!hasLeftInput) {
toLeft = false;
}
Set<Slot> inputSlots = aggFunc.getInputSlots();
if (toLeft && leftIsNullable) {
boolean hasLeftInput = inputSlots.stream()
.anyMatch(slot -> join.left().getOutputSet().contains(slot));
if (!hasLeftInput) {
toLeft = false;
}
if (toRight && rightIsNullable) {
boolean hasRightInput = countInputSlots.stream()
.anyMatch(slot -> join.right().getOutputSet().contains(slot));
if (!hasRightInput) {
toRight = false;
}
}
if (toRight && rightIsNullable) {
boolean hasRightInput = inputSlots.stream()
.anyMatch(slot -> join.right().getOutputSet().contains(slot));
if (!hasRightInput) {
toRight = false;
}
}
}
Expand Down Expand Up @@ -505,6 +503,28 @@ public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, PushDown

@Override
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, PushDownAggContext context) {
if (filter.child() instanceof LogicalRelation) {
return genAggregate(filter, context);
}
if (filter.getConjuncts().stream().anyMatch(Expression::containsUniqueFunction)) {
return genAggregate(filter, context);
}
List<SlotReference> filterInputSlots = filter.getInputSlots().stream()
.map(slot -> (SlotReference) slot)
.collect(Collectors.toList());
List<SlotReference> childGroupKeys = Stream.concat(
context.getGroupKeys().stream(),
filterInputSlots.stream())
.distinct()
.collect(Collectors.toList());
PushDownAggContext childContext = context.withGroupKeys(childGroupKeys);
if (!childContext.isValid()) {
return genAggregate(filter, context);
}
Plan newChild = filter.child().accept(this, childContext);
if (newChild != filter.child()) {
return filter.withChildren(newChild);
}
return genAggregate(filter, context);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

package org.apache.doris.nereids.rules.rewrite.eageraggregation;

import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

class EagerAggRewriterTest extends TestWithFeService implements MemoPatternMatchSupported {
Expand Down Expand Up @@ -311,4 +315,128 @@ void testAsofJoinNotPushAgg() {
connectContext.getSessionVariable().setDisableJoinReorder(false);
}
}

@Test
void testNotPushAggLiteralToNullableSideOfOuterJoin() {
// sum(literal), min(literal), max(literal) aggregate over all physical rows,
// including null-extended rows from the outer join.
// Pushing to the nullable side loses the contribution of unmatched rows:
// original: sum(2) on unmatched row = 2
// pushed: sum(NULL) skips the row (wrong!)
// So agg(literal) must NOT be pushed to the nullable side.
connectContext.getSessionVariable().setEagerAggregationMode(1);
connectContext.getSessionVariable().setDisableJoinReorder(true);
try {
// RIGHT JOIN: t1 is the nullable side (left side of RIGHT JOIN)
// sum(2) should NOT be pushed to t1
String sql = "select sum(2), t2.id2 from t1 right join t2"
+ " on t1.id1 = t2.id2 group by t2.id2";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.nonMatch(logicalJoin(logicalAggregate(), any()))
.printlnTree();

// LEFT JOIN: t2 is the nullable side (right side of LEFT JOIN)
// min(1) should NOT be pushed to t2
sql = "select min(1), t1.id1 from t1 left join t2"
+ " on t1.id1 = t2.id2 group by t1.id1";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.nonMatch(logicalJoin(any(), logicalAggregate()))
.printlnTree();

// RIGHT JOIN: max(3) should NOT be pushed to nullable left side
sql = "select max(3), t2.id2 from t1 right join t2"
+ " on t1.id1 = t2.id2 group by t2.id2";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.nonMatch(logicalJoin(logicalAggregate(), any()))
.printlnTree();

// Verify agg(nullable_side_col) is still safe to push (no regression)
// max(t1.name) references the left (nullable) side, so push is allowed
sql = "select max(t1.name), t2.id2 from t1 right join t2"
+ " on t1.id1 = t2.id2 group by t2.id2";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate(logicalProject(logicalJoin(logicalAggregate(), any()))))
.printlnTree();
} finally {
connectContext.getSessionVariable().setEagerAggregationMode(0);
connectContext.getSessionVariable().setDisableJoinReorder(false);
}
}

@Test
void testUniqueFunctionFilterBlocksPushDownThroughFilter() {
connectContext.getSessionVariable().setEagerAggregationMode(1);
connectContext.getSessionVariable().setDisableJoinReorder(true);
try {
String sql = "select count(s.name1), t2.id2"
+ " from (select * from (select id1, name as name1 from t1) s1 where random() < 0.5) s"
+ " join t2 on s.id1 = t2.id2 group by t2.id2";
Plan plan = PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.getPlan();
Assertions.assertEquals(2, countPlans(plan, LogicalAggregate.class), plan.treeString());
LogicalFilter<?> filter = findFirstPlan(plan, LogicalFilter.class);
Assertions.assertNotNull(filter, plan.treeString());
Assertions.assertFalse(containsPlan(filter.child(), LogicalAggregate.class), plan.treeString());
} finally {
connectContext.getSessionVariable().setEagerAggregationMode(0);
connectContext.getSessionVariable().setDisableJoinReorder(false);
}
}

@Test
void testInvalidFilterContextFallsBackToCurrentFilter() {
connectContext.getSessionVariable().setEagerAggregationMode(1);
connectContext.getSessionVariable().setDisableJoinReorder(true);
try {
String sql = "select count(s.name1), t2.id2"
+ " from (select * from (select id1, name as name1 from t1) s1 where s1.name1 is not null) s"
+ " join t2 on s.id1 = t2.id2 group by t2.id2";
Plan plan = PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.getPlan();
Assertions.assertEquals(2, countPlans(plan, LogicalAggregate.class), plan.treeString());
LogicalFilter<?> filter = findFirstPlan(plan, LogicalFilter.class);
Assertions.assertNotNull(filter, plan.treeString());
Assertions.assertFalse(containsPlan(filter.child(), LogicalAggregate.class), plan.treeString());
} finally {
connectContext.getSessionVariable().setEagerAggregationMode(0);
connectContext.getSessionVariable().setDisableJoinReorder(false);
}
}

private int countPlans(Plan plan, Class<? extends Plan> clazz) {
int count = clazz.isInstance(plan) ? 1 : 0;
for (Plan child : plan.children()) {
count += countPlans(child, clazz);
}
return count;
}

private boolean containsPlan(Plan plan, Class<? extends Plan> clazz) {
return countPlans(plan, clazz) > 0;
}

private <T extends Plan> T findFirstPlan(Plan plan, Class<T> clazz) {
if (clazz.isInstance(plan)) {
return clazz.cast(plan);
}
for (Plan child : plan.children()) {
T matched = findFirstPlan(child, clazz);
if (matched != null) {
return matched;
}
}
return null;
}
}
73 changes: 73 additions & 0 deletions regression-test/data/nereids_p0/eager_agg/eager_agg.out
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,76 @@ Used:
UnUsed:
SyntaxError: leading({ ss broadcast dt } broadcast ws) Msg:can not find table: ws

-- !check_sum_literal_right_join_not_push --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.val = c.val) and (b.id2 = c.id2)) otherCondition=()
--------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.id = b.id)) otherCondition=()
----------PhysicalOlapScan[eager_agg_t1]
----------PhysicalOlapScan[eager_agg_t2]
--------PhysicalOlapScan[eager_agg_t3]

-- !check_sum_literal_left_join_not_push --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[LEFT_OUTER_JOIN] hashCondition=((date_dim.d_date_sk = store_sales.ss_sold_date_sk)) otherCondition=()
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------PhysicalOlapScan[store_sales]
--------PhysicalOlapScan[date_dim]

-- !check_min_literal_right_join_not_push --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.val = c.val) and (b.id2 = c.id2)) otherCondition=()
--------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.id = b.id)) otherCondition=()
----------PhysicalOlapScan[eager_agg_t1]
----------PhysicalOlapScan[eager_agg_t2]
--------PhysicalOlapScan[eager_agg_t3]

-- !check_max_literal_left_join_not_push --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[LEFT_OUTER_JOIN] hashCondition=((date_dim.d_date_sk = store_sales.ss_sold_date_sk)) otherCondition=()
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------PhysicalOlapScan[store_sales]
--------PhysicalOlapScan[date_dim]

-- !sum_literal_right_join_eager_off --
\N 4
10 2

-- !sum_literal_right_join_eager_on --
\N 4
10 2

-- !min_literal_right_join_eager_on --
\N 1
10 1

-- !max_literal_right_join_eager_on --
\N 3
10 3

-- !check_filter_slots_preserved_pushdown --
PhysicalResultSink
--hashAgg[GLOBAL]
----filter(OR[( not (id = 1)),id IS NULL])
------hashJoin[LEFT_OUTER_JOIN] hashCondition=((a.id = b.id)) otherCondition=()
--------hashAgg[GLOBAL]
----------PhysicalOlapScan[eager_agg_filter_t1]
--------PhysicalOlapScan[eager_agg_filter_t2]

Hint log:
Used: [broadcast]_1
UnUsed:
SyntaxError:

-- !filter_slots_preserved_eager_on --
2 20

Loading
Loading