From 1c7bad8c538be233a6308a970861a2ff3fe002c7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Jul 2015 23:44:33 -0700 Subject: [PATCH] Make sorting of answers explicit in SparkPlanTest.checkAnswer(). --- .../spark/sql/execution/SortSuite.scala | 10 ++-- .../spark/sql/execution/SparkPlanTest.scala | 60 ++++++++++++++----- .../execution/UnsafeExternalSortSuite.scala | 2 +- .../sql/execution/joins/OuterJoinSuite.scala | 15 +++-- 4 files changed, 61 insertions(+), 26 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index a1e3ca11b1ad9..be59c502e8c64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -33,12 +33,14 @@ class SortSuite extends SparkPlanTest { checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan), - input.sorted) + ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + input.sortBy(t => (t._1, t._2)), + sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan), - input.sortBy(t => (t._2, t._1))) + ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + input.sortBy(t => (t._2, t._1)), + sortAnswers = false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index ece9dafbb1270..831b0f9109ab1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -46,12 +46,19 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedAnswer: Seq[Row]): Unit = { - checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer) + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Unit = { + checkAnswer( + input :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans.head), + expectedAnswer, + sortAnswers) } /** @@ -61,14 +68,20 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer( left: DataFrame, right: DataFrame, planFunction: (SparkPlan, SparkPlan) => SparkPlan, - expectedAnswer: Seq[Row]): Unit = { - checkAnswer(left :: right :: Nil, - (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer) + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Unit = { + checkAnswer( + left :: right :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), + expectedAnswer, + sortAnswers) } /** @@ -77,12 +90,15 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer( input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, - expectedAnswer: Seq[Row]): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -94,13 +110,16 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer[A <: Product : TypeTag]( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedAnswer: Seq[A]): Unit = { + expectedAnswer: Seq[A], + sortAnswers: Boolean): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) - checkAnswer(input, planFunction, expectedRows) + checkAnswer(input, planFunction, expectedRows, sortAnswers) } /** @@ -110,14 +129,17 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer[A <: Product : TypeTag]( left: DataFrame, right: DataFrame, planFunction: (SparkPlan, SparkPlan) => SparkPlan, - expectedAnswer: Seq[A]): Unit = { + expectedAnswer: Seq[A], + sortAnswers: Boolean): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) - checkAnswer(left, right, planFunction, expectedRows) + checkAnswer(left, right, planFunction, expectedRows, sortAnswers) } /** @@ -126,13 +148,16 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer[A <: Product : TypeTag]( input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, - expectedAnswer: Seq[A]): Unit = { + expectedAnswer: Seq[A], + sortAnswers: Boolean): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) - checkAnswer(input, planFunction, expectedRows) + checkAnswer(input, planFunction, expectedRows, sortAnswers) } /** @@ -231,11 +256,14 @@ object SparkPlanTest { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ def checkAnswer( input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, - expectedAnswer: Seq[Row]): Option[String] = { + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) @@ -254,7 +282,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(sparkAnswer, expectedAnswer).map { errorMessage => + compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match for Spark plan: | $outputPlan @@ -266,7 +294,7 @@ object SparkPlanTest { private def compareAnswers( sparkAnswer: Seq[Row], expectedAnswer: Seq[Row], - sort: Boolean = true): Option[String] = { + sort: Boolean): Option[String] = { def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index c5ab74ad72266..c697c319980dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -40,7 +40,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { // TODO: randomized spilling to ensure that merging is tested at least once for every data type. for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); - nullable <- Seq(false); + nullable <- Seq(true, false); sortOrder <- Seq('a.asc :: Nil); randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) ) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 5707d2fb300ae..f498f8c063e5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -47,7 +47,8 @@ class OuterJoinSuite extends SparkPlanTest { (1, 2.0, null, null), (2, 1.0, 2, 3.0), (3, 3.0, null, null) - )) + ), + sortAnswers = true) checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), @@ -55,7 +56,8 @@ class OuterJoinSuite extends SparkPlanTest { (2, 1.0, 2, 3.0), (null, null, 3, 2.0), (null, null, 4, 1.0) - )) + ), + sortAnswers = true) checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), @@ -65,7 +67,8 @@ class OuterJoinSuite extends SparkPlanTest { (3, 3.0, null, null), (null, null, 3, 2.0), (null, null, 4, 1.0) - )) + ), + sortAnswers = true) } test("broadcast hash outer join") { @@ -75,7 +78,8 @@ class OuterJoinSuite extends SparkPlanTest { (1, 2.0, null, null), (2, 1.0, 2, 3.0), (3, 3.0, null, null) - )) + ), + sortAnswers = true) checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), @@ -83,6 +87,7 @@ class OuterJoinSuite extends SparkPlanTest { (2, 1.0, 2, 3.0), (null, null, 3, 2.0), (null, null, 4, 1.0) - )) + ), + sortAnswers = true) } }