Skip to content

Commit

Permalink
refactor Round's constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
yjshen committed Jul 14, 2015
1 parent 9be894e commit 6cd9a64
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ object FunctionRegistry {
expression[Pow]("power"),
expression[UnaryPositive]("positive"),
expression[Rint]("rint"),
expression[Round]("round"),
expression[ShiftLeft]("shiftleft"),
expression[ShiftRight]("shiftright"),
expression[ShiftRightUnsigned]("shiftrightunsigned"),
Expand All @@ -127,7 +128,6 @@ object FunctionRegistry {
expression[Tanh]("tanh"),
expression[ToDegrees]("degrees"),
expression[ToRadians]("radians"),
expression[Round]("round"),

// misc functions
expression[Md5]("md5"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,74 +524,74 @@ case class Logarithm(left: Expression, right: Expression)
}
}

case class Round(children: Seq[Expression]) extends Expression {
case class Round(child: Expression, scale: Expression) extends Expression {

def this(child: Expression) = {
this(child, Literal(0))
}

def children: Seq[Expression] = Seq(child, scale)

def nullable: Boolean = true

private lazy val evalE2 = if (children.size == 2) children(1).eval(EmptyRow) else null
private lazy val _scale = if (evalE2 != null) evalE2.asInstanceOf[Int] else 0
private lazy val scaleV = scale.asInstanceOf[Literal].value
private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0

override lazy val dataType: DataType = {
children(0).dataType match {
child.dataType match {
case StringType | BinaryType => DoubleType
case DecimalType.Fixed(p, s) => DecimalType(p, _scale)
case t => t
}
}

override def checkInputDataTypes(): TypeCheckResult = {
if (children.size < 1 || children.size > 2) {
return TypeCheckFailure(s"ROUND require one or two arguments, got ${children.size}")
}
children(0).dataType match {
child.dataType match {
case _: NumericType | NullType | BinaryType | StringType => // satisfy requirement
case dt =>
return TypeCheckFailure(s"Only numeric, string or binary data types" +
s" are allowed for ROUND function, got $dt")
}
if (children.size == 2) {
children(1) match {
case Literal(value, LongType) =>
if (value.asInstanceOf[Long] < Int.MinValue || value.asInstanceOf[Long] > Int.MaxValue) {
return TypeCheckFailure("ROUND scale argument out of allowed range")
}
case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement
case child =>
if (child.find { case _: AttributeReference => true; case _ => false } != None) {
return TypeCheckFailure("Only Integral Literal or Null Literal " +
s"are allowed for ROUND scale arguments, got ${child.dataType}")
}
}
scale match {
case Literal(value, LongType) =>
if (value.asInstanceOf[Long] < Int.MinValue || value.asInstanceOf[Long] > Int.MaxValue) {
return TypeCheckFailure("ROUND scale argument out of allowed range")
}
case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement
case child =>
if (child.find { case _: AttributeReference => true; case _ => false } != None) {
return TypeCheckFailure("Only Integral Literal or Null Literal " +
s"are allowed for ROUND scale arguments, got ${child.dataType}")
}
}
TypeCheckSuccess
}

def eval(input: InternalRow): Any = {
val evalE1 = children(0).eval(input)
val evalE = child.eval(input)

if (evalE1 == null) return null
if (children.size == 2 && evalE2 == null) return null
if (evalE == null || scaleV == null) return null

children(0).dataType match {
case decimalType: DecimalType =>
val decimal = evalE1.asInstanceOf[Decimal]
val decimal = evalE.asInstanceOf[Decimal]
if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
case ByteType =>
round(evalE1.asInstanceOf[Byte], _scale)
round(evalE.asInstanceOf[Byte], _scale)
case ShortType =>
round(evalE1.asInstanceOf[Short], _scale)
round(evalE.asInstanceOf[Short], _scale)
case IntegerType =>
round(evalE1.asInstanceOf[Int], _scale)
round(evalE.asInstanceOf[Int], _scale)
case LongType =>
round(evalE1.asInstanceOf[Long], _scale)
round(evalE.asInstanceOf[Long], _scale)
case FloatType =>
round(evalE1.asInstanceOf[Float], _scale)
round(evalE.asInstanceOf[Float], _scale)
case DoubleType =>
round(evalE1.asInstanceOf[Double], _scale)
round(evalE.asInstanceOf[Double], _scale)
case StringType =>
round(evalE1.asInstanceOf[UTF8String].toString, _scale)
round(evalE.asInstanceOf[UTF8String].toString, _scale)
case BinaryType =>
round(UTF8String.fromBytes(evalE1.asInstanceOf[Array[Byte]]).toString, _scale)
round(UTF8String.fromBytes(evalE.asInstanceOf[Array[Byte]]).toString, _scale)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
}

test("check types for ROUND") {
assertError(Round(Seq()), "ROUND require one or two arguments")
assertError(Round(Seq(Literal(null), 'booleanField)),
assertError(Round(Literal(null), 'booleanField),
"Only Integral Literal or Null Literal are allowed for ROUND scale argument")
assertError(Round(Seq(Literal(null), 'complexField)),
assertError(Round(Literal(null), 'complexField),
"Only Integral Literal or Null Literal are allowed for ROUND scale argument")
assertSuccess(Round(Seq(Literal(null), Literal(null))))
assertError(Round(Seq('booleanField, 'intField)),
assertSuccess(Round(Literal(null), Literal(null)))
assertError(Round('booleanField, 'intField),
"Only numeric, string or binary data types are allowed for ROUND function")
assertError(Round(Seq(Literal(null), Literal(1L + Int.MaxValue))),
assertError(Round(Literal(null), Literal(1L + Int.MaxValue)),
"ROUND scale argument out of allowed range")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -347,24 +347,24 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val bdPi = BigDecimal(31415926535897932L, 10)

domain.foreach { scale =>
checkEvaluation(Round(Seq(doublePi, scale)),
checkEvaluation(Round(doublePi, scale),
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
checkEvaluation(Round(Seq(stringPi, scale)),
checkEvaluation(Round(stringPi, scale),
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
checkEvaluation(Round(Seq(intPi, scale)),
checkEvaluation(Round(intPi, scale),
BigDecimal.valueOf(intPi).setScale(scale, RoundingMode.HALF_UP).toInt, EmptyRow)
}
checkEvaluation(Round(Seq("invalid input")), null, EmptyRow)
checkEvaluation(new Round(Literal("invalid input")), null, EmptyRow)

// round_scale > current_scale would result in precision increase
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
val (validScales, nullScales) = domain.splitAt(27)
validScales.foreach { scale =>
checkEvaluation(Round(Seq(bdPi, scale)),
checkEvaluation(Round(bdPi, scale),
Decimal(bdPi.setScale(scale, RoundingMode.HALF_UP)), EmptyRow)
}
nullScales.foreach { scale =>
checkEvaluation(Round(Seq(bdPi, scale)), null, EmptyRow)
checkEvaluation(Round(bdPi, scale), null, EmptyRow)
}
}
}
16 changes: 12 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1386,20 +1386,20 @@ object functions {
def rint(columnName: String): Column = rint(Column(columnName))

/**
* Returns the value of the `e` rounded to 0 decimal places.
* Returns the value of the column `e` rounded to 0 decimal places.
*
* @group math_funcs
* @since 1.5.0
*/
def round(e: Column): Column = Round(Seq(e.expr))
def round(e: Column): Column = round(e.expr, 0)

/**
* Returns the value of `e` rounded to the value of `scale` decimal places.
* Returns the value of the given column `e` rounded to the value of `scale` decimal places.
*
* @group math_funcs
* @since 1.5.0
*/
def round(e: Column, scale: Column): Column = Round(Seq(e.expr, scale.expr))
def round(e: Column, scale: Column): Column = Round(e.expr, scale.expr)

/**
* Returns the value of `e` rounded to `scale` decimal places.
Expand All @@ -1409,6 +1409,14 @@ object functions {
*/
def round(e: Column, scale: Int): Column = round(e, lit(scale))

/**
* Returns the value of the given column rounded to `scale` decimal places.
*
* @group math_funcs
* @since 1.5.0
*/
def round(columnName: String, scale: Int): Column = round(Column(columnName), scale)

/**
* Shift the the given value numBits left. If the given value is a long value, this function
* will return a long value else it will return an integer value.
Expand Down

0 comments on commit 6cd9a64

Please sign in to comment.