Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9549][SQL] fix bugs in expressions #7882

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,11 @@ object HiveTypeCoercion {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None => c
}

case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType =>
NaNvl(l, Cast(r, DoubleType))
case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
NaNvl(Cast(l, DoubleType), r)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
val originValue = ctx.freshName("origin")
// codegen would fail to compile if we just write (-($c))
// for example, we could not write --9223372036854775808L in code
s"""
${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval);
${ev.primitive} = (${ctx.javaType(dt)})(-($originValue));
"""})
case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
if (ctx.isPrimitiveType(left.dataType)
&& left.dataType != BooleanType // java boolean doesn't support > or < operator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our codegen can handle boolean type, doesn't it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but without test against BooleanType, it would use

// faster version
defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2")

rather than the else part, therefore, genComp is not used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ic, makes sense, good catch!

&& left.dataType != FloatType
&& left.dataType != DoubleType) {
// faster version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,9 @@ object DateTimeUtils {
*/
def dateAddMonths(days: Int, months: Int): Int = {
val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months
val currentMonthInYear = absoluteMonth % 12
val currentYear = absoluteMonth / 12
val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0
val currentMonthInYear = nonNegativeMonth % 12
val currentYear = nonNegativeMonth / 12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above two statements can be replaced with:
val (currentYear, currentMonthInYear) = nonNegativeMonth /% 12

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is that possible?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/% is not defined for Int.

I read the notion in a Scala book which I have returned. I will read more once I have that book back.

val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0
val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay

Expand All @@ -626,7 +627,7 @@ object DateTimeUtils {
} else {
dayOfMonth
}
firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1
firstDayOfMonth(nonNegativeMonth) + currentDayInMonth - 1
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,18 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Nil))
}

test("nanvl casts") {
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),
NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)),
NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType)))
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
}

test("type coercion for If") {
val rule = HiveTypeCoercion.IfCoercion
ruleTest(rule,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.types._

class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand Down Expand Up @@ -56,6 +56,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(UnaryMinus(input), convert(-1))
checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null)
}
checkEvaluation(UnaryMinus(Literal(Long.MinValue)), Long.MinValue)
checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue)
checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue)
checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue)
}

test("- (Minus)") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), null)
checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)),
null)
checkEvaluation(
AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(Int.MinValue)), -7293498)
}

test("months_between") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,60 +136,60 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
}

private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_))
private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))
private val largeValues =
Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_))
Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true).map(Literal(_))

private val equalValues1 =
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
private val equalValues2 =
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))

