From 1b87540358f9195bf4a43ab3ade309bd43357f6c Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 24 Jun 2015 15:22:59 +0800 Subject: [PATCH] modify checkInputDataTypes using foldable --- .../spark/sql/catalyst/expressions/math.scala | 34 +++++++++---------- .../expressions/MathFunctionsSuite.scala | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 10460c0b2ff20..6f4db69d9e4f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -558,8 +558,8 @@ case class Round(child: Expression, scale: Expression) extends Expression { return TypeCheckFailure("ROUND scale argument out of allowed range") } case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement - case child => - if (child.find { case _: AttributeReference => true; case _ => false } != None) { + case _ => + if (!scale.foldable) { return TypeCheckFailure("Only Integral Literal or Null Literal " + s"are allowed for ROUND scale arguments, got ${child.dataType}") } @@ -595,6 +595,21 @@ case class Round(child: Expression, scale: Expression) extends Expression { } } + private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = { + input match { + case f: Float if (f.isNaN || f.isInfinite) => return input + case d: Double if (d.isNaN || d.isInfinite) => return input + case _ => + } + bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP)) + } + + private def round(input: String, scale: Int): Any = { + try round(input.toDouble, scale) catch { + case _ : NumberFormatException => null + } + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val ce = child.gen(ctx) @@ -672,19 +687,4 @@ case class Round(child: Expression, scale: Expression) extends Expression { } """ } - - private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = { - input match { - case f: Float if (f.isNaN || f.isInfinite) => return input - case d: Double if (d.isNaN || d.isInfinite) => return input - case _ => - } - bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP)) - } - - private def round(input: String, scale: Int): Any = { - try round(input.toDouble, scale) catch { - case _ : NumberFormatException => null - } - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 9d95ef5cae35d..477ae969240e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -339,7 +339,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { create_row(null)) } - test("round test") { + test("round") { val domain = -16 to 16 val doublePi = math.Pi val stringPi = "3.141592653589793"