Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 216 additions & 3 deletions sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ import org.scalatest.concurrent.Eventually
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.datasources.DataSourceUtils
Expand Down Expand Up @@ -535,6 +535,95 @@ trait QueryTestBase
.map(_.toFile.length).sum
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
doCheckAnswer(
input :: Nil,
(plans: Seq[SparkPlan]) => planFunction(plans.head),
expectedAnswer,
sortAnswers)
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param left the left input data to be used.
* @param right the right input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer2(
left: DataFrame,
right: DataFrame,
planFunction: (SparkPlan, SparkPlan) => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
doCheckAnswer(
left :: right :: Nil,
(plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)),
expectedAnswer,
sortAnswers)
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
* @param planFunction a function which accepts a sequence of input SparkPlans and uses them to
* instantiate the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def doCheckAnswer(
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
QueryTest
.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}

/**
* Runs the plan and makes sure the answer matches the result produced by a reference plan.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to
* instantiate a reference implementation of the physical operator
* that's being tested. The result of executing this plan will be
* treated as the source-of-truth for the test.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkThatPlansAgree(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
QueryTest.checkAnswer(
input, planFunction, expectedPlanFunction, sortAnswers, spark.sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}

}

/**
Expand Down Expand Up @@ -974,6 +1063,130 @@ object QueryTest extends Assertions {
capturedQueryExecutions
}

/**
* Runs the plan and makes sure the answer matches the result produced by a reference plan.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to
* instantiate a reference implementation of the physical operator
* that's being tested. The result of executing this plan will be
* treated as the source-of-truth for the test.
*/
def checkAnswer(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean,
spark: SQLContext): Option[String] = {

val outputPlan = planFunction(input.queryExecution.sparkPlan)
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)

val expectedAnswer: Seq[Row] = try {
executePlan(expectedOutputPlan, spark)
} catch {
case NonFatal(e) =>
val errorMessage =
s"""
| Exception thrown while executing Spark plan to calculate expected answer:
| $expectedOutputPlan
| == Exception ==
| $e
| ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin
return Some(errorMessage)
}

val actualAnswer: Seq[Row] = try {
executePlan(outputPlan, spark)
} catch {
case NonFatal(e) =>
val errorMessage =
s"""
| Exception thrown while executing Spark plan:
| $outputPlan
| == Exception ==
| $e
| ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin
return Some(errorMessage)
}

compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
s"""
| Results do not match.
| Actual result Spark plan:
| $outputPlan
| Expected result Spark plan:
| $expectedOutputPlan
| $errorMessage
""".stripMargin
}
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
def checkAnswer(
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean,
spark: SQLContext): Option[String] = {

val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))

val sparkAnswer: Seq[Row] = try {
executePlan(outputPlan, spark)
} catch {
case NonFatal(e) =>
val errorMessage =
s"""
| Exception thrown while executing Spark plan:
| $outputPlan
| == Exception ==
| $e
| ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin
return Some(errorMessage)
}

compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
s"""
| Results do not match for Spark plan:
| $outputPlan
| $errorMessage
""".stripMargin
}
}

/**
* Runs the plan
* @param outputPlan SparkPlan to be executed
* @param spark SqlContext used for execution of the plan
*/
def executePlan(outputPlan: SparkPlan, spark: SQLContext): Seq[Row] = {
val execution = new QueryExecution(spark.sparkSession, LocalRelation(Nil)) {
override lazy val sparkPlan: SparkPlan = outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
plan transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.getOrElse(u,
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
}
}
}
execution.executedPlan.executeCollectPublic().toSeq
}

}

class QueryTestSuite extends QueryTest with test.SharedSparkSession {
Expand Down
Loading