diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala new file mode 100644 index 0000000000000..22d3ca958a210 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Runs a query returning the result in Hive compatible form. + */ +object HiveResult { + /** + * Returns the result as a hive compatible sequence of strings. This is used in tests and + * `SparkSQLDriver` for CLI applications. + */ + def hiveResultString(executedPlan: SparkPlan): Seq[String] = executedPlan match { + case ExecutedCommandExec(desc: DescribeTableCommand) => + // If it is a describe command for a Hive table, we want to have the output format + // be similar with Hive. + executedPlan.executeCollectPublic().map { + case Row(name: String, dataType: String, comment) => + Seq(name, dataType, + Option(comment.asInstanceOf[String]).getOrElse("")) + .map(s => String.format(s"%-20s", s)) + .mkString("\t") + } + // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. + case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => + command.executeCollect().map(_.getString(1)) + case other => + val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq + // We need the types so we can output struct field names + val types = executedPlan.output.map(_.dataType) + // Reformat to match hive tab delimited output. + result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")) + } + + /** Formats a datum (based on the given data type) and returns the string representation. */ + private def toHiveString(a: (Any, DataType)): String = { + val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, + BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType) + val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone) + + def formatDecimal(d: java.math.BigDecimal): String = { + if (d.compareTo(java.math.BigDecimal.ZERO) == 0) { + java.math.BigDecimal.ZERO.toPlainString + } else { + d.stripTrailingZeros().toPlainString + } + } + + /** Hive outputs fields of structs slightly differently than top level attributes. */ + def toHiveStructString(a: (Any, DataType)): String = a match { + case (struct: Row, StructType(fields)) => + struct.toSeq.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ, _)) => + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_, _], MapType(kType, vType, _)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "null" + case (s: String, StringType) => "\"" + s + "\"" + case (decimal, DecimalType()) => decimal.toString + case (interval, CalendarIntervalType) => interval.toString + case (other, tpe) if primitiveTypes contains tpe => other.toString + } + + a match { + case (struct: Row, StructType(fields)) => + struct.toSeq.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ, _)) => + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_, _], MapType(kType, vType, _)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "NULL" + case (d: Date, DateType) => + DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) + case (t: Timestamp, TimestampType) => + DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t), timeZone) + case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) + case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) + case (interval, CalendarIntervalType) => interval.toString + case (other, tpe) if primitiveTypes.contains(tpe) => other.toString + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 7fccbf65d8525..72499aa936a56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -18,25 +18,20 @@ package org.apache.spark.sql.execution import java.io.{BufferedWriter, OutputStreamWriter} -import java.nio.charset.StandardCharsets -import java.sql.{Date, Timestamp} import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.util.Utils /** @@ -109,90 +104,6 @@ class QueryExecution( ReuseExchange(sparkSession.sessionState.conf), ReuseSubquery(sparkSession.sessionState.conf)) - /** - * Returns the result as a hive compatible sequence of strings. This is used in tests and - * `SparkSQLDriver` for CLI applications. - */ - def hiveResultString(): Seq[String] = executedPlan match { - case ExecutedCommandExec(desc: DescribeTableCommand) => - // If it is a describe command for a Hive table, we want to have the output format - // be similar with Hive. - desc.run(sparkSession).map { - case Row(name: String, dataType: String, comment) => - Seq(name, dataType, - Option(comment.asInstanceOf[String]).getOrElse("")) - .map(s => String.format(s"%-20s", s)) - .mkString("\t") - } - // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. - case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => - command.executeCollect().map(_.getString(1)) - case other => - val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq - // We need the types so we can output struct field names - val types = analyzed.output.map(_.dataType) - // Reformat to match hive tab delimited output. - result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")) - } - - /** Formats a datum (based on the given data type) and returns the string representation. */ - private def toHiveString(a: (Any, DataType)): String = { - val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, - BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType) - - def formatDecimal(d: java.math.BigDecimal): String = { - if (d.compareTo(java.math.BigDecimal.ZERO) == 0) { - java.math.BigDecimal.ZERO.toPlainString - } else { - d.stripTrailingZeros().toPlainString - } - } - - /** Hive outputs fields of structs slightly differently than top level attributes. */ - def toHiveStructString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => - struct.toSeq.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" - }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => - seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => - map.map { - case (key, value) => - toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) - }.toSeq.sorted.mkString("{", ",", "}") - case (null, _) => "null" - case (s: String, StringType) => "\"" + s + "\"" - case (decimal, DecimalType()) => decimal.toString - case (interval, CalendarIntervalType) => interval.toString - case (other, tpe) if primitiveTypes contains tpe => other.toString - } - - a match { - case (struct: Row, StructType(fields)) => - struct.toSeq.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}""" - }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => - seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => - map.map { - case (key, value) => - toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) - }.toSeq.sorted.mkString("{", ",", "}") - case (null, _) => "NULL" - case (d: Date, DateType) => - DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) - case (t: Timestamp, TimestampType) => - DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t), - DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)) - case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) - case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) - case (interval, CalendarIntervalType) => interval.toString - case (other, tpe) if primitiveTypes.contains(tpe) => other.toString - } - } - def simpleString: String = withRedaction { val concat = new StringConcat() concat.append("== Physical Plan ==\n") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index b2515226d9a14..24b312348bd67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -22,11 +22,11 @@ import java.util.{Locale, TimeZone} import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} +import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeTableCommand} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -287,7 +287,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { val schema = df.schema val notIncludedMsg = "[not included in comparison]" // Get answer, but also get rid of the #1234 expression ids that show up in explain plans - val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x") + val answer = hiveResultString(df.queryExecution.executedPlan) + .map(_.replaceAll("#\\d+", "#x") .replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/") .replaceAll("Created By.*", s"Created By $notIncludedMsg") .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 6775902173444..960fdd11db15d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.HiveResult.hiveResultString private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext) @@ -61,7 +62,7 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont context.sparkContext.setJobDescription(command) val execution = context.sessionState.executePlan(context.sql(command).logicalPlan) hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) { - execution.hiveResultString() + hiveResultString(execution.executedPlan) } tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 3508affda241a..4c2bc62b9faf8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -297,7 +297,7 @@ private[hive] class TestHiveSparkSession( protected[hive] implicit class SqlCmd(sql: String) { def cmd: () => Unit = { - () => new TestHiveQueryExecution(sql).hiveResultString(): Unit + () => new TestHiveQueryExecution(sql).executedPlan.executeCollect(): Unit } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 272e6f51f5002..66426824573c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} @@ -345,7 +346,8 @@ abstract class HiveComparisonTest val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => val query = new TestHiveQueryExecution(queryString.replace("../../data", testDataPath)) def getResult(): Seq[String] = { - SQLExecution.withNewExecutionId(query.sparkSession, query)(query.hiveResultString()) + SQLExecution.withNewExecutionId( + query.sparkSession, query)(hiveResultString(query.executedPlan)) } try { (query, prepareAnswer(query, getResult())) } catch { case e: Throwable =>