Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject;
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscom;
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscomProject;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTranspose;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose;
import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
import org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation;
Expand Down Expand Up @@ -55,6 +58,9 @@ public class RuleSet {
.add(JoinCommuteProject.LEFT_DEEP)
.add(JoinLAsscom.INNER)
.add(JoinLAsscomProject.INNER)
.add(SemiJoinLogicalJoinTranspose.LEFT_DEEP)
.add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP)
.add(SemiJoinSemiJoinTranspose.INSTANCE)
.add(new PushdownFilterThroughProject())
.add(new MergeConsecutiveProjects())
.build();
Expand Down Expand Up @@ -140,6 +146,11 @@ public RuleFactories add(RuleFactory ruleFactory) {
return this;
}

public RuleFactories addAll(List<Rule> rules) {
this.rules.addAll(rules);
return this;
}

public List<Rule> build() {
return rules.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ public enum RuleType {
OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE),
// Pushdown filter
PUSHDOWN_FILTER_THROUGH_PROJET(RuleTypeClass.REWRITE),
LOGICAL_LIMIT_TO_LOGICAL_EMPTY_RELATION_RULE(RuleTypeClass.REWRITE),
SWAP_LIMIT_PROJECT(RuleTypeClass.REWRITE),
PUSHDOWN_PROJECT_THROUGHT_LIMIT(RuleTypeClass.REWRITE),
REWRITE_SENTINEL(RuleTypeClass.REWRITE),

// limit push down
Expand All @@ -122,7 +121,11 @@ public enum RuleType {
LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION),
LOGICAL_LEFT_JOIN_ASSOCIATIVE(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_L_ASSCOM(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_L_ASSCOM_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE(RuleTypeClass.EXPLORATION),

// implementation rules
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,6 @@ public Rule build() {
return null;
}
return helper.newTopJoin();
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM_PROJECT);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.base.Preconditions;

import java.util.List;
import java.util.Set;

/**
Expand All @@ -42,17 +44,36 @@
* which operands actually participate in the semi-join.
*/
public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {

public static final SemiJoinLogicalJoinTranspose LEFT_DEEP = new SemiJoinLogicalJoinTranspose(true);

public static final SemiJoinLogicalJoinTranspose ALL = new SemiJoinLogicalJoinTranspose(false);

private final boolean leftDeep;

public SemiJoinLogicalJoinTranspose(boolean leftDeep) {
this.leftDeep = leftDeep;
}

@Override
public Rule build() {
return leftSemiLogicalJoin(logicalJoin(), group())
.whenNot(topJoin -> topJoin.left().getJoinType().isSemiOrAntiJoin())
.when(this::conditionChecker)
.then(topSemiJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topSemiJoin.left();
GroupPlan a = bottomJoin.left();
GroupPlan b = bottomJoin.right();
GroupPlan c = topSemiJoin.right();

boolean lasscom = bottomJoin.getOutputSet().containsAll(a.getOutput());
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
Set<Slot> aOutputSet = a.getOutputSet();

boolean lasscom = false;
for (Expression hashJoinConjunct : hashJoinConjuncts) {
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
lasscom = ExpressionUtils.isIntersecting(usedSlot, aOutputSet) || lasscom;
}

if (lasscom) {
/*
Expand Down Expand Up @@ -81,20 +102,27 @@ public Rule build() {
return new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin);
}
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE);
}

// bottomJoin just return A OR B, else return false.
private boolean conditionChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin) {
Set<Slot> bottomOutputSet = topJoin.left().getOutputSet();

Set<Slot> aOutputSet = topJoin.left().left().getOutputSet();
Set<Slot> bOutputSet = topJoin.left().right().getOutputSet();
private boolean conditionChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topSemiJoin) {
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();

boolean isProjectA = !ExpressionUtils.isIntersecting(bottomOutputSet, aOutputSet);
boolean isProjectB = !ExpressionUtils.isIntersecting(bottomOutputSet, bOutputSet);
List<Slot> aOutput = topSemiJoin.left().left().getOutput();
List<Slot> bOutput = topSemiJoin.left().right().getOutput();

Preconditions.checkState(isProjectA || isProjectB, "join output must contain child");
return !(isProjectA && isProjectB);
boolean hashContainsA = false;
boolean hashContainsB = false;
for (Expression hashJoinConjunct : hashJoinConjuncts) {
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
hashContainsA = ExpressionUtils.isIntersecting(usedSlot, aOutput) || hashContainsA;
hashContainsB = ExpressionUtils.isIntersecting(usedSlot, bOutput) || hashContainsB;
}
if (leftDeep && hashContainsB) {
return false;
}
Preconditions.checkState(hashContainsA || hashContainsB, "join output must contain child");
return !(hashContainsA && hashContainsB);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.base.Preconditions;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

Expand All @@ -45,9 +47,20 @@
* which operands actually participate in the semi-join.
*/
public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFactory {
public static final SemiJoinLogicalJoinTransposeProject LEFT_DEEP = new SemiJoinLogicalJoinTransposeProject(true);

public static final SemiJoinLogicalJoinTransposeProject ALL = new SemiJoinLogicalJoinTransposeProject(false);

private final boolean leftDeep;

public SemiJoinLogicalJoinTransposeProject(boolean leftDeep) {
this.leftDeep = leftDeep;
}

@Override
public Rule build() {
return leftSemiLogicalJoin(logicalProject(logicalJoin()), group())
.whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin())
.when(this::conditionChecker)
.then(topSemiJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topSemiJoin.left();
Expand All @@ -56,67 +69,77 @@ public Rule build() {
GroupPlan b = bottomJoin.right();
GroupPlan c = topSemiJoin.right();

boolean lasscom = a.getOutputSet().containsAll(project.getOutput());
Set<Slot> aOutputSet = a.getOutputSet();

List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();

boolean lasscom = false;
for (Expression hashJoinConjunct : hashJoinConjuncts) {
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
lasscom = ExpressionUtils.isIntersecting(usedSlot, aOutputSet) || lasscom;
}

if (lasscom) {
/*-
* topSemiJoin newTopProject
* / \ |
* topSemiJoin project
* / \ |
* project C newTopJoin
* | -> / \
* bottomJoin newBottomSemiJoin B
* bottomJoin newBottomSemiJoin B
* / \ / \
* A B aNewProject C
* |
* A
* A B A C
*/
List<NamedExpression> projects = project.getProjects();
LogicalProject<GroupPlan> aNewProject = new LogicalProject<>(projects, a);
LogicalJoin<LogicalProject<GroupPlan>, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
topSemiJoin.getOtherJoinCondition(), aNewProject, c);
LogicalJoin<LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>, GroupPlan> newTopJoin
= new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(), newBottomSemiJoin, b);
return new LogicalProject<>(projects, newTopJoin);
topSemiJoin.getOtherJoinCondition(), a, c);

LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinCondition(),
newBottomSemiJoin, b);

return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
} else {
/*-
* topSemiJoin newTopProject
* / \ |
* project C newTopJoin
* | / \
* bottomJoin C --> A newBottomSemiJoin
* / \ / \
* A B bNewProject C
* |
* B
* topSemiJoin project
* / \ |
* project C newTopJoin
* | / \
* bottomJoin C --> A newBottomSemiJoin
* / \ / \
* A B B C
*/
List<NamedExpression> projects = project.getProjects();
LogicalProject<GroupPlan> bNewProject = new LogicalProject<>(projects, b);
LogicalJoin<LogicalProject<GroupPlan>, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
topSemiJoin.getOtherJoinCondition(), bNewProject, c);
topSemiJoin.getOtherJoinCondition(), b, c);

LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinCondition(),
a, newBottomSemiJoin);

LogicalJoin<GroupPlan, LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>> newTopJoin
= new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin);
return new LogicalProject<>(projects, newTopJoin);
return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
}
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT);
}

