Skip to content

Commit

Permalink
Remove a bunch of overloaded methods to avoid default args. issue
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 10, 2015
1 parent 2bbac9c commit 6beb467
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._

class SortSuite extends SparkPlanTest {
Expand All @@ -34,13 +35,13 @@ class SortSuite extends SparkPlanTest {
checkAnswer(
input.toDF("a", "b", "c"),
ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan),
input.sortBy(t => (t._1, t._2)),
input.sortBy(t => (t._1, t._2)).map(Row.fromTuple),
sortAnswers = false)

checkAnswer(
input.toDF("a", "b", "c"),
ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan),
input.sortBy(t => (t._2, t._1)),
input.sortBy(t => (t._2, t._1)).map(Row.fromTuple),
sortAnswers = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class SparkPlanTest extends SparkFunSuite {
planFunction: SparkPlan => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
checkAnswer(
doCheckAnswer(
input :: Nil,
(plans: Seq[SparkPlan]) => planFunction(plans.head),
expectedAnswer,
Expand All @@ -71,13 +71,13 @@ class SparkPlanTest extends SparkFunSuite {
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer(
protected def checkAnswer2(
left: DataFrame,
right: DataFrame,
planFunction: (SparkPlan, SparkPlan) => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
checkAnswer(
doCheckAnswer(
left :: right :: Nil,
(plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)),
expectedAnswer,
Expand All @@ -87,13 +87,13 @@ class SparkPlanTest extends SparkFunSuite {
/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param planFunction a function which accepts a sequence of input SparkPlans and uses them to
* instantiate the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer(
protected def doCheckAnswer(
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
Expand All @@ -104,62 +104,6 @@ class SparkPlanTest extends SparkFunSuite {
}
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer[A <: Product : TypeTag](
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedAnswer: Seq[A],
sortAnswers: Boolean = true): Unit = {
val expectedRows = expectedAnswer.map(Row.fromTuple)
checkAnswer(input, planFunction, expectedRows, sortAnswers)
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param left the left input data to be used.
* @param right the right input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer[A <: Product : TypeTag](
left: DataFrame,
right: DataFrame,
planFunction: (SparkPlan, SparkPlan) => SparkPlan,
expectedAnswer: Seq[A],
sortAnswers: Boolean = true): Unit = {
val expectedRows = expectedAnswer.map(Row.fromTuple)
checkAnswer(left, right, planFunction, expectedRows, sortAnswers)
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer[A <: Product : TypeTag](
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[A],
sortAnswers: Boolean = true): Unit = {
val expectedRows = expectedAnswer.map(Row.fromTuple)
checkAnswer(input, planFunction, expectedRows, sortAnswers)
}

/**
* Runs the plan and makes sure the answer matches the result produced by a reference plan.
* @param input the input data to be used.
Expand All @@ -172,7 +116,7 @@ class SparkPlanTest extends SparkFunSuite {
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer(
protected def checkThatPlansAgree(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
// TODO: this test is going to fail until we implement a proper iterator interface
// with a close() method.
TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true")
checkAnswer(
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
(child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
(child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
Expand All @@ -51,7 +51,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
test("sort followed by limit") {
TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false")
try {
checkAnswer(
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
(child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
(child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
Expand All @@ -66,7 +66,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
test("sorting does not crash for large inputs") {
val sortOrder = 'a.asc :: Nil
val stringLength = 1024 * 1024 * 2
checkAnswer(
checkThatPlansAgree(
Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
Sort(sortOrder, global = true, _: SparkPlan),
Expand All @@ -93,7 +93,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
StructType(StructField("a", dataType, nullable = true) :: Nil)
)
assert(UnsafeExternalSort.supportsSchema(inputDf.schema))
checkAnswer(
checkThatPlansAgree(
inputDf,
UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23),
Sort(sortOrder, global = true, _: SparkPlan),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan}
import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter}
Expand All @@ -41,48 +42,48 @@ class OuterJoinSuite extends SparkPlanTest {
val condition = Some(LessThan('b, 'd))

test("shuffled hash outer join") {
checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right),
Seq(
(1, 2.0, null, null),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null)
))
).map(Row.fromTuple))

checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
Seq(
(2, 1.0, 2, 3.0),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
))
).map(Row.fromTuple))

checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right),
Seq(
(1, 2.0, null, null),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
))
).map(Row.fromTuple))
}

test("broadcast hash outer join") {
checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right),
Seq(
(1, 2.0, null, null),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null)
))
).map(Row.fromTuple))

checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
Seq(
(2, 1.0, 2, 3.0),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
))
).map(Row.fromTuple))
}
}

0 comments on commit 6beb467

Please sign in to comment.