Skip to content

Commit

Permalink
Support projection after sorting in SQL (#5788) (#6228)
Browse files Browse the repository at this point in the history
* Add sort project

* add more test

* address comments
  • Loading branch information
gianm authored and fjy committed Aug 26, 2018
1 parent 98234e8 commit bc07320
Show file tree
Hide file tree
Showing 9 changed files with 581 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.druid.sql.calcite.table.RowSignature;

import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -112,7 +113,7 @@ public static Aggregation create(final AggregatorFactory aggregatorFactory)

public static Aggregation create(final PostAggregator postAggregator)
{
return new Aggregation(ImmutableList.of(), ImmutableList.of(), postAggregator);
return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator);
}

public static Aggregation create(
Expand Down
190 changes: 138 additions & 52 deletions sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.OptionalInt;
import java.util.TreeSet;
import java.util.stream.Collectors;

Expand All @@ -105,9 +106,11 @@ public class DruidQuery
private final DimFilter filter;
private final SelectProjection selectProjection;
private final Grouping grouping;
private final SortProject sortProject;
private final DefaultLimitSpec limitSpec;
private final RowSignature outputRowSignature;
private final RelDataType outputRowType;
private final DefaultLimitSpec limitSpec;

private final Query query;

public DruidQuery(
Expand All @@ -129,15 +132,22 @@ public DruidQuery(
this.selectProjection = computeSelectProjection(partialQuery, plannerContext, sourceRowSignature);
this.grouping = computeGrouping(partialQuery, plannerContext, sourceRowSignature, rexBuilder, finalizeAggregations);

final RowSignature sortingInputRowSignature;

if (this.selectProjection != null) {
this.outputRowSignature = this.selectProjection.getOutputRowSignature();
sortingInputRowSignature = this.selectProjection.getOutputRowSignature();
} else if (this.grouping != null) {
this.outputRowSignature = this.grouping.getOutputRowSignature();
sortingInputRowSignature = this.grouping.getOutputRowSignature();
} else {
this.outputRowSignature = sourceRowSignature;
sortingInputRowSignature = sourceRowSignature;
}

this.limitSpec = computeLimitSpec(partialQuery, this.outputRowSignature);
this.sortProject = computeSortProject(partialQuery, plannerContext, sortingInputRowSignature, grouping);

// outputRowSignature is used only for scan and select query, and thus sort and grouping must be null
this.outputRowSignature = sortProject == null ? sortingInputRowSignature : sortProject.getOutputRowSignature();

this.limitSpec = computeLimitSpec(partialQuery, sortingInputRowSignature);
this.query = computeQuery();
}

Expand Down Expand Up @@ -237,7 +247,7 @@ private static Grouping computeGrouping(
)
{
final Aggregate aggregate = partialQuery.getAggregate();
final Project postProject = partialQuery.getPostProject();
final Project aggregateProject = partialQuery.getAggregateProject();

if (aggregate == null) {
return null;
Expand Down Expand Up @@ -268,49 +278,27 @@ private static Grouping computeGrouping(
plannerContext
);

if (postProject == null) {
if (aggregateProject == null) {
return Grouping.create(dimensions, aggregations, havingFilter, aggregateRowSignature);
} else {
final List<String> rowOrder = new ArrayList<>();

int outputNameCounter = 0;
for (final RexNode postAggregatorRexNode : postProject.getChildExps()) {
// Attempt to convert to PostAggregator.
final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
plannerContext,
aggregateRowSignature,
postAggregatorRexNode
);

if (postAggregatorExpression == null) {
throw new CannotBuildQueryException(postProject, postAggregatorRexNode);
}

if (postAggregatorDirectColumnIsOk(aggregateRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
// Direct column access, without any type cast as far as Druid's runtime is concerned.
// (There might be a SQL-level type cast that we don't care about)
rowOrder.add(postAggregatorExpression.getDirectColumn());
} else {
final String postAggregatorName = "p" + outputNameCounter++;
final PostAggregator postAggregator = new ExpressionPostAggregator(
postAggregatorName,
postAggregatorExpression.getExpression(),
null,
plannerContext.getExprMacroTable()
);
aggregations.add(Aggregation.create(postAggregator));
rowOrder.add(postAggregator.getName());
}
}
final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
plannerContext,
aggregateRowSignature,
aggregateProject,
0
);
projectRowOrderAndPostAggregations.postAggregations.forEach(
postAggregator -> aggregations.add(Aggregation.create(postAggregator))
);

// Remove literal dimensions that did not appear in the projection. This is useful for queries
// like "SELECT COUNT(*) FROM tbl GROUP BY 'dummy'" which some tools can generate, and for which we don't
// actually want to include a dimension 'dummy'.
final ImmutableBitSet postProjectBits = RelOptUtil.InputFinder.bits(postProject.getChildExps(), null);
final ImmutableBitSet aggregateProjectBits = RelOptUtil.InputFinder.bits(aggregateProject.getChildExps(), null);
for (int i = dimensions.size() - 1; i >= 0; i--) {
final DimensionExpression dimension = dimensions.get(i);
if (Parser.parse(dimension.getDruidExpression().getExpression(), plannerContext.getExprMacroTable())
.isLiteral() && !postProjectBits.get(i)) {
.isLiteral() && !aggregateProjectBits.get(i)) {
dimensions.remove(i);
}
}
Expand All @@ -319,11 +307,98 @@ private static Grouping computeGrouping(
dimensions,
aggregations,
havingFilter,
RowSignature.from(rowOrder, postProject.getRowType())
RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, aggregateProject.getRowType())
);
}
}

@Nullable
private SortProject computeSortProject(
PartialDruidQuery partialQuery,
PlannerContext plannerContext,
RowSignature sortingInputRowSignature,
Grouping grouping
)
{
final Project sortProject = partialQuery.getSortProject();
if (sortProject == null) {
return null;
} else {
final List<PostAggregator> postAggregators = grouping.getPostAggregators();
final OptionalInt maybeMaxCounter = postAggregators
.stream()
.mapToInt(postAggregator -> Integer.parseInt(postAggregator.getName().substring(1)))
.max();

final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
plannerContext,
sortingInputRowSignature,
sortProject,
maybeMaxCounter.orElse(-1) + 1 // 0 if max doesn't exist
);

return new SortProject(
sortingInputRowSignature,
projectRowOrderAndPostAggregations.postAggregations,
RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, sortProject.getRowType())
);
}
}

