diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index e4b27c69bbc45..c74fa2da42afa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -135,6 +135,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { private val notIncludedMsg = "[not included in comparison]" private val clsName = this.getClass.getCanonicalName + protected val emptySchema = StructType(Seq.empty).catalogString + protected override def sparkConf: SparkConf = super.sparkConf // Fewer shuffle partitions to speed up testing. .set(SQLConf.SHUFFLE_PARTITIONS, 4) @@ -323,11 +325,11 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { } // Run the SQL queries preparing them for comparison. val outputs: Seq[QueryOutput] = queries.map { sql => - val (schema, output) = getNormalizedResult(localSparkSession, sql) + val (schema, output) = handleExceptions(getNormalizedResult(localSparkSession, sql)) // We might need to do some query canonicalization in the future. QueryOutput( sql = sql, - schema = schema.catalogString, + schema = schema, output = output.mkString("\n").replaceAll("\\s+$", "")) } @@ -388,49 +390,58 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { } } - /** Executes a query and returns the result as (schema of the output, normalized output). */ - private def getNormalizedResult(session: SparkSession, sql: String): (StructType, Seq[String]) = { - // Returns true if the plan is supposed to be sorted. - def isSorted(plan: LogicalPlan): Boolean = plan match { - case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false - case _: DescribeCommandBase - | _: DescribeColumnCommand - | _: DescribeTableStatement - | _: DescribeColumnStatement => true - case PhysicalOperation(_, _, Sort(_, true, _)) => true - case _ => plan.children.iterator.exists(isSorted) - } - + /** + * This method handles exceptions occurred during query execution as they may need special care + * to become comparable to the expected output. + * + * @param result a function that returns a pair of schema and output + */ + protected def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = { try { - val df = session.sql(sql) - val schema = df.schema - // Get answer, but also get rid of the #1234 expression ids that show up in explain plans - val answer = SQLExecution.withNewExecutionId(session, df.queryExecution, Some(sql)) { - hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg) - } - - // If the output is not pre-sorted, sort it. - if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) - + result } catch { case a: AnalysisException => // Do not output the logical plan tree which contains expression IDs. // Also implement a crude way of masking expression IDs in the error message // with a generic pattern "###". val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage - (StructType(Seq.empty), Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x"))) + (emptySchema, Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x"))) case s: SparkException if s.getCause != null => // For a runtime exception, it is hard to match because its message contains // information of stage, task ID, etc. // To make result matching simpler, here we match the cause of the exception if it exists. val cause = s.getCause - (StructType(Seq.empty), Seq(cause.getClass.getName, cause.getMessage)) + (emptySchema, Seq(cause.getClass.getName, cause.getMessage)) case NonFatal(e) => // If there is an exception, put the exception class followed by the message. - (StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage)) + (emptySchema, Seq(e.getClass.getName, e.getMessage)) } } + /** Executes a query and returns the result as (schema of the output, normalized output). */ + private def getNormalizedResult(session: SparkSession, sql: String): (String, Seq[String]) = { + // Returns true if the plan is supposed to be sorted. + def isSorted(plan: LogicalPlan): Boolean = plan match { + case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false + case _: DescribeCommandBase + | _: DescribeColumnCommand + | _: DescribeTableStatement + | _: DescribeColumnStatement => true + case PhysicalOperation(_, _, Sort(_, true, _)) => true + case _ => plan.children.iterator.exists(isSorted) + } + + val df = session.sql(sql) + val schema = df.schema.catalogString + // Get answer, but also get rid of the #1234 expression ids that show up in explain plans + val answer = SQLExecution.withNewExecutionId(session, df.queryExecution, Some(sql)) { + hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg) + } + + // If the output is not pre-sorted, sort it. + if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) + } + protected def replaceNotIncludedMsg(line: String): String = { line.replaceAll("#\\d+", "#x") .replaceAll( diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala index 613c1655727bb..36fcde35982cc 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -28,7 +28,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.sql.{AnalysisException, SQLQueryTestSuite} +import org.apache.spark.sql.SQLQueryTestSuite import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.util.fileToString import org.apache.spark.sql.execution.HiveResult @@ -123,7 +123,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { // Run the SQL queries preparing them for comparison. val outputs: Seq[QueryOutput] = queries.map { sql => - val output = getNormalizedResult(statement, sql) + val (_, output) = handleExceptions(getNormalizedResult(statement, sql)) // We might need to do some query canonicalization in the future. QueryOutput( sql = sql, @@ -142,8 +142,9 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { "Try regenerate the result files.") Seq.tabulate(outputs.size) { i => val sql = segments(i * 3 + 1).trim + val schema = segments(i * 3 + 2).trim val originalOut = segments(i * 3 + 3) - val output = if (isNeedSort(sql)) { + val output = if (schema != emptySchema && isNeedSort(sql)) { originalOut.split("\n").sorted.mkString("\n") } else { originalOut @@ -254,32 +255,30 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite { } } - private def getNormalizedResult(statement: Statement, sql: String): Seq[String] = { - try { - val rs = statement.executeQuery(sql) - val cols = rs.getMetaData.getColumnCount - val buildStr = () => (for (i <- 1 to cols) yield { - getHiveResult(rs.getObject(i)) - }).mkString("\t") - - val answer = Iterator.continually(rs.next()).takeWhile(identity).map(_ => buildStr()).toSeq - .map(replaceNotIncludedMsg) - if (isNeedSort(sql)) { - answer.sorted - } else { - answer + /** ThriftServer wraps the root exception, so it needs to be extracted. */ + override def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = { + super.handleExceptions { + try { + result + } catch { + case NonFatal(e) => throw ExceptionUtils.getRootCause(e) } - } catch { - case a: AnalysisException => - // Do not output the logical plan tree which contains expression IDs. - // Also implement a crude way of masking expression IDs in the error message - // with a generic pattern "###". - val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage - Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")).sorted - case NonFatal(e) => - val rootCause = ExceptionUtils.getRootCause(e) - // If there is an exception, put the exception class followed by the message. - Seq(rootCause.getClass.getName, rootCause.getMessage) + } + } + + private def getNormalizedResult(statement: Statement, sql: String): (String, Seq[String]) = { + val rs = statement.executeQuery(sql) + val cols = rs.getMetaData.getColumnCount + val buildStr = () => (for (i <- 1 to cols) yield { + getHiveResult(rs.getObject(i)) + }).mkString("\t") + + val answer = Iterator.continually(rs.next()).takeWhile(identity).map(_ => buildStr()).toSeq + .map(replaceNotIncludedMsg) + if (isNeedSort(sql)) { + ("", answer.sorted) + } else { + ("", answer) } }