Skip to content
Permalink
Browse files
DRILL-5188: Expand sub-queries using rules
- Add check for agg with group by literal
- Allow NLJ for limit 1
- Implement single_value aggregate function

closes #1321
  • Loading branch information
vvysotskyi authored and vdiravka committed Jun 22, 2018
1 parent b447260 commit 502d2977092eecda0a4aa0482b5f96459c315227
Showing 20 changed files with 504 additions and 70 deletions.
@@ -43,6 +43,7 @@ data: {
intervalNumericTypes: tdd(../data/IntervalNumericTypes.tdd),
extract: tdd(../data/ExtractTypes.tdd),
sumzero: tdd(../data/SumZero.tdd),
singleValue: tdd(../data/SingleValue.tdd),
numericTypes: tdd(../data/NumericTypes.tdd),
casthigh: tdd(../data/CastHigh.tdd),
countAggrTypes: tdd(../data/CountAggrTypes.tdd)
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http:# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

{
types: [
{inputType: "Bit", outputType: "NullableBit", runningType: "Bit", major: "primitive"},
{inputType: "TinyInt", outputType: "NullableTinyInt", runningType: "TinyInt", major: "primitive"},
{inputType: "NullableTinyInt", outputType: "NullableTinyInt", runningType: "TinyInt", major: "primitive"},
{inputType: "UInt1", outputType: "NullableUInt1", runningType: "UInt1", major: "primitive"},
{inputType: "NullableUInt1", outputType: "NullableUInt1", runningType: "UInt1", major: "primitive"},
{inputType: "UInt2", outputType: "NullableUInt2", runningType: "UInt2", major: "primitive"},
{inputType: "NullableUInt2", outputType: "NullableUInt2", runningType: "UInt2", major: "primitive"},
{inputType: "SmallInt", outputType: "NullableSmallInt", runningType: "SmallInt", major: "primitive"},
{inputType: "NullableSmallInt", outputType: "NullableSmallInt", runningType: "SmallInt", major: "primitive"},
{inputType: "UInt4", outputType: "NullableUInt4", runningType: "UInt4", major: "primitive"},
{inputType: "NullableUInt4", outputType: "NullableUInt4", runningType: "UInt4", major: "primitive"},
{inputType: "UInt8", outputType: "NullableUInt8", runningType: "UInt8", major: "primitive"},
{inputType: "NullableUInt8", outputType: "NullableUInt8", runningType: "UInt8", major: "primitive"},
{inputType: "Int", outputType: "NullableInt", runningType: "Int", major: "primitive"},
{inputType: "BigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "primitive"},
{inputType: "NullableBit", outputType: "NullableBit", runningType: "Bit", major: "primitive"},
{inputType: "NullableInt", outputType: "NullableInt", runningType: "Int", major: "primitive"},
{inputType: "NullableBigInt", outputType: "NullableBigInt", runningType: "BigInt", major: "primitive"},
{inputType: "Float4", outputType: "NullableFloat4", runningType: "Float4", major: "primitive"},
{inputType: "Float8", outputType: "NullableFloat8", runningType: "Float8", major: "primitive"},
{inputType: "NullableFloat4", outputType: "NullableFloat4", runningType: "Float4", major: "primitive"},
{inputType: "NullableFloat8", outputType: "NullableFloat8", runningType: "Float8", major: "primitive"},
{inputType: "Date", outputType: "NullableDate", runningType: "Date", major: "primitive"},
{inputType: "NullableDate", outputType: "NullableDate", runningType: "Date", major: "primitive"},
{inputType: "TimeStamp", outputType: "NullableTimeStamp", runningType: "TimeStamp", major: "primitive"},
{inputType: "NullableTimeStamp", outputType: "NullableTimeStamp", runningType: "TimeStamp", major: "primitive"},
{inputType: "Time", outputType: "NullableTime", runningType: "Time", major: "primitive"},
{inputType: "NullableTime", outputType: "NullableTime", runningType: "Time", major: "primitive"},
{inputType: "IntervalDay", outputType: "NullableIntervalDay", runningType: "IntervalDay", major: "IntervalDay"},
{inputType: "NullableIntervalDay", outputType: "NullableIntervalDay", runningType: "IntervalDay", major: "IntervalDay"},
{inputType: "IntervalYear", outputType: "NullableIntervalYear", runningType: "IntervalYear", major: "primitive"},
{inputType: "NullableIntervalYear", outputType: "NullableIntervalYear", runningType: "IntervalYear", major: "primitive"},
{inputType: "Interval", outputType: "NullableInterval", runningType: "Interval", major: "Interval"},
{inputType: "NullableInterval", outputType: "NullableInterval", runningType: "Interval", major: "Interval"},
{inputType: "VarDecimal", outputType: "NullableVarDecimal", runningType: "VarDecimal", major: "VarDecimal"},
{inputType: "NullableVarDecimal", outputType: "NullableVarDecimal", runningType: "VarDecimal", major: "VarDecimal"},
{inputType: "VarChar", outputType: "NullableVarChar", runningType: "VarChar", major: "bytes"},
{inputType: "NullableVarChar", outputType: "NullableVarChar", runningType: "VarChar", major: "bytes"},
{inputType: "Var16Char", outputType: "NullableVar16Char", runningType: "Var16Char", major: "bytes"},
{inputType: "NullableVar16Char", outputType: "NullableVar16Char", runningType: "Var16Char", major: "bytes"},
{inputType: "VarBinary", outputType: "NullableVarBinary", runningType: "VarBinary", major: "bytes"},
{inputType: "NullableVarBinary", outputType: "NullableVarBinary", runningType: "VarBinary", major: "bytes"}
]
}
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
<@pp.dropOutputFile />

<@pp.changeOutputFile name="/org/apache/drill/exec/expr/fn/impl/gaggr/SingleValueFunctions.java" />

<#include "/@includes/license.ftl" />

package org.apache.drill.exec.expr.fn.impl.gaggr;

import org.apache.drill.exec.expr.DrillAggFunc;
import org.apache.drill.exec.expr.annotations.FunctionTemplate;
import org.apache.drill.exec.expr.annotations.FunctionTemplate.FunctionScope;
import org.apache.drill.exec.expr.annotations.Output;
import org.apache.drill.exec.expr.annotations.Param;
import org.apache.drill.exec.expr.annotations.Workspace;
import org.apache.drill.exec.expr.holders.*;

import javax.inject.Inject;
import io.netty.buffer.DrillBuf;

/*
* This class is generated using freemarker and the ${.template_name} template.
*/
@SuppressWarnings("unused")
public class SingleValueFunctions {
<#list singleValue.types as type>

@FunctionTemplate(name = "single_value",
<#if type.major == "VarDecimal">
returnType = FunctionTemplate.ReturnType.DECIMAL_AVG_AGGREGATE,
</#if>
scope = FunctionTemplate.FunctionScope.POINT_AGGREGATE)
public static class ${type.inputType}SingleValue implements DrillAggFunc {
@Param ${type.inputType}Holder in;
@Workspace ${type.runningType}Holder value;
@Output ${type.outputType}Holder out;
@Workspace BigIntHolder nonNullCount;
<#if type.major == "VarDecimal" || type.major == "bytes">
@Inject DrillBuf buffer;
</#if>

public void setup() {
nonNullCount = new BigIntHolder();
nonNullCount.value = 0;
value = new ${type.runningType}Holder();
}

@Override
public void add() {
<#if type.inputType?starts_with("Nullable")>
sout: {
if (in.isSet == 0) {
// processing nullable input and the value is null, so don't do anything...
break sout;
}
</#if>
if (nonNullCount.value == 0) {
nonNullCount.value = 1;
} else {
throw org.apache.drill.common.exceptions.UserException.functionError()
.message("Input for single_value function has more than one row")
.build();
}
<#if type.major == "primitive">
value.value = in.value;
<#elseif type.major == "IntervalDay">
value.days = in.days;
value.milliseconds = in.milliseconds;
<#elseif type.major == "Interval">
value.days = in.days;
value.milliseconds = in.milliseconds;
value.months = in.months;
<#elseif type.major == "VarDecimal">
value.start = in.start;
value.end = in.end;
value.buffer = in.buffer;
value.scale = in.scale;
value.precision = in.precision;
<#elseif type.major == "bytes">
value.start = in.start;
value.end = in.end;
value.buffer = in.buffer;
</#if>
<#if type.inputType?starts_with("Nullable")>
} // end of sout block
</#if>
}

@Override
public void output() {
if (nonNullCount.value > 0) {
out.isSet = 1;
<#if type.major == "primitive">
out.value = value.value;
<#elseif type.major == "IntervalDay">
out.days = value.days;
out.milliseconds = value.milliseconds;
<#elseif type.major == "Interval">
out.days = value.days;
out.milliseconds = value.milliseconds;
out.months = value.months;
<#elseif type.major == "VarDecimal">
out.start = value.start;
out.end = value.end;
out.buffer = buffer.reallocIfNeeded(value.end - value.start);
out.buffer.writeBytes(value.buffer, value.start, value.end - value.start);
out.scale = value.scale;
out.precision = value.precision;
<#elseif type.major == "bytes">
out.start = value.start;
out.end = value.end;
out.buffer = buffer.reallocIfNeeded(value.end - value.start);
out.buffer.writeBytes(value.buffer, value.start, value.end - value.start);
</#if>
} else {
out.isSet = 0;
}
}

@Override
public void reset() {
value = new ${type.runningType}Holder();
nonNullCount.value = 0;
}
}
</#list>
}

@@ -17,9 +17,22 @@
*/
package org.apache.drill.exec.physical.impl.join;

import org.apache.calcite.rel.RelShuttleImpl;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalExchange;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalMinus;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.util.Util;
import org.apache.drill.common.exceptions.UserException;
import org.apache.drill.common.logical.data.JoinCondition;
import org.apache.calcite.rel.RelNode;
@@ -35,9 +48,11 @@
import org.apache.drill.common.types.TypeProtos;
import org.apache.drill.exec.expr.ExpressionTreeMaterializer;
import org.apache.drill.exec.ops.FragmentContext;
import org.apache.drill.exec.planner.logical.DrillLimitRel;
import org.apache.drill.exec.record.VectorAccessible;
import org.apache.drill.exec.resolver.TypeCastRules;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

@@ -220,7 +235,13 @@ public static boolean isScalarSubquery(RelNode root) {
if (currentrel instanceof DrillAggregateRel) {
agg = (DrillAggregateRel)currentrel;
} else if (currentrel instanceof RelSubset) {
currentrel = ((RelSubset)currentrel).getBest() ;
currentrel = ((RelSubset) currentrel).getBest() ;
} else if (currentrel instanceof DrillLimitRel) {
// TODO: Improve this check when DRILL-5691 is fixed.
// The problem is that RelMdMaxRowCount currently cannot be used
// due to CALCITE-1048.
Integer fetchValue = ((RexLiteral) ((DrillLimitRel) currentrel).getFetch()).getValueAs(Integer.class);
return fetchValue != null && fetchValue <= 1;
} else if (currentrel.getInputs().size() == 1) {
// If the rel is not an aggregate or RelSubset, but is a single-input rel (could be Project,
// Filter, Sort etc.), check its input
@@ -234,6 +255,17 @@ public static boolean isScalarSubquery(RelNode root) {
if (agg.getGroupSet().isEmpty()) {
return true;
}
// Checks that expression in group by is a single and it is literal.
// When Calcite rewrites EXISTS sub-queries using SubQueryRemoveRule rules,
// it creates project with TRUE literal in expressions list and aggregate on top of it
// with empty call list and literal from project expression in group set.
if (agg.getAggCallList().isEmpty() && agg.getGroupSet().cardinality() == 1) {
ProjectExpressionsCollector expressionsCollector = new ProjectExpressionsCollector();
agg.accept(expressionsCollector);
List<RexNode> projectedExpressions = expressionsCollector.getProjectedExpressions();
return projectedExpressions.size() == 1
&& RexUtil.isLiteral(projectedExpressions.get(agg.getGroupSet().nth(0)), true);
}
}
return false;
}
@@ -267,4 +299,75 @@ public static boolean hasScalarSubqueryInput(RelNode left, RelNode right) {
return isScalarSubquery(left) || isScalarSubquery(right);
}

/**
* Collects expressions list from the input project.
* For the case when input rel node has single input, its input is taken.
*/
private static class ProjectExpressionsCollector extends RelShuttleImpl {
private final List<RexNode> expressions = new ArrayList<>();

@Override
public RelNode visit(RelNode other) {
// RelShuttleImpl doesn't have visit methods for Project and RelSubset.
if (other instanceof RelSubset) {
return visit((RelSubset) other);
} else if (other instanceof Project) {
return visit((Project) other);
}
return super.visit(other);
}

@Override
public RelNode visit(TableFunctionScan scan) {
return scan;
}

@Override
public RelNode visit(LogicalJoin join) {
return join;
}

@Override
public RelNode visit(LogicalCorrelate correlate) {
return correlate;
}

@Override
public RelNode visit(LogicalUnion union) {
return union;
}

@Override
public RelNode visit(LogicalIntersect intersect) {
return intersect;
}

@Override
public RelNode visit(LogicalMinus minus) {
return minus;
}

@Override
public RelNode visit(LogicalSort sort) {
return sort;
}

@Override
public RelNode visit(LogicalExchange exchange) {
return exchange;
}

private RelNode visit(Project project) {
expressions.addAll(project.getProjects());
return project;
}

private RelNode visit(RelSubset subset) {
return Util.first(subset.getBest(), subset.getOriginal()).accept(this);
}

public List<RexNode> getProjectedExpressions() {
return expressions;
}
}
}
@@ -113,6 +113,16 @@ public RuleSet getRules(OptimizerRulesContext context, Collection<StoragePlugin>
}
},

SUBQUERY_REWRITE("Sub-queries rewrites") {
public RuleSet getRules(OptimizerRulesContext context, Collection<StoragePlugin> plugins) {
return RuleSets.ofList(
RuleInstance.SUB_QUERY_FILTER_REMOVE_RULE,
RuleInstance.SUB_QUERY_PROJECT_REMOVE_RULE,
RuleInstance.SUB_QUERY_JOIN_REMOVE_RULE
);
}
},

LOGICAL_PRUNE("Logical Planning (with partition pruning)") {
public RuleSet getRules(OptimizerRulesContext context, Collection<StoragePlugin> plugins) {
return PlannerPhase.mergedRuleSets(
@@ -40,6 +40,7 @@
import org.apache.calcite.rel.rules.ProjectWindowTransposeRule;
import org.apache.calcite.rel.rules.ReduceExpressionsRule;
import org.apache.calcite.rel.rules.SortRemoveRule;
import org.apache.calcite.rel.rules.SubQueryRemoveRule;
import org.apache.calcite.rel.rules.UnionToDistinctRule;
import org.apache.drill.exec.planner.logical.DrillConditions;
import org.apache.drill.exec.planner.logical.DrillRelFactories;
@@ -130,4 +131,13 @@ public interface RuleInstance {

FilterRemoveIsNotDistinctFromRule REMOVE_IS_NOT_DISTINCT_FROM_RULE =
new FilterRemoveIsNotDistinctFromRule(DrillRelBuilder.proto(DrillRelFactories.DRILL_LOGICAL_FILTER_FACTORY));

SubQueryRemoveRule SUB_QUERY_FILTER_REMOVE_RULE =
new SubQueryRemoveRule.SubQueryFilterRemoveRule(DrillRelFactories.LOGICAL_BUILDER);

SubQueryRemoveRule SUB_QUERY_PROJECT_REMOVE_RULE =
new SubQueryRemoveRule.SubQueryProjectRemoveRule(DrillRelFactories.LOGICAL_BUILDER);

SubQueryRemoveRule SUB_QUERY_JOIN_REMOVE_RULE =
new SubQueryRemoveRule.SubQueryJoinRemoveRule(DrillRelFactories.LOGICAL_BUILDER);
}

0 comments on commit 502d297

Please sign in to comment.