diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java index e52def0723cccf..d45cc5676fe656 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java @@ -80,28 +80,28 @@ public class PushDownJoinOnAssertNumRows extends OneRewriteRuleFactory { @Override public Rule build() { return logicalJoin() - .when(topJoin -> pattenCheck(topJoin)) - .then(topJoin -> pushDownAssertNumRowsJoin(topJoin)) + .when(this::pattenCheck) + .then(this::pushDownAssertNumRowsJoin) .toRule(RuleType.PUSH_DOWN_JOIN_ON_ASSERT_NUM_ROWS); } - private boolean pattenCheck(LogicalJoin topJoin) { + private boolean pattenCheck(LogicalJoin topJoin) { // 1. right is LogicalAssertNumRows or LogicalProject->LogicalAssertNumRows // 2. left is join or project->join // 3. only one join condition. if (!topJoin.getJoinType().isInnerOrCrossJoin()) { return false; } - LogicalJoin bottomJoin; + LogicalJoin bottomJoin; Plan left = topJoin.left(); Plan right = topJoin.right(); if (!isAssertOneRowEqOrProjectAssertOneRowEq(right)) { return false; } if (left instanceof LogicalJoin) { - bottomJoin = (LogicalJoin) left; + bottomJoin = (LogicalJoin) left; } else if (left instanceof LogicalProject && left.child(0) instanceof LogicalJoin) { - bottomJoin = (LogicalJoin) left.child(0); + bottomJoin = (LogicalJoin) left.child(0); } else { return false; } @@ -125,7 +125,7 @@ private boolean isAssertOneRowEqOrProjectAssertOneRowEq(Plan plan) { plan = plan.child(0); } if (plan instanceof LogicalAssertNumRows) { - AssertNumRowsElement assertNumRowsElement = ((LogicalAssertNumRows) plan).getAssertNumRowsElement(); + AssertNumRowsElement assertNumRowsElement = ((LogicalAssertNumRows) plan).getAssertNumRowsElement(); if (assertNumRowsElement.getAssertion() == AssertNumRowsElement.Assertion.EQ || assertNumRowsElement.getDesiredNumOfRows() == 1L) { return true; @@ -134,14 +134,14 @@ private boolean isAssertOneRowEqOrProjectAssertOneRowEq(Plan plan) { return false; } - private boolean joinOnAssertOneRowEq(LogicalJoin join) { + private boolean joinOnAssertOneRowEq(LogicalJoin join) { return isAssertOneRowEqOrProjectAssertOneRowEq(join.right()) || isAssertOneRowEqOrProjectAssertOneRowEq(join.left()); } - private Plan pushDownAssertNumRowsJoin(LogicalJoin topJoin) { + private Plan pushDownAssertNumRowsJoin(LogicalJoin topJoin) { Plan assertBranch = topJoin.right(); - Expression condition = (Expression) topJoin.getOtherJoinConjuncts().get(0); + Expression condition = topJoin.getOtherJoinConjuncts().get(0); List aliasUsedInConditionFromLeftProject = new ArrayList<>(); LogicalJoin bottomJoin; if (topJoin.left() instanceof LogicalProject) { @@ -160,59 +160,49 @@ private Plan pushDownAssertNumRowsJoin(LogicalJoin topJoin) { Plan bottomRight = bottomJoin.right(); List conditionSlotsFromTopLeft = condition.getInputSlots().stream() - .filter(slot -> topJoin.left().getOutputSet().contains(slot)) + .filter(slot -> bottomJoin.getOutputSet().contains(slot)) .collect(Collectors.toList()); + // Nothing from the bottom join participates in this scalar-subquery condition. + if (conditionSlotsFromTopLeft.isEmpty()) { + return null; + } if (bottomLeft.getOutputSet().containsAll(conditionSlotsFromTopLeft)) { - // push to bottomLeft - Plan newBottomLeft; - if (aliasUsedInConditionFromLeftProject.isEmpty()) { - newBottomLeft = bottomLeft; - } else { - newBottomLeft = projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottomLeft); - } - LogicalJoin newBottomJoin = new LogicalJoin<>( - topJoin.getJoinType(), - topJoin.getHashJoinConjuncts(), - topJoin.getOtherJoinConjuncts(), - newBottomLeft, - assertBranch, - topJoin.getJoinReorderContext()); - LogicalJoin newTopJoin = (LogicalJoin) - bottomJoin.withChildren(newBottomJoin, bottomRight); - if (topJoin.left() instanceof LogicalProject) { - LogicalProject upperProject = projectAliasOnPlan( - aliasUsedInConditionFromLeftProject, topJoin.left()); - return upperProject.withChildren(newTopJoin); - } else { - return newTopJoin; - } + return assembleNewJoin(bottomLeft, topJoin, bottomJoin, bottomRight, + assertBranch, aliasUsedInConditionFromLeftProject, true); } else if (bottomRight.getOutputSet().containsAll(conditionSlotsFromTopLeft)) { - Plan newBottomRight; - if (aliasUsedInConditionFromLeftProject.isEmpty()) { - newBottomRight = bottomRight; - } else { - newBottomRight = projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottomRight); - } - LogicalJoin newBottomJoin = new LogicalJoin<>( - topJoin.getJoinType(), - topJoin.getHashJoinConjuncts(), - topJoin.getOtherJoinConjuncts(), - newBottomRight, - assertBranch, - topJoin.getJoinReorderContext()); - LogicalJoin newTopJoin = (LogicalJoin) - bottomJoin.withChildren(bottomLeft, newBottomJoin); - if (topJoin.left() instanceof LogicalProject) { - LogicalProject upperProject = projectAliasOnPlan( - aliasUsedInConditionFromLeftProject, topJoin.left()); - return upperProject.withChildren(newTopJoin); - } else { - return newTopJoin; - } + return assembleNewJoin(bottomRight, topJoin, bottomJoin, bottomLeft, + assertBranch, aliasUsedInConditionFromLeftProject, false); } return null; } + private Plan assembleNewJoin(Plan bottom, LogicalJoin topJoin, LogicalJoin bottomJoin, Plan newTopChild, + Plan assertBranch, List aliasUsedInConditionFromLeftProject, boolean pushLeft) { + Plan newBottomChild; + if (aliasUsedInConditionFromLeftProject.isEmpty()) { + newBottomChild = bottom; + } else { + newBottomChild = projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottom); + } + LogicalJoin newBottomJoin = new LogicalJoin<>( + topJoin.getJoinType(), + topJoin.getHashJoinConjuncts(), + topJoin.getOtherJoinConjuncts(), + newBottomChild, + assertBranch, + topJoin.getJoinReorderContext()); + LogicalJoin newTopJoin = (LogicalJoin) + (pushLeft ? bottomJoin.withChildren(newBottomJoin, newTopChild) + : bottomJoin.withChildren(newTopChild, newBottomJoin)); + if (topJoin.left() instanceof LogicalProject) { + LogicalProject upperProject = projectAliasOnPlan( + aliasUsedInConditionFromLeftProject, topJoin.left()); + return upperProject.withChildren(newTopJoin); + } else { + return newTopJoin; + } + } + @VisibleForTesting LogicalProject projectAliasOnPlan(List projections, Plan child) { if (child instanceof LogicalProject) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java index aded31bd18fcbf..d241433a2191ef 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; @@ -253,6 +254,71 @@ void testPushDownWithProjectNode() { logicalOlapScan()))); } + /** + * Test push down when the top join condition uses an alias from the right child + * of the bottom join. This covers the following shape: + * + * Before: + * topJoin(rhs_score < x) + * |-- Project(T1.id, T2.cid + 1 as rhs_score, ...) + * | `-- bottomJoin(T1.id = T2.sid) + * | |-- Scan(T1) + * | `-- Scan(T2) + * `-- LogicalAssertNumRows(output=(x, ...)) + * + * After: + * Project(...) + * `-- bottomJoin(T1.id = T2.sid) + * |-- Scan(T1) + * `-- topJoin(rhs_score < x) + * |-- Project(T2.cid + 1 as rhs_score, ...) + * | `-- Scan(T2) + * `-- LogicalAssertNumRows(output=(x, ...)) + */ + @Test + void testPushDownWithProjectAliasFromRightChild() { + Plan oneRowRelation = new LogicalPlanBuilder(t3) + .limit(1) + .build(); + + AssertNumRowsElement assertElement = new AssertNumRowsElement(1, "", Assertion.EQ); + LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows<>(assertElement, oneRowRelation); + + Expression bottomJoinCondition = new EqualTo(t1Slots.get(0), t2Slots.get(0)); + + LogicalPlan bottomJoin = new LogicalPlanBuilder(t1) + .join(t2, JoinType.INNER_JOIN, ImmutableList.of(bottomJoinCondition), + ImmutableList.of()) + .build(); + + Expression addExpr = new Add(t2Slots.get(1), Literal.of(1)); + Alias rhsScore = new Alias(addExpr, "rhs_score"); + + ImmutableList.Builder projectListBuilder = ImmutableList.builder(); + projectListBuilder.add(t1Slots.get(0)); + projectListBuilder.add(t1Slots.get(1)); + projectListBuilder.add(t2Slots.get(0)); + projectListBuilder.add(rhsScore); + + LogicalProject project = new LogicalProject<>(projectListBuilder.build(), bottomJoin); + + Expression topJoinCondition = new LessThan(rhsScore.toSlot(), t3Slots.get(0)); + + LogicalPlan root = new LogicalPlanBuilder(project) + .join(assertNumRows, JoinType.INNER_JOIN, ImmutableList.of(), + ImmutableList.of(topJoinCondition)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new PushDownJoinOnAssertNumRows()) + .matches(logicalProject( + logicalJoin( + logicalOlapScan(), + logicalJoin( + logicalProject(logicalOlapScan()), + logicalAssertNumRows())))); + } + /** * Test with CROSS JOIN type. */