Skip to content

Commit

Permalink
Make sorting of answers explicit in SparkPlanTest.checkAnswer().
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 7, 2015
1 parent b81a920 commit 1c7bad8
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ class SortSuite extends SparkPlanTest {

checkAnswer(
input.toDF("a", "b", "c"),
ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan),
input.sorted)
ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan),
input.sortBy(t => (t._1, t._2)),
sortAnswers = false)

checkAnswer(
input.toDF("a", "b", "c"),
ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan),
input.sortBy(t => (t._2, t._1)))
ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan),
input.sortBy(t => (t._2, t._1)),
sortAnswers = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,19 @@ class SparkPlanTest extends SparkFunSuite {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedAnswer: Seq[Row]): Unit = {
checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer)
expectedAnswer: Seq[Row],
sortAnswers: Boolean): Unit = {
checkAnswer(
input :: Nil,
(plans: Seq[SparkPlan]) => planFunction(plans.head),
expectedAnswer,
sortAnswers)
}

/**
Expand All @@ -61,14 +68,20 @@ class SparkPlanTest extends SparkFunSuite {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer(
left: DataFrame,
right: DataFrame,
planFunction: (SparkPlan, SparkPlan) => SparkPlan,
expectedAnswer: Seq[Row]): Unit = {
checkAnswer(left :: right :: Nil,
(plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer)
expectedAnswer: Seq[Row],
sortAnswers: Boolean): Unit = {
checkAnswer(
left :: right :: Nil,
(plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)),
expectedAnswer,
sortAnswers)
}

/**
Expand All @@ -77,12 +90,15 @@ class SparkPlanTest extends SparkFunSuite {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer(
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row]): Unit = {
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match {
expectedAnswer: Seq[Row],
sortAnswers: Boolean): Unit = {
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
Expand All @@ -94,13 +110,16 @@ class SparkPlanTest extends SparkFunSuite {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer[A <: Product : TypeTag](
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedAnswer: Seq[A]): Unit = {
expectedAnswer: Seq[A],
sortAnswers: Boolean): Unit = {
val expectedRows = expectedAnswer.map(Row.fromTuple)
checkAnswer(input, planFunction, expectedRows)
checkAnswer(input, planFunction, expectedRows, sortAnswers)
}

/**
Expand All @@ -110,14 +129,17 @@ class SparkPlanTest extends SparkFunSuite {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer[A <: Product : TypeTag](
left: DataFrame,
right: DataFrame,
planFunction: (SparkPlan, SparkPlan) => SparkPlan,
expectedAnswer: Seq[A]): Unit = {
expectedAnswer: Seq[A],
sortAnswers: Boolean): Unit = {
val expectedRows = expectedAnswer.map(Row.fromTuple)
checkAnswer(left, right, planFunction, expectedRows)
checkAnswer(left, right, planFunction, expectedRows, sortAnswers)
}

/**
Expand All @@ -126,13 +148,16 @@ class SparkPlanTest extends SparkFunSuite {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer[A <: Product : TypeTag](
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[A]): Unit = {
expectedAnswer: Seq[A],
sortAnswers: Boolean): Unit = {
val expectedRows = expectedAnswer.map(Row.fromTuple)
checkAnswer(input, planFunction, expectedRows)
checkAnswer(input, planFunction, expectedRows, sortAnswers)
}

/**
Expand Down Expand Up @@ -231,11 +256,14 @@ object SparkPlanTest {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
def checkAnswer(
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row]): Option[String] = {
expectedAnswer: Seq[Row],
sortAnswers: Boolean): Option[String] = {

val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))

Expand All @@ -254,7 +282,7 @@ object SparkPlanTest {
return Some(errorMessage)
}

compareAnswers(sparkAnswer, expectedAnswer).map { errorMessage =>
compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
s"""
| Results do not match for Spark plan:
| $outputPlan
Expand All @@ -266,7 +294,7 @@ object SparkPlanTest {
private def compareAnswers(
sparkAnswer: Seq[Row],
expectedAnswer: Seq[Row],
sort: Boolean = true): Option[String] = {
sort: Boolean): Option[String] = {
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
// TODO: randomized spilling to ensure that merging is tested at least once for every data type.
for (
dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);
nullable <- Seq(false);
nullable <- Seq(true, false);
sortOrder <- Seq('a.asc :: Nil);
randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,17 @@ class OuterJoinSuite extends SparkPlanTest {
(1, 2.0, null, null),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null)
))
),
sortAnswers = true)

checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
Seq(
(2, 1.0, 2, 3.0),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
))
),
sortAnswers = true)

checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right),
Expand All @@ -65,7 +67,8 @@ class OuterJoinSuite extends SparkPlanTest {
(3, 3.0, null, null),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
))
),
sortAnswers = true)
}

test("broadcast hash outer join") {
Expand All @@ -75,14 +78,16 @@ class OuterJoinSuite extends SparkPlanTest {
(1, 2.0, null, null),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null)
))
),
sortAnswers = true)

checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
Seq(
(2, 1.0, 2, 3.0),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
))
),
sortAnswers = true)
}
}

0 comments on commit 1c7bad8

Please sign in to comment.