diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index a1f15e4f0f25a..72285c6a24199 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -52,6 +52,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Add(positiveShortLit, negativeShortLit), -1.toShort) checkEvaluation(Add(positiveIntLit, negativeIntLit), -1) checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) + + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Add, tpe, tpe) + } } test("- (UnaryMinus)") { @@ -71,6 +75,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(negativeIntLit), - negativeInt) checkEvaluation(UnaryMinus(positiveLongLit), - positiveLong) checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong) + + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(UnaryMinus, tpe) + } } test("- (Minus)") { @@ -85,6 +93,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper (positiveShort - negativeShort).toShort) checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt) checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong) + + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Subtract, tpe, tpe) + } } test("* (Multiply)") { @@ -99,6 +111,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper (positiveShort * negativeShort).toShort) checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt) checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong) + + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Multiply, tpe, tpe) + } } test("/ (Divide) basic") { @@ -111,6 +127,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero } + + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe) + } } test("/ (Divide) for integral type") { @@ -144,6 +164,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Remainder(negativeIntLit, negativeIntLit), 0) checkEvaluation(Remainder(positiveLongLit, positiveLongLit), 0L) checkEvaluation(Remainder(negativeLongLit, negativeLongLit), 0L) + + // TODO: the following lines would fail the test due to inconsistency result of interpret + // and codegen for remainder between giant values, seems like a numeric stability issue + // DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + // checkConsistencyBetweenInterpretedAndCodegen(Remainder, tpe, tpe) + // } } test("Abs") { @@ -161,6 +187,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Abs(negativeIntLit), - negativeInt) checkEvaluation(Abs(positiveLongLit), positiveLong) checkEvaluation(Abs(negativeLongLit), - negativeLong) + + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(Abs, tpe) + } } test("MaxOf basic") { @@ -175,6 +205,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MaxOf(positiveShortLit, negativeShortLit), (positiveShort).toShort) checkEvaluation(MaxOf(positiveIntLit, negativeIntLit), positiveInt) checkEvaluation(MaxOf(positiveLongLit, negativeLongLit), positiveLong) + + DataTypeTestUtils.ordered.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(MaxOf, tpe, tpe) + } } test("MaxOf for atomic type") { @@ -196,6 +230,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(positiveShortLit, negativeShortLit), (negativeShort).toShort) checkEvaluation(MinOf(positiveIntLit, negativeIntLit), negativeInt) checkEvaluation(MinOf(positiveLongLit, negativeLongLit), negativeLong) + + DataTypeTestUtils.ordered.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(MinOf, tpe, tpe) + } } test("MinOf for atomic type") { @@ -222,4 +260,8 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Pmod(positiveInt, negativeInt), positiveInt) checkEvaluation(Pmod(positiveLong, negativeLong), positiveLong) } + + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegen(MinOf, tpe, tpe) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala index 4fc1c06153595..3a310c0e9a7a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -45,6 +45,10 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseNot(negativeIntLit), ~negativeInt) checkEvaluation(BitwiseNot(positiveLongLit), ~positiveLong) checkEvaluation(BitwiseNot(negativeLongLit), ~negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseNot, dt) + } } test("BitwiseAnd") { @@ -68,6 +72,10 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (positiveShort & negativeShort).toShort) checkEvaluation(BitwiseAnd(positiveIntLit, negativeIntLit), positiveInt & negativeInt) checkEvaluation(BitwiseAnd(positiveLongLit, negativeLongLit), positiveLong & negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseAnd, dt, dt) + } } test("BitwiseOr") { @@ -91,6 +99,10 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (positiveShort | negativeShort).toShort) checkEvaluation(BitwiseOr(positiveIntLit, negativeIntLit), positiveInt | negativeInt) checkEvaluation(BitwiseOr(positiveLongLit, negativeLongLit), positiveLong | negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseOr, dt, dt) + } } test("BitwiseXor") { @@ -110,10 +122,13 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseXor(nullLit, Literal(1)), null) checkEvaluation(BitwiseXor(Literal(1), nullLit), null) checkEvaluation(BitwiseXor(nullLit, nullLit), null) - checkEvaluation(BitwiseXor(positiveShortLit, negativeShortLit), (positiveShort ^ negativeShort).toShort) checkEvaluation(BitwiseXor(positiveIntLit, negativeIntLit), positiveInt ^ negativeInt) checkEvaluation(BitwiseXor(positiveLongLit, negativeLongLit), positiveLong ^ negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseXor, dt, dt) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index d26bcdb2902ab..0df673bb9fa02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -66,6 +66,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper testIf(_.toLong, TimestampType) testIf(_.toString, StringType) + + DataTypeTestUtils.propertyCheckSupported.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(If, BooleanType, dt, dt) + } } test("case when") { @@ -176,6 +180,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Timestamp.valueOf("2015-07-01 08:00:00")), Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty) + + DataTypeTestUtils.ordered.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) + } } test("function greatest") { @@ -218,6 +226,9 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Timestamp.valueOf("2015-07-01 08:00:00")), Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty) - } + DataTypeTestUtils.ordered.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index f9b73f1a75e73..610d39e8493cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -60,6 +60,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } checkEvaluation(DayOfYear(Literal.create(null, DateType)), null) + checkConsistencyBetweenInterpretedAndCodegen(DayOfYear, DateType) } test("Year") { @@ -79,6 +80,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistencyBetweenInterpretedAndCodegen(Year, DateType) } test("Quarter") { @@ -98,6 +100,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistencyBetweenInterpretedAndCodegen(Quarter, DateType) } test("Month") { @@ -117,6 +120,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistencyBetweenInterpretedAndCodegen(Month, DateType) } test("Day / DayOfMonth") { @@ -135,6 +139,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.get(Calendar.DAY_OF_MONTH)) } } + checkConsistencyBetweenInterpretedAndCodegen(DayOfMonth, DateType) } test("Seconds") { @@ -149,6 +154,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Second(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.SECOND)) } + checkConsistencyBetweenInterpretedAndCodegen(Second, TimestampType) } test("WeekOfYear") { @@ -157,6 +163,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15) checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45) checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18) + checkConsistencyBetweenInterpretedAndCodegen(WeekOfYear, DateType) } test("DateFormat") { @@ -184,6 +191,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkConsistencyBetweenInterpretedAndCodegen(Hour, TimestampType) } test("Minute") { @@ -200,6 +208,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.get(Calendar.MINUTE)) } } + checkConsistencyBetweenInterpretedAndCodegen(Minute, TimestampType) } test("date_add") { @@ -218,6 +227,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DateAdd(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 49627) checkEvaluation( DateAdd(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -15910) + checkConsistencyBetweenInterpretedAndCodegen(DateAdd, DateType, IntegerType) } test("date_sub") { @@ -236,6 +246,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DateSub(Literal(Date.valueOf("2016-02-28")), positiveIntLit), -15909) checkEvaluation( DateSub(Literal(Date.valueOf("2016-02-28")), negativeIntLit), 49628) + checkConsistencyBetweenInterpretedAndCodegen(DateSub, DateType, IntegerType) } test("time_add") { @@ -254,6 +265,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( TimeAdd(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), null) + checkConsistencyBetweenInterpretedAndCodegen(TimeAdd, TimestampType, CalendarIntervalType) } test("time_sub") { @@ -277,6 +289,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( TimeSub(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), null) + checkConsistencyBetweenInterpretedAndCodegen(TimeSub, TimestampType, CalendarIntervalType) } test("add_months") { @@ -296,6 +309,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { AddMonths(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 1014213) checkEvaluation( AddMonths(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -980528) + checkConsistencyBetweenInterpretedAndCodegen(AddMonths, DateType, IntegerType) } test("months_between") { @@ -320,6 +334,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(MonthsBetween(t, tnull), null) checkEvaluation(MonthsBetween(tnull, t), null) checkEvaluation(MonthsBetween(tnull, tnull), null) + checkConsistencyBetweenInterpretedAndCodegen(MonthsBetween, TimestampType, TimestampType) } test("last_day") { @@ -337,6 +352,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(LastDay(Literal(Date.valueOf("2016-01-06"))), Date.valueOf("2016-01-31")) checkEvaluation(LastDay(Literal(Date.valueOf("2016-02-07"))), Date.valueOf("2016-02-29")) checkEvaluation(LastDay(Literal.create(null, DateType)), null) + checkConsistencyBetweenInterpretedAndCodegen(LastDay, DateType) } test("next_day") { @@ -370,6 +386,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ToDate(Literal(Date.valueOf("2015-07-22"))), DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22"))) checkEvaluation(ToDate(Literal.create(null, DateType)), null) + checkConsistencyBetweenInterpretedAndCodegen(ToDate, DateType) } test("function trunc") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a41185b4d8754..465f7d08aa142 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -17,18 +17,21 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread +import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.sql.types.DataType /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. */ -trait ExpressionEvalHelper { +trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { self: SparkFunSuite => protected def create_row(values: Any*): InternalRow = { @@ -211,4 +214,111 @@ trait ExpressionEvalHelper { plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) } + + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against unary expressions by feeding them arbitrary literals of `dataType`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: Expression => Expression, + dataType: DataType): Unit = { + forAll (LiteralGenerator.randomGen(dataType)) { (l: Literal) => + cmpInterpretWithCodegen(EmptyRow, c(l)) + } + } + + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against binary expressions by feeding them arbitrary literals of `dataType1` + * and `dataType2`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: (Expression, Expression) => Expression, + dataType1: DataType, + dataType2: DataType): Unit = { + forAll ( + LiteralGenerator.randomGen(dataType1), + LiteralGenerator.randomGen(dataType2) + ) { (l1: Literal, l2: Literal) => + cmpInterpretWithCodegen(EmptyRow, c(l1, l2)) + } + } + + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against ternary expressions by feeding them arbitrary literals of `dataType1`, + * `dataType2` and `dataType3`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: (Expression, Expression, Expression) => Expression, + dataType1: DataType, + dataType2: DataType, + dataType3: DataType): Unit = { + forAll ( + LiteralGenerator.randomGen(dataType1), + LiteralGenerator.randomGen(dataType2), + LiteralGenerator.randomGen(dataType3) + ) { (l1: Literal, l2: Literal, l3: Literal) => + cmpInterpretWithCodegen(EmptyRow, c(l1, l2, l3)) + } + } + + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. + * + * This method test against expressions take Seq[Expression] as input by feeding them + * arbitrary length Seq of arbitrary literal of `dataType`. + */ + def checkConsistencyBetweenInterpretedAndCodegen( + c: Seq[Expression] => Expression, + dataType: DataType, + minNumElements: Int = 0): Unit = { + forAll (Gen.listOf(LiteralGenerator.randomGen(dataType))) { (literals: Seq[Literal]) => + whenever(literals.size >= minNumElements) { + cmpInterpretWithCodegen(EmptyRow, c(literals)) + } + } + } + + private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { + val interpret = try { + evaluate(expr, inputRow) + } catch { + case e: Exception => fail(s"Exception evaluating $expr", e) + } + + val plan = generateProject( + GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil)(), + expr) + val codegen = plan(inputRow).get(0, expr.dataType) + + if (!compareResults(interpret, codegen)) { + fail(s"Incorrect evaluation: $expr, interpret: $interpret, codegen: $codegen") + } + } + + /** + * Check the equality between result of expression and expected value, it will handle + * Array[Byte] and Spread[Double]. + */ + private[this] def compareResults(result: Any, expected: Any): Boolean = { + (result, expected) match { + case (result: Array[Byte], expected: Array[Byte]) => + java.util.Arrays.equals(result, expected) + case (result: Double, expected: Spread[Double]) => + expected.isWithin(result) + case (result: Double, expected: Double) if result.isNaN && expected.isNaN => + true + case (result: Float, expected: Float) if result.isNaN && expected.isNaN => + true + case _ => result == expected + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala new file mode 100644 index 0000000000000..ee6d25157fc08 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date, Timestamp} + +import org.scalacheck.{Arbitrary, Gen} +import org.scalatest.Matchers +import org.scalatest.prop.GeneratorDrivenPropertyChecks + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Property is a high-level specification of behavior that should hold for a range of data points. + * + * For example, while we are evaluating a deterministic expression for some input, we should always + * hold the property that the result never changes, regardless of how we get the result, + * via interpreted or codegen. + * + * In ScalaTest, properties are specified as functions and the data points used to check properties + * can be supplied by either tables or generators. + * + * Generator-driven property checks are performed via integration with ScalaCheck. + * + * @example {{{ + * def toTest(i: Int): Boolean = if (i % 2 == 0) true else false + * + * import org.scalacheck.Gen + * + * test ("true if param is even") { + * val evenInts = for (n <- Gen.choose(-1000, 1000)) yield 2 * n + * forAll(evenInts) { (i: Int) => + * assert (toTest(i) === true) + * } + * } + * }}} + * + */ +object LiteralGenerator { + + lazy val byteLiteralGen: Gen[Literal] = + for { b <- Arbitrary.arbByte.arbitrary } yield Literal.create(b, ByteType) + + lazy val shortLiteralGen: Gen[Literal] = + for { s <- Arbitrary.arbShort.arbitrary } yield Literal.create(s, ShortType) + + lazy val integerLiteralGen: Gen[Literal] = + for { i <- Arbitrary.arbInt.arbitrary } yield Literal.create(i, IntegerType) + + lazy val longLiteralGen: Gen[Literal] = + for { l <- Arbitrary.arbLong.arbitrary } yield Literal.create(l, LongType) + + lazy val floatLiteralGen: Gen[Literal] = + for { + f <- Gen.chooseNum(Float.MinValue / 2, Float.MaxValue / 2, + Float.NaN, Float.PositiveInfinity, Float.NegativeInfinity) + } yield Literal.create(f, FloatType) + + lazy val doubleLiteralGen: Gen[Literal] = + for { + f <- Gen.chooseNum(Double.MinValue / 2, Double.MaxValue / 2, + Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity) + } yield Literal.create(f, DoubleType) + + // TODO: decimal type + + lazy val stringLiteralGen: Gen[Literal] = + for { s <- Arbitrary.arbString.arbitrary } yield Literal.create(s, StringType) + + lazy val binaryLiteralGen: Gen[Literal] = + for { ab <- Gen.listOf[Byte](Arbitrary.arbByte.arbitrary) } + yield Literal.create(ab.toArray, BinaryType) + + lazy val booleanLiteralGen: Gen[Literal] = + for { b <- Arbitrary.arbBool.arbitrary } yield Literal.create(b, BooleanType) + + lazy val dateLiteralGen: Gen[Literal] = + for { d <- Arbitrary.arbInt.arbitrary } yield Literal.create(new Date(d), DateType) + + lazy val timestampLiteralGen: Gen[Literal] = + for { t <- Arbitrary.arbLong.arbitrary } yield Literal.create(new Timestamp(t), TimestampType) + + lazy val calendarIntervalLiterGen: Gen[Literal] = + for { m <- Arbitrary.arbInt.arbitrary; s <- Arbitrary.arbLong.arbitrary} + yield Literal.create(new CalendarInterval(m, s), CalendarIntervalType) + + + // Sometimes, it would be quite expensive when unlimited value is used, + // for example, the `times` arguments for StringRepeat would hang the test 'forever' + // if it's tested against Int.MaxValue by ScalaCheck, therefore, use values from a limited + // range is more reasonable + lazy val limitedIntegerLiteralGen: Gen[Literal] = + for { i <- Gen.choose(-100, 100) } yield Literal.create(i, IntegerType) + + def randomGen(dt: DataType): Gen[Literal] = { + dt match { + case ByteType => byteLiteralGen + case ShortType => shortLiteralGen + case IntegerType => integerLiteralGen + case LongType => longLiteralGen + case DoubleType => doubleLiteralGen + case FloatType => floatLiteralGen + case DateType => dateLiteralGen + case TimestampType => timestampLiteralGen + case BooleanType => booleanLiteralGen + case StringType => stringLiteralGen + case BinaryType => binaryLiteralGen + case CalendarIntervalType => calendarIntervalLiterGen + case dt => throw new IllegalArgumentException(s"not supported type $dt") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 033792eee6c0f..90c59f240b542 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ - class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { import IntegralLiteralTestUtils._ @@ -184,60 +183,74 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("sin") { testUnary(Sin, math.sin) + checkConsistencyBetweenInterpretedAndCodegen(Sin, DoubleType) } test("asin") { testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Asin, DoubleType) } test("sinh") { testUnary(Sinh, math.sinh) + checkConsistencyBetweenInterpretedAndCodegen(Sinh, DoubleType) } test("cos") { testUnary(Cos, math.cos) + checkConsistencyBetweenInterpretedAndCodegen(Cos, DoubleType) } test("acos") { testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) } test("cosh") { testUnary(Cosh, math.cosh) + checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType) } test("tan") { testUnary(Tan, math.tan) + checkConsistencyBetweenInterpretedAndCodegen(Tan, DoubleType) } test("atan") { testUnary(Atan, math.atan) + checkConsistencyBetweenInterpretedAndCodegen(Atan, DoubleType) } test("tanh") { testUnary(Tanh, math.tanh) + checkConsistencyBetweenInterpretedAndCodegen(Tanh, DoubleType) } test("toDegrees") { testUnary(ToDegrees, math.toDegrees) + checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) } test("toRadians") { testUnary(ToRadians, math.toRadians) + checkConsistencyBetweenInterpretedAndCodegen(ToRadians, DoubleType) } test("cbrt") { testUnary(Cbrt, math.cbrt) + checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType) } test("ceil") { testUnary(Ceil, math.ceil) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) } test("floor") { testUnary(Floor, math.floor) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) } test("factorial") { @@ -247,37 +260,45 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, IntegerType), null, create_row(null)) checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) checkEvaluation(Factorial(Literal(21)), null, EmptyRow) + checkConsistencyBetweenInterpretedAndCodegen(Factorial.apply _, IntegerType) } test("rint") { testUnary(Rint, math.rint) + checkConsistencyBetweenInterpretedAndCodegen(Rint, DoubleType) } test("exp") { testUnary(Exp, math.exp) + checkConsistencyBetweenInterpretedAndCodegen(Exp, DoubleType) } test("expm1") { testUnary(Expm1, math.expm1) + checkConsistencyBetweenInterpretedAndCodegen(Expm1, DoubleType) } test("signum") { testUnary[Double, Double](Signum, math.signum) + checkConsistencyBetweenInterpretedAndCodegen(Signum, DoubleType) } test("log") { testUnary(Log, math.log, (1 to 20).map(_ * 0.1)) testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log, DoubleType) } test("log10") { testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1)) testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log10, DoubleType) } test("log1p") { testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1)) testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log1p, DoubleType) } test("bin") { @@ -298,12 +319,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Bin(positiveLongLit), java.lang.Long.toBinaryString(positiveLong)) checkEvaluation(Bin(negativeLongLit), java.lang.Long.toBinaryString(negativeLong)) + + checkConsistencyBetweenInterpretedAndCodegen(Bin, LongType) } test("log2") { def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) testUnary(Log2, f, (1 to 20).map(_ * 0.1)) testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log2, DoubleType) } test("sqrt") { @@ -313,11 +337,13 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) checkNaN(Sqrt(Literal(-1.0)), EmptyRow) checkNaN(Sqrt(Literal(-1.5)), EmptyRow) + checkConsistencyBetweenInterpretedAndCodegen(Sqrt, DoubleType) } test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType) } test("shift left") { @@ -338,6 +364,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftLeft(positiveLongLit, negativeIntLit), positiveLong << negativeInt) checkEvaluation(ShiftLeft(negativeLongLit, positiveIntLit), negativeLong << positiveInt) checkEvaluation(ShiftLeft(negativeLongLit, negativeIntLit), negativeLong << negativeInt) + + checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, LongType, IntegerType) } test("shift right") { @@ -358,6 +387,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ShiftRight(positiveLongLit, negativeIntLit), positiveLong >> negativeInt) checkEvaluation(ShiftRight(negativeLongLit, positiveIntLit), negativeLong >> positiveInt) checkEvaluation(ShiftRight(negativeLongLit, negativeIntLit), negativeLong >> negativeInt) + + checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, LongType, IntegerType) } test("shift right unsigned") { @@ -386,6 +418,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { negativeLong >>> positiveInt) checkEvaluation(ShiftRightUnsigned(negativeLongLit, negativeIntLit), negativeLong >>> negativeInt) + + checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, LongType, IntegerType) } test("hex") { @@ -400,6 +435,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Turn off scala style for non-ascii chars checkEvaluation(Hex(Literal("三重的".getBytes("UTF8"))), "E4B889E9878DE79A84") // scalastyle:on + Seq(LongType, BinaryType, StringType).foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(Hex.apply _, dt) + } } test("unhex") { @@ -413,16 +451,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Turn off scala style for non-ascii chars checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8")) checkEvaluation(Unhex(Literal("三重的")), null) - // scalastyle:on + checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType) } test("hypot") { testBinary(Hypot, math.hypot) + checkConsistencyBetweenInterpretedAndCodegen(Hypot, DoubleType, DoubleType) } test("atan2") { testBinary(Atan2, math.atan2) + checkConsistencyBetweenInterpretedAndCodegen(Atan2, DoubleType, DoubleType) } test("binary log") { @@ -454,6 +494,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Logarithm(Literal(1.0), Literal(-1.0)), null, create_row(null)) + checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType) } test("round") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index b524d0af14a67..75d17417e5a02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -29,6 +29,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "6ac1e56bc78f031059be7be854522c4c") checkEvaluation(Md5(Literal.create(null, BinaryType)), null) + checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType) } test("sha1") { @@ -37,6 +38,7 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { "5d211bad8f4ee70e16c7d343a838fc344a1ed961") checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType) } test("sha2") { @@ -55,6 +57,6 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), 2180413220L) checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) + checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 7beef71845e43..54c04faddb477 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -73,6 +73,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { notTrueTable.foreach { case (v, answer) => checkEvaluation(Not(Literal.create(v, BooleanType)), answer) } + checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType) + } + + test("AND, OR, EqualTo, EqualNullSafe consistency check") { + checkConsistencyBetweenInterpretedAndCodegen(And, BooleanType, BooleanType) + checkConsistencyBetweenInterpretedAndCodegen(Or, BooleanType, BooleanType) + DataTypeTestUtils.propertyCheckSupported.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(EqualTo, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(EqualNullSafe, dt, dt) + } } booleanLogicTest("AND", And, @@ -180,6 +190,15 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_)) + test("BinaryComparison consistency check") { + DataTypeTestUtils.ordered.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(LessThan, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(LessThanOrEqual, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(GreaterThan, dt, dt) + checkConsistencyBetweenInterpretedAndCodegen(GreaterThanOrEqual, dt, dt) + } + } + test("BinaryComparison: lessThan") { for (i <- 0 until smallValues.length) { checkEvaluation(LessThan(smallValues(i), largeValues(i)), true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index 417df006ab7c2..ed2c641d63e25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -46,6 +46,25 @@ object DataTypeTestUtils { */ val numericTypes: Set[NumericType] = integralType ++ fractionalTypes + // TODO: remove this once we find out how to handle decimal properly in property check + val numericTypeWithoutDecimal: Set[DataType] = integralType ++ Set(DoubleType, FloatType) + + /** + * Instances of all [[NumericType]]s and [[CalendarIntervalType]] + */ + val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal + CalendarIntervalType + + /** + * All the types that support ordering + */ + val ordered: Set[DataType] = + numericTypeWithoutDecimal + BooleanType + TimestampType + DateType + StringType + BinaryType + + /** + * All the types that we can use in a property check + */ + val propertyCheckSupported: Set[DataType] = ordered + /** * Instances of all [[AtomicType]]s. */