private static class ProjectRowOrderAndPostAggregations
{
private final List<String> rowOrder;
private final List<PostAggregator> postAggregations;

ProjectRowOrderAndPostAggregations(List<String> rowOrder, List<PostAggregator> postAggregations)
{
this.rowOrder = rowOrder;
this.postAggregations = postAggregations;
}
}

private static ProjectRowOrderAndPostAggregations computePostAggregations(
PlannerContext plannerContext,
RowSignature inputRowSignature,
Project project,
int outputNameCounter
)
{
final List<String> rowOrder = new ArrayList<>();
final List<PostAggregator> aggregations = new ArrayList<>();

for (final RexNode postAggregatorRexNode : project.getChildExps()) {
// Attempt to convert to PostAggregator.
final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
plannerContext,
inputRowSignature,
postAggregatorRexNode
);

if (postAggregatorExpression == null) {
throw new CannotBuildQueryException(project, postAggregatorRexNode);
}

if (postAggregatorDirectColumnIsOk(inputRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
// Direct column access, without any type cast as far as Druid's runtime is concerned.
// (There might be a SQL-level type cast that we don't care about)
rowOrder.add(postAggregatorExpression.getDirectColumn());
} else {
final String postAggregatorName = "p" + outputNameCounter++;
final PostAggregator postAggregator = new ExpressionPostAggregator(
postAggregatorName,
postAggregatorExpression.getExpression(),
null,
plannerContext.getExprMacroTable()
);
aggregations.add(postAggregator);
rowOrder.add(postAggregator.getName());
}
}

return new ProjectRowOrderAndPostAggregations(rowOrder, aggregations);
}

