Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
Expand Down Expand Up @@ -133,34 +132,32 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, P
}
}

// Do not push count(*)/count(literal)/count(preserved_side_col) to the nullable side of outer joins.
// count(*) counts all physical rows, including null-extended rows from the outer join.
// After pushdown to the nullable side, unmatched rows produce NULL for the pre-aggregated count,
// and ifnull(sum(NULL), 0) = 0, which loses the count of unmatched rows.
// However, count(nullable_side_col) is safe to push down because for unmatched rows,
// nullable_side_col IS NULL, so the original count is 0, matching ifnull(sum(NULL), 0) = 0.
// Do not push agg(literal) or agg(preserved_side_col) to the nullable side of outer joins.
// Aggregates like count(*), sum(2), min(1) etc. aggregate over all physical rows,
// including null-extended rows from the outer join.
// After pushdown to the nullable side, unmatched rows produce NULL for the pre-aggregated value,
// losing the contribution of those rows (e.g. sum(2) should add 2 per unmatched row,
// but sum(NULL) skips them).
// However, agg(nullable_side_col) is safe to push down because for unmatched rows,
// nullable_side_col IS NULL, and the aggregate naturally handles NULL values correctly.
if (!join.getJoinType().isInnerJoin() && !join.getJoinType().isCrossJoin()) {
JoinType joinType = join.getJoinType();
boolean leftIsNullable = joinType.isRightOuterJoin() || joinType.isFullOuterJoin();
boolean rightIsNullable = joinType.isLeftOuterJoin() || joinType.isFullOuterJoin();
for (AggregateFunction aggFunc : context.getAggFunctions()) {
if (aggFunc instanceof Count) {
Set<Slot> countInputSlots = aggFunc.getInputSlots();
// Determine which side is nullable
boolean leftIsNullable = joinType.isRightOuterJoin() || joinType.isFullOuterJoin();
boolean rightIsNullable = joinType.isLeftOuterJoin() || joinType.isFullOuterJoin();
// Check if we're pushing to a nullable side without referencing its columns
if (toLeft && leftIsNullable) {
boolean hasLeftInput = countInputSlots.stream()
.anyMatch(slot -> join.left().getOutputSet().contains(slot));
if (!hasLeftInput) {
toLeft = false;
}
Set<Slot> inputSlots = aggFunc.getInputSlots();
if (toLeft && leftIsNullable) {
boolean hasLeftInput = inputSlots.stream()
.anyMatch(slot -> join.left().getOutputSet().contains(slot));
if (!hasLeftInput) {
toLeft = false;
}
if (toRight && rightIsNullable) {
boolean hasRightInput = countInputSlots.stream()
.anyMatch(slot -> join.right().getOutputSet().contains(slot));
if (!hasRightInput) {
toRight = false;
}
}
if (toRight && rightIsNullable) {
boolean hasRightInput = inputSlots.stream()
.anyMatch(slot -> join.right().getOutputSet().contains(slot));
if (!hasRightInput) {
toRight = false;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,4 +311,59 @@ void testAsofJoinNotPushAgg() {
connectContext.getSessionVariable().setDisableJoinReorder(false);
}
}

@Test
void testNotPushAggLiteralToNullableSideOfOuterJoin() {
// sum(literal), min(literal), max(literal) aggregate over all physical rows,
// including null-extended rows from the outer join.
// Pushing to the nullable side loses the contribution of unmatched rows:
// original: sum(2) on unmatched row = 2
// pushed: sum(NULL) skips the row (wrong!)
// So agg(literal) must NOT be pushed to the nullable side.
connectContext.getSessionVariable().setEagerAggregationMode(1);
connectContext.getSessionVariable().setDisableJoinReorder(true);
try {
// RIGHT JOIN: t1 is the nullable side (left side of RIGHT JOIN)
// sum(2) should NOT be pushed to t1
String sql = "select sum(2), t2.id2 from t1 right join t2"
+ " on t1.id1 = t2.id2 group by t2.id2";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.nonMatch(logicalJoin(logicalAggregate(), any()))
.printlnTree();

// LEFT JOIN: t2 is the nullable side (right side of LEFT JOIN)
// min(1) should NOT be pushed to t2
sql = "select min(1), t1.id1 from t1 left join t2"
+ " on t1.id1 = t2.id2 group by t1.id1";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.nonMatch(logicalJoin(any(), logicalAggregate()))
.printlnTree();

// RIGHT JOIN: max(3) should NOT be pushed to nullable left side
sql = "select max(3), t2.id2 from t1 right join t2"
+ " on t1.id1 = t2.id2 group by t2.id2";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.nonMatch(logicalJoin(logicalAggregate(), any()))
.printlnTree();

// Verify agg(nullable_side_col) is still safe to push (no regression)
// max(t1.name) references the left (nullable) side, so push is allowed
sql = "select max(t1.name), t2.id2 from t1 right join t2"
+ " on t1.id1 = t2.id2 group by t2.id2";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalAggregate(logicalProject(logicalJoin(logicalAggregate(), any()))))
.printlnTree();
} finally {
connectContext.getSessionVariable().setEagerAggregationMode(0);
connectContext.getSessionVariable().setDisableJoinReorder(false);
}
}
}
132 changes: 94 additions & 38 deletions regression-test/data/nereids_p0/eager_agg/eager_agg.out
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ PhysicalResultSink
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
--------------PhysicalOlapScan[store_sales]
--------------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[store_sales(ss)]
--------------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -28,9 +28,9 @@ PhysicalResultSink
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
--------------PhysicalOlapScan[store_sales]
--------------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[store_sales(ss)]
--------------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -49,9 +49,9 @@ PhysicalResultSink
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
--------------PhysicalOlapScan[store_sales]
--------------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[store_sales(ss)]
--------------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -68,9 +68,9 @@ PhysicalResultSink
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
--------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
----------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
----------PhysicalOlapScan[store_sales(ss)]
----------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -87,11 +87,11 @@ PhysicalResultSink
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
--------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
----------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[store_sales(ss)]
----------hashAgg[GLOBAL]
------------hashAgg[LOCAL]
--------------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -110,9 +110,9 @@ PhysicalResultSink
--------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
----------hashAgg[GLOBAL]
------------hashAgg[LOCAL]
--------------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[store_sales(ss)]
----------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -131,9 +131,9 @@ PhysicalResultSink
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
--------------PhysicalOlapScan[store_sales]
--------------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[store_sales(ss)]
--------------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -152,9 +152,9 @@ PhysicalResultSink
--------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
----------hashAgg[GLOBAL]
------------hashAgg[LOCAL]
--------------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[web_sales]
--------PhysicalOlapScan[date_dim]
--------------PhysicalOlapScan[store_sales(ss)]
----------PhysicalOlapScan[web_sales(ws)]
--------PhysicalOlapScan[date_dim(dt)]

Hint log:
Used: leading({ ss broadcast ws } broadcast dt )
Expand All @@ -173,9 +173,9 @@ PhysicalResultSink
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
--------------PhysicalOlapScan[store_sales]
--------------PhysicalOlapScan[date_dim]
--------PhysicalOlapScan[web_sales]
--------------PhysicalOlapScan[store_sales(ss)]
--------------PhysicalOlapScan[date_dim(dt)]
--------PhysicalOlapScan[web_sales(ws)]

Hint log:
Used: leading({ ss broadcast dt } broadcast ws )
Expand All @@ -197,9 +197,9 @@ PhysicalResultSink
--------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
----------hashAgg[GLOBAL]
------------hashAgg[LOCAL]
--------------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[date_dim]
--------PhysicalOlapScan[web_sales]
--------------PhysicalOlapScan[store_sales(ss)]
----------PhysicalOlapScan[date_dim(dt)]
--------PhysicalOlapScan[web_sales(ws)]

Hint log:
Used: leading({ ss broadcast dt } broadcast ws )
Expand Down Expand Up @@ -266,11 +266,11 @@ PhysicalResultSink
------hashAgg[LOCAL]
--------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
----------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
------------PhysicalOlapScan[store_sales]
------------PhysicalOlapScan[store_sales(ss)]
------------hashAgg[GLOBAL]
--------------hashAgg[LOCAL]
----------------PhysicalOlapScan[date_dim]
----------PhysicalOlapScan[web_sales]
----------------PhysicalOlapScan[date_dim(dt)]
----------PhysicalOlapScan[web_sales(ws)]

Hint log:
Used: leading({ ss broadcast dt } broadcast ws )
Expand All @@ -287,9 +287,9 @@ PhysicalResultSink
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((ss.ss_item_sk = ws.ws_item_sk)) otherCondition=()
--------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
----------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[date_dim]
--------PhysicalOlapScan[web_sales]
----------PhysicalOlapScan[store_sales(ss)]
----------PhysicalOlapScan[date_dim(dt)]
--------PhysicalOlapScan[web_sales(ws)]

Hint log:
Used: leading({ ss broadcast dt } broadcast ws )
Expand All @@ -302,12 +302,68 @@ PhysicalResultSink
----hashAgg[LOCAL]
------PhysicalUnion
--------hashJoin[INNER_JOIN] hashCondition=((dt.d_date_sk = ss.ss_sold_date_sk)) otherCondition=()
----------PhysicalOlapScan[store_sales]
----------PhysicalOlapScan[date_dim]
----------PhysicalOlapScan[store_sales(ss)]
----------PhysicalOlapScan[date_dim(dt)]
--------PhysicalOlapScan[date_dim]

Hint log:
Used:
UnUsed:
SyntaxError: leading({ ss broadcast dt } broadcast ws) Msg:can not find table: ws

-- !check_sum_literal_right_join_not_push --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.val = c.val) and (b.id2 = c.id2)) otherCondition=()
--------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.id = b.id)) otherCondition=()
----------PhysicalOlapScan[eager_agg_t1(a)]
----------PhysicalOlapScan[eager_agg_t2(b)]
--------PhysicalOlapScan[eager_agg_t3(c)]

