diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index e72736d9e2691..ece9dafbb1270 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row} @@ -145,12 +144,15 @@ class SparkPlanTest extends SparkFunSuite { * instantiate a reference implementation of the physical operator * that's being tested. The result of executing this plan will be * treated as the source-of-truth for the test. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedPlanFunction: SparkPlan => SparkPlan): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction) match { + expectedPlanFunction: SparkPlan => SparkPlan, + sortAnswers: Boolean): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -175,7 +177,8 @@ object SparkPlanTest { def checkAnswer( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedPlanFunction: SparkPlan => SparkPlan): Option[String] = { + expectedPlanFunction: SparkPlan => SparkPlan, + sortAnswers: Boolean): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) @@ -210,7 +213,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(actualAnswer, expectedAnswer).map { errorMessage => + compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match. | Actual result Spark plan: @@ -262,7 +265,8 @@ object SparkPlanTest { private def compareAnswers( sparkAnswer: Seq[Row], - expectedAnswer: Seq[Row]): Option[String] = { + expectedAnswer: Seq[Row], + sort: Boolean = true): Option[String] = { def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to @@ -277,7 +281,11 @@ object SparkPlanTest { case o => o }) } - converted.sortBy(_.toString()) + if (sort) { + converted.sortBy(_.toString()) + } else { + converted + } } if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { val errorMessage = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index f92aab904f754..f4b8782e39b03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -38,29 +38,29 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { // Test sorting on different data types // TODO: randomized spilling to ensure that merging is tested at least once for every data type. - (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => - for ( - nullable <- Seq(true, false); - sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); - randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) - ) { - test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(1024)(randomDataGenerator()).filter { - case d: Double => !d.isNaN - case f: Float => !java.lang.Float.isNaN(f) - case x => true - } - val inputDf = TestSQLContext.createDataFrame( - TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), - StructType(StructField("a", dataType, nullable = true) :: Nil) - ) - assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) - checkAnswer( - inputDf, - UnsafeExternalSort(sortOrder, global = false, _: SparkPlan, testSpillFrequency = 100), - Sort(sortOrder, global = false, _: SparkPlan) - ) + for ( + dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); + nullable <- Seq(true, false); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) + ) { + test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { + val inputData = Seq.fill(3)(randomDataGenerator()).filter { + case d: Double => !d.isNaN + case f: Float => !java.lang.Float.isNaN(f) + case x => true } + val inputDf = TestSQLContext.createDataFrame( + TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + StructType(StructField("a", dataType, nullable = true) :: Nil) + ) + assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) + checkAnswer( + inputDf, + UnsafeExternalSort(sortOrder, global = false, _: SparkPlan, testSpillFrequency = 2), + Sort(sortOrder, global = false, _: SparkPlan), + sortAnswers = false + ) } } }