Skip to content

Commit

Permalink
[SPARK-8279][SQL]Add math function round
Browse files Browse the repository at this point in the history
JIRA: https://issues.apache.org/jira/browse/SPARK-8279

Author: Yijie Shen <henry.yijieshen@gmail.com>

Closes #6938 from yijieshen/udf_round_3 and squashes the following commits:

07a124c [Yijie Shen] remove useless def children
392b65b [Yijie Shen] add negative scale test in DecimalSuite
61760ee [Yijie Shen] address reviews
302a78a [Yijie Shen] Add dataframe function test
31dfe7c [Yijie Shen] refactor round to make it readable
8c7a949 [Yijie Shen] rebase & inputTypes update
9555e35 [Yijie Shen] tiny style fix
d10be4a [Yijie Shen] use TypeCollection to specify wanted input and implicit cast
c3b9839 [Yijie Shen] rely on implict cast to handle string input
b0bff79 [Yijie Shen] make round's inner method's name more meaningful
9bd6930 [Yijie Shen] revert accidental change
e6f44c4 [Yijie Shen] refactor eval and genCode
1b87540 [Yijie Shen] modify checkInputDataTypes using foldable
5486b2d [Yijie Shen] DataFrame API modification
2077888 [Yijie Shen] codegen versioned eval
6cd9a64 [Yijie Shen] refactor Round's constructor
9be894e [Yijie Shen] add round functions in o.a.s.sql.functions
7c83e13 [Yijie Shen] more tests on round
56db4bb [Yijie Shen] Add decimal support to Round
7e163ae [Yijie Shen] style fix
653d047 [Yijie Shen] Add math function round
  • Loading branch information
yjshen authored and rxin committed Jul 15, 2015
1 parent 3f6296f commit f0e1297
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 13 deletions.
Expand Up @@ -117,6 +117,7 @@ object FunctionRegistry {
expression[Pow]("power"),
expression[UnaryPositive]("positive"),
expression[Rint]("rint"),
expression[Round]("round"),
expression[ShiftLeft]("shiftleft"),
expression[ShiftRight]("shiftright"),
expression[ShiftRightUnsigned]("shiftrightunsigned"),
Expand Down
Expand Up @@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions

import java.{lang => jl}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -520,3 +522,202 @@ case class Logarithm(left: Expression, right: Expression)
"""
}
}

/**
* Round the `child`'s result to `scale` decimal place when `scale` >= 0
* or round at integral part when `scale` < 0.
* For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30.
*
* Child of IntegralType would eval to itself when `scale` >= 0.
* Child of FractionalType whose value is NaN or Infinite would always eval to itself.
*
* Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]],
* which leads to scale update in DecimalType's [[PrecisionInfo]]
*
* @param child expr to be round, all [[NumericType]] is allowed as Input
* @param scale new scale to be round to, this should be a constant int at runtime
*/
case class Round(child: Expression, scale: Expression)
extends BinaryExpression with ExpectsInputTypes {

import BigDecimal.RoundingMode.HALF_UP

def this(child: Expression) = this(child, Literal(0))

override def left: Expression = child
override def right: Expression = scale

// round of Decimal would eval to null if it fails to `changePrecision`
override def nullable: Boolean = true

override def foldable: Boolean = child.foldable

override lazy val dataType: DataType = child.dataType match {
// if the new scale is bigger which means we are scaling up,
// keep the original scale as `Decimal` does
case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale)
case t => t
}

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case TypeCheckSuccess =>
if (scale.foldable) {
TypeCheckSuccess
} else {
TypeCheckFailure("Only foldable Expression is allowed for scale arguments")
}
case f => f
}
}

// Avoid repeated evaluation since `scale` is a constant int,
// avoid unnecessary `child` evaluation in both codegen and non-codegen eval
// by checking if scaleV == null as well.
private lazy val scaleV: Any = scale.eval(EmptyRow)
private lazy val _scale: Int = scaleV.asInstanceOf[Int]

override def eval(input: InternalRow): Any = {
if (scaleV == null) { // if scale is null, no need to eval its child at all
null
} else {
val evalE = child.eval(input)
if (evalE == null) {
null
} else {
nullSafeEval(evalE)
}
}
}

// not overriding since _scale is a constant int at runtime
def nullSafeEval(input1: Any): Any = {
child.dataType match {
case _: DecimalType =>
val decimal = input1.asInstanceOf[Decimal]
if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte
case ShortType =>
BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort
case IntegerType =>
BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt
case LongType =>
BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong
case FloatType =>
val f = input1.asInstanceOf[Float]
if (f.isNaN || f.isInfinite) {
f
} else {
BigDecimal(f).setScale(_scale, HALF_UP).toFloat
}
case DoubleType =>
val d = input1.asInstanceOf[Double]
if (d.isNaN || d.isInfinite) {
d
} else {
BigDecimal(d).setScale(_scale, HALF_UP).toDouble
}
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val ce = child.gen(ctx)

val evaluationCode = child.dataType match {
case _: DecimalType =>
s"""
if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) {
${ev.primitive} = ${ce.primitive};
} else {
${ev.isNull} = true;
}"""
case ByteType =>
if (_scale < 0) {
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
} else {
s"${ev.primitive} = ${ce.primitive};"
}
case ShortType =>
if (_scale < 0) {
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
} else {
s"${ev.primitive} = ${ce.primitive};"
}
case IntegerType =>
if (_scale < 0) {
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
} else {
s"${ev.primitive} = ${ce.primitive};"
}
case LongType =>
if (_scale < 0) {
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
} else {
s"${ev.primitive} = ${ce.primitive};"
}
case FloatType => // if child eval to NaN or Infinity, just return it.
if (_scale == 0) {
s"""
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
} else {
${ev.primitive} = Math.round(${ce.primitive});
}"""
} else {
s"""
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
} else {
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
}"""
}
case DoubleType => // if child eval to NaN or Infinity, just return it.
if (_scale == 0) {
s"""
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
} else {
${ev.primitive} = Math.round(${ce.primitive});
}"""
} else {
s"""
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
} else {
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
}"""
}
}

if (scaleV == null) { // if scale is null, no need to eval its child at all
s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
"""
} else {
s"""
${ce.code}
boolean ${ev.isNull} = ${ce.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
$evaluationCode
}
"""
}
}

