From ededb896018b6385576bce5075e9e169b0c0c61f Mon Sep 17 00:00:00 2001 From: minghong Date: Tue, 26 May 2026 18:42:07 +0800 Subject: [PATCH] [opt](fe) Bound not-null inference cost (#63318) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Issue Number: close #xxx Related PR: #xxx Problem Summary: Not-null inference replaces candidate slots with NULL and folds each predicate, so wide, deep, or high-slot expressions can make rewrite rules spend excessive time in repeated replace and fold work. Aggregate not-null inference also needed to handle multiple aggregate outputs conservatively instead of inferring from all aggregate arguments as one set. This change adds a shared bounded guard for not-null inference, reuses it from aggregate inference, and lets general callers skip only expensive predicates while preserving the original query predicates. It also reworks join inference to compute nullable-rejecting slots once and reuse them for both sides, and makes aggregate inference require a common inferred not-null predicate across supported aggregate functions. optimized rule: InferAggNotNull、InferFilterNotNull、InferJoinNotNull、EliminateNotNull --- .../rules/rewrite/InferAggNotNull.java | 66 ++++++++++++--- .../rules/rewrite/InferJoinNotNull.java | 29 +++++-- .../trees/plans/algebra/Aggregate.java | 18 ++++- .../doris/nereids/util/ExpressionUtils.java | 55 ++++++++++++- .../rules/rewrite/EliminateNotNullTest.java | 77 ++++++++++++++++++ .../rules/rewrite/InferAggNotNullTest.java | 81 +++++++++++++++++++ .../rules/rewrite/InferFilterNotNullTest.java | 28 +++++++ .../rules/rewrite/InferJoinNotNullTest.java | 33 ++++++++ 8 files changed, 363 insertions(+), 24 deletions(-) create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNullTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java index e30190592a6b3b..4daf320811a796 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; @@ -36,8 +37,8 @@ import com.google.common.collect.ImmutableSet; import java.util.Collections; +import java.util.HashSet; import java.util.Set; -import java.util.stream.Collectors; /** * InferNotNull from Agg count(distinct); @@ -47,19 +48,10 @@ public class InferAggNotNull extends OneRewriteRuleFactory { public Rule build() { return logicalAggregate() .when(agg -> agg.getGroupByExpressions().size() == 0) - .when(agg -> agg.getAggregateFunctions().size() == 1) - .when(agg -> { - Set funcs = agg.getAggregateFunctions(); - return funcs.stream().allMatch(f -> f instanceof Count) - || funcs.stream().allMatch(f -> f instanceof Avg) - || funcs.stream().allMatch(f -> f instanceof Sum) - || funcs.stream().allMatch(f -> f instanceof Max) - || funcs.stream().allMatch(f -> f instanceof Min); - }).thenApply(ctx -> { + .thenApply(ctx -> { LogicalAggregate agg = ctx.root; - Set exprs = agg.getAggregateFunctions().stream().flatMap(f -> f.children().stream()) - .collect(Collectors.toSet()); - Set isNotNulls = ExpressionUtils.inferNotNull(exprs, ctx.cascadesContext); + Set aggregateFunctions = agg.getAggregateFunctions(); + Set isNotNulls = inferCommonNotNulls(aggregateFunctions, ctx.cascadesContext); Set predicates = Collections.emptySet(); if ((agg.child() instanceof Filter)) { predicates = ((Filter) agg.child()).getConjuncts(); @@ -80,4 +72,52 @@ public Rule build() { return agg.withChildren(PlanUtils.filter(needGenerateNotNulls, agg.child()).get()); }).toRule(RuleType.INFER_AGG_NOT_NULL); } + + private Set inferCommonNotNulls( + Set aggregateFunctions, CascadesContext cascadesContext) { + if (aggregateFunctions.isEmpty()) { + return Collections.emptySet(); + } + for (AggregateFunction aggregateFunction : aggregateFunctions) { + if (!canInferFunctionNotNull(aggregateFunction)) { + return Collections.emptySet(); + } + } + Set commonNotNulls = null; + for (AggregateFunction aggregateFunction : aggregateFunctions) { + Set functionNotNulls = inferFunctionNotNulls(aggregateFunction, cascadesContext); + if (functionNotNulls.isEmpty()) { + return Collections.emptySet(); + } + if (commonNotNulls == null) { + commonNotNulls = new HashSet<>(functionNotNulls); + } else { + commonNotNulls.retainAll(functionNotNulls); + if (commonNotNulls.isEmpty()) { + return Collections.emptySet(); + } + } + } + return commonNotNulls == null ? Collections.emptySet() : commonNotNulls; + } + + private Set inferFunctionNotNulls( + AggregateFunction aggregateFunction, CascadesContext cascadesContext) { + return ExpressionUtils.inferNotNull(ImmutableSet.copyOf(aggregateFunction.children()), cascadesContext); + } + + private boolean canInferFunctionNotNull(AggregateFunction aggregateFunction) { + return isSupportedAggregateFunction(aggregateFunction) + && !aggregateFunction.children().isEmpty() + && ExpressionUtils.isCheapEnoughToInferNotNull(aggregateFunction.children()); + } + + private boolean isSupportedAggregateFunction(AggregateFunction aggregateFunction) { + return aggregateFunction instanceof Count + || aggregateFunction instanceof Avg + || aggregateFunction instanceof Sum + || aggregateFunction instanceof Max + || aggregateFunction instanceof Min; + } + } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java index 5f87a3ec940592..c3fb0e068cba57 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java @@ -20,12 +20,17 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.IsNull; +import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; +import com.google.common.collect.ImmutableSet; + import java.util.LinkedHashSet; import java.util.Set; @@ -50,23 +55,21 @@ public Rule build() { Set conjuncts = new LinkedHashSet<>(); conjuncts.addAll(join.getHashJoinConjuncts()); conjuncts.addAll(join.getOtherJoinConjuncts()); + Set notNullSlots = ExpressionUtils.inferNotNullSlots( + conjuncts, ctx.cascadesContext); Plan left = join.left(); Plan right = join.right(); if (join.getJoinType().isInnerJoin() || join.getJoinType().isAsofInnerJoin()) { - Set leftNotNull = ExpressionUtils.inferNotNull( - conjuncts, join.left().getOutputSet(), ctx.cascadesContext); - Set rightNotNull = ExpressionUtils.inferNotNull( - conjuncts, join.right().getOutputSet(), ctx.cascadesContext); + Set leftNotNull = inferNotNull(notNullSlots, join.left().getOutputSet()); + Set rightNotNull = inferNotNull(notNullSlots, join.right().getOutputSet()); left = PlanUtils.filterOrSelf(leftNotNull, join.left()); right = PlanUtils.filterOrSelf(rightNotNull, join.right()); } else if (join.getJoinType() == JoinType.LEFT_SEMI_JOIN) { - Set leftNotNull = ExpressionUtils.inferNotNull( - conjuncts, join.left().getOutputSet(), ctx.cascadesContext); + Set leftNotNull = inferNotNull(notNullSlots, join.left().getOutputSet()); left = PlanUtils.filterOrSelf(leftNotNull, join.left()); } else { - Set rightNotNull = ExpressionUtils.inferNotNull( - conjuncts, join.right().getOutputSet(), ctx.cascadesContext); + Set rightNotNull = inferNotNull(notNullSlots, join.right().getOutputSet()); right = PlanUtils.filterOrSelf(rightNotNull, join.right()); } @@ -76,4 +79,14 @@ public Rule build() { return join.withChildren(left, right); }).toRule(RuleType.INFER_JOIN_NOT_NULL); } + + private Set inferNotNull(Set notNullSlots, Set outputSlots) { + ImmutableSet.Builder predicates = ImmutableSet.builderWithExpectedSize(notNullSlots.size()); + for (Slot slot : notNullSlots) { + if (outputSlots.contains(slot)) { + predicates.add(new Not(new IsNull(slot), true)); + } + } + return predicates.build(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java index 12fc9608fd3936..f67703c23c9f52 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java @@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.UnaryPlan; import org.apache.doris.nereids.trees.plans.logical.OutputPrunable; -import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableSet; @@ -60,8 +59,23 @@ default Aggregate pruneOutputs(List prunedOutputs) return withAggOutput(prunedOutputs); } + /** + * get aggregate functions + * aggregate functions cannot be nested, so we stop recursion when we find an aggregate function, + * and do not need to traverse its children. + */ default Set getAggregateFunctions() { - return ExpressionUtils.collect(getOutputExpressions(), AggregateFunction.class::isInstance); + ImmutableSet.Builder aggregateFunctions = ImmutableSet.builder(); + for (Expression outputExpression : getOutputExpressions()) { + outputExpression.foreach(expression -> { + if (expression instanceof AggregateFunction) { + aggregateFunctions.add((AggregateFunction) expression); + return true; + } + return false; + }); + } + return aggregateFunctions.build(); } /**getAggregateFunctionWithGuardExpr*/ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 78299b1d7d7a26..b63bfd921762d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -105,6 +105,9 @@ public class ExpressionUtils { public static final List EMPTY_CONDITION = ImmutableList.of(); + private static final int MAX_INFER_NOT_NULL_EXPR_WIDTH = 256; + private static final int MAX_INFER_NOT_NULL_EXPR_DEPTH = 64; + private static final int MAX_INFER_NOT_NULL_INPUT_SLOTS = 32; public static List extractConjunction(Expression expr) { return extract(And.class, expr); @@ -767,7 +770,7 @@ private static boolean isNullOrFalse(Expression expression) { */ public static Set inferNotNullSlots(Set predicates, CascadesContext cascadesContext) { ImmutableSet.Builder notNullSlots = ImmutableSet.builderWithExpectedSize(predicates.size()); - for (Expression predicate : predicates) { + for (Expression predicate : filterCheapPredicatesForNotNull(predicates)) { for (Slot slot : predicate.getInputSlots()) { Map replaceMap = new HashMap<>(); Literal nullLiteral = new NullLiteral(slot.getDataType()); @@ -784,6 +787,56 @@ public static Set inferNotNullSlots(Set predicates, CascadesCo return notNullSlots.build(); } + /** + * Return whether all predicates are cheap enough for not-null inference. + */ + public static boolean isCheapEnoughToInferNotNull(Collection predicates) { + Set inputSlots = new HashSet<>(); + for (Expression predicate : predicates) { + Optional> mergedInputSlots = mergeInputSlotsIfCheap(predicate, inputSlots); + if (!mergedInputSlots.isPresent()) { + return false; + } + inputSlots = mergedInputSlots.get(); + } + return true; + } + + /** + * Filter predicates that are cheap enough for not-null inference. + */ + public static Set filterCheapPredicatesForNotNull( + Collection predicates) { + Set inputSlots = new HashSet<>(); + Set cheapPredicates = Sets.newLinkedHashSet(); + for (Expression predicate : predicates) { + Optional> mergedInputSlots = mergeInputSlotsIfCheap(predicate, inputSlots); + if (!mergedInputSlots.isPresent()) { + continue; + } + inputSlots = mergedInputSlots.get(); + cheapPredicates.add(predicate); + } + return cheapPredicates; + } + + private static Optional> mergeInputSlotsIfCheap(Expression predicate, Set inputSlots) { + if (predicate.getWidth() > MAX_INFER_NOT_NULL_EXPR_WIDTH + || predicate.getDepth() > MAX_INFER_NOT_NULL_EXPR_DEPTH) { + return Optional.empty(); + } + Set predicateInputSlots = predicate.getInputSlots(); + if (predicateInputSlots.size() > MAX_INFER_NOT_NULL_INPUT_SLOTS) { + return Optional.empty(); + } + Set mergedInputSlots = new HashSet<>(inputSlots); + mergedInputSlots.addAll(predicateInputSlots); + if (mergedInputSlots.size() > MAX_INFER_NOT_NULL_INPUT_SLOTS) { + return Optional.empty(); + } + return Optional.of(mergedInputSlots); + } + /** * infer notNulls slot from predicate */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNullTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNullTest.java new file mode 100644 index 00000000000000..5486aae91f77ef --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNullTest.java @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.IsNull; +import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.junit.jupiter.api.Test; + +class EliminateNotNullTest implements MemoPatternMatchSupported { + private final SlotReference slot = new SlotReference("nullable_col", IntegerType.INSTANCE, true); + private final LogicalOneRowRelation relation = new LogicalOneRowRelation(new RelationId(1), ImmutableList.of(slot)); + + @Test + void testEliminateNotNullForSimplePredicate() { + Expression simplePredicate = new EqualTo(slot, Literal.of(1)); + Expression explicitNotNull = new Not(new IsNull(slot)); + LogicalPlan plan = new LogicalPlanBuilder(relation) + .filter(ImmutableSet.of(simplePredicate, explicitNotNull)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new EliminateNotNull()) + .matches(logicalFilter().when(filter -> filter.getConjuncts().size() == 1)); + } + + @Test + void testKeepNotNullWhenOnlyWidePredicateCanProveIt() { + Expression widePredicate = new EqualTo(repeatAdd(slot, 257), Literal.of(1)); + Expression explicitNotNull = new Not(new IsNull(slot)); + LogicalPlan plan = new LogicalPlanBuilder(relation) + .filter(ImmutableSet.of(widePredicate, explicitNotNull)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new EliminateNotNull()) + .matches(logicalFilter().when(filter -> filter.getConjuncts().size() == 2)); + } + + private Expression repeatAdd(Expression expression, int width) { + if (width == 1) { + return expression; + } + int leftWidth = width / 2; + return new Add(repeatAdd(expression, leftWidth), repeatAdd(expression, width - leftWidth)); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNullTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNullTest.java index 7d20c2f22a6f2b..23b108c8347412 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNullTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNullTest.java @@ -19,7 +19,11 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.util.LogicalPlanBuilder; @@ -29,8 +33,11 @@ import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.util.Set; + class InferAggNotNullTest implements MemoPatternMatchSupported { private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); @@ -51,6 +58,62 @@ void testInfer() { ); } + @Test + void testInferMultipleAggregateSameInput() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .aggGroupUsingIndex(ImmutableList.of(), + ImmutableList.of( + new Alias(new Avg(scan1.getOutput().get(1)), "avg_k"), + new Alias(new Sum(scan1.getOutput().get(1)), "sum_k"))) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new InferAggNotNull()) + .matches( + logicalAggregate( + logicalFilter().when(filter -> filter.getConjuncts().size() == 1 + && filter.getConjuncts().stream() + .allMatch(e -> ((Not) e).isGeneratedIsNotNull())) + ) + ); + } + + @Test + void testNotInferMultipleAggregateDifferentInputs() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .aggGroupUsingIndex(ImmutableList.of(), + ImmutableList.of( + new Alias(new Avg(scan1.getOutput().get(1)), "avg_k1"), + new Alias(new Sum(scan1.getOutput().get(0)), "sum_k2"))) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new InferAggNotNull()) + .matches( + logicalAggregate( + logicalOlapScan() + ) + ); + } + + @Test + void testNotInferMultipleAggregateWithCountStar() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .aggGroupUsingIndex(ImmutableList.of(), + ImmutableList.of( + new Alias(new Avg(scan1.getOutput().get(1)), "avg_k"), + new Alias(new Count(), "count_star"))) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new InferAggNotNull()) + .matches( + logicalAggregate( + logicalOlapScan() + ) + ); + } + @Test void testCountStar() { LogicalPlan plan = new LogicalPlanBuilder(scan1) @@ -66,4 +129,22 @@ void testCountStar() { ) ); } + + @Test + void testGetAggregateFunctionsStopsAtAggregateFunction() { + // Use different agg function types for inner (Avg) and outer (Count), + // so we can verify by instanceof regardless of how the plan builder + // clones/transforms expressions internally. + Avg inner = new Avg(scan1.getOutput().get(1)); + Count outer = new Count(false, inner); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .aggGroupUsingIndex(ImmutableList.of(), ImmutableList.of(new Alias(outer, "cnt"))) + .build(); + + Set aggregateFunctions = ((LogicalAggregate) plan).getAggregateFunctions(); + System.out.println("aggregateFunctions: " + aggregateFunctions); + Assertions.assertEquals(1, aggregateFunctions.size()); + Assertions.assertTrue(aggregateFunctions.stream().allMatch(f -> f instanceof Count), + "should collect only the outer Count, got: " + aggregateFunctions); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferFilterNotNullTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferFilterNotNullTest.java index 9e4335db3ede40..bf9d1d31f712dd 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferFilterNotNullTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferFilterNotNullTest.java @@ -19,7 +19,9 @@ import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.IsNull; +import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; @@ -30,6 +32,7 @@ import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import com.google.common.collect.ImmutableSet; import org.junit.jupiter.api.Test; class InferFilterNotNullTest implements MemoPatternMatchSupported { @@ -77,4 +80,29 @@ void testInferFailOr() { logicalFilter().when(filter -> filter.getConjuncts().size() == 1) ); } + + @Test + void testSkipWidePredicateButKeepSimplePredicate() { + Expression widePredicate = new EqualTo(repeatAdd(scan1.getOutput().get(0), 257), Literal.of(1)); + Expression simplePredicate = new EqualTo(scan1.getOutput().get(1), Literal.of(1)); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .filter(ImmutableSet.of(widePredicate, simplePredicate)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new InferFilterNotNull()) + .matches( + logicalFilter().when(filter -> filter.getConjuncts().stream() + .filter(e -> e instanceof Not && ((Not) e).isGeneratedIsNotNull()) + .count() == 1) + ); + } + + private Expression repeatAdd(Expression expression, int width) { + if (width == 1) { + return expression; + } + int leftWidth = width / 2; + return new Add(repeatAdd(expression, leftWidth), repeatAdd(expression, width - leftWidth)); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNullTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNullTest.java index d963363a3793ab..4ea2466edb2584 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNullTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNullTest.java @@ -18,6 +18,10 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -27,6 +31,7 @@ import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -92,6 +97,27 @@ void testInferIsNotNull() { ); } + @Test + void testSkipWideOtherConjunctButKeepHashConjunct() { + Expression widePredicate = new EqualTo(repeatAdd(scan1.getOutput().get(1), 257), Literal.of(1)); + LogicalPlan innerJoin = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, + ImmutableList.of(new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0))), + ImmutableList.of(widePredicate)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), innerJoin) + .applyTopDown(new InferJoinNotNull()) + .matches( + innerLogicalJoin( + logicalFilter().when(f -> f.getPredicate().toString() + .equals("( not id#10000 IS NULL)")), + logicalFilter().when(f -> f.getPredicate().toString() + .equals("( not id#10002 IS NULL)")) + ) + ); + } + @Test void testInferAndEliminate() { LogicalPlan plan = new LogicalPlanBuilder(scan1) @@ -109,4 +135,11 @@ void testInferAndEliminate() { ); } + private Expression repeatAdd(Expression expression, int width) { + if (width == 1) { + return expression; + } + int leftWidth = width / 2; + return new Add(repeatAdd(expression, leftWidth), repeatAdd(expression, width - leftWidth)); + } }