New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-8279][SQL]Add math function round #6938
Changes from all commits
653d047
7e163ae
56db4bb
7c83e13
9be894e
6cd9a64
2077888
5486b2d
1b87540
e6f44c4
9bd6930
b0bff79
c3b9839
d10be4a
9555e35
8c7a949
31dfe7c
302a78a
61760ee
392b65b
07a124c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions | |
|
||
import java.{lang => jl} | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} | ||
import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.unsafe.types.UTF8String | ||
|
||
|
@@ -520,3 +522,202 @@ case class Logarithm(left: Expression, right: Expression) | |
""" | ||
} | ||
} | ||
|
||
/** | ||
* Round the `child`'s result to `scale` decimal place when `scale` >= 0 | ||
* or round at integral part when `scale` < 0. | ||
* For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30. | ||
* | ||
* Child of IntegralType would eval to itself when `scale` >= 0. | ||
* Child of FractionalType whose value is NaN or Infinite would always eval to itself. | ||
* | ||
* Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]], | ||
* which leads to scale update in DecimalType's [[PrecisionInfo]] | ||
* | ||
* @param child expr to be round, all [[NumericType]] is allowed as Input | ||
* @param scale new scale to be round to, this should be a constant int at runtime | ||
*/ | ||
case class Round(child: Expression, scale: Expression) | ||
extends BinaryExpression with ExpectsInputTypes { | ||
|
||
import BigDecimal.RoundingMode.HALF_UP | ||
|
||
def this(child: Expression) = this(child, Literal(0)) | ||
|
||
override def left: Expression = child | ||
override def right: Expression = scale | ||
|
||
// round of Decimal would eval to null if it fails to `changePrecision` | ||
override def nullable: Boolean = true | ||
|
||
override def foldable: Boolean = child.foldable | ||
|
||
override lazy val dataType: DataType = child.dataType match { | ||
// if the new scale is bigger which means we are scaling up, | ||
// keep the original scale as `Decimal` does | ||
case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale) | ||
case t => t | ||
} | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) | ||
|
||
override def checkInputDataTypes(): TypeCheckResult = { | ||
super.checkInputDataTypes() match { | ||
case TypeCheckSuccess => | ||
if (scale.foldable) { | ||
TypeCheckSuccess | ||
} else { | ||
TypeCheckFailure("Only foldable Expression is allowed for scale arguments") | ||
} | ||
case f => f | ||
} | ||
} | ||
|
||
// Avoid repeated evaluation since `scale` is a constant int, | ||
// avoid unnecessary `child` evaluation in both codegen and non-codegen eval | ||
// by checking if scaleV == null as well. | ||
private lazy val scaleV: Any = scale.eval(EmptyRow) | ||
private lazy val _scale: Int = scaleV.asInstanceOf[Int] | ||
|
||
override def eval(input: InternalRow): Any = { | ||
if (scaleV == null) { // if scale is null, no need to eval its child at all | ||
null | ||
} else { | ||
val evalE = child.eval(input) | ||
if (evalE == null) { | ||
null | ||
} else { | ||
nullSafeEval(evalE) | ||
} | ||
} | ||
} | ||
|
||
// not overriding since _scale is a constant int at runtime | ||
def nullSafeEval(input1: Any): Any = { | ||
child.dataType match { | ||
case _: DecimalType => | ||
val decimal = input1.asInstanceOf[Decimal] | ||
if (decimal.changePrecision(decimal.precision, _scale)) decimal else null | ||
case ByteType => | ||
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte | ||
case ShortType => | ||
BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort | ||
case IntegerType => | ||
BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt | ||
case LongType => | ||
BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong | ||
case FloatType => | ||
val f = input1.asInstanceOf[Float] | ||
if (f.isNaN || f.isInfinite) { | ||
f | ||
} else { | ||
BigDecimal(f).setScale(_scale, HALF_UP).toFloat | ||
} | ||
case DoubleType => | ||
val d = input1.asInstanceOf[Double] | ||
if (d.isNaN || d.isInfinite) { | ||
d | ||
} else { | ||
BigDecimal(d).setScale(_scale, HALF_UP).toDouble | ||
} | ||
} | ||
} | ||
|
||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { | ||
val ce = child.gen(ctx) | ||
|
||
val evaluationCode = child.dataType match { | ||
case _: DecimalType => | ||
s""" | ||
if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) { | ||
${ev.primitive} = ${ce.primitive}; | ||
} else { | ||
${ev.isNull} = true; | ||
}""" | ||
case ByteType => | ||
if (_scale < 0) { | ||
s""" | ||
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). | ||
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" | ||
} else { | ||
s"${ev.primitive} = ${ce.primitive};" | ||
} | ||
case ShortType => | ||
if (_scale < 0) { | ||
s""" | ||
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). | ||
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" | ||
} else { | ||
s"${ev.primitive} = ${ce.primitive};" | ||
} | ||
case IntegerType => | ||
if (_scale < 0) { | ||
s""" | ||
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). | ||
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" | ||
} else { | ||
s"${ev.primitive} = ${ce.primitive};" | ||
} | ||
case LongType => | ||
if (_scale < 0) { | ||
s""" | ||
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). | ||
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" | ||
} else { | ||
s"${ev.primitive} = ${ce.primitive};" | ||
} | ||
case FloatType => // if child eval to NaN or Infinity, just return it. | ||
if (_scale == 0) { | ||
s""" | ||
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ | ||
${ev.primitive} = ${ce.primitive}; | ||
} else { | ||
${ev.primitive} = Math.round(${ce.primitive}); | ||
}""" | ||
} else { | ||
s""" | ||
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ | ||
${ev.primitive} = ${ce.primitive}; | ||
} else { | ||
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). | ||
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); | ||
}""" | ||
} | ||
case DoubleType => // if child eval to NaN or Infinity, just return it. | ||
if (_scale == 0) { | ||
s""" | ||
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ | ||
${ev.primitive} = ${ce.primitive}; | ||
} else { | ||
${ev.primitive} = Math.round(${ce.primitive}); | ||
}""" | ||
} else { | ||
s""" | ||
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ | ||
${ev.primitive} = ${ce.primitive}; | ||
} else { | ||
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). | ||
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); | ||
}""" | ||
} | ||
} | ||
|
||
if (scaleV == null) { // if scale is null, no need to eval its child at all | ||
s""" | ||
boolean ${ev.isNull} = true; | ||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; | ||
""" | ||
} else { | ||
s""" | ||
${ce.code} | ||
boolean ${ev.isNull} = ${ce.isNull}; | ||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; | ||
if (!${ev.isNull}) { | ||
$evaluationCode | ||
} | ||
""" | ||
} | ||
} | ||
|
||
override def prettyName: String = "round" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can remove this, since the expression is already named Round |
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ | |
|
||
package org.apache.spark.sql.catalyst.expressions | ||
|
||
import scala.math.BigDecimal.RoundingMode | ||
|
||
import com.google.common.math.LongMath | ||
|
||
import org.apache.spark.SparkFunSuite | ||
|
@@ -336,4 +338,46 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { | |
null, | ||
create_row(null)) | ||
} | ||
|
||
test("round") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the case of round, i think we should explicitly write out the test cases, rather than relying on conversions. |
||
val domain = -6 to 6 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. domain -> scales |
||
val doublePi: Double = math.Pi | ||
val shortPi: Short = 31415 | ||
val intPi: Int = 314159265 | ||
val longPi: Long = 31415926535897932L | ||
val bdPi: BigDecimal = BigDecimal(31415927L, 7) | ||
|
||
val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, | ||
3.1416, 3.14159, 3.141593) | ||
|
||
val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ | ||
Seq.fill[Short](7)(31415) | ||
|
||
val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, | ||
314159270) ++ Seq.fill(7)(314159265) | ||
|
||
val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L, | ||
31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ | ||
Seq.fill(7)(31415926535897932L) | ||
|
||
val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), | ||
BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), | ||
BigDecimal(3.141593), BigDecimal(3.1415927)) | ||
|
||
domain.zipWithIndex.foreach { case (scale, i) => | ||
checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) | ||
checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) | ||
checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) | ||
checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) | ||
} | ||
|
||
// round_scale > current_scale would result in precision increase | ||
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null | ||
(0 to 7).foreach { i => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i -> scale There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was also using i for array index here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok - can you at least move bdResults closer to this loop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) | ||
} | ||
(8 to 10).foreach { scale => | ||
checkEvaluation(Round(bdPi, scale), null, EmptyRow) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1385,6 +1385,38 @@ object functions { | |
*/ | ||
def rint(columnName: String): Column = rint(Column(columnName)) | ||
|
||
/** | ||
* Returns the value of the column `e` rounded to 0 decimal places. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should document all the behaviors you documented for Round expression here. |
||
* | ||
* @group math_funcs | ||
* @since 1.5.0 | ||
*/ | ||
def round(e: Column): Column = round(e.expr, 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add def round(columnName: String): Column = round(columnName, 0) |
||
|
||
/** | ||
* Returns the value of the given column rounded to 0 decimal places. | ||
* | ||
* @group math_funcs | ||
* @since 1.5.0 | ||
*/ | ||
def round(columnName: String): Column = round(Column(columnName), 0) | ||
|
||
/** | ||
* Returns the value of `e` rounded to `scale` decimal places. | ||
* | ||
* @group math_funcs | ||
* @since 1.5.0 | ||
*/ | ||
def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) | ||
|
||
/** | ||
* Returns the value of the given column rounded to `scale` decimal places. | ||
* | ||
* @group math_funcs | ||
* @since 1.5.0 | ||
*/ | ||
def round(columnName: String, scale: Int): Column = round(Column(columnName), scale) | ||
|
||
/** | ||
* Shift the the given value numBits left. If the given value is a long value, this function | ||
* will return a long value else it will return an integer value. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ExpectsInputTypes -> ImplicitCastInputTypes