Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
/**
* Rule for pushdown project through left-semi/anti join
* Just push down project inside join to avoid to push the top of Join-Cluster.
* Note this rule is only used to push down project between join for join ordering.
* <pre>
* Join Join
* | |
Expand All @@ -61,6 +62,9 @@ public List<Rule> buildRules() {
.whenNot(j -> j.left().child().hasJoinHint())
.then(topJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topJoin.left();
if (projectBothJoinSide(project)) {
return null;
}
Plan newLeft = pushdownProject(project);
return topJoin.withChildren(newLeft, topJoin.right());
}).toRule(RuleType.PUSHDOWN_PROJECT_THROUGH_SEMI_JOIN_LEFT),
Expand All @@ -72,12 +76,27 @@ public List<Rule> buildRules() {
.whenNot(j -> j.right().child().hasJoinHint())
.then(topJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topJoin.right();
if (projectBothJoinSide(project)) {
return null;
}
Plan newRight = pushdownProject(project);
return topJoin.withChildren(topJoin.left(), newRight);
}).toRule(RuleType.PUSHDOWN_PROJECT_THROUGH_SEMI_JOIN_RIGHT)
);
}

private boolean projectBothJoinSide(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
// if project contains both side of join, it can't be pushed.
// such as:
// Project(l, null as r)
// ------ L(l) left anti join R(r)
LogicalJoin<?, ?> join = project.child();
Set<Slot> projectOutput = project.getOutputSet();
boolean containLeft = join.left().getOutput().stream().anyMatch(projectOutput::contains);
boolean containRight = join.right().getOutput().stream().anyMatch(projectOutput::contains);
return containRight && containLeft;
}

private Plan pushdownProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
Set<Slot> conditionLeftSlots = CBOUtils.joinChildConditionSlots(join, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.TypeUtils;

import com.google.common.collect.ImmutableSet;

import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -45,23 +46,15 @@ public class ConvertOuterJoinToAntiJoin extends OneRewriteRuleFactory {

@Override
public Rule build() {
return logicalProject(logicalFilter(logicalJoin()
.when(join -> join.getJoinType().isOuterJoin())))
return logicalFilter(logicalJoin()
.when(join -> join.getJoinType().isOuterJoin()))
.then(this::toAntiJoin)
.toRule(RuleType.CONVERT_OUTER_JOIN_TO_ANTI);
}

private Plan toAntiJoin(LogicalProject<LogicalFilter<LogicalJoin<Plan, Plan>>> project) {
LogicalFilter<LogicalJoin<Plan, Plan>> filter = project.child();
private Plan toAntiJoin(LogicalFilter<LogicalJoin<Plan, Plan>> filter) {
LogicalJoin<Plan, Plan> join = filter.child();

boolean leftOutput = join.left().getOutputSet().containsAll(project.getInputSlots());
boolean rightOutput = join.right().getOutputSet().containsAll(project.getInputSlots());

if (!leftOutput && !rightOutput) {
return null;
}

Set<Slot> alwaysNullSlots = filter.getConjuncts().stream()
.filter(p -> TypeUtils.isNull(p).isPresent())
.flatMap(p -> p.getInputSlots().stream())
Expand All @@ -73,36 +66,33 @@ private Plan toAntiJoin(LogicalProject<LogicalFilter<LogicalJoin<Plan, Plan>>> p
.filter(s -> alwaysNullSlots.contains(s) && !s.nullable())
.collect(Collectors.toSet());

Plan res = project;
if (join.getJoinType().isLeftOuterJoin() && !rightAlwaysNullSlots.isEmpty() && leftOutput) {
// When there is right slot always null, we can turn left outer join to left anti join
Set<Expression> predicates = filter.getExpressions().stream()
.filter(p -> !(TypeUtils.isNull(p).isPresent()
&& rightAlwaysNullSlots.containsAll(p.getInputSlots())))
.collect(ImmutableSet.toImmutableSet());
boolean containRightSlot = predicates.stream()
.flatMap(p -> p.getInputSlots().stream())
.anyMatch(join.right().getOutputSet()::contains);
if (!containRightSlot) {
res = join.withJoinType(JoinType.LEFT_ANTI_JOIN);
res = predicates.isEmpty() ? res : filter.withConjuncts(predicates).withChildren(res);
res = project.withChildren(res);
}
Plan newJoin = null;
if (join.getJoinType().isLeftOuterJoin() && !rightAlwaysNullSlots.isEmpty()) {
newJoin = join.withJoinType(JoinType.LEFT_ANTI_JOIN);
}
if (join.getJoinType().isRightOuterJoin() && !leftAlwaysNullSlots.isEmpty() && rightOutput) {
Set<Expression> predicates = filter.getExpressions().stream()
.filter(p -> !(TypeUtils.isNull(p).isPresent()
&& leftAlwaysNullSlots.containsAll(p.getInputSlots())))
.collect(ImmutableSet.toImmutableSet());
boolean containLeftSlot = predicates.stream()
.flatMap(p -> p.getInputSlots().stream())
.anyMatch(join.left().getOutputSet()::contains);
if (!containLeftSlot) {
res = join.withJoinType(JoinType.RIGHT_ANTI_JOIN);
res = predicates.isEmpty() ? res : filter.withConjuncts(predicates).withChildren(res);
res = project.withChildren(res);
}
if (join.getJoinType().isRightOuterJoin() && !leftAlwaysNullSlots.isEmpty()) {
newJoin = join.withJoinType(JoinType.RIGHT_ANTI_JOIN);
}
if (newJoin == null) {
return null;
}

if (!newJoin.getOutputSet().containsAll(filter.getInputSlots())) {
// if there are slots that don't belong to join output, we use null alias to replace them
// such as:
// project(A.id, null as B.id)
// - (A left anti join B)
Set<Slot> joinOutput = newJoin.getOutputSet();
List<NamedExpression> projects = filter.getOutput().stream()
.map(s -> {
if (joinOutput.contains(s)) {
return s;
} else {
return new Alias(s.getExprId(), new NullLiteral(s.getDataType()), s.getName());
}
}).collect(Collectors.toList());
newJoin = new LogicalProject<>(projects, newJoin);
}
return res;
return filter.withChildren(newJoin);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
Expand Down Expand Up @@ -133,4 +134,35 @@ void pushComplexProject() {
)
);
}

@Test
void testProjectLiteral() {
List<NamedExpression> projectExprs = ImmutableList.of(
new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), "alias"),
new Alias(scan2.getOutput().get(0).getExprId(), new NullLiteral(), scan2.getOutput().get(0).getName())
);
// complex projection contain ti.id, which isn't in Join Condition
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(1, 1))
.projectExprs(projectExprs)
.join(scan3, JoinType.INNER_JOIN, Pair.of(1, 1))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.buildRules())
.nonMatch(logicalJoin(logicalJoin(logicalProject(), group()), group()));

projectExprs = ImmutableList.of(
new Alias(new Add(scan2.getOutput().get(0), Literal.of(1)), "alias"),
new Alias(scan1.getOutput().get(0).getExprId(), new NullLiteral(), scan2.getOutput().get(0).getName())
);
// complex projection contain ti.id, which isn't in Join Condition
plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(1, 1))
.projectExprs(projectExprs)
.join(scan3, JoinType.INNER_JOIN, Pair.of(1, 1))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.buildRules())
.nonMatch(logicalJoin(logicalJoin(logicalProject(), group()), group()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void testEliminateLeftWithRightPredicate() {
.applyTopDown(new InferFilterNotNull())
.applyTopDown(new ConvertOuterJoinToAntiJoin())
.printlnTree()
.matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin()));
.matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin()));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,25 @@ suite("transform_outer_join_to_anti") {
sql("select eliminate_outer_join_B.* from eliminate_outer_join_A right outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_A.null_a is null")
contains "OUTER JOIN"
}

explain {
sql("select eliminate_outer_join_A.* from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_B.b is null or eliminate_outer_join_A.null_a is null")
contains "OUTER JOIN"
}

explain {
sql("select * from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_B.b is null and eliminate_outer_join_A.null_a is null")
contains "ANTI JOIN"
}

explain {
sql("select * from eliminate_outer_join_A left outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_B.b is null and eliminate_outer_join_B.null_b is null")
contains "ANTI JOIN"
}

explain {
sql("select * from eliminate_outer_join_A right outer join eliminate_outer_join_B on eliminate_outer_join_B.b = eliminate_outer_join_A.a where eliminate_outer_join_A.a is null and eliminate_outer_join_B.null_b is null and eliminate_outer_join_A.null_a is null")
contains "ANTI JOIN"
}
}