Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand All @@ -36,8 +37,8 @@
import com.google.common.collect.ImmutableSet;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;

/**
* InferNotNull from Agg count(distinct);
Expand All @@ -47,19 +48,10 @@ public class InferAggNotNull extends OneRewriteRuleFactory {
public Rule build() {
return logicalAggregate()
.when(agg -> agg.getGroupByExpressions().size() == 0)
.when(agg -> agg.getAggregateFunctions().size() == 1)
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return funcs.stream().allMatch(f -> f instanceof Count)
|| funcs.stream().allMatch(f -> f instanceof Avg)
|| funcs.stream().allMatch(f -> f instanceof Sum)
|| funcs.stream().allMatch(f -> f instanceof Max)
|| funcs.stream().allMatch(f -> f instanceof Min);
}).thenApply(ctx -> {
.thenApply(ctx -> {
LogicalAggregate<Plan> agg = ctx.root;
Set<Expression> exprs = agg.getAggregateFunctions().stream().flatMap(f -> f.children().stream())
.collect(Collectors.toSet());
Set<Expression> isNotNulls = ExpressionUtils.inferNotNull(exprs, ctx.cascadesContext);
Set<AggregateFunction> aggregateFunctions = agg.getAggregateFunctions();
Set<Expression> isNotNulls = inferCommonNotNulls(aggregateFunctions, ctx.cascadesContext);
Set<Expression> predicates = Collections.emptySet();
if ((agg.child() instanceof Filter)) {
predicates = ((Filter) agg.child()).getConjuncts();
Expand All @@ -80,4 +72,52 @@ public Rule build() {
return agg.withChildren(PlanUtils.filter(needGenerateNotNulls, agg.child()).get());
}).toRule(RuleType.INFER_AGG_NOT_NULL);
}

private Set<Expression> inferCommonNotNulls(
Set<AggregateFunction> aggregateFunctions, CascadesContext cascadesContext) {
if (aggregateFunctions.isEmpty()) {
return Collections.emptySet();
}
for (AggregateFunction aggregateFunction : aggregateFunctions) {
if (!canInferFunctionNotNull(aggregateFunction)) {
return Collections.emptySet();
}
}
Set<Expression> commonNotNulls = null;
for (AggregateFunction aggregateFunction : aggregateFunctions) {
Set<Expression> functionNotNulls = inferFunctionNotNulls(aggregateFunction, cascadesContext);
if (functionNotNulls.isEmpty()) {
return Collections.emptySet();
}
if (commonNotNulls == null) {
commonNotNulls = new HashSet<>(functionNotNulls);
} else {
commonNotNulls.retainAll(functionNotNulls);
if (commonNotNulls.isEmpty()) {
return Collections.emptySet();
}
}
}
return commonNotNulls == null ? Collections.emptySet() : commonNotNulls;
}

private Set<Expression> inferFunctionNotNulls(
AggregateFunction aggregateFunction, CascadesContext cascadesContext) {
return ExpressionUtils.inferNotNull(ImmutableSet.copyOf(aggregateFunction.children()), cascadesContext);
}

private boolean canInferFunctionNotNull(AggregateFunction aggregateFunction) {
return isSupportedAggregateFunction(aggregateFunction)
&& !aggregateFunction.children().isEmpty()
&& ExpressionUtils.isCheapEnoughToInferNotNull(aggregateFunction.children());
}

private boolean isSupportedAggregateFunction(AggregateFunction aggregateFunction) {
return aggregateFunction instanceof Count
|| aggregateFunction instanceof Avg
|| aggregateFunction instanceof Sum
|| aggregateFunction instanceof Max
|| aggregateFunction instanceof Min;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableSet;

import java.util.LinkedHashSet;
import java.util.Set;

Expand All @@ -50,23 +55,21 @@ public Rule build() {
Set<Expression> conjuncts = new LinkedHashSet<>();
conjuncts.addAll(join.getHashJoinConjuncts());
conjuncts.addAll(join.getOtherJoinConjuncts());
Set<Slot> notNullSlots = ExpressionUtils.inferNotNullSlots(
conjuncts, ctx.cascadesContext);

Plan left = join.left();
Plan right = join.right();
if (join.getJoinType().isInnerJoin() || join.getJoinType().isAsofInnerJoin()) {
Set<Expression> leftNotNull = ExpressionUtils.inferNotNull(
conjuncts, join.left().getOutputSet(), ctx.cascadesContext);
Set<Expression> rightNotNull = ExpressionUtils.inferNotNull(
conjuncts, join.right().getOutputSet(), ctx.cascadesContext);
Set<Expression> leftNotNull = inferNotNull(notNullSlots, join.left().getOutputSet());
Set<Expression> rightNotNull = inferNotNull(notNullSlots, join.right().getOutputSet());
left = PlanUtils.filterOrSelf(leftNotNull, join.left());
right = PlanUtils.filterOrSelf(rightNotNull, join.right());
} else if (join.getJoinType() == JoinType.LEFT_SEMI_JOIN) {
Set<Expression> leftNotNull = ExpressionUtils.inferNotNull(
conjuncts, join.left().getOutputSet(), ctx.cascadesContext);
Set<Expression> leftNotNull = inferNotNull(notNullSlots, join.left().getOutputSet());
left = PlanUtils.filterOrSelf(leftNotNull, join.left());
} else {
Set<Expression> rightNotNull = ExpressionUtils.inferNotNull(
conjuncts, join.right().getOutputSet(), ctx.cascadesContext);
Set<Expression> rightNotNull = inferNotNull(notNullSlots, join.right().getOutputSet());
right = PlanUtils.filterOrSelf(rightNotNull, join.right());
}

Expand All @@ -76,4 +79,14 @@ public Rule build() {
return join.withChildren(left, right);
}).toRule(RuleType.INFER_JOIN_NOT_NULL);
}

private Set<Expression> inferNotNull(Set<Slot> notNullSlots, Set<Slot> outputSlots) {
ImmutableSet.Builder<Expression> predicates = ImmutableSet.builderWithExpectedSize(notNullSlots.size());
for (Slot slot : notNullSlots) {
if (outputSlots.contains(slot)) {
predicates.add(new Not(new IsNull(slot), true));
}
}
return predicates.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.UnaryPlan;
import org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -60,8 +59,23 @@ default Aggregate<CHILD_TYPE> pruneOutputs(List<NamedExpression> prunedOutputs)
return withAggOutput(prunedOutputs);
}

/**
* get aggregate functions
* aggregate functions cannot be nested, so we stop recursion when we find an aggregate function,
* and do not need to traverse its children.
*/
default Set<AggregateFunction> getAggregateFunctions() {
return ExpressionUtils.collect(getOutputExpressions(), AggregateFunction.class::isInstance);
ImmutableSet.Builder<AggregateFunction> aggregateFunctions = ImmutableSet.builder();
for (Expression outputExpression : getOutputExpressions()) {
outputExpression.foreach(expression -> {
if (expression instanceof AggregateFunction) {
aggregateFunctions.add((AggregateFunction) expression);
return true;
}
return false;
});
}
return aggregateFunctions.build();
}