/**
* Returns dimensions corresponding to {@code aggregate.getGroupSet()}, in the same order.
*
Expand Down Expand Up @@ -548,18 +623,20 @@ public VirtualColumns getVirtualColumns(final ExprMacroTable macroTable, final b
{
final List<VirtualColumn> retVal = new ArrayList<>();

if (grouping != null) {
if (includeDimensions) {
for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
if (selectProjection != null) {
retVal.addAll(selectProjection.getVirtualColumns());
} else {
if (grouping != null) {
if (includeDimensions) {
for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
}
}
}

for (Aggregation aggregation : grouping.getAggregations()) {
retVal.addAll(aggregation.getVirtualColumns());
for (Aggregation aggregation : grouping.getAggregations()) {
retVal.addAll(aggregation.getVirtualColumns());
}
}
} else if (selectProjection != null) {
retVal.addAll(selectProjection.getVirtualColumns());
}

return VirtualColumns.create(retVal);
Expand All @@ -575,6 +652,11 @@ public DefaultLimitSpec getLimitSpec()
return limitSpec;
}

public SortProject getSortProject()
{
return sortProject;
}

public RelDataType getOutputRowType()
{
return outputRowType;
Expand Down Expand Up @@ -675,7 +757,6 @@ public TimeseriesQuery toTimeseriesQuery()

if (limitSpec != null) {
// If there is a limit spec, timeseries cannot LIMIT; and must be ORDER BY time (or nothing).

if (limitSpec.isLimited()) {
return null;
}
Expand Down Expand Up @@ -805,6 +886,11 @@ public GroupByQuery toGroupByQuery()

final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature);

final List<PostAggregator> postAggregators = new ArrayList<>(grouping.getPostAggregators());
if (sortProject != null) {
postAggregators.addAll(sortProject.getPostAggregators());
}

return new GroupByQuery(
dataSource,
filtration.getQuerySegmentSpec(),
Expand All @@ -813,7 +899,7 @@ public GroupByQuery toGroupByQuery()
Granularities.ALL,
grouping.getDimensionSpecs(),
grouping.getAggregatorFactories(),
grouping.getPostAggregators(),
postAggregators,
grouping.getHavingFilter() != null ? new DimFilterHavingSpec(grouping.getHavingFilter(), true) : null,
limitSpec,
ImmutableSortedMap.copyOf(plannerContext.getQueryContext())
Expand Down
8 changes: 6 additions & 2 deletions sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,18 @@ public RelOptCost computeSelfCost(final RelOptPlanner planner, final RelMetadata
cost += COST_PER_COLUMN * partialQuery.getAggregate().getAggCallList().size();
}

if (partialQuery.getPostProject() != null) {
cost += COST_PER_COLUMN * partialQuery.getPostProject().getChildExps().size();
if (partialQuery.getAggregateProject() != null) {
cost += COST_PER_COLUMN * partialQuery.getAggregateProject().getChildExps().size();
}

if (partialQuery.getSort() != null && partialQuery.getSort().fetch != null) {
cost *= COST_LIMIT_MULTIPLIER;
}

if (partialQuery.getSortProject() != null) {
cost += COST_PER_COLUMN * partialQuery.getSortProject().getChildExps().size();
}

if (partialQuery.getHavingFilter() != null) {
cost *= COST_HAVING_MULTIPLIER;
}
Expand Down
8 changes: 6 additions & 2 deletions sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,12 @@ public List<RexNode> accumulate(final List<RexNode> theConditions, final Object[
newPartialQuery = newPartialQuery.withHavingFilter(leftPartialQuery.getHavingFilter());
}

if (leftPartialQuery.getPostProject() != null) {
newPartialQuery = newPartialQuery.withPostProject(leftPartialQuery.getPostProject());
if (leftPartialQuery.getAggregateProject() != null) {
newPartialQuery = newPartialQuery.withAggregateProject(leftPartialQuery.getAggregateProject());
}

if (leftPartialQuery.getSortProject() != null) {
newPartialQuery = newPartialQuery.withSortProject(leftPartialQuery.getSortProject());
}

if (leftPartialQuery.getSort() != null) {
Expand Down

0 comments on commit bc07320

Please sign in to comment.