Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8279][SQL]Add math function round #6938

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -116,6 +116,7 @@ object FunctionRegistry {
expression[Pow]("power"),
expression[UnaryPositive]("positive"),
expression[Rint]("rint"),
expression[Round]("round"),
expression[ShiftLeft]("shiftleft"),
expression[ShiftRight]("shiftright"),
expression[ShiftRightUnsigned]("shiftrightunsigned"),
Expand Down
Expand Up @@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions

import java.{lang => jl}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -520,3 +522,202 @@ case class Logarithm(left: Expression, right: Expression)
"""
}
}

/**
* Round the `child`'s result to `scale` decimal place when `scale` >= 0
* or round at integral part when `scale` < 0.
* For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30.
*
* Child of IntegralType would eval to itself when `scale` >= 0.
* Child of FractionalType whose value is NaN or Infinite would always eval to itself.
*
* Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]],
* which leads to scale update in DecimalType's [[PrecisionInfo]]
*
* @param child expr to be round, all [[NumericType]] is allowed as Input
* @param scale new scale to be round to, this should be a constant int at runtime
*/
case class Round(child: Expression, scale: Expression)
extends BinaryExpression with ExpectsInputTypes {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExpectsInputTypes -> ImplicitCastInputTypes


import BigDecimal.RoundingMode.HALF_UP

def this(child: Expression) = this(child, Literal(0))

override def left: Expression = child
override def right: Expression = scale

// round of Decimal would eval to null if it fails to `changePrecision`
override def nullable: Boolean = true

override def foldable: Boolean = child.foldable

override lazy val dataType: DataType = child.dataType match {
// if the new scale is bigger which means we are scaling up,
// keep the original scale as `Decimal` does
case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale)
case t => t
}

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case TypeCheckSuccess =>
if (scale.foldable) {
TypeCheckSuccess
} else {
TypeCheckFailure("Only foldable Expression is allowed for scale arguments")
}
case f => f
}
}

// Avoid repeated evaluation since `scale` is a constant int,
// avoid unnecessary `child` evaluation in both codegen and non-codegen eval
// by checking if scaleV == null as well.
private lazy val scaleV: Any = scale.eval(EmptyRow)
private lazy val _scale: Int = scaleV.asInstanceOf[Int]

override def eval(input: InternalRow): Any = {
if (scaleV == null) { // if scale is null, no need to eval its child at all
null
} else {
val evalE = child.eval(input)
if (evalE == null) {
null
} else {
nullSafeEval(evalE)
}
}
}

// not overriding since _scale is a constant int at runtime
def nullSafeEval(input1: Any): Any = {
child.dataType match {
case _: DecimalType =>
val decimal = input1.asInstanceOf[Decimal]
if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte
case ShortType =>
BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort
case IntegerType =>
BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt
case LongType =>
BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong
case FloatType =>
val f = input1.asInstanceOf[Float]
if (f.isNaN || f.isInfinite) {
f
} else {
BigDecimal(f).setScale(_scale, HALF_UP).toFloat
}
case DoubleType =>
val d = input1.asInstanceOf[Double]
if (d.isNaN || d.isInfinite) {
d
} else {
BigDecimal(d).setScale(_scale, HALF_UP).toDouble
}
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val ce = child.gen(ctx)

val evaluationCode = child.dataType match {
case _: DecimalType =>
s"""
if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) {
${ev.primitive} = ${ce.primitive};
} else {
${ev.isNull} = true;
}"""
case ByteType =>
if (_scale < 0) {
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
} else {
s"${ev.primitive} = ${ce.primitive};"
}
case ShortType =>
if (_scale < 0) {
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
} else {
s"${ev.primitive} = ${ce.primitive};"
}
case IntegerType =>
if (_scale < 0) {
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
} else {
s"${ev.primitive} = ${ce.primitive};"
}
case LongType =>
if (_scale < 0) {
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
} else {
s"${ev.primitive} = ${ce.primitive};"
}
case FloatType => // if child eval to NaN or Infinity, just return it.
if (_scale == 0) {
s"""
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
} else {
${ev.primitive} = Math.round(${ce.primitive});
}"""
} else {
s"""
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
} else {
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
}"""
}
case DoubleType => // if child eval to NaN or Infinity, just return it.
if (_scale == 0) {
s"""
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
} else {
${ev.primitive} = Math.round(${ce.primitive});
}"""
} else {
s"""
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
} else {
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
}"""
}
}

if (scaleV == null) { // if scale is null, no need to eval its child at all
s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
"""
} else {
s"""
${ce.code}
boolean ${ev.isNull} = ${ce.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
$evaluationCode
}
"""
}
}

