Skip to content

Commit

Permalink
codegen versioned eval
Browse files Browse the repository at this point in the history
  • Loading branch information
yjshen committed Jul 14, 2015
1 parent 6cd9a64 commit 2077888
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,8 @@ case class Round(child: Expression, scale: Expression) extends Expression {

if (evalE == null || scaleV == null) return null

children(0).dataType match {
case decimalType: DecimalType =>
child.dataType match {
case _: DecimalType =>
val decimal = evalE.asInstanceOf[Decimal]
if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
case ByteType =>
Expand All @@ -595,6 +595,84 @@ case class Round(child: Expression, scale: Expression) extends Expression {
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val ce = child.gen(ctx)

def integralRound(primitive: String): String = {
s"""
${ev.primitive} = new java.math.BigDecimal(${primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP)"""
}

def fractionalRound(primitive: String): String = {
s"""
${ev.primitive} = java.math.BigDecimal.valueOf(${primitive}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP)"""
}

def check(primitive: String, function: String): String = {
s"""
if (Double.isNaN(${primitive}) || Double.isInfinite(${primitive})){
${ev.primitive} = ${primitive};
} else {
${fractionalRound(primitive)}.${function};
}"""
}

def convert(primitive: String): String = {
val dName = ctx.freshName("converter")
s"""
Double $dName = 0.0;
try {
$dName = Double.valueOf(${primitive}.toString());
} catch (NumberFormatException e) {
${ev.isNull} = true;
}
${check(dName, "doubleValue()")}
"""
}

def decimalRound(): String = {
s"""
if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) {
${ev.primitive} = ${ce.primitive};
} else {
${ev.isNull} = true;
}
"""
}

val roundCode = child.dataType match {
case NullType => ";"
case _: DecimalType =>
decimalRound()
case ByteType =>
integralRound(ce.primitive) + ".byteValue();"
case ShortType =>
integralRound(ce.primitive) + ".shortValue();"
case IntegerType =>
integralRound(ce.primitive) + ".intValue();"
case LongType =>
integralRound(ce.primitive) + ".longValue();"
case FloatType =>
check(ce.primitive, "floatValue()")
case DoubleType =>
check(ce.primitive, "doubleValue()")
case StringType =>
convert(ce.primitive)
case BinaryType =>
convert(s"${ctx.stringType}.fromBytes(${ce.primitive})")
}

ce.code + s"""
boolean ${ev.isNull} = ${ce.isNull} || ${scaleV == null};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${roundCode}
}
"""
}

private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = {
input match {
case f: Float if (f.isNaN || f.isInfinite) => return input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,16 +343,25 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val domain = -16 to 16
val doublePi = math.Pi
val stringPi = "3.141592653589793"
val arrayPi: Array[Byte] = stringPi.toCharArray.map(_.toByte)
val shortPi: Short = 31415
val intPi = 314159265
val longPi = 31415926535897932L
val bdPi = BigDecimal(31415926535897932L, 10)

domain.foreach { scale =>
checkEvaluation(Round(doublePi, scale),
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
checkEvaluation(Round(stringPi, scale),
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
checkEvaluation(Round(arrayPi, scale),
BigDecimal.valueOf(doublePi).setScale(scale, RoundingMode.HALF_UP).toDouble, EmptyRow)
checkEvaluation(Round(shortPi, scale),
BigDecimal.valueOf(shortPi).setScale(scale, RoundingMode.HALF_UP).toShort, EmptyRow)
checkEvaluation(Round(intPi, scale),
BigDecimal.valueOf(intPi).setScale(scale, RoundingMode.HALF_UP).toInt, EmptyRow)
checkEvaluation(Round(longPi, scale),
BigDecimal.valueOf(longPi).setScale(scale, RoundingMode.HALF_UP).toLong, EmptyRow)
}
checkEvaluation(new Round(Literal("invalid input")), null, EmptyRow)

Expand Down

0 comments on commit 2077888

Please sign in to comment.