From b0bff7950969f575ae3756177b5d7acab419042f Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 3 Jul 2015 10:49:30 +0800 Subject: [PATCH] make round's inner method's name more meaningful --- .../spark/sql/catalyst/expressions/math.scala | 45 ++++++++++--------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index c9e8b72cdb486..92d8118c67252 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -536,9 +536,6 @@ case class Round(child: Expression, scale: Expression) extends Expression { override def foldable: Boolean = child.foldable - private lazy val scaleV = scale.eval(EmptyRow) - private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0 - override lazy val dataType: DataType = child.dataType match { case StringType | BinaryType => DoubleType case DecimalType.Fixed(p, s) => DecimalType(p, _scale) @@ -570,33 +567,43 @@ case class Round(child: Expression, scale: Expression) extends Expression { TypeCheckSuccess } - private lazy val rounding: (Any) => (Any) = roundGen(child.dataType) + private lazy val scaleV = scale.eval(EmptyRow) + private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0 + + override def eval(input: InternalRow): Any = { + val evalE = child.eval(input) + if (evalE == null || scaleV == null) return null + round(evalE) + } + + private lazy val round: (Any) => (Any) = typedRound(child.dataType) - def roundGen(dt: DataType)(x: Any): Any = { + // Using dataType info to find an appropriate round method + private def typedRound(dt: DataType)(x: Any): Any = { dt match { case _: DecimalType => val decimal = x.asInstanceOf[Decimal] if (decimal.changePrecision(decimal.precision, _scale)) decimal else null case ByteType => - round(x.asInstanceOf[Byte], _scale) + numericRound(x.asInstanceOf[Byte], _scale) case ShortType => - round(x.asInstanceOf[Short], _scale) + numericRound(x.asInstanceOf[Short], _scale) case IntegerType => - round(x.asInstanceOf[Int], _scale) + numericRound(x.asInstanceOf[Int], _scale) case LongType => - round(x.asInstanceOf[Long], _scale) + numericRound(x.asInstanceOf[Long], _scale) case FloatType => - round(x.asInstanceOf[Float], _scale) + numericRound(x.asInstanceOf[Float], _scale) case DoubleType => - round(x.asInstanceOf[Double], _scale) + numericRound(x.asInstanceOf[Double], _scale) case StringType => - round(x.asInstanceOf[UTF8String].toString, _scale) + stringLikeRound(x.asInstanceOf[UTF8String].toString, _scale) case BinaryType => - round(UTF8String.fromBytes(x.asInstanceOf[Array[Byte]]).toString, _scale) + stringLikeRound(UTF8String.fromBytes(x.asInstanceOf[Array[Byte]]).toString, _scale) } } - private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = { + private def numericRound[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = { input match { case f: Float if (f.isNaN || f.isInfinite) => return input case d: Double if (d.isNaN || d.isInfinite) => return input @@ -605,18 +612,12 @@ case class Round(child: Expression, scale: Expression) extends Expression { bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP)) } - private def round(input: String, scale: Int): Any = { - try round(input.toDouble, scale) catch { + private def stringLikeRound(input: String, scale: Int): Any = { + try numericRound(input.toDouble, scale) catch { case _: NumberFormatException => null } } - def eval(input: InternalRow): Any = { - val evalE = child.eval(input) - if (evalE == null || scaleV == null) return null - rounding(evalE) - } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val ce = child.gen(ctx)