diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ed69c42dcb825..471e8bd68b9cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -127,6 +127,7 @@ object FunctionRegistry { expression[Tanh]("tanh"), expression[ToDegrees]("degrees"), expression[ToRadians]("radians"), + expression[Round]("round"), // misc functions expression[Md5]("md5"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index c31890e27fb54..ca31009c99419 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.BigDecimalConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -520,3 +523,95 @@ case class Logarithm(left: Expression, right: Expression) """ } } + +case class Round(children: Seq[Expression]) extends Expression { + + def nullable: Boolean = true + + def dataType: DataType = { + children(0).dataType match { + case StringType | BinaryType => DoubleType + case t => t + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size < 1 || children.size > 2) { + return TypeCheckFailure(s"ROUND require one or two arguments, got ${children.size}") + } + children(0).dataType match { + case _: NumericType | NullType | BinaryType | StringType => // satisfy requirement + case dt => + return TypeCheckFailure(s"Only numeric, string or binary data types" + + s" are allowed for ROUND function, got $dt") + } + if (children.size == 2) { + children(1) match { + case Literal(value, LongType) => + if (value.asInstanceOf[Long] < Int.MinValue || value.asInstanceOf[Long] > Int.MaxValue) { + return TypeCheckFailure("ROUND scale argument out of allowed range") + } + case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement + case child => + if (child.find { case _: AttributeReference => true; case _ => false } != None) { + return TypeCheckFailure("Only Integral Literal or Null Literal " + + s"are allowed for ROUND scale arguments, got ${child.dataType}") + } + } + } + TypeCheckSuccess + } + + def eval(input: InternalRow): Any = { + val evalE1 = children(0).eval(input) + if (evalE1 == null) { + return null + } + + var _scale: Int = 0 + if (children.size == 2) { + val evalE2 = children(1).eval(input) + if (evalE2 == null) { + return null + } else { + _scale = evalE2.asInstanceOf[Int] + } + } + + children(0).dataType match { + case decimalType: DecimalType => + // TODO: Support Decimal Round + case ByteType => + round(evalE1.asInstanceOf[Byte], _scale) + case ShortType => + round(evalE1.asInstanceOf[Short], _scale) + case IntegerType => + round(evalE1.asInstanceOf[Int], _scale) + case LongType => + round(evalE1.asInstanceOf[Long], _scale) + case FloatType => + round(evalE1.asInstanceOf[Float], _scale) + case DoubleType => + round(evalE1.asInstanceOf[Double], _scale) + case StringType => + round(evalE1.asInstanceOf[UTF8String].toString, _scale) + case BinaryType => + round(UTF8String.fromBytes(evalE1.asInstanceOf[Array[Byte]]).toString, _scale) + } + } + + private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = { + input match { + case f: Float if (f.isNaN || f.isInfinite) => return input + case d: Double if (d.isNaN || d.isInfinite) => return input + case _ => + } + bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP)) + } + + private def round(input: String, scale: Int): Any = { + try round(input.toDouble, scale) catch { + case _ : NumberFormatException => null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala new file mode 100644 index 0000000000000..1320680925c80 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +trait BigDecimalConverter[T] { + def toBigDecimal(in: T) : BigDecimal + def fromBigDecimal(bd: BigDecimal) : T +} + +/** + * Helper type converters to work with BigDecimal + * from http://stackoverflow.com/a/30979266/1115193 + */ +object BigDecimalConverter { + + implicit object ByteConverter extends BigDecimalConverter[Byte] { + def toBigDecimal(in: Byte) = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal) = bd.toByte + } + + implicit object ShortConverter extends BigDecimalConverter[Short] { + def toBigDecimal(in: Short) = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal) = bd.toShort + } + + implicit object IntConverter extends BigDecimalConverter[Int] { + def toBigDecimal(in: Int) = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal) = bd.toInt + } + + implicit object LongConverter extends BigDecimalConverter[Long] { + def toBigDecimal(in: Long) = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal) = bd.toLong + } + + implicit object FloatConverter extends BigDecimalConverter[Float] { + def toBigDecimal(in: Float) = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal) = bd.toFloat + } + + implicit object DoubleConverter extends BigDecimalConverter[Double] { + def toBigDecimal(in: Double) = BigDecimal(in) + def fromBigDecimal(bd: BigDecimal) = bd.toDouble + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8e0551b23eea6..fcefa8f891265 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -171,4 +171,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), "Odd position only allow foldable and not-null StringType expressions") } + + test("check types for ROUND") { + assertError(Round(Seq()), "ROUND require one or two arguments") + assertError(Round(Seq(Literal(null),'booleanField)), + "Only Integral Literal or Null Literal are allowed for ROUND scale argument") + assertError(Round(Seq(Literal(null), 'complexField)), + "Only Integral Literal or Null Literal are allowed for ROUND scale argument") + assertSuccess(Round(Seq(Literal(null), Literal(null)))) + assertError(Round(Seq('booleanField, 'intField)), + "Only numeric, string or binary data types are allowed for ROUND function") + assertError(Round(Seq(Literal(null), Literal(1L + Int.MaxValue))), + "ROUND scale argument out of allowed range") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 7ca9e30b2bcd5..c79b9ca0a3340 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -336,4 +336,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { null, create_row(null)) } + + test("round test") { + val piRounds = Seq( + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, + 3.1, 3.14, 3.142, 3.1416, 3.14159, 3.141593, 3.1415927, 3.14159265, 3.141592654, + 3.1415926536, 3.14159265359, 3.14159265359, 3.1415926535898, 3.14159265358979, + 3.141592653589793, 3.141592653589793) + (-16 to 16).zipWithIndex.foreach { + case (scale, i) => + checkEvaluation(Round(Seq(3.141592653589793, scale)), piRounds(i), EmptyRow) + } + } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index c884c399281a8..65a6a5023ea62 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -919,7 +919,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_repeat", "udf_rlike", "udf_round", - // "udf_round_3", TODO: FIX THIS failed due to cast exception + "udf_round_3", "udf_rpad", "udf_rtrim", "udf_second",