test("BinaryComparison: <") {
test("BinaryComparison: lessThan") {
for (i <- 0 until smallValues.length) {
checkEvaluation(smallValues(i) < largeValues(i), true)
checkEvaluation(equalValues1(i) < equalValues2(i), false)
checkEvaluation(largeValues(i) < smallValues(i), false)
checkEvaluation(LessThan(smallValues(i), largeValues(i)), true)
checkEvaluation(LessThan(equalValues1(i), equalValues2(i)), false)
checkEvaluation(LessThan(largeValues(i), smallValues(i)), false)
}
}

test("BinaryComparison: <=") {
test("BinaryComparison: LessThanOrEqual") {
for (i <- 0 until smallValues.length) {
checkEvaluation(smallValues(i) <= largeValues(i), true)
checkEvaluation(equalValues1(i) <= equalValues2(i), true)
checkEvaluation(largeValues(i) <= smallValues(i), false)
checkEvaluation(LessThanOrEqual(smallValues(i), largeValues(i)), true)
checkEvaluation(LessThanOrEqual(equalValues1(i), equalValues2(i)), true)
checkEvaluation(LessThanOrEqual(largeValues(i), smallValues(i)), false)
}
}

test("BinaryComparison: >") {
test("BinaryComparison: GreaterThan") {
for (i <- 0 until smallValues.length) {
checkEvaluation(smallValues(i) > largeValues(i), false)
checkEvaluation(equalValues1(i) > equalValues2(i), false)
checkEvaluation(largeValues(i) > smallValues(i), true)
checkEvaluation(GreaterThan(smallValues(i), largeValues(i)), false)
checkEvaluation(GreaterThan(equalValues1(i), equalValues2(i)), false)
checkEvaluation(GreaterThan(largeValues(i), smallValues(i)), true)
}
}

test("BinaryComparison: >=") {
test("BinaryComparison: GreaterThanOrEqual") {
for (i <- 0 until smallValues.length) {
checkEvaluation(smallValues(i) >= largeValues(i), false)
checkEvaluation(equalValues1(i) >= equalValues2(i), true)
checkEvaluation(largeValues(i) >= smallValues(i), true)
checkEvaluation(GreaterThanOrEqual(smallValues(i), largeValues(i)), false)
checkEvaluation(GreaterThanOrEqual(equalValues1(i), equalValues2(i)), true)
checkEvaluation(GreaterThanOrEqual(largeValues(i), smallValues(i)), true)
}
}

test("BinaryComparison: ===") {
test("BinaryComparison: EqualTo") {
for (i <- 0 until smallValues.length) {
checkEvaluation(smallValues(i) === largeValues(i), false)
checkEvaluation(equalValues1(i) === equalValues2(i), true)
checkEvaluation(largeValues(i) === smallValues(i), false)
checkEvaluation(EqualTo(smallValues(i), largeValues(i)), false)
checkEvaluation(EqualTo(equalValues1(i), equalValues2(i)), true)
checkEvaluation(EqualTo(largeValues(i), smallValues(i)), false)
}
}

test("BinaryComparison: <=>") {
test("BinaryComparison: EqualNullSafe") {
for (i <- 0 until smallValues.length) {
checkEvaluation(smallValues(i) <=> largeValues(i), false)
checkEvaluation(equalValues1(i) <=> equalValues2(i), true)
checkEvaluation(largeValues(i) <=> smallValues(i), false)
checkEvaluation(EqualNullSafe(smallValues(i), largeValues(i)), false)
checkEvaluation(EqualNullSafe(equalValues1(i), equalValues2(i)), true)
checkEvaluation(EqualNullSafe(largeValues(i), smallValues(i)), false)
}
}

Expand All @@ -209,8 +209,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
nullTest(GreaterThanOrEqual)
nullTest(EqualTo)

checkEvaluation(normalInt <=> nullInt, false)
checkEvaluation(nullInt <=> normalInt, false)
checkEvaluation(nullInt <=> nullInt, true)
checkEvaluation(EqualNullSafe(normalInt, nullInt), false)
checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -227,20 +227,24 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {

test("nanvl") {
val testData = ctx.createDataFrame(ctx.sparkContext.parallelize(
Row(null, 3.0, Double.NaN, Double.PositiveInfinity) :: Nil),
Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil),
StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType),
StructField("c", DoubleType), StructField("d", DoubleType))))
StructField("c", DoubleType), StructField("d", DoubleType),
StructField("e", FloatType), StructField("f", IntegerType))))

checkAnswer(
testData.select(
nanvl($"a", lit(5)), nanvl($"b", lit(10)),
nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10))),
Row(null, 3.0, null, Double.PositiveInfinity)
nanvl($"a", lit(5)), nanvl($"b", lit(10)), nanvl(lit(10), $"b"),
nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10)),
nanvl($"b", $"e"), nanvl($"e", $"f")),
Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
)
testData.registerTempTable("t")
checkAnswer(
ctx.sql("select nanvl(a, 5), nanvl(b, 10), nanvl(c, null), nanvl(d, 10) from t"),
Row(null, 3.0, null, Double.PositiveInfinity)
ctx.sql(
"select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " +
" nanvl(b, e), nanvl(e, f) from t"),
Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
)
}

Expand Down