From 2af1cf0ed3c1f509fd71d07bcf8f00a2f8c7e70f Mon Sep 17 00:00:00 2001 From: Volodymyr Vysotskyi Date: Wed, 7 Nov 2018 20:03:40 +0200 Subject: [PATCH] DRILL-3610: Fix TIMESTAMPADD and TIMESTAMPDIFF functions --- .../main/codegen/data/DateIntervalFunc.tdd | 1 + .../main/codegen/data/MathFunctionTypes.tdd | 23 +- .../IntervalNumericArithmetic.java | 18 +- .../TimestampDiff.java | 109 +++++++ .../exec/planner/logical/DrillOptiq.java | 274 +++++++++++------- .../planner/sql/DrillConvertletTable.java | 33 ++- .../exec/planner/sql/TypeInferenceUtils.java | 73 ++++- .../exec/fn/impl/TestDateAddFunctions.java | 159 ++++++++++ 8 files changed, 556 insertions(+), 134 deletions(-) create mode 100644 exec/java-exec/src/main/codegen/templates/DateIntervalFunctionTemplates/TimestampDiff.java diff --git a/exec/java-exec/src/main/codegen/data/DateIntervalFunc.tdd b/exec/java-exec/src/main/codegen/data/DateIntervalFunc.tdd index 96e16070aa5..bfa44c8df09 100644 --- a/exec/java-exec/src/main/codegen/data/DateIntervalFunc.tdd +++ b/exec/java-exec/src/main/codegen/data/DateIntervalFunc.tdd @@ -18,6 +18,7 @@ {intervals: ["Interval", "IntervalDay", "IntervalYear", "Int", "BigInt"] }, {truncInputTypes: ["Date", "TimeStamp", "Time", "Interval", "IntervalDay", "IntervalYear"] }, {truncUnits : ["Second", "Minute", "Hour", "Day", "Month", "Year", "Week", "Quarter", "Decade", "Century", "Millennium" ] }, + {timestampDiffUnits : ["Nanosecond", "Microsecond", "Second", "Minute", "Hour", "Day", "Month", "Year", "Week", "Quarter"] }, { varCharToDate: [ diff --git a/exec/java-exec/src/main/codegen/data/MathFunctionTypes.tdd b/exec/java-exec/src/main/codegen/data/MathFunctionTypes.tdd index 1724373a8d6..6cb48246ebd 100644 --- a/exec/java-exec/src/main/codegen/data/MathFunctionTypes.tdd +++ b/exec/java-exec/src/main/codegen/data/MathFunctionTypes.tdd @@ -28,7 +28,7 @@ {input1: "UInt4", input2: "UInt4", outputType: "UInt4", castType: "int"}, {input1: "UInt8", input2: "UInt8", outputType: "UInt8", castType: "long"} ] - }, + }, {className: "Subtract", funcName: "subtract", op: "-", types: [ {input1: "Int", input2: "Int", outputType: "Int", castType: "int"}, {input1: "BigInt", input2: "BigInt", outputType: "BigInt", castType: "long"}, @@ -42,7 +42,7 @@ {input1: "UInt8", input2: "UInt8", outputType: "UInt8", castType: "long"} ] }, - {className: "Multiply", funcName: "multiply", op: "*", types: [ + {className: "Multiply", funcName: "multiply", op: "*", types: [ {input1: "Int", input2: "Int", outputType: "Int", castType: "int"}, {input1: "BigInt", input2: "BigInt", outputType: "BigInt", castType: "long"}, {input1: "Float4", input2: "Float4", outputType: "Float4", castType: "float"}, @@ -54,8 +54,8 @@ {input1: "UInt4", input2: "UInt4", outputType: "UInt4", castType: "int"}, {input1: "UInt8", input2: "UInt8", outputType: "UInt8", castType: "long"} ] - }, - {className: "Divide", funcName: "divide", op: "/", types: [ + }, + {className: "Divide", funcName: "divide", op: "/", types: [ {input1: "Int", input2: "Int", outputType: "Int", castType: "int"}, {input1: "BigInt", input2: "BigInt", outputType: "BigInt", castType: "long"}, {input1: "Float4", input2: "Float4", outputType: "Float4", castType: "float"}, @@ -67,6 +67,19 @@ {input1: "UInt4", input2: "UInt4", outputType: "UInt4", castType: "int"}, {input1: "UInt8", input2: "UInt8", outputType: "UInt8", castType: "long"} ] - } + }, + {className: "DivideInt", funcName: "/int", op: "/", types: [ + {input1: "Int", input2: "Int", outputType: "Int", castType: "int"}, + {input1: "BigInt", input2: "BigInt", outputType: "Int", castType: "int"}, + {input1: "Float4", input2: "Float4", outputType: "Int", castType: "int"}, + {input1: "Float8", input2: "Float8", outputType: "Int", castType: "int"}, + {input1: "SmallInt", input2: "SmallInt", outputType: "Int", castType: "int"}, + {input1: "TinyInt", input2: "TinyInt", outputType: "Int", castType: "int"}, + {input1: "UInt1", input2: "UInt1", outputType: "Int", castType: "int"}, + {input1: "UInt2", input2: "UInt2", outputType: "Int", castType: "int"}, + {input1: "UInt4", input2: "UInt4", outputType: "Int", castType: "int"}, + {input1: "UInt8", input2: "UInt8", outputType: "Int", castType: "int"} + ] + } ] } \ No newline at end of file diff --git a/exec/java-exec/src/main/codegen/templates/DateIntervalFunctionTemplates/IntervalNumericArithmetic.java b/exec/java-exec/src/main/codegen/templates/DateIntervalFunctionTemplates/IntervalNumericArithmetic.java index 66e754c3606..f410601383f 100644 --- a/exec/java-exec/src/main/codegen/templates/DateIntervalFunctionTemplates/IntervalNumericArithmetic.java +++ b/exec/java-exec/src/main/codegen/templates/DateIntervalFunctionTemplates/IntervalNumericArithmetic.java @@ -129,20 +129,22 @@ public void eval() { } } - @SuppressWarnings("unused") - @FunctionTemplate(names = {"divide", "div"}, scope = FunctionTemplate.FunctionScope.SIMPLE, nulls=NullHandling.NULL_IF_NULL) - public static class ${intervaltype}${numerictype}DivideFunction implements DrillSimpleFunc { + @SuppressWarnings("unused") + @FunctionTemplate(names = {"divide", "div"<#if numerictype == "Int">, "/int"}, + scope = FunctionTemplate.FunctionScope.SIMPLE, + nulls = NullHandling.NULL_IF_NULL) + public static class ${intervaltype}${numerictype}DivideFunction implements DrillSimpleFunc { @Param ${intervaltype}Holder left; @Param ${numerictype}Holder right; @Output IntervalHolder out; - public void setup() { - } + public void setup() { + } - public void eval() { - <@intervalNumericArithmeticBlock left="left" right="right" temp = "temp" op = "/" out = "out" intervaltype=intervaltype /> - } + public void eval() { + <@intervalNumericArithmeticBlock left="left" right="right" temp = "temp" op = "/" out = "out" intervaltype=intervaltype /> } + } } \ No newline at end of file diff --git a/exec/java-exec/src/main/codegen/templates/DateIntervalFunctionTemplates/TimestampDiff.java b/exec/java-exec/src/main/codegen/templates/DateIntervalFunctionTemplates/TimestampDiff.java new file mode 100644 index 00000000000..54232e26825 --- /dev/null +++ b/exec/java-exec/src/main/codegen/templates/DateIntervalFunctionTemplates/TimestampDiff.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +<@pp.dropOutputFile /> +<#assign className="GTimestampDiff"/> + +<@pp.changeOutputFile name="/org/apache/drill/exec/expr/fn/impl/${className}.java"/> + +<#include "/@includes/license.ftl"/> + +package org.apache.drill.exec.expr.fn.impl; + +import org.apache.drill.exec.expr.DrillSimpleFunc; +import org.apache.drill.exec.expr.annotations.FunctionTemplate; +import org.apache.drill.exec.expr.annotations.FunctionTemplate.NullHandling; +import org.apache.drill.exec.expr.annotations.Output; +import org.apache.drill.exec.expr.annotations.Workspace; +import org.apache.drill.exec.expr.annotations.Param; +import org.apache.drill.exec.expr.holders.*; +import org.apache.drill.exec.record.RecordBatch; + +/* + * This class is generated using freemarker and the ${.template_name} template. + */ + +public class ${className} { + +<#list dateIntervalFunc.timestampDiffUnits as unit> + +<#list dateIntervalFunc.dates as fromUnit> +<#list dateIntervalFunc.dates as toUnit> + + @FunctionTemplate(name = "timestampdiff${unit}", + scope = FunctionTemplate.FunctionScope.SIMPLE, + nulls = FunctionTemplate.NullHandling.NULL_IF_NULL) + public static class TimestampDiff${unit}${fromUnit}To${toUnit} implements DrillSimpleFunc { + + @Param ${fromUnit}Holder left; + @Param ${toUnit}Holder right; + @Output BigIntHolder out; + + public void setup() { + } + + public void eval() { + <#if unit == "Nanosecond"> + out.value = (right.value - left.value) * 1000000; + <#elseif unit == "Microsecond"> + out.value = (right.value - left.value) * 1000; + <#elseif unit == "Second"> + out.value = (right.value - left.value) / org.apache.drill.exec.vector.DateUtilities.secondsToMillis; + <#elseif unit == "Minute"> + out.value = (right.value - left.value) / org.apache.drill.exec.vector.DateUtilities.minutesToMillis; + <#elseif unit == "Hour"> + out.value = (right.value - left.value) / org.apache.drill.exec.vector.DateUtilities.hoursToMillis; + <#elseif unit == "Day"> + out.value = (right.value - left.value) / org.apache.drill.exec.vector.DateUtilities.daysToStandardMillis; + <#elseif unit == "Week"> + out.value = (right.value - left.value) / 604800000; // 7 * 24 * 60 * 60 * 1000 + <#elseif unit == "Month" || unit == "Quarter" || unit == "Year"> + long timeMilliseconds = left.value % org.apache.drill.exec.vector.DateUtilities.daysToStandardMillis + - right.value % org.apache.drill.exec.vector.DateUtilities.daysToStandardMillis; + + java.time.Period between = java.time.Period.between( + java.time.Instant.ofEpochMilli(left.value).atZone(java.time.ZoneOffset.UTC).toLocalDate(), + java.time.Instant.ofEpochMilli(right.value).atZone(java.time.ZoneOffset.UTC).toLocalDate()); + int days = between.getDays(); + if (timeMilliseconds < 0 && days > 0) { + // in the case of negative time value increases left operand days value + between = java.time.Period.between( + java.time.Instant.ofEpochMilli(left.value + org.apache.drill.exec.vector.DateUtilities.daysToStandardMillis).atZone(java.time.ZoneOffset.UTC).toLocalDate(), + java.time.Instant.ofEpochMilli(right.value).atZone(java.time.ZoneOffset.UTC).toLocalDate()); + } else if (timeMilliseconds > 0 && days < 0) { + // in the case of negative days value decreases it for the right operand + between = java.time.Period.between( + java.time.Instant.ofEpochMilli(left.value - org.apache.drill.exec.vector.DateUtilities.daysToStandardMillis).atZone(java.time.ZoneOffset.UTC).toLocalDate(), + java.time.Instant.ofEpochMilli(right.value).atZone(java.time.ZoneOffset.UTC).toLocalDate()); + } + int months = between.getMonths() + between.getYears() * org.apache.drill.exec.vector.DateUtilities.yearsToMonths; + + <#if unit == "Month"> + out.value = months; + <#elseif unit == "Quarter"> + out.value = months / 4; + <#elseif unit == "Year"> + out.value = months / org.apache.drill.exec.vector.DateUtilities.yearsToMonths; + + + } + } + + + + +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java index 477b03c561d..63ce90f12ae 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java @@ -18,11 +18,14 @@ package org.apache.drill.exec.planner.logical; import java.math.BigDecimal; +import java.util.ArrayList; import java.util.GregorianCalendar; import java.util.LinkedList; import java.util.List; +import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.rel.type.RelDataType; +import org.apache.commons.lang3.StringUtils; import org.apache.drill.common.exceptions.UserException; import org.apache.drill.common.expression.ExpressionPosition; import org.apache.drill.common.expression.FieldReference; @@ -58,6 +61,7 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.NlsString; +import org.apache.drill.shaded.guava.com.google.common.base.Preconditions; import org.apache.drill.shaded.guava.com.google.common.collect.Lists; import org.apache.drill.exec.planner.physical.PlannerSettings; import org.apache.drill.exec.work.ExecErrorConstants; @@ -395,7 +399,7 @@ private LogicalExpression getDrillCastFunctionFromOptiq(RexCall call){ } private LogicalExpression getDrillFunctionFromOptiqCall(RexCall call) { - List args = Lists.newArrayList(); + List args = new ArrayList<>(); for(RexNode n : call.getOperands()){ args.add(n.accept(this)); @@ -408,114 +412,158 @@ private LogicalExpression getDrillFunctionFromOptiqCall(RexCall call) { /* Rewrite extract functions in the following manner * extract(year, date '2008-2-23') ---> extractYear(date '2008-2-23') */ - if (functionName.equals("extract")) { - - // Assert that the first argument to extract is a QuotedString - assert args.get(0) instanceof ValueExpressions.QuotedString; - - // Get the unit of time to be extracted - String timeUnitStr = ((ValueExpressions.QuotedString)args.get(0)).value; - - switch (timeUnitStr){ - case ("YEAR"): - case ("MONTH"): - case ("DAY"): - case ("HOUR"): - case ("MINUTE"): - case ("SECOND"): - String functionPostfix = timeUnitStr.substring(0, 1).toUpperCase() + timeUnitStr.substring(1).toLowerCase(); - functionName += functionPostfix; - return FunctionCallFactory.createExpression(functionName, args.subList(1, 2)); - default: - throw new UnsupportedOperationException("extract function supports the following time units: YEAR, MONTH, DAY, HOUR, MINUTE, SECOND"); + switch (functionName) { + case "extract": { + + // Assert that the first argument to extract is a QuotedString + assert args.get(0) instanceof ValueExpressions.QuotedString; + + // Get the unit of time to be extracted + String timeUnitStr = ((ValueExpressions.QuotedString) args.get(0)).value; + + TimeUnit timeUnit = TimeUnit.valueOf(timeUnitStr); + + switch (timeUnit) { + case YEAR: + case MONTH: + case DAY: + case HOUR: + case MINUTE: + case SECOND: + String functionPostfix = StringUtils.capitalize(timeUnitStr.toLowerCase()); + functionName += functionPostfix; + return FunctionCallFactory.createExpression(functionName, args.subList(1, 2)); + default: + throw new UnsupportedOperationException("extract function supports the following time units: YEAR, MONTH, DAY, HOUR, MINUTE, SECOND"); + } } - } else if (functionName.equals("trim")) { - String trimFunc = null; - List trimArgs = Lists.newArrayList(); - - assert args.get(0) instanceof ValueExpressions.QuotedString; - switch (((ValueExpressions.QuotedString)args.get(0)).value.toUpperCase()) { - case "LEADING": - trimFunc = "ltrim"; - break; - case "TRAILING": - trimFunc = "rtrim"; - break; - case "BOTH": - trimFunc = "btrim"; - break; - default: - assert 1 == 0; - } - - trimArgs.add(args.get(2)); - trimArgs.add(args.get(1)); - - return FunctionCallFactory.createExpression(trimFunc, trimArgs); - } else if (functionName.equals("date_part")) { - // Rewrite DATE_PART functions as extract functions - // assert that the function has exactly two arguments - assert argsSize == 2; - - /* Based on the first input to the date_part function we rewrite the function as the - * appropriate extract function. For example - * date_part('year', date '2008-2-23') ------> extractYear(date '2008-2-23') - */ - assert args.get(0) instanceof QuotedString; - - QuotedString extractString = (QuotedString) args.get(0); - String functionPostfix = extractString.value.substring(0, 1).toUpperCase() + extractString.value.substring(1).toLowerCase(); - return FunctionCallFactory.createExpression("extract" + functionPostfix, args.subList(1, 2)); - } else if (functionName.equals("concat")) { - - if (argsSize == 1) { - /* - * We treat concat with one argument as a special case. Since we don't have a function - * implementation of concat that accepts one argument. We simply add another dummy argument - * (empty string literal) to the list of arguments. - */ - List concatArgs = new LinkedList<>(args); - concatArgs.add(QuotedString.EMPTY_STRING); - - return FunctionCallFactory.createExpression(functionName, concatArgs); + case "timestampdiff": { + + // Assert that the first argument to extract is a QuotedString + Preconditions.checkArgument(args.get(0) instanceof ValueExpressions.QuotedString, + "The first argument of TIMESTAMPDIFF function should be QuotedString"); + + String timeUnitStr = ((ValueExpressions.QuotedString) args.get(0)).value; + + TimeUnit timeUnit = TimeUnit.valueOf(timeUnitStr); + + switch (timeUnit) { + case YEAR: + case MONTH: + case DAY: + case HOUR: + case MINUTE: + case SECOND: + case QUARTER: + case WEEK: + case MICROSECOND: + case NANOSECOND: + String functionPostfix = StringUtils.capitalize(timeUnitStr.toLowerCase()); + functionName += functionPostfix; + return FunctionCallFactory.createExpression(functionName, args.subList(1, 3)); + default: + throw new UnsupportedOperationException("TIMESTAMPDIFF function supports the following time units: " + + "YEAR, MONTH, DAY, HOUR, MINUTE, SECOND, QUARTER, WEEK, MICROSECOND, NANOSECOND"); + } + } + case "trim": { + String trimFunc; + List trimArgs = new ArrayList<>(); + + assert args.get(0) instanceof ValueExpressions.QuotedString; + switch (((ValueExpressions.QuotedString) args.get(0)).value.toUpperCase()) { + case "LEADING": + trimFunc = "ltrim"; + break; + case "TRAILING": + trimFunc = "rtrim"; + break; + case "BOTH": + trimFunc = "btrim"; + break; + default: + throw new UnsupportedOperationException("Invalid argument for TRIM function. " + + "Expected one of the following: LEADING, TRAILING, BOTH"); + } - } else if (argsSize > 2) { - List concatArgs = Lists.newArrayList(); + trimArgs.add(args.get(2)); + trimArgs.add(args.get(1)); - /* stack concat functions on top of each other if we have more than two arguments - * Eg: concat(col1, col2, col3) => concat(concat(col1, col2), col3) + return FunctionCallFactory.createExpression(trimFunc, trimArgs); + } + case "date_part": { + // Rewrite DATE_PART functions as extract functions + // assert that the function has exactly two arguments + assert argsSize == 2; + + /* Based on the first input to the date_part function we rewrite the function as the + * appropriate extract function. For example + * date_part('year', date '2008-2-23') ------> extractYear(date '2008-2-23') */ - concatArgs.add(args.get(0)); - concatArgs.add(args.get(1)); - - LogicalExpression first = FunctionCallFactory.createExpression(functionName, concatArgs); + assert args.get(0) instanceof QuotedString; - for (int i = 2; i < argsSize; i++) { - concatArgs = Lists.newArrayList(); - concatArgs.add(first); - concatArgs.add(args.get(i)); - first = FunctionCallFactory.createExpression(functionName, concatArgs); + QuotedString extractString = (QuotedString) args.get(0); + String functionPostfix = StringUtils.capitalize(extractString.value.toLowerCase()); + return FunctionCallFactory.createExpression("extract" + functionPostfix, args.subList(1, 2)); + } + case "concat": { + + if (argsSize == 1) { + /* + * We treat concat with one argument as a special case. Since we don't have a function + * implementation of concat that accepts one argument. We simply add another dummy argument + * (empty string literal) to the list of arguments. + */ + List concatArgs = new LinkedList<>(args); + concatArgs.add(QuotedString.EMPTY_STRING); + + return FunctionCallFactory.createExpression(functionName, concatArgs); + + } else if (argsSize > 2) { + List concatArgs = new ArrayList<>(); + + /* stack concat functions on top of each other if we have more than two arguments + * Eg: concat(col1, col2, col3) => concat(concat(col1, col2), col3) + */ + concatArgs.add(args.get(0)); + concatArgs.add(args.get(1)); + + LogicalExpression first = FunctionCallFactory.createExpression(functionName, concatArgs); + + for (int i = 2; i < argsSize; i++) { + concatArgs = new ArrayList<>(); + concatArgs.add(first); + concatArgs.add(args.get(i)); + first = FunctionCallFactory.createExpression(functionName, concatArgs); + } + + return first; } - - return first; + break; } - } else if (functionName.equals("length")) { - + case "length": { if (argsSize == 2) { - // Second argument should always be a literal specifying the encoding format - assert args.get(1) instanceof ValueExpressions.QuotedString; + // Second argument should always be a literal specifying the encoding format + assert args.get(1) instanceof ValueExpressions.QuotedString; - String encodingType = ((ValueExpressions.QuotedString) args.get(1)).value; - functionName += encodingType.substring(0, 1).toUpperCase() + encodingType.substring(1).toLowerCase(); + String encodingType = ((ValueExpressions.QuotedString) args.get(1)).value; + functionName += StringUtils.capitalize(encodingType.toLowerCase()); - return FunctionCallFactory.createExpression(functionName, args.subList(0, 1)); + return FunctionCallFactory.createExpression(functionName, args.subList(0, 1)); } - } else if ((functionName.equals("convert_from") || functionName.equals("convert_to")) - && args.get(1) instanceof QuotedString) { - return FunctionCallFactory.createConvert(functionName, ((QuotedString)args.get(1)).value, args.get(0), ExpressionPosition.UNKNOWN); - } else if (functionName.equals("date_trunc")) { - return handleDateTruncFunction(args); + break; + } + case "convert_from": + case "convert_to": { + if (args.get(1) instanceof QuotedString) { + return FunctionCallFactory.createConvert(functionName, ((QuotedString) args.get(1)).value, args.get(0), ExpressionPosition.UNKNOWN); + } + break; + } + case "date_trunc": { + return handleDateTruncFunction(args); + } } return FunctionCallFactory.createExpression(functionName, args); @@ -526,21 +574,23 @@ private LogicalExpression handleDateTruncFunction(final List assert args.get(0) instanceof ValueExpressions.QuotedString; // Get the unit of time to be extracted - String timeUnitStr = ((ValueExpressions.QuotedString)args.get(0)).value.toUpperCase(); - - switch (timeUnitStr){ - case ("YEAR"): - case ("MONTH"): - case ("DAY"): - case ("HOUR"): - case ("MINUTE"): - case ("SECOND"): - case ("WEEK"): - case ("QUARTER"): - case ("DECADE"): - case ("CENTURY"): - case ("MILLENNIUM"): - final String functionPostfix = timeUnitStr.substring(0, 1).toUpperCase() + timeUnitStr.substring(1).toLowerCase(); + String timeUnitStr = ((ValueExpressions.QuotedString) args.get(0)).value.toUpperCase(); + + TimeUnit timeUnit = TimeUnit.valueOf(timeUnitStr); + + switch (timeUnit) { + case YEAR: + case MONTH: + case DAY: + case HOUR: + case MINUTE: + case SECOND: + case WEEK: + case QUARTER: + case DECADE: + case CENTURY: + case MILLENNIUM: + final String functionPostfix = StringUtils.capitalize(timeUnitStr.toLowerCase()); return FunctionCallFactory.createExpression("date_trunc_" + functionPostfix, args.subList(1, 2)); } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java index 5f3b95e9739..8d2899f6212 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/DrillConvertletTable.java @@ -18,21 +18,29 @@ package org.apache.drill.exec.planner.sql; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; +import org.apache.calcite.avatica.util.TimeUnit; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlBasicCall; import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlIntervalQualifier; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql2rel.SqlRexConvertlet; import org.apache.calcite.sql2rel.SqlRexConvertletTable; import org.apache.calcite.sql2rel.StandardConvertletTable; import org.apache.drill.exec.planner.sql.parser.DrillCalciteWrapperUtility; -public class DrillConvertletTable implements SqlRexConvertletTable{ +public class DrillConvertletTable implements SqlRexConvertletTable { public static HashMap map = new HashMap<>(); @@ -61,6 +69,28 @@ public class DrillConvertletTable implements SqlRexConvertletTable{ } }; + // Custom convertlet to avoid rewriting TIMESTAMP_DIFF by Calcite. + private static final SqlRexConvertlet TIMESTAMP_DIFF_CONVERTLET = (cx, call) -> { + SqlLiteral unitLiteral = call.operand(0); + SqlIntervalQualifier qualifier = + new SqlIntervalQualifier(unitLiteral.symbolValue(TimeUnit.class), null, SqlParserPos.ZERO); + + List operands = Arrays.asList( + cx.convertExpression(qualifier), + cx.convertExpression(call.operand(1)), + cx.convertExpression(call.operand(2))); + + RelDataTypeFactory typeFactory = cx.getTypeFactory(); + + RelDataType returnType = typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.BIGINT), + cx.getValidator().getValidatedNodeType(call.operand(1)).isNullable() + || cx.getValidator().getValidatedNodeType(call.operand(2)).isNullable()); + + return cx.getRexBuilder().makeCall(returnType, + SqlStdOperatorTable.TIMESTAMP_DIFF, operands); + }; + static { // Use custom convertlet for EXTRACT function map.put(SqlStdOperatorTable.EXTRACT, DrillExtractConvertlet.INSTANCE); @@ -68,6 +98,7 @@ public class DrillConvertletTable implements SqlRexConvertletTable{ // which is not suitable for Infinity value case map.put(SqlStdOperatorTable.SQRT, SQRT_CONVERTLET); map.put(SqlStdOperatorTable.COALESCE, COALESCE_CONVERTLET); + map.put(SqlStdOperatorTable.TIMESTAMP_DIFF, TIMESTAMP_DIFF_CONVERTLET); map.put(SqlStdOperatorTable.AVG, new DrillAvgVarianceConvertlet(SqlKind.AVG)); map.put(SqlStdOperatorTable.STDDEV_POP, new DrillAvgVarianceConvertlet(SqlKind.STDDEV_POP)); map.put(SqlStdOperatorTable.STDDEV_SAMP, new DrillAvgVarianceConvertlet(SqlKind.STDDEV_SAMP)); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java index 016473c03db..23406e06937 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java @@ -17,6 +17,7 @@ */ package org.apache.drill.exec.planner.sql; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.drill.shaded.guava.com.google.common.collect.ImmutableMap; import org.apache.drill.shaded.guava.com.google.common.collect.ImmutableSet; import org.apache.drill.shaded.guava.com.google.common.collect.Lists; @@ -140,6 +141,7 @@ public class TypeInferenceUtils { private static final ImmutableMap funcNameToInference = ImmutableMap. builder() .put("DATE_PART", DrillDatePartSqlReturnTypeInference.INSTANCE) + .put(SqlStdOperatorTable.TIMESTAMP_ADD.getName(), DrillTimestampAddTypeInference.INSTANCE) .put(SqlKind.SUM.name(), DrillSumSqlReturnTypeInference.INSTANCE) .put(SqlKind.COUNT.name(), DrillCountSqlReturnTypeInference.INSTANCE) .put("CONCAT", DrillConcatSqlReturnTypeInference.INSTANCE_CONCAT) @@ -555,6 +557,60 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { } } + private static class DrillTimestampAddTypeInference implements SqlReturnTypeInference { + private static final SqlReturnTypeInference INSTANCE = new DrillTimestampAddTypeInference(); + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + RelDataTypeFactory factory = opBinding.getTypeFactory(); + // operands count ond order is checked at parsing stage + RelDataType inputType = opBinding.getOperandType(2); + boolean isNullable = inputType.isNullable() || opBinding.getOperandType(1).isNullable(); + + SqlTypeName inputTypeName = inputType.getSqlTypeName(); + + TimeUnit qualifier = ((SqlLiteral) ((SqlCallBinding) opBinding).operand(0)).getValueAs(TimeUnit.class); + + SqlTypeName sqlTypeName; + + // follow up with type inference of reduced expression + switch (qualifier) { + case DAY: + case WEEK: + case MONTH: + case QUARTER: + case YEAR: + case NANOSECOND: // NANOSECOND is not supported by Calcite SqlTimestampAddFunction. + // Once it is fixed, NANOSECOND should be moved to the group below. + sqlTypeName = inputTypeName; + break; + case MICROSECOND: + case MILLISECOND: + // for MICROSECOND and MILLISECOND should be specified precision + return factory.createTypeWithNullability( + factory.createSqlType(SqlTypeName.TIMESTAMP, 3), + isNullable); + case SECOND: + case MINUTE: + case HOUR: + sqlTypeName = SqlTypeName.TIMESTAMP; + break; + default: + sqlTypeName = SqlTypeName.ANY; + } + + // preserves precision of input type if it was specified + if (inputType.getSqlTypeName().allowsPrecNoScale()) { + RelDataType type = factory.createSqlType(sqlTypeName, inputType.getPrecision()); + return factory.createTypeWithNullability(type, isNullable); + } + return createCalciteTypeWithNullability( + opBinding.getTypeFactory(), + sqlTypeName, + isNullable); + } + } + private static class DrillSubstringSqlReturnTypeInference implements SqlReturnTypeInference { private static final DrillSubstringSqlReturnTypeInference INSTANCE = new DrillSubstringSqlReturnTypeInference(); @@ -823,15 +879,16 @@ private static DrillFuncHolder resolveDrillFuncHolder(final SqlOperatorBinding o /** * For Extract and date_part functions, infer the return types based on timeUnit */ - public static SqlTypeName getSqlTypeNameForTimeUnit(String timeUnit) { - switch (timeUnit.toUpperCase()){ - case "YEAR": - case "MONTH": - case "DAY": - case "HOUR": - case "MINUTE": + public static SqlTypeName getSqlTypeNameForTimeUnit(String timeUnitStr) { + TimeUnit timeUnit = TimeUnit.valueOf(timeUnitStr); + switch (timeUnit) { + case YEAR: + case MONTH: + case DAY: + case HOUR: + case MINUTE: return SqlTypeName.BIGINT; - case "SECOND": + case SECOND: return SqlTypeName.DOUBLE; default: throw UserException diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestDateAddFunctions.java b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestDateAddFunctions.java index 20f19954c34..eb7b1ed98b9 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestDateAddFunctions.java +++ b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestDateAddFunctions.java @@ -19,6 +19,11 @@ import java.time.LocalDate; import java.time.LocalDateTime; +import java.time.LocalTime; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.apache.drill.categories.SqlFunctionTest; import org.apache.drill.categories.UnlikelyTest; @@ -30,6 +35,29 @@ @Category({UnlikelyTest.class, SqlFunctionTest.class}) public class TestDateAddFunctions extends BaseTestQuery { + private final List QUALIFIERS = Arrays.asList( + "FRAC_SECOND", + "MICROSECOND", + "NANOSECOND", + "SQL_TSI_FRAC_SECOND", + "SQL_TSI_MICROSECOND", + "SECOND", + "SQL_TSI_SECOND", + "MINUTE", + "SQL_TSI_MINUTE", + "HOUR", + "SQL_TSI_HOUR", + "DAY", + "SQL_TSI_DAY", + "WEEK", + "SQL_TSI_WEEK", + "MONTH", + "SQL_TSI_MONTH", + "QUARTER", + "SQL_TSI_QUARTER", + "YEAR", + "SQL_TSI_YEAR"); + @Test public void testDateAddIntervalDay() throws Exception { String query = "select date_add(timestamp '2015-01-24 07:27:05.0', interval '3' day) as col1,\n" + @@ -84,4 +112,135 @@ public void testDateAddIntegerAsDay() throws Exception { LocalDate.parse("2015-01-29")) .go(); } + + @Test // DRILL-3610 + public void testTimestampAddDiffLiteralTypeInference() throws Exception { + Map dateTypes = new HashMap<>(); + dateTypes.put("DATE", "2013-03-31"); + dateTypes.put("TIME", "00:02:03.123"); + dateTypes.put("TIMESTAMP", "2013-03-31 00:02:03"); + + for (String qualifier : QUALIFIERS) { + for (Map.Entry typeResultPair : dateTypes.entrySet()) { + String dateTimeLiteral = typeResultPair.getValue(); + String type = typeResultPair.getKey(); + + test("SELECT TIMESTAMPADD(%s, 0, CAST('%s' AS %s)) col1", + qualifier, dateTimeLiteral, type); + + // TIMESTAMPDIFF with args of different types + for (Map.Entry secondArg : dateTypes.entrySet()) { + test("SELECT TIMESTAMPDIFF(%s, CAST('%s' AS %s), CAST('%s' AS %s)) col1", + qualifier, dateTimeLiteral, type, secondArg.getValue(), secondArg.getKey()); + } + } + } + } + + @Test // DRILL-3610 + public void testTimestampAddDiffTypeInference() throws Exception { + for (String qualifier : QUALIFIERS) { + test( + "SELECT TIMESTAMPADD(%1$s, 0, `date`) col1," + + "TIMESTAMPADD(%1$s, 0, `time`) timeReq," + + "TIMESTAMPADD(%1$s, 0, `timestamp`) timestampReq," + + "TIMESTAMPADD(%1$s, 0, t.time_map.`date`) dateOpt," + + "TIMESTAMPADD(%1$s, 0, t.time_map.`time`) timeOpt," + + "TIMESTAMPADD(%1$s, 0, t.time_map.`timestamp`) timestampOpt\n" + + "FROM cp.`datetime.parquet` t", qualifier); + + test( + "SELECT TIMESTAMPDIFF(%1$s, `date`, `date`) col1," + + "TIMESTAMPDIFF(%1$s, `time`, `time`) timeReq," + + "TIMESTAMPDIFF(%1$s, `timestamp`, `timestamp`) timestampReq," + + "TIMESTAMPDIFF(%1$s, `timestamp`, t.time_map.`date`) timestampReqTimestampOpt," + + "TIMESTAMPDIFF(%1$s, `timestamp`, t.time_map.`timestamp`) timestampReqTimestampOpt," + + "TIMESTAMPDIFF(%1$s, `date`, `time`) timeDate," + + "TIMESTAMPDIFF(%1$s, `time`, `date`) Datetime," + + "TIMESTAMPDIFF(%1$s, t.time_map.`date`, t.time_map.`date`) dateOpt," + + "TIMESTAMPDIFF(%1$s, t.time_map.`time`, t.time_map.`time`) timeOpt," + + "TIMESTAMPDIFF(%1$s, t.time_map.`timestamp`, t.time_map.`timestamp`) timestampOpt\n" + + "FROM cp.`datetime.parquet` t", qualifier); + } + } + + @Test // DRILL-3610 + public void testTimestampAddParquet() throws Exception { + String query = + "SELECT TIMESTAMPADD(SECOND, 1, `date`) dateReq," + + "TIMESTAMPADD(QUARTER, 1, `time`) timeReq," + + "TIMESTAMPADD(DAY, 1, `timestamp`) timestampReq," + + "TIMESTAMPADD(MONTH, 1, t.time_map.`date`) dateOpt," + + "TIMESTAMPADD(HOUR, 1, t.time_map.`time`) timeOpt," + + "TIMESTAMPADD(YEAR, 1, t.time_map.`timestamp`) timestampOpt\n" + + "FROM cp.`datetime.parquet` t"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("dateReq", "timeReq", "timestampReq", "dateOpt", "timeOpt", "timestampOpt") + .baselineValues( + LocalDateTime.parse("1970-01-11T00:00:01"), LocalTime.parse("00:00:03.600"), LocalDateTime.parse("2018-03-24T17:40:52.123"), + LocalDateTime.parse("1970-02-11T00:00"), LocalTime.parse("01:00:03.600"), LocalDateTime.parse("2019-03-23T17:40:52.123")) + .go(); + } + + @Test // DRILL-3610 + public void testTimestampDiffParquet() throws Exception { + String query = + "SELECT TIMESTAMPDIFF(SECOND, DATE '1970-01-15', `date`) dateReq," + + "TIMESTAMPDIFF(QUARTER, TIME '12:00:03.600', `time`) timeReq," + + "TIMESTAMPDIFF(DAY, TIMESTAMP '2018-03-24 17:40:52.123', `timestamp`) timestampReq," + + "TIMESTAMPDIFF(MONTH, DATE '1971-10-30', t.time_map.`date`) dateOpt," + + "TIMESTAMPDIFF(HOUR, TIME '18:00:03.600', t.time_map.`time`) timeOpt," + + "TIMESTAMPDIFF(YEAR, TIMESTAMP '2020-03-24 17:40:52.123', t.time_map.`timestamp`) timestampOpt\n" + + "FROM cp.`datetime.parquet` t"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("dateReq", "timeReq", "timestampReq", "dateOpt", "timeOpt", "timestampOpt") + .baselineValues(-345600L, 0L, -1L, -21L, -18L, -2L) + .go(); + } + + @Test // DRILL-3610 + public void testTimestampAddDiffNull() throws Exception { + String query = + "SELECT TIMESTAMPDIFF(SECOND, DATE '1970-01-15', a) col1," + + "TIMESTAMPDIFF(QUARTER, a, DATE '1970-01-15') col2," + + "TIMESTAMPDIFF(DAY, a, a) col3," + + "TIMESTAMPADD(MONTH, 1, a) col4," + + "TIMESTAMPADD(MONTH, b, DATE '1970-01-15') col5," + + "TIMESTAMPADD(MONTH, b, a) col6\n" + + "FROM" + + "(SELECT CASE WHEN FALSE THEN TIME '12:00:03.600' ELSE null END AS a," + + "CASE WHEN FALSE THEN 2 ELSE null END AS b)"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("col1", "col2", "col3", "col4", "col5", "col6") + .baselineValues(null, null, null, null, null, null) + .go(); + } + + @Test // DRILL-3610 + public void testTimestampDiffTimeDateTransition() throws Exception { + String query = + "SELECT TIMESTAMPDIFF(SECOND, time '12:30:00.123', time '12:30:00') col1," + + "TIMESTAMPDIFF(DAY, TIMESTAMP '1970-01-15 15:30:00', TIMESTAMP '1970-01-16 12:30:00') col2," + + "TIMESTAMPDIFF(DAY, TIMESTAMP '1970-01-16 12:30:00', TIMESTAMP '1970-01-15 15:30:00') col3," + + "TIMESTAMPDIFF(MONTH, TIMESTAMP '1970-01-16 12:30:00', TIMESTAMP '1970-03-15 15:30:00') col4," + + "TIMESTAMPDIFF(MONTH, TIMESTAMP '1970-03-15 15:30:00', TIMESTAMP '1970-01-16 12:30:00') col5," + + "TIMESTAMPDIFF(DAY, DATE '2012-01-01', DATE '2013-01-01') col6," + + "TIMESTAMPDIFF(DAY, DATE '2013-01-01', DATE '2014-01-01') col7"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("col1", "col2", "col3", "col4", "col5", "col6", "col7") + .baselineValues(0L, 0L, 0L, 1L, -1L, 366L, 365L) + .go(); + } }