Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support projection after sorting in SQL #5788

Merged
merged 3 commits into from
Jun 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.druid.sql.calcite.table.RowSignature;

import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -112,7 +113,7 @@ public static Aggregation create(final AggregatorFactory aggregatorFactory)

public static Aggregation create(final PostAggregator postAggregator)
{
return new Aggregation(ImmutableList.of(), ImmutableList.of(), postAggregator);
return new Aggregation(Collections.emptyList(), Collections.emptyList(), postAggregator);
}

public static Aggregation create(
Expand Down
190 changes: 138 additions & 52 deletions sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.OptionalInt;
import java.util.TreeSet;
import java.util.stream.Collectors;

Expand All @@ -105,9 +106,11 @@ public class DruidQuery
private final DimFilter filter;
private final SelectProjection selectProjection;
private final Grouping grouping;
private final SortProject sortProject;
private final DefaultLimitSpec limitSpec;
private final RowSignature outputRowSignature;
private final RelDataType outputRowType;
private final DefaultLimitSpec limitSpec;

private final Query query;

public DruidQuery(
Expand All @@ -128,15 +131,22 @@ public DruidQuery(
this.selectProjection = computeSelectProjection(partialQuery, plannerContext, sourceRowSignature);
this.grouping = computeGrouping(partialQuery, plannerContext, sourceRowSignature, rexBuilder);

final RowSignature sortingInputRowSignature;

if (this.selectProjection != null) {
this.outputRowSignature = this.selectProjection.getOutputRowSignature();
sortingInputRowSignature = this.selectProjection.getOutputRowSignature();
} else if (this.grouping != null) {
this.outputRowSignature = this.grouping.getOutputRowSignature();
sortingInputRowSignature = this.grouping.getOutputRowSignature();
} else {
this.outputRowSignature = sourceRowSignature;
sortingInputRowSignature = sourceRowSignature;
}

this.limitSpec = computeLimitSpec(partialQuery, this.outputRowSignature);
this.sortProject = computeSortProject(partialQuery, plannerContext, sortingInputRowSignature, grouping);

// outputRowSignature is used only for scan and select query, and thus sort and grouping must be null
this.outputRowSignature = sortProject == null ? sortingInputRowSignature : sortProject.getOutputRowSignature();

this.limitSpec = computeLimitSpec(partialQuery, sortingInputRowSignature);
this.query = computeQuery();
}

Expand Down Expand Up @@ -235,7 +245,7 @@ private static Grouping computeGrouping(
)
{
final Aggregate aggregate = partialQuery.getAggregate();
final Project postProject = partialQuery.getPostProject();
final Project aggregateProject = partialQuery.getAggregateProject();

if (aggregate == null) {
return null;
Expand Down Expand Up @@ -265,49 +275,27 @@ private static Grouping computeGrouping(
plannerContext
);

if (postProject == null) {
if (aggregateProject == null) {
return Grouping.create(dimensions, aggregations, havingFilter, aggregateRowSignature);
} else {
final List<String> rowOrder = new ArrayList<>();

int outputNameCounter = 0;
for (final RexNode postAggregatorRexNode : postProject.getChildExps()) {
// Attempt to convert to PostAggregator.
final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
plannerContext,
aggregateRowSignature,
postAggregatorRexNode
);

if (postAggregatorExpression == null) {
throw new CannotBuildQueryException(postProject, postAggregatorRexNode);
}

if (postAggregatorDirectColumnIsOk(aggregateRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
// Direct column access, without any type cast as far as Druid's runtime is concerned.
// (There might be a SQL-level type cast that we don't care about)
rowOrder.add(postAggregatorExpression.getDirectColumn());
} else {
final String postAggregatorName = "p" + outputNameCounter++;
final PostAggregator postAggregator = new ExpressionPostAggregator(
postAggregatorName,
postAggregatorExpression.getExpression(),
null,
plannerContext.getExprMacroTable()
);
aggregations.add(Aggregation.create(postAggregator));
rowOrder.add(postAggregator.getName());
}
}
final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
plannerContext,
aggregateRowSignature,
aggregateProject,
0
);
projectRowOrderAndPostAggregations.postAggregations.forEach(
postAggregator -> aggregations.add(Aggregation.create(postAggregator))
);

// Remove literal dimensions that did not appear in the projection. This is useful for queries
// like "SELECT COUNT(*) FROM tbl GROUP BY 'dummy'" which some tools can generate, and for which we don't
// actually want to include a dimension 'dummy'.
final ImmutableBitSet postProjectBits = RelOptUtil.InputFinder.bits(postProject.getChildExps(), null);
final ImmutableBitSet aggregateProjectBits = RelOptUtil.InputFinder.bits(aggregateProject.getChildExps(), null);
for (int i = dimensions.size() - 1; i >= 0; i--) {
final DimensionExpression dimension = dimensions.get(i);
if (Parser.parse(dimension.getDruidExpression().getExpression(), plannerContext.getExprMacroTable())
.isLiteral() && !postProjectBits.get(i)) {
.isLiteral() && !aggregateProjectBits.get(i)) {
dimensions.remove(i);
}
}
Expand All @@ -316,11 +304,98 @@ private static Grouping computeGrouping(
dimensions,
aggregations,
havingFilter,
RowSignature.from(rowOrder, postProject.getRowType())
RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, aggregateProject.getRowType())
);
}
}

@Nullable
private SortProject computeSortProject(
PartialDruidQuery partialQuery,
PlannerContext plannerContext,
RowSignature sortingInputRowSignature,
Grouping grouping
)
{
final Project sortProject = partialQuery.getSortProject();
if (sortProject == null) {
return null;
} else {
final List<PostAggregator> postAggregators = grouping.getPostAggregators();
final OptionalInt maybeMaxCounter = postAggregators
.stream()
.mapToInt(postAggregator -> Integer.parseInt(postAggregator.getName().substring(1)))
.max();

final ProjectRowOrderAndPostAggregations projectRowOrderAndPostAggregations = computePostAggregations(
plannerContext,
sortingInputRowSignature,
sortProject,
maybeMaxCounter.orElse(-1) + 1 // 0 if max doesn't exist
);

return new SortProject(
sortingInputRowSignature,
projectRowOrderAndPostAggregations.postAggregations,
RowSignature.from(projectRowOrderAndPostAggregations.rowOrder, sortProject.getRowType())
);
}
}

private static class ProjectRowOrderAndPostAggregations
{
private final List<String> rowOrder;
private final List<PostAggregator> postAggregations;

ProjectRowOrderAndPostAggregations(List<String> rowOrder, List<PostAggregator> postAggregations)
{
this.rowOrder = rowOrder;
this.postAggregations = postAggregations;
}
}

private static ProjectRowOrderAndPostAggregations computePostAggregations(
PlannerContext plannerContext,
RowSignature inputRowSignature,
Project project,
int outputNameCounter
)
{
final List<String> rowOrder = new ArrayList<>();
final List<PostAggregator> aggregations = new ArrayList<>();

for (final RexNode postAggregatorRexNode : project.getChildExps()) {
// Attempt to convert to PostAggregator.
final DruidExpression postAggregatorExpression = Expressions.toDruidExpression(
plannerContext,
inputRowSignature,
postAggregatorRexNode
);

if (postAggregatorExpression == null) {
throw new CannotBuildQueryException(project, postAggregatorRexNode);
}

if (postAggregatorDirectColumnIsOk(inputRowSignature, postAggregatorExpression, postAggregatorRexNode)) {
// Direct column access, without any type cast as far as Druid's runtime is concerned.
// (There might be a SQL-level type cast that we don't care about)
rowOrder.add(postAggregatorExpression.getDirectColumn());
} else {
final String postAggregatorName = "p" + outputNameCounter++;
final PostAggregator postAggregator = new ExpressionPostAggregator(
postAggregatorName,
postAggregatorExpression.getExpression(),
null,
plannerContext.getExprMacroTable()
);
aggregations.add(postAggregator);
rowOrder.add(postAggregator.getName());
}
}

return new ProjectRowOrderAndPostAggregations(rowOrder, aggregations);
}

/**
* Returns dimensions corresponding to {@code aggregate.getGroupSet()}, in the same order.
*
Expand Down Expand Up @@ -540,18 +615,20 @@ public VirtualColumns getVirtualColumns(final ExprMacroTable macroTable, final b
{
final List<VirtualColumn> retVal = new ArrayList<>();

if (grouping != null) {
if (includeDimensions) {
for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
if (selectProjection != null) {
retVal.addAll(selectProjection.getVirtualColumns());
} else {
if (grouping != null) {
if (includeDimensions) {
for (DimensionExpression dimensionExpression : grouping.getDimensions()) {
retVal.addAll(dimensionExpression.getVirtualColumns(macroTable));
}
}
}

for (Aggregation aggregation : grouping.getAggregations()) {
retVal.addAll(aggregation.getVirtualColumns());
for (Aggregation aggregation : grouping.getAggregations()) {
retVal.addAll(aggregation.getVirtualColumns());
}
}
} else if (selectProjection != null) {
retVal.addAll(selectProjection.getVirtualColumns());
}

return VirtualColumns.create(retVal);
Expand All @@ -567,6 +644,11 @@ public DefaultLimitSpec getLimitSpec()
return limitSpec;
}

public SortProject getSortProject()
{
return sortProject;
}

public RelDataType getOutputRowType()
{
return outputRowType;
Expand Down Expand Up @@ -667,7 +749,6 @@ public TimeseriesQuery toTimeseriesQuery()

if (limitSpec != null) {
// If there is a limit spec, timeseries cannot LIMIT; and must be ORDER BY time (or nothing).

if (limitSpec.isLimited()) {
return null;
}
Expand Down Expand Up @@ -797,6 +878,11 @@ public GroupByQuery toGroupByQuery()

final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature);

final List<PostAggregator> postAggregators = new ArrayList<>(grouping.getPostAggregators());
if (sortProject != null) {
postAggregators.addAll(sortProject.getPostAggregators());
}

return new GroupByQuery(
dataSource,
filtration.getQuerySegmentSpec(),
Expand All @@ -805,7 +891,7 @@ public GroupByQuery toGroupByQuery()
Granularities.ALL,
grouping.getDimensionSpecs(),
grouping.getAggregatorFactories(),
grouping.getPostAggregators(),
postAggregators,
grouping.getHavingFilter() != null ? new DimFilterHavingSpec(grouping.getHavingFilter(), true) : null,
limitSpec,
ImmutableSortedMap.copyOf(plannerContext.getQueryContext())
Expand Down
8 changes: 6 additions & 2 deletions sql/src/main/java/io/druid/sql/calcite/rel/DruidQueryRel.java
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,18 @@ public RelOptCost computeSelfCost(final RelOptPlanner planner, final RelMetadata
cost += COST_PER_COLUMN * partialQuery.getAggregate().getAggCallList().size();
}

if (partialQuery.getPostProject() != null) {
cost += COST_PER_COLUMN * partialQuery.getPostProject().getChildExps().size();
if (partialQuery.getAggregateProject() != null) {
cost += COST_PER_COLUMN * partialQuery.getAggregateProject().getChildExps().size();
}

if (partialQuery.getSort() != null && partialQuery.getSort().fetch != null) {
cost *= COST_LIMIT_MULTIPLIER;
}

if (partialQuery.getSortProject() != null) {
cost += COST_PER_COLUMN * partialQuery.getSortProject().getChildExps().size();
}

if (partialQuery.getHavingFilter() != null) {
cost *= COST_HAVING_MULTIPLIER;
}
Expand Down
8 changes: 6 additions & 2 deletions sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,12 @@ public List<RexNode> accumulate(final List<RexNode> theConditions, final Object[
newPartialQuery = newPartialQuery.withHavingFilter(leftPartialQuery.getHavingFilter());
}

if (leftPartialQuery.getPostProject() != null) {
newPartialQuery = newPartialQuery.withPostProject(leftPartialQuery.getPostProject());
if (leftPartialQuery.getAggregateProject() != null) {
newPartialQuery = newPartialQuery.withAggregateProject(leftPartialQuery.getAggregateProject());
}

if (leftPartialQuery.getSortProject() != null) {
newPartialQuery = newPartialQuery.withSortProject(leftPartialQuery.getSortProject());
}

if (leftPartialQuery.getSort() != null) {
Expand Down
Loading