From bc07320ae73a0ba03512fb4b647bf98dfdce6aef Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Sun, 26 Aug 2018 16:15:27 -0700 Subject: [PATCH] Support projection after sorting in SQL (#5788) (#6228) * Add sort project * add more test * address comments --- .../sql/calcite/aggregation/Aggregation.java | 3 +- .../io/druid/sql/calcite/rel/DruidQuery.java | 190 +++++++++++++----- .../druid/sql/calcite/rel/DruidQueryRel.java | 8 +- .../druid/sql/calcite/rel/DruidSemiJoin.java | 8 +- .../sql/calcite/rel/PartialDruidQuery.java | 120 +++++++---- .../io/druid/sql/calcite/rel/SortProject.java | 112 +++++++++++ .../io/druid/sql/calcite/rule/DruidRules.java | 38 +++- .../sql/calcite/rule/DruidSemiJoinRule.java | 10 +- .../druid/sql/calcite/CalciteQueryTest.java | 189 +++++++++++++++++ 9 files changed, 581 insertions(+), 97 deletions(-) create mode 100644 sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java diff --git a/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java b/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java index 2532c8d7f82e..09436b96e9da 100644 --- a/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java +++ b/sql/src/main/java/io/druid/sql/calcite/aggregation/Aggregation.java @@ -36,6 +36,7 @@ import io.druid.sql.calcite.table.RowSignature; import javax.annotation.Nullable; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Set; @@ -112,7 +113,7 @@ public static Aggregation create(final AggregatorFactory aggregatorFactory) public static Aggregation create(final PostAggregator postAggregator) { - return new Aggregation(ImmutableList.of(), ImmutableList.of(), postAggregator); + return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator); } public static Aggregation create( diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java index 9740f6815514..2f6fde564358 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java +++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java @@ -89,6 +89,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.OptionalInt; import java.util.TreeSet; import java.util.stream.Collectors; @@ -105,9 +106,11 @@ public class DruidQuery private final DimFilter filter; private final SelectProjection selectProjection; private final Grouping grouping; + private final SortProject sortProject; + private final DefaultLimitSpec limitSpec; private final RowSignature outputRowSignature; private final RelDataType outputRowType; - private final DefaultLimitSpec limitSpec; + private final Query query; public DruidQuery( @@ -129,15 +132,22 @@ public DruidQuery( this.selectProjection = computeSelectProjection(partialQuery, plannerContext, sourceRowSignature); this.grouping = computeGrouping(partialQuery, plannerContext, sourceRowSignature, rexBuilder, finalizeAggregations); + final RowSignature sortingInputRowSignature; + if (this.selectProjection != null) { - this.outputRowSignature = this.selectProjection.getOutputRowSignature(); + sortingInputRowSignature = this.selectProjection.getOutputRowSignature(); } else if (this.grouping != null) { - this.outputRowSignature = this.grouping.getOutputRowSignature(); + sortingInputRowSignature = this.grouping.getOutputRowSignature(); } else { - this.outputRowSignature = sourceRowSignature; + sortingInputRowSignature = sourceRowSignature; } - this.limitSpec = computeLimitSpec(partialQuery, this.outputRowSignature); + this.sortProject = computeSortProject(partialQuery, plannerContext, sortingInputRowSignature, grouping); + + // outputRowSignature is used only for scan and select query, and thus sort and grouping must be null + this.outputRowSignature = sortProject == null ? sortingInputRowSignature : sortProject.getOutputRowSignature(); + + this.limitSpec = computeLimitSpec(partialQuery, sortingInputRowSignature); this.query = computeQuery(); } @@ -237,7 +247,7 @@ private static Grouping computeGrouping( ) { final Aggregate aggregate = partialQuery.getAggregate(); - final Project postProject = partialQuery.getPostProject(); + final Project aggregateProject = partialQuery.getAggregateProject(); if (aggregate == null) { return null; @@ -268,49 +278,27 @@ private static Grouping computeGrouping( plannerContext ); - if (postProject == null) { + if (aggregateProject == null) { return Grouping.create(dimensions, aggregations, havingFilter, aggregateRowSignature); } else { - final List rowOrder = new ArrayList<>(); - - int outputNameCounter = 0; - for (final RexNode postAggregatorRexNode : postProject.getChildExps()) { - // Attempt to convert to PostAggregator. - final DruidExpression postAggregatorExpression = Expressions.toDruidExpression( - plannerContext, - aggregateRowSignature, - postAggregatorRexNode - ); - - if (postAggregatorExpression == null) { - throw new CannotBuildQueryException(postProject, postAggregatorRexNode); - } - - if (postAggregatorDirectColumnIsOk(aggregateRowSignature, postAggregatorExpression, postAggregatorRexNode)) { - // Direct column access, without any type cast as far as Druid's runtime is concerned. - // (There might be a SQL-level type cast that we don't care about) - rowOrder.add(postAggregatorExpression.getDirectColumn()); - } else { - final String postAggregatorName = "p" + outputNameCounter++; - final PostAggregator postAggregator = new ExpressionPostAggregator( - postAggregatorName, - postAggregatorExpression.getExpression(), - null, - plannerContext.getExprMacroTable() - ); - aggregations.add(Aggregation.create(postAggregator)); - rowOrder.add(postAggregator.getName()); - } - } + final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations( + plannerContext, + aggregateRowSignature, + aggregateProject, + 0 + ); + projectRowOrderAndPostAggregations.postAggregations.forEach( + postAggregator -> aggregations.add(Aggregation.create(postAggregator)) + ); // Remove literal dimensions that did not appear in the projection. This is useful for queries // like "SELECT COUNT(*) FROM tbl GROUP BY 'dummy'" which some tools can generate, and for which we don't // actually want to include a dimension 'dummy'. - final ImmutableBitSet postProjectBits = RelOptUtil.InputFinder.bits(postProject.getChildExps(), null); + final ImmutableBitSet aggregateProjectBits = RelOptUtil.InputFinder.bits(aggregateProject.getChildExps(), null); for (int i = dimensions.size() - 1; i >= 0; i--) { final DimensionExpression dimension = dimensions.get(i); if (Parser.parse(dimension.getDruidExpression().getExpression(), plannerContext.getExprMacroTable()) - .isLiteral() && !postProjectBits.get(i)) { + .isLiteral() && !aggregateProjectBits.get(i)) { dimensions.remove(i); } } @@ -319,11 +307,98 @@ private static Grouping computeGrouping( dimensions, aggregations, havingFilter, - RowSignature.from(rowOrder, postProject.getRowType()) + RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, aggregateProject.getRowType()) ); } } + @Nullable + private SortProject computeSortProject( + PartialDruidQuery partialQuery, + PlannerContext plannerContext, + RowSignature sortingInputRowSignature, + Grouping grouping + ) + { + final Project sortProject = partialQuery.getSortProject(); + if (sortProject == null) { + return null; + } else { + final List postAggregators = grouping.getPostAggregators(); + final OptionalInt maybeMaxCounter = postAggregators + .stream() + .mapToInt(postAggregator -> Integer.parseInt(postAggregator.getName().substring(1))) + .max(); + + final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations( + plannerContext, + sortingInputRowSignature, + sortProject, + maybeMaxCounter.orElse(-1) + 1 // 0 if max doesn't exist + ); + + return new SortProject( + sortingInputRowSignature, + projectRowOrderAndPostAggregations.postAggregations, + RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, sortProject.getRowType()) + ); + } + } + + private static class ProjectRowOrderAndPostAggregations + { + private final List rowOrder; + private final List postAggregations; + + ProjectRowOrderAndPostAggregations(List rowOrder, List postAggregations) + { + this.rowOrder = rowOrder; + this.postAggregations = postAggregations; + } + } + + private static ProjectRowOrderAndPostAggregations computePostAggregations( + PlannerContext plannerContext, + RowSignature inputRowSignature, + Project project, + int outputNameCounter + ) + { + final List rowOrder = new ArrayList<>(); + final List aggregations = new ArrayList<>(); + + for (final RexNode postAggregatorRexNode : project.getChildExps()) { + // Attempt to convert to PostAggregator. + final DruidExpression postAggregatorExpression = Expressions.toDruidExpression( + plannerContext, + inputRowSignature, + postAggregatorRexNode + ); + + if (postAggregatorExpression == null) { + throw new CannotBuildQueryException(project, postAggregatorRexNode); + } + + if (postAggregatorDirectColumnIsOk(inputRowSignature, postAggregatorExpression, postAggregatorRexNode)) { + // Direct column access, without any type cast as far as Druid's runtime is concerned. + // (There might be a SQL-level type cast that we don't care about) + rowOrder.add(postAggregatorExpression.getDirectColumn()); + } else { + final String postAggregatorName = "p" + outputNameCounter++; + final PostAggregator postAggregator = new ExpressionPostAggregator( + postAggregatorName, + postAggregatorExpression.getExpression(), + null, + plannerContext.getExprMacroTable() + ); + aggregations.add(postAggregator); + rowOrder.add(postAggregator.getName()); + } + } + + return new ProjectRowOrderAndPostAggregations(rowOrder, aggregations); + } + /** * Returns dimensions corresponding to {@code aggregate.getGroupSet()}, in the same order. * @@ -548,18 +623,20 @@ public VirtualColumns getVirtualColumns(final ExprMacroTable macroTable, final b { final List retVal = new ArrayList<>(); - if (grouping != null) { - if (includeDimensions) { - for (DimensionExpression dimensionExpression : grouping.getDimensions()) { - retVal.addAll(dimensionExpression.getVirtualColumns(macroTable)); + if (selectProjection != null) { + retVal.addAll(selectProjection.getVirtualColumns()); + } else { + if (grouping != null) { + if (includeDimensions) { + for (DimensionExpression dimensionExpression : grouping.getDimensions()) { + retVal.addAll(dimensionExpression.getVirtualColumns(macroTable)); + } } - } - for (Aggregation aggregation : grouping.getAggregations()) { - retVal.addAll(aggregation.getVirtualColumns()); + for (Aggregation aggregation : grouping.getAggregations()) { + retVal.addAll(aggregation.getVirtualColumns()); + } } - } else if (selectProjection != null) { - retVal.addAll(selectProjection.getVirtualColumns()); } return VirtualColumns.create(retVal); @@ -575,6 +652,11 @@ public DefaultLimitSpec getLimitSpec() return limitSpec; } + public SortProject getSortProject() + { + return sortProject; + } + public RelDataType getOutputRowType() { return outputRowType; @@ -675,7 +757,6 @@ public TimeseriesQuery toTimeseriesQuery() if (limitSpec != null) { // If there is a limit spec, timeseries cannot LIMIT; and must be ORDER BY time (or nothing). - if (limitSpec.isLimited()) { return null; } @@ -805,6 +886,11 @@ public GroupByQuery toGroupByQuery() final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature); + final List postAggregators = new ArrayList<>(grouping.getPostAggregators()); + if (sortProject != null) { + postAggregators.addAll(sortProject.getPostAggregators()); + } + return new GroupByQuery( dataSource, filtration.getQuerySegmentSpec(), @@ -813,7 +899,7 @@ public GroupByQuery toGroupByQuery() Granularities.ALL, grouping.getDimensionSpecs(), grouping.getAggregatorFactories(), - grouping.getPostAggregators(), + postAggregators, grouping.getHavingFilter() != null ? new DimFilterHavingSpec(grouping.getHavingFilter(), true) : null, limitSpec, ImmutableSortedMap.copyOf(plannerContext.getQueryContext()) diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java index c304e5babd7e..a62096a5e62a 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java +++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java @@ -225,14 +225,18 @@ public RelOptCost computeSelfCost(final RelOptPlanner planner, final RelMetadata cost += COST_PER_COLUMN * partialQuery.getAggregate().getAggCallList().size(); } - if (partialQuery.getPostProject() != null) { - cost += COST_PER_COLUMN * partialQuery.getPostProject().getChildExps().size(); + if (partialQuery.getAggregateProject() != null) { + cost += COST_PER_COLUMN * partialQuery.getAggregateProject().getChildExps().size(); } if (partialQuery.getSort() != null && partialQuery.getSort().fetch != null) { cost *= COST_LIMIT_MULTIPLIER; } + if (partialQuery.getSortProject() != null) { + cost += COST_PER_COLUMN * partialQuery.getSortProject().getChildExps().size(); + } + if (partialQuery.getHavingFilter() != null) { cost *= COST_HAVING_MULTIPLIER; } diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java index ecfd8bbb2b67..7c0d8b62f662 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java +++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java @@ -358,8 +358,12 @@ public List accumulate(final List theConditions, final Object[ newPartialQuery = newPartialQuery.withHavingFilter(leftPartialQuery.getHavingFilter()); } - if (leftPartialQuery.getPostProject() != null) { - newPartialQuery = newPartialQuery.withPostProject(leftPartialQuery.getPostProject()); + if (leftPartialQuery.getAggregateProject() != null) { + newPartialQuery = newPartialQuery.withAggregateProject(leftPartialQuery.getAggregateProject()); + } + + if (leftPartialQuery.getSortProject() != null) { + newPartialQuery = newPartialQuery.withSortProject(leftPartialQuery.getSortProject()); } if (leftPartialQuery.getSort() != null) { diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java b/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java index 01c960c918f2..d7d0f77e059d 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java +++ b/sql/src/main/java/io/druid/sql/calcite/rel/PartialDruidQuery.java @@ -46,8 +46,9 @@ public class PartialDruidQuery private final Sort selectSort; private final Aggregate aggregate; private final Filter havingFilter; - private final Project postProject; + private final Project aggregateProject; private final Sort sort; + private final Project sortProject; public enum Stage { @@ -57,8 +58,9 @@ public enum Stage SELECT_SORT, AGGREGATE, HAVING_FILTER, - POST_PROJECT, - SORT + AGGREGATE_PROJECT, + SORT, + SORT_PROJECT } public PartialDruidQuery( @@ -67,9 +69,10 @@ public PartialDruidQuery( final Project selectProject, final Sort selectSort, final Aggregate aggregate, - final Project postProject, + final Project aggregateProject, final Filter havingFilter, - final Sort sort + final Sort sort, + final Project sortProject ) { this.scan = Preconditions.checkNotNull(scan, "scan"); @@ -77,14 +80,15 @@ public PartialDruidQuery( this.selectProject = selectProject; this.selectSort = selectSort; this.aggregate = aggregate; - this.postProject = postProject; + this.aggregateProject = aggregateProject; this.havingFilter = havingFilter; this.sort = sort; + this.sortProject = sortProject; } public static PartialDruidQuery create(final RelNode scanRel) { - return new PartialDruidQuery(scanRel, null, null, null, null, null, null, null); + return new PartialDruidQuery(scanRel, null, null, null, null, null, null, null, null); } public RelNode getScan() @@ -117,9 +121,9 @@ public Filter getHavingFilter() return havingFilter; } - public Project getPostProject() + public Project getAggregateProject() { - return postProject; + return aggregateProject; } public Sort getSort() @@ -127,6 +131,11 @@ public Sort getSort() return sort; } + public Project getSortProject() + { + return sortProject; + } + public PartialDruidQuery withWhereFilter(final Filter newWhereFilter) { validateStage(Stage.WHERE_FILTER); @@ -136,9 +145,10 @@ public PartialDruidQuery withWhereFilter(final Filter newWhereFilter) selectProject, selectSort, aggregate, - postProject, + aggregateProject, havingFilter, - sort + sort, + sortProject ); } @@ -151,9 +161,10 @@ public PartialDruidQuery withSelectProject(final Project newSelectProject) newSelectProject, selectSort, aggregate, - postProject, + aggregateProject, havingFilter, - sort + sort, + sortProject ); } @@ -166,9 +177,10 @@ public PartialDruidQuery withSelectSort(final Sort newSelectSort) selectProject, newSelectSort, aggregate, - postProject, + aggregateProject, havingFilter, - sort + sort, + sortProject ); } @@ -181,9 +193,10 @@ public PartialDruidQuery withAggregate(final Aggregate newAggregate) selectProject, selectSort, newAggregate, - postProject, + aggregateProject, havingFilter, - sort + sort, + sortProject ); } @@ -196,24 +209,26 @@ public PartialDruidQuery withHavingFilter(final Filter newHavingFilter) selectProject, selectSort, aggregate, - postProject, + aggregateProject, newHavingFilter, - sort + sort, + sortProject ); } - public PartialDruidQuery withPostProject(final Project newPostProject) + public PartialDruidQuery withAggregateProject(final Project newAggregateProject) { - validateStage(Stage.POST_PROJECT); + validateStage(Stage.AGGREGATE_PROJECT); return new PartialDruidQuery( scan, whereFilter, selectProject, selectSort, aggregate, - newPostProject, + newAggregateProject, havingFilter, - sort + sort, + sortProject ); } @@ -226,9 +241,26 @@ public PartialDruidQuery withSort(final Sort newSort) selectProject, selectSort, aggregate, - postProject, + aggregateProject, + havingFilter, + newSort, + sortProject + ); + } + + public PartialDruidQuery withSortProject(final Project newSortProject) + { + validateStage(Stage.SORT_PROJECT); + return new PartialDruidQuery( + scan, + whereFilter, + selectProject, + selectSort, + aggregate, + aggregateProject, havingFilter, - newSort + sort, + newSortProject ); } @@ -266,6 +298,9 @@ public boolean canAccept(final Stage stage) } else if (stage.compareTo(Stage.AGGREGATE) >= 0 && selectSort != null) { // Cannot do any aggregations after a select + sort. return false; + } else if (stage.compareTo(Stage.SORT) > 0 && sort == null) { + // Cannot add sort project without a sort + return false; } else { // Looks good. return true; @@ -278,12 +313,15 @@ public boolean canAccept(final Stage stage) * * @return stage */ + @SuppressWarnings("VariableNotUsedInsideIf") public Stage stage() { - if (sort != null) { + if (sortProject != null) { + return Stage.SORT_PROJECT; + } else if (sort != null) { return Stage.SORT; - } else if (postProject != null) { - return Stage.POST_PROJECT; + } else if (aggregateProject != null) { + return Stage.AGGREGATE_PROJECT; } else if (havingFilter != null) { return Stage.HAVING_FILTER; } else if (aggregate != null) { @@ -309,10 +347,12 @@ public RelNode leafRel() final Stage currentStage = stage(); switch (currentStage) { + case SORT_PROJECT: + return sortProject; case SORT: return sort; - case POST_PROJECT: - return postProject; + case AGGREGATE_PROJECT: + return aggregateProject; case HAVING_FILTER: return havingFilter; case AGGREGATE: @@ -353,14 +393,25 @@ public boolean equals(final Object o) Objects.equals(selectSort, that.selectSort) && Objects.equals(aggregate, that.aggregate) && Objects.equals(havingFilter, that.havingFilter) && - Objects.equals(postProject, that.postProject) && - Objects.equals(sort, that.sort); + Objects.equals(aggregateProject, that.aggregateProject) && + Objects.equals(sort, that.sort) && + Objects.equals(sortProject, that.sortProject); } @Override public int hashCode() { - return Objects.hash(scan, whereFilter, selectProject, selectSort, aggregate, havingFilter, postProject, sort); + return Objects.hash( + scan, + whereFilter, + selectProject, + selectSort, + aggregate, + havingFilter, + aggregateProject, + sort, + sortProject + ); } @Override @@ -373,8 +424,9 @@ public String toString() ", selectSort=" + selectSort + ", aggregate=" + aggregate + ", havingFilter=" + havingFilter + - ", postProject=" + postProject + + ", aggregateProject=" + aggregateProject + ", sort=" + sort + + ", sortProject=" + sortProject + '}'; } } diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java b/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java new file mode 100644 index 000000000000..c00aff97ee5b --- /dev/null +++ b/sql/src/main/java/io/druid/sql/calcite/rel/SortProject.java @@ -0,0 +1,112 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package io.druid.sql.calcite.rel; + +import com.google.common.base.Preconditions; +import io.druid.java.util.common.ISE; +import io.druid.query.aggregation.PostAggregator; +import io.druid.sql.calcite.table.RowSignature; + +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +public class SortProject +{ + private final RowSignature inputRowSignature; + private final List postAggregators; + private final RowSignature outputRowSignature; + + SortProject( + RowSignature inputRowSignature, + List postAggregators, + RowSignature outputRowSignature + ) + { + this.inputRowSignature = Preconditions.checkNotNull(inputRowSignature, "inputRowSignature"); + this.postAggregators = Preconditions.checkNotNull(postAggregators, "postAggregators"); + this.outputRowSignature = Preconditions.checkNotNull(outputRowSignature, "outputRowSignature"); + + // Verify no collisions. + final Set seen = new HashSet<>(); + inputRowSignature.getRowOrder().forEach(field -> { + if (!seen.add(field)) { + throw new ISE("Duplicate field name: %s", field); + } + }); + + for (PostAggregator postAggregator : postAggregators) { + if (postAggregator == null) { + throw new ISE("aggregation[%s] is not a postAggregator", postAggregator); + } + if (!seen.add(postAggregator.getName())) { + throw new ISE("Duplicate field name: %s", postAggregator.getName()); + } + } + + // Verify that items in the output signature exist. + outputRowSignature.getRowOrder().forEach(field -> { + if (!seen.contains(field)) { + throw new ISE("Missing field in rowOrder: %s", field); + } + }); + } + + public List getPostAggregators() + { + return postAggregators; + } + + public RowSignature getOutputRowSignature() + { + return outputRowSignature; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SortProject sortProject = (SortProject) o; + return Objects.equals(inputRowSignature, sortProject.inputRowSignature) && + Objects.equals(postAggregators, sortProject.postAggregators) && + Objects.equals(outputRowSignature, sortProject.outputRowSignature); + } + + @Override + public int hashCode() + { + return Objects.hash(inputRowSignature, postAggregators, outputRowSignature); + } + + @Override + public String toString() + { + return "SortProject{" + + "inputRowSignature=" + inputRowSignature + + ", postAggregators=" + postAggregators + + ", outputRowSignature=" + outputRowSignature + + '}'; + } +} diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java b/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java index b565aba995aa..c2c6208416cb 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java +++ b/sql/src/main/java/io/druid/sql/calcite/rule/DruidRules.java @@ -68,8 +68,8 @@ public static List rules() ), new DruidQueryRule<>( Project.class, - PartialDruidQuery.Stage.POST_PROJECT, - PartialDruidQuery::withPostProject + PartialDruidQuery.Stage.AGGREGATE_PROJECT, + PartialDruidQuery::withAggregateProject ), new DruidQueryRule<>( Filter.class, @@ -81,10 +81,16 @@ public static List rules() PartialDruidQuery.Stage.SORT, PartialDruidQuery::withSort ), + new DruidQueryRule<>( + Project.class, + PartialDruidQuery.Stage.SORT_PROJECT, + PartialDruidQuery::withSortProject + ), DruidOuterQueryRule.AGGREGATE, DruidOuterQueryRule.FILTER_AGGREGATE, DruidOuterQueryRule.FILTER_PROJECT_AGGREGATE, - DruidOuterQueryRule.PROJECT_AGGREGATE + DruidOuterQueryRule.PROJECT_AGGREGATE, + DruidOuterQueryRule.AGGREGATE_SORT_PROJECT ); } @@ -227,6 +233,32 @@ public void onMatch(final RelOptRuleCall call) } }; + public static RelOptRule AGGREGATE_SORT_PROJECT = new DruidOuterQueryRule( + operand(Project.class, operand(Sort.class, operand(Aggregate.class, operand(DruidRel.class, any())))), + "AGGREGATE_SORT_PROJECT" + ) + { + @Override + public void onMatch(RelOptRuleCall call) + { + final Project sortProject = call.rel(0); + final Sort sort = call.rel(1); + final Aggregate aggregate = call.rel(2); + final DruidRel druidRel = call.rel(3); + + final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create( + druidRel, + PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel()) + .withAggregate(aggregate) + .withSort(sort) + .withSortProject(sortProject) + ); + if (outerQueryRel.isValidDruidQuery()) { + call.transformTo(outerQueryRel); + } + } + }; + public DruidOuterQueryRule(final RelOptRuleOperand op, final String description) { super(op, StringUtils.format("%s(%s)", DruidOuterQueryRel.class.getSimpleName(), description)); diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java b/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java index 5376ff124f1d..9ef0430932b3 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java +++ b/sql/src/main/java/io/druid/sql/calcite/rule/DruidSemiJoinRule.java @@ -24,6 +24,7 @@ import io.druid.sql.calcite.planner.PlannerConfig; import io.druid.sql.calcite.rel.DruidRel; import io.druid.sql.calcite.rel.DruidSemiJoin; +import io.druid.sql.calcite.rel.PartialDruidQuery; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelOptUtil; @@ -115,15 +116,18 @@ public void onMatch(RelOptRuleCall call) return; } - final Project rightPostProject = right.getPartialDruidQuery().getPostProject(); + final PartialDruidQuery rightQuery = right.getPartialDruidQuery(); + final Project rightProject = rightQuery.getSortProject() != null ? + rightQuery.getSortProject() : + rightQuery.getAggregateProject(); int i = 0; for (int joinRef : joinInfo.rightSet()) { final int aggregateRef; - if (rightPostProject == null) { + if (rightProject == null) { aggregateRef = joinRef; } else { - final RexNode projectExp = rightPostProject.getChildExps().get(joinRef); + final RexNode projectExp = rightProject.getChildExps().get(joinRef); if (projectExp.isA(SqlKind.INPUT_REF)) { aggregateRef = ((RexInputRef) projectExp).getIndex(); } else { diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java index 5bfa8a3bfab5..dccdc5f3ef08 100644 --- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java @@ -73,6 +73,7 @@ import io.druid.query.groupby.having.DimFilterHavingSpec; import io.druid.query.groupby.orderby.DefaultLimitSpec; import io.druid.query.groupby.orderby.OrderByColumnSpec; +import io.druid.query.groupby.orderby.OrderByColumnSpec.Direction; import io.druid.query.lookup.RegisteredLookupExtractionFn; import io.druid.query.ordering.StringComparator; import io.druid.query.ordering.StringComparators; @@ -123,6 +124,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -6446,6 +6448,193 @@ public void testUnicodeFilterAndGroupBy() throws Exception ); } + @Test + public void testProjectAfterSort() throws Exception + { + testQuery( + "select dim1 from (select dim1, dim2, count(*) cnt from druid.foo group by dim1, dim2 order by cnt)", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(QSS(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions( + DIMS( + new DefaultDimensionSpec("dim1", "d0"), + new DefaultDimensionSpec("dim2", "d1") + ) + ) + .setAggregatorSpecs(AGGS(new CountAggregatorFactory("a0"))) + .setLimitSpec( + new DefaultLimitSpec( + Collections.singletonList( + new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC) + ), + Integer.MAX_VALUE + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{""}, + new Object[]{"1"}, + new Object[]{"10.1"}, + new Object[]{"2"}, + new Object[]{"abc"}, + new Object[]{"def"} + ) + ); + } + + @Test + public void testProjectAfterSort2() throws Exception + { + testQuery( + "select s / cnt, dim1, dim2, s from (select dim1, dim2, count(*) cnt, sum(m2) s from druid.foo group by dim1, dim2 order by cnt)", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(QSS(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions( + DIMS( + new DefaultDimensionSpec("dim1", "d0"), + new DefaultDimensionSpec("dim2", "d1") + ) + ) + .setAggregatorSpecs( + AGGS(new CountAggregatorFactory("a0"), new DoubleSumAggregatorFactory("a1", "m2")) + ) + .setPostAggregatorSpecs(Collections.singletonList(EXPRESSION_POST_AGG("p0", "(\"a1\" / \"a0\")"))) + .setLimitSpec( + new DefaultLimitSpec( + Collections.singletonList( + new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC) + ), + Integer.MAX_VALUE + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{1.0, "", "a", 1.0}, + new Object[]{4.0, "1", "a", 4.0}, + new Object[]{2.0, "10.1", "", 2.0}, + new Object[]{3.0, "2", "", 3.0}, + new Object[]{6.0, "abc", "", 6.0}, + new Object[]{5.0, "def", "abc", 5.0} + ) + ); + } + + @Test + public void testProjectAfterSort3() throws Exception + { + testQuery( + "select dim1 from (select dim1, dim1, count(*) cnt from druid.foo group by dim1, dim1 order by cnt)", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(QSS(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions( + DIMS( + new DefaultDimensionSpec("dim1", "d0") + ) + ) + .setAggregatorSpecs(AGGS(new CountAggregatorFactory("a0"))) + .setLimitSpec( + new DefaultLimitSpec( + Collections.singletonList( + new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC) + ), + Integer.MAX_VALUE + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{""}, + new Object[]{"1"}, + new Object[]{"10.1"}, + new Object[]{"2"}, + new Object[]{"abc"}, + new Object[]{"def"} + ) + ); + } + + @Test + public void testSortProjectAfterNestedGroupBy() throws Exception + { + testQuery( + "SELECT " + + " cnt " + + "FROM (" + + " SELECT " + + " __time, " + + " dim1, " + + " COUNT(m2) AS cnt " + + " FROM (" + + " SELECT " + + " __time, " + + " m2, " + + " dim1 " + + " FROM druid.foo " + + " GROUP BY __time, m2, dim1 " + + " ) " + + " GROUP BY __time, dim1 " + + " ORDER BY cnt" + + ")", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(QSS(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(DIMS( + new DefaultDimensionSpec("__time", "d0", ValueType.LONG), + new DefaultDimensionSpec("dim1", "d1"), + new DefaultDimensionSpec("m2", "d2", ValueType.DOUBLE) + )) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ) + .setInterval(QSS(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(DIMS( + new DefaultDimensionSpec("d0", "_d0", ValueType.LONG), + new DefaultDimensionSpec("d1", "_d1", ValueType.STRING) + )) + .setAggregatorSpecs(AGGS( + new CountAggregatorFactory("a0") + )) + .setLimitSpec( + new DefaultLimitSpec( + Collections.singletonList( + new OrderByColumnSpec("a0", Direction.ASCENDING, StringComparators.NUMERIC) + ), + Integer.MAX_VALUE + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{1L}, + new Object[]{1L}, + new Object[]{1L}, + new Object[]{1L}, + new Object[]{1L}, + new Object[]{1L} + ) + ); + } + private void testQuery( final String sql, final List expectedQueries,