-- !check_sum_literal_left_join_not_push --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[LEFT_OUTER_JOIN] hashCondition=((date_dim.d_date_sk = store_sales.ss_sold_date_sk)) otherCondition=()
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------PhysicalOlapScan[store_sales]
--------PhysicalOlapScan[date_dim]

-- !check_min_literal_right_join_not_push --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.val = c.val) and (b.id2 = c.id2)) otherCondition=()
--------hashJoin[RIGHT_OUTER_JOIN] hashCondition=((a.id = b.id)) otherCondition=()
----------PhysicalOlapScan[eager_agg_t1(a)]
----------PhysicalOlapScan[eager_agg_t2(b)]
--------PhysicalOlapScan[eager_agg_t3(c)]

-- !check_max_literal_left_join_not_push --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[LEFT_OUTER_JOIN] hashCondition=((date_dim.d_date_sk = store_sales.ss_sold_date_sk)) otherCondition=()
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------PhysicalOlapScan[store_sales]
--------PhysicalOlapScan[date_dim]

-- !sum_literal_right_join_eager_off --
\N 4
10 2

-- !sum_literal_right_join_eager_on --
\N 4
10 2

-- !min_literal_right_join_eager_on --
\N 1
10 1

-- !max_literal_right_join_eager_on --
\N 3
10 3

Loading
Loading