override def prettyName: String = "round"
}
Expand Up @@ -52,6 +52,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
s"differing types in '${expr.prettyString}' (int and boolean)")
}

def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = {
val e = intercept[AnalysisException] {
assertSuccess(expr)
}
assert(e.getMessage.contains(errorMessage))
}

test("check types for unary arithmetic") {
assertError(UnaryMinus('stringField), "operator - accepts numeric type")
assertError(Abs('stringField), "function abs accepts numeric type")
Expand Down Expand Up @@ -171,4 +178,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
"Odd position only allow foldable and not-null StringType expressions")
}

test("check types for ROUND") {
assertErrorWithImplicitCast(Round(Literal(null), 'booleanField),
"data type mismatch: argument 2 is expected to be of type int")
assertErrorWithImplicitCast(Round(Literal(null), 'complexField),
"data type mismatch: argument 2 is expected to be of type int")
assertSuccess(Round(Literal(null), Literal(null)))
assertError(Round('booleanField, 'intField),
"data type mismatch: argument 1 is expected to be of type numeric")
}
}
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import scala.math.BigDecimal.RoundingMode

import com.google.common.math.LongMath

import org.apache.spark.SparkFunSuite
Expand Down Expand Up @@ -336,4 +338,46 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
null,
create_row(null))
}

test("round") {
val domain = -6 to 6
val doublePi: Double = math.Pi
val shortPi: Short = 31415
val intPi: Int = 314159265
val longPi: Long = 31415926535897932L
val bdPi: BigDecimal = BigDecimal(31415927L, 7)

val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142,
3.1416, 3.14159, 3.141593)

val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++
Seq.fill[Short](7)(31415)

val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
314159270) ++ Seq.fill(7)(314159265)

val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L,
31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
Seq.fill(7)(31415926535897932L)

val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
BigDecimal(3.141593), BigDecimal(3.1415927))

domain.zipWithIndex.foreach { case (scale, i) =>
checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
}

// round_scale > current_scale would result in precision increase
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
(0 to 7).foreach { i =>
checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
}
(8 to 10).foreach { scale =>
checkEvaluation(Round(bdPi, scale), null, EmptyRow)
}
}
}
Expand Up @@ -24,14 +24,14 @@ import org.scalatest.PrivateMethodTester
import scala.language.postfixOps

class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
test("creating decimals") {
/** Check that a Decimal has the given string representation, precision and scale */
def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
assert(d.toString === string)
assert(d.precision === precision)
assert(d.scale === scale)
}
/** Check that a Decimal has the given string representation, precision and scale */
private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
assert(d.toString === string)
assert(d.precision === precision)
assert(d.scale === scale)
}

test("creating decimals") {
checkDecimal(new Decimal(), "0", 1, 0)
checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3)
checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1)
Expand All @@ -53,6 +53,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
}

test("creating decimals with negative scale") {
checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3)
checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2)
checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9)
checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10)
checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10)
checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10)
}

test("double and long values") {
/** Check that a Decimal converts to the given double and long values */
def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = {
Expand Down
32 changes: 32 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Expand Up @@ -1389,6 +1389,38 @@ object functions {
*/
def rint(columnName: String): Column = rint(Column(columnName))

/**
* Returns the value of the column `e` rounded to 0 decimal places.
*
* @group math_funcs
* @since 1.5.0
*/
def round(e: Column): Column = round(e.expr, 0)

/**
* Returns the value of the given column rounded to 0 decimal places.
*
* @group math_funcs
* @since 1.5.0
*/
def round(columnName: String): Column = round(Column(columnName), 0)

/**
* Returns the value of `e` rounded to `scale` decimal places.
*
* @group math_funcs
* @since 1.5.0
*/
def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale))

/**
* Returns the value of the given column rounded to `scale` decimal places.
*
* @group math_funcs
* @since 1.5.0
*/
def round(columnName: String, scale: Int): Column = round(Column(columnName), scale)

/**
* Shift the the given value numBits left. If the given value is a long value, this function
* will return a long value else it will return an integer value.
Expand Down

0 comments on commit f0e1297

Please sign in to comment.