/**getAggregateFunctionWithGuardExpr*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@
public class ExpressionUtils {

public static final List<Expression> EMPTY_CONDITION = ImmutableList.of();
private static final int MAX_INFER_NOT_NULL_EXPR_WIDTH = 256;
private static final int MAX_INFER_NOT_NULL_EXPR_DEPTH = 64;
private static final int MAX_INFER_NOT_NULL_INPUT_SLOTS = 32;

public static List<Expression> extractConjunction(Expression expr) {
return extract(And.class, expr);
Expand Down Expand Up @@ -767,7 +770,7 @@ private static boolean isNullOrFalse(Expression expression) {
*/
public static Set<Slot> inferNotNullSlots(Set<Expression> predicates, CascadesContext cascadesContext) {
ImmutableSet.Builder<Slot> notNullSlots = ImmutableSet.builderWithExpectedSize(predicates.size());
for (Expression predicate : predicates) {
for (Expression predicate : filterCheapPredicatesForNotNull(predicates)) {
for (Slot slot : predicate.getInputSlots()) {
Map<Expression, Expression> replaceMap = new HashMap<>();
Literal nullLiteral = new NullLiteral(slot.getDataType());
Expand All @@ -784,6 +787,56 @@ public static Set<Slot> inferNotNullSlots(Set<Expression> predicates, CascadesCo
return notNullSlots.build();
}

/**
* Return whether all predicates are cheap enough for not-null inference.
*/
public static boolean isCheapEnoughToInferNotNull(Collection<? extends Expression> predicates) {
Set<Slot> inputSlots = new HashSet<>();
for (Expression predicate : predicates) {
Optional<Set<Slot>> mergedInputSlots = mergeInputSlotsIfCheap(predicate, inputSlots);
if (!mergedInputSlots.isPresent()) {
return false;
}
inputSlots = mergedInputSlots.get();
}
return true;
}

/**
* Filter predicates that are cheap enough for not-null inference.
*/
public static Set<Expression> filterCheapPredicatesForNotNull(
Collection<? extends Expression> predicates) {
Set<Slot> inputSlots = new HashSet<>();
Set<Expression> cheapPredicates = Sets.newLinkedHashSet();
for (Expression predicate : predicates) {
Optional<Set<Slot>> mergedInputSlots = mergeInputSlotsIfCheap(predicate, inputSlots);
if (!mergedInputSlots.isPresent()) {
continue;
}
inputSlots = mergedInputSlots.get();
cheapPredicates.add(predicate);
}
return cheapPredicates;
}

private static Optional<Set<Slot>> mergeInputSlotsIfCheap(Expression predicate, Set<Slot> inputSlots) {
if (predicate.getWidth() > MAX_INFER_NOT_NULL_EXPR_WIDTH
|| predicate.getDepth() > MAX_INFER_NOT_NULL_EXPR_DEPTH) {
return Optional.empty();
}
Set<Slot> predicateInputSlots = predicate.getInputSlots();
if (predicateInputSlots.size() > MAX_INFER_NOT_NULL_INPUT_SLOTS) {
return Optional.empty();
}
Set<Slot> mergedInputSlots = new HashSet<>(inputSlots);
mergedInputSlots.addAll(predicateInputSlots);
if (mergedInputSlots.size() > MAX_INFER_NOT_NULL_INPUT_SLOTS) {
return Optional.empty();
}
return Optional.of(mergedInputSlots);
}

/**
* infer notNulls slot from predicate
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.junit.jupiter.api.Test;

class EliminateNotNullTest implements MemoPatternMatchSupported {
private final SlotReference slot = new SlotReference("nullable_col", IntegerType.INSTANCE, true);
private final LogicalOneRowRelation relation = new LogicalOneRowRelation(new RelationId(1), ImmutableList.of(slot));

@Test
void testEliminateNotNullForSimplePredicate() {
Expression simplePredicate = new EqualTo(slot, Literal.of(1));
Expression explicitNotNull = new Not(new IsNull(slot));
LogicalPlan plan = new LogicalPlanBuilder(relation)
.filter(ImmutableSet.of(simplePredicate, explicitNotNull))
.build();

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new EliminateNotNull())
.matches(logicalFilter().when(filter -> filter.getConjuncts().size() == 1));
}

@Test
void testKeepNotNullWhenOnlyWidePredicateCanProveIt() {
Expression widePredicate = new EqualTo(repeatAdd(slot, 257), Literal.of(1));
Expression explicitNotNull = new Not(new IsNull(slot));
LogicalPlan plan = new LogicalPlanBuilder(relation)
.filter(ImmutableSet.of(widePredicate, explicitNotNull))
.build();

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new EliminateNotNull())
.matches(logicalFilter().when(filter -> filter.getConjuncts().size() == 2));
}

private Expression repeatAdd(Expression expression, int width) {
if (width == 1) {
return expression;
}
int leftWidth = width / 2;
return new Add(repeatAdd(expression, leftWidth), repeatAdd(expression, width - leftWidth));
}
}
Loading