Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10176][SQL] Show partially analyzed plans when checkAnswer fails to analyze #8389

Closed
wants to merge 8 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.util._
* Provides helper methods for comparing plans.
*/
class PlanTest extends SparkFunSuite {

/**
* Since attribute references are given globally unique ids during analysis,
* we must normalize them to check if two different queries are identical.
Expand Down
26 changes: 21 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation

class QueryTest extends PlanTest {
abstract class QueryTest extends PlanTest {

protected def sqlContext: SQLContext

// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
Expand Down Expand Up @@ -56,18 +58,32 @@ class QueryTest extends PlanTest {
* @param df the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
protected def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = {
protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
val analyzedDF = try df catch {
case ae: AnalysisException =>
sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false)
val partiallyAnalzyedPlan = df.queryExecution.analyzed
sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true)
fail(
s"""
|Failed to analyze query: $ae
|$partiallyAnalzyedPlan
|
|${stackTraceToString(ae)}
|""".stripMargin)
}

QueryTest.checkAnswer(df, expectedAnswer) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}

protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = {
protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = {
checkAnswer(df, Seq(expectedAnswer))
}

protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = {
protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = {
checkAnswer(df, expectedAnswer.collect())
}

Expand Down Expand Up @@ -96,7 +112,7 @@ object QueryTest {
* @param df the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Row(3, 4))

intercept[AnalysisException] {
checkAnswer(
df.selectExpr("length(c)"), // int type of the argument is unacceptable
Row("5.0000"))
df.selectExpr("length(c)") // int type of the argument is unacceptable
}
}

Expand All @@ -284,63 +282,46 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
}

test("number format function") {
val tuple =
("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
3.13223f, 4, 5L, 6.48173d, Decimal(7.128381))
val df =
Seq(tuple)
.toDF(
"a", // string "aa"
"b", // byte 1
"c", // short 2
"d", // float 3.13223f
"e", // integer 4
"f", // long 5L
"g", // double 6.48173d
"h") // decimal 7.128381

checkAnswer(
df.select(format_number($"f", 4)),
val df = sqlContext.range(1)

checkAnswer(
df.select(format_number(lit(5L), 4)),
Row("5.0000"))

checkAnswer(
df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer
df.select(format_number(lit(1.toByte), 4)), // convert the 1st argument to integer
Row("1.0000"))

checkAnswer(
df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer
df.select(format_number(lit(2.toShort), 4)), // convert the 1st argument to integer
Row("2.0000"))

checkAnswer(
df.selectExpr("format_number(d, e)"), // convert the 1st argument to double
df.select(format_number(lit(3.1322.toFloat), 4)), // convert the 1st argument to double
Row("3.1322"))

checkAnswer(
df.selectExpr("format_number(e, e)"), // not convert anything
df.select(format_number(lit(4), 4)), // not convert anything
Row("4.0000"))

checkAnswer(
df.selectExpr("format_number(f, e)"), // not convert anything
df.select(format_number(lit(5L), 4)), // not convert anything
Row("5.0000"))

checkAnswer(
df.selectExpr("format_number(g, e)"), // not convert anything
df.select(format_number(lit(6.48173), 4)), // not convert anything
Row("6.4817"))

checkAnswer(
df.selectExpr("format_number(h, e)"), // not convert anything
df.select(format_number(lit(BigDecimal(7.128381)), 4)), // not convert anything
Row("7.1284"))

intercept[AnalysisException] {
checkAnswer(
df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable
Row("5.0000"))
df.select(format_number(lit("aaa"), 4)) // string type of the 1st argument is unacceptable
}

intercept[AnalysisException] {
checkAnswer(
df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
Row("5.0000"))
df.selectExpr("format_number(4, 6.48173)") // non-integral type 2nd argument is unacceptable
}

// for testing the mutable state of the expression in code gen.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ import org.apache.spark.sql.catalyst.util._
* class's test helper methods can be used, see [[SortSuite]].
*/
private[sql] abstract class SparkPlanTest extends SparkFunSuite {
protected def _sqlContext: SQLContext
protected def sqlContext: SQLContext

/**
* Creates a DataFrame from a local Seq of Product.
*/
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = {
_sqlContext.implicits.localSeqToDataFrameHolder(data)
sqlContext.implicits.localSeqToDataFrameHolder(data)
}

/**
Expand Down Expand Up @@ -98,7 +98,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match {
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
Expand All @@ -122,7 +122,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite {
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(
input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match {
input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
Expand All @@ -149,13 +149,13 @@ object SparkPlanTest {
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean,
_sqlContext: SQLContext): Option[String] = {
sqlContext: SQLContext): Option[String] = {

val outputPlan = planFunction(input.queryExecution.sparkPlan)
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)

val expectedAnswer: Seq[Row] = try {
executePlan(expectedOutputPlan, _sqlContext)
executePlan(expectedOutputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
Expand All @@ -170,7 +170,7 @@ object SparkPlanTest {
}

val actualAnswer: Seq[Row] = try {
executePlan(outputPlan, _sqlContext)
executePlan(outputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
Expand Down Expand Up @@ -210,12 +210,12 @@ object SparkPlanTest {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean,
_sqlContext: SQLContext): Option[String] = {
sqlContext: SQLContext): Option[String] = {

val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))

val sparkAnswer: Seq[Row] = try {
executePlan(outputPlan, _sqlContext)
executePlan(outputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
Expand Down Expand Up @@ -278,10 +278,10 @@ object SparkPlanTest {
}
}

private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = {
private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
val resolvedPlan = _sqlContext.prepareForExecution.execute(
val resolvedPlan = sqlContext.prepareForExecution.execute(
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext

private[json] trait TestJsonData {
protected def _sqlContext: SQLContext
protected def sqlContext: SQLContext

def primitiveFieldAndType: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{"string":"this is a simple string.",
"integer":10,
"long":21474836470,
Expand All @@ -35,7 +35,7 @@ private[json] trait TestJsonData {
}""" :: Nil)

def primitiveFieldValueTypeConflict: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1,
"num_bool":true, "num_str":13.1, "str_bool":"str1"}""" ::
"""{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null,
Expand All @@ -46,14 +46,14 @@ private[json] trait TestJsonData {
"num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil)

def jsonNullStruct: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":{}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":""}""" ::
"""{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil)

def complexFieldValueTypeConflict: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{"num_struct":11, "str_array":[1, 2, 3],
"array":[], "struct_array":[], "struct": {}}""" ::
"""{"num_struct":{"field":false}, "str_array":null,
Expand All @@ -64,22 +64,22 @@ private[json] trait TestJsonData {
"array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil)

def arrayElementTypeConflict: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
"array2": [{"field":214748364700}, {"field":1}]}""" ::
"""{"array3": [{"field":"str"}, {"field":1}]}""" ::
"""{"array3": [1, 2, 3]}""" :: Nil)

def missingFields: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{"a":true}""" ::
"""{"b":21474836470}""" ::
"""{"c":[33, 44]}""" ::
"""{"d":{"field":true}}""" ::
"""{"e":"str"}""" :: Nil)

def complexFieldAndType1: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{"struct":{"field1": true, "field2": 92233720368547758070},
"structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]},
"arrayOfString":["str1", "str2"],
Expand All @@ -95,7 +95,7 @@ private[json] trait TestJsonData {
}""" :: Nil)

def complexFieldAndType2: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}],
"complexArrayOfStruct": [
{
Expand Down Expand Up @@ -149,15 +149,15 @@ private[json] trait TestJsonData {
}""" :: Nil)

def mapType1: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{"map": {"a": 1}}""" ::
"""{"map": {"b": 2}}""" ::
"""{"map": {"c": 3}}""" ::
"""{"map": {"c": 1, "d": 4}}""" ::
"""{"map": {"e": null}}""" :: Nil)

def mapType2: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
"""{"map": {"b": {"field2": 2}}}""" ::
"""{"map": {"c": {"field1": [], "field2": 4}}}""" ::
Expand All @@ -166,21 +166,21 @@ private[json] trait TestJsonData {
"""{"map": {"f": {"field1": null}}}""" :: Nil)

def nullsInArrays: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{"field1":[[null], [[["Test"]]]]}""" ::
"""{"field2":[null, [{"Test":1}]]}""" ::
"""{"field3":[[null], [{"Test":"2"}]]}""" ::
"""{"field4":[[null, [1,2,3]]]}""" :: Nil)

def jsonArray: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""[{"a":"str_a_1"}]""" ::
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""[]""" :: Nil)

def corruptRecords: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a":1, b:2}""" ::
Expand All @@ -189,7 +189,7 @@ private[json] trait TestJsonData {
"""]""" :: Nil)

def emptyRecords: RDD[String] =
_sqlContext.sparkContext.parallelize(
sqlContext.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a": {}}""" ::
Expand All @@ -198,7 +198,7 @@ private[json] trait TestJsonData {
"""]""" :: Nil)


lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)

def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]())
def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]())
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq

protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
val fsPath = new Path(path)
val fs = fsPath.getFileSystem(configuration)
val fs = fsPath.getFileSystem(hadoopConfiguration)
val parquetFiles = fs.listStatus(fsPath, new PathFilter {
override def accept(path: Path): Boolean = pathFilter(path)
}).toSeq.asJava

val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true)

val footers =
ParquetFileReader.readAllFootersInParallel(hadoopConfiguration, parquetFiles, true)
footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema
}

Expand Down
Loading