From f0e129740dc2442a21dfa7fbd97360df87291095 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 14 Jul 2015 23:30:41 -0700 Subject: [PATCH] [SPARK-8279][SQL]Add math function round JIRA: https://issues.apache.org/jira/browse/SPARK-8279 Author: Yijie Shen Closes #6938 from yijieshen/udf_round_3 and squashes the following commits: 07a124c [Yijie Shen] remove useless def children 392b65b [Yijie Shen] add negative scale test in DecimalSuite 61760ee [Yijie Shen] address reviews 302a78a [Yijie Shen] Add dataframe function test 31dfe7c [Yijie Shen] refactor round to make it readable 8c7a949 [Yijie Shen] rebase & inputTypes update 9555e35 [Yijie Shen] tiny style fix d10be4a [Yijie Shen] use TypeCollection to specify wanted input and implicit cast c3b9839 [Yijie Shen] rely on implict cast to handle string input b0bff79 [Yijie Shen] make round's inner method's name more meaningful 9bd6930 [Yijie Shen] revert accidental change e6f44c4 [Yijie Shen] refactor eval and genCode 1b87540 [Yijie Shen] modify checkInputDataTypes using foldable 5486b2d [Yijie Shen] DataFrame API modification 2077888 [Yijie Shen] codegen versioned eval 6cd9a64 [Yijie Shen] refactor Round's constructor 9be894e [Yijie Shen] add round functions in o.a.s.sql.functions 7c83e13 [Yijie Shen] more tests on round 56db4bb [Yijie Shen] Add decimal support to Round 7e163ae [Yijie Shen] style fix 653d047 [Yijie Shen] Add math function round --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 203 +++++++++++++++++- .../ExpressionTypeCheckingSuite.scala | 17 ++ .../expressions/MathFunctionsSuite.scala | 44 ++++ .../sql/types/decimal/DecimalSuite.scala | 23 +- .../org/apache/spark/sql/functions.scala | 32 +++ .../spark/sql/MathExpressionsSuite.scala | 15 ++ .../execution/HiveCompatibilitySuite.scala | 7 +- 8 files changed, 329 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6b1a94e4b2ad4..ec75f51d5e4ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -117,6 +117,7 @@ object FunctionRegistry { expression[Pow]("power"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), + expression[Round]("round"), expression[ShiftLeft]("shiftleft"), expression[ShiftRight]("shiftright"), expression[ShiftRightUnsigned]("shiftrightunsigned"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 4b7fe05dd4980..a7ad452ef4943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -520,3 +522,202 @@ case class Logarithm(left: Expression, right: Expression) """ } } + +/** + * Round the `child`'s result to `scale` decimal place when `scale` >= 0 + * or round at integral part when `scale` < 0. + * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30. + * + * Child of IntegralType would eval to itself when `scale` >= 0. + * Child of FractionalType whose value is NaN or Infinite would always eval to itself. + * + * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]], + * which leads to scale update in DecimalType's [[PrecisionInfo]] + * + * @param child expr to be round, all [[NumericType]] is allowed as Input + * @param scale new scale to be round to, this should be a constant int at runtime + */ +case class Round(child: Expression, scale: Expression) + extends BinaryExpression with ExpectsInputTypes { + + import BigDecimal.RoundingMode.HALF_UP + + def this(child: Expression) = this(child, Literal(0)) + + override def left: Expression = child + override def right: Expression = scale + + // round of Decimal would eval to null if it fails to `changePrecision` + override def nullable: Boolean = true + + override def foldable: Boolean = child.foldable + + override lazy val dataType: DataType = child.dataType match { + // if the new scale is bigger which means we are scaling up, + // keep the original scale as `Decimal` does + case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale) + case t => t + } + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckSuccess => + if (scale.foldable) { + TypeCheckSuccess + } else { + TypeCheckFailure("Only foldable Expression is allowed for scale arguments") + } + case f => f + } + } + + // Avoid repeated evaluation since `scale` is a constant int, + // avoid unnecessary `child` evaluation in both codegen and non-codegen eval + // by checking if scaleV == null as well. + private lazy val scaleV: Any = scale.eval(EmptyRow) + private lazy val _scale: Int = scaleV.asInstanceOf[Int] + + override def eval(input: InternalRow): Any = { + if (scaleV == null) { // if scale is null, no need to eval its child at all + null + } else { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + nullSafeEval(evalE) + } + } + } + + // not overriding since _scale is a constant int at runtime + def nullSafeEval(input1: Any): Any = { + child.dataType match { + case _: DecimalType => + val decimal = input1.asInstanceOf[Decimal] + if (decimal.changePrecision(decimal.precision, _scale)) decimal else null + case ByteType => + BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte + case ShortType => + BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort + case IntegerType => + BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt + case LongType => + BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong + case FloatType => + val f = input1.asInstanceOf[Float] + if (f.isNaN || f.isInfinite) { + f + } else { + BigDecimal(f).setScale(_scale, HALF_UP).toFloat + } + case DoubleType => + val d = input1.asInstanceOf[Double] + if (d.isNaN || d.isInfinite) { + d + } else { + BigDecimal(d).setScale(_scale, HALF_UP).toDouble + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val ce = child.gen(ctx) + + val evaluationCode = child.dataType match { + case _: DecimalType => + s""" + if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) { + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.isNull} = true; + }""" + case ByteType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case ShortType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case IntegerType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case LongType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case FloatType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" + } else { + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); + }""" + } + case DoubleType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" + } else { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); + }""" + } + } + + if (scaleV == null) { // if scale is null, no need to eval its child at all + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + s""" + ${ce.code} + boolean ${ev.isNull} = ${ce.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + $evaluationCode + } + """ + } + } + + override def prettyName: String = "round" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 5958acbe009ca..e885a18254ea0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -52,6 +52,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { s"differing types in '${expr.prettyString}' (int and boolean)") } + def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + assertSuccess(expr) + } + assert(e.getMessage.contains(errorMessage)) + } + test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "operator - accepts numeric type") assertError(Abs('stringField), "function abs accepts numeric type") @@ -171,4 +178,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), "Odd position only allow foldable and not-null StringType expressions") } + + test("check types for ROUND") { + assertErrorWithImplicitCast(Round(Literal(null), 'booleanField), + "data type mismatch: argument 2 is expected to be of type int") + assertErrorWithImplicitCast(Round(Literal(null), 'complexField), + "data type mismatch: argument 2 is expected to be of type int") + assertSuccess(Round(Literal(null), Literal(null))) + assertError(Round('booleanField, 'intField), + "data type mismatch: argument 1 is expected to be of type numeric") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 7ca9e30b2bcd5..52a874a9d89ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.math.BigDecimal.RoundingMode + import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite @@ -336,4 +338,46 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { null, create_row(null)) } + + test("round") { + val domain = -6 to 6 + val doublePi: Double = math.Pi + val shortPi: Short = 31415 + val intPi: Int = 314159265 + val longPi: Long = 31415926535897932L + val bdPi: BigDecimal = BigDecimal(31415927L, 7) + + val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, + 3.1416, 3.14159, 3.141593) + + val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ + Seq.fill[Short](7)(31415) + + val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, + 314159270) ++ Seq.fill(7)(314159265) + + val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L, + 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ + Seq.fill(7)(31415926535897932L) + + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) + + domain.zipWithIndex.foreach { case (scale, i) => + checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) + checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) + checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) + checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) + } + + // round_scale > current_scale would result in precision increase + // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null + (0 to 7).foreach { i => + checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) + } + (8 to 10).foreach { scale => + checkEvaluation(Round(bdPi, scale), null, EmptyRow) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index 030bb6d21b18b..f0c849d1a1564 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -24,14 +24,14 @@ import org.scalatest.PrivateMethodTester import scala.language.postfixOps class DecimalSuite extends SparkFunSuite with PrivateMethodTester { - test("creating decimals") { - /** Check that a Decimal has the given string representation, precision and scale */ - def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { - assert(d.toString === string) - assert(d.precision === precision) - assert(d.scale === scale) - } + /** Check that a Decimal has the given string representation, precision and scale */ + private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { + assert(d.toString === string) + assert(d.precision === precision) + assert(d.scale === scale) + } + test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) @@ -53,6 +53,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) } + test("creating decimals with negative scale") { + checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3) + checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2) + checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9) + checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10) + checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10) + checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10) + } + test("double and long values") { /** Check that a Decimal converts to the given double and long values */ def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d4e160ed8057..5119ee31d852d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1389,6 +1389,38 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Returns the value of the column `e` rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column): Column = round(e.expr, 0) + + /** + * Returns the value of the given column rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String): Column = round(Column(columnName), 0) + + /** + * Returns the value of `e` rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) + + /** + * Returns the value of the given column rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String, scale: Int): Column = round(Column(columnName), scale) + /** * Shift the the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index b30b9f12258b9..087126bb2e513 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -198,6 +198,21 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(rint, math.rint) } + test("round") { + val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") + checkAnswer( + df.select(round('a), round('a, -1), round('a, -2)), + Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) + ) + + val pi = 3.1415 + checkAnswer( + ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), + Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142)) + ) + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index c884c399281a8..4ada64bc21966 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -221,9 +221,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_when", "udf_case", - // Needs constant object inspectors - "udf_round", - // the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive // is src(key STRING, value STRING), and in the reflect.q, it failed in // Integer.valueOf, which expect the first argument passed as STRING type not INT. @@ -918,8 +915,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_regexp_replace", "udf_repeat", "udf_rlike", - "udf_round", - // "udf_round_3", TODO: FIX THIS failed due to cast exception + // "udf_round", turn this on after we figure out null vs nan vs infinity + "udf_round_3", "udf_rpad", "udf_rtrim", "udf_second",