From 27c75b319381100832851b8b1faaa9935be15b68 Mon Sep 17 00:00:00 2001 From: starocean999 <40539150+starocean999@users.noreply.github.com> Date: Mon, 20 May 2024 14:08:13 +0800 Subject: [PATCH] [opt](nereids)new way to set pre-agg status (#34738) --- .../doris/nereids/jobs/executor/Rewriter.java | 4 + .../apache/doris/nereids/rules/RuleType.java | 15 + .../nereids/rules/analysis/BindRelation.java | 2 +- .../rules/rewrite/AdjustPreAggStatus.java | 748 ++++++++++++++++++ .../AbstractSelectMaterializedIndexRule.java | 12 +- .../SelectMaterializedIndexWithAggregate.java | 18 +- ...lectMaterializedIndexWithoutAggregate.java | 20 +- .../nereids/trees/plans/PreAggStatus.java | 15 +- .../trees/plans/logical/LogicalOlapScan.java | 14 +- .../rewrite/mv/SelectRollupIndexTest.java | 30 +- .../nereids/trees/plans/PlanToStringTest.java | 2 +- 11 files changed, 823 insertions(+), 57 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 2bc61b3b6fe1a9..e2095248298145 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -36,6 +36,7 @@ import org.apache.doris.nereids.rules.rewrite.AddDefaultLimit; import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType; import org.apache.doris.nereids.rules.rewrite.AdjustNullable; +import org.apache.doris.nereids.rules.rewrite.AdjustPreAggStatus; import org.apache.doris.nereids.rules.rewrite.AggScalarSubQueryToWindowFunction; import org.apache.doris.nereids.rules.rewrite.BuildAggForUnion; import org.apache.doris.nereids.rules.rewrite.CTEInline; @@ -391,6 +392,9 @@ public class Rewriter extends AbstractBatchJobExecutor { bottomUp(RuleSet.PUSH_DOWN_FILTERS), custom(RuleType.ELIMINATE_UNNECESSARY_PROJECT, EliminateUnnecessaryProject::new) ), + topic("adjust preagg status", + topDown(new AdjustPreAggStatus()) + ), topic("topn optimize", topDown(new DeferMaterializeTopNResult()) ), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 9d7d8e2d62176c..3d950b5781fc98 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -242,6 +242,21 @@ public enum RuleType { MATERIALIZED_INDEX_PROJECT_SCAN(RuleTypeClass.REWRITE), MATERIALIZED_INDEX_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), MATERIALIZED_INDEX_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_FILTER_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_PROJECT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_REPEAT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_REPEAT_FILTER_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_REPEAT_PROJECT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_REPEAT_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_AGG_REPEAT_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_FILTER_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_PROJECT_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE), + PREAGG_STATUS_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE), REDUCE_AGGREGATE_CHILD_OUTPUT_ROWS(RuleTypeClass.REWRITE), OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java index 0e6d940891ebce..df3743928a9b96 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java @@ -206,7 +206,7 @@ private LogicalPlan makeOlapScan(TableIf table, UnboundRelation unboundRelation, } PreAggStatus preAggStatus = olapTable.getIndexMetaByIndexId(indexId).getKeysType().equals(KeysType.DUP_KEYS) - ? PreAggStatus.on() + ? PreAggStatus.unset() : PreAggStatus.off("For direct index scan."); scan = new LogicalOlapScan(unboundRelation.getRelationId(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java new file mode 100644 index 00000000000000..a0c0b56dd71c99 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java @@ -0,0 +1,748 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.catalog.AggregateType; +import org.apache.doris.catalog.KeysType; +import org.apache.doris.catalog.MaterializedIndexMeta; +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.annotation.Developing; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.PreAggStatus; +import org.apache.doris.nereids.trees.plans.algebra.Project; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * AdjustPreAggStatus + */ +@Developing +public class AdjustPreAggStatus implements RewriteRuleFactory { + /////////////////////////////////////////////////////////////////////////// + // All the patterns + /////////////////////////////////////////////////////////////////////////// + @Override + public List buildRules() { + return ImmutableList.of( + // Aggregate(Scan) + logicalAggregate(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)) + .thenApplyNoThrow(ctx -> { + LogicalAggregate agg = ctx.root; + LogicalOlapScan scan = agg.child(); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = agg.getGroupByExpressions(); + Set predicates = ImmutableSet.of(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return agg.withChildren(scan.withPreAggStatus(preAggStatus)); + }).toRule(RuleType.PREAGG_STATUS_AGG_SCAN), + + // Aggregate(Filter(Scan)) + logicalAggregate( + logicalFilter(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate> agg = ctx.root; + LogicalFilter filter = agg.child(); + LogicalOlapScan scan = filter.child(); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = + agg.getGroupByExpressions(); + Set predicates = filter.getConjuncts(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return agg.withChildren(filter + .withChildren(scan.withPreAggStatus(preAggStatus))); + }).toRule(RuleType.PREAGG_STATUS_AGG_FILTER_SCAN), + + // Aggregate(Project(Scan)) + logicalAggregate(logicalProject( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate> agg = + ctx.root; + LogicalProject project = agg.child(); + LogicalOlapScan scan = project.child(); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, + Optional.of(project)); + List groupByExpressions = + ExpressionUtils.replace(agg.getGroupByExpressions(), + project.getAliasToProducer()); + Set predicates = ImmutableSet.of(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return agg.withChildren(project + .withChildren(scan.withPreAggStatus(preAggStatus))); + }).toRule(RuleType.PREAGG_STATUS_AGG_PROJECT_SCAN), + + // Aggregate(Project(Filter(Scan))) + logicalAggregate(logicalProject(logicalFilter( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate>> agg = ctx.root; + LogicalProject> project = agg.child(); + LogicalFilter filter = project.child(); + LogicalOlapScan scan = filter.child(); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.of(project)); + List groupByExpressions = + ExpressionUtils.replace(agg.getGroupByExpressions(), + project.getAliasToProducer()); + Set predicates = filter.getConjuncts(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return agg.withChildren(project.withChildren(filter + .withChildren(scan.withPreAggStatus(preAggStatus)))); + }).toRule(RuleType.PREAGG_STATUS_AGG_PROJECT_FILTER_SCAN), + + // Aggregate(Filter(Project(Scan))) + logicalAggregate(logicalFilter(logicalProject( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate>> agg = ctx.root; + LogicalFilter> filter = + agg.child(); + LogicalProject project = filter.child(); + LogicalOlapScan scan = project.child(); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.of(project)); + List groupByExpressions = + ExpressionUtils.replace(agg.getGroupByExpressions(), + project.getAliasToProducer()); + Set predicates = ExpressionUtils.replace( + filter.getConjuncts(), project.getAliasToProducer()); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return agg.withChildren(filter.withChildren(project + .withChildren(scan.withPreAggStatus(preAggStatus)))); + }).toRule(RuleType.PREAGG_STATUS_AGG_FILTER_PROJECT_SCAN), + + // Aggregate(Repeat(Scan)) + logicalAggregate( + logicalRepeat(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate> agg = ctx.root; + LogicalRepeat repeat = agg.child(); + LogicalOlapScan scan = repeat.child(); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = nonVirtualGroupByExprs(agg); + Set predicates = ImmutableSet.of(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return agg.withChildren(repeat + .withChildren(scan.withPreAggStatus(preAggStatus))); + }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_SCAN), + + // Aggregate(Repeat(Filter(Scan))) + logicalAggregate(logicalRepeat(logicalFilter( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate>> agg = ctx.root; + LogicalRepeat> repeat = agg.child(); + LogicalFilter filter = repeat.child(); + LogicalOlapScan scan = filter.child(); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = + nonVirtualGroupByExprs(agg); + Set predicates = filter.getConjuncts(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return agg.withChildren(repeat.withChildren(filter + .withChildren(scan.withPreAggStatus(preAggStatus)))); + }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_FILTER_SCAN), + + // Aggregate(Repeat(Project(Scan))) + logicalAggregate(logicalRepeat(logicalProject( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate>> agg = ctx.root; + LogicalRepeat> repeat = agg.child(); + LogicalProject project = repeat.child(); + LogicalOlapScan scan = project.child(); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = + ExpressionUtils.replace(nonVirtualGroupByExprs(agg), + project.getAliasToProducer()); + Set predicates = ImmutableSet.of(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return agg.withChildren(repeat.withChildren(project + .withChildren(scan.withPreAggStatus(preAggStatus)))); + }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_PROJECT_SCAN), + + // Aggregate(Repeat(Project(Filter(Scan)))) + logicalAggregate(logicalRepeat(logicalProject(logicalFilter( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate>>> agg + = ctx.root; + LogicalRepeat>> repeat = agg.child(); + LogicalProject> project = repeat.child(); + LogicalFilter filter = project.child(); + LogicalOlapScan scan = filter.child(); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.empty()); + List groupByExpressions = + ExpressionUtils.replace(nonVirtualGroupByExprs(agg), + project.getAliasToProducer()); + Set predicates = filter.getConjuncts(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return agg.withChildren(repeat + .withChildren(project.withChildren(filter.withChildren( + scan.withPreAggStatus(preAggStatus))))); + }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_PROJECT_FILTER_SCAN), + + // Aggregate(Repeat(Filter(Project(Scan)))) + logicalAggregate(logicalRepeat(logicalFilter(logicalProject( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))))) + .thenApplyNoThrow(ctx -> { + LogicalAggregate>>> agg + = ctx.root; + LogicalRepeat>> repeat = agg.child(); + LogicalFilter> filter = repeat.child(); + LogicalProject project = filter.child(); + LogicalOlapScan scan = project.child(); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = + extractAggFunctionAndReplaceSlot(agg, Optional.of(project)); + List groupByExpressions = + ExpressionUtils.replace(nonVirtualGroupByExprs(agg), + project.getAliasToProducer()); + Set predicates = ExpressionUtils.replace( + filter.getConjuncts(), project.getAliasToProducer()); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return agg.withChildren(repeat + .withChildren(filter.withChildren(project.withChildren( + scan.withPreAggStatus(preAggStatus))))); + }).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_FILTER_PROJECT_SCAN), + + // Filter(Project(Scan)) + logicalFilter(logicalProject( + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))) + .thenApplyNoThrow(ctx -> { + LogicalFilter> filter = ctx.root; + LogicalProject project = filter.child(); + LogicalOlapScan scan = project.child(); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = ImmutableList.of(); + List groupByExpressions = ImmutableList.of(); + Set predicates = ExpressionUtils.replace( + filter.getConjuncts(), project.getAliasToProducer()); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return filter.withChildren(project + .withChildren(scan.withPreAggStatus(preAggStatus))); + }).toRule(RuleType.PREAGG_STATUS_FILTER_PROJECT_SCAN), + + // Filter(Scan) + logicalFilter(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)) + .thenApplyNoThrow(ctx -> { + LogicalFilter filter = ctx.root; + LogicalOlapScan scan = filter.child(); + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = ImmutableList.of(); + List groupByExpressions = ImmutableList.of(); + Set predicates = filter.getConjuncts(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return filter.withChildren(scan.withPreAggStatus(preAggStatus)); + }).toRule(RuleType.PREAGG_STATUS_FILTER_SCAN), + + // only scan. + logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet) + .thenApplyNoThrow(ctx -> { + LogicalOlapScan scan = ctx.root; + PreAggStatus preAggStatus = checkKeysType(scan); + if (preAggStatus == PreAggStatus.unset()) { + List aggregateFunctions = ImmutableList.of(); + List groupByExpressions = ImmutableList.of(); + Set predicates = ImmutableSet.of(); + preAggStatus = checkPreAggStatus(scan, predicates, + aggregateFunctions, groupByExpressions); + } + return scan.withPreAggStatus(preAggStatus); + }).toRule(RuleType.PREAGG_STATUS_SCAN)); + } + + /////////////////////////////////////////////////////////////////////////// + // Set pre-aggregation status. + /////////////////////////////////////////////////////////////////////////// + + /** + * Do aggregate function extraction and replace aggregate function's input slots by underlying project. + *

+ * 1. extract aggregate functions in aggregate plan. + *

+ * 2. replace aggregate function's input slot by underlying project expression if project is present. + *

+ * For example: + *

+     * input arguments:
+     * agg: Aggregate(sum(v) as sum_value)
+     * underlying project: Project(a + b as v)
+     *
+     * output:
+     * sum(a + b)
+     * 
+ */ + private List extractAggFunctionAndReplaceSlot(LogicalAggregate agg, + Optional> project) { + Optional> slotToProducerOpt = + project.map(Project::getAliasToProducer); + return agg.getOutputExpressions().stream() + // extract aggregate functions. + .flatMap(e -> e.>collect(AggregateFunction.class::isInstance) + .stream()) + // replace aggregate function's input slot by its producing expression. + .map(expr -> slotToProducerOpt + .map(slotToExpressions -> (AggregateFunction) ExpressionUtils.replace(expr, + slotToExpressions)) + .orElse(expr)) + .collect(Collectors.toList()); + } + + private PreAggStatus checkKeysType(LogicalOlapScan olapScan) { + long selectIndexId = olapScan.getSelectedIndexId(); + MaterializedIndexMeta meta = olapScan.getTable().getIndexMetaByIndexId(selectIndexId); + if (meta.getKeysType() == KeysType.DUP_KEYS || (meta.getKeysType() == KeysType.UNIQUE_KEYS + && olapScan.getTable().getEnableUniqueKeyMergeOnWrite())) { + return PreAggStatus.on(); + } else { + return PreAggStatus.unset(); + } + } + + private PreAggStatus checkPreAggStatus(LogicalOlapScan olapScan, Set predicates, + List aggregateFuncs, List groupingExprs) { + Set outputSlots = olapScan.getOutputSet(); + Pair, Set> splittedSlots = splitSlots(outputSlots); + Set keySlots = splittedSlots.first; + Set valueSlots = splittedSlots.second; + Preconditions.checkState(outputSlots.size() == keySlots.size() + valueSlots.size(), + "output slots contains no key or value slots"); + + Set groupingExprsInputSlots = ExpressionUtils.getInputSlotSet(groupingExprs); + if (groupingExprsInputSlots.retainAll(keySlots)) { + return PreAggStatus + .off(String.format("Grouping expression %s contains non-key column %s", + groupingExprs, groupingExprsInputSlots)); + } + + Set predicateInputSlots = ExpressionUtils.getInputSlotSet(predicates); + if (predicateInputSlots.retainAll(keySlots)) { + return PreAggStatus.off(String.format("Predicate %s contains non-key column %s", + predicates, predicateInputSlots)); + } + + return checkAggregateFunctions(aggregateFuncs, groupingExprsInputSlots); + } + + private Pair, Set> splitSlots(Set slots) { + Set keySlots = Sets.newHashSetWithExpectedSize(slots.size()); + Set valueSlots = Sets.newHashSetWithExpectedSize(slots.size()); + for (Slot slot : slots) { + if (slot instanceof SlotReference && ((SlotReference) slot).getColumn().isPresent()) { + if (((SlotReference) slot).getColumn().get().isKey()) { + keySlots.add((SlotReference) slot); + } else { + valueSlots.add((SlotReference) slot); + } + } + } + return Pair.of(keySlots, valueSlots); + } + + private static Expression removeCast(Expression expression) { + while (expression instanceof Cast) { + expression = ((Cast) expression).child(); + } + return expression; + } + + private PreAggStatus checkAggWithKeyAndValueSlots(AggregateFunction aggFunc, + Set keySlots, Set valueSlots) { + Expression child = aggFunc.child(0); + List conditionExps = new ArrayList<>(); + List returnExps = new ArrayList<>(); + + // ignore cast + while (child instanceof Cast) { + if (!((Cast) child).getDataType().isNumericType()) { + return PreAggStatus.off(String.format("%s is not numeric CAST.", child.toSql())); + } + child = child.child(0); + } + // step 1: extract all condition exprs and return exprs + if (child instanceof If) { + conditionExps.add(child.child(0)); + returnExps.add(removeCast(child.child(1))); + returnExps.add(removeCast(child.child(2))); + } else if (child instanceof CaseWhen) { + CaseWhen caseWhen = (CaseWhen) child; + // WHEN THEN + for (WhenClause whenClause : caseWhen.getWhenClauses()) { + conditionExps.add(whenClause.getOperand()); + returnExps.add(removeCast(whenClause.getResult())); + } + // ELSE + returnExps.add(removeCast(caseWhen.getDefaultValue().orElse(new NullLiteral()))); + } else { + // currently, only IF and CASE WHEN are supported + returnExps.add(removeCast(child)); + } + + // step 2: check condition expressions + Set inputSlots = ExpressionUtils.getInputSlotSet(conditionExps); + inputSlots.retainAll(valueSlots); + if (!inputSlots.isEmpty()) { + return PreAggStatus + .off(String.format("some columns in condition %s is not key.", conditionExps)); + } + + return KeyAndValueSlotsAggChecker.INSTANCE.check(aggFunc, returnExps); + } + + private PreAggStatus checkAggregateFunctions(List aggregateFuncs, + Set groupingExprsInputSlots) { + PreAggStatus preAggStatus = aggregateFuncs.isEmpty() && groupingExprsInputSlots.isEmpty() + ? PreAggStatus.off("No aggregate on scan.") + : PreAggStatus.on(); + for (AggregateFunction aggFunc : aggregateFuncs) { + if (aggFunc.children().size() == 1 && aggFunc.child(0) instanceof Slot) { + Slot aggSlot = (Slot) aggFunc.child(0); + if (aggSlot instanceof SlotReference + && ((SlotReference) aggSlot).getColumn().isPresent()) { + if (((SlotReference) aggSlot).getColumn().get().isKey()) { + preAggStatus = OneKeySlotAggChecker.INSTANCE.check(aggFunc); + } else { + preAggStatus = OneValueSlotAggChecker.INSTANCE.check(aggFunc, + ((SlotReference) aggSlot).getColumn().get().getAggregationType()); + } + } else { + preAggStatus = PreAggStatus.off( + String.format("aggregate function %s use unknown slot %s from scan", + aggFunc, aggSlot)); + } + } else { + Set aggSlots = aggFunc.getInputSlots(); + Pair, Set> splitSlots = splitSlots(aggSlots); + preAggStatus = + checkAggWithKeyAndValueSlots(aggFunc, splitSlots.first, splitSlots.second); + } + if (preAggStatus.isOff()) { + return preAggStatus; + } + } + return preAggStatus; + } + + private List nonVirtualGroupByExprs(LogicalAggregate agg) { + return agg.getGroupByExpressions().stream() + .filter(expr -> !(expr instanceof VirtualSlotReference)) + .collect(ImmutableList.toImmutableList()); + } + + private static class OneValueSlotAggChecker + extends ExpressionVisitor { + public static final OneValueSlotAggChecker INSTANCE = new OneValueSlotAggChecker(); + + public PreAggStatus check(AggregateFunction aggFun, AggregateType aggregateType) { + return aggFun.accept(INSTANCE, aggregateType); + } + + @Override + public PreAggStatus visit(Expression expr, AggregateType aggregateType) { + return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql())); + } + + @Override + public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, + AggregateType aggregateType) { + return PreAggStatus + .off(String.format("%s is not supported.", aggregateFunction.toSql())); + } + + @Override + public PreAggStatus visitMax(Max max, AggregateType aggregateType) { + if (aggregateType == AggregateType.MAX && !max.isDistinct()) { + return PreAggStatus.on(); + } else { + return PreAggStatus + .off(String.format("%s is not match agg mode %s or has distinct param", + max.toSql(), aggregateType)); + } + } + + @Override + public PreAggStatus visitMin(Min min, AggregateType aggregateType) { + if (aggregateType == AggregateType.MIN && !min.isDistinct()) { + return PreAggStatus.on(); + } else { + return PreAggStatus + .off(String.format("%s is not match agg mode %s or has distinct param", + min.toSql(), aggregateType)); + } + } + + @Override + public PreAggStatus visitSum(Sum sum, AggregateType aggregateType) { + if (aggregateType == AggregateType.SUM && !sum.isDistinct()) { + return PreAggStatus.on(); + } else { + return PreAggStatus + .off(String.format("%s is not match agg mode %s or has distinct param", + sum.toSql(), aggregateType)); + } + } + + @Override + public PreAggStatus visitBitmapUnionCount(BitmapUnionCount bitmapUnionCount, + AggregateType aggregateType) { + if (aggregateType == AggregateType.BITMAP_UNION) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off("invalid bitmap_union_count: " + bitmapUnionCount.toSql()); + } + } + + @Override + public PreAggStatus visitBitmapUnion(BitmapUnion bitmapUnion, AggregateType aggregateType) { + if (aggregateType == AggregateType.BITMAP_UNION) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off("invalid bitmapUnion: " + bitmapUnion.toSql()); + } + } + + @Override + public PreAggStatus visitHllUnionAgg(HllUnionAgg hllUnionAgg, AggregateType aggregateType) { + if (aggregateType == AggregateType.HLL_UNION) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off("invalid hllUnionAgg: " + hllUnionAgg.toSql()); + } + } + + @Override + public PreAggStatus visitHllUnion(HllUnion hllUnion, AggregateType aggregateType) { + if (aggregateType == AggregateType.HLL_UNION) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off("invalid hllUnion: " + hllUnion.toSql()); + } + } + } + + private static class OneKeySlotAggChecker extends ExpressionVisitor { + public static final OneKeySlotAggChecker INSTANCE = new OneKeySlotAggChecker(); + + public PreAggStatus check(AggregateFunction aggFun) { + return aggFun.accept(INSTANCE, null); + } + + @Override + public PreAggStatus visit(Expression expr, Void context) { + return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql())); + } + + @Override + public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, + Void context) { + return PreAggStatus.off(String.format("Aggregate function %s contains key column %s", + aggregateFunction.toSql(), aggregateFunction.child(0).toSql())); + } + + @Override + public PreAggStatus visitMax(Max max, Void context) { + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitMin(Min min, Void context) { + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitCount(Count count, Void context) { + if (count.isDistinct()) { + return PreAggStatus.on(); + } else { + return PreAggStatus.off(String.format("%s is not distinct.", count.toSql())); + } + } + } + + private static class KeyAndValueSlotsAggChecker + extends ExpressionVisitor> { + public static final KeyAndValueSlotsAggChecker INSTANCE = new KeyAndValueSlotsAggChecker(); + + public PreAggStatus check(AggregateFunction aggFun, List returnValues) { + return aggFun.accept(INSTANCE, returnValues); + } + + @Override + public PreAggStatus visit(Expression expr, List returnValues) { + return PreAggStatus.off(String.format("%s is not aggregate function.", expr.toSql())); + } + + @Override + public PreAggStatus visitAggregateFunction(AggregateFunction aggregateFunction, + List returnValues) { + return PreAggStatus + .off(String.format("%s is not supported.", aggregateFunction.toSql())); + } + + @Override + public PreAggStatus visitSum(Sum sum, List returnValues) { + for (Expression value : returnValues) { + if (!(isAggTypeMatched(value, AggregateType.SUM) || value.isZeroLiteral() + || value.isNullLiteral())) { + return PreAggStatus.off(String.format("%s is not supported.", sum.toSql())); + } + } + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitMax(Max max, List returnValues) { + for (Expression value : returnValues) { + if (!(isAggTypeMatched(value, AggregateType.MAX) || isKeySlot(value) + || value.isNullLiteral())) { + return PreAggStatus.off(String.format("%s is not supported.", max.toSql())); + } + } + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitMin(Min min, List returnValues) { + for (Expression value : returnValues) { + if (!(isAggTypeMatched(value, AggregateType.MIN) || isKeySlot(value) + || value.isNullLiteral())) { + return PreAggStatus.off(String.format("%s is not supported.", min.toSql())); + } + } + return PreAggStatus.on(); + } + + @Override + public PreAggStatus visitCount(Count count, List returnValues) { + if (count.isDistinct()) { + for (Expression value : returnValues) { + if (!(isKeySlot(value) || value.isZeroLiteral() || value.isNullLiteral())) { + return PreAggStatus + .off(String.format("%s is not supported.", count.toSql())); + } + } + return PreAggStatus.on(); + } else { + return PreAggStatus.off(String.format("%s is not supported.", count.toSql())); + } + } + + private boolean isKeySlot(Expression expression) { + return expression instanceof SlotReference + && ((SlotReference) expression).getColumn().isPresent() + && ((SlotReference) expression).getColumn().get().isKey(); + } + + private boolean isAggTypeMatched(Expression expression, AggregateType aggregateType) { + return expression instanceof SlotReference + && ((SlotReference) expression).getColumn().isPresent() + && ((SlotReference) expression).getColumn().get() + .getAggregationType() == aggregateType; + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java index b5773a7571d24e..1124c141416f3f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java @@ -88,17 +88,7 @@ protected boolean shouldSelectIndexWithAgg(LogicalOlapScan scan) { case AGG_KEYS: case UNIQUE_KEYS: case DUP_KEYS: - // SelectMaterializedIndexWithAggregate(R1) run before SelectMaterializedIndexWithoutAggregate(R2) - // if R1 selects baseIndex and preAggStatus is off - // we should give a chance to R2 to check if some prefix-index can be selected - // so if R1 selects baseIndex and preAggStatus is off, we keep scan's index unselected in order to - // let R2 to get a chance to do its work - // at last, after R1, the scan may be the 4 status - // 1. preAggStatus is ON and baseIndex is selected, it means select baseIndex is correct. - // 2. preAggStatus is ON and some other Index is selected, this is correct, too. - // 3. preAggStatus is OFF, no index is selected, it means R2 could get a chance to run - // so we check the preAggStatus and if some index is selected to make sure R1 can be run only once - return scan.getPreAggStatus().isOn() && !scan.isIndexSelected(); + return !scan.isIndexSelected(); default: return false; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java index b8aae5066862af..b221637f18794a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java @@ -211,7 +211,7 @@ public List buildRules() { result.exprRewriteMap.projectExprMap); LogicalProject newProject = new LogicalProject<>( generateNewOutputsWithMvOutputs(mvPlan, newProjectList), - scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId)); + scan.withMaterializedIndexSelected(result.indexId)); return new LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext), new ReplaceExpressions(slotContext) .replace( @@ -259,9 +259,6 @@ public List buildRules() { filter.getExpressions(), project.getExpressions() )) ); - if (mvPlanWithoutAgg.getSelectedIndexId() == result.indexId) { - mvPlanWithoutAgg = mvPlanWithoutAgg.withPreAggStatus(result.preAggStatus); - } SlotContext slotContextWithoutAgg = generateBaseScanExprToMvExpr(mvPlanWithoutAgg); return agg.withChildren(new LogicalProject( @@ -535,7 +532,7 @@ public List buildRules() { result.exprRewriteMap.projectExprMap); LogicalProject newProject = new LogicalProject<>( generateNewOutputsWithMvOutputs(mvPlan, newProjectList), - scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId)); + scan.withMaterializedIndexSelected(result.indexId)); return new LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext), new ReplaceExpressions(slotContext).replace(new LogicalAggregate<>( @@ -552,16 +549,7 @@ public List buildRules() { } private static LogicalOlapScan createLogicalOlapScan(LogicalOlapScan scan, SelectResult result) { - LogicalOlapScan mvPlan; - if (result.preAggStatus.isOff()) { - // we only set preAggStatus and make index unselected to let SelectMaterializedIndexWithoutAggregate - // have a chance to run and select proper index - mvPlan = scan.withPreAggStatus(result.preAggStatus); - } else { - mvPlan = - scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId); - } - return mvPlan; + return scan.withMaterializedIndexSelected(result.indexId); } /////////////////////////////////////////////////////////////////////////// diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java index acffdc3b258052..5e4e1ce44c92dd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java @@ -27,7 +27,6 @@ import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.plans.PreAggStatus; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -185,7 +184,7 @@ public static LogicalOlapScan select( break; case DUP_KEYS: if (table.getIndexIdToMeta().size() == 1) { - return scan.withMaterializedIndexSelected(PreAggStatus.on(), baseIndexId); + return scan.withMaterializedIndexSelected(baseIndexId); } break; default: @@ -210,19 +209,10 @@ public static LogicalOlapScan select( // this is fail-safe for select mv // select baseIndex if bestIndex's slots' data types are different from baseIndex bestIndex = isSameDataType(scan, bestIndex, requiredSlots.get()) ? bestIndex : baseIndexId; - return scan.withMaterializedIndexSelected(PreAggStatus.on(), bestIndex); + return scan.withMaterializedIndexSelected(bestIndex); } else { - final PreAggStatus preAggStatus; - if (preAggEnabledByHint(scan)) { - // PreAggStatus could be enabled by pre-aggregation hint for agg-keys and unique-keys. - preAggStatus = PreAggStatus.on(); - } else { - // if PreAggStatus is OFF, we use the message from SelectMaterializedIndexWithAggregate - preAggStatus = scan.getPreAggStatus().isOff() ? scan.getPreAggStatus() - : PreAggStatus.off("No aggregate on scan."); - } if (table.getIndexIdToMeta().size() == 1) { - return scan.withMaterializedIndexSelected(preAggStatus, baseIndexId); + return scan.withMaterializedIndexSelected(baseIndexId); } int baseIndexKeySize = table.getKeyColumnsByIndexId(table.getBaseIndexId()).size(); // No aggregate on scan. @@ -235,13 +225,13 @@ public static LogicalOlapScan select( if (candidates.size() == 1) { // `candidates` only have base index. - return scan.withMaterializedIndexSelected(preAggStatus, baseIndexId); + return scan.withMaterializedIndexSelected(baseIndexId); } else { long bestIndex = selectBestIndex(candidates, scan, predicatesSupplier.get(), requiredExpr.get()); // this is fail-safe for select mv // select baseIndex if bestIndex's slots' data types are different from baseIndex bestIndex = isSameDataType(scan, bestIndex, requiredSlots.get()) ? bestIndex : baseIndexId; - return scan.withMaterializedIndexSelected(preAggStatus, bestIndex); + return scan.withMaterializedIndexSelected(bestIndex); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PreAggStatus.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PreAggStatus.java index 7affac49b2bc09..8ba99c2c07f0eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PreAggStatus.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PreAggStatus.java @@ -26,10 +26,11 @@ public class PreAggStatus { private enum Status { - ON, OFF + ON, OFF, UNSET } private static final PreAggStatus PRE_AGG_ON = new PreAggStatus(Status.ON, ""); + private static final PreAggStatus PRE_AGG_UNSET = new PreAggStatus(Status.UNSET, ""); private final Status status; private final String offReason; @@ -46,6 +47,10 @@ public boolean isOff() { return status == Status.OFF; } + public boolean isUnset() { + return status == Status.UNSET; + } + public String getOffReason() { return offReason; } @@ -58,6 +63,10 @@ public PreAggStatus offOrElse(Supplier supplier) { } } + public static PreAggStatus unset() { + return PRE_AGG_UNSET; + } + public static PreAggStatus on() { return PRE_AGG_ON; } @@ -70,8 +79,10 @@ public static PreAggStatus off(String reason) { public String toString() { if (status == Status.ON) { return "ON"; - } else { + } else if (status == Status.OFF) { return "OFF, " + offReason; + } else { + return "UNSET"; } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java index d0d91f1cf8dafb..714f540524f1a8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java @@ -126,7 +126,7 @@ public LogicalOlapScan(RelationId id, OlapTable table, List qualifier) { this(id, table, qualifier, Optional.empty(), Optional.empty(), table.getPartitionIds(), false, ImmutableList.of(), - -1, false, PreAggStatus.on(), ImmutableList.of(), ImmutableList.of(), + -1, false, PreAggStatus.unset(), ImmutableList.of(), ImmutableList.of(), Maps.newHashMap(), Optional.empty(), false, false); } @@ -134,7 +134,7 @@ public LogicalOlapScan(RelationId id, OlapTable table, List qualifier, L List hints, Optional tableSample) { this(id, table, qualifier, Optional.empty(), Optional.empty(), table.getPartitionIds(), false, tabletIds, - -1, false, PreAggStatus.on(), ImmutableList.of(), hints, Maps.newHashMap(), + -1, false, PreAggStatus.unset(), ImmutableList.of(), hints, Maps.newHashMap(), tableSample, false, false); } @@ -143,7 +143,7 @@ public LogicalOlapScan(RelationId id, OlapTable table, List qualifier, L this(id, table, qualifier, Optional.empty(), Optional.empty(), // must use specifiedPartitions here for prune partition by sql like 'select * from t partition p1' specifiedPartitions, false, tabletIds, - -1, false, PreAggStatus.on(), specifiedPartitions, hints, Maps.newHashMap(), + -1, false, PreAggStatus.unset(), specifiedPartitions, hints, Maps.newHashMap(), tableSample, false, false); } @@ -275,11 +275,11 @@ public LogicalOlapScan withSelectedPartitionIds(List selectedPartitionIds) hints, cacheSlotWithSlotName, tableSample, directMvScan, projectPulledUp); } - public LogicalOlapScan withMaterializedIndexSelected(PreAggStatus preAgg, long indexId) { + public LogicalOlapScan withMaterializedIndexSelected(long indexId) { return new LogicalOlapScan(relationId, (Table) table, qualifier, Optional.empty(), Optional.of(getLogicalProperties()), selectedPartitionIds, partitionPruned, selectedTabletIds, - indexId, true, preAgg, manuallySpecifiedPartitions, hints, cacheSlotWithSlotName, + indexId, true, PreAggStatus.unset(), manuallySpecifiedPartitions, hints, cacheSlotWithSlotName, tableSample, directMvScan, projectPulledUp); } @@ -432,6 +432,10 @@ public boolean isDirectMvScan() { return directMvScan; } + public boolean isPreAggStatusUnSet() { + return preAggStatus.isUnset(); + } + private List createSlotsVectorized(List columns) { List qualified = qualified(); Object[] slots = new Object[columns.size()]; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java index 0686edba64e01e..45552bfc2fae6e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java @@ -19,6 +19,7 @@ import org.apache.doris.common.FeConstants; import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject; +import org.apache.doris.nereids.rules.rewrite.AdjustPreAggStatus; import org.apache.doris.nereids.rules.rewrite.MergeProjects; import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject; import org.apache.doris.nereids.trees.plans.PreAggStatus; @@ -110,6 +111,7 @@ public void testMatchingBase() { PlanChecker.from(connectContext) .analyze(" select k1, sum(v1) from t group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); Assertions.assertEquals("t", scan.getSelectedMaterializedIndexName().get()); @@ -122,6 +124,7 @@ void testAggFilterScan() { PlanChecker.from(connectContext) .analyze("select k2, sum(v1) from t where k3=0 group by k2") .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); @@ -139,8 +142,7 @@ void testTranslate() { public void testTranslateWhenPreAggIsOff() { singleTableTest("select k2, min(v1) from t group by k2", scan -> { Assertions.assertFalse(scan.isPreAggregation()); - Assertions.assertEquals("Aggregate operator don't match, " - + "aggregate function: min(v1), column aggregate type: SUM", + Assertions.assertEquals("min(v1) is not match agg mode SUM or has distinct param", scan.getReasonOfPreAggregation()); }); } @@ -150,6 +152,7 @@ public void testWithEqualFilter() { PlanChecker.from(connectContext) .analyze("select k2, sum(v1) from t where k3=0 group by k2") .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); @@ -162,6 +165,7 @@ public void testWithNonEqualFilter() { PlanChecker.from(connectContext) .analyze("select k2, sum(v1) from t where k3>0 group by k2") .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); @@ -174,6 +178,7 @@ public void testWithFilter() { PlanChecker.from(connectContext) .analyze("select k2, sum(v1) from t where k2>3 group by k2") .applyTopDown(new SelectMaterializedIndexWithAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); Assertions.assertEquals("r1", scan.getSelectedMaterializedIndexName().get()); @@ -193,6 +198,7 @@ public void testWithFilterAndProject() { .applyBottomUp(new MergeProjects()) .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { Assertions.assertTrue(scan.getPreAggStatus().isOn()); Assertions.assertEquals("r2", scan.getSelectedMaterializedIndexName().get()); @@ -210,6 +216,7 @@ public void testNoAggregate() { .analyze("select k1, v1 from t") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); @@ -224,11 +231,11 @@ public void testAggregateTypeNotMatch() { .analyze("select k1, min(v1) from t group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); - Assertions.assertEquals("Aggregate operator don't match, " - + "aggregate function: min(v1), column aggregate type: SUM", preAgg.getOffReason()); + Assertions.assertEquals("min(v1) is not match agg mode SUM or has distinct param", preAgg.getOffReason()); return true; })); } @@ -239,10 +246,11 @@ public void testInvalidSlotInAggFunction() { .analyze("select k1, sum(v1 + 1) from t group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); - Assertions.assertEquals("do not support compound expression [(v1 + 1)] in SUM.", + Assertions.assertEquals("sum((v1 + 1)) is not supported.", preAgg.getOffReason()); return true; })); @@ -254,10 +262,11 @@ public void testKeyColumnInAggFunction() { .analyze("select k1, sum(k2) from t group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOff()); - Assertions.assertEquals("Aggregate function sum(k2) contains key column k2.", + Assertions.assertEquals("Aggregate function sum(k2) contains key column k2", preAgg.getOffReason()); return true; })); @@ -269,6 +278,7 @@ public void testMaxCanUseKeyColumn() { .analyze("select k2, max(k3) from t group by k2") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -283,6 +293,7 @@ public void testMinCanUseKeyColumn() { .analyze("select k2, min(k3) from t group by k2") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -297,6 +308,7 @@ public void testMinMaxCanUseKeyColumnWithBaseTable() { .analyze("select k1, min(k2), max(k2) from t group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -311,6 +323,8 @@ public void testFilterAggWithBaseTable() { .analyze("select k1 from t where k1 = 0 group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new MergeProjects()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -325,6 +339,7 @@ public void testDuplicatePreAggOn() { .analyze("select k1, sum(k1) from duplicate_tbl group by k1") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -338,6 +353,7 @@ public void testDuplicatePreAggOnEvenWithoutAggregate() { .analyze("select k1, v1 from duplicate_tbl") .applyTopDown(new SelectMaterializedIndexWithAggregate()) .applyTopDown(new SelectMaterializedIndexWithoutAggregate()) + .applyTopDown(new AdjustPreAggStatus()) .matches(logicalOlapScan().when(scan -> { PreAggStatus preAgg = scan.getPreAggStatus(); Assertions.assertTrue(preAgg.isOn()); @@ -402,7 +418,7 @@ public void testCountDistinctKeyColumn() { public void testCountDistinctValueColumn() { singleTableTest("select k1, count(distinct v1) from t group by k1", scan -> { Assertions.assertFalse(scan.isPreAggregation()); - Assertions.assertEquals("Count distinct is only valid for key columns, but meet count(DISTINCT v1).", + Assertions.assertEquals("count(DISTINCT v1) is not supported.", scan.getReasonOfPreAggregation()); Assertions.assertEquals("t", scan.getSelectedIndexName()); }); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java index 44cb6c296af30d..0a6eb7e0c592ef 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java @@ -84,7 +84,7 @@ public void testLogicalOlapScan() { Assertions.assertTrue( plan.toString().matches("LogicalOlapScan \\( qualified=db\\.table, " + "indexName=, " - + "selectedIndexId=-1, preAgg=ON \\)")); + + "selectedIndexId=-1, preAgg=UNSET \\)")); } @Test