Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32688][SQL][TEST] Add special values to LiteralGenerator for float and double #29515

Closed
wants to merge 8 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,27 @@ object LiteralGenerator {
lazy val longLiteralGen: Gen[Literal] =
for { l <- Arbitrary.arbLong.arbitrary } yield Literal.create(l, LongType)

// The floatLiteralGen and doubleLiteralGen will 50% of the time yield arbitrary values
// and 50% of the time will yield some special values that are more likely to reveal
// corner cases. This behavior is similar to the integral value generators.
lazy val floatLiteralGen: Gen[Literal] =
for {
f <- Gen.chooseNum(Float.MinValue / 2, Float.MaxValue / 2,
Float.NaN, Float.PositiveInfinity, Float.NegativeInfinity)
tanelk marked this conversation as resolved.
Show resolved Hide resolved
f <- Gen.oneOf(
Gen.oneOf(
Float.NaN, Float.PositiveInfinity, Float.NegativeInfinity, Float.MinPositiveValue,
tanelk marked this conversation as resolved.
Show resolved Hide resolved
Float.MaxValue, -Float.MaxValue, 0.0f, -0.0f, 1.0f, -1.0f),
Arbitrary.arbFloat.arbitrary
tanelk marked this conversation as resolved.
Show resolved Hide resolved
)
} yield Literal.create(f, FloatType)

lazy val doubleLiteralGen: Gen[Literal] =
for {
f <- Gen.chooseNum(Double.MinValue / 2, Double.MaxValue / 2,
Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity)
f <- Gen.oneOf(
Gen.oneOf(
Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity, Double.MinPositiveValue,
Double.MaxValue, -Double.MaxValue, 0.0, -0.0, 1.0, -1.0),
Arbitrary.arbDouble.arbitrary
)
} yield Literal.create(f, DoubleType)

// TODO cache the generated data
Expand Down Expand Up @@ -167,6 +178,8 @@ object LiteralGenerator {
case BinaryType => binaryLiteralGen
case CalendarIntervalType => calendarIntervalLiterGen
case DecimalType.Fixed(precision, scale) => decimalLiteralGen(precision, scale)
case ArrayType(et, _) => randomGen(et).map(
lit => Literal.create(Array(lit.value), ArrayType(et)))
case dt => throw new IllegalArgumentException(s"not supported type $dt")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.collection.immutable.HashSet

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
Expand Down Expand Up @@ -91,6 +91,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
DataTypeTestUtils.propertyCheckSupported.foreach { dt =>
checkConsistencyBetweenInterpretedAndCodegen(EqualTo, dt, dt)
checkConsistencyBetweenInterpretedAndCodegen(EqualNullSafe, dt, dt)

val arrayType = ArrayType(dt)
checkConsistencyBetweenInterpretedAndCodegen(EqualTo, arrayType, arrayType)
checkConsistencyBetweenInterpretedAndCodegen(EqualNullSafe, arrayType, arrayType)
}
}

Expand Down Expand Up @@ -496,6 +500,32 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(EqualTo(infinity, infinity), true)
}

private def testEquality(literals: Seq[Literal]): Unit = {
literals.foreach(left => {
literals.foreach(right => {
checkEvaluation(EqualTo(left, right), true)
checkEvaluation(EqualNullSafe(left, right), true)

val leftArray = Literal.create(Array(left.value), ArrayType(left.dataType))
val rightArray = Literal.create(Array(right.value), ArrayType(right.dataType))
checkEvaluation(EqualTo(leftArray, rightArray), true)
checkEvaluation(EqualNullSafe(leftArray, rightArray), true)

val leftStruct = Literal.create(
Row(left.value), new StructType().add("a", left.dataType))
val rightStruct = Literal.create(
Row(right.value), new StructType().add("a", right.dataType))
checkEvaluation(EqualTo(leftStruct, rightStruct), true)
checkEvaluation(EqualNullSafe(leftStruct, rightStruct), true)
})
})
}

test("SPARK-32688: 0.0 and -0.0 should be equal") {
testEquality(Seq(Literal(0.0), Literal(-0.0)))
testEquality(Seq(Literal(0.0f), Literal(-0.0f)))
}

test("SPARK-22693: InSet should not use global variables") {
val ctx = new CodegenContext
InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx)
Expand Down