override def prettyName: String = "round"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove this, since the expression is already named Round

}
Expand Up @@ -52,6 +52,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).")
}

def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = {
val e = intercept[AnalysisException] {
assertSuccess(expr)
}
assert(e.getMessage.contains(errorMessage))
}

test("check types for unary arithmetic") {
assertError(UnaryMinus('stringField), "operator - accepts numeric type")
assertError(Abs('stringField), "function abs accepts numeric type")
Expand Down Expand Up @@ -171,4 +178,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
"Odd position only allow foldable and not-null StringType expressions")
}

test("check types for ROUND") {
assertErrorWithImplicitCast(Round(Literal(null), 'booleanField),
"data type mismatch: argument 2 is expected to be of type int")
assertErrorWithImplicitCast(Round(Literal(null), 'complexField),
"data type mismatch: argument 2 is expected to be of type int")
assertSuccess(Round(Literal(null), Literal(null)))
assertError(Round('booleanField, 'intField),
"data type mismatch: argument 1 is expected to be of type numeric")
}
}
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import scala.math.BigDecimal.RoundingMode

import com.google.common.math.LongMath

import org.apache.spark.SparkFunSuite
Expand Down Expand Up @@ -336,4 +338,46 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
null,
create_row(null))
}

test("round") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the case of round, i think we should explicitly write out the test cases, rather than relying on conversions.

val domain = -6 to 6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

domain -> scales

val doublePi: Double = math.Pi
val shortPi: Short = 31415
val intPi: Int = 314159265
val longPi: Long = 31415926535897932L
val bdPi: BigDecimal = BigDecimal(31415927L, 7)

val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142,
3.1416, 3.14159, 3.141593)

val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++
Seq.fill[Short](7)(31415)

val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
314159270) ++ Seq.fill(7)(314159265)

val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L,
31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
Seq.fill(7)(31415926535897932L)

val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
BigDecimal(3.141593), BigDecimal(3.1415927))

domain.zipWithIndex.foreach { case (scale, i) =>
checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
}

// round_scale > current_scale would result in precision increase
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
(0 to 7).foreach { i =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i -> scale

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also using i for array index here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok - can you at least move bdResults closer to this loop?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
}
(8 to 10).foreach { scale =>
checkEvaluation(Round(bdPi, scale), null, EmptyRow)
}
}
}
Expand Up @@ -24,14 +24,14 @@ import org.scalatest.PrivateMethodTester
import scala.language.postfixOps

class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
test("creating decimals") {
/** Check that a Decimal has the given string representation, precision and scale */
def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
assert(d.toString === string)
assert(d.precision === precision)
assert(d.scale === scale)
}
/** Check that a Decimal has the given string representation, precision and scale */
private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
assert(d.toString === string)
assert(d.precision === precision)
assert(d.scale === scale)
}

test("creating decimals") {
checkDecimal(new Decimal(), "0", 1, 0)
checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3)
checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1)
Expand All @@ -53,6 +53,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
}

test("creating decimals with negative scale") {
checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3)
checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2)
checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9)
checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10)
checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10)
checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10)
}

test("double and long values") {
/** Check that a Decimal converts to the given double and long values */
def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = {
Expand Down
32 changes: 32 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Expand Up @@ -1385,6 +1385,38 @@ object functions {
*/
def rint(columnName: String): Column = rint(Column(columnName))

/**
* Returns the value of the column `e` rounded to 0 decimal places.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should document all the behaviors you documented for Round expression here.

*
* @group math_funcs
* @since 1.5.0
*/
def round(e: Column): Column = round(e.expr, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add

def round(columnName: String): Column = round(columnName, 0)


/**
* Returns the value of the given column rounded to 0 decimal places.
*
* @group math_funcs
* @since 1.5.0
*/
def round(columnName: String): Column = round(Column(columnName), 0)

/**
* Returns the value of `e` rounded to `scale` decimal places.
*
* @group math_funcs
* @since 1.5.0
*/
def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale))

/**
* Returns the value of the given column rounded to `scale` decimal places.
*
* @group math_funcs
* @since 1.5.0
*/
def round(columnName: String, scale: Int): Column = round(Column(columnName), scale)

/**
* Shift the the given value numBits left. If the given value is a long value, this function
* will return a long value else it will return an integer value.
Expand Down