From c0293354ec79b42ff27ce4ad2113a2ff52a934bd Mon Sep 17 00:00:00 2001 From: Hsuan-Yi Chu Date: Thu, 3 Mar 2016 22:38:04 -0800 Subject: [PATCH 1/5] DRILL-4372: Expose the functions return type to Drill - Drill-Calite version update: This commit needs to have Calcite's patch (CALCITE-1062) to plugin customized SqlOperator. - FunctionTemplate Add FunctionArgumentNumber annotation. This annotation element tells if the number of argument(s) is fixed or arbitrary (e.g., String concatenation function). Due to this modification, there are some minor changes in DrillFuncHolder, DrillFunctionRegistry and FunctionAttributes. - Checker Add a new Checker (which Calcite uses to validate the legitimacy of the number of argument(s) for a function) to allow functions with arbitrary arguments to pass Caclite's validation - Type conversion between Drill and Calcite DrillConstExector is given a static method getDrillTypeFromCalcite() to convert Calcite types to Drill's. - Extract function's return type inference Unlike other functions, Extract function's return type can be determined solely based on the first argument. A logic is added in to allow this inference to happen - DrillCalcite wrapper: From the aspects of return type inference and argument type checks, Calcite's mechanism is very different from Drill's. In addition, currently, there is no straightforward way for Drill to plug-in customized mechanisms to Calcite. Thus, wrappers are provided to serve the objective. Except for the mechanisms of type inference and argument type checks, these wrappers just forward any method calls to the wrapped SqlOpertor, SqlFuncion or SqlAggFunction to respond. A interface DrillCalciteSqlWrapper is also added for the callers of the three wrappers to get the wrapped objects easier. Due to these wrappers, UnsupportedOperatorsVisitor is modified in a minor manner. - Calcite's SqlOpertor, SqlFuncion or SqlAggFunction are wrapped in DrillOperatorTable Instead of returning Caclite's native SqlOpertor, SqlFuncion or SqlAggFunction, return the wrapped ones to ensure customized behaviors can be adopted. - Type inference mechanism This mechanism is used across all SqlOpertor, SqlFuncion or SqlAggFunction. Thus, it is factored out as its own method in TypeInferenceUtils - Upgrade Drill-Calcite Bump version number to 1.4.0-drill-test-r16 - Implement two argument version of lpad, rpad - Implement one argument version of ltrim, rtrim, btrim --- .../drill/exec/expr/fn/DrillFuncHolder.java | 26 +- .../exec/expr/fn/DrillFunctionRegistry.java | 107 ++- .../fn/FunctionImplementationRegistry.java | 6 + .../exec/expr/fn/impl/StringFunctions.java | 225 ++++++ .../drill/exec/planner/PlannerPhase.java | 10 + .../planner/logical/DrillConstExecutor.java | 81 +-- .../exec/planner/logical/DrillOptiq.java | 12 - .../logical/DrillReduceAggregatesRule.java | 250 +++++-- .../planner/logical/PreProcessLogicalRel.java | 75 +- .../visitor/InsertLocalExchangeVisitor.java | 2 +- .../drill/exec/planner/sql/Checker.java | 35 +- .../DrillCalciteSqlAggFunctionWrapper.java | 162 +++++ .../sql/DrillCalciteSqlFunctionWrapper.java | 147 ++++ .../sql/DrillCalciteSqlOperatorWrapper.java | 140 ++++ .../planner/sql/DrillCalciteSqlWrapper.java | 33 + .../planner/sql/DrillConvertletTable.java | 14 +- .../planner/sql/DrillExtractConvertlet.java | 8 +- .../exec/planner/sql/DrillOperatorTable.java | 92 ++- .../exec/planner/sql/DrillSqlAggOperator.java | 55 +- .../exec/planner/sql/DrillSqlOperator.java | 81 +-- .../exec/planner/sql/TypeInferenceUtils.java | 649 ++++++++++++++++++ .../sql/handlers/CreateTableHandler.java | 2 +- .../sql/handlers/DefaultSqlHandler.java | 10 +- .../parser/UnsupportedOperatorsVisitor.java | 21 +- .../resolver/DefaultFunctionResolver.java | 13 +- .../exec/resolver/ExactFunctionResolver.java | 10 +- .../drill/exec/resolver/FunctionResolver.java | 17 +- .../resolver/FunctionResolverFactory.java | 3 - .../drill/exec/resolver/TypeCastRules.java | 18 +- .../drill/TestDisabledFunctionality.java | 10 - .../TestFunctionsWithTypeExpoQueries.java | 281 +++++++- .../expr/fn/impl/TestStringFunctions.java | 85 +++ .../testConcatWithMoreThanTwoArgs.tsv | 5 + .../typeExposure/metadata_caching/a.parquet | Bin 0 -> 439 bytes .../typeExposure/metadata_caching/b.parquet | Bin 0 -> 474 bytes .../expression/FunctionCallFactory.java | 2 +- .../MajorTypeInLogicalExpression.java | 63 ++ pom.xml | 2 +- 38 files changed, 2424 insertions(+), 328 deletions(-) create mode 100644 exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlAggFunctionWrapper.java create mode 100644 exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlFunctionWrapper.java create mode 100644 exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlOperatorWrapper.java create mode 100644 exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlWrapper.java create mode 100644 exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java create mode 100644 exec/java-exec/src/test/resources/testframework/testFunctionsWithTypeExpoQueries/testConcatWithMoreThanTwoArgs.tsv create mode 100644 exec/java-exec/src/test/resources/typeExposure/metadata_caching/a.parquet create mode 100644 exec/java-exec/src/test/resources/typeExposure/metadata_caching/b.parquet create mode 100644 logical/src/main/java/org/apache/drill/common/expression/MajorTypeInLogicalExpression.java diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFuncHolder.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFuncHolder.java index a9cdbc7ea83..869a4acb476 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFuncHolder.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFuncHolder.java @@ -19,10 +19,8 @@ import java.util.Arrays; import java.util.List; -import java.util.Map; import java.util.Set; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.apache.drill.common.exceptions.DrillRuntimeException; import org.apache.drill.common.exceptions.UserException; @@ -264,27 +262,31 @@ public boolean isFieldReader(int i) { return this.parameters[i].isFieldReader; } - public MajorType getReturnType(List args) { + public MajorType getReturnType(final List logicalExpressions) { if (returnValue.type.getMinorType() == MinorType.UNION) { - Set subTypes = Sets.newHashSet(); - for (ValueReference ref : parameters) { + final Set subTypes = Sets.newHashSet(); + for(final ValueReference ref : parameters) { subTypes.add(ref.getType().getMinorType()); } - MajorType.Builder builder = MajorType.newBuilder().setMinorType(MinorType.UNION).setMode(DataMode.OPTIONAL); - for (MinorType subType : subTypes) { + + final MajorType.Builder builder = MajorType.newBuilder() + .setMinorType(MinorType.UNION) + .setMode(DataMode.OPTIONAL); + + for(final MinorType subType : subTypes) { builder.addSubType(subType); } return builder.build(); } - if (nullHandling == NullHandling.NULL_IF_NULL) { + + if(nullHandling == NullHandling.NULL_IF_NULL) { // if any one of the input types is nullable, then return nullable return type - for (LogicalExpression e : args) { - if (e.getMajorType().getMode() == TypeProtos.DataMode.OPTIONAL) { + for(final LogicalExpression logicalExpression : logicalExpressions) { + if(logicalExpression.getMajorType().getMode() == TypeProtos.DataMode.OPTIONAL) { return Types.optional(returnValue.type.getMinorType()); } } } - return returnValue.type; } @@ -405,7 +407,6 @@ public Class getType() { public String getName() { return name; } - } public boolean checkPrecisionRange() { @@ -419,5 +420,4 @@ public MajorType getReturnType() { public ValueReference getReturnValue() { return returnValue; } - } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java index 05439b36ec3..76ec90dde5d 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java @@ -17,40 +17,61 @@ */ package org.apache.drill.exec.expr.fn; -import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Map.Entry; -import java.util.Set; -import org.apache.calcite.sql.SqlOperator; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.commons.lang3.tuple.Pair; import org.apache.drill.common.scanner.persistence.AnnotatedClassDescriptor; import org.apache.drill.common.scanner.persistence.ScanResult; -import org.apache.drill.exec.expr.DrillFunc; +import org.apache.drill.common.types.TypeProtos; import org.apache.drill.exec.planner.logical.DrillConstExecutor; import org.apache.drill.exec.planner.sql.DrillOperatorTable; import org.apache.drill.exec.planner.sql.DrillSqlAggOperator; import org.apache.drill.exec.planner.sql.DrillSqlOperator; import com.google.common.collect.ArrayListMultimap; -import com.google.common.collect.Sets; +/** + * Registry of Drill functions. + */ public class DrillFunctionRegistry { - static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillFunctionRegistry.class); + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillFunctionRegistry.class); + + // key: function name (lowercase) value: list of functions with that name + private final ArrayListMultimap registeredFunctions = ArrayListMultimap.create(); - private ArrayListMultimap methods = ArrayListMultimap.create(); + private static final ImmutableMap> drillFuncToRange = ImmutableMap.> builder() + // CONCAT is allowed to take [1, infinity) number of arguments. + // Currently, this flexibility is offered by DrillOptiq to rewrite it as + // a nested structure + .put("CONCAT", Pair.of(1, Integer.MAX_VALUE)) - /* Hash map to prevent registering functions with exactly matching signatures - * key: Function Name + Input's Major Type - * Value: Class name where function is implemented - */ - private HashMap functionSignatureMap = new HashMap<>(); + // When LENGTH is given two arguments, this function relies on DrillOptiq to rewrite it as + // another function based on the second argument (encodingType) + .put("LENGTH", Pair.of(1, 2)) + + // Dummy functions + .put("CONVERT_TO", Pair.of(2, 2)) + .put("CONVERT_FROM", Pair.of(2, 2)) + .put("FLATTEN", Pair.of(1, 1)).build(); public DrillFunctionRegistry(ScanResult classpathScan) { FunctionConverter converter = new FunctionConverter(); List providerClasses = classpathScan.getAnnotatedClasses(); + + // Hash map to prevent registering functions with exactly matching signatures + // key: Function Name + Input's Major Type + // value: Class name where function is implemented + // + final Map functionSignatureMap = new HashMap<>(); for (AnnotatedClassDescriptor func : providerClasses) { DrillFuncHolder holder = converter.getHolder(func); if (holder != null) { @@ -64,7 +85,7 @@ public DrillFunctionRegistry(ScanResult classpathScan) { } for (String name : names) { String functionName = name.toLowerCase(); - methods.put(functionName, holder); + registeredFunctions.put(functionName, holder); String functionSignature = functionName + functionInput; String existingImplementation; if ((existingImplementation = functionSignatureMap.get(functionSignature)) != null) { @@ -84,7 +105,7 @@ public DrillFunctionRegistry(ScanResult classpathScan) { } if (logger.isTraceEnabled()) { StringBuilder allFunctions = new StringBuilder(); - for (DrillFuncHolder method: methods.values()) { + for (DrillFuncHolder method: registeredFunctions.values()) { allFunctions.append(method.toString()).append("\n"); } logger.trace("Registered functions: [\n{}]", allFunctions); @@ -92,38 +113,54 @@ public DrillFunctionRegistry(ScanResult classpathScan) { } public int size(){ - return methods.size(); + return registeredFunctions.size(); } /** Returns functions with given name. Function name is case insensitive. */ public List getMethods(String name) { - return this.methods.get(name.toLowerCase()); + return this.registeredFunctions.get(name.toLowerCase()); } public void register(DrillOperatorTable operatorTable) { - SqlOperator op; - for (Entry> function : methods.asMap().entrySet()) { - Set argCounts = Sets.newHashSet(); - String name = function.getKey().toUpperCase(); + for (Entry> function : registeredFunctions.asMap().entrySet()) { + final ArrayListMultimap, DrillFuncHolder> functions = ArrayListMultimap.create(); + final ArrayListMultimap aggregateFunctions = ArrayListMultimap.create(); + final String name = function.getKey().toUpperCase(); + boolean isDeterministic = true; for (DrillFuncHolder func : function.getValue()) { - if (argCounts.add(func.getParamCount())) { - if (func.isAggregating()) { - op = new DrillSqlAggOperator(name, func.getParamCount()); + final int paramCount = func.getParamCount(); + if(func.isAggregating()) { + aggregateFunctions.put(paramCount, func); + } else { + final Pair argNumberRange; + if(drillFuncToRange.containsKey(name)) { + argNumberRange = drillFuncToRange.get(name); } else { - boolean isDeterministic; - // prevent Drill from folding constant functions with types that cannot be materialized - // into literals - if (DrillConstExecutor.NON_REDUCIBLE_TYPES.contains(func.getReturnType().getMinorType())) { - isDeterministic = false; - } else { - isDeterministic = func.isDeterministic(); - } - op = new DrillSqlOperator(name, func.getParamCount(), func.getReturnType(), isDeterministic); + argNumberRange = Pair.of(func.getParamCount(), func.getParamCount()); } - operatorTable.add(function.getKey(), op); + functions.put(argNumberRange, func); } + + if(!func.isDeterministic()) { + isDeterministic = false; + } + } + for (Entry, Collection> entry : functions.asMap().entrySet()) { + final DrillSqlOperator drillSqlOperator; + final Pair range = entry.getKey(); + final int max = range.getRight(); + final int min = range.getLeft(); + drillSqlOperator = new DrillSqlOperator( + name, + Lists.newArrayList(entry.getValue()), + min, + max, + isDeterministic); + operatorTable.add(name, drillSqlOperator); + } + for (Entry> entry : aggregateFunctions.asMap().entrySet()) { + operatorTable.add(name, new DrillSqlAggOperator(name, Lists.newArrayList(entry.getValue()), entry.getKey())); } } } - } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionImplementationRegistry.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionImplementationRegistry.java index 5985f0eb6eb..2feac1a405a 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionImplementationRegistry.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionImplementationRegistry.java @@ -19,6 +19,7 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; +import java.util.Collection; import java.util.List; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -40,6 +41,11 @@ import com.google.common.base.Stopwatch; import com.google.common.collect.Lists; +/** + * This class offers the registry for functions. Notably, in addition to Drill its functions + * (in {@link DrillFunctionRegistry}), other PluggableFunctionRegistry (e.g., {@link org.apache.drill.exec.expr.fn.HiveFunctionRegistry}) + * is also registered in this class + */ public class FunctionImplementationRegistry implements FunctionLookupContext { static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(FunctionImplementationRegistry.class); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java index f5aeaf68b2d..112f5fdc1a5 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java @@ -770,6 +770,65 @@ public void eval() { } // end of eval } + /* + * Fill up the string to length 'length' by prepending the character ' ' in the beginning of 'text'. + * If the string is already longer than length, then it is truncated (on the right). + */ + @FunctionTemplate(name = "lpad", scope = FunctionScope.SIMPLE, nulls = NullHandling.NULL_IF_NULL) + public static class LpadTwoArg implements DrillSimpleFunc { + @Param VarCharHolder text; + @Param BigIntHolder length; + @Inject DrillBuf buffer; + + @Output VarCharHolder out; + @Workspace byte spaceInByte; + + @Override + public void setup() { + spaceInByte = 32; + } + + @Override + public void eval() { + final long theLength = length.value; + final int lengthNeeded = (int) (theLength <= 0 ? 0 : theLength * 2); + buffer = buffer.reallocIfNeeded(lengthNeeded); + //get the char length of text. + int textCharCount = org.apache.drill.exec.expr.fn.impl.StringFunctionUtil.getUTF8CharLength(text.buffer, text.start, text.end); + + if (theLength <= 0) { + //case 1: target length is <=0, then return an empty string. + out.buffer = buffer; + out.start = out.end = 0; + } else if (theLength == textCharCount) { + //case 2: target length is same as text's length. + out.buffer = text.buffer; + out.start = text.start; + out.end = text.end; + } else if (theLength < textCharCount) { + //case 3: truncate text on the right side. It's same as substring(text, 1, length). + out.buffer = text.buffer; + out.start = text.start; + out.end = org.apache.drill.exec.expr.fn.impl.StringFunctionUtil.getUTF8CharPosition(text.buffer, text.start, text.end, (int) theLength); + } else if (theLength > textCharCount) { + //case 4: copy " " on left. Total # of char to copy : theLength - textCharCount + int count = 0; + out.buffer = buffer; + out.start = out.end = 0; + + while (count < theLength - textCharCount) { + out.buffer.setByte(out.end++, spaceInByte); + ++count; + } // end of while + + //copy "text" into "out" + for (int id = text.start; id < text.end; id++) { + out.buffer.setByte(out.end++, text.buffer.getByte(id)); + } + } + } // end of eval + } + /** * Fill up the string to length "length" by appending the characters 'fill' at the end of 'text' * If the string is already longer than length then it is truncated. @@ -848,6 +907,68 @@ public void eval() { } // end of eval } + /** + * Fill up the string to length "length" by appending the characters ' ' at the end of 'text' + * If the string is already longer than length then it is truncated. + */ + @FunctionTemplate(name = "rpad", scope = FunctionScope.SIMPLE, nulls = NullHandling.NULL_IF_NULL) + public static class RpadTwoArg implements DrillSimpleFunc { + @Param VarCharHolder text; + @Param BigIntHolder length; + @Inject DrillBuf buffer; + + @Output VarCharHolder out; + @Workspace byte spaceInByte; + + @Override + public void setup() { + spaceInByte = 32; + } + + @Override + public void eval() { + final long theLength = length.value; + final int lengthNeeded = (int) (theLength <= 0 ? 0 : theLength * 2); + buffer = buffer.reallocIfNeeded(lengthNeeded); + + //get the char length of text. + int textCharCount = org.apache.drill.exec.expr.fn.impl.StringFunctionUtil.getUTF8CharLength(text.buffer, text.start, text.end); + + if (theLength <= 0) { + //case 1: target length is <=0, then return an empty string. + out.buffer = buffer; + out.start = out.end = 0; + } else if (theLength == textCharCount) { + //case 2: target length is same as text's length. + out.buffer = text.buffer; + out.start = text.start; + out.end = text.end; + } else if (theLength < textCharCount) { + //case 3: truncate text on the right side. It's same as substring(text, 1, length). + out.buffer = text.buffer; + out.start = text.start; + out.end = org.apache.drill.exec.expr.fn.impl.StringFunctionUtil.getUTF8CharPosition(text.buffer, text.start, text.end, (int) theLength); + } else if (theLength > textCharCount) { + //case 4: copy "text" into "out", then copy " " on the right. + out.buffer = buffer; + out.start = out.end = 0; + + for (int id = text.start; id < text.end; id++) { + out.buffer.setByte(out.end++, text.buffer.getByte(id)); + } + + //copy " " on right. Total # of char to copy : theLength - textCharCount + int count = 0; + + while (count < theLength - textCharCount) { + out.buffer.setByte(out.end++, spaceInByte); + ++count; + } // end of while + + } + } // end of eval + } + /** * Remove the longest string containing only characters from "from" from the start of "text" */ @@ -881,6 +1002,36 @@ public void eval() { } // end of eval } + /** + * Remove the longest string containing only character " " from the start of "text" + */ + @FunctionTemplate(name = "ltrim", scope = FunctionScope.SIMPLE, nulls = NullHandling.NULL_IF_NULL) + public static class LtrimOneArg implements DrillSimpleFunc { + @Param VarCharHolder text; + + @Output VarCharHolder out; + @Workspace byte spaceInByte; + + @Override + public void setup() { + spaceInByte = 32; + } + + @Override + public void eval() { + out.buffer = text.buffer; + out.start = out.end = text.end; + + //Scan from left of "text", stop until find a char not " " + for (int id = text.start; id < text.end; ++id) { + if (text.buffer.getByte(id) != spaceInByte) { // Found the 1st char not " ", stop + out.start = id; + break; + } + } + } // end of eval + } + /** * Remove the longest string containing only characters from "from" from the end of "text" */ @@ -917,6 +1068,39 @@ public void eval() { } // end of eval } + /** + * Remove the longest string containing only character " " from the end of "text" + */ + @FunctionTemplate(name = "rtrim", scope = FunctionScope.SIMPLE, nulls = NullHandling.NULL_IF_NULL) + public static class RtrimOneArg implements DrillSimpleFunc { + @Param VarCharHolder text; + + @Output VarCharHolder out; + @Workspace byte spaceInByte; + + @Override + public void setup() { + spaceInByte = 32; + } + + @Override + public void eval() { + out.buffer = text.buffer; + out.start = out.end = text.start; + + //Scan from right of "text", stop until find a char not in " " + for (int id = text.end - 1; id >= text.start; --id) { + while ((text.buffer.getByte(id) & 0xC0) == 0x80 && id >= text.start) { + id--; + } + if (text.buffer.getByte(id) != spaceInByte) { // Found the 1st char not in " ", stop + out.end = id + 1; + break; + } + } + } // end of eval + } + /** * Remove the longest string containing only characters from "from" from the start of "text" */ @@ -964,6 +1148,47 @@ public void eval() { } // end of eval } + /** + * Remove the longest string containing only character " " from the start of "text" + */ + @FunctionTemplate(name = "btrim", scope = FunctionScope.SIMPLE, nulls = NullHandling.NULL_IF_NULL) + public static class BtrimOneArg implements DrillSimpleFunc { + @Param VarCharHolder text; + + @Output VarCharHolder out; + @Workspace byte spaceInByte; + + @Override + public void setup() { + spaceInByte = 32; + } + + @Override + public void eval() { + out.buffer = text.buffer; + out.start = out.end = text.start; + + //Scan from left of "text", stop until find a char not " " + for (int id = text.start; id < text.end; ++id) { + if (text.buffer.getByte(id) != spaceInByte) { // Found the 1st char not " ", stop + out.start = id; + break; + } + } + + //Scan from right of "text", stop until find a char not " " + for (int id = text.end - 1; id >= text.start; --id) { + while ((text.buffer.getByte(id) & 0xC0) == 0x80 && id >= text.start) { + id--; + } + if (text.buffer.getByte(id) != spaceInByte) { // Found the 1st char not in " ", stop + out.end = id + 1; + break; + } + } + } // end of eval + } + @FunctionTemplate(name = "concatOperator", scope = FunctionScope.SIMPLE, nulls = NullHandling.NULL_IF_NULL) public static class ConcatOperator implements DrillSimpleFunc { @Param VarCharHolder left; diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java index 7cbed772c24..7ab7faf5213 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java @@ -135,6 +135,16 @@ public RuleSet getRules(OptimizerRulesContext context, Collection } }, + SUM_CONVERSION("Convert SUM to $SUM0") { + public RuleSet getRules(OptimizerRulesContext context, Collection plugins) { + return PlannerPhase.mergedRuleSets( + RuleSets.ofList( + DrillReduceAggregatesRule.INSTANCE_SUM), + getStorageRules(context, plugins, this) + ); + } + }, + PARTITION_PRUNING("Partition Prune Planning") { public RuleSet getRules(OptimizerRulesContext context, Collection plugins) { return PlannerPhase.mergedRuleSets(getPruneScanRules(context), getStorageRules(context, plugins, this)); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillConstExecutor.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillConstExecutor.java index 78d27018c1a..96579dbea50 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillConstExecutor.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillConstExecutor.java @@ -58,6 +58,7 @@ import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.NlsString; import org.apache.drill.exec.planner.physical.PlannerSettings; +import org.apache.drill.exec.planner.sql.TypeInferenceUtils; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; @@ -70,34 +71,6 @@ public class DrillConstExecutor implements RelOptPlanner.Executor { private final PlannerSettings plannerSettings; - public static ImmutableMap DRILL_TO_CALCITE_TYPE_MAPPING = - ImmutableMap. builder() - .put(TypeProtos.MinorType.INT, SqlTypeName.INTEGER) - .put(TypeProtos.MinorType.BIGINT, SqlTypeName.BIGINT) - .put(TypeProtos.MinorType.FLOAT4, SqlTypeName.FLOAT) - .put(TypeProtos.MinorType.FLOAT8, SqlTypeName.DOUBLE) - .put(TypeProtos.MinorType.VARCHAR, SqlTypeName.VARCHAR) - .put(TypeProtos.MinorType.BIT, SqlTypeName.BOOLEAN) - .put(TypeProtos.MinorType.DATE, SqlTypeName.DATE) - .put(TypeProtos.MinorType.DECIMAL9, SqlTypeName.DECIMAL) - .put(TypeProtos.MinorType.DECIMAL18, SqlTypeName.DECIMAL) - .put(TypeProtos.MinorType.DECIMAL28SPARSE, SqlTypeName.DECIMAL) - .put(TypeProtos.MinorType.DECIMAL38SPARSE, SqlTypeName.DECIMAL) - .put(TypeProtos.MinorType.TIME, SqlTypeName.TIME) - .put(TypeProtos.MinorType.TIMESTAMP, SqlTypeName.TIMESTAMP) - .put(TypeProtos.MinorType.VARBINARY, SqlTypeName.VARBINARY) - .put(TypeProtos.MinorType.INTERVALYEAR, SqlTypeName.INTERVAL_YEAR_MONTH) - .put(TypeProtos.MinorType.INTERVALDAY, SqlTypeName.INTERVAL_DAY_TIME) - .put(TypeProtos.MinorType.MAP, SqlTypeName.MAP) - .put(TypeProtos.MinorType.LIST, SqlTypeName.ARRAY) - .put(TypeProtos.MinorType.LATE, SqlTypeName.ANY) - // These are defined in the Drill type system but have been turned off for now - .put(TypeProtos.MinorType.TINYINT, SqlTypeName.TINYINT) - .put(TypeProtos.MinorType.SMALLINT, SqlTypeName.SMALLINT) - // Calcite types currently not supported by Drill, nor defined in the Drill type list: - // - CHAR, SYMBOL, MULTISET, DISTINCT, STRUCTURED, ROW, OTHER, CURSOR, COLUMN_LIST - .build(); - // This is a list of all types that cannot be folded at planning time for various reasons, most of the types are // currently not supported at all. The reasons for the others can be found in the evaluation code in the reduce method public static final List NON_REDUCIBLE_TYPES = ImmutableList.builder().add( @@ -132,30 +105,6 @@ public DrillConstExecutor(FunctionImplementationRegistry funcImplReg, UdfUtiliti this.plannerSettings = plannerSettings; } - private RelDataType createCalciteTypeWithNullability(RelDataTypeFactory typeFactory, - SqlTypeName sqlTypeName, - boolean isNullable) { - RelDataType type; - if (sqlTypeName == SqlTypeName.INTERVAL_DAY_TIME) { - type = typeFactory.createSqlIntervalType( - new SqlIntervalQualifier( - TimeUnit.DAY, - TimeUnit.MINUTE, - SqlParserPos.ZERO)); - } else if (sqlTypeName == SqlTypeName.INTERVAL_YEAR_MONTH) { - type = typeFactory.createSqlIntervalType( - new SqlIntervalQualifier( - TimeUnit.YEAR, - TimeUnit.MONTH, - SqlParserPos.ZERO)); - } else if (sqlTypeName == SqlTypeName.VARCHAR) { - type = typeFactory.createSqlType(sqlTypeName, TypeHelper.VARCHAR_DEFAULT_CAST_LEN); - } else { - type = typeFactory.createSqlType(sqlTypeName); - } - return typeFactory.createTypeWithNullability(type, isNullable); - } - @Override public void reduce(RexBuilder rexBuilder, List constExps, List reducedValues) { for (RexNode newCall : constExps) { @@ -183,7 +132,7 @@ public void reduce(RexBuilder rexBuilder, List constExps, List RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory(); if (materializedExpr.getMajorType().getMode() == TypeProtos.DataMode.OPTIONAL && TypeHelper.isNull(output)) { - SqlTypeName sqlTypeName = DRILL_TO_CALCITE_TYPE_MAPPING.get(materializedExpr.getMajorType().getMinorType()); + SqlTypeName sqlTypeName = TypeInferenceUtils.getCalciteTypeFromDrillType(materializedExpr.getMajorType().getMinorType()); if (sqlTypeName == null) { String message = String.format("Error reducing constant expression, unsupported type: %s.", materializedExpr.getMajorType().getMinorType()); @@ -198,25 +147,25 @@ public void reduce(RexBuilder rexBuilder, List constExps, List case INT: reducedValues.add(rexBuilder.makeLiteral( new BigDecimal(((IntHolder)output).value), - createCalciteTypeWithNullability(typeFactory, SqlTypeName.INTEGER, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.INTEGER, newCall.getType().isNullable()), false)); break; case BIGINT: reducedValues.add(rexBuilder.makeLiteral( new BigDecimal(((BigIntHolder)output).value), - createCalciteTypeWithNullability(typeFactory, SqlTypeName.BIGINT, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.BIGINT, newCall.getType().isNullable()), false)); break; case FLOAT4: reducedValues.add(rexBuilder.makeLiteral( new BigDecimal(((Float4Holder)output).value), - createCalciteTypeWithNullability(typeFactory, SqlTypeName.FLOAT, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.FLOAT, newCall.getType().isNullable()), false)); break; case FLOAT8: reducedValues.add(rexBuilder.makeLiteral( new BigDecimal(((Float8Holder)output).value), - createCalciteTypeWithNullability(typeFactory, SqlTypeName.DOUBLE, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.DOUBLE, newCall.getType().isNullable()), false)); break; case VARCHAR: @@ -226,25 +175,25 @@ public void reduce(RexBuilder rexBuilder, List constExps, List case BIT: reducedValues.add(rexBuilder.makeLiteral( ((BitHolder)output).value == 1 ? true : false, - createCalciteTypeWithNullability(typeFactory, SqlTypeName.BOOLEAN, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.BOOLEAN, newCall.getType().isNullable()), false)); break; case DATE: reducedValues.add(rexBuilder.makeLiteral( new DateTime(((DateHolder) output).value, DateTimeZone.UTC).toCalendar(null), - createCalciteTypeWithNullability(typeFactory, SqlTypeName.DATE, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.DATE, newCall.getType().isNullable()), false)); break; case DECIMAL9: reducedValues.add(rexBuilder.makeLiteral( new BigDecimal(BigInteger.valueOf(((Decimal9Holder) output).value), ((Decimal9Holder)output).scale), - createCalciteTypeWithNullability(typeFactory, SqlTypeName.DECIMAL, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.DECIMAL, newCall.getType().isNullable()), false)); break; case DECIMAL18: reducedValues.add(rexBuilder.makeLiteral( new BigDecimal(BigInteger.valueOf(((Decimal18Holder) output).value), ((Decimal18Holder)output).scale), - createCalciteTypeWithNullability(typeFactory, SqlTypeName.DECIMAL, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.DECIMAL, newCall.getType().isNullable()), false)); break; case DECIMAL28SPARSE: @@ -255,7 +204,7 @@ public void reduce(RexBuilder rexBuilder, List constExps, List decimal28Out.start * 20, 5, decimal28Out.scale), - createCalciteTypeWithNullability(typeFactory, SqlTypeName.DECIMAL, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.DECIMAL, newCall.getType().isNullable()), false )); break; @@ -267,14 +216,14 @@ public void reduce(RexBuilder rexBuilder, List constExps, List decimal38Out.start * 24, 6, decimal38Out.scale), - createCalciteTypeWithNullability(typeFactory, SqlTypeName.DECIMAL, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.DECIMAL, newCall.getType().isNullable()), false)); break; case TIME: reducedValues.add(rexBuilder.makeLiteral( new DateTime(((TimeHolder)output).value, DateTimeZone.UTC).toCalendar(null), - createCalciteTypeWithNullability(typeFactory, SqlTypeName.TIME, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.TIME, newCall.getType().isNullable()), false)); break; case TIMESTAMP: @@ -284,14 +233,14 @@ public void reduce(RexBuilder rexBuilder, List constExps, List case INTERVALYEAR: reducedValues.add(rexBuilder.makeLiteral( new BigDecimal(((IntervalYearHolder)output).value), - createCalciteTypeWithNullability(typeFactory, SqlTypeName.INTERVAL_YEAR_MONTH, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.INTERVAL_YEAR_MONTH, newCall.getType().isNullable()), false)); break; case INTERVALDAY: IntervalDayHolder intervalDayOut = (IntervalDayHolder) output; reducedValues.add(rexBuilder.makeLiteral( new BigDecimal(intervalDayOut.days * DateUtility.daysToStandardMillis + intervalDayOut.milliseconds), - createCalciteTypeWithNullability(typeFactory, SqlTypeName.INTERVAL_DAY_TIME, newCall.getType().isNullable()), + TypeInferenceUtils.createCalciteTypeWithNullability(typeFactory, SqlTypeName.INTERVAL_DAY_TIME, newCall.getType().isNullable()), false)); break; // The list of known unsupported types is used to trigger this behavior of re-using the input expression diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java index bd029bf5110..cae67969cb6 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java @@ -360,11 +360,6 @@ private LogicalExpression getDrillFunctionFromOptiqCall(RexCall call) { trimArgs.add(args.get(1)); return FunctionCallFactory.createExpression(trimFunc, trimArgs); - } else if (functionName.equals("ltrim") || functionName.equals("rtrim") || functionName.equals("btrim")) { - if (argsSize == 1) { - args.add(ValueExpressions.getChar(" ")); - } - return FunctionCallFactory.createExpression(functionName, args); } else if (functionName.equals("date_part")) { // Rewrite DATE_PART functions as extract functions // assert that the function has exactly two arguments @@ -427,13 +422,6 @@ private LogicalExpression getDrillFunctionFromOptiqCall(RexCall call) { } else if ((functionName.equals("convert_from") || functionName.equals("convert_to")) && args.get(1) instanceof QuotedString) { return FunctionCallFactory.createConvert(functionName, ((QuotedString)args.get(1)).value, args.get(0), ExpressionPosition.UNKNOWN); - } else if ((functionName.equalsIgnoreCase("rpad")) || functionName.equalsIgnoreCase("lpad")) { - // If we have only two arguments for rpad/lpad append a default QuotedExpression as an argument which will be used to pad the string - if (argsSize == 2) { - String spaceFill = " "; - LogicalExpression fill = ValueExpressions.getChar(spaceFill); - args.add(fill); - } } return FunctionCallFactory.createExpression(functionName, args); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java index 9ba01a84719..3a2510e02ec 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java @@ -18,6 +18,7 @@ package org.apache.drill.exec.planner.logical; +import com.google.common.collect.ImmutableList; import java.math.BigDecimal; import java.util.ArrayList; @@ -25,12 +26,14 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.logging.Logger; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.apache.calcite.rel.InvalidRelException; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.sql.fun.SqlCountAggFunction; -import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.drill.exec.planner.sql.DrillSqlOperator; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.RelNode; import org.apache.calcite.plan.RelOptRule; @@ -48,12 +51,15 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.fun.SqlSumAggFunction; import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; -import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.CompositeList; import org.apache.calcite.util.ImmutableIntList; import org.apache.calcite.util.Util; -import com.google.common.collect.ImmutableList; +import org.apache.calcite.util.trace.CalciteTrace; +import org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper; +import org.apache.drill.exec.planner.sql.DrillCalciteSqlWrapper; +import org.apache.drill.exec.planner.sql.DrillSqlOperator; /** * Rule to reduce aggregates to simpler forms. Currently only AVG(x) to @@ -65,8 +71,11 @@ public class DrillReduceAggregatesRule extends RelOptRule { /** * The singleton. */ + public static final DrillReduceAggregatesRule INSTANCE = new DrillReduceAggregatesRule(operand(LogicalAggregate.class, any())); + public static final DrillConvertSumToSumZero INSTANCE_SUM = + new DrillConvertSumToSumZero(operand(DrillAggregateRel.class, any())); private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false); @@ -100,8 +109,13 @@ public void onMatch(RelOptRuleCall ruleCall) { */ private boolean containsAvgStddevVarCall(List aggCallList) { for (AggregateCall call : aggCallList) { - if (call.getAggregation() instanceof SqlAvgAggFunction - || call.getAggregation() instanceof SqlSumAggFunction) { + SqlAggFunction sqlAggFunction = call.getAggregation(); + if(sqlAggFunction instanceof DrillCalciteSqlWrapper) { + sqlAggFunction = (SqlAggFunction) ((DrillCalciteSqlWrapper) sqlAggFunction).getOperator(); + } + + if (sqlAggFunction instanceof SqlAvgAggFunction + || sqlAggFunction instanceof SqlSumAggFunction) { return true; } } @@ -198,15 +212,19 @@ private RexNode reduceAgg( List newCalls, Map aggCallMapping, List inputExprs) { - if (oldCall.getAggregation() instanceof SqlSumAggFunction) { + SqlAggFunction sqlAggFunction = oldCall.getAggregation(); + if(sqlAggFunction instanceof DrillCalciteSqlWrapper) { + sqlAggFunction = (SqlAggFunction) ((DrillCalciteSqlWrapper) sqlAggFunction).getOperator(); + } + + if (sqlAggFunction instanceof SqlSumAggFunction) { // replace original SUM(x) with // case COUNT(x) when 0 then null else SUM0(x) end return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping); } - if (oldCall.getAggregation() instanceof SqlAvgAggFunction) { - final SqlAvgAggFunction.Subtype subtype = - ((SqlAvgAggFunction) oldCall.getAggregation()).getSubtype(); + if (sqlAggFunction instanceof SqlAvgAggFunction) { + final SqlAvgAggFunction.Subtype subtype = ((SqlAvgAggFunction) sqlAggFunction).getSubtype(); switch (subtype) { case AVG: // replace original AVG(x) with SUM(x) / COUNT(x) @@ -274,6 +292,7 @@ private RexNode reduceAvg( AggregateCall oldCall, List newCalls, Map aggCallMapping) { + final boolean isWrapper = useWrapper(oldCall); final int nGroups = oldAggRel.getGroupCount(); RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory(); @@ -283,12 +302,25 @@ private RexNode reduceAvg( getFieldType( oldAggRel.getInput(), iAvgInput); - RelDataType sumType = - typeFactory.createTypeWithNullability( - avgInputType, - avgInputType.isNullable() || nGroups == 0); + + final RelDataType sumType; + if(isWrapper) { + sumType = oldCall.getType(); + } else { + sumType = + typeFactory.createTypeWithNullability( + avgInputType, + avgInputType.isNullable() || nGroups == 0); + } // SqlAggFunction sumAgg = new SqlSumAggFunction(sumType); - SqlAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction(); + SqlAggFunction sumAgg; + if(isWrapper) { + sumAgg = new DrillCalciteSqlAggFunctionWrapper( + new SqlSumEmptyIsZeroAggFunction(), sumType); + } else { + sumAgg = new SqlSumEmptyIsZeroAggFunction(); + } + AggregateCall sumCall = new AggregateCall( sumAgg, @@ -358,8 +390,13 @@ private RexNode reduceAvg( SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef); - return rexBuilder.makeCast( - typeFactory.createSqlType(SqlTypeName.ANY), divideRef); + + if(isWrapper) { + return divideRef; + } else { + return rexBuilder.makeCast( + typeFactory.createSqlType(SqlTypeName.ANY), divideRef); + } } private RexNode reduceSum( @@ -367,19 +404,34 @@ private RexNode reduceSum( AggregateCall oldCall, List newCalls, Map aggCallMapping) { + final boolean isWrapper = useWrapper(oldCall); final int nGroups = oldAggRel.getGroupCount(); RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory(); RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); - int arg = oldCall.getArgList().get(0); - RelDataType argType = - getFieldType( - oldAggRel.getInput(), - arg); - RelDataType sumType = - typeFactory.createTypeWithNullability( - argType, argType.isNullable()); - SqlAggFunction sumZeroAgg = new SqlSumEmptyIsZeroAggFunction(); + + final RelDataType argType; + if(isWrapper) { + argType = oldCall.getType(); + } else { + int arg = oldCall.getArgList().get(0); + argType = + getFieldType( + oldAggRel.getInput(), + arg); + } + + final RelDataType sumType; + final SqlAggFunction sumZeroAgg; + if(isWrapper) { + sumType = oldCall.getType(); + sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper( + new SqlSumEmptyIsZeroAggFunction(), sumType); + } else { + sumType = typeFactory.createTypeWithNullability(argType, argType.isNullable()); + sumZeroAgg = new SqlSumEmptyIsZeroAggFunction(); + } + AggregateCall sumZeroCall = new AggregateCall( sumZeroAgg, @@ -436,6 +488,7 @@ private RexNode reduceStddev( List newCalls, Map aggCallMapping, List inputExprs) { + final boolean isWrapper = useWrapper(oldCall); // stddev_pop(x) ==> // power( // (sum(x * x) - sum(x) * sum(x) / count(x)) @@ -472,13 +525,26 @@ private RexNode reduceStddev( typeFactory.createTypeWithNullability( argType, true); - final AggregateCall sumArgSquaredAggCall = - new AggregateCall( - new SqlSumAggFunction(sumType), - oldCall.isDistinct(), - ImmutableIntList.of(argSquaredOrdinal), - sumType, - null); + final AggregateCall sumArgSquaredAggCall; + if(isWrapper) { + sumArgSquaredAggCall = + new AggregateCall( + new DrillCalciteSqlAggFunctionWrapper( + new SqlSumAggFunction(sumType), sumType), + oldCall.isDistinct(), + ImmutableIntList.of(argSquaredOrdinal), + sumType, + null); + } else { + sumArgSquaredAggCall = + new AggregateCall( + new SqlSumAggFunction(sumType), + oldCall.isDistinct(), + ImmutableIntList.of(argSquaredOrdinal), + sumType, + null); + } + final RexNode sumArgSquared = rexBuilder.addAggCall( sumArgSquaredAggCall, @@ -488,13 +554,26 @@ private RexNode reduceStddev( aggCallMapping, ImmutableList.of(argType)); - final AggregateCall sumArgAggCall = - new AggregateCall( - new SqlSumAggFunction(sumType), - oldCall.isDistinct(), - ImmutableIntList.of(argOrdinal), - sumType, - null); + final AggregateCall sumArgAggCall; + if(isWrapper) { + sumArgAggCall = + new AggregateCall( + new DrillCalciteSqlAggFunctionWrapper( + new SqlSumAggFunction(sumType), sumType), + oldCall.isDistinct(), + ImmutableIntList.of(argOrdinal), + sumType, + null); + } else { + sumArgAggCall = + new AggregateCall( + new SqlSumAggFunction(sumType), + oldCall.isDistinct(), + ImmutableIntList.of(argOrdinal), + sumType, + null); + } + final RexNode sumArg = rexBuilder.addAggCall( sumArgAggCall, @@ -577,8 +656,12 @@ private RexNode reduceStddev( * this if we add cast after rewriting the aggregate we add an additional cast which * would cause wrong results. So we simply add a cast to ANY. */ - return rexBuilder.makeCast( - typeFactory.createSqlType(SqlTypeName.ANY), result); + if(isWrapper) { + return result; + } else { + return rexBuilder.makeCast( + typeFactory.createSqlType(SqlTypeName.ANY), result); + } } /** @@ -621,5 +704,88 @@ private RelDataType getFieldType(RelNode relNode, int i) { return inputField.getType(); } + private boolean useWrapper(AggregateCall aggregateCall) { + return aggregateCall.getAggregation() instanceof DrillCalciteSqlWrapper; + } + + private static class DrillConvertSumToSumZero extends RelOptRule { + protected static final Logger tracer = CalciteTrace.getPlannerTracer(); + + public DrillConvertSumToSumZero(RelOptRuleOperand operand) { + super(operand); + } + + @Override + public boolean matches(RelOptRuleCall call) { + DrillAggregateRel oldAggRel = (DrillAggregateRel) call.rels[0]; + for (AggregateCall aggregateCall : oldAggRel.getAggCallList()) { + SqlAggFunction sqlAggFunction = aggregateCall.getAggregation(); + if(sqlAggFunction instanceof DrillCalciteSqlWrapper) { + sqlAggFunction = (SqlAggFunction) ((DrillCalciteSqlWrapper) sqlAggFunction).getOperator(); + } + + if(sqlAggFunction instanceof SqlSumAggFunction + && !aggregateCall.getType().isNullable()) { + // If SUM(x) is not nullable, the validator must have determined that + // nulls are impossible (because the group is never empty and x is never + // null). Therefore we translate to SUM0(x). + return true; + } + } + return false; + } + + @Override + public void onMatch(RelOptRuleCall call) { + final DrillAggregateRel oldAggRel = (DrillAggregateRel) call.rels[0]; + + final Map aggCallMapping = Maps.newHashMap(); + final List newAggregateCalls = Lists.newArrayList(); + for (AggregateCall oldAggregateCall : oldAggRel.getAggCallList()) { + SqlAggFunction sqlAggFunction = oldAggregateCall.getAggregation(); + if(sqlAggFunction instanceof DrillCalciteSqlWrapper) { + sqlAggFunction = (SqlAggFunction) ((DrillCalciteSqlWrapper) sqlAggFunction).getOperator(); + } + + if(sqlAggFunction instanceof SqlSumAggFunction + && !oldAggregateCall.getType().isNullable()) { + final RelDataType argType = oldAggregateCall.getType(); + final RelDataType sumType = oldAggRel.getCluster().getTypeFactory() + .createTypeWithNullability(argType, argType.isNullable()); + final SqlAggFunction sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper( + new SqlSumEmptyIsZeroAggFunction(), sumType); + AggregateCall sumZeroCall = + new AggregateCall( + sumZeroAgg, + oldAggregateCall.isDistinct(), + oldAggregateCall.getArgList(), + sumType, + null); + oldAggRel.getCluster().getRexBuilder() + .addAggCall(sumZeroCall, + oldAggRel.getGroupCount(), + oldAggRel.indicator, + newAggregateCalls, + aggCallMapping, + ImmutableList.of(argType)); + } else { + newAggregateCalls.add(oldAggregateCall); + } + } + + try { + call.transformTo(new DrillAggregateRel( + oldAggRel.getCluster(), + oldAggRel.getTraitSet(), + oldAggRel.getInput(), + oldAggRel.indicator, + oldAggRel.getGroupSet(), + oldAggRel.getGroupSets(), + newAggregateCalls)); + } catch (InvalidRelException e) { + tracer.warning(e.toString()); + } + } + } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/PreProcessLogicalRel.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/PreProcessLogicalRel.java index 8d4c1b4197f..1585a5616a8 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/PreProcessLogicalRel.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/PreProcessLogicalRel.java @@ -18,12 +18,17 @@ package org.apache.drill.exec.planner.logical; import java.util.ArrayList; -import java.util.Collections; import java.util.List; +import com.google.common.collect.Lists; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexUtil; import org.apache.drill.common.exceptions.UserException; import org.apache.drill.exec.exception.UnsupportedOperatorCollector; import org.apache.drill.exec.planner.StarColumnHelper; +import org.apache.drill.exec.planner.sql.DrillCalciteSqlWrapper; import org.apache.drill.exec.planner.sql.DrillOperatorTable; import org.apache.drill.exec.util.ApproximateStringMatcher; import org.apache.drill.exec.work.foreman.SqlUnsupportedException; @@ -33,7 +38,6 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelShuttleImpl; import org.apache.calcite.rel.logical.LogicalUnion; -import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; @@ -60,16 +64,18 @@ public class PreProcessLogicalRel extends RelShuttleImpl { private RelDataTypeFactory factory; private DrillOperatorTable table; private UnsupportedOperatorCollector unsupportedOperatorCollector; + private final UnwrappingExpressionVisitor unwrappingExpressionVisitor; - public static PreProcessLogicalRel createVisitor(RelDataTypeFactory factory, DrillOperatorTable table) { - return new PreProcessLogicalRel(factory, table); + public static PreProcessLogicalRel createVisitor(RelDataTypeFactory factory, DrillOperatorTable table, RexBuilder rexBuilder) { + return new PreProcessLogicalRel(factory, table, rexBuilder); } - private PreProcessLogicalRel(RelDataTypeFactory factory, DrillOperatorTable table) { + private PreProcessLogicalRel(RelDataTypeFactory factory, DrillOperatorTable table, RexBuilder rexBuilder) { super(); this.factory = factory; this.table = table; this.unsupportedOperatorCollector = new UnsupportedOperatorCollector(); + this.unwrappingExpressionVisitor = new UnwrappingExpressionVisitor(rexBuilder); } @Override @@ -82,12 +88,21 @@ public RelNode visit(LogicalAggregate aggregate) { throw new UnsupportedOperationException(); } } - return visitChild(aggregate, 0, aggregate.getInput()); } @Override public RelNode visit(LogicalProject project) { + final List projExpr = Lists.newArrayList(); + for(RexNode rexNode : project.getChildExps()) { + projExpr.add(rexNode.accept(unwrappingExpressionVisitor)); + } + + project = project.copy(project.getTraitSet(), + project.getInput(), + projExpr, + project.getRowType()); + List exprList = new ArrayList<>(); boolean rewrite = false; @@ -161,6 +176,29 @@ public RelNode visit(LogicalProject project) { return visitChild(project, 0, project.getInput()); } + @Override + public RelNode visit(LogicalFilter filter) { + final RexNode condition = filter.getCondition().accept(unwrappingExpressionVisitor); + filter = filter.copy( + filter.getTraitSet(), + filter.getInput(), + condition); + return visitChild(filter, 0, filter.getInput()); + } + + @Override + public RelNode visit(LogicalJoin join) { + final RexNode conditionExpr = join.getCondition().accept(unwrappingExpressionVisitor); + join = join.copy(join.getTraitSet(), + conditionExpr, + join.getLeft(), + join.getRight(), + join.getJoinType(), + join.isSemiJoinDone()); + + return visitChildren(join); + } + @Override public RelNode visit(LogicalUnion union) { for(RelNode child : union.getInputs()) { @@ -214,4 +252,29 @@ private UserException getConvertFunctionException(final String functionName, fin public void convertException() throws SqlUnsupportedException { unsupportedOperatorCollector.convertException(); } + + private static class UnwrappingExpressionVisitor extends RexShuttle { + private final RexBuilder rexBuilder; + + private UnwrappingExpressionVisitor(RexBuilder rexBuilder) { + this.rexBuilder = rexBuilder; + } + + @Override + public RexNode visitCall(final RexCall call) { + final List clonedOperands = visitList(call.operands, new boolean[]{true}); + final SqlOperator sqlOperator; + if(call.getOperator() instanceof DrillCalciteSqlWrapper) { + sqlOperator = ((DrillCalciteSqlWrapper) call.getOperator()).getOperator(); + } else { + sqlOperator = call.getOperator(); + } + + return RexUtil.flatten(rexBuilder, + rexBuilder.makeCall( + call.getType(), + sqlOperator, + clonedOperands)); + } + } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/visitor/InsertLocalExchangeVisitor.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/visitor/InsertLocalExchangeVisitor.java index ad64ed8fc72..a2f44f4757d 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/visitor/InsertLocalExchangeVisitor.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/visitor/InsertLocalExchangeVisitor.java @@ -58,7 +58,7 @@ public RexNodeBasedHashExpressionCreatorHelper(RexBuilder rexBuilder) { @Override public RexNode createCall(String funcName, List inputFields) { final DrillSqlOperator op = - new DrillSqlOperator(funcName, inputFields.size(), MajorType.getDefaultInstance(), true); + new DrillSqlOperator(funcName, inputFields.size(), true); return rexBuilder.makeCall(op, inputFields); } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/Checker.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/Checker.java index c274d2deaab..c130e1424c4 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/Checker.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/Checker.java @@ -17,18 +17,51 @@ */ package org.apache.drill.exec.planner.sql; +import com.google.common.collect.Maps; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.type.SqlOperandCountRanges; import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.commons.lang3.tuple.Pair; + +import java.util.Map; class Checker implements SqlOperandTypeChecker { private SqlOperandCountRange range; - public Checker(int size) { + public static final Checker ANY_CHECKER = new Checker(); + private static final Map, Checker> checkerMap = Maps.newHashMap(); + + public static Checker getChecker(int min, int max) { + final Pair range = Pair.of(min, max); + if(checkerMap.containsKey(range)) { + return checkerMap.get(range); + } + + final Checker newChecker; + if(min == max) { + newChecker = new Checker(min); + } else { + newChecker = new Checker(min, max); + } + + checkerMap.put(range, newChecker); + return newChecker; + } + + private Checker(int size) { range = new FixedRange(size); } + private Checker(int min, int max) { + range = SqlOperandCountRanges.between(min, max); + } + + private Checker() { + range = SqlOperandCountRanges.any(); + } + @Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { return true; diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlAggFunctionWrapper.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlAggFunctionWrapper.java new file mode 100644 index 00000000000..3795dd4651b --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlAggFunctionWrapper.java @@ -0,0 +1,162 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.sql; + +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.validate.SqlMonotonicity; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorScope; + +import org.apache.drill.exec.expr.fn.DrillFuncHolder; + +import java.util.List; + +/** + * This class serves as a wrapper class for SqlAggFunction. The motivation is to plug-in the return type inference and operand + * type check algorithms of Drill into Calcite's sql validation procedure. + * + * Except for the methods which are relevant to the return type inference and operand type check algorithms, the wrapper + * simply forwards the method calls to the wrapped SqlAggFunction. + */ +public class DrillCalciteSqlAggFunctionWrapper extends SqlAggFunction implements DrillCalciteSqlWrapper { + private final SqlAggFunction operator; + + @Override + public SqlOperator getOperator() { + return operator; + } + + private DrillCalciteSqlAggFunctionWrapper( + SqlAggFunction sqlAggFunction, + SqlReturnTypeInference sqlReturnTypeInference) { + super(sqlAggFunction.getName(), + sqlAggFunction.getSqlIdentifier(), + sqlAggFunction.getKind(), + sqlReturnTypeInference, + sqlAggFunction.getOperandTypeInference(), + Checker.ANY_CHECKER, + sqlAggFunction.getFunctionType(), + sqlAggFunction.requiresOrder(), + sqlAggFunction.requiresOver()); + this.operator = sqlAggFunction; + } + + public DrillCalciteSqlAggFunctionWrapper( + SqlAggFunction sqlAggFunction, + List functions) { + this(sqlAggFunction, + TypeInferenceUtils.getDrillSqlReturnTypeInference( + sqlAggFunction.getName(), + functions)); + } + + public DrillCalciteSqlAggFunctionWrapper( + final SqlAggFunction sqlAggFunction, + final RelDataType relDataType) { + this(sqlAggFunction, new SqlReturnTypeInference() { + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return relDataType; + } + }); + } + + @Override + public boolean validRexOperands(int count, boolean fail) { + return true; + } + + @Override + public String getAllowedSignatures(String opNameToUse) { + return operator.getAllowedSignatures(opNameToUse); + } + + @Override + public boolean isAggregator() { + return operator.isAggregator(); + } + + @Override + public boolean allowsFraming() { + return operator.allowsFraming(); + } + + @Override + public SqlMonotonicity getMonotonicity(SqlOperatorBinding call) { + return operator.getMonotonicity(call); + } + + @Override + public boolean isDeterministic() { + return operator.isDeterministic(); + } + + @Override + public boolean isDynamicFunction() { + return operator.isDynamicFunction(); + } + + @Override + public boolean requiresDecimalExpansion() { + return operator.requiresDecimalExpansion(); + } + + @Override + public boolean argumentMustBeScalar(int ordinal) { + return operator.argumentMustBeScalar(ordinal); + } + + @Override + public boolean checkOperandTypes( + SqlCallBinding callBinding, + boolean throwOnFailure) { + return true; + } + + @Override + public SqlSyntax getSyntax() { + return operator.getSyntax(); + } + + @Override + public List getParamNames() { + return operator.getParamNames(); + } + + @Override + public String getSignatureTemplate(final int operandsCount) { + return operator.getSignatureTemplate(operandsCount); + } + + @Override + public RelDataType deriveType( + SqlValidator validator, + SqlValidatorScope scope, + SqlCall call) { + return operator.deriveType(validator, + scope, + call); + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlFunctionWrapper.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlFunctionWrapper.java new file mode 100644 index 00000000000..1c61d085ff7 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlFunctionWrapper.java @@ -0,0 +1,147 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.sql; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.validate.SqlMonotonicity; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorScope; + +import org.apache.drill.exec.expr.fn.DrillFuncHolder; + +import java.util.List; + +/** + * This class serves as a wrapper class for SqlFunction. The motivation is to plug-in the return type inference and operand + * type check algorithms of Drill into Calcite's sql validation procedure. + * + * Except for the methods which are relevant to the return type inference and operand type check algorithms, the wrapper + * simply forwards the method calls to the wrapped SqlFunction. + */ +public class DrillCalciteSqlFunctionWrapper extends SqlFunction implements DrillCalciteSqlWrapper { + private final SqlFunction operator; + + public DrillCalciteSqlFunctionWrapper( + final SqlFunction wrappedFunction, + final List functions) { + super(wrappedFunction.getName(), + wrappedFunction.getSqlIdentifier(), + wrappedFunction.getKind(), + TypeInferenceUtils.getDrillSqlReturnTypeInference( + wrappedFunction.getName(), + functions), + wrappedFunction.getOperandTypeInference(), + Checker.ANY_CHECKER, + wrappedFunction.getParamTypes(), + wrappedFunction.getFunctionType()); + this.operator = wrappedFunction; + } + + @Override + public SqlOperator getOperator() { + return operator; + } + + @Override + public boolean validRexOperands(int count, boolean fail) { + return true; + } + + @Override + public String getAllowedSignatures(String opNameToUse) { + return operator.getAllowedSignatures(opNameToUse); + } + + @Override + public SqlMonotonicity getMonotonicity(SqlOperatorBinding call) { + return operator.getMonotonicity(call); + } + + @Override + public boolean isDeterministic() { + return operator.isDeterministic(); + } + + @Override + public boolean isDynamicFunction() { + return operator.isDynamicFunction(); + } + + @Override + public boolean requiresDecimalExpansion() { + return operator.requiresDecimalExpansion(); + } + + @Override + public boolean argumentMustBeScalar(int ordinal) { + return operator.argumentMustBeScalar(ordinal); + } + + @Override + public boolean checkOperandTypes( + SqlCallBinding callBinding, + boolean throwOnFailure) { + return true; + } + + @Override + public SqlSyntax getSyntax() { + return operator.getSyntax(); + } + + @Override + public List getParamNames() { + return operator.getParamNames(); + } + + @Override + public String getSignatureTemplate(final int operandsCount) { + return operator.getSignatureTemplate(operandsCount); + } + + @Override + public RelDataType deriveType( + SqlValidator validator, + SqlValidatorScope scope, + SqlCall call) { + return operator.deriveType(validator, + scope, + call); + } + + @Override + public String toString() { + return operator.toString(); + } + + @Override + public void unparse( + SqlWriter writer, + SqlCall call, + int leftPrec, + int rightPrec) { + operator.unparse(writer, call, leftPrec, rightPrec); + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlOperatorWrapper.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlOperatorWrapper.java new file mode 100644 index 00000000000..28c1cecb251 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlOperatorWrapper.java @@ -0,0 +1,140 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.sql; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.validate.SqlMonotonicity; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorScope; +import org.apache.drill.exec.expr.fn.DrillFuncHolder; + +import java.util.List; + +/** + * This class serves as a wrapper class for SqlOperator. The motivation is to plug-in the return type inference and operand + * type check algorithms of Drill into Calcite's sql validation procedure. + * + * Except for the methods which are relevant to the return type inference and operand type check algorithms, the wrapper + * simply forwards the method calls to the wrapped SqlOperator. + */ +public class DrillCalciteSqlOperatorWrapper extends SqlOperator implements DrillCalciteSqlWrapper { + public final SqlOperator operator; + + public DrillCalciteSqlOperatorWrapper(SqlOperator operator, final String rename, final List functions) { + super( + operator.getName(), + operator.getKind(), + operator.getLeftPrec(), + operator.getRightPrec(), + TypeInferenceUtils.getDrillSqlReturnTypeInference( + rename, + functions), + operator.getOperandTypeInference(), + Checker.ANY_CHECKER); + this.operator = operator; + } + + @Override + public SqlOperator getOperator() { + return operator; + } + + @Override + public SqlSyntax getSyntax() { + return operator.getSyntax(); + } + + @Override + public SqlCall createCall( + SqlLiteral functionQualifier, + SqlParserPos pos, + SqlNode... operands) { + return operator.createCall(functionQualifier, pos, operands); + } + + @Override + public SqlNode rewriteCall(SqlValidator validator, SqlCall call) { + return operator.rewriteCall(validator, call); + } + + + @Override + public boolean checkOperandTypes( + SqlCallBinding callBinding, + boolean throwOnFailure) { + return true; + } + + @Override + public boolean validRexOperands(int count, boolean fail) { + return true; + } + + @Override + public String getSignatureTemplate(final int operandsCount) { + return operator.getSignatureTemplate(operandsCount); + } + + @Override + public String getAllowedSignatures(String opNameToUse) { + return operator.getAllowedSignatures(opNameToUse); + } + + @Override + public SqlMonotonicity getMonotonicity(SqlOperatorBinding call) { + return operator.getMonotonicity(call); + } + + @Override + public boolean isDeterministic() { + return operator.isDeterministic(); + } + + @Override + public boolean requiresDecimalExpansion() { + return operator.requiresDecimalExpansion(); + } + + @Override + public boolean argumentMustBeScalar(int ordinal) { + return operator.argumentMustBeScalar(ordinal); + } + + @Override + public String toString() { + return operator.toString(); + } + + @Override + public void unparse( + SqlWriter writer, + SqlCall call, + int leftPrec, + int rightPrec) { + operator.unparse(writer, call, leftPrec, rightPrec); + } +} \ No newline at end of file diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlWrapper.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlWrapper.java new file mode 100644 index 00000000000..8410e6778a7 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlWrapper.java @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.sql; + +import org.apache.calcite.sql.SqlOperator; +/** + * This interface is meant for the users of the wrappers, {@link DrillCalciteSqlOperatorWrapper}, + * {@link DrillCalciteSqlFunctionWrapper} and {@link DrillCalciteSqlAggFunctionWrapper}, to access the wrapped Calcite + * {@link SqlOperator} without knowing exactly which wrapper it is. + */ +public interface DrillCalciteSqlWrapper { + /** + * Get the wrapped {@link SqlOperator} + * + * @return SqlOperator get the wrapped {@link SqlOperator} + */ + SqlOperator getOperator(); +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java index 4ade513fc10..6b81bf0fc35 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java @@ -19,6 +19,7 @@ import java.util.HashMap; +import org.apache.calcite.sql.SqlBasicCall; import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlAvgAggFunction; @@ -49,8 +50,19 @@ public class DrillConvertletTable implements SqlRexConvertletTable{ */ @Override public SqlRexConvertlet get(SqlCall call) { - SqlRexConvertlet convertlet; + if(call.getOperator() instanceof DrillCalciteSqlWrapper) { + final SqlOperator wrapper = call.getOperator(); + final SqlOperator wrapped = ((DrillCalciteSqlWrapper) call.getOperator()).getOperator(); + if ((convertlet = map.get(wrapped)) != null) { + return convertlet; + } + + ((SqlBasicCall) call).setOperator(wrapped); + SqlRexConvertlet sqlRexConvertlet = StandardConvertletTable.INSTANCE.get(call); + ((SqlBasicCall) call).setOperator(wrapper); + return sqlRexConvertlet; + } if ((convertlet = map.get(call.getOperator())) != null) { return convertlet; diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillExtractConvertlet.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillExtractConvertlet.java index 7bf2584c245..61e1e07e817 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillExtractConvertlet.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillExtractConvertlet.java @@ -25,6 +25,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlIntervalQualifier; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql2rel.SqlRexContext; @@ -50,6 +51,8 @@ public RexNode convertCall(SqlRexContext cx, SqlCall call) { final List operands = call.getOperandList(); final List exprs = new LinkedList<>(); + String timeUnit = ((SqlIntervalQualifier) operands.get(0)).timeUnitRange.toString(); + RelDataTypeFactory typeFactory = cx.getTypeFactory(); //RelDataType nullableReturnType = @@ -59,7 +62,10 @@ public RexNode convertCall(SqlRexContext cx, SqlCall call) { } // Determine NULL-able using 2nd argument's Null-able. - RelDataType returnType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), exprs.get(1).getType().isNullable()); + RelDataType returnType = typeFactory.createTypeWithNullability( + typeFactory.createSqlType( + TypeInferenceUtils.getSqlTypeNameForTimeUnit(timeUnit)), + exprs.get(1).getType().isNullable()); return rexBuilder.makeCall(returnType, call.getOperator(), exprs); } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java index 4dd796374e0..7fe6020b772 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java @@ -19,6 +19,12 @@ import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlPrefixOperator; +import org.apache.drill.common.expression.FunctionCallFactory; +import org.apache.drill.exec.expr.fn.DrillFuncHolder; import org.apache.drill.exec.expr.fn.FunctionImplementationRegistry; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlIdentifier; @@ -28,34 +34,50 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import java.util.List; +import java.util.Map; +/** + * Implementation of {@link SqlOperatorTable} that contains standard operators and functions provided through + * {@link #inner SqlStdOperatorTable}, and Drill User Defined Functions. + */ public class DrillOperatorTable extends SqlStdOperatorTable { - static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillOperatorTable.class); - +// private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillOperatorTable.class); private static final SqlOperatorTable inner = SqlStdOperatorTable.instance(); - private List operators; + private List operators = Lists.newArrayList(); + private final Map calciteToWrapper = Maps.newIdentityHashMap(); private ArrayListMultimap opMap = ArrayListMultimap.create(); public DrillOperatorTable(FunctionImplementationRegistry registry) { - operators = Lists.newArrayList(); - operators.addAll(inner.getOperatorList()); - registry.register(this); + operators.addAll(inner.getOperatorList()); + populateWrappedCalciteOperators(); } public void add(String name, SqlOperator op) { operators.add(op); - opMap.put(name, op); + opMap.put(name.toLowerCase(), op); } @Override - public void lookupOperatorOverloads(SqlIdentifier opName, SqlFunctionCategory category, SqlSyntax syntax, List operatorList) { - inner.lookupOperatorOverloads(opName, category, syntax, operatorList); - - if (operatorList.isEmpty() && syntax == SqlSyntax.FUNCTION && opName.isSimple()) { - List drillOps = opMap.get(opName.getSimple().toLowerCase()); - if (drillOps != null) { - operatorList.addAll(drillOps); + public void lookupOperatorOverloads(SqlIdentifier opName, SqlFunctionCategory category, + SqlSyntax syntax, List operatorList) { + final List calciteOperatorList = Lists.newArrayList(); + inner.lookupOperatorOverloads(opName, category, syntax, calciteOperatorList); + if(!calciteOperatorList.isEmpty()) { + for(SqlOperator calciteOperator : calciteOperatorList) { + if(calciteToWrapper.containsKey(calciteOperator)) { + operatorList.add(calciteToWrapper.get(calciteOperator)); + } else { + operatorList.add(calciteOperator); + } + } + } else { + // if no function is found, check in Drill UDFs + if (operatorList.isEmpty() && syntax == SqlSyntax.FUNCTION && opName.isSimple()) { + List drillOps = opMap.get(opName.getSimple().toLowerCase()); + if (drillOps != null && !drillOps.isEmpty()) { + operatorList.addAll(drillOps); + } } } } @@ -69,4 +91,46 @@ public List getOperatorList() { public List getSqlOperator(String name) { return opMap.get(name.toLowerCase()); } + + private void populateWrappedCalciteOperators() { + for(SqlOperator calciteOperator : inner.getOperatorList()) { + final SqlOperator wrapper; + if(calciteOperator instanceof SqlAggFunction) { + wrapper = new DrillCalciteSqlAggFunctionWrapper((SqlAggFunction) calciteOperator, + getFunctionList(calciteOperator.getName())); + } else if(calciteOperator instanceof SqlFunction) { + wrapper = new DrillCalciteSqlFunctionWrapper((SqlFunction) calciteOperator, + getFunctionList(calciteOperator.getName())); + } else { + final String drillOpName = FunctionCallFactory.replaceOpWithFuncName(calciteOperator.getName()); + final List drillFuncHolders = getFunctionList(drillOpName); + if(drillFuncHolders.isEmpty() || calciteOperator == SqlStdOperatorTable.UNARY_MINUS) { + continue; + } + + wrapper = new DrillCalciteSqlOperatorWrapper(calciteOperator, drillOpName, drillFuncHolders); + } + calciteToWrapper.put(calciteOperator, wrapper); + } + } + + private List getFunctionList(String name) { + final List functions = Lists.newArrayList(); + for(SqlOperator sqlOperator : opMap.get(name.toLowerCase())) { + if(sqlOperator instanceof DrillSqlOperator) { + final List list = ((DrillSqlOperator) sqlOperator).getFunctions(); + if(list != null) { + functions.addAll(list); + } + } + + if(sqlOperator instanceof DrillSqlAggOperator) { + final List list = ((DrillSqlAggOperator) sqlOperator).getFunctions(); + if(list != null) { + functions.addAll(list); + } + } + } + return functions; + } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java index 213620162d2..81c744c2fce 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + *

