Skip to content

Commit

Permalink
address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
yjshen committed Jul 14, 2015
1 parent 302a78a commit 61760ee
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,20 @@ case class Logarithm(left: Expression, right: Expression)
}
}

/**
* Round the `child`'s result to `scale` decimal place when `scale` >= 0
* or round at integral part when `scale` < 0.
* For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30.
*
* Child of IntegralType would eval to itself when `scale` >= 0.
* Child of FractionalType whose value is NaN or Infinite would always eval to itself.
*
* Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]],
* which leads to scale update in DecimalType's [[PrecisionInfo]]
*
* @param child expr to be round, all [[NumericType]] is allowed as Input
* @param scale new scale to be round to, this should be a constant int at runtime
*/
case class Round(child: Expression, scale: Expression)
extends BinaryExpression with ExpectsInputTypes {

Expand Down Expand Up @@ -559,10 +573,27 @@ case class Round(child: Expression, scale: Expression)
}
}

private lazy val scaleV = scale.eval(EmptyRow)
private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0
// Avoid repeated evaluation since `scale` is a constant int,
// avoid unnecessary `child` evaluation in both codegen and non-codegen eval
// by checking if scaleV == null as well.
private lazy val scaleV: Any = scale.eval(EmptyRow)
private lazy val _scale: Int = scaleV.asInstanceOf[Int]

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
override def eval(input: InternalRow): Any = {
if (scaleV == null) { // if scale is null, no need to eval its child at all
null
} else {
val evalE = child.eval(input)
if (evalE == null) {
null
} else {
nullSafeEval(evalE)
}
}
}

// not overriding since _scale is a constant int at runtime
def nullSafeEval(input1: Any): Any = {
child.dataType match {
case _: DecimalType =>
val decimal = input1.asInstanceOf[Decimal]
Expand Down Expand Up @@ -604,45 +635,89 @@ case class Round(child: Expression, scale: Expression)
${ev.isNull} = true;
}"""
case ByteType =>
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
if (_scale < 0) {
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
} else {
s"${ev.primitive} = ${ce.primitive};"
}
case ShortType =>
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
if (_scale < 0) {
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
} else {
s"${ev.primitive} = ${ce.primitive};"
}
case IntegerType =>
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
if (_scale < 0) {
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
} else {
s"${ev.primitive} = ${ce.primitive};"
}
case LongType =>
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
case FloatType =>
s"""
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
if (_scale < 0) {
s"""
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
} else {
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
}"""
case DoubleType =>
s"""
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
s"${ev.primitive} = ${ce.primitive};"
}
case FloatType => // if child eval to NaN or Infinity, just return it.
if (_scale == 0) {
s"""
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
} else {
${ev.primitive} = Math.round(${ce.primitive});
}"""
} else {
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
}"""
s"""
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
} else {
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
}"""
}
case DoubleType => // if child eval to NaN or Infinity, just return it.
if (_scale == 0) {
s"""
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
} else {
${ev.primitive} = Math.round(${ce.primitive});
}"""
} else {
s"""
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
${ev.primitive} = ${ce.primitive};
} else {
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
}"""
}
}

ce.code + s"""
boolean ${ev.isNull} = ${ce.isNull} || ${scaleV == null};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${evaluationCode}
}
if (scaleV == null) { // if scale is null, no need to eval its child at all
s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
"""
} else {
s"""
${ce.code}
boolean ${ev.isNull} = ${ce.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
$evaluationCode
}
"""
}
}

override def prettyName: String = "round"
}
Original file line number Diff line number Diff line change
Expand Up @@ -340,32 +340,43 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("round") {
val domain = -16 to 16
val doublePi = math.Pi
val domain = -6 to 6
val doublePi: Double = math.Pi
val shortPi: Short = 31415
val intPi = 314159265
val longPi = 31415926535897932L
val bdPi = BigDecimal(31415926535897932L, 10)

domain.foreach { scale =>
checkEvaluation(Round(doublePi, scale),
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
checkEvaluation(Round(shortPi, scale),
BigDecimal.valueOf(shortPi).setScale(scale, RoundingMode.HALF_UP).toShort, EmptyRow)
checkEvaluation(Round(intPi, scale),
BigDecimal.valueOf(intPi).setScale(scale, RoundingMode.HALF_UP).toInt, EmptyRow)
checkEvaluation(Round(longPi, scale),
BigDecimal.valueOf(longPi).setScale(scale, RoundingMode.HALF_UP).toLong, EmptyRow)
val intPi: Int = 314159265
val longPi: Long = 31415926535897932L
val bdPi: BigDecimal = BigDecimal(31415927L, 7)

val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142,
3.1416, 3.14159, 3.141593)

val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++
Seq.fill[Short](7)(31415)

val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
314159270) ++ Seq.fill(7)(314159265)

val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L,
31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
Seq.fill(7)(31415926535897932L)

val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
BigDecimal(3.141593), BigDecimal(3.1415927))

domain.zipWithIndex.foreach { case (scale, i) =>
checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
}

// round_scale > current_scale would result in precision increase
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
val (validScales, nullScales) = domain.splitAt(27)
validScales.foreach { scale =>
checkEvaluation(Round(bdPi, scale),
Decimal(bdPi.setScale(scale, RoundingMode.HALF_UP)), EmptyRow)
(0 to 7).foreach { i =>
checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
}
nullScales.foreach { scale =>
(8 to 10).foreach { scale =>
checkEvaluation(Round(bdPi, scale), null, EmptyRow)
}
}
Expand Down

0 comments on commit 61760ee

Please sign in to comment.