Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;

import java.util.LinkedHashSet;
Expand All @@ -41,10 +42,11 @@ public class DistinctPredicatesRule extends AbstractExpressionRewriteRule {

@Override
public Expression visitCompoundPredicate(CompoundPredicate expr, ExpressionRewriteContext context) {
Preconditions.checkNotNull(expr);
List<Expression> extractExpressions = ExpressionUtils.extract(expr);
Set<Expression> distinctExpressions = new LinkedHashSet<>(extractExpressions);
if (distinctExpressions.size() != extractExpressions.size()) {
return ExpressionUtils.combine(expr.getClass(), Lists.newArrayList(distinctExpressions));
return ExpressionUtils.combine(expr.getClass(), Lists.newArrayList(distinctExpressions)).get();
}
return expr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.util.ExpressionUtils;
Expand All @@ -45,7 +47,7 @@ public class ExtractCommonFactorRule extends AbstractExpressionRewriteRule {
public Expression visitCompoundPredicate(CompoundPredicate expr, ExpressionRewriteContext context) {

Expression rewrittenChildren = ExpressionUtils.combine(expr.getClass(), ExpressionUtils.extract(expr).stream()
.map(predicate -> rewrite(predicate, context)).collect(Collectors.toList()));
.map(predicate -> rewrite(predicate, context)).collect(Collectors.toList())).get();

if (!(rewrittenChildren instanceof CompoundPredicate)) {
return rewrittenChildren;
Expand All @@ -64,14 +66,16 @@ public Expression visitCompoundPredicate(CompoundPredicate expr, ExpressionRewri
.map(predicates -> predicates.stream().filter(p -> !commons.contains(p)).collect(Collectors.toList()))
.collect(Collectors.toList());

// TODO(wenjie): add BooleanLiteral for solving empty list is tricky.
Expression combineUncorrelated = ExpressionUtils.combine(compoundPredicate.getClass(),
uncorrelated.stream()
.map(predicates -> ExpressionUtils.combine(compoundPredicate.flipType(), predicates))
.collect(Collectors.toList()));
.map(option -> option.orElse(new BooleanLiteral(compoundPredicate.flipType() == And.class)))
.collect(Collectors.toList())).get();

List<Expression> finalCompound = Lists.newArrayList(commons);
finalCompound.add(combineUncorrelated);

return ExpressionUtils.combine(compoundPredicate.flipType(), finalCompound);
return ExpressionUtils.combine(compoundPredicate.flipType(), finalCompound).get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,29 @@

/**
* Push the predicate in the LogicalFilter or LogicalJoin to the join children.
* For example:
* select a.k1,b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5 where a.k1 > 1 and b.k1 > 2
* Logical plan tree:
* project
* |
* filter (a.k1 > 1 and b.k1 > 2)
* |
* join (a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5)
* / \
* scan scan
* transformed:
* project
* |
* join (a.k1 = b.k1)
* / \
* filter(a.k1 > 1 and a.k2 > 2 ) filter(b.k1 > 2 and b.k2 > 5)
* | |
* scan scan
* todo: Now, only support eq on condition for inner join, support other case later
*/
public class PushPredicateThroughJoin extends OneRewriteRuleFactory {

/*
* For example:
* select a.k1,b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5 where a.k1 > 1 and b.k1 > 2
* Logical plan tree:
* project
* |
* filter (a.k1 > 1 and b.k1 > 2)
* |
* join (a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5)
* / \
* scan scan
* transformed:
* project
* |
* join (a.k1 = b.k1)
* / \
* filter(a.k1 > 1 and a.k2 > 2 ) filter(b.k1 > 2 and b.k2 > 5)
* | |
* scan scan
*/
@Override
public Rule build() {
return logicalFilter(innerLogicalJoin()).then(filter -> {
Expand All @@ -83,13 +84,14 @@ public Rule build() {
List<Slot> leftInput = join.left().getOutput();
List<Slot> rightInput = join.right().getOutput();

ExpressionUtils.extractConjunct(ExpressionUtils.and(onPredicates, wherePredicates)).forEach(predicate -> {
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) {
eqConditions.add(predicate);
} else {
otherConditions.add(predicate);
}
});
ExpressionUtils.extractConjunct(ExpressionUtils.and(onPredicates, wherePredicates).get())
.forEach(predicate -> {
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) {
eqConditions.add(predicate);
} else {
otherConditions.add(predicate);
}
});

List<Expression> leftPredicates = Lists.newArrayList();
List<Expression> rightPredicates = Lists.newArrayList();
Expand All @@ -111,7 +113,7 @@ public Rule build() {
otherConditions.removeAll(leftPredicates);
otherConditions.removeAll(rightPredicates);
otherConditions.addAll(eqConditions);
Expression joinConditions = ExpressionUtils.and(otherConditions);
Expression joinConditions = ExpressionUtils.and(otherConditions).get();

return pushDownPredicate(join, joinConditions, leftPredicates, rightPredicates);
}).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN);
Expand All @@ -120,18 +122,18 @@ public Rule build() {
private Plan pushDownPredicate(LogicalJoin<GroupPlan, GroupPlan> joinPlan,
Expression joinConditions, List<Expression> leftPredicates, List<Expression> rightPredicates) {

Expression left = ExpressionUtils.and(leftPredicates);
Expression right = ExpressionUtils.and(rightPredicates);
//todo expr should optimize again using expr rewrite
ExpressionRuleExecutor exprRewriter = new ExpressionRuleExecutor();
Plan leftPlan = joinPlan.left();
Plan rightPlan = joinPlan.right();
if (!left.equals(BooleanLiteral.TRUE)) {
leftPlan = new LogicalFilter(exprRewriter.rewrite(left), leftPlan);
}

if (!right.equals(BooleanLiteral.TRUE)) {
rightPlan = new LogicalFilter(exprRewriter.rewrite(right), rightPlan);
Optional<Expression> leftOption = ExpressionUtils.and(leftPredicates);
Optional<Expression> rightOption = ExpressionUtils.and(rightPredicates);
if (leftOption.isPresent()) {
leftPlan = new LogicalFilter(exprRewriter.rewrite(leftOption.get()), leftPlan);
}
if (rightOption.isPresent()) {
rightPlan = new LogicalFilter(exprRewriter.rewrite(rightOption.get()), rightPlan);
}

return new LogicalJoin<>(joinPlan.getJoinType(), Optional.of(joinConditions), leftPlan, rightPlan);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, List<Expre
if (joinConditions.isEmpty()) {
cond = Optional.empty();
} else {
cond = Optional.of(ExpressionUtils.and(joinConditions));
cond = ExpressionUtils.and(joinConditions);
}

LogicalJoin join = new LogicalJoin(JoinType.INNER_JOIN, cond, joinInputs.get(0), joinInputs.get(1));
if (nonJoinConditions.isEmpty()) {
return join;
} else {
return new LogicalFilter(ExpressionUtils.and(nonJoinConditions), join);
return new LogicalFilter(ExpressionUtils.and(nonJoinConditions).get(), join);
}
} else {
Plan left = joinInputs.get(0);
Expand Down Expand Up @@ -150,7 +150,7 @@ private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, List<Expre
if (joinConditions.isEmpty()) {
cond = Optional.empty();
} else {
cond = Optional.of(ExpressionUtils.and(joinConditions));
cond = ExpressionUtils.and(joinConditions);
}

LogicalJoin join = new LogicalJoin(JoinType.INNER_JOIN, cond, left, right);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Expression rewrite helper class.
Expand Down Expand Up @@ -64,42 +66,42 @@ private static void extract(Class<? extends Expression> type, Expression expr, L
}
}

public static Expression and(List<Expression> expressions) {
public static Optional<Expression> and(List<Expression> expressions) {
return combine(And.class, expressions);
}

public static Expression and(Expression... expressions) {
public static Optional<Expression> and(Expression... expressions) {
return combine(And.class, Lists.newArrayList(expressions));
}

public static Expression or(Expression... expressions) {
public static Optional<Expression> or(Expression... expressions) {
return combine(Or.class, Lists.newArrayList(expressions));
}

public static Expression or(List<Expression> expressions) {
public static Optional<Expression> or(List<Expression> expressions) {
return combine(Or.class, expressions);
}

/**
* Use AND/OR to combine expressions together.
*/
public static Expression combine(Class<? extends Expression> type, List<Expression> expressions) {
public static Optional<Expression> combine(Class<? extends Expression> type, List<Expression> inputExpressions) {
Preconditions.checkArgument(type == And.class || type == Or.class);
Objects.requireNonNull(expressions, "expressions is null");
Objects.requireNonNull(inputExpressions, "expressions is null");

List<Expression> expressions = inputExpressions.stream().filter(Objects::nonNull).collect(Collectors.toList());

Expression shortCircuit = (type == And.class ? BooleanLiteral.FALSE : BooleanLiteral.TRUE);
Expression skip = (type == And.class ? BooleanLiteral.TRUE : BooleanLiteral.FALSE);
Set<Expression> distinctExpressions = Sets.newLinkedHashSetWithExpectedSize(expressions.size());
for (Expression expression : expressions) {
if (expression.equals(shortCircuit)) {
return shortCircuit;
return Optional.of(shortCircuit);
} else if (!expression.equals(skip)) {
distinctExpressions.add(expression);
}
}

return distinctExpressions.stream()
.reduce(type == And.class ? And::new : Or::new)
.orElse(new BooleanLiteral(type == And.class));
return distinctExpressions.stream().reduce(type == And.class ? And::new : Or::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ public void pushDownPredicateIntoScanTest1() {
Expression onCondition1 = new EqualTo(rStudent.getOutput().get(0), rScore.getOutput().get(0));
Expression onCondition2 = new GreaterThan(rStudent.getOutput().get(0), Literal.of(1));
Expression onCondition3 = new GreaterThan(rScore.getOutput().get(0), Literal.of(2));
Expression onCondition = ExpressionUtils.and(onCondition1, onCondition2, onCondition3);
Expression onCondition = ExpressionUtils.and(onCondition1, onCondition2, onCondition3).get();

Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression whereCondition2 = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2);
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2).get();


Plan join = new LogicalJoin(JoinType.INNER_JOIN, Optional.of(onCondition), rStudent, rScore);
Expand Down Expand Up @@ -136,8 +136,8 @@ public void pushDownPredicateIntoScanTest1() {
LogicalFilter filter2 = (LogicalFilter) op3;

Assertions.assertEquals(join1.getCondition().get(), onCondition1);
Assertions.assertEquals(filter1.getPredicates(), ExpressionUtils.and(onCondition2, whereCondition1));
Assertions.assertEquals(filter2.getPredicates(), ExpressionUtils.and(onCondition3, whereCondition2));
Assertions.assertEquals(filter1.getPredicates(), ExpressionUtils.and(onCondition2, whereCondition1).get());
Assertions.assertEquals(filter2.getPredicates(), ExpressionUtils.and(onCondition3, whereCondition2).get());
}

@Test
Expand All @@ -148,7 +148,7 @@ public void pushDownPredicateIntoScanTest3() {
new Subtract(rScore.getOutput().get(0), Literal.of(2)));
Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18));
Expression whereCondition3 = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3);
Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3).get();

Plan join = new LogicalJoin(JoinType.INNER_JOIN, Optional.empty(), rStudent, rScore);
Plan filter = new LogicalFilter(whereCondition, join);
Expand Down Expand Up @@ -207,7 +207,7 @@ public void pushDownPredicateIntoScanTest4() {
Expression whereCondition4 = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));

Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3,
whereCondition4);
whereCondition4).get();

Plan join = new LogicalJoin(JoinType.INNER_JOIN, Optional.empty(), rStudent, rScore);
Plan join1 = new LogicalJoin(JoinType.INNER_JOIN, Optional.empty(), join, rCourse);
Expand Down