From dc020009a1b4f10f4332091663df4acfb12c3fb9 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Mon, 13 Feb 2023 22:23:12 +0800 Subject: [PATCH] fix constant propagate in grouping sets --- .../rules/HiveReduceExpressionsRule.java | 70 ++++++++++++++++++- .../constant_prop_in_groupingsets.q | 12 ++++ .../constant_prop_in_groupingsets.q.out | 43 ++++++++++++ 3 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 ql/src/test/queries/clientpositive/constant_prop_in_groupingsets.q create mode 100644 ql/src/test/results/clientpositive/constant_prop_in_groupingsets.q.out diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveReduceExpressionsRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveReduceExpressionsRule.java index 0521dc3e32a3..4a343b0553cb 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveReduceExpressionsRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveReduceExpressionsRule.java @@ -16,13 +16,27 @@ */ package org.apache.hadoop.hive.ql.optimizer.calcite.rules; +import com.google.common.collect.Lists; +import org.apache.calcite.plan.RelOptPredicateList; import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.rules.ReduceExpressionsRule; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSemiJoin; +import java.util.List; +import java.util.stream.Collectors; /** * Collection of planner rules that apply various simplifying transformations on @@ -58,10 +72,10 @@ private HiveReduceExpressionsRule() { * {@link org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject}. */ public static final RelOptRule PROJECT_INSTANCE = - ReduceExpressionsRule.ProjectReduceExpressionsRule.Config.DEFAULT + HiveProjectReduceExpressionsRule.Config.DEFAULT .withOperandFor(HiveProject.class) .withRelBuilderFactory(HiveRelFactories.HIVE_BUILDER) - .as(ReduceExpressionsRule.ProjectReduceExpressionsRule.Config.class) + .as(HiveProjectReduceExpressionsRule.Config.class) .toRule(); /** @@ -88,6 +102,58 @@ private HiveReduceExpressionsRule() { .as(ReduceExpressionsRule.JoinReduceExpressionsRule.Config.class) .toRule(); + public static class HiveProjectReduceExpressionsRule extends ReduceExpressionsRule.ProjectReduceExpressionsRule { + protected HiveProjectReduceExpressionsRule(ProjectReduceExpressionsRule.Config config) { + super(config); + } + + private boolean hasGroupingSets(Project project) { + RelNode input = project.getInput(); + if (input instanceof HepRelVertex) { + HepRelVertex hepInput = (HepRelVertex) input; + if (hepInput.getCurrentRel() instanceof HiveAggregate) { + HiveAggregate aggregate = (HiveAggregate) hepInput.getCurrentRel(); + return aggregate.getGroupType() != Aggregate.Group.SIMPLE; + } + } + + return false; + } + + @Override public void onMatch(RelOptRuleCall call) { + final Project project = call.rel(0); + final RelMetadataQuery mq = call.getMetadataQuery(); + final RelOptPredicateList predicates = mq.getPulledUpPredicates(project.getInput()); + final List expList; + if (hasGroupingSets(project)) { + expList = project.getProjects().stream().filter(v -> !(v instanceof RexCall)).collect(Collectors.toList()); + } else { + expList = Lists.newArrayList(project.getProjects()); + } + + if (reduceExpressions(project, expList, predicates, false, config.matchNullability())) { + assert !project.getProjects().equals(expList) : "Reduced expressions should be different from original expressions"; + call.transformTo( + call.builder().push(project.getInput()).project(expList, project.getRowType().getFieldNames()).build() + ); + + // New plan is absolutely better than old plan. + call.getPlanner().prune(project); + } + } + + public interface Config extends ProjectReduceExpressionsRule.Config { + HiveProjectReduceExpressionsRule.Config DEFAULT = EMPTY.as(HiveProjectReduceExpressionsRule.Config.class) + .withMatchNullability(true) + .withOperandFor(LogicalProject.class) + .withDescription("HiveProjectReduceExpressionsRule(Project)") + .as(HiveProjectReduceExpressionsRule.Config.class); + + @Override default HiveProjectReduceExpressionsRule toRule() { + return new HiveProjectReduceExpressionsRule(this); + } + } + } } // End HiveReduceExpressionsRule.java diff --git a/ql/src/test/queries/clientpositive/constant_prop_in_groupingsets.q b/ql/src/test/queries/clientpositive/constant_prop_in_groupingsets.q new file mode 100644 index 000000000000..6375be63b753 --- /dev/null +++ b/ql/src/test/queries/clientpositive/constant_prop_in_groupingsets.q @@ -0,0 +1,12 @@ +drop table tb1; + +create table tb1 (key string, value string); + +insert into tb1 values("a", "b"); + +with mid1 as ( + select 'test_value' as test_field, * from tb1 +) +select key, nvl(test_field, 'default_test_value') +from mid1 group by key, test_field +grouping sets(key, test_field, (key, test_field)); \ No newline at end of file diff --git a/ql/src/test/results/clientpositive/constant_prop_in_groupingsets.q.out b/ql/src/test/results/clientpositive/constant_prop_in_groupingsets.q.out new file mode 100644 index 000000000000..208e652e56da --- /dev/null +++ b/ql/src/test/results/clientpositive/constant_prop_in_groupingsets.q.out @@ -0,0 +1,43 @@ +PREHOOK: query: drop table tb1 +PREHOOK: type: DROPTABLE +POSTHOOK: query: drop table tb1 +POSTHOOK: type: DROPTABLE +PREHOOK: query: create table tb1 (key string, value string) +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@tb1 +POSTHOOK: query: create table tb1 (key string, value string) +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@tb1 +PREHOOK: query: insert into tb1 values("a", "b") +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +PREHOOK: Output: default@tb1 +POSTHOOK: query: insert into tb1 values("a", "b") +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +POSTHOOK: Output: default@tb1 +POSTHOOK: Lineage: tb1.key SCRIPT [] +POSTHOOK: Lineage: tb1.value SCRIPT [] +PREHOOK: query: with mid1 as ( + select 'test_value' as test_field, * from tb1 +) +select key, nvl(test_field, 'default_test_value') +from mid1 group by key, test_field +grouping sets(key, test_field, (key, test_field)) +PREHOOK: type: QUERY +PREHOOK: Input: default@tb1 +#### A masked pattern was here #### +POSTHOOK: query: with mid1 as ( + select 'test_value' as test_field, * from tb1 +) +select key, nvl(test_field, 'default_test_value') +from mid1 group by key, test_field +grouping sets(key, test_field, (key, test_field)) +POSTHOOK: type: QUERY +POSTHOOK: Input: default@tb1 +#### A masked pattern was here #### +a test_value +a default_test_value +NULL test_value