// bottomJoin just return A OR B, else return false.
// project of bottomJoin just return A OR B, else return false.
private boolean conditionChecker(
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topJoin) {
Set<Slot> projectOutputSet = topJoin.left().getOutputSet();

Set<Slot> aOutputSet = topJoin.left().child().left().getOutputSet();
Set<Slot> bOutputSet = topJoin.left().child().right().getOutputSet();
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topSemiJoin) {
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();

boolean isProjectA = !ExpressionUtils.isIntersecting(projectOutputSet, aOutputSet);
boolean isProjectB = !ExpressionUtils.isIntersecting(projectOutputSet, bOutputSet);
List<Slot> aOutput = topSemiJoin.left().child().left().getOutput();
List<Slot> bOutput = topSemiJoin.left().child().right().getOutput();

Preconditions.checkState(isProjectA || isProjectB, "project must contain child");
return !(isProjectA && isProjectB);
boolean hashContainsA = false;
boolean hashContainsB = false;
for (Expression hashJoinConjunct : hashJoinConjuncts) {
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
hashContainsA = ExpressionUtils.isIntersecting(usedSlot, aOutput) || hashContainsA;
hashContainsB = ExpressionUtils.isIntersecting(usedSlot, bOutput) || hashContainsB;
}
if (leftDeep && hashContainsB) {
return false;
}
Preconditions.checkState(hashContainsA || hashContainsB, "join output must contain child");
return !(hashContainsA && hashContainsB);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
* LEFT-Semi/ANTI(X, LEFT-Semi/ANTI(Y, Z))
*/
public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory {
public static final SemiJoinSemiJoinTranspose INSTANCE = new SemiJoinSemiJoinTranspose();

public static Set<Pair<JoinType, JoinType>> typeSet = ImmutableSet.of(
Copy link
Contributor

@morrySnow morrySnow Sep 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit typeSet -> VALID_TYPE_PAIR_SET. pls fix it in next PR

Pair.of(JoinType.LEFT_SEMI_JOIN, JoinType.LEFT_SEMI_JOIN),
Expand Down Expand Up @@ -69,7 +70,7 @@ public Rule build() {
newBottomJoin, b);

return newTopJoin;
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE);
}

private boolean typeChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,6 @@ public Rule build() {
return new LogicalLimit<LogicalProject<GroupPlan>>(logicalLimit.getLimit(),
logicalLimit.getOffset(), new LogicalProject<>(logicalProject.getProjects(),
logicalLimit.child()));
}).toRule(RuleType.SWAP_LIMIT_PROJECT);
}).toRule(RuleType.PUSHDOWN_PROJECT_THROUGHT_LIMIT);
}
}
Loading