Skip to content

Commit

Permalink
modify checkInputDataTypes using foldable
Browse files Browse the repository at this point in the history
  • Loading branch information
yjshen committed Jul 14, 2015
1 parent 5486b2d commit 1b87540
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,8 @@ case class Round(child: Expression, scale: Expression) extends Expression {
return TypeCheckFailure("ROUND scale argument out of allowed range")
}
case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement
case child =>
if (child.find { case _: AttributeReference => true; case _ => false } != None) {
case _ =>
if (!scale.foldable) {
return TypeCheckFailure("Only Integral Literal or Null Literal " +
s"are allowed for ROUND scale arguments, got ${child.dataType}")
}
Expand Down Expand Up @@ -595,6 +595,21 @@ case class Round(child: Expression, scale: Expression) extends Expression {
}
}

private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = {
input match {
case f: Float if (f.isNaN || f.isInfinite) => return input
case d: Double if (d.isNaN || d.isInfinite) => return input
case _ =>
}
bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP))
}

private def round(input: String, scale: Int): Any = {
try round(input.toDouble, scale) catch {
case _ : NumberFormatException => null
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val ce = child.gen(ctx)

Expand Down Expand Up @@ -672,19 +687,4 @@ case class Round(child: Expression, scale: Expression) extends Expression {
}
"""
}

private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = {
input match {
case f: Float if (f.isNaN || f.isInfinite) => return input
case d: Double if (d.isNaN || d.isInfinite) => return input
case _ =>
}
bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP))
}

private def round(input: String, scale: Int): Any = {
try round(input.toDouble, scale) catch {
case _ : NumberFormatException => null
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
create_row(null))
}

test("round test") {
test("round") {
val domain = -16 to 16
val doublePi = math.Pi
val stringPi = "3.141592653589793"
Expand Down

0 comments on commit 1b87540

Please sign in to comment.