Skip to content

Commit

Permalink
fix bugs in expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
yjshen committed Aug 3, 2015
1 parent 9d03ad9 commit 4fa5de0
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,12 @@ object HiveTypeCoercion {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None => c
}

case n @ NaNvl(l, r) if l.dataType != r.dataType =>
l.dataType match {
case DoubleType => NaNvl(l, Cast(r, DoubleType))
case FloatType => NaNvl(Cast(l, DoubleType), r)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
val originValue = ctx.freshName("origin")
// codegen would fail to compile if we just write (-($c))
// for example, we could not write --9223372036854775808L in code
s"""
${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval);
${ev.primitive} = (${ctx.javaType(dt)})(-($originValue));
"""})
case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (ctx.isPrimitiveType(left.dataType)
&& left.dataType != BooleanType // java boolean doesn't support > or < operator
&& left.dataType != FloatType
&& left.dataType != DoubleType) {
// faster version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,9 @@ object DateTimeUtils {
*/
def dateAddMonths(days: Int, months: Int): Int = {
val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months
val currentMonthInYear = absoluteMonth % 12
val currentYear = absoluteMonth / 12
val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0
val currentMonthInYear = nonNegativeMonth % 12
val currentYear = nonNegativeMonth / 12
val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0
val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay

Expand All @@ -626,7 +627,7 @@ object DateTimeUtils {
} else {
dayOfMonth
}
firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1
firstDayOfMonth(nonNegativeMonth) + currentDayInMonth - 1
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,18 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Nil))
}

test("nanvl casts") {
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),
NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)),
NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType)))
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
}

test("type coercion for If") {
val rule = HiveTypeCoercion.IfCoercion
ruleTest(rule,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.types.{LongType, Decimal}

class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand Down Expand Up @@ -56,6 +56,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(UnaryMinus(input), convert(-1))
checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null)
}
checkEvaluation(UnaryMinus(Literal.create(Long.MinValue, LongType)), Long.MinValue)
}

test("- (Minus)") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), null)
checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)),
null)
checkEvaluation(
AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(Int.MinValue)), -7293498)
}

test("months_between") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,11 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(nullInt <=> normalInt, false)
checkEvaluation(nullInt <=> nullInt, true)
}

test("BinaryComparison: boolean input") {
checkEvaluation(LessThan(Literal(true), Literal(false)), false)
checkEvaluation(LessThanOrEqual(Literal(true), Literal(false)), false)
checkEvaluation(GreaterThan(Literal(true), Literal(false)), true)
checkEvaluation(GreaterThanOrEqual(Literal(true), Literal(false)), true)
}
}

0 comments on commit 4fa5de0

Please sign in to comment.