Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Calcite 1758] push order and limit to druid #433

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -60,6 +60,7 @@
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Litmus;
import org.apache.calcite.util.Pair;
Expand All @@ -71,13 +72,15 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;

import java.io.IOException;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import java.util.Set;
import java.util.regex.Pattern;

import static org.apache.calcite.sql.SqlKind.INPUT_REF;
Expand All @@ -100,6 +103,7 @@ public class DruidQuery extends AbstractRelNode implements BindableRel {

private static final Pattern VALID_SIG = Pattern.compile("sf?p?a?l?");
private static final String EXTRACT_COLUMN_NAME_PREFIX = "extract";
private static final String FLOOR_COLUMN_NAME_PREFIX = "floor";
protected static final String DRUID_QUERY_FETCH = "druid.query.fetch";

/**
Expand Down Expand Up @@ -377,9 +381,10 @@ public DruidTable getDruidTable() {
// A plan where all extra columns are pruned will be preferred.
.multiplyBy(
RelMdUtil.linear(querySpec.fieldNames.size(), 2, 100, 1d, 2d))
.multiplyBy(getQueryTypeCostMultiplier());
.multiplyBy(getQueryTypeCostMultiplier())
// a plan with sort pushed to druid is better than doing sort outside of druid
.multiplyBy(Util.last(rels) instanceof Bindables.BindableSort ? 0.1 : 0.2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the superclass in case other rules do not generate a Sort with bindable convention, i.e., use instanceof Sort instead of BindableSort.

Further, if last operator is not a SortLimit, we should just leave cost as it is, i.e., multiply by 1 instead of 0.2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

}

private double getQueryTypeCostMultiplier() {
// Cost of Select > GroupBy > Timeseries > TopN
switch (querySpec.queryType) {
Expand Down Expand Up @@ -491,6 +496,7 @@ protected QuerySpec getQuery(RelDataType rowType, RexNode filter, List<RexNode>
QueryType queryType = QueryType.SELECT;
final Translator translator = new Translator(druidTable, rowType);
List<String> fieldNames = rowType.getFieldNames();
Set<String> usedFieldNames = Sets.newHashSet(fieldNames);

// Handle filter
Json jsonFilter = null;
Expand All @@ -515,7 +521,7 @@ protected QuerySpec getQuery(RelDataType rowType, RexNode filter, List<RexNode>
// executed as a Timeseries, TopN, or GroupBy in Druid
final List<DimensionSpec> dimensions = new ArrayList<>();
final List<JsonAggregation> aggregations = new ArrayList<>();
Granularity granularity = Granularity.ALL;
Granularity finalGranularity = Granularity.ALL;
Direction timeSeriesDirection = null;
JsonLimit limit = null;
TimeExtractionDimensionSpec timeExtractionDimensionSpec = null;
Expand All @@ -525,63 +531,67 @@ protected QuerySpec getQuery(RelDataType rowType, RexNode filter, List<RexNode>
assert aggCalls.size() == aggNames.size();

int timePositionIdx = -1;
int extractNumber = -1;
final ImmutableList.Builder<String> builder = ImmutableList.builder();
if (projects != null) {
for (int groupKey : groupSet) {
final String s = fieldNames.get(groupKey);
final String fieldName = fieldNames.get(groupKey);
final RexNode project = projects.get(groupKey);
if (project instanceof RexInputRef) {
// Reference could be to the timestamp or druid dimension but no druid metric
final RexInputRef ref = (RexInputRef) project;
final String origin = druidTable.getRowType(getCluster().getTypeFactory())
final String originalFieldName = druidTable.getRowType(getCluster().getTypeFactory())
.getFieldList().get(ref.getIndex()).getName();
if (origin.equals(druidTable.timestampFieldName)) {
granularity = Granularity.ALL;
// Generate unique name as timestampFieldName is taken
String extractColumnName = EXTRACT_COLUMN_NAME_PREFIX + "_" + (++extractNumber);
while (fieldNames.contains(extractColumnName)) {
extractColumnName = EXTRACT_COLUMN_NAME_PREFIX + "_" + (++extractNumber);
}
if (originalFieldName.equals(druidTable.timestampFieldName)) {
finalGranularity = Granularity.ALL;
String extractColumnName = SqlValidatorUtil.uniquify(EXTRACT_COLUMN_NAME_PREFIX
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

+ "_full_time", usedFieldNames, SqlValidatorUtil.EXPR_SUGGESTER);
timeExtractionDimensionSpec = TimeExtractionDimensionSpec.makeFullTimeExtract(
extractColumnName);
dimensions.add(timeExtractionDimensionSpec);
builder.add(extractColumnName);
assert timePositionIdx == -1;
timePositionIdx = groupKey;
} else {
dimensions.add(new DefaultDimensionSpec(s));
builder.add(s);
dimensions.add(new DefaultDimensionSpec(fieldName));
builder.add(fieldName);
}
} else if (project instanceof RexCall) {
// Call, check if we should infer granularity
final RexCall call = (RexCall) project;
final Granularity funcGranularity =
DruidDateTimeUtils.extractGranularity(call);
final Granularity funcGranularity = DruidDateTimeUtils.extractGranularity(call);
if (funcGranularity != null) {
if (call.getKind().equals(SqlKind.EXTRACT)) {
// case extract on time
granularity = Granularity.ALL;
// Generate unique name as timestampFieldName is taken
String extractColumnName = EXTRACT_COLUMN_NAME_PREFIX + "_" + (++extractNumber);
while (fieldNames.contains(extractColumnName)) {
extractColumnName = EXTRACT_COLUMN_NAME_PREFIX + "_" + (++extractNumber);
}
// case extract field from time column
finalGranularity = Granularity.ALL;
String extractColumnName = SqlValidatorUtil.uniquify(EXTRACT_COLUMN_NAME_PREFIX
+ "_" + funcGranularity.value, usedFieldNames, SqlValidatorUtil.EXPR_SUGGESTER);
timeExtractionDimensionSpec = TimeExtractionDimensionSpec.makeExtract(
funcGranularity, extractColumnName);
dimensions.add(timeExtractionDimensionSpec);
builder.add(extractColumnName);
} else {
// case floor by granularity
granularity = funcGranularity;
builder.add(s);
// case floor time column
if (groupSet.cardinality() > 1) {
// case we have more than 1 group by key -> then will have druid group by
String extractColumnName = SqlValidatorUtil.uniquify(FLOOR_COLUMN_NAME_PREFIX
+ "_" + funcGranularity.value, usedFieldNames, SqlValidatorUtil
.EXPR_SUGGESTER);
dimensions.add(
TimeExtractionDimensionSpec.makeFloor(funcGranularity, extractColumnName));
finalGranularity = Granularity.ALL;
builder.add(extractColumnName);
} else {
// case timeseries we can not use extraction function
finalGranularity = funcGranularity;
builder.add(fieldName);
}
assert timePositionIdx == -1;
timePositionIdx = groupKey;
}

} else {
dimensions.add(new DefaultDimensionSpec(s));
builder.add(s);
dimensions.add(new DefaultDimensionSpec(fieldName));
builder.add(fieldName);
}
} else {
throw new AssertionError("incompatible project expression: " + project);
Expand All @@ -591,12 +601,10 @@ protected QuerySpec getQuery(RelDataType rowType, RexNode filter, List<RexNode>
for (int groupKey : groupSet) {
final String s = fieldNames.get(groupKey);
if (s.equals(druidTable.timestampFieldName)) {
granularity = Granularity.ALL;
finalGranularity = Granularity.ALL;
// Generate unique name as timestampFieldName is taken
String extractColumnName = EXTRACT_COLUMN_NAME_PREFIX + "_" + (++extractNumber);
while (fieldNames.contains(extractColumnName)) {
extractColumnName = EXTRACT_COLUMN_NAME_PREFIX + "_" + (++extractNumber);
}
String extractColumnName = SqlValidatorUtil.uniquify(EXTRACT_COLUMN_NAME_PREFIX,
usedFieldNames, SqlValidatorUtil.EXPR_SUGGESTER);
timeExtractionDimensionSpec = TimeExtractionDimensionSpec.makeFullTimeExtract(
extractColumnName);
dimensions.add(timeExtractionDimensionSpec);
Expand Down Expand Up @@ -645,7 +653,7 @@ protected QuerySpec getQuery(RelDataType rowType, RexNode filter, List<RexNode>
queryType = QueryType.TIMESERIES;
assert fetch == null;
} else if (dimensions.size() == 1
&& granularity == Granularity.ALL
&& finalGranularity == Granularity.ALL
&& sortsMetric
&& collations.size() == 1
&& fetch != null
Expand Down Expand Up @@ -680,7 +688,7 @@ protected QuerySpec getQuery(RelDataType rowType, RexNode filter, List<RexNode>
generator.writeStringField("dataSource", druidTable.dataSource);
generator.writeBooleanField("descending", timeSeriesDirection != null
&& timeSeriesDirection == Direction.DESCENDING);
generator.writeStringField("granularity", granularity.value);
generator.writeStringField("granularity", finalGranularity.value);
writeFieldIf(generator, "filter", jsonFilter);
writeField(generator, "aggregations", aggregations);
writeFieldIf(generator, "postAggregations", null);
Expand All @@ -700,7 +708,7 @@ protected QuerySpec getQuery(RelDataType rowType, RexNode filter, List<RexNode>

generator.writeStringField("queryType", "topN");
generator.writeStringField("dataSource", druidTable.dataSource);
generator.writeStringField("granularity", granularity.value);
generator.writeStringField("granularity", finalGranularity.value);
writeField(generator, "dimension", dimensions.get(0));
generator.writeStringField("metric", fieldNames.get(collationIndexes.get(0)));
writeFieldIf(generator, "filter", jsonFilter);
Expand All @@ -716,7 +724,7 @@ protected QuerySpec getQuery(RelDataType rowType, RexNode filter, List<RexNode>
generator.writeStartObject();
generator.writeStringField("queryType", "groupBy");
generator.writeStringField("dataSource", druidTable.dataSource);
generator.writeStringField("granularity", granularity.value);
generator.writeStringField("granularity", finalGranularity.value);
writeField(generator, "dimensions", dimensions);
writeFieldIf(generator, "limitSpec", limit);
writeFieldIf(generator, "filter", jsonFilter);
Expand All @@ -738,7 +746,7 @@ protected QuerySpec getQuery(RelDataType rowType, RexNode filter, List<RexNode>
writeFieldIf(generator, "filter", jsonFilter);
writeField(generator, "dimensions", translator.dimensions);
writeField(generator, "metrics", translator.metrics);
generator.writeStringField("granularity", granularity.value);
generator.writeStringField("granularity", finalGranularity.value);

generator.writeFieldName("pagingSpec");
generator.writeStartObject();
Expand Down
Expand Up @@ -565,10 +565,9 @@ public void onMatch(RelOptRuleCall call) {
return;
}
// Either it is:
// - a sort and limit on a dimension/metric part of the druid group by query or
// - a sort without limit on the time column on top of
// Agg operator (transformable to timeseries query), or
// - it is a sort w/o limit on columns that do not include
// the time column on top of Agg operator, or
// - a simple limit on top of other operator than Agg
if (!validSortLimit(sort, query)) {
return;
Expand All @@ -590,32 +589,21 @@ private static boolean validSortLimit(Sort sort, DruidQuery query) {
int metricsRefs = 0;
for (RelFieldCollation col : sort.collation.getFieldCollations()) {
int idx = col.getFieldIndex();
//computes the number of metrics in the sort
if (idx >= topAgg.getGroupCount()) {
metricsRefs++;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metricsRefs is not used anymore (and it seems it is not needed). If it became useless, could we get rid of it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

continue;
}
//has the indexes of the columns used for sorts
positionsReferenced.set(topAgg.getGroupSet().nth(idx));
}
boolean refsTimestamp =
checkTimestampRefOnQuery(positionsReferenced.build(), topAgg.getInput(), query);
if (refsTimestamp && metricsRefs != 0) {
// Metrics reference timestamp too
return false;
}
// If the aggregate is grouping by timestamp (or a function of the
// timestamp such as month) then we cannot push Sort to Druid.
// Druid's topN and groupBy operators would sort only within the
// granularity, whereas we want global sort.
final boolean aggregateRefsTimestamp =
checkTimestampRefOnQuery(topAgg.getGroupSet(), topAgg.getInput(), query);
if (aggregateRefsTimestamp && metricsRefs != 0) {
return false;
}
if (refsTimestamp
&& sort.collation.getFieldCollations().size() == 1
// Case it is a timeseries query
if (checkIsFlooringTimestampRefOnQuery(topAgg.getGroupSet(), topAgg.getInput(), query)
&& topAgg.getGroupCount() == 1) {
// Timeseries query: if it has a limit, we cannot push
return !RelOptUtil.isLimit(sort);
// do not push if it has a limit or more than one sort key or we have sort by
// metric/dimension
return !RelOptUtil.isLimit(sort) && sort.collation.getFieldCollations().size() == 1
&& checkTimestampRefOnQuery(positionsReferenced.build(), topAgg.getInput(), query);
}
return true;
}
Expand All @@ -625,6 +613,36 @@ private static boolean validSortLimit(Sort sort, DruidQuery query) {
}
}

/** Returns true if any of the grouping key is a floor operator over the timestamp column. */
private static boolean checkIsFlooringTimestampRefOnQuery(ImmutableBitSet set, RelNode top,
DruidQuery query) {
if (top instanceof Project) {
ImmutableBitSet.Builder newSet = ImmutableBitSet.builder();
final Project project = (Project) top;
for (int index : set) {
RexNode node = project.getProjects().get(index);
if (node instanceof RexCall) {
RexCall call = (RexCall) node;
assert DruidDateTimeUtils.extractGranularity(call) != null;
if (call.getKind().equals(SqlKind.FLOOR)) {
newSet.addAll(RelOptUtil.InputFinder.bits(call));
}
}
}
top = project.getInput();
set = newSet.build();
}
// Check if any references the timestamp column
for (int index : set) {
if (query.druidTable.timestampFieldName.equals(
top.getRowType().getFieldNames().get(index))) {
return true;
}
}

return false;
}

/** Checks whether any of the references leads to the timestamp column. */
private static boolean checkTimestampRefOnQuery(ImmutableBitSet set, RelNode top,
DruidQuery query) {
Expand All @@ -638,7 +656,8 @@ private static boolean checkTimestampRefOnQuery(ImmutableBitSet set, RelNode top
} else if (node instanceof RexCall) {
RexCall call = (RexCall) node;
assert DruidDateTimeUtils.extractGranularity(call) != null;
newSet.set(((RexInputRef) call.getOperands().get(0)).getIndex());
// when we have extract from time columnthe rexCall is in the form of /Reinterpret$0
newSet.addAll(RelOptUtil.InputFinder.bits(call));
}
}
top = project.getInput();
Expand Down
Expand Up @@ -51,11 +51,11 @@ public static ExtractionFunction buildExtraction(RexNode rexNode) {
}
switch (timeUnit) {
case YEAR:
return TimeExtractionFunction.createFromGranularity(Granularity.YEAR);
return TimeExtractionFunction.createExtractFromGranularity(Granularity.YEAR);
case MONTH:
return TimeExtractionFunction.createFromGranularity(Granularity.MONTH);
return TimeExtractionFunction.createExtractFromGranularity(Granularity.MONTH);
case DAY:
return TimeExtractionFunction.createFromGranularity(Granularity.DAY);
return TimeExtractionFunction.createExtractFromGranularity(Granularity.DAY);
default:
return null;
}
Expand Down
Expand Up @@ -31,6 +31,7 @@ public TimeExtractionDimensionSpec(
* to the given name.
*
* @param outputName name of the output column
*
* @return the time extraction DimensionSpec instance
*/
public static TimeExtractionDimensionSpec makeFullTimeExtract(String outputName) {
Expand All @@ -44,27 +45,41 @@ public static TimeExtractionDimensionSpec makeFullTimeExtract(String outputName)
* name. Only YEAR, MONTH, and DAY granularity are supported.
*
* @param granularity granularity to apply to the column
* @param outputName name of the output column
* @return the time extraction DimensionSpec instance or null if granularity
* @param outputName name of the output column
*
* @return time field extraction DimensionSpec instance or null if granularity
* is not supported
*/
public static TimeExtractionDimensionSpec makeExtract(
Granularity granularity, String outputName) {
switch (granularity) {
case YEAR:
return new TimeExtractionDimensionSpec(
TimeExtractionFunction.createFromGranularity(granularity), outputName);
TimeExtractionFunction.createExtractFromGranularity(granularity), outputName);
case MONTH:
return new TimeExtractionDimensionSpec(
TimeExtractionFunction.createFromGranularity(granularity), outputName);
TimeExtractionFunction.createExtractFromGranularity(granularity), outputName);
case DAY:
return new TimeExtractionDimensionSpec(
TimeExtractionFunction.createFromGranularity(granularity), outputName);
TimeExtractionFunction.createExtractFromGranularity(granularity), outputName);
// TODO: Support other granularities
default:
return null;
}
}


/**
* Creates floor time extraction dimension spec from Granularity with a given output name
* @param granularity granularity to apply to the time column
* @param outputName name of the output column
*
* @return floor time extraction DimensionSpec instance.
*/
public static TimeExtractionDimensionSpec makeFloor(Granularity granularity, String outputName) {
ExtractionFunction fn = TimeExtractionFunction.createFloorFromGranularity(granularity);
return new TimeExtractionDimensionSpec(fn, outputName);
}
}

// End TimeExtractionDimensionSpec.java