From c9f8621d228cca803f967ae91c277f74c6e8e748 Mon Sep 17 00:00:00 2001 From: Hsuan-Yi Chu Date: Tue, 8 Mar 2016 17:57:36 -0800 Subject: [PATCH] DRILL-4372: (continued) Add option to disable/enable function output type inference --- .../exec/expr/fn/HiveFunctionRegistry.java | 4 +- .../planner/sql/HiveUDFOperatorNotInfer.java | 44 ++++ .../exec/expr/fn/DrillFunctionRegistry.java | 89 +++++++- .../apache/drill/exec/ops/QueryContext.java | 2 +- .../logical/DrillReduceAggregatesRule.java | 211 ++++++++---------- .../planner/physical/PlannerSettings.java | 7 + .../sql/DrillAvgVarianceConvertlet.java | 14 +- .../exec/planner/sql/DrillOperatorTable.java | 86 +++++-- .../exec/planner/sql/DrillSqlAggOperator.java | 56 ++++- .../sql/DrillSqlAggOperatorNotInfer.java | 43 ++++ .../exec/planner/sql/DrillSqlOperator.java | 99 +++++++- .../planner/sql/DrillSqlOperatorNotInfer.java | 76 +++++++ .../drill/exec/planner/sql/SqlConverter.java | 70 +++++- .../exec/planner/sql/TypeInferenceUtils.java | 13 +- .../server/options/SystemOptionManager.java | 1 + .../TestFunctionsWithTypeExpoQueries.java | 188 +++++++++++++++- 16 files changed, 837 insertions(+), 166 deletions(-) create mode 100644 contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorNotInfer.java create mode 100644 exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorNotInfer.java create mode 100644 exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorNotInfer.java diff --git a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java index 9a4e2101b62..52bd05b4055 100644 --- a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java +++ b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java @@ -43,6 +43,7 @@ import org.apache.drill.exec.expr.fn.impl.hive.ObjectInspectorHelper; import org.apache.drill.exec.planner.sql.DrillOperatorTable; import org.apache.drill.exec.planner.sql.HiveUDFOperator; +import org.apache.drill.exec.planner.sql.HiveUDFOperatorNotInfer; import org.apache.drill.exec.planner.sql.TypeInferenceUtils; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDF; @@ -84,7 +85,8 @@ public HiveFunctionRegistry(DrillConfig config) { @Override public void register(DrillOperatorTable operatorTable) { for (String name : Sets.union(methodsGenericUDF.asMap().keySet(), methodsUDF.asMap().keySet())) { - operatorTable.add(name, new HiveUDFOperator(name.toUpperCase(), new HiveSqlReturnTypeInference())); + operatorTable.addDefault(name, new HiveUDFOperatorNotInfer(name.toUpperCase())); + operatorTable.addInference(name, new HiveUDFOperator(name.toUpperCase(), new HiveSqlReturnTypeInference())); } } diff --git a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorNotInfer.java b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorNotInfer.java new file mode 100644 index 00000000000..0c718f61f0b --- /dev/null +++ b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorNotInfer.java @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.sql; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorScope; + +public class HiveUDFOperatorNotInfer extends HiveUDFOperator { + public HiveUDFOperatorNotInfer(String name) { + super(name, DynamicReturnType.INSTANCE); + } + + @Override + public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { + RelDataTypeFactory factory = validator.getTypeFactory(); + return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true); + } + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + RelDataTypeFactory factory = opBinding.getTypeFactory(); + return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true); + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java index 76ec90dde5d..f6bc666f8c1 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java @@ -23,10 +23,13 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Set; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.commons.lang3.tuple.Pair; import org.apache.drill.common.scanner.persistence.AnnotatedClassDescriptor; @@ -35,9 +38,11 @@ import org.apache.drill.exec.planner.logical.DrillConstExecutor; import org.apache.drill.exec.planner.sql.DrillOperatorTable; import org.apache.drill.exec.planner.sql.DrillSqlAggOperator; +import org.apache.drill.exec.planner.sql.DrillSqlAggOperatorNotInfer; import org.apache.drill.exec.planner.sql.DrillSqlOperator; import com.google.common.collect.ArrayListMultimap; +import org.apache.drill.exec.planner.sql.DrillSqlOperatorNotInfer; /** * Registry of Drill functions. @@ -122,6 +127,13 @@ public List getMethods(String name) { } public void register(DrillOperatorTable operatorTable) { + registerForInference(operatorTable); + registerForDefault(operatorTable); + } + + public void registerForInference(DrillOperatorTable operatorTable) { + final Map map = Maps.newHashMap(); + final Map mapAgg = Maps.newHashMap(); for (Entry> function : registeredFunctions.asMap().entrySet()) { final ArrayListMultimap, DrillFuncHolder> functions = ArrayListMultimap.create(); final ArrayListMultimap aggregateFunctions = ArrayListMultimap.create(); @@ -146,20 +158,79 @@ public void register(DrillOperatorTable operatorTable) { } } for (Entry, Collection> entry : functions.asMap().entrySet()) { - final DrillSqlOperator drillSqlOperator; final Pair range = entry.getKey(); final int max = range.getRight(); final int min = range.getLeft(); - drillSqlOperator = new DrillSqlOperator( - name, - Lists.newArrayList(entry.getValue()), - min, - max, - isDeterministic); - operatorTable.add(name, drillSqlOperator); + if(map.containsKey(name)) { + final DrillSqlOperator.DrillSqlOperatorBuilder drillSqlOperatorBuilder = map.get(name); + drillSqlOperatorBuilder + .addFunctions(entry.getValue()) + .setArgumentCount(min, max) + .setDeterministic(isDeterministic); + } else { + final DrillSqlOperator.DrillSqlOperatorBuilder drillSqlOperatorBuilder = new DrillSqlOperator.DrillSqlOperatorBuilder(); + drillSqlOperatorBuilder + .setName(name) + .addFunctions(entry.getValue()) + .setArgumentCount(min, max) + .setDeterministic(isDeterministic); + + map.put(name, drillSqlOperatorBuilder); + } } for (Entry> entry : aggregateFunctions.asMap().entrySet()) { - operatorTable.add(name, new DrillSqlAggOperator(name, Lists.newArrayList(entry.getValue()), entry.getKey())); + if(mapAgg.containsKey(name)) { + final DrillSqlAggOperator.DrillSqlAggOperatorBuilder drillSqlAggOperatorBuilder = mapAgg.get(name); + drillSqlAggOperatorBuilder + .addFunctions(entry.getValue()) + .setArgumentCount(entry.getKey(), entry.getKey()); + } else { + final DrillSqlAggOperator.DrillSqlAggOperatorBuilder drillSqlAggOperatorBuilder = new DrillSqlAggOperator.DrillSqlAggOperatorBuilder(); + drillSqlAggOperatorBuilder + .setName(name) + .addFunctions(entry.getValue()) + .setArgumentCount(entry.getKey(), entry.getKey()); + + mapAgg.put(name, drillSqlAggOperatorBuilder); + } + } + } + + for(final Entry entry : map.entrySet()) { + operatorTable.addInference( + entry.getKey(), + entry.getValue().build()); + } + + for(final Entry entry : mapAgg.entrySet()) { + operatorTable.addInference( + entry.getKey(), + entry.getValue().build()); + } + } + + public void registerForDefault(DrillOperatorTable operatorTable) { + SqlOperator op; + for (Entry> function : registeredFunctions.asMap().entrySet()) { + Set argCounts = Sets.newHashSet(); + String name = function.getKey().toUpperCase(); + for (DrillFuncHolder func : function.getValue()) { + if (argCounts.add(func.getParamCount())) { + if (func.isAggregating()) { + op = new DrillSqlAggOperatorNotInfer(name, func.getParamCount()); + } else { + boolean isDeterministic; + // prevent Drill from folding constant functions with types that cannot be materialized + // into literals + if (DrillConstExecutor.NON_REDUCIBLE_TYPES.contains(func.getReturnType().getMinorType())) { + isDeterministic = false; + } else { + isDeterministic = func.isDeterministic(); + } + op = new DrillSqlOperatorNotInfer(name, func.getParamCount(), func.getReturnType(), isDeterministic); + } + operatorTable.addDefault(function.getKey(), op); + } } } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/ops/QueryContext.java b/exec/java-exec/src/main/java/org/apache/drill/exec/ops/QueryContext.java index 51a581a5369..3ce0633305e 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/ops/QueryContext.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/ops/QueryContext.java @@ -86,7 +86,7 @@ public QueryContext(final UserSession session, final DrillbitContext drillbitCon executionControls = new ExecutionControls(queryOptions, drillbitContext.getEndpoint()); plannerSettings = new PlannerSettings(queryOptions, getFunctionRegistry()); plannerSettings.setNumEndPoints(drillbitContext.getBits().size()); - table = new DrillOperatorTable(getFunctionRegistry()); + table = new DrillOperatorTable(getFunctionRegistry(), drillbitContext.getOptionManager()); queryContextInfo = Utilities.createQueryContextInfo(session.getDefaultSchemaName()); contextInformation = new ContextInformation(session.getCredentials(), queryContextInfo); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java index 3a2510e02ec..8975e9fe8bc 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java @@ -18,7 +18,6 @@ package org.apache.drill.exec.planner.logical; -import com.google.common.collect.ImmutableList; import java.math.BigDecimal; import java.util.ArrayList; @@ -33,7 +32,16 @@ import org.apache.calcite.rel.InvalidRelException; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.fun.SqlCountAggFunction; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.trace.CalciteTrace; +import org.apache.drill.exec.planner.physical.PlannerSettings; +import org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper; +import org.apache.drill.exec.planner.sql.DrillCalciteSqlWrapper; +import org.apache.drill.exec.planner.sql.DrillSqlOperator; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.RelNode; import org.apache.calcite.plan.RelOptRule; @@ -51,15 +59,12 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.fun.SqlSumAggFunction; import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; -import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.CompositeList; import org.apache.calcite.util.ImmutableIntList; import org.apache.calcite.util.Util; -import org.apache.calcite.util.trace.CalciteTrace; -import org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper; -import org.apache.drill.exec.planner.sql.DrillCalciteSqlWrapper; -import org.apache.drill.exec.planner.sql.DrillSqlOperator; +import com.google.common.collect.ImmutableList; +import org.apache.drill.exec.planner.sql.TypeInferenceUtils; /** * Rule to reduce aggregates to simpler forms. Currently only AVG(x) to @@ -71,13 +76,21 @@ public class DrillReduceAggregatesRule extends RelOptRule { /** * The singleton. */ - public static final DrillReduceAggregatesRule INSTANCE = new DrillReduceAggregatesRule(operand(LogicalAggregate.class, any())); public static final DrillConvertSumToSumZero INSTANCE_SUM = - new DrillConvertSumToSumZero(operand(DrillAggregateRel.class, any())); - - private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false); + new DrillConvertSumToSumZero(operand(DrillAggregateRel.class, any())); + + private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false, + new SqlReturnTypeInference() { + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return TypeInferenceUtils.createCalciteTypeWithNullability( + opBinding.getTypeFactory(), + SqlTypeName.ANY, + opBinding.getOperandType(0).isNullable()); + } + }); //~ Constructors ----------------------------------------------------------- @@ -222,7 +235,6 @@ private RexNode reduceAgg( // case COUNT(x) when 0 then null else SUM0(x) end return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping); } - if (sqlAggFunction instanceof SqlAvgAggFunction) { final SqlAvgAggFunction.Subtype subtype = ((SqlAvgAggFunction) sqlAggFunction).getSubtype(); switch (subtype) { @@ -292,7 +304,8 @@ private RexNode reduceAvg( AggregateCall oldCall, List newCalls, Map aggCallMapping) { - final boolean isWrapper = useWrapper(oldCall); + final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext(); + final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled(); final int nGroups = oldAggRel.getGroupCount(); RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory(); @@ -302,25 +315,12 @@ private RexNode reduceAvg( getFieldType( oldAggRel.getInput(), iAvgInput); - - final RelDataType sumType; - if(isWrapper) { - sumType = oldCall.getType(); - } else { - sumType = - typeFactory.createTypeWithNullability( - avgInputType, - avgInputType.isNullable() || nGroups == 0); - } + RelDataType sumType = + typeFactory.createTypeWithNullability( + avgInputType, + avgInputType.isNullable() || nGroups == 0); // SqlAggFunction sumAgg = new SqlSumAggFunction(sumType); - SqlAggFunction sumAgg; - if(isWrapper) { - sumAgg = new DrillCalciteSqlAggFunctionWrapper( - new SqlSumEmptyIsZeroAggFunction(), sumType); - } else { - sumAgg = new SqlSumEmptyIsZeroAggFunction(); - } - + SqlAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction(); AggregateCall sumCall = new AggregateCall( sumAgg, @@ -385,15 +385,21 @@ private RexNode reduceAvg( newCalls, aggCallMapping, ImmutableList.of(avgInputType)); - final RexNode divideRef = - rexBuilder.makeCall( - SqlStdOperatorTable.DIVIDE, - numeratorRef, - denominatorRef); - - if(isWrapper) { - return divideRef; + if(isInferenceEnabled) { + return rexBuilder.makeCall( + new DrillSqlOperator( + "divide", + 2, + true, + oldCall.getType()), + numeratorRef, + denominatorRef); } else { + final RexNode divideRef = + rexBuilder.makeCall( + SqlStdOperatorTable.DIVIDE, + numeratorRef, + denominatorRef); return rexBuilder.makeCast( typeFactory.createSqlType(SqlTypeName.ANY), divideRef); } @@ -404,34 +410,29 @@ private RexNode reduceSum( AggregateCall oldCall, List newCalls, Map aggCallMapping) { - final boolean isWrapper = useWrapper(oldCall); + final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext(); + final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled(); final int nGroups = oldAggRel.getGroupCount(); RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory(); RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); - - final RelDataType argType; - if(isWrapper) { - argType = oldCall.getType(); - } else { - int arg = oldCall.getArgList().get(0); - argType = - getFieldType( - oldAggRel.getInput(), - arg); - } - + int arg = oldCall.getArgList().get(0); + RelDataType argType = + getFieldType( + oldAggRel.getInput(), + arg); final RelDataType sumType; final SqlAggFunction sumZeroAgg; - if(isWrapper) { + if(isInferenceEnabled) { sumType = oldCall.getType(); sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper( new SqlSumEmptyIsZeroAggFunction(), sumType); } else { - sumType = typeFactory.createTypeWithNullability(argType, argType.isNullable()); + sumType = + typeFactory.createTypeWithNullability( + argType, argType.isNullable()); sumZeroAgg = new SqlSumEmptyIsZeroAggFunction(); } - AggregateCall sumZeroCall = new AggregateCall( sumZeroAgg, @@ -488,7 +489,6 @@ private RexNode reduceStddev( List newCalls, Map aggCallMapping, List inputExprs) { - final boolean isWrapper = useWrapper(oldCall); // stddev_pop(x) ==> // power( // (sum(x * x) - sum(x) * sum(x) / count(x)) @@ -500,6 +500,8 @@ private RexNode reduceStddev( // (sum(x * x) - sum(x) * sum(x) / count(x)) // / nullif(count(x) - 1, 0), // .5) + final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext(); + final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled(); final int nGroups = oldAggRel.getGroupCount(); RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory(); @@ -525,26 +527,13 @@ private RexNode reduceStddev( typeFactory.createTypeWithNullability( argType, true); - final AggregateCall sumArgSquaredAggCall; - if(isWrapper) { - sumArgSquaredAggCall = - new AggregateCall( - new DrillCalciteSqlAggFunctionWrapper( - new SqlSumAggFunction(sumType), sumType), - oldCall.isDistinct(), - ImmutableIntList.of(argSquaredOrdinal), - sumType, - null); - } else { - sumArgSquaredAggCall = - new AggregateCall( - new SqlSumAggFunction(sumType), - oldCall.isDistinct(), - ImmutableIntList.of(argSquaredOrdinal), - sumType, - null); - } - + final AggregateCall sumArgSquaredAggCall = + new AggregateCall( + new SqlSumAggFunction(sumType), + oldCall.isDistinct(), + ImmutableIntList.of(argSquaredOrdinal), + sumType, + null); final RexNode sumArgSquared = rexBuilder.addAggCall( sumArgSquaredAggCall, @@ -554,26 +543,13 @@ private RexNode reduceStddev( aggCallMapping, ImmutableList.of(argType)); - final AggregateCall sumArgAggCall; - if(isWrapper) { - sumArgAggCall = - new AggregateCall( - new DrillCalciteSqlAggFunctionWrapper( - new SqlSumAggFunction(sumType), sumType), - oldCall.isDistinct(), - ImmutableIntList.of(argOrdinal), - sumType, - null); - } else { - sumArgAggCall = - new AggregateCall( - new SqlSumAggFunction(sumType), - oldCall.isDistinct(), - ImmutableIntList.of(argOrdinal), - sumType, - null); - } - + final AggregateCall sumArgAggCall = + new AggregateCall( + new SqlSumAggFunction(sumType), + oldCall.isDistinct(), + ImmutableIntList.of(argOrdinal), + sumType, + null); final RexNode sumArg = rexBuilder.addAggCall( sumArgAggCall, @@ -635,9 +611,20 @@ private RexNode reduceStddev( countEqOne, nul, countMinusOne); } + final SqlOperator divide; + if(isInferenceEnabled) { + divide = new DrillSqlOperator( + "divide", + 2, + true, + oldCall.getType()); + } else { + divide = SqlStdOperatorTable.DIVIDE; + } + final RexNode div = rexBuilder.makeCall( - SqlStdOperatorTable.DIVIDE, diff, denominator); + divide, diff, denominator); RexNode result = div; if (sqrt) { @@ -648,17 +635,17 @@ private RexNode reduceStddev( SqlStdOperatorTable.POWER, div, half); } - /* - * Currently calcite's strategy to infer the return type of aggregate functions - * is wrong because it uses the first known argument to determine output type. For - * instance if we are performing stddev on an integer column then it interprets the - * output type to be integer which is incorrect as it should be double. So based on - * this if we add cast after rewriting the aggregate we add an additional cast which - * would cause wrong results. So we simply add a cast to ANY. - */ - if(isWrapper) { + if(isInferenceEnabled) { return result; } else { + /* + * Currently calcite's strategy to infer the return type of aggregate functions + * is wrong because it uses the first known argument to determine output type. For + * instance if we are performing stddev on an integer column then it interprets the + * output type to be integer which is incorrect as it should be double. So based on + * this if we add cast after rewriting the aggregate we add an additional cast which + * would cause wrong results. So we simply add a cast to ANY. + */ return rexBuilder.makeCast( typeFactory.createSqlType(SqlTypeName.ANY), result); } @@ -704,10 +691,6 @@ private RelDataType getFieldType(RelNode relNode, int i) { return inputField.getType(); } - private boolean useWrapper(AggregateCall aggregateCall) { - return aggregateCall.getAggregation() instanceof DrillCalciteSqlWrapper; - } - private static class DrillConvertSumToSumZero extends RelOptRule { protected static final Logger tracer = CalciteTrace.getPlannerTracer(); @@ -756,11 +739,11 @@ public void onMatch(RelOptRuleCall call) { new SqlSumEmptyIsZeroAggFunction(), sumType); AggregateCall sumZeroCall = new AggregateCall( - sumZeroAgg, - oldAggregateCall.isDistinct(), - oldAggregateCall.getArgList(), - sumType, - null); + sumZeroAgg, + oldAggregateCall.isDistinct(), + oldAggregateCall.getArgList(), + sumType, + null); oldAggRel.getCluster().getRexBuilder() .addAggCall(sumZeroCall, oldAggRel.getGroupCount(), diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java index 3eb50383071..a98619ce6fa 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java @@ -81,6 +81,9 @@ public class PlannerSettings implements Context{ new RangeLongValidator("planner.identifier_max_length", 128 /* A minimum length is needed because option names are identifiers themselves */, Integer.MAX_VALUE, DEFAULT_IDENTIFIER_MAX_LENGTH); + public static final String TYPE_INFERENCE_KEY = "planner.type_inference.enable"; + public static final BooleanValidator TYPE_INFERENCE = new BooleanValidator(TYPE_INFERENCE_KEY, true); + public OptionManager options = null; public FunctionImplementationRegistry functionImplementationRegistry = null; @@ -209,6 +212,10 @@ public static long getInitialPlanningMemorySize() { return INITIAL_OFF_HEAP_ALLOCATION_IN_BYTES; } + public boolean isTypeInferenceEnabled() { + return options.getOption(TYPE_INFERENCE.getOptionName()).bool_val; + } + @Override public T unwrap(Class clazz) { if(clazz == PlannerSettings.class){ diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java index 97317be8471..068423e01c6 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java @@ -23,9 +23,12 @@ import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNumericLiteral; +import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.fun.SqlAvgAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql2rel.SqlRexContext; import org.apache.calcite.sql2rel.SqlRexConvertlet; import org.apache.calcite.util.Util; @@ -40,7 +43,16 @@ public class DrillAvgVarianceConvertlet implements SqlRexConvertlet { private final SqlAvgAggFunction.Subtype subtype; - private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false); + private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false, + new SqlReturnTypeInference() { + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return TypeInferenceUtils.createCalciteTypeWithNullability( + opBinding.getTypeFactory(), + SqlTypeName.ANY, + opBinding.getOperandType(0).isNullable()); + } + }); public DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype subtype) { this.subtype = subtype; diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java index 7fe6020b772..de18f029872 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java @@ -24,6 +24,7 @@ import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlPrefixOperator; import org.apache.drill.common.expression.FunctionCallFactory; +import org.apache.drill.exec.ExecConstants; import org.apache.drill.exec.expr.fn.DrillFuncHolder; import org.apache.drill.exec.expr.fn.FunctionImplementationRegistry; import org.apache.calcite.sql.SqlFunctionCategory; @@ -32,6 +33,8 @@ import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.SqlSyntax; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.drill.exec.planner.physical.PlannerSettings; +import org.apache.drill.exec.server.options.SystemOptionManager; import java.util.List; import java.util.Map; @@ -43,24 +46,49 @@ public class DrillOperatorTable extends SqlStdOperatorTable { // private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillOperatorTable.class); private static final SqlOperatorTable inner = SqlStdOperatorTable.instance(); - private List operators = Lists.newArrayList(); + private final List operatorsCalcite = Lists.newArrayList(); + private final List operatorsDefault = Lists.newArrayList(); + private final List operatorsInferernce = Lists.newArrayList(); private final Map calciteToWrapper = Maps.newIdentityHashMap(); - private ArrayListMultimap opMap = ArrayListMultimap.create(); + + private final ArrayListMultimap opMapDefault = ArrayListMultimap.create(); + private final ArrayListMultimap opMapInferernce = ArrayListMultimap.create(); + + private final SystemOptionManager systemOptionManager; public DrillOperatorTable(FunctionImplementationRegistry registry) { + this(registry, null); + } + + public DrillOperatorTable(FunctionImplementationRegistry registry, SystemOptionManager systemOptionManager) { registry.register(this); - operators.addAll(inner.getOperatorList()); + operatorsCalcite.addAll(inner.getOperatorList()); populateWrappedCalciteOperators(); + this.systemOptionManager = systemOptionManager; + } + + public void addDefault(String name, SqlOperator op) { + operatorsDefault.add(op); + opMapDefault.put(name.toLowerCase(), op); } - public void add(String name, SqlOperator op) { - operators.add(op); - opMap.put(name.toLowerCase(), op); + public void addInference(String name, SqlOperator op) { + operatorsInferernce.add(op); + opMapInferernce.put(name.toLowerCase(), op); } @Override public void lookupOperatorOverloads(SqlIdentifier opName, SqlFunctionCategory category, SqlSyntax syntax, List operatorList) { + if(isEnableInference()) { + populateFromTypeInference(opName, category, syntax, operatorList); + } else { + populateFromDefault(opName, category, syntax, operatorList); + } + } + + private void populateFromTypeInference(SqlIdentifier opName, SqlFunctionCategory category, + SqlSyntax syntax, List operatorList) { final List calciteOperatorList = Lists.newArrayList(); inner.lookupOperatorOverloads(opName, category, syntax, calciteOperatorList); if(!calciteOperatorList.isEmpty()) { @@ -74,7 +102,7 @@ public void lookupOperatorOverloads(SqlIdentifier opName, SqlFunctionCategory ca } else { // if no function is found, check in Drill UDFs if (operatorList.isEmpty() && syntax == SqlSyntax.FUNCTION && opName.isSimple()) { - List drillOps = opMap.get(opName.getSimple().toLowerCase()); + List drillOps = opMapInferernce.get(opName.getSimple().toLowerCase()); if (drillOps != null && !drillOps.isEmpty()) { operatorList.addAll(drillOps); } @@ -82,14 +110,37 @@ public void lookupOperatorOverloads(SqlIdentifier opName, SqlFunctionCategory ca } } + private void populateFromDefault(SqlIdentifier opName, SqlFunctionCategory category, + SqlSyntax syntax, List operatorList) { + inner.lookupOperatorOverloads(opName, category, syntax, operatorList); + if (operatorList.isEmpty() && syntax == SqlSyntax.FUNCTION && opName.isSimple()) { + List drillOps = opMapDefault.get(opName.getSimple().toLowerCase()); + if (drillOps != null) { + operatorList.addAll(drillOps); + } + } + } + @Override public List getOperatorList() { - return operators; + final List sqlOperators = Lists.newArrayList(); + sqlOperators.addAll(operatorsCalcite); + if(isEnableInference()) { + sqlOperators.addAll(operatorsInferernce); + } else { + sqlOperators.addAll(operatorsDefault); + } + + return sqlOperators; } // Get the list of SqlOperator's with the given name. public List getSqlOperator(String name) { - return opMap.get(name.toLowerCase()); + if(isEnableInference()) { + return opMapInferernce.get(name.toLowerCase()); + } else { + return opMapDefault.get(name.toLowerCase()); + } } private void populateWrappedCalciteOperators() { @@ -97,14 +148,14 @@ private void populateWrappedCalciteOperators() { final SqlOperator wrapper; if(calciteOperator instanceof SqlAggFunction) { wrapper = new DrillCalciteSqlAggFunctionWrapper((SqlAggFunction) calciteOperator, - getFunctionList(calciteOperator.getName())); + getFunctionListWithInference(calciteOperator.getName())); } else if(calciteOperator instanceof SqlFunction) { wrapper = new DrillCalciteSqlFunctionWrapper((SqlFunction) calciteOperator, - getFunctionList(calciteOperator.getName())); + getFunctionListWithInference(calciteOperator.getName())); } else { final String drillOpName = FunctionCallFactory.replaceOpWithFuncName(calciteOperator.getName()); - final List drillFuncHolders = getFunctionList(drillOpName); - if(drillFuncHolders.isEmpty() || calciteOperator == SqlStdOperatorTable.UNARY_MINUS) { + final List drillFuncHolders = getFunctionListWithInference(drillOpName); + if(drillFuncHolders.isEmpty() || calciteOperator == SqlStdOperatorTable.UNARY_MINUS || calciteOperator == SqlStdOperatorTable.UNARY_PLUS) { continue; } @@ -114,9 +165,9 @@ private void populateWrappedCalciteOperators() { } } - private List getFunctionList(String name) { + private List getFunctionListWithInference(String name) { final List functions = Lists.newArrayList(); - for(SqlOperator sqlOperator : opMap.get(name.toLowerCase())) { + for(SqlOperator sqlOperator : opMapInferernce.get(name.toLowerCase())) { if(sqlOperator instanceof DrillSqlOperator) { final List list = ((DrillSqlOperator) sqlOperator).getFunctions(); if(list != null) { @@ -133,4 +184,9 @@ private List getFunctionList(String name) { } return functions; } + + private boolean isEnableInference() { + return systemOptionManager != null + && systemOptionManager.getOption(PlannerSettings.TYPE_INFERENCE); + } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java index 81c744c2fce..044f5b05ee5 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java @@ -17,36 +17,76 @@ */ package org.apache.drill.exec.planner.sql; -import org.apache.calcite.rel.type.RelDataType; +import com.google.common.collect.Lists; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.drill.exec.expr.fn.DrillFuncHolder; -import java.util.ArrayList; +import java.util.Collection; import java.util.List; public class DrillSqlAggOperator extends SqlAggFunction { // private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillSqlAggOperator.class); private final List functions; - public DrillSqlAggOperator(String name, List functions, int argCount) { + protected DrillSqlAggOperator(String name, List functions, int argCountMin, int argCountMax, SqlReturnTypeInference sqlReturnTypeInference) { super(name, new SqlIdentifier(name, SqlParserPos.ZERO), SqlKind.OTHER_FUNCTION, - TypeInferenceUtils.getDrillSqlReturnTypeInference( - name, - functions), + sqlReturnTypeInference, null, - Checker.getChecker(argCount, argCount), + Checker.getChecker(argCountMin, argCountMax), SqlFunctionCategory.USER_DEFINED_FUNCTION); this.functions = functions; } + private DrillSqlAggOperator(String name, List functions, int argCountMin, int argCountMax) { + this(name, + functions, + argCountMin, + argCountMax, + TypeInferenceUtils.getDrillSqlReturnTypeInference( + name, + functions)); + } + public List getFunctions() { return functions; } + + public static class DrillSqlAggOperatorBuilder { + private String name; + private final List functions = Lists.newArrayList(); + private int argCountMin = Integer.MAX_VALUE; + private int argCountMax = Integer.MIN_VALUE; + private boolean isDeterministic = true; + + public DrillSqlAggOperatorBuilder setName(final String name) { + this.name = name; + return this; + } + + public DrillSqlAggOperatorBuilder addFunctions(Collection functions) { + this.functions.addAll(functions); + return this; + } + + public DrillSqlAggOperatorBuilder setArgumentCount(final int argCountMin, final int argCountMax) { + this.argCountMin = Math.min(this.argCountMin, argCountMin); + this.argCountMax = Math.max(this.argCountMax, argCountMax); + return this; + } + + public DrillSqlAggOperator build() { + return new DrillSqlAggOperator( + name, + functions, + argCountMin, + argCountMax); + } + } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorNotInfer.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorNotInfer.java new file mode 100644 index 00000000000..592c23edce7 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorNotInfer.java @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.sql; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorScope; +import org.apache.drill.exec.expr.fn.DrillFuncHolder; + +import java.util.ArrayList; + +public class DrillSqlAggOperatorNotInfer extends DrillSqlAggOperator { + public DrillSqlAggOperatorNotInfer(String name, int argCount) { + super(name, new ArrayList(), argCount, argCount, DynamicReturnType.INSTANCE); + } + + @Override + public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { + return getAny(validator.getTypeFactory()); + } + + private RelDataType getAny(RelDataTypeFactory factory){ + return factory.createSqlType(SqlTypeName.ANY); + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java index 0873c8df3ea..1bb62f3963a 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java @@ -19,12 +19,17 @@ package org.apache.drill.exec.planner.sql; import java.util.ArrayList; +import java.util.Collection; import java.util.List; +import com.google.common.collect.Lists; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.drill.exec.expr.fn.DrillFuncHolder; public class DrillSqlOperator extends SqlFunction { @@ -39,15 +44,54 @@ public class DrillSqlOperator extends SqlFunction { * In principle, if Drill needs a DrillSqlOperator, it is supposed to go to DrillOperatorTable for pickup. */ @Deprecated - public DrillSqlOperator(String name, int argCount, boolean isDeterministic) { - this(name, new ArrayList(), argCount, argCount, isDeterministic); + public DrillSqlOperator(final String name, final int argCount, final boolean isDeterministic) { + this(name, + argCount, + isDeterministic, + DynamicReturnType.INSTANCE); } - public DrillSqlOperator(String name, List functions, int argCountMin, int argCountMax, boolean isDeterministic) { + /** + * This constructor exists for the legacy reason. + * + * It is because Drill cannot access to DrillOperatorTable at the place where this constructor is being called. + * In principle, if Drill needs a DrillSqlOperator, it is supposed to go to DrillOperatorTable for pickup. + */ + @Deprecated + public DrillSqlOperator(final String name, final int argCount, final boolean isDeterministic, + final SqlReturnTypeInference sqlReturnTypeInference) { + this(name, + new ArrayList(), + argCount, + argCount, + isDeterministic, + sqlReturnTypeInference); + } + + /** + * This constructor exists for the legacy reason. + * + * It is because Drill cannot access to DrillOperatorTable at the place where this constructor is being called. + * In principle, if Drill needs a DrillSqlOperator, it is supposed to go to DrillOperatorTable for pickup. + */ + @Deprecated + public DrillSqlOperator(final String name, final int argCount, final boolean isDeterministic, final RelDataType type) { + this(name, + new ArrayList(), + argCount, + argCount, + isDeterministic, new SqlReturnTypeInference() { + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return type; + } + }); + } + + protected DrillSqlOperator(String name, List functions, int argCountMin, int argCountMax, boolean isDeterministic, + SqlReturnTypeInference sqlReturnTypeInference) { super(new SqlIdentifier(name, SqlParserPos.ZERO), - TypeInferenceUtils.getDrillSqlReturnTypeInference( - name, - functions), + sqlReturnTypeInference, null, Checker.getChecker(argCountMin, argCountMax), null, @@ -64,4 +108,47 @@ public boolean isDeterministic() { public List getFunctions() { return functions; } + + public static class DrillSqlOperatorBuilder { + private String name; + private final List functions = Lists.newArrayList(); + private int argCountMin = Integer.MAX_VALUE; + private int argCountMax = Integer.MIN_VALUE; + private boolean isDeterministic = true; + + public DrillSqlOperatorBuilder setName(final String name) { + this.name = name; + return this; + } + + public DrillSqlOperatorBuilder addFunctions(Collection functions) { + this.functions.addAll(functions); + return this; + } + + public DrillSqlOperatorBuilder setArgumentCount(final int argCountMin, final int argCountMax) { + this.argCountMin = Math.min(this.argCountMin, argCountMin); + this.argCountMax = Math.max(this.argCountMax, argCountMax); + return this; + } + + public DrillSqlOperatorBuilder setDeterministic(boolean isDeterministic) { + if(this.isDeterministic) { + this.isDeterministic = isDeterministic; + } + return this; + } + + public DrillSqlOperator build() { + return new DrillSqlOperator( + name, + functions, + argCountMin, + argCountMax, + isDeterministic, + TypeInferenceUtils.getDrillSqlReturnTypeInference( + name, + functions)); + } + } } \ No newline at end of file diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorNotInfer.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorNotInfer.java new file mode 100644 index 00000000000..a7394bd0a81 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorNotInfer.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.sql; + +import com.google.common.base.Preconditions; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorScope; +import org.apache.drill.common.types.TypeProtos; +import org.apache.drill.exec.expr.fn.DrillFuncHolder; + +import java.util.ArrayList; + +public class DrillSqlOperatorNotInfer extends DrillSqlOperator { + private static final TypeProtos.MajorType NONE = TypeProtos.MajorType.getDefaultInstance(); + private final TypeProtos.MajorType returnType; + + public DrillSqlOperatorNotInfer(String name, int argCount, TypeProtos.MajorType returnType, boolean isDeterminisitic) { + super(name, + new ArrayList< DrillFuncHolder>(), + argCount, + argCount, + isDeterminisitic, + DynamicReturnType.INSTANCE); + this.returnType = Preconditions.checkNotNull(returnType); + } + + protected RelDataType getReturnDataType(final RelDataTypeFactory factory) { + if (TypeProtos.MinorType.BIT.equals(returnType.getMinorType())) { + return factory.createSqlType(SqlTypeName.BOOLEAN); + } + return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true); + } + + private RelDataType getNullableReturnDataType(final RelDataTypeFactory factory) { + return factory.createTypeWithNullability(getReturnDataType(factory), true); + } + + @Override + public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { + if (NONE.equals(returnType)) { + return validator.getTypeFactory().createSqlType(SqlTypeName.ANY); + } + /* + * We return a nullable output type both in validation phase and in + * Sql to Rel phase. We don't know the type of the output until runtime + * hence have to choose the least restrictive type to avoid any wrong + * results. + */ + return getNullableReturnDataType(validator.getTypeFactory()); + } + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return getNullableReturnDataType(opBinding.getTypeFactory()); + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/SqlConverter.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/SqlConverter.java index 2e0afeac438..fc63276661e 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/SqlConverter.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/SqlConverter.java @@ -19,7 +19,9 @@ import java.util.Arrays; import java.util.List; +import java.util.Set; +import com.google.common.collect.Sets; import org.apache.calcite.adapter.java.JavaTypeFactory; import org.apache.calcite.avatica.util.Casing; import org.apache.calcite.avatica.util.Quoting; @@ -37,18 +39,23 @@ import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeSystemImpl; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.sql.parser.SqlParserImplFactory; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.util.ChainedSqlOperatorTable; +import org.apache.calcite.sql.validate.AggregatingSelectScope; import org.apache.calcite.sql.validate.SqlConformance; import org.apache.calcite.sql.validate.SqlValidatorCatalogReader; import org.apache.calcite.sql.validate.SqlValidatorImpl; +import org.apache.calcite.sql.validate.SqlValidatorScope; import org.apache.calcite.sql2rel.RelDecorrelator; import org.apache.calcite.sql2rel.SqlToRelConverter; import org.apache.drill.common.exceptions.UserException; @@ -183,10 +190,40 @@ public SchemaPlus getDefaultSchema() { } private class DrillValidator extends SqlValidatorImpl { + private final Set identitySet = Sets.newIdentityHashSet(); + protected DrillValidator(SqlOperatorTable opTab, SqlValidatorCatalogReader catalogReader, RelDataTypeFactory typeFactory, SqlConformance conformance) { super(opTab, catalogReader, typeFactory, conformance); } + + @Override + public SqlValidatorScope getSelectScope(final SqlSelect select) { + final SqlValidatorScope sqlValidatorScope = super.getSelectScope(select); + if(needsValidation(sqlValidatorScope)) { + final AggregatingSelectScope aggregatingSelectScope = ((AggregatingSelectScope) sqlValidatorScope); + for(SqlNode sqlNode : aggregatingSelectScope.groupExprList) { + if(sqlNode instanceof SqlCall) { + final SqlCall sqlCall = (SqlCall) sqlNode; + sqlCall.getOperator().deriveType(this, sqlValidatorScope, sqlCall); + } + } + identitySet.add(sqlValidatorScope); + } + return sqlValidatorScope; + } + + // Due to the deep-copy of AggregatingSelectScope in the following two commits in the Forked Drill-Calcite: + // 1. [StarColumn] Reverse one change in CALCITE-356, which regresses AggChecker logic, after * query in schema-less table is added. + // 2. [StarColumn] When group-by a column, projecting on a star which cannot be expanded at planning time, + // use ITEM operator to wrap this column + private boolean needsValidation(final SqlValidatorScope sqlValidatorScope) { + if(sqlValidatorScope instanceof AggregatingSelectScope) { + return !identitySet.contains(sqlValidatorScope); + } else { + return false; + } + } } private static class DrillTypeSystem extends RelDataTypeSystemImpl { @@ -218,7 +255,7 @@ public int getMaxNumericPrecision() { public RelNode toRel( final SqlNode validatedNode) { - final RexBuilder rexBuilder = new RexBuilder(typeFactory); + final RexBuilder rexBuilder = new DrillRexBuilder(typeFactory); if (planner == null) { planner = new VolcanoPlanner(costFactory, settings); planner.setExecutor(new DrillConstExecutor(functions, util, settings)); @@ -364,4 +401,35 @@ private static SchemaPlus rootSchema(SchemaPlus schema) { } } + private static class DrillRexBuilder extends RexBuilder { + private DrillRexBuilder(RelDataTypeFactory typeFactory) { + super(typeFactory); + } + + @Override + public RexNode ensureType( + RelDataType type, + RexNode node, + boolean matchNullability) { + RelDataType targetType = type; + if (matchNullability) { + targetType = matchNullability(type, node); + } + if (targetType.getSqlTypeName() == SqlTypeName.ANY) { + return node; + } + if (!node.getType().equals(targetType)) { + if(!targetType.isStruct()) { + final RelDataType anyType = TypeInferenceUtils.createCalciteTypeWithNullability( + getTypeFactory(), + SqlTypeName.ANY, + targetType.isNullable()); + return makeCast(anyType, node); + } else { + return makeCast(targetType, node); + } + } + return node; + } + } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java index 8914b1133d4..9af6fa302d0 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java @@ -227,7 +227,9 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { final DrillFuncHolder func = resolveDrillFuncHolder(opBinding, functions); final RelDataType returnType = getReturnType(opBinding, func); - return returnType; + return returnType.getSqlTypeName() == SqlTypeName.VARBINARY + ? createCalciteTypeWithNullability(factory, SqlTypeName.ANY, returnType.isNullable()) + : returnType; } private static RelDataType getReturnType(final SqlOperatorBinding opBinding, final DrillFuncHolder func) { @@ -512,19 +514,18 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { RelDataType ret = factory.createTypeWithNullability( opBinding.getOperandType(1), isNullable); - if (opBinding instanceof SqlCallBinding) { SqlCallBinding callBinding = (SqlCallBinding) opBinding; SqlNode operand0 = callBinding.operand(0); // dynamic parameters and null constants need their types assigned // to them using the type they are casted to. - if (((operand0 instanceof SqlLiteral) - && (((SqlLiteral) operand0).getValue() == null)) + if(((operand0 instanceof SqlLiteral) + && (((SqlLiteral) operand0).getValue() == null)) || (operand0 instanceof SqlDynamicParam)) { callBinding.getValidator().setValidatedNodeType( - operand0, - ret); + operand0, + ret); } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/server/options/SystemOptionManager.java b/exec/java-exec/src/main/java/org/apache/drill/exec/server/options/SystemOptionManager.java index 1e54e5c02ce..cbc5c095e00 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/server/options/SystemOptionManager.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/server/options/SystemOptionManager.java @@ -83,6 +83,7 @@ public class SystemOptionManager extends BaseOptionManager implements AutoClosea PlannerSettings.HEP_OPT, PlannerSettings.PLANNER_MEMORY_LIMIT, PlannerSettings.HEP_PARTITION_PRUNING, + PlannerSettings.TYPE_INFERENCE, ExecConstants.CAST_TO_NULLABLE_NUMERIC_OPTION, ExecConstants.OUTPUT_FORMAT_VALIDATOR, ExecConstants.PARQUET_BLOCK_SIZE_VALIDATOR, diff --git a/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java b/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java index 81d093c88b4..ad9a2053172 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java +++ b/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java @@ -29,7 +29,7 @@ public class TestFunctionsWithTypeExpoQueries extends BaseTestQuery { @Test public void testConcatWithMoreThanTwoArgs() throws Exception { - final String query = "select concat(r_name, r_name, r_name) as col \n" + + final String query = "select concat(r_name, r_name, r_name, 'f') as col \n" + "from cp.`tpch/region.parquet` limit 0"; List> expectedSchema = Lists.newArrayList(); @@ -58,7 +58,6 @@ public void testRow_NumberInView() throws Exception { " over(order by position_id) as rnum " + " from cp.`employee.json`)"; - final String view2 = "create view TestFunctionsWithTypeExpoQueries_testViewShield2 as \n" + "select row_number() over(order by position_id) as rnum, " + @@ -68,7 +67,6 @@ public void testRow_NumberInView() throws Exception { test(view1); test(view2); - testBuilder() .sqlQuery("select * from TestFunctionsWithTypeExpoQueries_testViewShield1") .ordered() @@ -113,6 +111,38 @@ public void testLRBTrimOneArg() throws Exception { .run(); } + @Test + public void testTrim() throws Exception { + final String query1 = "SELECT trim('drill') as col FROM cp.`tpch/region.parquet` limit 0"; + final String query2 = "SELECT trim('drill') as col FROM cp.`tpch/region.parquet` limit 0"; + final String query3 = "SELECT trim('drill') as col FROM cp.`tpch/region.parquet` limit 0"; + + List> expectedSchema = Lists.newArrayList(); + TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.VARCHAR) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query1) + .schemaBaseLine(expectedSchema) + .build() + .run(); + + testBuilder() + .sqlQuery(query2) + .schemaBaseLine(expectedSchema) + .build() + .run(); + + testBuilder() + .sqlQuery(query3) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + @Test public void testTrimOneArg() throws Exception { final String query1 = "SELECT trim(leading 'drill') as col FROM cp.`tpch/region.parquet` limit 0"; @@ -294,7 +324,7 @@ public void testSumRequiredType() throws Exception { } @Test - public void testSQRT() throws Exception { + public void testSQRTDecimalLiteral() throws Exception { final String query = "SELECT sqrt(5.1) as col \n" + "from cp.`tpch/nation.parquet` \n" + "limit 0"; @@ -313,6 +343,26 @@ public void testSQRT() throws Exception { .run(); } + @Test + public void testSQRTIntegerLiteral() throws Exception { + final String query = "SELECT sqrt(4) as col \n" + + "from cp.`tpch/nation.parquet` \n" + + "limit 0"; + + List> expectedSchema = Lists.newArrayList(); + TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.FLOAT8) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + @Test public void testTimestampDiff() throws Exception { final String query = "select to_timestamp('2014-02-13 00:30:30','YYYY-MM-dd HH:mm:ss') - to_timestamp('2014-02-13 00:30:30','YYYY-MM-dd HH:mm:ss') as col \n" + @@ -333,6 +383,27 @@ public void testTimestampDiff() throws Exception { .run(); } + @Test + public void testEqualBetweenIntervalAndTimestampDiff() throws Exception { + final String query = "select to_timestamp('2016-11-02 10:00:00','YYYY-MM-dd HH:mm:ss') + interval '10-11' year to month as col \n" + + "from cp.`tpch/region.parquet` \n" + + "where (to_timestamp('2016-11-02 10:00:00','YYYY-MM-dd HH:mm:ss') - to_timestamp('2016-01-01 10:00:00','YYYY-MM-dd HH:mm:ss') < interval '5 10:00:00' day to second) \n" + + "limit 0"; + + final List> expectedSchema = Lists.newArrayList(); + final TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.TIMESTAMP) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + @Test public void testAvgAndSUM() throws Exception { final String query = "SELECT AVG(cast(r_regionkey as float)) AS `col1`, \n" + @@ -368,4 +439,113 @@ public void testAvgAndSUM() throws Exception { .build() .run(); } + + @Test + public void testAvgCountStar() throws Exception { + final String query = "select avg(distinct cast(r_regionkey as bigint)) + avg(cast(r_regionkey as integer)) as col1, \n" + + "sum(distinct cast(r_regionkey as bigint)) + 100 as col2, count(*) as col3 \n" + + "from cp.`tpch/region.parquet` alltypes_v \n" + + "where cast(r_regionkey as bigint) = 100000000000000000 \n" + + "limit 0"; + + final List> expectedSchema = Lists.newArrayList(); + final TypeProtos.MajorType majorType1 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.FLOAT8) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final TypeProtos.MajorType majorType2 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final TypeProtos.MajorType majorType3 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col1"), majorType1)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col2"), majorType2)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col3"), majorType3)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + + @Test // explain plan including all attributes for + public void testUDFInGroupBy() throws Exception { + final String query = "select count(*) as col1, substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2) as col2, \n" + + "char_length(substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2)) as col3 \n" + + "from cp.`tpch/region.parquet` t1 \n" + + "left outer join cp.`tpch/nation.parquet` t2 on cast(t1.r_regionkey as Integer) = cast(t2.n_nationkey as Integer) \n" + + "left outer join cp.`employee.json` t3 on cast(t1.r_regionkey as Integer) = cast(t3.employee_id as Integer) \n" + + "group by substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2), \n" + + "char_length(substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2)) \n" + + "order by substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2),\n" + + "char_length(substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2)) \n" + + "limit 0"; + + final List> expectedSchema = Lists.newArrayList(); + final TypeProtos.MajorType majorType1 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + + final TypeProtos.MajorType majorType2 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.VARCHAR) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final TypeProtos.MajorType majorType3 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col1"), majorType1)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col2"), majorType2)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col3"), majorType3)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + + @Test + public void testWindowSumAvg() throws Exception { + final String query = "with query as ( \n" + + "select sum(cast(employee_id as integer)) over w as col1, cast(avg(cast(employee_id as bigint)) over w as double precision) as col2, count(*) over w as col3 \n" + + "from cp.`tpch/region.parquet` \n" + + "window w as (partition by cast(full_name as varchar(10)) order by cast(full_name as varchar(10)) nulls first)) \n" + + "select * \n" + + "from query \n" + + "limit 0"; + + final List> expectedSchema = Lists.newArrayList(); + final TypeProtos.MajorType majorType1 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final TypeProtos.MajorType majorType2 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.FLOAT8) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final TypeProtos.MajorType majorType3 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col1"), majorType1)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col2"), majorType2)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col3"), majorType3)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } }