Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,28 +80,28 @@ public class PushDownJoinOnAssertNumRows extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalJoin()
.when(topJoin -> pattenCheck(topJoin))
.then(topJoin -> pushDownAssertNumRowsJoin(topJoin))
.when(this::pattenCheck)
.then(this::pushDownAssertNumRowsJoin)
.toRule(RuleType.PUSH_DOWN_JOIN_ON_ASSERT_NUM_ROWS);
}

private boolean pattenCheck(LogicalJoin topJoin) {
private boolean pattenCheck(LogicalJoin<?, ?> topJoin) {
// 1. right is LogicalAssertNumRows or LogicalProject->LogicalAssertNumRows
// 2. left is join or project->join
// 3. only one join condition.
if (!topJoin.getJoinType().isInnerOrCrossJoin()) {
return false;
}
LogicalJoin bottomJoin;
LogicalJoin<?, ?> bottomJoin;
Plan left = topJoin.left();
Plan right = topJoin.right();
if (!isAssertOneRowEqOrProjectAssertOneRowEq(right)) {
return false;
}
if (left instanceof LogicalJoin) {
bottomJoin = (LogicalJoin) left;
bottomJoin = (LogicalJoin<?, ?>) left;
} else if (left instanceof LogicalProject && left.child(0) instanceof LogicalJoin) {
bottomJoin = (LogicalJoin) left.child(0);
bottomJoin = (LogicalJoin<?, ?>) left.child(0);
} else {
return false;
}
Expand All @@ -125,7 +125,7 @@ private boolean isAssertOneRowEqOrProjectAssertOneRowEq(Plan plan) {
plan = plan.child(0);
}
if (plan instanceof LogicalAssertNumRows) {
AssertNumRowsElement assertNumRowsElement = ((LogicalAssertNumRows) plan).getAssertNumRowsElement();
AssertNumRowsElement assertNumRowsElement = ((LogicalAssertNumRows<?>) plan).getAssertNumRowsElement();
if (assertNumRowsElement.getAssertion() == AssertNumRowsElement.Assertion.EQ
|| assertNumRowsElement.getDesiredNumOfRows() == 1L) {
return true;
Expand All @@ -134,14 +134,14 @@ private boolean isAssertOneRowEqOrProjectAssertOneRowEq(Plan plan) {
return false;
}

private boolean joinOnAssertOneRowEq(LogicalJoin join) {
private boolean joinOnAssertOneRowEq(LogicalJoin<?, ?> join) {
return isAssertOneRowEqOrProjectAssertOneRowEq(join.right())
|| isAssertOneRowEqOrProjectAssertOneRowEq(join.left());
}

private Plan pushDownAssertNumRowsJoin(LogicalJoin topJoin) {
private Plan pushDownAssertNumRowsJoin(LogicalJoin<?, ?> topJoin) {
Plan assertBranch = topJoin.right();
Expression condition = (Expression) topJoin.getOtherJoinConjuncts().get(0);
Expression condition = topJoin.getOtherJoinConjuncts().get(0);
List<Alias> aliasUsedInConditionFromLeftProject = new ArrayList<>();
LogicalJoin<? extends Plan, ? extends Plan> bottomJoin;
if (topJoin.left() instanceof LogicalProject) {
Expand All @@ -160,59 +160,49 @@ private Plan pushDownAssertNumRowsJoin(LogicalJoin topJoin) {
Plan bottomRight = bottomJoin.right();

List<Slot> conditionSlotsFromTopLeft = condition.getInputSlots().stream()
.filter(slot -> topJoin.left().getOutputSet().contains(slot))
.filter(slot -> bottomJoin.getOutputSet().contains(slot))
.collect(Collectors.toList());
// Nothing from the bottom join participates in this scalar-subquery condition.
if (conditionSlotsFromTopLeft.isEmpty()) {
return null;
}
if (bottomLeft.getOutputSet().containsAll(conditionSlotsFromTopLeft)) {
// push to bottomLeft
Plan newBottomLeft;
if (aliasUsedInConditionFromLeftProject.isEmpty()) {
newBottomLeft = bottomLeft;
} else {
newBottomLeft = projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottomLeft);
}
LogicalJoin<? extends Plan, ? extends Plan> newBottomJoin = new LogicalJoin<>(
topJoin.getJoinType(),
topJoin.getHashJoinConjuncts(),
topJoin.getOtherJoinConjuncts(),
newBottomLeft,
assertBranch,
topJoin.getJoinReorderContext());
LogicalJoin<? extends Plan, ? extends Plan> newTopJoin = (LogicalJoin<? extends Plan, ? extends Plan>)
bottomJoin.withChildren(newBottomJoin, bottomRight);
if (topJoin.left() instanceof LogicalProject) {
LogicalProject<? extends Plan> upperProject = projectAliasOnPlan(
aliasUsedInConditionFromLeftProject, topJoin.left());
return upperProject.withChildren(newTopJoin);
} else {
return newTopJoin;
}
return assembleNewJoin(bottomLeft, topJoin, bottomJoin, bottomRight,
assertBranch, aliasUsedInConditionFromLeftProject, true);
} else if (bottomRight.getOutputSet().containsAll(conditionSlotsFromTopLeft)) {
Plan newBottomRight;
if (aliasUsedInConditionFromLeftProject.isEmpty()) {
newBottomRight = bottomRight;
} else {
newBottomRight = projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottomRight);
}
LogicalJoin<? extends Plan, ? extends Plan> newBottomJoin = new LogicalJoin<>(
topJoin.getJoinType(),
topJoin.getHashJoinConjuncts(),
topJoin.getOtherJoinConjuncts(),
newBottomRight,
assertBranch,
topJoin.getJoinReorderContext());
LogicalJoin<? extends Plan, ? extends Plan> newTopJoin = (LogicalJoin<? extends Plan, ? extends Plan>)
bottomJoin.withChildren(bottomLeft, newBottomJoin);
if (topJoin.left() instanceof LogicalProject) {
LogicalProject<? extends Plan> upperProject = projectAliasOnPlan(
aliasUsedInConditionFromLeftProject, topJoin.left());
return upperProject.withChildren(newTopJoin);
} else {
return newTopJoin;
}
return assembleNewJoin(bottomRight, topJoin, bottomJoin, bottomLeft,
assertBranch, aliasUsedInConditionFromLeftProject, false);
}
return null;
}

private Plan assembleNewJoin(Plan bottom, LogicalJoin<?, ?> topJoin, LogicalJoin<?, ?> bottomJoin, Plan newTopChild,
Plan assertBranch, List<Alias> aliasUsedInConditionFromLeftProject, boolean pushLeft) {
Plan newBottomChild;
if (aliasUsedInConditionFromLeftProject.isEmpty()) {
newBottomChild = bottom;
} else {
newBottomChild = projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottom);
}
LogicalJoin<? extends Plan, ? extends Plan> newBottomJoin = new LogicalJoin<>(
topJoin.getJoinType(),
topJoin.getHashJoinConjuncts(),
topJoin.getOtherJoinConjuncts(),
newBottomChild,
assertBranch,
topJoin.getJoinReorderContext());
LogicalJoin<? extends Plan, ? extends Plan> newTopJoin = (LogicalJoin<? extends Plan, ? extends Plan>)
(pushLeft ? bottomJoin.withChildren(newBottomJoin, newTopChild)
: bottomJoin.withChildren(newTopChild, newBottomJoin));
if (topJoin.left() instanceof LogicalProject) {
LogicalProject<? extends Plan> upperProject = projectAliasOnPlan(
aliasUsedInConditionFromLeftProject, topJoin.left());
return upperProject.withChildren(newTopJoin);
} else {
return newTopJoin;
}
}

@VisibleForTesting
LogicalProject<? extends Plan> projectAliasOnPlan(List<Alias> projections, Plan child) {
if (child instanceof LogicalProject) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
Expand Down Expand Up @@ -253,6 +254,71 @@ void testPushDownWithProjectNode() {
logicalOlapScan())));
}

/**
* Test push down when the top join condition uses an alias from the right child
* of the bottom join. This covers the following shape:
*
* Before:
* topJoin(rhs_score < x)
* |-- Project(T1.id, T2.cid + 1 as rhs_score, ...)
* | `-- bottomJoin(T1.id = T2.sid)
* | |-- Scan(T1)
* | `-- Scan(T2)
* `-- LogicalAssertNumRows(output=(x, ...))
*
* After:
* Project(...)
Comment thread
morrySnow marked this conversation as resolved.
* `-- bottomJoin(T1.id = T2.sid)
* |-- Scan(T1)
* `-- topJoin(rhs_score < x)
* |-- Project(T2.cid + 1 as rhs_score, ...)
* | `-- Scan(T2)
* `-- LogicalAssertNumRows(output=(x, ...))
*/
@Test
void testPushDownWithProjectAliasFromRightChild() {
Plan oneRowRelation = new LogicalPlanBuilder(t3)
.limit(1)
.build();

AssertNumRowsElement assertElement = new AssertNumRowsElement(1, "", Assertion.EQ);
LogicalAssertNumRows<Plan> assertNumRows = new LogicalAssertNumRows<>(assertElement, oneRowRelation);

Expression bottomJoinCondition = new EqualTo(t1Slots.get(0), t2Slots.get(0));

LogicalPlan bottomJoin = new LogicalPlanBuilder(t1)
.join(t2, JoinType.INNER_JOIN, ImmutableList.of(bottomJoinCondition),
ImmutableList.of())
.build();

Expression addExpr = new Add(t2Slots.get(1), Literal.of(1));
Alias rhsScore = new Alias(addExpr, "rhs_score");

ImmutableList.Builder<NamedExpression> projectListBuilder = ImmutableList.builder();
projectListBuilder.add(t1Slots.get(0));
projectListBuilder.add(t1Slots.get(1));
projectListBuilder.add(t2Slots.get(0));
projectListBuilder.add(rhsScore);

LogicalProject<Plan> project = new LogicalProject<>(projectListBuilder.build(), bottomJoin);

Expression topJoinCondition = new LessThan(rhsScore.toSlot(), t3Slots.get(0));

LogicalPlan root = new LogicalPlanBuilder(project)
.join(assertNumRows, JoinType.INNER_JOIN, ImmutableList.of(),
ImmutableList.of(topJoinCondition))
.build();

PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyTopDown(new PushDownJoinOnAssertNumRows())
.matches(logicalProject(
logicalJoin(
logicalOlapScan(),
logicalJoin(
logicalProject(logicalOlapScan()),
logicalAssertNumRows()))));
}

/**
* Test with CROSS JOIN type.
*/
Expand Down
Loading