* http://www.apache.org/licenses/LICENSE-2.0 - * + *

* Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,47 +17,36 @@ */ package org.apache.drill.exec.planner.sql; -import java.util.List; - import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.sql.SqlAggFunction; -import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.sql.validate.SqlValidator; -import org.apache.calcite.sql.validate.SqlValidatorScope; +import org.apache.drill.exec.expr.fn.DrillFuncHolder; -import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; public class DrillSqlAggOperator extends SqlAggFunction { - static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillSqlAggOperator.class); - - - public DrillSqlAggOperator(String name, int argCount) { - super(name, new SqlIdentifier(name, SqlParserPos.ZERO), SqlKind.OTHER_FUNCTION, DynamicReturnType.INSTANCE, null, new Checker(argCount), SqlFunctionCategory.USER_DEFINED_FUNCTION); - } - - @Override - public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { - return getAny(validator.getTypeFactory()); + // private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillSqlAggOperator.class); + private final List functions; + + public DrillSqlAggOperator(String name, List functions, int argCount) { + super(name, + new SqlIdentifier(name, SqlParserPos.ZERO), + SqlKind.OTHER_FUNCTION, + TypeInferenceUtils.getDrillSqlReturnTypeInference( + name, + functions), + null, + Checker.getChecker(argCount, argCount), + SqlFunctionCategory.USER_DEFINED_FUNCTION); + this.functions = functions; } - private RelDataType getAny(RelDataTypeFactory factory){ - return factory.createSqlType(SqlTypeName.ANY); -// return new RelDataTypeDrillImpl(new RelDataTypeHolder(), factory); + public List getFunctions() { + return functions; } - -// @Override -// public List getParameterTypes(RelDataTypeFactory typeFactory) { -// return ImmutableList.of(typeFactory.createSqlType(SqlTypeName.ANY)); -// } -// -// @Override -// public RelDataType getReturnType(RelDataTypeFactory typeFactory) { -// return getAny(typeFactory); -// } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java index 7b5a99dbc2e..0873c8df3ea 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + *

* http://www.apache.org/licenses/LICENSE-2.0 - * + *

* Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,69 +18,50 @@ package org.apache.drill.exec.planner.sql; -import com.google.common.base.Preconditions; -import org.apache.drill.common.types.TypeProtos.MajorType; -import org.apache.drill.common.types.TypeProtos.MinorType; -import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.sql.SqlCall; +import java.util.ArrayList; +import java.util.List; + import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlIdentifier; -import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.sql.validate.SqlValidator; -import org.apache.calcite.sql.validate.SqlValidatorScope; +import org.apache.drill.exec.expr.fn.DrillFuncHolder; public class DrillSqlOperator extends SqlFunction { - static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillSqlOperator.class); - - private static final MajorType NONE = MajorType.getDefaultInstance(); - private final MajorType returnType; + // static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillSqlOperator.class); private final boolean isDeterministic; + private final List functions; + /** + * This constructor exists for the legacy reason. + * + * It is because Drill cannot access to DrillOperatorTable at the place where this constructor is being called. + * In principle, if Drill needs a DrillSqlOperator, it is supposed to go to DrillOperatorTable for pickup. + */ + @Deprecated public DrillSqlOperator(String name, int argCount, boolean isDeterministic) { - this(name, argCount, MajorType.getDefaultInstance(), isDeterministic); + this(name, new ArrayList(), argCount, argCount, isDeterministic); } - public DrillSqlOperator(String name, int argCount, MajorType returnType, boolean isDeterminisitic) { - super(new SqlIdentifier(name, SqlParserPos.ZERO), DynamicReturnType.INSTANCE, null, new Checker(argCount), null, SqlFunctionCategory.USER_DEFINED_FUNCTION); - this.returnType = Preconditions.checkNotNull(returnType); - this.isDeterministic = isDeterminisitic; + public DrillSqlOperator(String name, List functions, int argCountMin, int argCountMax, boolean isDeterministic) { + super(new SqlIdentifier(name, SqlParserPos.ZERO), + TypeInferenceUtils.getDrillSqlReturnTypeInference( + name, + functions), + null, + Checker.getChecker(argCountMin, argCountMax), + null, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + this.functions = functions; + this.isDeterministic = isDeterministic; } + @Override public boolean isDeterministic() { return isDeterministic; } - protected RelDataType getReturnDataType(final RelDataTypeFactory factory) { - if (MinorType.BIT.equals(returnType.getMinorType())) { - return factory.createSqlType(SqlTypeName.BOOLEAN); - } - return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true); - } - - private RelDataType getNullableReturnDataType(final RelDataTypeFactory factory) { - return factory.createTypeWithNullability(getReturnDataType(factory), true); - } - - @Override - public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { - if (NONE.equals(returnType)) { - return validator.getTypeFactory().createSqlType(SqlTypeName.ANY); - } - /* - * We return a nullable output type both in validation phase and in - * Sql to Rel phase. We don't know the type of the output until runtime - * hence have to choose the least restrictive type to avoid any wrong - * results. - */ - return getNullableReturnDataType(validator.getTypeFactory()); - } - - @Override - public RelDataType inferReturnType(SqlOperatorBinding opBinding) { - return getNullableReturnDataType(opBinding.getTypeFactory()); + public List getFunctions() { + return functions; } -} +} \ No newline at end of file diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java new file mode 100644 index 00000000000..8914b1133d4 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java @@ -0,0 +1,649 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.sql; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; + +import org.apache.calcite.avatica.util.TimeUnit; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlCharStringLiteral; +import org.apache.calcite.sql.SqlDynamicParam; +import org.apache.calcite.sql.SqlIntervalQualifier; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; + +import org.apache.drill.common.expression.ExpressionPosition; +import org.apache.drill.common.expression.FunctionCall; +import org.apache.drill.common.expression.FunctionCallFactory; +import org.apache.drill.common.expression.LogicalExpression; +import org.apache.drill.common.expression.MajorTypeInLogicalExpression; +import org.apache.drill.common.exceptions.UserException; +import org.apache.drill.common.types.TypeProtos; +import org.apache.drill.common.types.Types; +import org.apache.drill.exec.expr.TypeHelper; +import org.apache.drill.exec.expr.fn.DrillFuncHolder; +import org.apache.drill.exec.resolver.FunctionResolver; +import org.apache.drill.exec.resolver.FunctionResolverFactory; +import org.apache.drill.exec.resolver.TypeCastRules; + +import java.util.List; + +public class TypeInferenceUtils { + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(TypeInferenceUtils.class); + + public static final TypeProtos.MajorType UNKNOWN_TYPE = TypeProtos.MajorType.getDefaultInstance(); + private static final ImmutableMap DRILL_TO_CALCITE_TYPE_MAPPING = ImmutableMap. builder() + .put(TypeProtos.MinorType.INT, SqlTypeName.INTEGER) + .put(TypeProtos.MinorType.BIGINT, SqlTypeName.BIGINT) + .put(TypeProtos.MinorType.FLOAT4, SqlTypeName.FLOAT) + .put(TypeProtos.MinorType.FLOAT8, SqlTypeName.DOUBLE) + .put(TypeProtos.MinorType.VARCHAR, SqlTypeName.VARCHAR) + .put(TypeProtos.MinorType.BIT, SqlTypeName.BOOLEAN) + .put(TypeProtos.MinorType.DATE, SqlTypeName.DATE) + .put(TypeProtos.MinorType.DECIMAL9, SqlTypeName.DECIMAL) + .put(TypeProtos.MinorType.DECIMAL18, SqlTypeName.DECIMAL) + .put(TypeProtos.MinorType.DECIMAL28SPARSE, SqlTypeName.DECIMAL) + .put(TypeProtos.MinorType.DECIMAL38SPARSE, SqlTypeName.DECIMAL) + .put(TypeProtos.MinorType.TIME, SqlTypeName.TIME) + .put(TypeProtos.MinorType.TIMESTAMP, SqlTypeName.TIMESTAMP) + .put(TypeProtos.MinorType.VARBINARY, SqlTypeName.VARBINARY) + .put(TypeProtos.MinorType.INTERVALYEAR, SqlTypeName.INTERVAL_YEAR_MONTH) + .put(TypeProtos.MinorType.INTERVALDAY, SqlTypeName.INTERVAL_DAY_TIME) + .put(TypeProtos.MinorType.MAP, SqlTypeName.MAP) + .put(TypeProtos.MinorType.LIST, SqlTypeName.ARRAY) + .put(TypeProtos.MinorType.LATE, SqlTypeName.ANY) + + // These are defined in the Drill type system but have been turned off for now + // .put(TypeProtos.MinorType.TINYINT, SqlTypeName.TINYINT) + // .put(TypeProtos.MinorType.SMALLINT, SqlTypeName.SMALLINT) + // Calcite types currently not supported by Drill, nor defined in the Drill type list: + // - CHAR, SYMBOL, MULTISET, DISTINCT, STRUCTURED, ROW, OTHER, CURSOR, COLUMN_LIST + .build(); + + private static final ImmutableMap CALCITE_TO_DRILL_MAPPING = ImmutableMap. builder() + .put(SqlTypeName.INTEGER, TypeProtos.MinorType.INT) + .put(SqlTypeName.BIGINT, TypeProtos.MinorType.BIGINT) + .put(SqlTypeName.FLOAT, TypeProtos.MinorType.FLOAT4) + .put(SqlTypeName.DOUBLE, TypeProtos.MinorType.FLOAT8) + .put(SqlTypeName.VARCHAR, TypeProtos.MinorType.VARCHAR) + .put(SqlTypeName.BOOLEAN, TypeProtos.MinorType.BIT) + .put(SqlTypeName.DATE, TypeProtos.MinorType.DATE) + .put(SqlTypeName.TIME, TypeProtos.MinorType.TIME) + .put(SqlTypeName.TIMESTAMP, TypeProtos.MinorType.TIMESTAMP) + .put(SqlTypeName.VARBINARY, TypeProtos.MinorType.VARBINARY) + .put(SqlTypeName.INTERVAL_YEAR_MONTH, TypeProtos.MinorType.INTERVALYEAR) + .put(SqlTypeName.INTERVAL_DAY_TIME, TypeProtos.MinorType.INTERVALDAY) + + // SqlTypeName.CHAR is the type for Literals in Calcite, Drill treats Literals as VARCHAR also + .put(SqlTypeName.CHAR, TypeProtos.MinorType.VARCHAR) + + // The following types are not added due to a variety of reasons: + // (1) Disabling decimal type + //.put(SqlTypeName.DECIMAL, TypeProtos.MinorType.DECIMAL9) + //.put(SqlTypeName.DECIMAL, TypeProtos.MinorType.DECIMAL18) + //.put(SqlTypeName.DECIMAL, TypeProtos.MinorType.DECIMAL28SPARSE) + //.put(SqlTypeName.DECIMAL, TypeProtos.MinorType.DECIMAL38SPARSE) + + // (2) These 2 types are defined in the Drill type system but have been turned off for now + // .put(SqlTypeName.TINYINT, TypeProtos.MinorType.TINYINT) + // .put(SqlTypeName.SMALLINT, TypeProtos.MinorType.SMALLINT) + + // (3) Calcite types currently not supported by Drill, nor defined in the Drill type list: + // - SYMBOL, MULTISET, DISTINCT, STRUCTURED, ROW, OTHER, CURSOR, COLUMN_LIST + // .put(SqlTypeName.MAP, TypeProtos.MinorType.MAP) + // .put(SqlTypeName.ARRAY, TypeProtos.MinorType.LIST) + .build(); + + private static final ImmutableMap funcNameToInference = ImmutableMap. builder() + .put("DATE_PART", DrillDatePartSqlReturnTypeInference.INSTANCE) + .put("SUM", DrillSumSqlReturnTypeInference.INSTANCE) + .put("COUNT", DrillCountSqlReturnTypeInference.INSTANCE) + .put("CONCAT", DrillConcatSqlReturnTypeInference.INSTANCE) + .put("LENGTH", DrillLengthSqlReturnTypeInference.INSTANCE) + .put("LPAD", DrillPadTrimSqlReturnTypeInference.INSTANCE) + .put("RPAD", DrillPadTrimSqlReturnTypeInference.INSTANCE) + .put("LTRIM", DrillPadTrimSqlReturnTypeInference.INSTANCE) + .put("RTRIM", DrillPadTrimSqlReturnTypeInference.INSTANCE) + .put("BTRIM", DrillPadTrimSqlReturnTypeInference.INSTANCE) + .put("TRIM", DrillPadTrimSqlReturnTypeInference.INSTANCE) + .put("CONVERT_TO", DrillConvertToSqlReturnTypeInference.INSTANCE) + .put("EXTRACT", DrillExtractSqlReturnTypeInference.INSTANCE) + .put("SQRT", DrillSqrtSqlReturnTypeInference.INSTANCE) + .put("CAST", DrillCastSqlReturnTypeInference.INSTANCE) + .put("FLATTEN", DrillDeferToExecSqlReturnTypeInference.INSTANCE) + .put("KVGEN", DrillDeferToExecSqlReturnTypeInference.INSTANCE) + .put("CONVERT_FROM", DrillDeferToExecSqlReturnTypeInference.INSTANCE) + .build(); + + /** + * Given a Drill's TypeProtos.MinorType, return a Calcite's corresponding SqlTypeName + */ + public static SqlTypeName getCalciteTypeFromDrillType(final TypeProtos.MinorType type) { + if(!DRILL_TO_CALCITE_TYPE_MAPPING.containsKey(type)) { + return SqlTypeName.ANY; + } + + return DRILL_TO_CALCITE_TYPE_MAPPING.get(type); + } + + /** + * Given a Calcite's RelDataType, return a Drill's corresponding TypeProtos.MinorType + */ + public static TypeProtos.MinorType getDrillTypeFromCalciteType(final RelDataType relDataType) { + final SqlTypeName sqlTypeName = relDataType.getSqlTypeName(); + return getDrillTypeFromCalciteType(sqlTypeName); + } + + /** + * Given a Calcite's SqlTypeName, return a Drill's corresponding TypeProtos.MinorType + */ + public static TypeProtos.MinorType getDrillTypeFromCalciteType(final SqlTypeName sqlTypeName) { + if(!CALCITE_TO_DRILL_MAPPING.containsKey(sqlTypeName)) { + return TypeProtos.MinorType.LATE; + } + + return CALCITE_TO_DRILL_MAPPING.get(sqlTypeName); + } + + /** + * Give the name and DrillFuncHolder list, return the inference mechanism. + */ + public static SqlReturnTypeInference getDrillSqlReturnTypeInference( + final String name, + final List functions) { + + final String nameCap = name.toUpperCase(); + if(funcNameToInference.containsKey(nameCap)) { + return funcNameToInference.get(nameCap); + } else { + return new DrillDefaultSqlReturnTypeInference(functions); + } + } + + private static class DrillDefaultSqlReturnTypeInference implements SqlReturnTypeInference { + private final List functions; + + public DrillDefaultSqlReturnTypeInference(List functions) { + this.functions = functions; + } + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + if (functions.isEmpty()) { + return factory.createTypeWithNullability( + factory.createSqlType(SqlTypeName.ANY), + true); + } + + // The following logic is just a safe play: + // Even if any of the input arguments has ANY type, + // it "might" still be possible to determine the return type based on other non-ANY types + for (RelDataType type : opBinding.collectOperandTypes()) { + if (getDrillTypeFromCalciteType(type) == TypeProtos.MinorType.LATE) { + // This code for boolean output type is added for addressing DRILL-1729 + // In summary, if we have a boolean output function in the WHERE-CLAUSE, + // this logic can validate and execute user queries seamlessly + boolean allBooleanOutput = true; + for (DrillFuncHolder function : functions) { + if (function.getReturnType().getMinorType() != TypeProtos.MinorType.BIT) { + allBooleanOutput = false; + break; + } + } + + if(allBooleanOutput) { + return factory.createTypeWithNullability( + factory.createSqlType(SqlTypeName.BOOLEAN), true); + } else { + return factory.createTypeWithNullability( + factory.createSqlType(SqlTypeName.ANY), + true); + } + } + } + + final DrillFuncHolder func = resolveDrillFuncHolder(opBinding, functions); + final RelDataType returnType = getReturnType(opBinding, func); + return returnType; + } + + private static RelDataType getReturnType(final SqlOperatorBinding opBinding, final DrillFuncHolder func) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + + // least restrictive type (nullable ANY type) + final RelDataType nullableAnyType = factory.createTypeWithNullability( + factory.createSqlType(SqlTypeName.ANY), + true); + + final TypeProtos.MajorType returnType = func.getReturnType(); + if (UNKNOWN_TYPE.equals(returnType)) { + return nullableAnyType; + } + + final TypeProtos.MinorType minorType = returnType.getMinorType(); + final SqlTypeName sqlTypeName = getCalciteTypeFromDrillType(minorType); + if (sqlTypeName == null) { + return nullableAnyType; + } + + final boolean isNullable; + switch (returnType.getMode()) { + case REPEATED: + case OPTIONAL: + isNullable = true; + break; + + case REQUIRED: + switch (func.getNullHandling()) { + case INTERNAL: + isNullable = false; + break; + + case NULL_IF_NULL: + boolean isNull = false; + for (int i = 0; i < opBinding.getOperandCount(); ++i) { + if (opBinding.getOperandType(i).isNullable()) { + isNull = true; + break; + } + } + + isNullable = isNull; + break; + default: + throw new UnsupportedOperationException(); + } + break; + + default: + throw new UnsupportedOperationException(); + } + + return createCalciteTypeWithNullability( + factory, + sqlTypeName, + isNullable); + } + } + + private static class DrillDeferToExecSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillDeferToExecSqlReturnTypeInference INSTANCE = new DrillDeferToExecSqlReturnTypeInference(); + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + return factory.createTypeWithNullability( + factory.createSqlType(SqlTypeName.ANY), + true); + } + } + + private static class DrillSumSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillSumSqlReturnTypeInference INSTANCE = new DrillSumSqlReturnTypeInference(); + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + // If there is group-by and the imput type is Non-nullable, + // the output is Non-nullable; + // Otherwise, the output is nullable. + final boolean isNullable = opBinding.getGroupCount() == 0 + || opBinding.getOperandType(0).isNullable(); + + if(getDrillTypeFromCalciteType(opBinding.getOperandType(0)) == TypeProtos.MinorType.LATE) { + return createCalciteTypeWithNullability( + factory, + SqlTypeName.ANY, + isNullable); + } + + final RelDataType operandType = opBinding.getOperandType(0); + final TypeProtos.MinorType inputMinorType = getDrillTypeFromCalciteType(operandType); + if(TypeCastRules.getLeastRestrictiveType(Lists.newArrayList(inputMinorType, TypeProtos.MinorType.BIGINT)) + == TypeProtos.MinorType.BIGINT) { + return createCalciteTypeWithNullability( + factory, + SqlTypeName.BIGINT, + isNullable); + } else if(TypeCastRules.getLeastRestrictiveType(Lists.newArrayList(inputMinorType, TypeProtos.MinorType.FLOAT8)) + == TypeProtos.MinorType.FLOAT8) { + return createCalciteTypeWithNullability( + factory, + SqlTypeName.DOUBLE, + isNullable); + } else { + throw UserException + .functionError() + .message(String.format("%s does not support operand types (%s)", + opBinding.getOperator().getName(), + opBinding.getOperandType(0).getSqlTypeName())) + .build(logger); + } + } + } + + private static class DrillCountSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillCountSqlReturnTypeInference INSTANCE = new DrillCountSqlReturnTypeInference(); + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + final SqlTypeName type = SqlTypeName.BIGINT; + return createCalciteTypeWithNullability( + factory, + type, + false); + } + } + + private static class DrillConcatSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillConcatSqlReturnTypeInference INSTANCE = new DrillConcatSqlReturnTypeInference(); + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + + boolean isNullable = true; + int precision = 0; + for(RelDataType relDataType : opBinding.collectOperandTypes()) { + if(!relDataType.isNullable()) { + isNullable = false; + } + + // If the underlying columns cannot offer information regarding the precision (i.e., the length) of the VarChar, + // Drill uses the largest to represent it + if(relDataType.getPrecision() == TypeHelper.VARCHAR_DEFAULT_CAST_LEN + || relDataType.getPrecision() == RelDataType.PRECISION_NOT_SPECIFIED) { + precision = TypeHelper.VARCHAR_DEFAULT_CAST_LEN; + } else { + precision += relDataType.getPrecision(); + } + } + + return factory.createTypeWithNullability( + factory.createSqlType(SqlTypeName.VARCHAR, precision), + isNullable); + } + } + + private static class DrillLengthSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillLengthSqlReturnTypeInference INSTANCE = new DrillLengthSqlReturnTypeInference(); + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + final SqlTypeName sqlTypeName = SqlTypeName.BIGINT; + + // We need to check only the first argument because + // the second one is used to represent encoding type + final boolean isNullable = opBinding.getOperandType(0).isNullable(); + return createCalciteTypeWithNullability( + factory, + sqlTypeName, + isNullable); + } + } + + private static class DrillPadTrimSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillPadTrimSqlReturnTypeInference INSTANCE = new DrillPadTrimSqlReturnTypeInference(); + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + final SqlTypeName sqlTypeName = SqlTypeName.VARCHAR; + + for(int i = 0; i < opBinding.getOperandCount(); ++i) { + if(opBinding.getOperandType(i).isNullable()) { + return createCalciteTypeWithNullability( + factory, sqlTypeName, true); + } + } + + return createCalciteTypeWithNullability( + factory, sqlTypeName, false); + } + } + + private static class DrillConvertToSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillConvertToSqlReturnTypeInference INSTANCE = new DrillConvertToSqlReturnTypeInference(); + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + final SqlTypeName type = SqlTypeName.VARBINARY; + + return createCalciteTypeWithNullability( + factory, type, opBinding.getOperandType(0).isNullable()); + } + } + + private static class DrillExtractSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillExtractSqlReturnTypeInference INSTANCE = new DrillExtractSqlReturnTypeInference(); + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + final TimeUnit timeUnit = opBinding.getOperandType(0).getIntervalQualifier().getStartUnit(); + final boolean isNullable = opBinding.getOperandType(1).isNullable(); + + final SqlTypeName sqlTypeName = getSqlTypeNameForTimeUnit(timeUnit.name()); + return createCalciteTypeWithNullability( + factory, + sqlTypeName, + isNullable); + } + } + + private static class DrillSqrtSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillSqrtSqlReturnTypeInference INSTANCE = new DrillSqrtSqlReturnTypeInference(); + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + final boolean isNullable = opBinding.getOperandType(0).isNullable(); + return createCalciteTypeWithNullability( + factory, + SqlTypeName.DOUBLE, + isNullable); + } + } + + private static class DrillDatePartSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillDatePartSqlReturnTypeInference INSTANCE = new DrillDatePartSqlReturnTypeInference(); + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + + final SqlNode firstOperand = ((SqlCallBinding) opBinding).operand(0); + if(!(firstOperand instanceof SqlCharStringLiteral)) { + return createCalciteTypeWithNullability(factory, + SqlTypeName.ANY, + opBinding.getOperandType(1).isNullable()); + } + + final String part = ((SqlCharStringLiteral) firstOperand) + .getNlsString() + .getValue() + .toUpperCase(); + + final SqlTypeName sqlTypeName = getSqlTypeNameForTimeUnit(part); + final boolean isNullable = opBinding.getOperandType(1).isNullable(); + return createCalciteTypeWithNullability( + factory, + sqlTypeName, + isNullable); + } + } + + private static class DrillCastSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillCastSqlReturnTypeInference INSTANCE = new DrillCastSqlReturnTypeInference(); + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + final boolean isNullable = opBinding + .getOperandType(0) + .isNullable(); + + RelDataType ret = factory.createTypeWithNullability( + opBinding.getOperandType(1), + isNullable); + + if (opBinding instanceof SqlCallBinding) { + SqlCallBinding callBinding = (SqlCallBinding) opBinding; + SqlNode operand0 = callBinding.operand(0); + + // dynamic parameters and null constants need their types assigned + // to them using the type they are casted to. + if (((operand0 instanceof SqlLiteral) + && (((SqlLiteral) operand0).getValue() == null)) + || (operand0 instanceof SqlDynamicParam)) { + callBinding.getValidator().setValidatedNodeType( + operand0, + ret); + } + } + + return ret; + } + } + + private static DrillFuncHolder resolveDrillFuncHolder(final SqlOperatorBinding opBinding, final List functions) { + final FunctionCall functionCall = convertSqlOperatorBindingToFunctionCall(opBinding); + final FunctionResolver functionResolver = FunctionResolverFactory.getResolver(functionCall); + final DrillFuncHolder func = functionResolver.getBestMatch(functions, functionCall); + + // Throw an exception + // if no DrillFuncHolder matched for the given list of operand types + if(func == null) { + String operandTypes = ""; + for(int i = 0; i < opBinding.getOperandCount(); ++i) { + operandTypes += opBinding.getOperandType(i).getSqlTypeName(); + if(i < opBinding.getOperandCount() - 1) { + operandTypes += ","; + } + } + + throw UserException + .functionError() + .message(String.format("%s does not support operand types (%s)", + opBinding.getOperator().getName(), + operandTypes)) + .build(logger); + } + return func; + } + + /** + * For Extract and date_part functions, infer the return types based on timeUnit + */ + public static SqlTypeName getSqlTypeNameForTimeUnit(String timeUnit) { + switch (timeUnit.toUpperCase()){ + case "YEAR": + case "MONTH": + case "DAY": + case "HOUR": + case "MINUTE": + return SqlTypeName.BIGINT; + case "SECOND": + return SqlTypeName.DOUBLE; + default: + throw UserException + .functionError() + .message("extract function supports the following time units: YEAR, MONTH, DAY, HOUR, MINUTE, SECOND") + .build(logger); + } + } + + /** + * Given a {@link SqlTypeName} and nullability, create a RelDataType from the RelDataTypeFactory + * + * @param typeFactory RelDataTypeFactory used to create the RelDataType + * @param sqlTypeName the given SqlTypeName + * @param isNullable the nullability of the created RelDataType + * @return RelDataType Type of call + */ + public static RelDataType createCalciteTypeWithNullability(RelDataTypeFactory typeFactory, + SqlTypeName sqlTypeName, + boolean isNullable) { + RelDataType type; + if (sqlTypeName == SqlTypeName.INTERVAL_DAY_TIME) { + type = typeFactory.createSqlIntervalType( + new SqlIntervalQualifier( + TimeUnit.DAY, + TimeUnit.MINUTE, + SqlParserPos.ZERO)); + } else if (sqlTypeName == SqlTypeName.INTERVAL_YEAR_MONTH) { + type = typeFactory.createSqlIntervalType( + new SqlIntervalQualifier( + TimeUnit.YEAR, + TimeUnit.MONTH, + SqlParserPos.ZERO)); + } else if (sqlTypeName == SqlTypeName.VARCHAR) { + type = typeFactory.createSqlType(sqlTypeName, TypeHelper.VARCHAR_DEFAULT_CAST_LEN); + } else { + type = typeFactory.createSqlType(sqlTypeName); + } + return typeFactory.createTypeWithNullability(type, isNullable); + } + + /** + * Given a SqlOperatorBinding, convert it to FunctionCall + * @param opBinding the given SqlOperatorBinding + * @return FunctionCall the converted FunctionCall + */ + public static FunctionCall convertSqlOperatorBindingToFunctionCall(final SqlOperatorBinding opBinding) { + final List args = Lists.newArrayList(); + + for (int i = 0; i < opBinding.getOperandCount(); ++i) { + final RelDataType type = opBinding.getOperandType(i); + final TypeProtos.MinorType minorType = getDrillTypeFromCalciteType(type); + final TypeProtos.MajorType majorType; + if (type.isNullable()) { + majorType = Types.optional(minorType); + } else { + majorType = Types.required(minorType); + } + + args.add(new MajorTypeInLogicalExpression(majorType)); + } + + final String drillFuncName = FunctionCallFactory.replaceOpWithFuncName(opBinding.getOperator().getName()); + final FunctionCall functionCall = new FunctionCall( + drillFuncName, + args, + ExpressionPosition.UNKNOWN); + return functionCall; + } + + /** + * This class is not intended to be instantiated + */ + private TypeInferenceUtils() { + + } +} \ No newline at end of file diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/handlers/CreateTableHandler.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/handlers/CreateTableHandler.java index 0ebb557a943..b6ffde67adf 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/handlers/CreateTableHandler.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/handlers/CreateTableHandler.java @@ -234,7 +234,7 @@ private RexNode createPartitionColComparator(final RexBuilder rexBuilder, List compFuncs) { final DrillSqlOperator booleanOrFunc - = new DrillSqlOperator("orNoShortCircuit", 2, MajorType.getDefaultInstance(), true); + = new DrillSqlOperator("orNoShortCircuit", 2, true); RexNode node = compFuncs.remove(0); while (!compFuncs.isEmpty()) { node = rexBuilder.makeCall(booleanOrFunc, node, compFuncs.remove(0)); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/handlers/DefaultSqlHandler.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/handlers/DefaultSqlHandler.java index 5152fa609dc..4ca9fe426fa 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/handlers/DefaultSqlHandler.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/handlers/DefaultSqlHandler.java @@ -234,14 +234,17 @@ protected DrillRel convertToDrel(final RelNode relNode) throws SqlUnsupportedExc convertedRelNode = transform(PlannerType.HEP_BOTTOM_UP, PlannerPhase.JOIN_PLANNING, intermediateNode2); } - final DrillRel drillRel = (DrillRel) convertedRelNode; + // Convert SUM to $SUM0 + final RelNode convertedRelNodeWithSum0 = transform(PlannerType.HEP_BOTTOM_UP, PlannerPhase.SUM_CONVERSION, convertedRelNode); + + final DrillRel drillRel = (DrillRel) convertedRelNodeWithSum0; if (drillRel instanceof DrillStoreRel) { throw new UnsupportedOperationException(); } else { // If the query contains a limit 0 clause, disable distributed mode since it is overkill for determining schema. - if (FindLimit0Visitor.containsLimit0(convertedRelNode)) { + if (FindLimit0Visitor.containsLimit0(convertedRelNodeWithSum0)) { context.getPlannerSettings().forceSingleMode(); } @@ -613,7 +616,8 @@ private RelNode preprocessNode(RelNode rel) throws SqlUnsupportedException { */ PreProcessLogicalRel visitor = PreProcessLogicalRel.createVisitor(config.getConverter().getTypeFactory(), - context.getDrillOperatorTable()); + context.getDrillOperatorTable(), + rel.getCluster().getRexBuilder()); try { rel = rel.accept(visitor); } catch (UnsupportedOperationException ex) { diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java index e90aa3b5684..528cadc4727 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java @@ -25,6 +25,7 @@ import org.apache.drill.exec.exception.UnsupportedOperatorCollector; import org.apache.drill.exec.ops.QueryContext; import org.apache.drill.exec.planner.physical.PlannerSettings; +import org.apache.drill.exec.planner.sql.DrillCalciteSqlWrapper; import org.apache.drill.exec.work.foreman.SqlUnsupportedException; import org.apache.calcite.sql.SqlSelectKeyword; @@ -351,7 +352,7 @@ public SqlNode visit(SqlCall sqlCall) { } } - if(sqlCall.getOperator() instanceof SqlCountAggFunction) { + if(extractSqlOperatorFromWrapper(sqlCall.getOperator()) instanceof SqlCountAggFunction) { for(SqlNode sqlNode : sqlCall.getOperandList()) { if(containsFlatten(sqlNode)) { unsupportedOperatorCollector.setException(SqlUnsupportedException.ExceptionType.FUNCTION, @@ -415,7 +416,7 @@ private interface SqlNodeCondition { @Override public boolean test(SqlNode sqlNode) { if (sqlNode instanceof SqlCall) { - final SqlOperator operator = ((SqlCall) sqlNode).getOperator(); + final SqlOperator operator = extractSqlOperatorFromWrapper(((SqlCall) sqlNode).getOperator()); if (operator == SqlStdOperatorTable.ROLLUP || operator == SqlStdOperatorTable.CUBE || operator == SqlStdOperatorTable.GROUPING_SETS) { @@ -433,10 +434,10 @@ public boolean test(SqlNode sqlNode) { @Override public boolean test(SqlNode sqlNode) { if (sqlNode instanceof SqlCall) { - final SqlOperator operator = ((SqlCall) sqlNode).getOperator(); - if (operator == SqlStdOperatorTable.GROUPING - || operator == SqlStdOperatorTable.GROUPING_ID - || operator == SqlStdOperatorTable.GROUP_ID) { + final SqlOperator operator = extractSqlOperatorFromWrapper(((SqlCall) sqlNode).getOperator()); + if (operator == SqlStdOperatorTable.GROUPING + || operator == SqlStdOperatorTable.GROUPING_ID + || operator == SqlStdOperatorTable.GROUP_ID) { return true; } } @@ -554,4 +555,12 @@ private void detectMultiplePartitions(SqlSelect sqlSelect) { } } } + + private SqlOperator extractSqlOperatorFromWrapper(SqlOperator sqlOperator) { + if(sqlOperator instanceof DrillCalciteSqlWrapper) { + return ((DrillCalciteSqlWrapper) sqlOperator).getOperator(); + } else { + return sqlOperator; + } + } } \ No newline at end of file diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/DefaultFunctionResolver.java b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/DefaultFunctionResolver.java index 69042720bdf..bf125d7c884 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/DefaultFunctionResolver.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/DefaultFunctionResolver.java @@ -21,7 +21,10 @@ import java.util.LinkedList; import java.util.List; +import com.google.common.collect.Lists; import org.apache.drill.common.expression.FunctionCall; +import org.apache.drill.common.expression.LogicalExpression; +import org.apache.drill.common.types.TypeProtos; import org.apache.drill.exec.expr.fn.DrillFuncHolder; import org.apache.drill.exec.util.AssertionUtil; @@ -30,7 +33,7 @@ public class DefaultFunctionResolver implements FunctionResolver { private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DefaultFunctionResolver.class); @Override - public DrillFuncHolder getBestMatch(List methods,FunctionCall call) { + public DrillFuncHolder getBestMatch(List methods, FunctionCall call) { int bestcost = Integer.MAX_VALUE; int currcost = Integer.MAX_VALUE; @@ -38,8 +41,11 @@ public DrillFuncHolder getBestMatch(List methods,FunctionCall c final List bestMatchAlternatives = new LinkedList<>(); for (DrillFuncHolder h : methods) { - - currcost = TypeCastRules.getCost(call, h); + final List argumentTypes = Lists.newArrayList(); + for (LogicalExpression expression : call.args) { + argumentTypes.add(expression.getMajorType()); + } + currcost = TypeCastRules.getCost(argumentTypes, h); // if cost is lower than 0, func implementation is not matched, either w/ or w/o implicit casts if (currcost < 0 ) { @@ -79,5 +85,4 @@ public DrillFuncHolder getBestMatch(List methods,FunctionCall c return bestmatch; } } - } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/ExactFunctionResolver.java b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/ExactFunctionResolver.java index ab6be331435..72e27efb4d9 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/ExactFunctionResolver.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/ExactFunctionResolver.java @@ -17,7 +17,10 @@ */ package org.apache.drill.exec.resolver; +import com.google.common.collect.Lists; import org.apache.drill.common.expression.FunctionCall; +import org.apache.drill.common.expression.LogicalExpression; +import org.apache.drill.common.types.TypeProtos; import org.apache.drill.exec.expr.fn.DrillFuncHolder; import java.util.List; @@ -37,8 +40,11 @@ public DrillFuncHolder getBestMatch(List methods, FunctionCall int currcost; for (DrillFuncHolder h : methods) { - - currcost = TypeCastRules.getCost(call, h); + final List argumentTypes = Lists.newArrayList(); + for (LogicalExpression expression : call.args) { + argumentTypes.add(expression.getMajorType()); + } + currcost = TypeCastRules.getCost(argumentTypes, h); // Return if we found a function that has an exact match with the input arguments if (currcost == 0){ diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/FunctionResolver.java b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/FunctionResolver.java index 14d46c9d88c..e2a9622d6b7 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/FunctionResolver.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/FunctionResolver.java @@ -23,8 +23,19 @@ import org.apache.drill.common.expression.FunctionCall; import org.apache.drill.exec.expr.fn.DrillFuncHolder; +/** + * An implementing class of FunctionResolver provide their own algorithm to choose a DrillFuncHolder from a given list of + * candidates, with respect to a given FunctionCall + */ public interface FunctionResolver { - - public DrillFuncHolder getBestMatch(List methods, FunctionCall call); - + /** + * Creates a placeholder SqlFunction for an invocation of a function with a + * possibly qualified name. This name must be resolved into either a builtin + * function or a user-defined function. + * + * @param methods a list of candidates of DrillFuncHolder to be chosen from + * @param call a given function call whose DrillFuncHolder is to be determined via this method + * @return DrillFuncHolder the chosen DrillFuncHolder + */ + DrillFuncHolder getBestMatch(List methods, FunctionCall call); } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/FunctionResolverFactory.java b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/FunctionResolverFactory.java index b9070afb4bd..ab6934bbfd3 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/FunctionResolverFactory.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/FunctionResolverFactory.java @@ -21,7 +21,6 @@ import org.apache.drill.common.expression.FunctionCall; public class FunctionResolverFactory { - public static FunctionResolver getResolver(FunctionCall call) { return new DefaultFunctionResolver(); } @@ -29,6 +28,4 @@ public static FunctionResolver getResolver(FunctionCall call) { public static FunctionResolver getExactResolver(FunctionCall call) { return new ExactFunctionResolver(); } - - } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java index 7ee8ebed33a..ae429379579 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/resolver/TypeCastRules.java @@ -26,7 +26,8 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import org.apache.drill.common.expression.FunctionCall; +import org.apache.drill.common.expression.MajorTypeInLogicalExpression; +import org.apache.drill.common.expression.LogicalExpression; import org.apache.drill.common.types.TypeProtos.DataMode; import org.apache.drill.common.types.TypeProtos.MajorType; import org.apache.drill.common.types.TypeProtos.MinorType; @@ -834,10 +835,10 @@ public static MinorType getLeastRestrictiveType(List types) { * implicit cast > 0: cost associated with implicit cast. ==0: parms are * exactly same type of arg. No need of implicit. */ - public static int getCost(FunctionCall call, DrillFuncHolder holder) { + public static int getCost(List argumentTypes, DrillFuncHolder holder) { int cost = 0; - if (call.args.size() != holder.getParamCount()) { + if (argumentTypes.size() != holder.getParamCount()) { return -1; } @@ -852,14 +853,21 @@ public static int getCost(FunctionCall call, DrillFuncHolder holder) { * the function can fit the precision that we need based on the input types. */ if (holder.checkPrecisionRange() == true) { - if (DecimalUtility.getMaxPrecision(holder.getReturnType().getMinorType()) < holder.getReturnType(call.args).getPrecision()) { + List logicalExpressions = Lists.newArrayList(); + for(MajorType majorType : argumentTypes) { + logicalExpressions.add( + new MajorTypeInLogicalExpression(majorType)); + } + + if (DecimalUtility.getMaxPrecision(holder.getReturnType().getMinorType()) < + holder.getReturnType(logicalExpressions).getPrecision()) { return -1; } } final int numOfArgs = holder.getParamCount(); for (int i = 0; i < numOfArgs; i++) { - final MajorType argType = call.args.get(i).getMajorType(); + final MajorType argType = argumentTypes.get(i); final MajorType parmType = holder.getParmMajorType(i); //@Param FieldReader will match any type diff --git a/exec/java-exec/src/test/java/org/apache/drill/TestDisabledFunctionality.java b/exec/java-exec/src/test/java/org/apache/drill/TestDisabledFunctionality.java index fdf39c18111..5f6cd9cce97 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/TestDisabledFunctionality.java +++ b/exec/java-exec/src/test/java/org/apache/drill/TestDisabledFunctionality.java @@ -28,16 +28,6 @@ public class TestDisabledFunctionality extends BaseTestQuery{ static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(TestExampleQueries.class); - @Test(expected = UserException.class) // see DRILL-2054 - public void testBooleanORExpression() throws Exception { - test("select (1 = 1) || (1 > 0) from cp.`tpch/nation.parquet` "); - } - - @Test(expected = UserException.class) // see DRILL-2054 - public void testBooleanORSelectClause() throws Exception { - test("select true || true from cp.`tpch/nation.parquet` "); - } - @Test(expected = UserException.class) // see DRILL-2054 public void testBooleanORWhereClause() throws Exception { test("select * from cp.`tpch/nation.parquet` where (true || true) "); diff --git a/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java b/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java index 7d3f6d07bd2..81d093c88b4 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java +++ b/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java @@ -18,16 +18,13 @@ package org.apache.drill; import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import org.apache.commons.lang3.tuple.Pair; import org.apache.drill.common.expression.SchemaPath; import org.apache.drill.common.types.TypeProtos; import org.apache.drill.common.util.FileUtils; -import org.junit.Ignore; import org.junit.Test; import java.util.List; -import java.util.Map; public class TestFunctionsWithTypeExpoQueries extends BaseTestQuery { @Test @@ -50,10 +47,45 @@ public void testConcatWithMoreThanTwoArgs() throws Exception { } @Test - public void testTrimOnlyOneArg() throws Exception { - final String query1 = "SELECT ltrim('drill') as col FROM (VALUES(1)) limit 0"; - final String query2 = "SELECT rtrim('drill') as col FROM (VALUES(1)) limit 0"; - final String query3 = "SELECT btrim('drill') as col FROM (VALUES(1)) limit 0"; + public void testRow_NumberInView() throws Exception { + try { + test("use dfs_test.tmp;"); + final String view1 = + "create view TestFunctionsWithTypeExpoQueries_testViewShield1 as \n" + + "select rnum, position_id, " + + " ntile(4) over(order by position_id) " + + " from (select position_id, row_number() " + + " over(order by position_id) as rnum " + + " from cp.`employee.json`)"; + + + final String view2 = + "create view TestFunctionsWithTypeExpoQueries_testViewShield2 as \n" + + "select row_number() over(order by position_id) as rnum, " + + " position_id, " + + " ntile(4) over(order by position_id) " + + " from cp.`employee.json`"; + + test(view1); + test(view2); + + testBuilder() + .sqlQuery("select * from TestFunctionsWithTypeExpoQueries_testViewShield1") + .ordered() + .sqlBaselineQuery("select * from TestFunctionsWithTypeExpoQueries_testViewShield2") + .build() + .run(); + } finally { + test("drop view TestFunctionsWithTypeExpoQueries_testViewShield1;"); + test("drop view TestFunctionsWithTypeExpoQueries_testViewShield2;"); + } + } + + @Test + public void testLRBTrimOneArg() throws Exception { + final String query1 = "SELECT ltrim('drill') as col FROM cp.`tpch/region.parquet` limit 0"; + final String query2 = "SELECT rtrim('drill') as col FROM cp.`tpch/region.parquet` limit 0"; + final String query3 = "SELECT btrim('drill') as col FROM cp.`tpch/region.parquet` limit 0"; List> expectedSchema = Lists.newArrayList(); TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() @@ -82,18 +114,64 @@ public void testTrimOnlyOneArg() throws Exception { } @Test - public void testExtract() throws Exception { - final String query = "select extract(second from time '02:30:45.100') as col \n" + - "from cp.`employee.json` limit 0"; + public void testTrimOneArg() throws Exception { + final String query1 = "SELECT trim(leading 'drill') as col FROM cp.`tpch/region.parquet` limit 0"; + final String query2 = "SELECT trim(trailing 'drill') as col FROM cp.`tpch/region.parquet` limit 0"; + final String query3 = "SELECT trim(both 'drill') as col FROM cp.`tpch/region.parquet` limit 0"; + List> expectedSchema = Lists.newArrayList(); TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() - .setMinorType(TypeProtos.MinorType.FLOAT8) - .setMode(TypeProtos.DataMode.OPTIONAL) + .setMinorType(TypeProtos.MinorType.VARCHAR) + .setMode(TypeProtos.DataMode.REQUIRED) .build(); expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); testBuilder() - .sqlQuery(query) + .sqlQuery(query1) + .schemaBaseLine(expectedSchema) + .build() + .run(); + + testBuilder() + .sqlQuery(query2) + .schemaBaseLine(expectedSchema) + .build() + .run(); + + testBuilder() + .sqlQuery(query3) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + + @Test + public void testTrimTwoArg() throws Exception { + final String query1 = "SELECT trim(leading ' ' from 'drill') as col FROM cp.`tpch/region.parquet` limit 0"; + final String query2 = "SELECT trim(trailing ' ' from 'drill') as col FROM cp.`tpch/region.parquet` limit 0"; + final String query3 = "SELECT trim(both ' ' from 'drill') as col FROM cp.`tpch/region.parquet` limit 0"; + + List> expectedSchema = Lists.newArrayList(); + TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.VARCHAR) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query1) + .schemaBaseLine(expectedSchema) + .build() + .run(); + + testBuilder() + .sqlQuery(query2) + .schemaBaseLine(expectedSchema) + .build() + .run(); + + testBuilder() + .sqlQuery(query3) .schemaBaseLine(expectedSchema) .build() .run(); @@ -105,7 +183,7 @@ public void tesIsNull() throws Exception { List> expectedSchema = Lists.newArrayList(); TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() .setMinorType(TypeProtos.MinorType.BIT) - .setMode(TypeProtos.DataMode.OPTIONAL) + .setMode(TypeProtos.DataMode.REQUIRED) .build(); expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); @@ -115,4 +193,179 @@ public void tesIsNull() throws Exception { .build() .run(); } + + /** + * In the following query, the extract function would be borrowed from Calcite, + * which asserts the return type as be BIG-INT + */ + @Test + public void testExtractSecond() throws Exception { + String query = "select extract(second from time '02:30:45.100') as col \n" + + "from cp.`tpch/region.parquet` \n" + + "limit 0"; + + List> expectedSchema = Lists.newArrayList(); + TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.FLOAT8) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + + @Test + public void testMetaDataExposeType() throws Exception { + final String root = FileUtils.getResourceAsFile("/typeExposure/metadata_caching").toURI().toString(); + final String query = String.format("select count(*) as col \n" + + "from dfs_test.`%s` \n" + + "where concat(a, 'asdf') = 'asdf'", root); + + // Validate the plan + final String[] expectedPlan = {"Scan.*a.parquet.*numFiles=1"}; + final String[] excludedPlan = {"Filter"}; + PlanTestBase.testPlanMatchingPatterns(query, expectedPlan, excludedPlan); + + // Validate the result + testBuilder() + .sqlQuery(query) + .ordered() + .baselineColumns("col") + .baselineValues(1l) + .build() + .run(); + } + + @Test + public void testDate_Part() throws Exception { + final String query = "select date_part('year', date '2008-2-23') as col \n" + + "from cp.`tpch/region.parquet` \n" + + "limit 0"; + + List> expectedSchema = Lists.newArrayList(); + TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + + @Test + public void testNegativeByInterpreter() throws Exception { + final String query = "select * from cp.`tpch/region.parquet` \n" + + "where r_regionkey = negative(-1)"; + + // Validate the plan + final String[] expectedPlan = {"Filter.*condition=\\[=\\(.*, 1\\)\\]\\)"}; + final String[] excludedPlan = {}; + PlanTestBase.testPlanMatchingPatterns(query, expectedPlan, excludedPlan); + } + + @Test + public void testSumRequiredType() throws Exception { + final String query = "SELECT \n" + + "SUM(CASE WHEN (CAST(n_regionkey AS INT) = 1) THEN 1 ELSE 0 END) AS col \n" + + "FROM cp.`tpch/nation.parquet` \n" + + "GROUP BY CAST(n_regionkey AS INT) \n" + + "limit 0"; + + List> expectedSchema = Lists.newArrayList(); + TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + + @Test + public void testSQRT() throws Exception { + final String query = "SELECT sqrt(5.1) as col \n" + + "from cp.`tpch/nation.parquet` \n" + + "limit 0"; + + List> expectedSchema = Lists.newArrayList(); + TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.FLOAT8) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + + @Test + public void testTimestampDiff() throws Exception { + final String query = "select to_timestamp('2014-02-13 00:30:30','YYYY-MM-dd HH:mm:ss') - to_timestamp('2014-02-13 00:30:30','YYYY-MM-dd HH:mm:ss') as col \n" + + "from cp.`tpch/region.parquet` \n" + + "limit 0"; + + final List> expectedSchema = Lists.newArrayList(); + final TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.INTERVALDAY) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + + @Test + public void testAvgAndSUM() throws Exception { + final String query = "SELECT AVG(cast(r_regionkey as float)) AS `col1`, \n" + + "SUM(cast(r_regionkey as float)) AS `col2`, \n" + + "SUM(1) AS `col3` \n" + + "FROM cp.`tpch/region.parquet` \n" + + "GROUP BY CAST(r_regionkey AS INTEGER) \n" + + "LIMIT 0"; + + final List> expectedSchema = Lists.newArrayList(); + final TypeProtos.MajorType majorType1 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.FLOAT8) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final TypeProtos.MajorType majorType2 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.FLOAT8) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final TypeProtos.MajorType majorType3 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col1"), majorType1)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col2"), majorType2)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col3"), majorType3)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } } diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java b/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java index e1ef7c9da62..0ff789aab17 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java +++ b/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java @@ -56,4 +56,89 @@ public void testSubstr() throws Exception { .build() .run(); } + + @Test + public void testLpadTwoArgConvergeToLpad() throws Exception { + final String query_1 = "SELECT lpad(r_name, 25) \n" + + "FROM cp.`tpch/region.parquet`"; + + + final String query_2 = "SELECT lpad(r_name, 25, ' ') \n" + + "FROM cp.`tpch/region.parquet`"; + + testBuilder() + .sqlQuery(query_1) + .unOrdered() + .sqlBaselineQuery(query_2) + .build() + .run(); + } + + @Test + public void testRpadTwoArgConvergeToRpad() throws Exception { + final String query_1 = "SELECT rpad(r_name, 25) \n" + + "FROM cp.`tpch/region.parquet`"; + + + final String query_2 = "SELECT rpad(r_name, 25, ' ') \n" + + "FROM cp.`tpch/region.parquet`"; + + testBuilder() + .sqlQuery(query_1) + .unOrdered() + .sqlBaselineQuery(query_2) + .build() + .run(); + } + + @Test + public void testLtrimOneArgConvergeToLtrim() throws Exception { + final String query_1 = "SELECT ltrim(concat(' ', r_name, ' ')) \n" + + "FROM cp.`tpch/region.parquet`"; + + + final String query_2 = "SELECT ltrim(concat(' ', r_name, ' '), ' ') \n" + + "FROM cp.`tpch/region.parquet`"; + + testBuilder() + .sqlQuery(query_1) + .unOrdered() + .sqlBaselineQuery(query_2) + .build() + .run(); + } + + @Test + public void testRtrimOneArgConvergeToRtrim() throws Exception { + final String query_1 = "SELECT rtrim(concat(' ', r_name, ' ')) \n" + + "FROM cp.`tpch/region.parquet`"; + + + final String query_2 = "SELECT rtrim(concat(' ', r_name, ' '), ' ') \n" + + "FROM cp.`tpch/region.parquet`"; + + testBuilder() + .sqlQuery(query_1) + .unOrdered() + .sqlBaselineQuery(query_2) + .build() + .run(); + } + + @Test + public void testBtrimOneArgConvergeToBtrim() throws Exception { + final String query_1 = "SELECT btrim(concat(' ', r_name, ' ')) \n" + + "FROM cp.`tpch/region.parquet`"; + + + final String query_2 = "SELECT btrim(concat(' ', r_name, ' '), ' ') \n" + + "FROM cp.`tpch/region.parquet`"; + + testBuilder() + .sqlQuery(query_1) + .unOrdered() + .sqlBaselineQuery(query_2) + .build() + .run(); + } } diff --git a/exec/java-exec/src/test/resources/testframework/testFunctionsWithTypeExpoQueries/testConcatWithMoreThanTwoArgs.tsv b/exec/java-exec/src/test/resources/testframework/testFunctionsWithTypeExpoQueries/testConcatWithMoreThanTwoArgs.tsv new file mode 100644 index 00000000000..887c45f94cf --- /dev/null +++ b/exec/java-exec/src/test/resources/testframework/testFunctionsWithTypeExpoQueries/testConcatWithMoreThanTwoArgs.tsv @@ -0,0 +1,5 @@ +AFRICAAFRICAAFRICA +AMERICAAMERICAAMERICA +ASIAASIAASIA +EUROPEEUROPEEUROPE +MIDDLE EASTMIDDLE EASTMIDDLE EAST \ No newline at end of file diff --git a/exec/java-exec/src/test/resources/typeExposure/metadata_caching/a.parquet b/exec/java-exec/src/test/resources/typeExposure/metadata_caching/a.parquet new file mode 100644 index 0000000000000000000000000000000000000000..bdbba9e830cb77b776428eb846ece9d32189eda3 GIT binary patch literal 439 zcmb7Oc2cGn0b2chK$L$L@^KfGA1F+c#cTG9RQeOdPDDGQvNB#)d68f z;+{UlRXMKS@28(0#b_~o(e`!MY?CL8KAg&WIPZTh0XgO}631Pty-W1o*tfQSjOe^zPNFj2e&Bn+=NG&%_U0kBC-BBAf2I3^_nSiWP8h>sCp6jGn0H=y@t9 zp)K_$oyF%f^9SBv-mgjmNJ{vOz+hs6(*&7ay*VibVXh?%dKICd{b~4cEADiHdL2CU zO9I*A0r{@m69mY30+Yy0z}ea0C}b&{WjG~cRskHd5Fn63`bcet7XzYf`zvX`liQH| zgF-(d3c+;rV*KmD?eixO-b0OL-}`@5F}{@Qn7o8)^+UzqoZmIwGcV_QUgouL+BWwE ppV~vy)?8UtR7N^eR8}ch8f7 T accept(ExprVisitor visitor, V value) throws E { + throw new UnsupportedOperationException(); + } + + @Override + public ExpressionPosition getPosition() { + throw new UnsupportedOperationException(); + } + + public int getSelfCost() { + throw new UnsupportedOperationException(); + } + + @Override + public int getCumulativeCost() { + throw new UnsupportedOperationException(); + } + + @Override + public Iterator iterator() { + throw new UnsupportedOperationException(); + } +} \ No newline at end of file diff --git a/pom.xml b/pom.xml index 1379cfe021d..4dfa682b99e 100644 --- a/pom.xml +++ b/pom.xml @@ -1279,7 +1279,7 @@ org.apache.calcite calcite-core - 1.4.0-drill-r10 + 1.4.0-drill-test-r16 org.jgrapht From 9ecf4a484e2cc03f73aacd1b4f3801bb1909b71f Mon Sep 17 00:00:00 2001 From: Hsuan-Yi Chu Date: Thu, 3 Mar 2016 20:14:59 -0800 Subject: [PATCH 2/5] DRILL-4372: (continued) Type inference for HiveUDFs --- .../exec/expr/fn/HiveFunctionRegistry.java | 58 ++++++++++++++++++- .../exec/planner/sql/HiveUDFOperator.java | 28 ++------- .../exec/fn/hive/TestInbuiltHiveUDFs.java | 28 +++++++++ 3 files changed, 89 insertions(+), 25 deletions(-) diff --git a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java index 728954d8068..9a4e2101b62 100644 --- a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java +++ b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java @@ -18,18 +18,32 @@ package org.apache.drill.exec.expr.fn; import java.util.HashSet; +import java.util.List; import java.util.Set; +import java.util.Collection; +import com.google.common.collect.Lists; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.drill.common.config.DrillConfig; +import org.apache.drill.common.exceptions.UserException; +import org.apache.drill.common.expression.ExpressionPosition; import org.apache.drill.common.expression.FunctionCall; +import org.apache.drill.common.expression.LogicalExpression; +import org.apache.drill.common.expression.MajorTypeInLogicalExpression; import org.apache.drill.common.scanner.ClassPathScanner; import org.apache.drill.common.scanner.persistence.ScanResult; +import org.apache.drill.common.types.TypeProtos; import org.apache.drill.common.types.TypeProtos.MajorType; import org.apache.drill.common.types.TypeProtos.MinorType; import org.apache.drill.common.types.Types; import org.apache.drill.exec.expr.fn.impl.hive.ObjectInspectorHelper; import org.apache.drill.exec.planner.sql.DrillOperatorTable; import org.apache.drill.exec.planner.sql.HiveUDFOperator; +import org.apache.drill.exec.planner.sql.TypeInferenceUtils; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.hive.ql.udf.UDFType; @@ -70,7 +84,7 @@ public HiveFunctionRegistry(DrillConfig config) { @Override public void register(DrillOperatorTable operatorTable) { for (String name : Sets.union(methodsGenericUDF.asMap().keySet(), methodsUDF.asMap().keySet())) { - operatorTable.add(name, new HiveUDFOperator(name.toUpperCase())); + operatorTable.add(name, new HiveUDFOperator(name.toUpperCase(), new HiveSqlReturnTypeInference())); } } @@ -204,4 +218,46 @@ private HiveFuncHolder matchAndCreateUDFHolder(String udfName, return null; } + public class HiveSqlReturnTypeInference implements SqlReturnTypeInference { + private HiveSqlReturnTypeInference() { + + } + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + for (RelDataType type : opBinding.collectOperandTypes()) { + final TypeProtos.MinorType minorType = TypeInferenceUtils.getDrillTypeFromCalciteType(type); + if(minorType == TypeProtos.MinorType.LATE) { + return opBinding.getTypeFactory() + .createTypeWithNullability( + opBinding.getTypeFactory().createSqlType(SqlTypeName.ANY), + true); + } + } + + final FunctionCall functionCall = TypeInferenceUtils.convertSqlOperatorBindingToFunctionCall(opBinding); + final HiveFuncHolder hiveFuncHolder = getFunction(functionCall); + if(hiveFuncHolder == null) { + String operandTypes = ""; + for(int j = 0; j < opBinding.getOperandCount(); ++j) { + operandTypes += opBinding.getOperandType(j).getSqlTypeName(); + if(j < opBinding.getOperandCount() - 1) { + operandTypes += ","; + } + } + + throw UserException + .functionError() + .message(String.format("%s does not support operand types (%s)", + opBinding.getOperator().getName(), + operandTypes)) + .build(logger); + } + + return TypeInferenceUtils.createCalciteTypeWithNullability( + opBinding.getTypeFactory(), + TypeInferenceUtils.getCalciteTypeFromDrillType(hiveFuncHolder.getReturnType().getMinorType()), + hiveFuncHolder.getReturnType().getMode() != TypeProtos.DataMode.REQUIRED); + } + } } diff --git a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java index a9647bd0ce7..90c4135e33e 100644 --- a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java +++ b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java @@ -18,28 +18,20 @@ package org.apache.drill.exec.planner.sql; -import com.fasterxml.jackson.databind.type.TypeFactory; -import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.SqlOperandCountRanges; import org.apache.calcite.sql.type.SqlOperandTypeChecker; -import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.sql.validate.SqlValidator; -import org.apache.calcite.sql.validate.SqlValidatorScope; +import org.apache.calcite.sql.type.SqlReturnTypeInference; public class HiveUDFOperator extends SqlFunction { - - public HiveUDFOperator(String name) { - super(new SqlIdentifier(name, SqlParserPos.ZERO), DynamicReturnType.INSTANCE, null, new ArgChecker(), null, + public HiveUDFOperator(String name, SqlReturnTypeInference sqlReturnTypeInference) { + super(new SqlIdentifier(name, SqlParserPos.ZERO), sqlReturnTypeInference, null, new ArgChecker(), null, SqlFunctionCategory.USER_DEFINED_FUNCTION); } @@ -51,19 +43,7 @@ public boolean isDeterministic() { return false; } - @Override - public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { - RelDataTypeFactory factory = validator.getTypeFactory(); - return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true); - } - - @Override - public RelDataType inferReturnType(SqlOperatorBinding opBinding) { - RelDataTypeFactory factory = opBinding.getTypeFactory(); - return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true); - } - - /** Argument Checker for variable number of arguments */ + /** Argument Checker for variable number of arguments */ public static class ArgChecker implements SqlOperandTypeChecker { public static ArgChecker INSTANCE = new ArgChecker(); diff --git a/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java b/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java index aba7573764c..14390626b24 100644 --- a/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java +++ b/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java @@ -43,4 +43,32 @@ public void testEncode() throws Exception { .baselineValues(new Object[] { null }) .go(); } + + @Test + public void testReflect() throws Exception { + final String query = "select reflect('java.lang.Math', 'round', cast(2 as float)) as col \n" + + "from hive.kv \n" + + "limit 1"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("col") + .baselineValues("2") + .go(); + } + + @Test + public void testAbs() throws Exception { + final String query = "select reflect('java.lang.Math', 'abs', cast(-2 as double)) as col \n" + + "from hive.kv \n" + + "limit 1"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("col") + .baselineValues("2.0") + .go(); + } } From c9f8621d228cca803f967ae91c277f74c6e8e748 Mon Sep 17 00:00:00 2001 From: Hsuan-Yi Chu Date: Tue, 8 Mar 2016 17:57:36 -0800 Subject: [PATCH 3/5] DRILL-4372: (continued) Add option to disable/enable function output type inference --- .../exec/expr/fn/HiveFunctionRegistry.java | 4 +- .../planner/sql/HiveUDFOperatorNotInfer.java | 44 ++++ .../exec/expr/fn/DrillFunctionRegistry.java | 89 +++++++- .../apache/drill/exec/ops/QueryContext.java | 2 +- .../logical/DrillReduceAggregatesRule.java | 211 ++++++++---------- .../planner/physical/PlannerSettings.java | 7 + .../sql/DrillAvgVarianceConvertlet.java | 14 +- .../exec/planner/sql/DrillOperatorTable.java | 86 +++++-- .../exec/planner/sql/DrillSqlAggOperator.java | 56 ++++- .../sql/DrillSqlAggOperatorNotInfer.java | 43 ++++ .../exec/planner/sql/DrillSqlOperator.java | 99 +++++++- .../planner/sql/DrillSqlOperatorNotInfer.java | 76 +++++++ .../drill/exec/planner/sql/SqlConverter.java | 70 +++++- .../exec/planner/sql/TypeInferenceUtils.java | 13 +- .../server/options/SystemOptionManager.java | 1 + .../TestFunctionsWithTypeExpoQueries.java | 188 +++++++++++++++- 16 files changed, 837 insertions(+), 166 deletions(-) create mode 100644 contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorNotInfer.java create mode 100644 exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorNotInfer.java create mode 100644 exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorNotInfer.java diff --git a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java index 9a4e2101b62..52bd05b4055 100644 --- a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java +++ b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java @@ -43,6 +43,7 @@ import org.apache.drill.exec.expr.fn.impl.hive.ObjectInspectorHelper; import org.apache.drill.exec.planner.sql.DrillOperatorTable; import org.apache.drill.exec.planner.sql.HiveUDFOperator; +import org.apache.drill.exec.planner.sql.HiveUDFOperatorNotInfer; import org.apache.drill.exec.planner.sql.TypeInferenceUtils; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDF; @@ -84,7 +85,8 @@ public HiveFunctionRegistry(DrillConfig config) { @Override public void register(DrillOperatorTable operatorTable) { for (String name : Sets.union(methodsGenericUDF.asMap().keySet(), methodsUDF.asMap().keySet())) { - operatorTable.add(name, new HiveUDFOperator(name.toUpperCase(), new HiveSqlReturnTypeInference())); + operatorTable.addDefault(name, new HiveUDFOperatorNotInfer(name.toUpperCase())); + operatorTable.addInference(name, new HiveUDFOperator(name.toUpperCase(), new HiveSqlReturnTypeInference())); } } diff --git a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorNotInfer.java b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorNotInfer.java new file mode 100644 index 00000000000..0c718f61f0b --- /dev/null +++ b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorNotInfer.java @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.sql; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorScope; + +public class HiveUDFOperatorNotInfer extends HiveUDFOperator { + public HiveUDFOperatorNotInfer(String name) { + super(name, DynamicReturnType.INSTANCE); + } + + @Override + public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { + RelDataTypeFactory factory = validator.getTypeFactory(); + return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true); + } + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + RelDataTypeFactory factory = opBinding.getTypeFactory(); + return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true); + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java index 76ec90dde5d..f6bc666f8c1 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java @@ -23,10 +23,13 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Set; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.commons.lang3.tuple.Pair; import org.apache.drill.common.scanner.persistence.AnnotatedClassDescriptor; @@ -35,9 +38,11 @@ import org.apache.drill.exec.planner.logical.DrillConstExecutor; import org.apache.drill.exec.planner.sql.DrillOperatorTable; import org.apache.drill.exec.planner.sql.DrillSqlAggOperator; +import org.apache.drill.exec.planner.sql.DrillSqlAggOperatorNotInfer; import org.apache.drill.exec.planner.sql.DrillSqlOperator; import com.google.common.collect.ArrayListMultimap; +import org.apache.drill.exec.planner.sql.DrillSqlOperatorNotInfer; /** * Registry of Drill functions. @@ -122,6 +127,13 @@ public List getMethods(String name) { } public void register(DrillOperatorTable operatorTable) { + registerForInference(operatorTable); + registerForDefault(operatorTable); + } + + public void registerForInference(DrillOperatorTable operatorTable) { + final Map map = Maps.newHashMap(); + final Map mapAgg = Maps.newHashMap(); for (Entry> function : registeredFunctions.asMap().entrySet()) { final ArrayListMultimap, DrillFuncHolder> functions = ArrayListMultimap.create(); final ArrayListMultimap aggregateFunctions = ArrayListMultimap.create(); @@ -146,20 +158,79 @@ public void register(DrillOperatorTable operatorTable) { } } for (Entry, Collection> entry : functions.asMap().entrySet()) { - final DrillSqlOperator drillSqlOperator; final Pair range = entry.getKey(); final int max = range.getRight(); final int min = range.getLeft(); - drillSqlOperator = new DrillSqlOperator( - name, - Lists.newArrayList(entry.getValue()), - min, - max, - isDeterministic); - operatorTable.add(name, drillSqlOperator); + if(map.containsKey(name)) { + final DrillSqlOperator.DrillSqlOperatorBuilder drillSqlOperatorBuilder = map.get(name); + drillSqlOperatorBuilder + .addFunctions(entry.getValue()) + .setArgumentCount(min, max) + .setDeterministic(isDeterministic); + } else { + final DrillSqlOperator.DrillSqlOperatorBuilder drillSqlOperatorBuilder = new DrillSqlOperator.DrillSqlOperatorBuilder(); + drillSqlOperatorBuilder + .setName(name) + .addFunctions(entry.getValue()) + .setArgumentCount(min, max) + .setDeterministic(isDeterministic); + + map.put(name, drillSqlOperatorBuilder); + } } for (Entry> entry : aggregateFunctions.asMap().entrySet()) { - operatorTable.add(name, new DrillSqlAggOperator(name, Lists.newArrayList(entry.getValue()), entry.getKey())); + if(mapAgg.containsKey(name)) { + final DrillSqlAggOperator.DrillSqlAggOperatorBuilder drillSqlAggOperatorBuilder = mapAgg.get(name); + drillSqlAggOperatorBuilder + .addFunctions(entry.getValue()) + .setArgumentCount(entry.getKey(), entry.getKey()); + } else { + final DrillSqlAggOperator.DrillSqlAggOperatorBuilder drillSqlAggOperatorBuilder = new DrillSqlAggOperator.DrillSqlAggOperatorBuilder(); + drillSqlAggOperatorBuilder + .setName(name) + .addFunctions(entry.getValue()) + .setArgumentCount(entry.getKey(), entry.getKey()); + + mapAgg.put(name, drillSqlAggOperatorBuilder); + } + } + } + + for(final Entry entry : map.entrySet()) { + operatorTable.addInference( + entry.getKey(), + entry.getValue().build()); + } + + for(final Entry entry : mapAgg.entrySet()) { + operatorTable.addInference( + entry.getKey(), + entry.getValue().build()); + } + } + + public void registerForDefault(DrillOperatorTable operatorTable) { + SqlOperator op; + for (Entry> function : registeredFunctions.asMap().entrySet()) { + Set argCounts = Sets.newHashSet(); + String name = function.getKey().toUpperCase(); + for (DrillFuncHolder func : function.getValue()) { + if (argCounts.add(func.getParamCount())) { + if (func.isAggregating()) { + op = new DrillSqlAggOperatorNotInfer(name, func.getParamCount()); + } else { + boolean isDeterministic; + // prevent Drill from folding constant functions with types that cannot be materialized + // into literals + if (DrillConstExecutor.NON_REDUCIBLE_TYPES.contains(func.getReturnType().getMinorType())) { + isDeterministic = false; + } else { + isDeterministic = func.isDeterministic(); + } + op = new DrillSqlOperatorNotInfer(name, func.getParamCount(), func.getReturnType(), isDeterministic); + } + operatorTable.addDefault(function.getKey(), op); + } } } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/ops/QueryContext.java b/exec/java-exec/src/main/java/org/apache/drill/exec/ops/QueryContext.java index 51a581a5369..3ce0633305e 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/ops/QueryContext.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/ops/QueryContext.java @@ -86,7 +86,7 @@ public QueryContext(final UserSession session, final DrillbitContext drillbitCon executionControls = new ExecutionControls(queryOptions, drillbitContext.getEndpoint()); plannerSettings = new PlannerSettings(queryOptions, getFunctionRegistry()); plannerSettings.setNumEndPoints(drillbitContext.getBits().size()); - table = new DrillOperatorTable(getFunctionRegistry()); + table = new DrillOperatorTable(getFunctionRegistry(), drillbitContext.getOptionManager()); queryContextInfo = Utilities.createQueryContextInfo(session.getDefaultSchemaName()); contextInformation = new ContextInformation(session.getCredentials(), queryContextInfo); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java index 3a2510e02ec..8975e9fe8bc 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java @@ -18,7 +18,6 @@ package org.apache.drill.exec.planner.logical; -import com.google.common.collect.ImmutableList; import java.math.BigDecimal; import java.util.ArrayList; @@ -33,7 +32,16 @@ import org.apache.calcite.rel.InvalidRelException; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.fun.SqlCountAggFunction; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.trace.CalciteTrace; +import org.apache.drill.exec.planner.physical.PlannerSettings; +import org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper; +import org.apache.drill.exec.planner.sql.DrillCalciteSqlWrapper; +import org.apache.drill.exec.planner.sql.DrillSqlOperator; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.RelNode; import org.apache.calcite.plan.RelOptRule; @@ -51,15 +59,12 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.fun.SqlSumAggFunction; import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; -import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.CompositeList; import org.apache.calcite.util.ImmutableIntList; import org.apache.calcite.util.Util; -import org.apache.calcite.util.trace.CalciteTrace; -import org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper; -import org.apache.drill.exec.planner.sql.DrillCalciteSqlWrapper; -import org.apache.drill.exec.planner.sql.DrillSqlOperator; +import com.google.common.collect.ImmutableList; +import org.apache.drill.exec.planner.sql.TypeInferenceUtils; /** * Rule to reduce aggregates to simpler forms. Currently only AVG(x) to @@ -71,13 +76,21 @@ public class DrillReduceAggregatesRule extends RelOptRule { /** * The singleton. */ - public static final DrillReduceAggregatesRule INSTANCE = new DrillReduceAggregatesRule(operand(LogicalAggregate.class, any())); public static final DrillConvertSumToSumZero INSTANCE_SUM = - new DrillConvertSumToSumZero(operand(DrillAggregateRel.class, any())); - - private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false); + new DrillConvertSumToSumZero(operand(DrillAggregateRel.class, any())); + + private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false, + new SqlReturnTypeInference() { + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return TypeInferenceUtils.createCalciteTypeWithNullability( + opBinding.getTypeFactory(), + SqlTypeName.ANY, + opBinding.getOperandType(0).isNullable()); + } + }); //~ Constructors ----------------------------------------------------------- @@ -222,7 +235,6 @@ private RexNode reduceAgg( // case COUNT(x) when 0 then null else SUM0(x) end return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping); } - if (sqlAggFunction instanceof SqlAvgAggFunction) { final SqlAvgAggFunction.Subtype subtype = ((SqlAvgAggFunction) sqlAggFunction).getSubtype(); switch (subtype) { @@ -292,7 +304,8 @@ private RexNode reduceAvg( AggregateCall oldCall, List newCalls, Map aggCallMapping) { - final boolean isWrapper = useWrapper(oldCall); + final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext(); + final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled(); final int nGroups = oldAggRel.getGroupCount(); RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory(); @@ -302,25 +315,12 @@ private RexNode reduceAvg( getFieldType( oldAggRel.getInput(), iAvgInput); - - final RelDataType sumType; - if(isWrapper) { - sumType = oldCall.getType(); - } else { - sumType = - typeFactory.createTypeWithNullability( - avgInputType, - avgInputType.isNullable() || nGroups == 0); - } + RelDataType sumType = + typeFactory.createTypeWithNullability( + avgInputType, + avgInputType.isNullable() || nGroups == 0); // SqlAggFunction sumAgg = new SqlSumAggFunction(sumType); - SqlAggFunction sumAgg; - if(isWrapper) { - sumAgg = new DrillCalciteSqlAggFunctionWrapper( - new SqlSumEmptyIsZeroAggFunction(), sumType); - } else { - sumAgg = new SqlSumEmptyIsZeroAggFunction(); - } - + SqlAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction(); AggregateCall sumCall = new AggregateCall( sumAgg, @@ -385,15 +385,21 @@ private RexNode reduceAvg( newCalls, aggCallMapping, ImmutableList.of(avgInputType)); - final RexNode divideRef = - rexBuilder.makeCall( - SqlStdOperatorTable.DIVIDE, - numeratorRef, - denominatorRef); - - if(isWrapper) { - return divideRef; + if(isInferenceEnabled) { + return rexBuilder.makeCall( + new DrillSqlOperator( + "divide", + 2, + true, + oldCall.getType()), + numeratorRef, + denominatorRef); } else { + final RexNode divideRef = + rexBuilder.makeCall( + SqlStdOperatorTable.DIVIDE, + numeratorRef, + denominatorRef); return rexBuilder.makeCast( typeFactory.createSqlType(SqlTypeName.ANY), divideRef); } @@ -404,34 +410,29 @@ private RexNode reduceSum( AggregateCall oldCall, List newCalls, Map aggCallMapping) { - final boolean isWrapper = useWrapper(oldCall); + final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext(); + final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled(); final int nGroups = oldAggRel.getGroupCount(); RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory(); RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder(); - - final RelDataType argType; - if(isWrapper) { - argType = oldCall.getType(); - } else { - int arg = oldCall.getArgList().get(0); - argType = - getFieldType( - oldAggRel.getInput(), - arg); - } - + int arg = oldCall.getArgList().get(0); + RelDataType argType = + getFieldType( + oldAggRel.getInput(), + arg); final RelDataType sumType; final SqlAggFunction sumZeroAgg; - if(isWrapper) { + if(isInferenceEnabled) { sumType = oldCall.getType(); sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper( new SqlSumEmptyIsZeroAggFunction(), sumType); } else { - sumType = typeFactory.createTypeWithNullability(argType, argType.isNullable()); + sumType = + typeFactory.createTypeWithNullability( + argType, argType.isNullable()); sumZeroAgg = new SqlSumEmptyIsZeroAggFunction(); } - AggregateCall sumZeroCall = new AggregateCall( sumZeroAgg, @@ -488,7 +489,6 @@ private RexNode reduceStddev( List newCalls, Map aggCallMapping, List inputExprs) { - final boolean isWrapper = useWrapper(oldCall); // stddev_pop(x) ==> // power( // (sum(x * x) - sum(x) * sum(x) / count(x)) @@ -500,6 +500,8 @@ private RexNode reduceStddev( // (sum(x * x) - sum(x) * sum(x) / count(x)) // / nullif(count(x) - 1, 0), // .5) + final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext(); + final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled(); final int nGroups = oldAggRel.getGroupCount(); RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory(); @@ -525,26 +527,13 @@ private RexNode reduceStddev( typeFactory.createTypeWithNullability( argType, true); - final AggregateCall sumArgSquaredAggCall; - if(isWrapper) { - sumArgSquaredAggCall = - new AggregateCall( - new DrillCalciteSqlAggFunctionWrapper( - new SqlSumAggFunction(sumType), sumType), - oldCall.isDistinct(), - ImmutableIntList.of(argSquaredOrdinal), - sumType, - null); - } else { - sumArgSquaredAggCall = - new AggregateCall( - new SqlSumAggFunction(sumType), - oldCall.isDistinct(), - ImmutableIntList.of(argSquaredOrdinal), - sumType, - null); - } - + final AggregateCall sumArgSquaredAggCall = + new AggregateCall( + new SqlSumAggFunction(sumType), + oldCall.isDistinct(), + ImmutableIntList.of(argSquaredOrdinal), + sumType, + null); final RexNode sumArgSquared = rexBuilder.addAggCall( sumArgSquaredAggCall, @@ -554,26 +543,13 @@ private RexNode reduceStddev( aggCallMapping, ImmutableList.of(argType)); - final AggregateCall sumArgAggCall; - if(isWrapper) { - sumArgAggCall = - new AggregateCall( - new DrillCalciteSqlAggFunctionWrapper( - new SqlSumAggFunction(sumType), sumType), - oldCall.isDistinct(), - ImmutableIntList.of(argOrdinal), - sumType, - null); - } else { - sumArgAggCall = - new AggregateCall( - new SqlSumAggFunction(sumType), - oldCall.isDistinct(), - ImmutableIntList.of(argOrdinal), - sumType, - null); - } - + final AggregateCall sumArgAggCall = + new AggregateCall( + new SqlSumAggFunction(sumType), + oldCall.isDistinct(), + ImmutableIntList.of(argOrdinal), + sumType, + null); final RexNode sumArg = rexBuilder.addAggCall( sumArgAggCall, @@ -635,9 +611,20 @@ private RexNode reduceStddev( countEqOne, nul, countMinusOne); } + final SqlOperator divide; + if(isInferenceEnabled) { + divide = new DrillSqlOperator( + "divide", + 2, + true, + oldCall.getType()); + } else { + divide = SqlStdOperatorTable.DIVIDE; + } + final RexNode div = rexBuilder.makeCall( - SqlStdOperatorTable.DIVIDE, diff, denominator); + divide, diff, denominator); RexNode result = div; if (sqrt) { @@ -648,17 +635,17 @@ private RexNode reduceStddev( SqlStdOperatorTable.POWER, div, half); } - /* - * Currently calcite's strategy to infer the return type of aggregate functions - * is wrong because it uses the first known argument to determine output type. For - * instance if we are performing stddev on an integer column then it interprets the - * output type to be integer which is incorrect as it should be double. So based on - * this if we add cast after rewriting the aggregate we add an additional cast which - * would cause wrong results. So we simply add a cast to ANY. - */ - if(isWrapper) { + if(isInferenceEnabled) { return result; } else { + /* + * Currently calcite's strategy to infer the return type of aggregate functions + * is wrong because it uses the first known argument to determine output type. For + * instance if we are performing stddev on an integer column then it interprets the + * output type to be integer which is incorrect as it should be double. So based on + * this if we add cast after rewriting the aggregate we add an additional cast which + * would cause wrong results. So we simply add a cast to ANY. + */ return rexBuilder.makeCast( typeFactory.createSqlType(SqlTypeName.ANY), result); } @@ -704,10 +691,6 @@ private RelDataType getFieldType(RelNode relNode, int i) { return inputField.getType(); } - private boolean useWrapper(AggregateCall aggregateCall) { - return aggregateCall.getAggregation() instanceof DrillCalciteSqlWrapper; - } - private static class DrillConvertSumToSumZero extends RelOptRule { protected static final Logger tracer = CalciteTrace.getPlannerTracer(); @@ -756,11 +739,11 @@ public void onMatch(RelOptRuleCall call) { new SqlSumEmptyIsZeroAggFunction(), sumType); AggregateCall sumZeroCall = new AggregateCall( - sumZeroAgg, - oldAggregateCall.isDistinct(), - oldAggregateCall.getArgList(), - sumType, - null); + sumZeroAgg, + oldAggregateCall.isDistinct(), + oldAggregateCall.getArgList(), + sumType, + null); oldAggRel.getCluster().getRexBuilder() .addAggCall(sumZeroCall, oldAggRel.getGroupCount(), diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java index 3eb50383071..a98619ce6fa 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java @@ -81,6 +81,9 @@ public class PlannerSettings implements Context{ new RangeLongValidator("planner.identifier_max_length", 128 /* A minimum length is needed because option names are identifiers themselves */, Integer.MAX_VALUE, DEFAULT_IDENTIFIER_MAX_LENGTH); + public static final String TYPE_INFERENCE_KEY = "planner.type_inference.enable"; + public static final BooleanValidator TYPE_INFERENCE = new BooleanValidator(TYPE_INFERENCE_KEY, true); + public OptionManager options = null; public FunctionImplementationRegistry functionImplementationRegistry = null; @@ -209,6 +212,10 @@ public static long getInitialPlanningMemorySize() { return INITIAL_OFF_HEAP_ALLOCATION_IN_BYTES; } + public boolean isTypeInferenceEnabled() { + return options.getOption(TYPE_INFERENCE.getOptionName()).bool_val; + } + @Override public T unwrap(Class clazz) { if(clazz == PlannerSettings.class){ diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java index 97317be8471..068423e01c6 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillAvgVarianceConvertlet.java @@ -23,9 +23,12 @@ import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNumericLiteral; +import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.fun.SqlAvgAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql2rel.SqlRexContext; import org.apache.calcite.sql2rel.SqlRexConvertlet; import org.apache.calcite.util.Util; @@ -40,7 +43,16 @@ public class DrillAvgVarianceConvertlet implements SqlRexConvertlet { private final SqlAvgAggFunction.Subtype subtype; - private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false); + private static final DrillSqlOperator CastHighOp = new DrillSqlOperator("CastHigh", 1, false, + new SqlReturnTypeInference() { + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return TypeInferenceUtils.createCalciteTypeWithNullability( + opBinding.getTypeFactory(), + SqlTypeName.ANY, + opBinding.getOperandType(0).isNullable()); + } + }); public DrillAvgVarianceConvertlet(SqlAvgAggFunction.Subtype subtype) { this.subtype = subtype; diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java index 7fe6020b772..de18f029872 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java @@ -24,6 +24,7 @@ import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlPrefixOperator; import org.apache.drill.common.expression.FunctionCallFactory; +import org.apache.drill.exec.ExecConstants; import org.apache.drill.exec.expr.fn.DrillFuncHolder; import org.apache.drill.exec.expr.fn.FunctionImplementationRegistry; import org.apache.calcite.sql.SqlFunctionCategory; @@ -32,6 +33,8 @@ import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.SqlSyntax; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.drill.exec.planner.physical.PlannerSettings; +import org.apache.drill.exec.server.options.SystemOptionManager; import java.util.List; import java.util.Map; @@ -43,24 +46,49 @@ public class DrillOperatorTable extends SqlStdOperatorTable { // private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillOperatorTable.class); private static final SqlOperatorTable inner = SqlStdOperatorTable.instance(); - private List operators = Lists.newArrayList(); + private final List operatorsCalcite = Lists.newArrayList(); + private final List operatorsDefault = Lists.newArrayList(); + private final List operatorsInferernce = Lists.newArrayList(); private final Map calciteToWrapper = Maps.newIdentityHashMap(); - private ArrayListMultimap opMap = ArrayListMultimap.create(); + + private final ArrayListMultimap opMapDefault = ArrayListMultimap.create(); + private final ArrayListMultimap opMapInferernce = ArrayListMultimap.create(); + + private final SystemOptionManager systemOptionManager; public DrillOperatorTable(FunctionImplementationRegistry registry) { + this(registry, null); + } + + public DrillOperatorTable(FunctionImplementationRegistry registry, SystemOptionManager systemOptionManager) { registry.register(this); - operators.addAll(inner.getOperatorList()); + operatorsCalcite.addAll(inner.getOperatorList()); populateWrappedCalciteOperators(); + this.systemOptionManager = systemOptionManager; + } + + public void addDefault(String name, SqlOperator op) { + operatorsDefault.add(op); + opMapDefault.put(name.toLowerCase(), op); } - public void add(String name, SqlOperator op) { - operators.add(op); - opMap.put(name.toLowerCase(), op); + public void addInference(String name, SqlOperator op) { + operatorsInferernce.add(op); + opMapInferernce.put(name.toLowerCase(), op); } @Override public void lookupOperatorOverloads(SqlIdentifier opName, SqlFunctionCategory category, SqlSyntax syntax, List operatorList) { + if(isEnableInference()) { + populateFromTypeInference(opName, category, syntax, operatorList); + } else { + populateFromDefault(opName, category, syntax, operatorList); + } + } + + private void populateFromTypeInference(SqlIdentifier opName, SqlFunctionCategory category, + SqlSyntax syntax, List operatorList) { final List calciteOperatorList = Lists.newArrayList(); inner.lookupOperatorOverloads(opName, category, syntax, calciteOperatorList); if(!calciteOperatorList.isEmpty()) { @@ -74,7 +102,7 @@ public void lookupOperatorOverloads(SqlIdentifier opName, SqlFunctionCategory ca } else { // if no function is found, check in Drill UDFs if (operatorList.isEmpty() && syntax == SqlSyntax.FUNCTION && opName.isSimple()) { - List drillOps = opMap.get(opName.getSimple().toLowerCase()); + List drillOps = opMapInferernce.get(opName.getSimple().toLowerCase()); if (drillOps != null && !drillOps.isEmpty()) { operatorList.addAll(drillOps); } @@ -82,14 +110,37 @@ public void lookupOperatorOverloads(SqlIdentifier opName, SqlFunctionCategory ca } } + private void populateFromDefault(SqlIdentifier opName, SqlFunctionCategory category, + SqlSyntax syntax, List operatorList) { + inner.lookupOperatorOverloads(opName, category, syntax, operatorList); + if (operatorList.isEmpty() && syntax == SqlSyntax.FUNCTION && opName.isSimple()) { + List drillOps = opMapDefault.get(opName.getSimple().toLowerCase()); + if (drillOps != null) { + operatorList.addAll(drillOps); + } + } + } + @Override public List getOperatorList() { - return operators; + final List sqlOperators = Lists.newArrayList(); + sqlOperators.addAll(operatorsCalcite); + if(isEnableInference()) { + sqlOperators.addAll(operatorsInferernce); + } else { + sqlOperators.addAll(operatorsDefault); + } + + return sqlOperators; } // Get the list of SqlOperator's with the given name. public List getSqlOperator(String name) { - return opMap.get(name.toLowerCase()); + if(isEnableInference()) { + return opMapInferernce.get(name.toLowerCase()); + } else { + return opMapDefault.get(name.toLowerCase()); + } } private void populateWrappedCalciteOperators() { @@ -97,14 +148,14 @@ private void populateWrappedCalciteOperators() { final SqlOperator wrapper; if(calciteOperator instanceof SqlAggFunction) { wrapper = new DrillCalciteSqlAggFunctionWrapper((SqlAggFunction) calciteOperator, - getFunctionList(calciteOperator.getName())); + getFunctionListWithInference(calciteOperator.getName())); } else if(calciteOperator instanceof SqlFunction) { wrapper = new DrillCalciteSqlFunctionWrapper((SqlFunction) calciteOperator, - getFunctionList(calciteOperator.getName())); + getFunctionListWithInference(calciteOperator.getName())); } else { final String drillOpName = FunctionCallFactory.replaceOpWithFuncName(calciteOperator.getName()); - final List drillFuncHolders = getFunctionList(drillOpName); - if(drillFuncHolders.isEmpty() || calciteOperator == SqlStdOperatorTable.UNARY_MINUS) { + final List drillFuncHolders = getFunctionListWithInference(drillOpName); + if(drillFuncHolders.isEmpty() || calciteOperator == SqlStdOperatorTable.UNARY_MINUS || calciteOperator == SqlStdOperatorTable.UNARY_PLUS) { continue; } @@ -114,9 +165,9 @@ private void populateWrappedCalciteOperators() { } } - private List getFunctionList(String name) { + private List getFunctionListWithInference(String name) { final List functions = Lists.newArrayList(); - for(SqlOperator sqlOperator : opMap.get(name.toLowerCase())) { + for(SqlOperator sqlOperator : opMapInferernce.get(name.toLowerCase())) { if(sqlOperator instanceof DrillSqlOperator) { final List list = ((DrillSqlOperator) sqlOperator).getFunctions(); if(list != null) { @@ -133,4 +184,9 @@ private List getFunctionList(String name) { } return functions; } + + private boolean isEnableInference() { + return systemOptionManager != null + && systemOptionManager.getOption(PlannerSettings.TYPE_INFERENCE); + } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java index 81c744c2fce..044f5b05ee5 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java @@ -17,36 +17,76 @@ */ package org.apache.drill.exec.planner.sql; -import org.apache.calcite.rel.type.RelDataType; +import com.google.common.collect.Lists; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.drill.exec.expr.fn.DrillFuncHolder; -import java.util.ArrayList; +import java.util.Collection; import java.util.List; public class DrillSqlAggOperator extends SqlAggFunction { // private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillSqlAggOperator.class); private final List functions; - public DrillSqlAggOperator(String name, List functions, int argCount) { + protected DrillSqlAggOperator(String name, List functions, int argCountMin, int argCountMax, SqlReturnTypeInference sqlReturnTypeInference) { super(name, new SqlIdentifier(name, SqlParserPos.ZERO), SqlKind.OTHER_FUNCTION, - TypeInferenceUtils.getDrillSqlReturnTypeInference( - name, - functions), + sqlReturnTypeInference, null, - Checker.getChecker(argCount, argCount), + Checker.getChecker(argCountMin, argCountMax), SqlFunctionCategory.USER_DEFINED_FUNCTION); this.functions = functions; } + private DrillSqlAggOperator(String name, List functions, int argCountMin, int argCountMax) { + this(name, + functions, + argCountMin, + argCountMax, + TypeInferenceUtils.getDrillSqlReturnTypeInference( + name, + functions)); + } + public List getFunctions() { return functions; } + + public static class DrillSqlAggOperatorBuilder { + private String name; + private final List functions = Lists.newArrayList(); + private int argCountMin = Integer.MAX_VALUE; + private int argCountMax = Integer.MIN_VALUE; + private boolean isDeterministic = true; + + public DrillSqlAggOperatorBuilder setName(final String name) { + this.name = name; + return this; + } + + public DrillSqlAggOperatorBuilder addFunctions(Collection functions) { + this.functions.addAll(functions); + return this; + } + + public DrillSqlAggOperatorBuilder setArgumentCount(final int argCountMin, final int argCountMax) { + this.argCountMin = Math.min(this.argCountMin, argCountMin); + this.argCountMax = Math.max(this.argCountMax, argCountMax); + return this; + } + + public DrillSqlAggOperator build() { + return new DrillSqlAggOperator( + name, + functions, + argCountMin, + argCountMax); + } + } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorNotInfer.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorNotInfer.java new file mode 100644 index 00000000000..592c23edce7 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorNotInfer.java @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.sql; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorScope; +import org.apache.drill.exec.expr.fn.DrillFuncHolder; + +import java.util.ArrayList; + +public class DrillSqlAggOperatorNotInfer extends DrillSqlAggOperator { + public DrillSqlAggOperatorNotInfer(String name, int argCount) { + super(name, new ArrayList(), argCount, argCount, DynamicReturnType.INSTANCE); + } + + @Override + public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { + return getAny(validator.getTypeFactory()); + } + + private RelDataType getAny(RelDataTypeFactory factory){ + return factory.createSqlType(SqlTypeName.ANY); + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java index 0873c8df3ea..1bb62f3963a 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java @@ -19,12 +19,17 @@ package org.apache.drill.exec.planner.sql; import java.util.ArrayList; +import java.util.Collection; import java.util.List; +import com.google.common.collect.Lists; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.drill.exec.expr.fn.DrillFuncHolder; public class DrillSqlOperator extends SqlFunction { @@ -39,15 +44,54 @@ public class DrillSqlOperator extends SqlFunction { * In principle, if Drill needs a DrillSqlOperator, it is supposed to go to DrillOperatorTable for pickup. */ @Deprecated - public DrillSqlOperator(String name, int argCount, boolean isDeterministic) { - this(name, new ArrayList(), argCount, argCount, isDeterministic); + public DrillSqlOperator(final String name, final int argCount, final boolean isDeterministic) { + this(name, + argCount, + isDeterministic, + DynamicReturnType.INSTANCE); } - public DrillSqlOperator(String name, List functions, int argCountMin, int argCountMax, boolean isDeterministic) { + /** + * This constructor exists for the legacy reason. + * + * It is because Drill cannot access to DrillOperatorTable at the place where this constructor is being called. + * In principle, if Drill needs a DrillSqlOperator, it is supposed to go to DrillOperatorTable for pickup. + */ + @Deprecated + public DrillSqlOperator(final String name, final int argCount, final boolean isDeterministic, + final SqlReturnTypeInference sqlReturnTypeInference) { + this(name, + new ArrayList(), + argCount, + argCount, + isDeterministic, + sqlReturnTypeInference); + } + + /** + * This constructor exists for the legacy reason. + * + * It is because Drill cannot access to DrillOperatorTable at the place where this constructor is being called. + * In principle, if Drill needs a DrillSqlOperator, it is supposed to go to DrillOperatorTable for pickup. + */ + @Deprecated + public DrillSqlOperator(final String name, final int argCount, final boolean isDeterministic, final RelDataType type) { + this(name, + new ArrayList(), + argCount, + argCount, + isDeterministic, new SqlReturnTypeInference() { + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return type; + } + }); + } + + protected DrillSqlOperator(String name, List functions, int argCountMin, int argCountMax, boolean isDeterministic, + SqlReturnTypeInference sqlReturnTypeInference) { super(new SqlIdentifier(name, SqlParserPos.ZERO), - TypeInferenceUtils.getDrillSqlReturnTypeInference( - name, - functions), + sqlReturnTypeInference, null, Checker.getChecker(argCountMin, argCountMax), null, @@ -64,4 +108,47 @@ public boolean isDeterministic() { public List getFunctions() { return functions; } + + public static class DrillSqlOperatorBuilder { + private String name; + private final List functions = Lists.newArrayList(); + private int argCountMin = Integer.MAX_VALUE; + private int argCountMax = Integer.MIN_VALUE; + private boolean isDeterministic = true; + + public DrillSqlOperatorBuilder setName(final String name) { + this.name = name; + return this; + } + + public DrillSqlOperatorBuilder addFunctions(Collection functions) { + this.functions.addAll(functions); + return this; + } + + public DrillSqlOperatorBuilder setArgumentCount(final int argCountMin, final int argCountMax) { + this.argCountMin = Math.min(this.argCountMin, argCountMin); + this.argCountMax = Math.max(this.argCountMax, argCountMax); + return this; + } + + public DrillSqlOperatorBuilder setDeterministic(boolean isDeterministic) { + if(this.isDeterministic) { + this.isDeterministic = isDeterministic; + } + return this; + } + + public DrillSqlOperator build() { + return new DrillSqlOperator( + name, + functions, + argCountMin, + argCountMax, + isDeterministic, + TypeInferenceUtils.getDrillSqlReturnTypeInference( + name, + functions)); + } + } } \ No newline at end of file diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorNotInfer.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorNotInfer.java new file mode 100644 index 00000000000..a7394bd0a81 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorNotInfer.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.sql; + +import com.google.common.base.Preconditions; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorScope; +import org.apache.drill.common.types.TypeProtos; +import org.apache.drill.exec.expr.fn.DrillFuncHolder; + +import java.util.ArrayList; + +public class DrillSqlOperatorNotInfer extends DrillSqlOperator { + private static final TypeProtos.MajorType NONE = TypeProtos.MajorType.getDefaultInstance(); + private final TypeProtos.MajorType returnType; + + public DrillSqlOperatorNotInfer(String name, int argCount, TypeProtos.MajorType returnType, boolean isDeterminisitic) { + super(name, + new ArrayList< DrillFuncHolder>(), + argCount, + argCount, + isDeterminisitic, + DynamicReturnType.INSTANCE); + this.returnType = Preconditions.checkNotNull(returnType); + } + + protected RelDataType getReturnDataType(final RelDataTypeFactory factory) { + if (TypeProtos.MinorType.BIT.equals(returnType.getMinorType())) { + return factory.createSqlType(SqlTypeName.BOOLEAN); + } + return factory.createTypeWithNullability(factory.createSqlType(SqlTypeName.ANY), true); + } + + private RelDataType getNullableReturnDataType(final RelDataTypeFactory factory) { + return factory.createTypeWithNullability(getReturnDataType(factory), true); + } + + @Override + public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) { + if (NONE.equals(returnType)) { + return validator.getTypeFactory().createSqlType(SqlTypeName.ANY); + } + /* + * We return a nullable output type both in validation phase and in + * Sql to Rel phase. We don't know the type of the output until runtime + * hence have to choose the least restrictive type to avoid any wrong + * results. + */ + return getNullableReturnDataType(validator.getTypeFactory()); + } + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return getNullableReturnDataType(opBinding.getTypeFactory()); + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/SqlConverter.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/SqlConverter.java index 2e0afeac438..fc63276661e 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/SqlConverter.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/SqlConverter.java @@ -19,7 +19,9 @@ import java.util.Arrays; import java.util.List; +import java.util.Set; +import com.google.common.collect.Sets; import org.apache.calcite.adapter.java.JavaTypeFactory; import org.apache.calcite.avatica.util.Casing; import org.apache.calcite.avatica.util.Quoting; @@ -37,18 +39,23 @@ import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeSystemImpl; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.sql.parser.SqlParserImplFactory; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.util.ChainedSqlOperatorTable; +import org.apache.calcite.sql.validate.AggregatingSelectScope; import org.apache.calcite.sql.validate.SqlConformance; import org.apache.calcite.sql.validate.SqlValidatorCatalogReader; import org.apache.calcite.sql.validate.SqlValidatorImpl; +import org.apache.calcite.sql.validate.SqlValidatorScope; import org.apache.calcite.sql2rel.RelDecorrelator; import org.apache.calcite.sql2rel.SqlToRelConverter; import org.apache.drill.common.exceptions.UserException; @@ -183,10 +190,40 @@ public SchemaPlus getDefaultSchema() { } private class DrillValidator extends SqlValidatorImpl { + private final Set identitySet = Sets.newIdentityHashSet(); + protected DrillValidator(SqlOperatorTable opTab, SqlValidatorCatalogReader catalogReader, RelDataTypeFactory typeFactory, SqlConformance conformance) { super(opTab, catalogReader, typeFactory, conformance); } + + @Override + public SqlValidatorScope getSelectScope(final SqlSelect select) { + final SqlValidatorScope sqlValidatorScope = super.getSelectScope(select); + if(needsValidation(sqlValidatorScope)) { + final AggregatingSelectScope aggregatingSelectScope = ((AggregatingSelectScope) sqlValidatorScope); + for(SqlNode sqlNode : aggregatingSelectScope.groupExprList) { + if(sqlNode instanceof SqlCall) { + final SqlCall sqlCall = (SqlCall) sqlNode; + sqlCall.getOperator().deriveType(this, sqlValidatorScope, sqlCall); + } + } + identitySet.add(sqlValidatorScope); + } + return sqlValidatorScope; + } + + // Due to the deep-copy of AggregatingSelectScope in the following two commits in the Forked Drill-Calcite: + // 1. [StarColumn] Reverse one change in CALCITE-356, which regresses AggChecker logic, after * query in schema-less table is added. + // 2. [StarColumn] When group-by a column, projecting on a star which cannot be expanded at planning time, + // use ITEM operator to wrap this column + private boolean needsValidation(final SqlValidatorScope sqlValidatorScope) { + if(sqlValidatorScope instanceof AggregatingSelectScope) { + return !identitySet.contains(sqlValidatorScope); + } else { + return false; + } + } } private static class DrillTypeSystem extends RelDataTypeSystemImpl { @@ -218,7 +255,7 @@ public int getMaxNumericPrecision() { public RelNode toRel( final SqlNode validatedNode) { - final RexBuilder rexBuilder = new RexBuilder(typeFactory); + final RexBuilder rexBuilder = new DrillRexBuilder(typeFactory); if (planner == null) { planner = new VolcanoPlanner(costFactory, settings); planner.setExecutor(new DrillConstExecutor(functions, util, settings)); @@ -364,4 +401,35 @@ private static SchemaPlus rootSchema(SchemaPlus schema) { } } + private static class DrillRexBuilder extends RexBuilder { + private DrillRexBuilder(RelDataTypeFactory typeFactory) { + super(typeFactory); + } + + @Override + public RexNode ensureType( + RelDataType type, + RexNode node, + boolean matchNullability) { + RelDataType targetType = type; + if (matchNullability) { + targetType = matchNullability(type, node); + } + if (targetType.getSqlTypeName() == SqlTypeName.ANY) { + return node; + } + if (!node.getType().equals(targetType)) { + if(!targetType.isStruct()) { + final RelDataType anyType = TypeInferenceUtils.createCalciteTypeWithNullability( + getTypeFactory(), + SqlTypeName.ANY, + targetType.isNullable()); + return makeCast(anyType, node); + } else { + return makeCast(targetType, node); + } + } + return node; + } + } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java index 8914b1133d4..9af6fa302d0 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java @@ -227,7 +227,9 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { final DrillFuncHolder func = resolveDrillFuncHolder(opBinding, functions); final RelDataType returnType = getReturnType(opBinding, func); - return returnType; + return returnType.getSqlTypeName() == SqlTypeName.VARBINARY + ? createCalciteTypeWithNullability(factory, SqlTypeName.ANY, returnType.isNullable()) + : returnType; } private static RelDataType getReturnType(final SqlOperatorBinding opBinding, final DrillFuncHolder func) { @@ -512,19 +514,18 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { RelDataType ret = factory.createTypeWithNullability( opBinding.getOperandType(1), isNullable); - if (opBinding instanceof SqlCallBinding) { SqlCallBinding callBinding = (SqlCallBinding) opBinding; SqlNode operand0 = callBinding.operand(0); // dynamic parameters and null constants need their types assigned // to them using the type they are casted to. - if (((operand0 instanceof SqlLiteral) - && (((SqlLiteral) operand0).getValue() == null)) + if(((operand0 instanceof SqlLiteral) + && (((SqlLiteral) operand0).getValue() == null)) || (operand0 instanceof SqlDynamicParam)) { callBinding.getValidator().setValidatedNodeType( - operand0, - ret); + operand0, + ret); } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/server/options/SystemOptionManager.java b/exec/java-exec/src/main/java/org/apache/drill/exec/server/options/SystemOptionManager.java index 1e54e5c02ce..cbc5c095e00 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/server/options/SystemOptionManager.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/server/options/SystemOptionManager.java @@ -83,6 +83,7 @@ public class SystemOptionManager extends BaseOptionManager implements AutoClosea PlannerSettings.HEP_OPT, PlannerSettings.PLANNER_MEMORY_LIMIT, PlannerSettings.HEP_PARTITION_PRUNING, + PlannerSettings.TYPE_INFERENCE, ExecConstants.CAST_TO_NULLABLE_NUMERIC_OPTION, ExecConstants.OUTPUT_FORMAT_VALIDATOR, ExecConstants.PARQUET_BLOCK_SIZE_VALIDATOR, diff --git a/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java b/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java index 81d093c88b4..ad9a2053172 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java +++ b/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java @@ -29,7 +29,7 @@ public class TestFunctionsWithTypeExpoQueries extends BaseTestQuery { @Test public void testConcatWithMoreThanTwoArgs() throws Exception { - final String query = "select concat(r_name, r_name, r_name) as col \n" + + final String query = "select concat(r_name, r_name, r_name, 'f') as col \n" + "from cp.`tpch/region.parquet` limit 0"; List> expectedSchema = Lists.newArrayList(); @@ -58,7 +58,6 @@ public void testRow_NumberInView() throws Exception { " over(order by position_id) as rnum " + " from cp.`employee.json`)"; - final String view2 = "create view TestFunctionsWithTypeExpoQueries_testViewShield2 as \n" + "select row_number() over(order by position_id) as rnum, " + @@ -68,7 +67,6 @@ public void testRow_NumberInView() throws Exception { test(view1); test(view2); - testBuilder() .sqlQuery("select * from TestFunctionsWithTypeExpoQueries_testViewShield1") .ordered() @@ -113,6 +111,38 @@ public void testLRBTrimOneArg() throws Exception { .run(); } + @Test + public void testTrim() throws Exception { + final String query1 = "SELECT trim('drill') as col FROM cp.`tpch/region.parquet` limit 0"; + final String query2 = "SELECT trim('drill') as col FROM cp.`tpch/region.parquet` limit 0"; + final String query3 = "SELECT trim('drill') as col FROM cp.`tpch/region.parquet` limit 0"; + + List> expectedSchema = Lists.newArrayList(); + TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.VARCHAR) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query1) + .schemaBaseLine(expectedSchema) + .build() + .run(); + + testBuilder() + .sqlQuery(query2) + .schemaBaseLine(expectedSchema) + .build() + .run(); + + testBuilder() + .sqlQuery(query3) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + @Test public void testTrimOneArg() throws Exception { final String query1 = "SELECT trim(leading 'drill') as col FROM cp.`tpch/region.parquet` limit 0"; @@ -294,7 +324,7 @@ public void testSumRequiredType() throws Exception { } @Test - public void testSQRT() throws Exception { + public void testSQRTDecimalLiteral() throws Exception { final String query = "SELECT sqrt(5.1) as col \n" + "from cp.`tpch/nation.parquet` \n" + "limit 0"; @@ -313,6 +343,26 @@ public void testSQRT() throws Exception { .run(); } + @Test + public void testSQRTIntegerLiteral() throws Exception { + final String query = "SELECT sqrt(4) as col \n" + + "from cp.`tpch/nation.parquet` \n" + + "limit 0"; + + List> expectedSchema = Lists.newArrayList(); + TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.FLOAT8) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + @Test public void testTimestampDiff() throws Exception { final String query = "select to_timestamp('2014-02-13 00:30:30','YYYY-MM-dd HH:mm:ss') - to_timestamp('2014-02-13 00:30:30','YYYY-MM-dd HH:mm:ss') as col \n" + @@ -333,6 +383,27 @@ public void testTimestampDiff() throws Exception { .run(); } + @Test + public void testEqualBetweenIntervalAndTimestampDiff() throws Exception { + final String query = "select to_timestamp('2016-11-02 10:00:00','YYYY-MM-dd HH:mm:ss') + interval '10-11' year to month as col \n" + + "from cp.`tpch/region.parquet` \n" + + "where (to_timestamp('2016-11-02 10:00:00','YYYY-MM-dd HH:mm:ss') - to_timestamp('2016-01-01 10:00:00','YYYY-MM-dd HH:mm:ss') < interval '5 10:00:00' day to second) \n" + + "limit 0"; + + final List> expectedSchema = Lists.newArrayList(); + final TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.TIMESTAMP) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + @Test public void testAvgAndSUM() throws Exception { final String query = "SELECT AVG(cast(r_regionkey as float)) AS `col1`, \n" + @@ -368,4 +439,113 @@ public void testAvgAndSUM() throws Exception { .build() .run(); } + + @Test + public void testAvgCountStar() throws Exception { + final String query = "select avg(distinct cast(r_regionkey as bigint)) + avg(cast(r_regionkey as integer)) as col1, \n" + + "sum(distinct cast(r_regionkey as bigint)) + 100 as col2, count(*) as col3 \n" + + "from cp.`tpch/region.parquet` alltypes_v \n" + + "where cast(r_regionkey as bigint) = 100000000000000000 \n" + + "limit 0"; + + final List> expectedSchema = Lists.newArrayList(); + final TypeProtos.MajorType majorType1 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.FLOAT8) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final TypeProtos.MajorType majorType2 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final TypeProtos.MajorType majorType3 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col1"), majorType1)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col2"), majorType2)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col3"), majorType3)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + + @Test // explain plan including all attributes for + public void testUDFInGroupBy() throws Exception { + final String query = "select count(*) as col1, substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2) as col2, \n" + + "char_length(substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2)) as col3 \n" + + "from cp.`tpch/region.parquet` t1 \n" + + "left outer join cp.`tpch/nation.parquet` t2 on cast(t1.r_regionkey as Integer) = cast(t2.n_nationkey as Integer) \n" + + "left outer join cp.`employee.json` t3 on cast(t1.r_regionkey as Integer) = cast(t3.employee_id as Integer) \n" + + "group by substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2), \n" + + "char_length(substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2)) \n" + + "order by substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2),\n" + + "char_length(substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2)) \n" + + "limit 0"; + + final List> expectedSchema = Lists.newArrayList(); + final TypeProtos.MajorType majorType1 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + + final TypeProtos.MajorType majorType2 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.VARCHAR) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final TypeProtos.MajorType majorType3 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col1"), majorType1)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col2"), majorType2)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col3"), majorType3)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + + @Test + public void testWindowSumAvg() throws Exception { + final String query = "with query as ( \n" + + "select sum(cast(employee_id as integer)) over w as col1, cast(avg(cast(employee_id as bigint)) over w as double precision) as col2, count(*) over w as col3 \n" + + "from cp.`tpch/region.parquet` \n" + + "window w as (partition by cast(full_name as varchar(10)) order by cast(full_name as varchar(10)) nulls first)) \n" + + "select * \n" + + "from query \n" + + "limit 0"; + + final List> expectedSchema = Lists.newArrayList(); + final TypeProtos.MajorType majorType1 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final TypeProtos.MajorType majorType2 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.FLOAT8) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final TypeProtos.MajorType majorType3 = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col1"), majorType1)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col2"), majorType2)); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col3"), majorType3)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } } From 488ba1aeca7c4bf83ef048a6b3515539c12fdaa9 Mon Sep 17 00:00:00 2001 From: Hsuan-Yi Chu Date: Mon, 14 Mar 2016 16:11:10 -0700 Subject: [PATCH 4/5] DRILL-4372: (continued) Support for Window functions: - CUME_DIST - DENSE_RANK - PERCENT_RANK - RANK - ROW_NUMBER - NTILE - LEAD - LAG - FIRST_VALUE - LAST_VALUE --- .../exec/expr/fn/HiveFunctionRegistry.java | 19 +- .../exec/planner/sql/HiveUDFOperator.java | 2 +- ...a => HiveUDFOperatorWithoutInference.java} | 4 +- .../exec/fn/hive/TestInbuiltHiveUDFs.java | 37 ++-- .../exec/expr/fn/DrillFunctionRegistry.java | 76 ++++---- .../expr/fn/PluggableFunctionRegistry.java | 6 +- .../logical/DrillReduceAggregatesRule.java | 26 +-- .../planner/logical/PreProcessLogicalRel.java | 9 +- .../planner/physical/PlannerSettings.java | 4 +- .../sql/DrillCalciteSqlOperatorWrapper.java | 2 - .../planner/sql/DrillConvertletTable.java | 3 +- .../planner/sql/DrillExtractConvertlet.java | 21 ++- .../exec/planner/sql/DrillOperatorTable.java | 72 ++++---- .../exec/planner/sql/DrillSqlAggOperator.java | 6 +- ... DrillSqlAggOperatorWithoutInference.java} | 4 +- .../exec/planner/sql/DrillSqlOperator.java | 12 ++ ... => DrillSqlOperatorWithoutInference.java} | 4 +- .../drill/exec/planner/sql/SqlConverter.java | 50 +----- .../exec/planner/sql/TypeInferenceUtils.java | 96 +++++++++- .../parser/DrillCalciteWrapperUtility.java | 76 ++++++++ .../parser/UnsupportedOperatorsVisitor.java | 14 +- .../java/org/apache/drill/PlanningBase.java | 2 +- .../TestFunctionsWithTypeExpoQueries.java | 168 +++++++++++++++++- 23 files changed, 493 insertions(+), 220 deletions(-) rename contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/{HiveUDFOperatorNotInfer.java => HiveUDFOperatorWithoutInference.java} (93%) rename exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/{DrillSqlAggOperatorNotInfer.java => DrillSqlAggOperatorWithoutInference.java} (91%) rename exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/{DrillSqlOperatorNotInfer.java => DrillSqlOperatorWithoutInference.java} (93%) create mode 100644 exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/DrillCalciteWrapperUtility.java diff --git a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java index 52bd05b4055..5e74f6f3276 100644 --- a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java +++ b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/expr/fn/HiveFunctionRegistry.java @@ -18,22 +18,15 @@ package org.apache.drill.exec.expr.fn; import java.util.HashSet; -import java.util.List; import java.util.Set; -import java.util.Collection; -import com.google.common.collect.Lists; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.drill.common.config.DrillConfig; import org.apache.drill.common.exceptions.UserException; -import org.apache.drill.common.expression.ExpressionPosition; import org.apache.drill.common.expression.FunctionCall; -import org.apache.drill.common.expression.LogicalExpression; -import org.apache.drill.common.expression.MajorTypeInLogicalExpression; import org.apache.drill.common.scanner.ClassPathScanner; import org.apache.drill.common.scanner.persistence.ScanResult; import org.apache.drill.common.types.TypeProtos; @@ -43,7 +36,7 @@ import org.apache.drill.exec.expr.fn.impl.hive.ObjectInspectorHelper; import org.apache.drill.exec.planner.sql.DrillOperatorTable; import org.apache.drill.exec.planner.sql.HiveUDFOperator; -import org.apache.drill.exec.planner.sql.HiveUDFOperatorNotInfer; +import org.apache.drill.exec.planner.sql.HiveUDFOperatorWithoutInference; import org.apache.drill.exec.planner.sql.TypeInferenceUtils; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDF; @@ -85,8 +78,8 @@ public HiveFunctionRegistry(DrillConfig config) { @Override public void register(DrillOperatorTable operatorTable) { for (String name : Sets.union(methodsGenericUDF.asMap().keySet(), methodsUDF.asMap().keySet())) { - operatorTable.addDefault(name, new HiveUDFOperatorNotInfer(name.toUpperCase())); - operatorTable.addInference(name, new HiveUDFOperator(name.toUpperCase(), new HiveSqlReturnTypeInference())); + operatorTable.addOperatorWithoutInference(name, new HiveUDFOperatorWithoutInference(name.toUpperCase())); + operatorTable.addOperatorWithInference(name, new HiveUDFOperator(name.toUpperCase(), new HiveSqlReturnTypeInference())); } } @@ -240,11 +233,11 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { final FunctionCall functionCall = TypeInferenceUtils.convertSqlOperatorBindingToFunctionCall(opBinding); final HiveFuncHolder hiveFuncHolder = getFunction(functionCall); if(hiveFuncHolder == null) { - String operandTypes = ""; + final StringBuilder operandTypes = new StringBuilder(); for(int j = 0; j < opBinding.getOperandCount(); ++j) { - operandTypes += opBinding.getOperandType(j).getSqlTypeName(); + operandTypes.append(opBinding.getOperandType(j).getSqlTypeName()); if(j < opBinding.getOperandCount() - 1) { - operandTypes += ","; + operandTypes.append(","); } } diff --git a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java index 90c4135e33e..8ed72df9121 100644 --- a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java +++ b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperator.java @@ -31,7 +31,7 @@ public class HiveUDFOperator extends SqlFunction { public HiveUDFOperator(String name, SqlReturnTypeInference sqlReturnTypeInference) { - super(new SqlIdentifier(name, SqlParserPos.ZERO), sqlReturnTypeInference, null, new ArgChecker(), null, + super(new SqlIdentifier(name, SqlParserPos.ZERO), sqlReturnTypeInference, null, ArgChecker.INSTANCE, null, SqlFunctionCategory.USER_DEFINED_FUNCTION); } diff --git a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorNotInfer.java b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorWithoutInference.java similarity index 93% rename from contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorNotInfer.java rename to contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorWithoutInference.java index 0c718f61f0b..f8358120ecd 100644 --- a/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorNotInfer.java +++ b/contrib/storage-hive/core/src/main/java/org/apache/drill/exec/planner/sql/HiveUDFOperatorWithoutInference.java @@ -25,8 +25,8 @@ import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorScope; -public class HiveUDFOperatorNotInfer extends HiveUDFOperator { - public HiveUDFOperatorNotInfer(String name) { +public class HiveUDFOperatorWithoutInference extends HiveUDFOperator { + public HiveUDFOperatorWithoutInference(String name) { super(name, DynamicReturnType.INSTANCE); } diff --git a/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java b/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java index 14390626b24..a287c89d4ba 100644 --- a/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java +++ b/contrib/storage-hive/core/src/test/java/org/apache/drill/exec/fn/hive/TestInbuiltHiveUDFs.java @@ -17,9 +17,15 @@ */ package org.apache.drill.exec.fn.hive; +import com.google.common.collect.Lists; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.drill.common.expression.SchemaPath; +import org.apache.drill.common.types.TypeProtos; import org.apache.drill.exec.hive.HiveTestBase; import org.junit.Test; +import java.util.List; + public class TestInbuiltHiveUDFs extends HiveTestBase { @Test // DRILL-3273 @@ -45,30 +51,23 @@ public void testEncode() throws Exception { } @Test - public void testReflect() throws Exception { - final String query = "select reflect('java.lang.Math', 'round', cast(2 as float)) as col \n" + + public void testXpath_Double() throws Exception { + final String query = "select xpath_double ('2040', 'a/b * a/c') as col \n" + "from hive.kv \n" + - "limit 1"; + "limit 0"; - testBuilder() - .sqlQuery(query) - .unOrdered() - .baselineColumns("col") - .baselineValues("2") - .go(); - } + final TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.FLOAT8) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); - @Test - public void testAbs() throws Exception { - final String query = "select reflect('java.lang.Math', 'abs', cast(-2 as double)) as col \n" + - "from hive.kv \n" + - "limit 1"; + final List> expectedSchema = Lists.newArrayList(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); testBuilder() .sqlQuery(query) - .unOrdered() - .baselineColumns("col") - .baselineValues("2.0") - .go(); + .schemaBaseLine(expectedSchema) + .build() + .run(); } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java index f6bc666f8c1..f58d5a549fb 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillFunctionRegistry.java @@ -18,7 +18,6 @@ package org.apache.drill.exec.expr.fn; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -26,23 +25,20 @@ import java.util.Set; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.type.SqlTypeName; import org.apache.commons.lang3.tuple.Pair; import org.apache.drill.common.scanner.persistence.AnnotatedClassDescriptor; import org.apache.drill.common.scanner.persistence.ScanResult; -import org.apache.drill.common.types.TypeProtos; import org.apache.drill.exec.planner.logical.DrillConstExecutor; import org.apache.drill.exec.planner.sql.DrillOperatorTable; import org.apache.drill.exec.planner.sql.DrillSqlAggOperator; -import org.apache.drill.exec.planner.sql.DrillSqlAggOperatorNotInfer; +import org.apache.drill.exec.planner.sql.DrillSqlAggOperatorWithoutInference; import org.apache.drill.exec.planner.sql.DrillSqlOperator; import com.google.common.collect.ArrayListMultimap; -import org.apache.drill.exec.planner.sql.DrillSqlOperatorNotInfer; +import org.apache.drill.exec.planner.sql.DrillSqlOperatorWithoutInference; /** * Registry of Drill functions. @@ -53,7 +49,7 @@ public class DrillFunctionRegistry { // key: function name (lowercase) value: list of functions with that name private final ArrayListMultimap registeredFunctions = ArrayListMultimap.create(); - private static final ImmutableMap> drillFuncToRange = ImmutableMap.> builder() + private static final ImmutableMap> registeredFuncNameToArgRange = ImmutableMap.> builder() // CONCAT is allowed to take [1, infinity) number of arguments. // Currently, this flexibility is offered by DrillOptiq to rewrite it as // a nested structure @@ -127,11 +123,11 @@ public List getMethods(String name) { } public void register(DrillOperatorTable operatorTable) { - registerForInference(operatorTable); - registerForDefault(operatorTable); + registerOperatorsWithInference(operatorTable); + registerOperatorsWithoutInference(operatorTable); } - public void registerForInference(DrillOperatorTable operatorTable) { + private void registerOperatorsWithInference(DrillOperatorTable operatorTable) { final Map map = Maps.newHashMap(); final Map mapAgg = Maps.newHashMap(); for (Entry> function : registeredFunctions.asMap().entrySet()) { @@ -145,8 +141,8 @@ public void registerForInference(DrillOperatorTable operatorTable) { aggregateFunctions.put(paramCount, func); } else { final Pair argNumberRange; - if(drillFuncToRange.containsKey(name)) { - argNumberRange = drillFuncToRange.get(name); + if(registeredFuncNameToArgRange.containsKey(name)) { + argNumberRange = registeredFuncNameToArgRange.get(name); } else { argNumberRange = Pair.of(func.getParamCount(), func.getParamCount()); } @@ -161,55 +157,43 @@ public void registerForInference(DrillOperatorTable operatorTable) { final Pair range = entry.getKey(); final int max = range.getRight(); final int min = range.getLeft(); - if(map.containsKey(name)) { - final DrillSqlOperator.DrillSqlOperatorBuilder drillSqlOperatorBuilder = map.get(name); - drillSqlOperatorBuilder - .addFunctions(entry.getValue()) - .setArgumentCount(min, max) - .setDeterministic(isDeterministic); - } else { - final DrillSqlOperator.DrillSqlOperatorBuilder drillSqlOperatorBuilder = new DrillSqlOperator.DrillSqlOperatorBuilder(); - drillSqlOperatorBuilder - .setName(name) - .addFunctions(entry.getValue()) - .setArgumentCount(min, max) - .setDeterministic(isDeterministic); - - map.put(name, drillSqlOperatorBuilder); + if(!map.containsKey(name)) { + map.put(name, new DrillSqlOperator.DrillSqlOperatorBuilder() + .setName(name)); } + + final DrillSqlOperator.DrillSqlOperatorBuilder drillSqlOperatorBuilder = map.get(name); + drillSqlOperatorBuilder + .addFunctions(entry.getValue()) + .setArgumentCount(min, max) + .setDeterministic(isDeterministic); } for (Entry> entry : aggregateFunctions.asMap().entrySet()) { - if(mapAgg.containsKey(name)) { - final DrillSqlAggOperator.DrillSqlAggOperatorBuilder drillSqlAggOperatorBuilder = mapAgg.get(name); - drillSqlAggOperatorBuilder - .addFunctions(entry.getValue()) - .setArgumentCount(entry.getKey(), entry.getKey()); - } else { - final DrillSqlAggOperator.DrillSqlAggOperatorBuilder drillSqlAggOperatorBuilder = new DrillSqlAggOperator.DrillSqlAggOperatorBuilder(); - drillSqlAggOperatorBuilder - .setName(name) - .addFunctions(entry.getValue()) - .setArgumentCount(entry.getKey(), entry.getKey()); - - mapAgg.put(name, drillSqlAggOperatorBuilder); + if(!mapAgg.containsKey(name)) { + mapAgg.put(name, new DrillSqlAggOperator.DrillSqlAggOperatorBuilder().setName(name)); } + + final DrillSqlAggOperator.DrillSqlAggOperatorBuilder drillSqlAggOperatorBuilder = mapAgg.get(name); + drillSqlAggOperatorBuilder + .addFunctions(entry.getValue()) + .setArgumentCount(entry.getKey(), entry.getKey()); } } for(final Entry entry : map.entrySet()) { - operatorTable.addInference( + operatorTable.addOperatorWithInference( entry.getKey(), entry.getValue().build()); } for(final Entry entry : mapAgg.entrySet()) { - operatorTable.addInference( + operatorTable.addOperatorWithInference( entry.getKey(), entry.getValue().build()); } } - public void registerForDefault(DrillOperatorTable operatorTable) { + private void registerOperatorsWithoutInference(DrillOperatorTable operatorTable) { SqlOperator op; for (Entry> function : registeredFunctions.asMap().entrySet()) { Set argCounts = Sets.newHashSet(); @@ -217,7 +201,7 @@ public void registerForDefault(DrillOperatorTable operatorTable) { for (DrillFuncHolder func : function.getValue()) { if (argCounts.add(func.getParamCount())) { if (func.isAggregating()) { - op = new DrillSqlAggOperatorNotInfer(name, func.getParamCount()); + op = new DrillSqlAggOperatorWithoutInference(name, func.getParamCount()); } else { boolean isDeterministic; // prevent Drill from folding constant functions with types that cannot be materialized @@ -227,9 +211,9 @@ public void registerForDefault(DrillOperatorTable operatorTable) { } else { isDeterministic = func.isDeterministic(); } - op = new DrillSqlOperatorNotInfer(name, func.getParamCount(), func.getReturnType(), isDeterministic); + op = new DrillSqlOperatorWithoutInference(name, func.getParamCount(), func.getReturnType(), isDeterministic); } - operatorTable.addDefault(function.getKey(), op); + operatorTable.addOperatorWithoutInference(function.getKey(), op); } } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/PluggableFunctionRegistry.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/PluggableFunctionRegistry.java index 547e65fec93..6ad4388a985 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/PluggableFunctionRegistry.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/PluggableFunctionRegistry.java @@ -23,7 +23,11 @@ public interface PluggableFunctionRegistry { /** - * Register functions in given operator table. + * Register functions in given operator table. There are two methods to add operators. + * One is addOperatorWithInference whose added operators will be used + * when planner.type_inference.enable is set to true; + * The other is addOperatorWithoutInference whose added operators will be used + * when planner.type_inference.enable is set to false; * @param operatorTable */ public void register(DrillOperatorTable operatorTable); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java index 8975e9fe8bc..dd2fc14399e 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java @@ -65,6 +65,7 @@ import com.google.common.collect.ImmutableList; import org.apache.drill.exec.planner.sql.TypeInferenceUtils; +import org.apache.drill.exec.planner.sql.parser.DrillCalciteWrapperUtility; /** * Rule to reduce aggregates to simpler forms. Currently only AVG(x) to @@ -122,11 +123,7 @@ public void onMatch(RelOptRuleCall ruleCall) { */ private boolean containsAvgStddevVarCall(List aggCallList) { for (AggregateCall call : aggCallList) { - SqlAggFunction sqlAggFunction = call.getAggregation(); - if(sqlAggFunction instanceof DrillCalciteSqlWrapper) { - sqlAggFunction = (SqlAggFunction) ((DrillCalciteSqlWrapper) sqlAggFunction).getOperator(); - } - + SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(call.getAggregation()); if (sqlAggFunction instanceof SqlAvgAggFunction || sqlAggFunction instanceof SqlSumAggFunction) { return true; @@ -225,11 +222,7 @@ private RexNode reduceAgg( List newCalls, Map aggCallMapping, List inputExprs) { - SqlAggFunction sqlAggFunction = oldCall.getAggregation(); - if(sqlAggFunction instanceof DrillCalciteSqlWrapper) { - sqlAggFunction = (SqlAggFunction) ((DrillCalciteSqlWrapper) sqlAggFunction).getOperator(); - } - + final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(oldCall.getAggregation()); if (sqlAggFunction instanceof SqlSumAggFunction) { // replace original SUM(x) with // case COUNT(x) when 0 then null else SUM0(x) end @@ -702,11 +695,7 @@ public DrillConvertSumToSumZero(RelOptRuleOperand operand) { public boolean matches(RelOptRuleCall call) { DrillAggregateRel oldAggRel = (DrillAggregateRel) call.rels[0]; for (AggregateCall aggregateCall : oldAggRel.getAggCallList()) { - SqlAggFunction sqlAggFunction = aggregateCall.getAggregation(); - if(sqlAggFunction instanceof DrillCalciteSqlWrapper) { - sqlAggFunction = (SqlAggFunction) ((DrillCalciteSqlWrapper) sqlAggFunction).getOperator(); - } - + final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(aggregateCall.getAggregation()); if(sqlAggFunction instanceof SqlSumAggFunction && !aggregateCall.getType().isNullable()) { // If SUM(x) is not nullable, the validator must have determined that @@ -725,11 +714,8 @@ public void onMatch(RelOptRuleCall call) { final Map aggCallMapping = Maps.newHashMap(); final List newAggregateCalls = Lists.newArrayList(); for (AggregateCall oldAggregateCall : oldAggRel.getAggCallList()) { - SqlAggFunction sqlAggFunction = oldAggregateCall.getAggregation(); - if(sqlAggFunction instanceof DrillCalciteSqlWrapper) { - sqlAggFunction = (SqlAggFunction) ((DrillCalciteSqlWrapper) sqlAggFunction).getOperator(); - } - + final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper( + oldAggregateCall.getAggregation()); if(sqlAggFunction instanceof SqlSumAggFunction && !oldAggregateCall.getType().isNullable()) { final RelDataType argType = oldAggregateCall.getType(); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/PreProcessLogicalRel.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/PreProcessLogicalRel.java index 1585a5616a8..10c131d56e8 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/PreProcessLogicalRel.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/PreProcessLogicalRel.java @@ -30,6 +30,7 @@ import org.apache.drill.exec.planner.StarColumnHelper; import org.apache.drill.exec.planner.sql.DrillCalciteSqlWrapper; import org.apache.drill.exec.planner.sql.DrillOperatorTable; +import org.apache.drill.exec.planner.sql.parser.DrillCalciteWrapperUtility; import org.apache.drill.exec.util.ApproximateStringMatcher; import org.apache.drill.exec.work.foreman.SqlUnsupportedException; import org.apache.calcite.rel.core.AggregateCall; @@ -263,13 +264,7 @@ private UnwrappingExpressionVisitor(RexBuilder rexBuilder) { @Override public RexNode visitCall(final RexCall call) { final List clonedOperands = visitList(call.operands, new boolean[]{true}); - final SqlOperator sqlOperator; - if(call.getOperator() instanceof DrillCalciteSqlWrapper) { - sqlOperator = ((DrillCalciteSqlWrapper) call.getOperator()).getOperator(); - } else { - sqlOperator = call.getOperator(); - } - + final SqlOperator sqlOperator = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(call.getOperator()); return RexUtil.flatten(rexBuilder, rexBuilder.makeCall( call.getType(), diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java index a98619ce6fa..ff36d47bc1d 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java @@ -81,7 +81,7 @@ public class PlannerSettings implements Context{ new RangeLongValidator("planner.identifier_max_length", 128 /* A minimum length is needed because option names are identifiers themselves */, Integer.MAX_VALUE, DEFAULT_IDENTIFIER_MAX_LENGTH); - public static final String TYPE_INFERENCE_KEY = "planner.type_inference.enable"; + public static final String TYPE_INFERENCE_KEY = "planner.enable_type_inference"; public static final BooleanValidator TYPE_INFERENCE = new BooleanValidator(TYPE_INFERENCE_KEY, true); public OptionManager options = null; @@ -213,7 +213,7 @@ public static long getInitialPlanningMemorySize() { } public boolean isTypeInferenceEnabled() { - return options.getOption(TYPE_INFERENCE.getOptionName()).bool_val; + return options.getOption(TYPE_INFERENCE); } @Override diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlOperatorWrapper.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlOperatorWrapper.java index 28c1cecb251..825812059bc 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlOperatorWrapper.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillCalciteSqlOperatorWrapper.java @@ -17,7 +17,6 @@ */ package org.apache.drill.exec.planner.sql; -import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlLiteral; @@ -29,7 +28,6 @@ import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.validate.SqlMonotonicity; import org.apache.calcite.sql.validate.SqlValidator; -import org.apache.calcite.sql.validate.SqlValidatorScope; import org.apache.drill.exec.expr.fn.DrillFuncHolder; import java.util.List; diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java index 6b81bf0fc35..511eed7aee9 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java @@ -27,6 +27,7 @@ import org.apache.calcite.sql2rel.SqlRexConvertlet; import org.apache.calcite.sql2rel.SqlRexConvertletTable; import org.apache.calcite.sql2rel.StandardConvertletTable; +import org.apache.drill.exec.planner.sql.parser.DrillCalciteWrapperUtility; public class DrillConvertletTable implements SqlRexConvertletTable{ @@ -53,7 +54,7 @@ public SqlRexConvertlet get(SqlCall call) { SqlRexConvertlet convertlet; if(call.getOperator() instanceof DrillCalciteSqlWrapper) { final SqlOperator wrapper = call.getOperator(); - final SqlOperator wrapped = ((DrillCalciteSqlWrapper) call.getOperator()).getOperator(); + final SqlOperator wrapped = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(call.getOperator()); if ((convertlet = map.get(wrapped)) != null) { return convertlet; } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillExtractConvertlet.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillExtractConvertlet.java index 61e1e07e817..5a85369aec0 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillExtractConvertlet.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillExtractConvertlet.java @@ -27,6 +27,7 @@ import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlIntervalQualifier; import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql2rel.SqlRexContext; import org.apache.calcite.sql2rel.SqlRexConvertlet; @@ -61,11 +62,21 @@ public RexNode convertCall(SqlRexContext cx, SqlCall call) { exprs.add(cx.convertExpression(node)); } - // Determine NULL-able using 2nd argument's Null-able. - RelDataType returnType = typeFactory.createTypeWithNullability( - typeFactory.createSqlType( - TypeInferenceUtils.getSqlTypeNameForTimeUnit(timeUnit)), - exprs.get(1).getType().isNullable()); + final RelDataType returnType; + if(call.getOperator() == SqlStdOperatorTable.EXTRACT) { + // Legacy code: + // The return type is wrong! + // Legacy code choose SqlTypeName.BIGINT simply to avoid conflicting against Calcite's inference mechanism + // (, which chose BIGINT in validation phase already) + // Determine NULL-able using 2nd argument's Null-able. + returnType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), exprs.get(1).getType().isNullable()); + } else { + // Determine NULL-able using 2nd argument's Null-able. + returnType = typeFactory.createTypeWithNullability( + typeFactory.createSqlType( + TypeInferenceUtils.getSqlTypeNameForTimeUnit(timeUnit)), + exprs.get(1).getType().isNullable()); + } return rexBuilder.makeCall(returnType, call.getOperator(), exprs); } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java index de18f029872..5f489b423c7 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillOperatorTable.java @@ -34,6 +34,7 @@ import org.apache.calcite.sql.SqlSyntax; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.drill.exec.planner.physical.PlannerSettings; +import org.apache.drill.exec.server.options.OptionManager; import org.apache.drill.exec.server.options.SystemOptionManager; import java.util.List; @@ -46,44 +47,48 @@ public class DrillOperatorTable extends SqlStdOperatorTable { // private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DrillOperatorTable.class); private static final SqlOperatorTable inner = SqlStdOperatorTable.instance(); - private final List operatorsCalcite = Lists.newArrayList(); - private final List operatorsDefault = Lists.newArrayList(); - private final List operatorsInferernce = Lists.newArrayList(); + private final List calciteOperators = Lists.newArrayList(); + private final List drillOperatorsWithoutInference = Lists.newArrayList(); + private final List drillOperatorsWithInference = Lists.newArrayList(); private final Map calciteToWrapper = Maps.newIdentityHashMap(); - private final ArrayListMultimap opMapDefault = ArrayListMultimap.create(); - private final ArrayListMultimap opMapInferernce = ArrayListMultimap.create(); + private final ArrayListMultimap drillOperatorsWithoutInferenceMap = ArrayListMultimap.create(); + private final ArrayListMultimap drillOperatorsWithInferenceMap = ArrayListMultimap.create(); - private final SystemOptionManager systemOptionManager; + private final OptionManager systemOptionManager; - public DrillOperatorTable(FunctionImplementationRegistry registry) { - this(registry, null); - } - - public DrillOperatorTable(FunctionImplementationRegistry registry, SystemOptionManager systemOptionManager) { + public DrillOperatorTable(FunctionImplementationRegistry registry, OptionManager systemOptionManager) { registry.register(this); - operatorsCalcite.addAll(inner.getOperatorList()); + calciteOperators.addAll(inner.getOperatorList()); populateWrappedCalciteOperators(); this.systemOptionManager = systemOptionManager; } - public void addDefault(String name, SqlOperator op) { - operatorsDefault.add(op); - opMapDefault.put(name.toLowerCase(), op); + /** + * When the option planner.type_inference.enable is turned off, the operators which are added via this method + * will be used. + */ + public void addOperatorWithoutInference(String name, SqlOperator op) { + drillOperatorsWithoutInference.add(op); + drillOperatorsWithoutInferenceMap.put(name.toLowerCase(), op); } - public void addInference(String name, SqlOperator op) { - operatorsInferernce.add(op); - opMapInferernce.put(name.toLowerCase(), op); + /** + * When the option planner.type_inference.enable is turned on, the operators which are added via this method + * will be used. + */ + public void addOperatorWithInference(String name, SqlOperator op) { + drillOperatorsWithInference.add(op); + drillOperatorsWithInferenceMap.put(name.toLowerCase(), op); } @Override public void lookupOperatorOverloads(SqlIdentifier opName, SqlFunctionCategory category, SqlSyntax syntax, List operatorList) { - if(isEnableInference()) { + if(isInferenceEnabled()) { populateFromTypeInference(opName, category, syntax, operatorList); } else { - populateFromDefault(opName, category, syntax, operatorList); + populateFromWithoutTypeInference(opName, category, syntax, operatorList); } } @@ -102,7 +107,7 @@ private void populateFromTypeInference(SqlIdentifier opName, SqlFunctionCategory } else { // if no function is found, check in Drill UDFs if (operatorList.isEmpty() && syntax == SqlSyntax.FUNCTION && opName.isSimple()) { - List drillOps = opMapInferernce.get(opName.getSimple().toLowerCase()); + List drillOps = drillOperatorsWithInferenceMap.get(opName.getSimple().toLowerCase()); if (drillOps != null && !drillOps.isEmpty()) { operatorList.addAll(drillOps); } @@ -110,11 +115,11 @@ private void populateFromTypeInference(SqlIdentifier opName, SqlFunctionCategory } } - private void populateFromDefault(SqlIdentifier opName, SqlFunctionCategory category, + private void populateFromWithoutTypeInference(SqlIdentifier opName, SqlFunctionCategory category, SqlSyntax syntax, List operatorList) { inner.lookupOperatorOverloads(opName, category, syntax, operatorList); if (operatorList.isEmpty() && syntax == SqlSyntax.FUNCTION && opName.isSimple()) { - List drillOps = opMapDefault.get(opName.getSimple().toLowerCase()); + List drillOps = drillOperatorsWithoutInferenceMap.get(opName.getSimple().toLowerCase()); if (drillOps != null) { operatorList.addAll(drillOps); } @@ -124,11 +129,11 @@ private void populateFromDefault(SqlIdentifier opName, SqlFunctionCategory categ @Override public List getOperatorList() { final List sqlOperators = Lists.newArrayList(); - sqlOperators.addAll(operatorsCalcite); - if(isEnableInference()) { - sqlOperators.addAll(operatorsInferernce); + sqlOperators.addAll(calciteOperators); + if(isInferenceEnabled()) { + sqlOperators.addAll(calciteOperators); } else { - sqlOperators.addAll(operatorsDefault); + sqlOperators.addAll(calciteOperators); } return sqlOperators; @@ -136,10 +141,10 @@ public List getOperatorList() { // Get the list of SqlOperator's with the given name. public List getSqlOperator(String name) { - if(isEnableInference()) { - return opMapInferernce.get(name.toLowerCase()); + if(isInferenceEnabled()) { + return drillOperatorsWithInferenceMap.get(name.toLowerCase()); } else { - return opMapDefault.get(name.toLowerCase()); + return drillOperatorsWithoutInferenceMap.get(name.toLowerCase()); } } @@ -167,7 +172,7 @@ private void populateWrappedCalciteOperators() { private List getFunctionListWithInference(String name) { final List functions = Lists.newArrayList(); - for(SqlOperator sqlOperator : opMapInferernce.get(name.toLowerCase())) { + for(SqlOperator sqlOperator : drillOperatorsWithInferenceMap.get(name.toLowerCase())) { if(sqlOperator instanceof DrillSqlOperator) { final List list = ((DrillSqlOperator) sqlOperator).getFunctions(); if(list != null) { @@ -185,8 +190,7 @@ private List getFunctionListWithInference(String name) { return functions; } - private boolean isEnableInference() { - return systemOptionManager != null - && systemOptionManager.getOption(PlannerSettings.TYPE_INFERENCE); + private boolean isInferenceEnabled() { + return systemOptionManager.getOption(PlannerSettings.TYPE_INFERENCE); } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java index 044f5b05ee5..73ff2007300 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperator.java @@ -59,11 +59,10 @@ public List getFunctions() { } public static class DrillSqlAggOperatorBuilder { - private String name; + private String name = null; private final List functions = Lists.newArrayList(); private int argCountMin = Integer.MAX_VALUE; private int argCountMax = Integer.MIN_VALUE; - private boolean isDeterministic = true; public DrillSqlAggOperatorBuilder setName(final String name) { this.name = name; @@ -82,6 +81,9 @@ public DrillSqlAggOperatorBuilder setArgumentCount(final int argCountMin, final } public DrillSqlAggOperator build() { + if(name == null || functions.isEmpty()) { + throw new AssertionError("The fields, name and functions, need to be set before build DrillSqlAggOperator"); + } return new DrillSqlAggOperator( name, functions, diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorNotInfer.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorWithoutInference.java similarity index 91% rename from exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorNotInfer.java rename to exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorWithoutInference.java index 592c23edce7..6e53a561baf 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorNotInfer.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlAggOperatorWithoutInference.java @@ -27,8 +27,8 @@ import java.util.ArrayList; -public class DrillSqlAggOperatorNotInfer extends DrillSqlAggOperator { - public DrillSqlAggOperatorNotInfer(String name, int argCount) { +public class DrillSqlAggOperatorWithoutInference extends DrillSqlAggOperator { + public DrillSqlAggOperatorWithoutInference(String name, int argCount) { super(name, new ArrayList(), argCount, argCount, DynamicReturnType.INSTANCE); } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java index 1bb62f3963a..e5942014083 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperator.java @@ -133,6 +133,14 @@ public DrillSqlOperatorBuilder setArgumentCount(final int argCountMin, final int } public DrillSqlOperatorBuilder setDeterministic(boolean isDeterministic) { + /* By the logic here, we will group the entire Collection as a DrillSqlOperator. and claim it is non-deterministic. + * Add if there is a non-deterministic DrillFuncHolder, then we claim this DrillSqlOperator is non-deterministic. + * + * In fact, in this case, separating all DrillFuncHolder into two DrillSqlOperator + * (one being deterministic and the other being non-deterministic does not help) since in DrillOperatorTable.lookupOperatorOverloads(), + * parameter list is not passed in. So even if we have two DrillSqlOperator, DrillOperatorTable.lookupOperatorOverloads() + * does not have enough information to pick the one matching the argument list. + */ if(this.isDeterministic) { this.isDeterministic = isDeterministic; } @@ -140,6 +148,10 @@ public DrillSqlOperatorBuilder setDeterministic(boolean isDeterministic) { } public DrillSqlOperator build() { + if(name == null || functions.isEmpty()) { + throw new AssertionError("The fields, name and functions, need to be set before build DrillSqlAggOperator"); + } + return new DrillSqlOperator( name, functions, diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorNotInfer.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorWithoutInference.java similarity index 93% rename from exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorNotInfer.java rename to exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorWithoutInference.java index a7394bd0a81..155a7a662fd 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorNotInfer.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillSqlOperatorWithoutInference.java @@ -30,11 +30,11 @@ import java.util.ArrayList; -public class DrillSqlOperatorNotInfer extends DrillSqlOperator { +public class DrillSqlOperatorWithoutInference extends DrillSqlOperator { private static final TypeProtos.MajorType NONE = TypeProtos.MajorType.getDefaultInstance(); private final TypeProtos.MajorType returnType; - public DrillSqlOperatorNotInfer(String name, int argCount, TypeProtos.MajorType returnType, boolean isDeterminisitic) { + public DrillSqlOperatorWithoutInference(String name, int argCount, TypeProtos.MajorType returnType, boolean isDeterminisitic) { super(name, new ArrayList< DrillFuncHolder>(), argCount, diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/SqlConverter.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/SqlConverter.java index fc63276661e..3dfea6f3eb8 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/SqlConverter.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/SqlConverter.java @@ -196,34 +196,6 @@ protected DrillValidator(SqlOperatorTable opTab, SqlValidatorCatalogReader catal RelDataTypeFactory typeFactory, SqlConformance conformance) { super(opTab, catalogReader, typeFactory, conformance); } - - @Override - public SqlValidatorScope getSelectScope(final SqlSelect select) { - final SqlValidatorScope sqlValidatorScope = super.getSelectScope(select); - if(needsValidation(sqlValidatorScope)) { - final AggregatingSelectScope aggregatingSelectScope = ((AggregatingSelectScope) sqlValidatorScope); - for(SqlNode sqlNode : aggregatingSelectScope.groupExprList) { - if(sqlNode instanceof SqlCall) { - final SqlCall sqlCall = (SqlCall) sqlNode; - sqlCall.getOperator().deriveType(this, sqlValidatorScope, sqlCall); - } - } - identitySet.add(sqlValidatorScope); - } - return sqlValidatorScope; - } - - // Due to the deep-copy of AggregatingSelectScope in the following two commits in the Forked Drill-Calcite: - // 1. [StarColumn] Reverse one change in CALCITE-356, which regresses AggChecker logic, after * query in schema-less table is added. - // 2. [StarColumn] When group-by a column, projecting on a star which cannot be expanded at planning time, - // use ITEM operator to wrap this column - private boolean needsValidation(final SqlValidatorScope sqlValidatorScope) { - if(sqlValidatorScope instanceof AggregatingSelectScope) { - return !identitySet.contains(sqlValidatorScope); - } else { - return false; - } - } } private static class DrillTypeSystem extends RelDataTypeSystemImpl { @@ -406,29 +378,15 @@ private DrillRexBuilder(RelDataTypeFactory typeFactory) { super(typeFactory); } + /** + * Since Drill has different mechanism and rules for implicit casting, + * ensureType() is overridden to avoid conflicting cast functions being added to the expressions. + */ @Override public RexNode ensureType( RelDataType type, RexNode node, boolean matchNullability) { - RelDataType targetType = type; - if (matchNullability) { - targetType = matchNullability(type, node); - } - if (targetType.getSqlTypeName() == SqlTypeName.ANY) { - return node; - } - if (!node.getType().equals(targetType)) { - if(!targetType.isStruct()) { - final RelDataType anyType = TypeInferenceUtils.createCalciteTypeWithNullability( - getTypeFactory(), - SqlTypeName.ANY, - targetType.isNullable()); - return makeCast(anyType, node); - } else { - return makeCast(targetType, node); - } - } return node; } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java index 9af6fa302d0..b7942ed86b8 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java @@ -27,9 +27,13 @@ import org.apache.calcite.sql.SqlCharStringLiteral; import org.apache.calcite.sql.SqlDynamicParam; import org.apache.calcite.sql.SqlIntervalQualifier; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.SqlRankFunction; +import org.apache.calcite.sql.fun.SqlAvgAggFunction; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeName; @@ -54,7 +58,8 @@ public class TypeInferenceUtils { private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(TypeInferenceUtils.class); public static final TypeProtos.MajorType UNKNOWN_TYPE = TypeProtos.MajorType.getDefaultInstance(); - private static final ImmutableMap DRILL_TO_CALCITE_TYPE_MAPPING = ImmutableMap. builder() + private static final ImmutableMap DRILL_TO_CALCITE_TYPE_MAPPING + = ImmutableMap. builder() .put(TypeProtos.MinorType.INT, SqlTypeName.INTEGER) .put(TypeProtos.MinorType.BIGINT, SqlTypeName.BIGINT) .put(TypeProtos.MinorType.FLOAT4, SqlTypeName.FLOAT) @@ -82,7 +87,8 @@ public class TypeInferenceUtils { // - CHAR, SYMBOL, MULTISET, DISTINCT, STRUCTURED, ROW, OTHER, CURSOR, COLUMN_LIST .build(); - private static final ImmutableMap CALCITE_TO_DRILL_MAPPING = ImmutableMap. builder() + private static final ImmutableMap CALCITE_TO_DRILL_MAPPING + = ImmutableMap. builder() .put(SqlTypeName.INTEGER, TypeProtos.MinorType.INT) .put(SqlTypeName.BIGINT, TypeProtos.MinorType.BIGINT) .put(SqlTypeName.FLOAT, TypeProtos.MinorType.FLOAT4) @@ -135,6 +141,32 @@ public class TypeInferenceUtils { .put("FLATTEN", DrillDeferToExecSqlReturnTypeInference.INSTANCE) .put("KVGEN", DrillDeferToExecSqlReturnTypeInference.INSTANCE) .put("CONVERT_FROM", DrillDeferToExecSqlReturnTypeInference.INSTANCE) + + // Window Functions + // RANKING + .put(SqlKind.CUME_DIST.name(), DrillRankingSqlReturnTypeInference.INSTANCE_DOUBLE) + .put(SqlKind.DENSE_RANK.name(), DrillRankingSqlReturnTypeInference.INSTANCE_BIGINT) + .put(SqlKind.PERCENT_RANK.name(), DrillRankingSqlReturnTypeInference.INSTANCE_DOUBLE) + .put(SqlKind.RANK.name(), DrillRankingSqlReturnTypeInference.INSTANCE_BIGINT) + .put(SqlKind.ROW_NUMBER.name(), DrillRankingSqlReturnTypeInference.INSTANCE_BIGINT) + + // NTILE + .put("NTILE", DrillNTILESqlReturnTypeInference.INSTANCE) + + // LEAD, LAG + .put("LEAD", DrillLeadLagSqlReturnTypeInference.INSTANCE) + .put("LAG", DrillLeadLagSqlReturnTypeInference.INSTANCE) + + // FIRST_VALUE, LAST_VALUE + .put("FIRST_VALUE", DrillFirstLastValueSqlReturnTypeInference.INSTANCE) + .put("LAST_VALUE", DrillFirstLastValueSqlReturnTypeInference.INSTANCE) + + // Functions rely on DrillReduceAggregatesRule for expression simplification as opposed to getting evaluated directly + .put(SqlAvgAggFunction.Subtype.AVG.name(), DrillAvgAggSqlReturnTypeInference.INSTANCE) + .put(SqlAvgAggFunction.Subtype.STDDEV_POP.name(), DrillAvgAggSqlReturnTypeInference.INSTANCE) + .put(SqlAvgAggFunction.Subtype.STDDEV_SAMP.name(), DrillAvgAggSqlReturnTypeInference.INSTANCE) + .put(SqlAvgAggFunction.Subtype.VAR_POP.name(), DrillAvgAggSqlReturnTypeInference.INSTANCE) + .put(SqlAvgAggFunction.Subtype.VAR_SAMP.name(), DrillAvgAggSqlReturnTypeInference.INSTANCE) .build(); /** @@ -533,6 +565,66 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { } } + private static class DrillRankingSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillRankingSqlReturnTypeInference INSTANCE_BIGINT = new DrillRankingSqlReturnTypeInference(SqlTypeName.BIGINT); + private static final DrillRankingSqlReturnTypeInference INSTANCE_DOUBLE = new DrillRankingSqlReturnTypeInference(SqlTypeName.DOUBLE); + + private final SqlTypeName returnType; + private DrillRankingSqlReturnTypeInference(final SqlTypeName returnType) { + this.returnType = returnType; + } + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return createCalciteTypeWithNullability( + opBinding.getTypeFactory(), + returnType, + false); + } + } + + private static class DrillNTILESqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillNTILESqlReturnTypeInference INSTANCE = new DrillNTILESqlReturnTypeInference(); + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return createCalciteTypeWithNullability( + opBinding.getTypeFactory(), + SqlTypeName.INTEGER, + opBinding.getOperandType(0).isNullable()); + } + } + + private static class DrillLeadLagSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillLeadLagSqlReturnTypeInference INSTANCE = new DrillLeadLagSqlReturnTypeInference(); + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return createCalciteTypeWithNullability( + opBinding.getTypeFactory(), + opBinding.getOperandType(0).getSqlTypeName(), + true); + } + } + + private static class DrillFirstLastValueSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillFirstLastValueSqlReturnTypeInference INSTANCE = new DrillFirstLastValueSqlReturnTypeInference(); + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return opBinding.getOperandType(0); + } + } + + private static class DrillAvgAggSqlReturnTypeInference implements SqlReturnTypeInference { + private static final DrillAvgAggSqlReturnTypeInference INSTANCE = new DrillAvgAggSqlReturnTypeInference(); + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final boolean isNullable = opBinding.getGroupCount() == 0 || opBinding.hasFilter() || opBinding.getOperandType(0).isNullable(); + return createCalciteTypeWithNullability( + opBinding.getTypeFactory(), + SqlTypeName.DOUBLE, + isNullable); + } + } + private static DrillFuncHolder resolveDrillFuncHolder(final SqlOperatorBinding opBinding, final List functions) { final FunctionCall functionCall = convertSqlOperatorBindingToFunctionCall(opBinding); final FunctionResolver functionResolver = FunctionResolverFactory.getResolver(functionCall); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/DrillCalciteWrapperUtility.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/DrillCalciteWrapperUtility.java new file mode 100644 index 00000000000..265f361975a --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/DrillCalciteWrapperUtility.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.drill.exec.planner.sql.parser; + +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlOperator; +import org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper; +import org.apache.drill.exec.planner.sql.DrillCalciteSqlFunctionWrapper; +import org.apache.drill.exec.planner.sql.DrillCalciteSqlOperatorWrapper; +import org.apache.drill.exec.planner.sql.DrillCalciteSqlWrapper; + +/** + * + * This utility contains the static functions to manipulate {@link DrillCalciteSqlWrapper}, {@link DrillCalciteSqlOperatorWrapper} + * {@link DrillCalciteSqlFunctionWrapper} and {@link DrillCalciteSqlAggFunctionWrapper}. + */ +public class DrillCalciteWrapperUtility { + /** + * This static method will extract the SqlOperator inside the given SqlOperator if the given SqlOperator is wrapped + * in DrillCalciteSqlWrapper and will just return the given SqlOperator if it is not wrapped. + */ + public static SqlOperator extractSqlOperatorFromWrapper(final SqlOperator sqlOperator) { + if(sqlOperator instanceof DrillCalciteSqlWrapper) { + return ((DrillCalciteSqlWrapper) sqlOperator).getOperator(); + } else { + return sqlOperator; + } + } + + /** + * This static method will extract the SqlFunction inside the given SqlFunction if the given SqlFunction is wrapped + * in DrillCalciteSqlFunctionWrapper and will just return the given SqlFunction if it is not wrapped. + */ + public static SqlFunction extractSqlOperatorFromWrapper(final SqlFunction sqlFunction) { + if(sqlFunction instanceof DrillCalciteSqlWrapper) { + return (SqlFunction) ((DrillCalciteSqlWrapper) sqlFunction).getOperator(); + } else { + return sqlFunction; + } + } + + /** + * This static method will extract the SqlAggFunction inside the given SqlAggFunction if the given SqlFunction is wrapped + * in DrillCalciteSqlAggFunctionWrapper and will just return the given SqlAggFunction if it is not wrapped. + */ + public static SqlAggFunction extractSqlOperatorFromWrapper(final SqlAggFunction sqlAggFunction) { + if(sqlAggFunction instanceof DrillCalciteSqlWrapper) { + return (SqlAggFunction) ((DrillCalciteSqlWrapper) sqlAggFunction).getOperator(); + } else { + return sqlAggFunction; + } + } + + /** + * This class is not intended to be instantiated + */ + private DrillCalciteWrapperUtility() { + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java index 528cadc4727..917353ee3b2 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java @@ -352,7 +352,7 @@ public SqlNode visit(SqlCall sqlCall) { } } - if(extractSqlOperatorFromWrapper(sqlCall.getOperator()) instanceof SqlCountAggFunction) { + if(DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(sqlCall.getOperator()) instanceof SqlCountAggFunction) { for(SqlNode sqlNode : sqlCall.getOperandList()) { if(containsFlatten(sqlNode)) { unsupportedOperatorCollector.setException(SqlUnsupportedException.ExceptionType.FUNCTION, @@ -416,7 +416,7 @@ private interface SqlNodeCondition { @Override public boolean test(SqlNode sqlNode) { if (sqlNode instanceof SqlCall) { - final SqlOperator operator = extractSqlOperatorFromWrapper(((SqlCall) sqlNode).getOperator()); + final SqlOperator operator = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(((SqlCall) sqlNode).getOperator()); if (operator == SqlStdOperatorTable.ROLLUP || operator == SqlStdOperatorTable.CUBE || operator == SqlStdOperatorTable.GROUPING_SETS) { @@ -434,7 +434,7 @@ public boolean test(SqlNode sqlNode) { @Override public boolean test(SqlNode sqlNode) { if (sqlNode instanceof SqlCall) { - final SqlOperator operator = extractSqlOperatorFromWrapper(((SqlCall) sqlNode).getOperator()); + final SqlOperator operator = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(((SqlCall) sqlNode).getOperator()); if (operator == SqlStdOperatorTable.GROUPING || operator == SqlStdOperatorTable.GROUPING_ID || operator == SqlStdOperatorTable.GROUP_ID) { @@ -555,12 +555,4 @@ private void detectMultiplePartitions(SqlSelect sqlSelect) { } } } - - private SqlOperator extractSqlOperatorFromWrapper(SqlOperator sqlOperator) { - if(sqlOperator instanceof DrillCalciteSqlWrapper) { - return ((DrillCalciteSqlWrapper) sqlOperator).getOperator(); - } else { - return sqlOperator; - } - } } \ No newline at end of file diff --git a/exec/java-exec/src/test/java/org/apache/drill/PlanningBase.java b/exec/java-exec/src/test/java/org/apache/drill/PlanningBase.java index 6a038f1a21f..ad9cc648a0d 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/PlanningBase.java +++ b/exec/java-exec/src/test/java/org/apache/drill/PlanningBase.java @@ -109,7 +109,7 @@ protected void testSqlPlan(String sqlCommands) throws Exception { final StoragePluginRegistry registry = new StoragePluginRegistryImpl(dbContext); registry.init(); final FunctionImplementationRegistry functionRegistry = new FunctionImplementationRegistry(config); - final DrillOperatorTable table = new DrillOperatorTable(functionRegistry); + final DrillOperatorTable table = new DrillOperatorTable(functionRegistry, systemOptions); final SchemaPlus root = SimpleCalciteSchema.createRootSchema(false); registry.getSchemaFactory().registerSchemas(SchemaConfig.newBuilder("foo", context).build(), root); diff --git a/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java b/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java index ad9a2053172..5df71f0f1db 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java +++ b/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java @@ -22,6 +22,7 @@ import org.apache.drill.common.expression.SchemaPath; import org.apache.drill.common.types.TypeProtos; import org.apache.drill.common.util.FileUtils; +import org.junit.Ignore; import org.junit.Test; import java.util.List; @@ -474,7 +475,12 @@ public void testAvgCountStar() throws Exception { .run(); } - @Test // explain plan including all attributes for + @Test + @Ignore // This is temporarily turned off due to + // [1] [StarColumn] Reverse one change in CALCITE-356, + // which regresses AggChecker logic, after * query in schema-less table is added. + // [2] [StarColumn] + // When group-by a column, projecting on a star which cannot be expanded at planning time, use ITEM operator to wrap this column public void testUDFInGroupBy() throws Exception { final String query = "select count(*) as col1, substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2) as col2, \n" + "char_length(substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2)) as col3 \n" + @@ -548,4 +554,164 @@ public void testWindowSumAvg() throws Exception { .build() .run(); } + + @Test + public void testWindowRanking() throws Exception { + final String queryCUME_DIST = "select CUME_DIST() over(order by n_nationkey) as col \n" + + "from cp.`tpch/nation.parquet` \n" + + "limit 0"; + + final String queryDENSE_RANK = "select DENSE_RANK() over(order by n_nationkey) as col \n" + + "from cp.`tpch/nation.parquet` \n" + + "limit 0"; + + final String queryPERCENT_RANK = "select PERCENT_RANK() over(order by n_nationkey) as col \n" + + "from cp.`tpch/nation.parquet` \n" + + "limit 0"; + + final String queryRANK = "select RANK() over(order by n_nationkey) as col \n" + + "from cp.`tpch/nation.parquet` \n" + + "limit 0"; + + final String queryROW_NUMBER = "select ROW_NUMBER() over(order by n_nationkey) as col \n" + + "from cp.`tpch/nation.parquet` \n" + + "limit 0"; + + final TypeProtos.MajorType majorTypeDouble = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.FLOAT8) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + + final TypeProtos.MajorType majorTypeBigInt = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + + final List> expectedSchemaCUME_DIST = Lists.newArrayList(); + expectedSchemaCUME_DIST.add(Pair.of(SchemaPath.getSimplePath("col"), majorTypeDouble)); + + final List> expectedSchemaDENSE_RANK = Lists.newArrayList(); + expectedSchemaDENSE_RANK.add(Pair.of(SchemaPath.getSimplePath("col"), majorTypeBigInt)); + + final List> expectedSchemaPERCENT_RANK = Lists.newArrayList(); + expectedSchemaPERCENT_RANK.add(Pair.of(SchemaPath.getSimplePath("col"), majorTypeDouble)); + + final List> expectedSchemaRANK = Lists.newArrayList(); + expectedSchemaRANK.add(Pair.of(SchemaPath.getSimplePath("col"), majorTypeBigInt)); + + final List> expectedSchemaROW_NUMBER = Lists.newArrayList(); + expectedSchemaROW_NUMBER.add(Pair.of(SchemaPath.getSimplePath("col"), majorTypeBigInt)); + + testBuilder() + .sqlQuery(queryCUME_DIST) + .schemaBaseLine(expectedSchemaCUME_DIST) + .build() + .run(); + + testBuilder() + .sqlQuery(queryDENSE_RANK) + .schemaBaseLine(expectedSchemaDENSE_RANK) + .build() + .run(); + + testBuilder() + .sqlQuery(queryPERCENT_RANK) + .schemaBaseLine(expectedSchemaPERCENT_RANK) + .build() + .run(); + + testBuilder() + .sqlQuery(queryRANK) + .schemaBaseLine(expectedSchemaRANK) + .build() + .run(); + + testBuilder() + .sqlQuery(queryROW_NUMBER) + .schemaBaseLine(expectedSchemaROW_NUMBER) + .build() + .run(); + } + + @Test + public void testWindowNTILE() throws Exception { + final String query = "select ntile(1) over(order by position_id) as col \n" + + "from cp.`employee.json` \n" + + "limit 0"; + + final TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.INT) + .setMode(TypeProtos.DataMode.REQUIRED) + .build(); + + final List> expectedSchema = Lists.newArrayList(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(query) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + + @Test + public void testLeadLag() throws Exception { + final String queryLEAD = "select lead(cast(n_nationkey as BigInt)) over(order by n_nationkey) as col \n" + + "from cp.`tpch/nation.parquet` \n" + + "limit 0"; + final String queryLAG = "select lag(cast(n_nationkey as BigInt)) over(order by n_nationkey) as col \n" + + "from cp.`tpch/nation.parquet` \n" + + "limit 0"; + + final TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.BIGINT) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final List> expectedSchema = Lists.newArrayList(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(queryLEAD) + .schemaBaseLine(expectedSchema) + .build() + .run(); + + testBuilder() + .sqlQuery(queryLAG) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } + + @Test + public void testFirst_Last_Value() throws Exception { + final String queryFirst = "select first_value(cast(position_id as integer)) over(order by position_id) as col \n" + + "from cp.`employee.json` \n" + + "limit 0"; + + final String queryLast = "select first_value(cast(position_id as integer)) over(order by position_id) as col \n" + + "from cp.`employee.json` \n" + + "limit 0"; + + final TypeProtos.MajorType majorType = TypeProtos.MajorType.newBuilder() + .setMinorType(TypeProtos.MinorType.INT) + .setMode(TypeProtos.DataMode.OPTIONAL) + .build(); + + final List> expectedSchema = Lists.newArrayList(); + expectedSchema.add(Pair.of(SchemaPath.getSimplePath("col"), majorType)); + + testBuilder() + .sqlQuery(queryFirst) + .schemaBaseLine(expectedSchema) + .build() + .run(); + + testBuilder() + .sqlQuery(queryLast) + .schemaBaseLine(expectedSchema) + .build() + .run(); + } } From 600ba9ee1d7f321036a6390c0ff9d9872b1d80f0 Mon Sep 17 00:00:00 2001 From: Hsuan-Yi Chu Date: Thu, 17 Mar 2016 21:54:05 -0700 Subject: [PATCH 5/5] Bump calcite version to 1.4.0-drill-r11 --- .../org/apache/drill/TestFunctionsWithTypeExpoQueries.java | 5 ----- pom.xml | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java b/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java index 5df71f0f1db..5d16edda1b0 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java +++ b/exec/java-exec/src/test/java/org/apache/drill/TestFunctionsWithTypeExpoQueries.java @@ -476,11 +476,6 @@ public void testAvgCountStar() throws Exception { } @Test - @Ignore // This is temporarily turned off due to - // [1] [StarColumn] Reverse one change in CALCITE-356, - // which regresses AggChecker logic, after * query in schema-less table is added. - // [2] [StarColumn] - // When group-by a column, projecting on a star which cannot be expanded at planning time, use ITEM operator to wrap this column public void testUDFInGroupBy() throws Exception { final String query = "select count(*) as col1, substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2) as col2, \n" + "char_length(substr(lower(UPPER(cast(t3.full_name as varchar(100)))), 5, 2)) as col3 \n" + diff --git a/pom.xml b/pom.xml index 4dfa682b99e..058036b6e11 100644 --- a/pom.xml +++ b/pom.xml @@ -1279,7 +1279,7 @@ org.apache.calcite calcite-core - 1.4.0-drill-test-r16 + 1.4.0-drill-r11 org.jgrapht