Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29359][SQL][TESTS] Better exception handling in (SQL|ThriftServer)QueryTestSuite #26028

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
private val notIncludedMsg = "[not included in comparison]"
private val clsName = this.getClass.getCanonicalName

protected val emptySchema = StructType(Seq.empty).catalogString

protected override def sparkConf: SparkConf = super.sparkConf
// Fewer shuffle partitions to speed up testing.
.set(SQLConf.SHUFFLE_PARTITIONS, 4)
Expand Down Expand Up @@ -323,11 +325,11 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
}
// Run the SQL queries preparing them for comparison.
val outputs: Seq[QueryOutput] = queries.map { sql =>
val (schema, output) = getNormalizedResult(localSparkSession, sql)
val (schema, output) = handleExceptions(getNormalizedResult(localSparkSession, sql))
// We might need to do some query canonicalization in the future.
QueryOutput(
sql = sql,
schema = schema.catalogString,
schema = schema,
output = output.mkString("\n").replaceAll("\\s+$", ""))
}

Expand Down Expand Up @@ -388,49 +390,58 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
}
}

/** Executes a query and returns the result as (schema of the output, normalized output). */
private def getNormalizedResult(session: SparkSession, sql: String): (StructType, Seq[String]) = {
// Returns true if the plan is supposed to be sorted.
def isSorted(plan: LogicalPlan): Boolean = plan match {
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
case _: DescribeCommandBase
| _: DescribeColumnCommand
| _: DescribeTableStatement
| _: DescribeColumnStatement => true
case PhysicalOperation(_, _, Sort(_, true, _)) => true
case _ => plan.children.iterator.exists(isSorted)
}

/**
* This method handles exceptions occurred during query execution as they may need special care
* to become comparable to the expected output.
*
* @param result a function that returns a pair of schema and output
*/
protected def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a function description because we override this differently?

  • SQLQueryTestSuite seems to return (struct<>, ...)
  • ThriftServerQueryTestSuite seems to return ("", answer.sorted)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, both returns a (String, Seq[String]) tuple where the first is the schema and the second is the result. Since it's impossible to get the exact spark schema back from a java.sql.ResultSet we use empty string in ThriftServerQueryTestSuite.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a description to it and to its override.

try {
val df = session.sql(sql)
val schema = df.schema
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
val answer = SQLExecution.withNewExecutionId(session, df.queryExecution, Some(sql)) {
hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
}

// If the output is not pre-sorted, sort it.
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)

result
} catch {
case a: AnalysisException =>
// Do not output the logical plan tree which contains expression IDs.
// Also implement a crude way of masking expression IDs in the error message
// with a generic pattern "###".
val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage
(StructType(Seq.empty), Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")))
(emptySchema, Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test case which this is required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No particular test case. Since I touched this method and StructType(Seq.empty) was used 3 times so I just moved it to a val.

case s: SparkException if s.getCause != null =>
// For a runtime exception, it is hard to match because its message contains
// information of stage, task ID, etc.
// To make result matching simpler, here we match the cause of the exception if it exists.
val cause = s.getCause
(StructType(Seq.empty), Seq(cause.getClass.getName, cause.getMessage))
(emptySchema, Seq(cause.getClass.getName, cause.getMessage))
case NonFatal(e) =>
// If there is an exception, put the exception class followed by the message.
(StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage))
(emptySchema, Seq(e.getClass.getName, e.getMessage))
}
}

/** Executes a query and returns the result as (schema of the output, normalized output). */
private def getNormalizedResult(session: SparkSession, sql: String): (String, Seq[String]) = {
// Returns true if the plan is supposed to be sorted.
def isSorted(plan: LogicalPlan): Boolean = plan match {
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
case _: DescribeCommandBase
| _: DescribeColumnCommand
| _: DescribeTableStatement
| _: DescribeColumnStatement => true
case PhysicalOperation(_, _, Sort(_, true, _)) => true
case _ => plan.children.iterator.exists(isSorted)
}

val df = session.sql(sql)
val schema = df.schema.catalogString
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
val answer = SQLExecution.withNewExecutionId(session, df.queryExecution, Some(sql)) {
hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
}

// If the output is not pre-sorted, sort it.
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
}

protected def replaceNotIncludedMsg(line: String): String = {
line.replaceAll("#\\d+", "#x")
.replaceAll(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.hive.conf.HiveConf.ConfVars

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{AnalysisException, SQLQueryTestSuite}
import org.apache.spark.sql.SQLQueryTestSuite
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.util.fileToString
import org.apache.spark.sql.execution.HiveResult
Expand Down Expand Up @@ -123,7 +123,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {

// Run the SQL queries preparing them for comparison.
val outputs: Seq[QueryOutput] = queries.map { sql =>
val output = getNormalizedResult(statement, sql)
val (_, output) = handleExceptions(getNormalizedResult(statement, sql))
// We might need to do some query canonicalization in the future.
QueryOutput(
sql = sql,
Expand All @@ -142,8 +142,9 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
"Try regenerate the result files.")
Seq.tabulate(outputs.size) { i =>
val sql = segments(i * 3 + 1).trim
val schema = segments(i * 3 + 2).trim
val originalOut = segments(i * 3 + 3)
val output = if (isNeedSort(sql)) {
val output = if (schema != emptySchema && isNeedSort(sql)) {
originalOut.split("\n").sorted.mkString("\n")
} else {
originalOut
Expand Down Expand Up @@ -254,32 +255,30 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite {
}
}

private def getNormalizedResult(statement: Statement, sql: String): Seq[String] = {
try {
val rs = statement.executeQuery(sql)
val cols = rs.getMetaData.getColumnCount
val buildStr = () => (for (i <- 1 to cols) yield {
getHiveResult(rs.getObject(i))
}).mkString("\t")

val answer = Iterator.continually(rs.next()).takeWhile(identity).map(_ => buildStr()).toSeq
.map(replaceNotIncludedMsg)
if (isNeedSort(sql)) {
answer.sorted
} else {
answer
/** ThriftServer wraps the root exception, so it needs to be extracted. */
override def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = {
super.handleExceptions {
try {
result
} catch {
case NonFatal(e) => throw ExceptionUtils.getRootCause(e)
}
} catch {
case a: AnalysisException =>
// Do not output the logical plan tree which contains expression IDs.
// Also implement a crude way of masking expression IDs in the error message
// with a generic pattern "###".
val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage
Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")).sorted
case NonFatal(e) =>
val rootCause = ExceptionUtils.getRootCause(e)
// If there is an exception, put the exception class followed by the message.
Seq(rootCause.getClass.getName, rootCause.getMessage)
}
}

private def getNormalizedResult(statement: Statement, sql: String): (String, Seq[String]) = {
val rs = statement.executeQuery(sql)
val cols = rs.getMetaData.getColumnCount
val buildStr = () => (for (i <- 1 to cols) yield {
getHiveResult(rs.getObject(i))
}).mkString("\t")

val answer = Iterator.continually(rs.next()).takeWhile(identity).map(_ => buildStr()).toSeq
.map(replaceNotIncludedMsg)
if (isNeedSort(sql)) {
("", answer.sorted)
} else {
("", answer)
}
}

Expand Down