diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewWindowRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewWindowRule.java index 5be68758ef96d8..b2e962bcb82d16 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewWindowRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewWindowRule.java @@ -17,19 +17,29 @@ package org.apache.doris.nereids.rules.exploration.mv; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping; +import org.apache.doris.nereids.rules.exploration.mv.rollup.AggFunctionRollUpHandler; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.WindowExpression; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; +import org.apache.doris.nereids.util.ExpressionUtils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -96,6 +106,9 @@ protected Plan rewriteQueryByView(MatchMode matchMode, StructInfo queryStructInf viewToQuerySlotMapping, tempRewrittenPlan.treeString())); return null; } + Map mvExprToMvScanExprQueryBased = + materializationContext.getShuttledExprToScanExprMapping().keyPermute(viewToQuerySlotMapping) + .flattenMap().get(0); // Rewrite top projects, represent the query projects by view List expressionsRewritten = rewriteExpression( queryStructInfo.getExpressions(), @@ -104,15 +117,19 @@ protected Plan rewriteQueryByView(MatchMode matchMode, StructInfo queryStructInf viewToQuerySlotMapping, ImmutableMap.of(), cascadesContext ); - // Can not rewrite, bail out + // If generic rewrite fails, try roll up from query expressions. if (expressionsRewritten.isEmpty()) { - materializationContext.recordFailReason(queryStructInfo, - "Rewrite expressions by view in window scan fail", - () -> String.format("expressionToRewritten is %s,\n mvExprToMvScanExprMapping is %s,\n" - + "targetToSourceMapping = %s", queryStructInfo.getExpressions(), - materializationContext.getShuttledExprToScanExprMapping(), - viewToQuerySlotMapping)); - return null; + expressionsRewritten = rollupWindowAggregateFunctions(queryStructInfo.getExpressions(), + queryStructInfo.getTopPlan(), mvExprToMvScanExprQueryBased, true, false); + if (expressionsRewritten.isEmpty()) { + materializationContext.recordFailReason(queryStructInfo, + "Rewrite expressions by view in window scan fail", + () -> String.format("expressionToRewritten is %s,\n mvExprToMvScanExprMapping is %s,\n" + + "targetToSourceMapping = %s", queryStructInfo.getExpressions(), + materializationContext.getShuttledExprToScanExprMapping(), + viewToQuerySlotMapping)); + return null; + } } return new LogicalProject<>( expressionsRewritten.stream() @@ -120,4 +137,121 @@ protected Plan rewriteQueryByView(MatchMode matchMode, StructInfo queryStructInf .map(NamedExpression.class::cast) .collect(Collectors.toList()), tempRewrittenPlan); } + + private static List rollupWindowAggregateFunctions(List expressions, + Plan queryTopPlan, Map mvExprToMvScanExprQueryBased, + boolean needShuttle, boolean strictSlotRewrite) { + WindowAggregateRollupContext context = new WindowAggregateRollupContext(queryTopPlan, + mvExprToMvScanExprQueryBased, strictSlotRewrite); + List inputExpressions = needShuttle + ? ExpressionUtils.shuttleExpressionWithLineage(expressions, queryTopPlan) + : expressions; + List rewrittenExpressions = inputExpressions.stream() + .map(expression -> expression.accept(WindowAggregateRollupRewriter.INSTANCE, context)) + .collect(Collectors.toList()); + return context.isValid() ? rewrittenExpressions : ImmutableList.of(); + } + + private static Function rollupWindowAggregateFunction(AggregateFunction queryAggregateFunction, + Expression queryAggregateFunctionShuttled, Map mvExprToMvScanExprQueryBased) { + for (Map.Entry expressionEntry : mvExprToMvScanExprQueryBased.entrySet()) { + Expression viewExpression = expressionEntry.getKey(); + // Window mapping keys may be full WindowExpression while rollup handlers match aggregate functions. + if (viewExpression instanceof WindowExpression) { + viewExpression = ((WindowExpression) viewExpression).getFunction(); + } + Pair mvExprToMvScanExprQueryBasedPair = Pair.of(viewExpression, + expressionEntry.getValue()); + for (AggFunctionRollUpHandler rollUpHandler : AbstractMaterializedViewAggregateRule.ROLL_UP_HANDLERS) { + if (!rollUpHandler.canRollup(queryAggregateFunction, queryAggregateFunctionShuttled, + mvExprToMvScanExprQueryBasedPair, mvExprToMvScanExprQueryBased)) { + continue; + } + Function rollupFunction = rollUpHandler.doRollup(queryAggregateFunction, + queryAggregateFunctionShuttled, mvExprToMvScanExprQueryBasedPair, + mvExprToMvScanExprQueryBased); + if (rollupFunction != null) { + return rollupFunction; + } + } + } + return null; + } + + private static class WindowAggregateRollupRewriter + extends DefaultExpressionRewriter { + + private static final WindowAggregateRollupRewriter INSTANCE = new WindowAggregateRollupRewriter(); + + @Override + public Expression visitWindow(WindowExpression windowExpression, WindowAggregateRollupContext context) { + if (!context.isValid()) { + return windowExpression; + } + Expression rewrittenWindowExpr = context.getMvExprToMvScanExprQueryBased().get(windowExpression); + if (rewrittenWindowExpr != null) { + return rewrittenWindowExpr; + } + Expression function = windowExpression.getFunction(); + if (!(function instanceof AggregateFunction)) { + return super.visitWindow(windowExpression, context); + } + Expression queryFunctionShuttled = ExpressionUtils.shuttleExpressionWithLineage(function, + context.getQueryTopPlan()); + Function rewrittenFunction = rollupWindowAggregateFunction((AggregateFunction) function, + queryFunctionShuttled, context.getMvExprToMvScanExprQueryBased()); + if (rewrittenFunction == null) { + context.setValid(false); + return windowExpression; + } + return super.visitWindow(windowExpression.withFunction(rewrittenFunction), context); + } + + @Override + public Expression visitSlot(Slot slot, WindowAggregateRollupContext context) { + if (!context.isValid()) { + return slot; + } + Expression rewritten = context.getMvExprToMvScanExprQueryBased().get(slot); + if (rewritten == null && context.isStrictSlotRewrite()) { + context.setValid(false); + return slot; + } + return rewritten == null ? slot : rewritten; + } + } + + private static class WindowAggregateRollupContext { + private boolean valid = true; + private final Plan queryTopPlan; + private final Map mvExprToMvScanExprQueryBased; + private final boolean strictSlotRewrite; + + private WindowAggregateRollupContext(Plan queryTopPlan, + Map mvExprToMvScanExprQueryBased, boolean strictSlotRewrite) { + this.queryTopPlan = queryTopPlan; + this.mvExprToMvScanExprQueryBased = mvExprToMvScanExprQueryBased; + this.strictSlotRewrite = strictSlotRewrite; + } + + public boolean isValid() { + return valid; + } + + public void setValid(boolean valid) { + this.valid = valid; + } + + public Plan getQueryTopPlan() { + return queryTopPlan; + } + + public Map getMvExprToMvScanExprQueryBased() { + return mvExprToMvScanExprQueryBased; + } + + public boolean isStrictSlotRewrite() { + return strictSlotRewrite; + } + } }