Skip to content

Commit

Permalink
[Fix](nereids) fix merge aggregate rule, rules should not have mutabl…
Browse files Browse the repository at this point in the history
…e members (#36223)

cherry-pick #36145  to branch-2.1
  • Loading branch information
feiniaofeiafei committed Jun 13, 2024
1 parent d70751a commit e2f7e0d
Showing 1 changed file with 10 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
public class MergeAggregate implements RewriteRuleFactory {
private static final ImmutableSet<String> ALLOW_MERGE_AGGREGATE_FUNCTIONS =
ImmutableSet.of("min", "max", "sum", "any_value");
private Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = new HashMap<>();

@Override
public List<Rule> buildRules() {
Expand All @@ -75,7 +74,7 @@ public List<Rule> buildRules() {
*/
private Plan mergeTwoAggregate(LogicalAggregate<LogicalAggregate<Plan>> outerAgg) {
LogicalAggregate<Plan> innerAgg = outerAgg.child();

Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
List<NamedExpression> newOutputExpressions = outerAgg.getOutputExpressions().stream()
.map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
.collect(Collectors.toList());
Expand All @@ -97,6 +96,7 @@ private Plan mergeAggProjectAgg(LogicalAggregate<LogicalProject<LogicalAggregate
List<NamedExpression> outputExpressions = outerAgg.getOutputExpressions();
List<NamedExpression> replacedOutputExpressions = PlanUtils.replaceExpressionByProjections(
project.getProjects(), (List) outputExpressions);
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
// rewrite agg function. e.g. max(max)
List<NamedExpression> replacedAggFunc = replacedOutputExpressions.stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
Expand Down Expand Up @@ -152,10 +152,7 @@ private NamedExpression rewriteAggregateFunction(NamedExpression e,

private boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, LogicalAggregate<Plan> innerAgg,
boolean sameGroupBy, Optional<LogicalProject> projectOptional) {
innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
(existValue, newValue) -> existValue));
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
Set<AggregateFunction> aggregateFunctions = outerAgg.getAggregateFunctions();
List<AggregateFunction> replacedAggFunctions = projectOptional.map(project ->
(List<AggregateFunction>) (List) PlanUtils.replaceExpressionByProjections(
Expand Down Expand Up @@ -225,4 +222,11 @@ private boolean canMergeAggregateWithProject(LogicalAggregate<LogicalProject<Log
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project));
}

private Map<ExprId, AggregateFunction> getInnerAggExprIdToAggFuncMap(LogicalAggregate<Plan> innerAgg) {
return innerAgg.getOutputExpressions().stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
(existValue, newValue) -> existValue));
}
}

0 comments on commit e2f7e0d

Please sign in to comment.