Skip to content

Commit

Permalink
[SPARK-26502][SQL] Move hiveResultString() from QueryExecution to Hiv…
Browse files Browse the repository at this point in the history
…eResult

## What changes were proposed in this pull request?

In the PR, I propose to move `hiveResultString()` out of `QueryExecution` and put it to a separate object.

Closes #23409 from MaxGekk/hive-result-string.

Lead-authored-by: Maxim Gekk <maxim.gekk@databricks.com>
Co-authored-by: Maxim Gekk <max.gekk@gmail.com>
Signed-off-by: Herman van Hovell <hvanhovell@databricks.com>
  • Loading branch information
2 people authored and hvanhovell committed Jan 3, 2019
1 parent 56967b7 commit 2a30deb
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 95 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

/**
* Runs a query returning the result in Hive compatible form.
*/
object HiveResult {
/**
* Returns the result as a hive compatible sequence of strings. This is used in tests and
* `SparkSQLDriver` for CLI applications.
*/
def hiveResultString(executedPlan: SparkPlan): Seq[String] = executedPlan match {
case ExecutedCommandExec(desc: DescribeTableCommand) =>
// If it is a describe command for a Hive table, we want to have the output format
// be similar with Hive.
executedPlan.executeCollectPublic().map {
case Row(name: String, dataType: String, comment) =>
Seq(name, dataType,
Option(comment.asInstanceOf[String]).getOrElse(""))
.map(s => String.format(s"%-20s", s))
.mkString("\t")
}
// SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp.
case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended =>
command.executeCollect().map(_.getString(1))
case other =>
val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq
// We need the types so we can output struct field names
val types = executedPlan.output.map(_.dataType)
// Reformat to match hive tab delimited output.
result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t"))
}

/** Formats a datum (based on the given data type) and returns the string representation. */
private def toHiveString(a: (Any, DataType)): String = {
val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType,
BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType)
val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone)

def formatDecimal(d: java.math.BigDecimal): String = {
if (d.compareTo(java.math.BigDecimal.ZERO) == 0) {
java.math.BigDecimal.ZERO.toPlainString
} else {
d.stripTrailingZeros().toPlainString
}
}

/** Hive outputs fields of structs slightly differently than top level attributes. */
def toHiveStructString(a: (Any, DataType)): String = a match {
case (struct: Row, StructType(fields)) =>
struct.toSeq.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}"""
}.mkString("{", ",", "}")
case (seq: Seq[_], ArrayType(typ, _)) =>
seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
case (map: Map[_, _], MapType(kType, vType, _)) =>
map.map {
case (key, value) =>
toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
}.toSeq.sorted.mkString("{", ",", "}")
case (null, _) => "null"
case (s: String, StringType) => "\"" + s + "\""
case (decimal, DecimalType()) => decimal.toString
case (interval, CalendarIntervalType) => interval.toString
case (other, tpe) if primitiveTypes contains tpe => other.toString
}

a match {
case (struct: Row, StructType(fields)) =>
struct.toSeq.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}"""
}.mkString("{", ",", "}")
case (seq: Seq[_], ArrayType(typ, _)) =>
seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
case (map: Map[_, _], MapType(kType, vType, _)) =>
map.map {
case (key, value) =>
toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
}.toSeq.sorted.mkString("{", ",", "}")
case (null, _) => "NULL"
case (d: Date, DateType) =>
DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d))
case (t: Timestamp, TimestampType) =>
DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t), timeZone)
case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8)
case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal)
case (interval, CalendarIntervalType) => interval.toString
case (other, tpe) if primitiveTypes.contains(tpe) => other.toString
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,20 @@
package org.apache.spark.sql.execution

import java.io.{BufferedWriter, OutputStreamWriter}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}

import org.apache.hadoop.fs.Path

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -109,90 +104,6 @@ class QueryExecution(
ReuseExchange(sparkSession.sessionState.conf),
ReuseSubquery(sparkSession.sessionState.conf))

/**
* Returns the result as a hive compatible sequence of strings. This is used in tests and
* `SparkSQLDriver` for CLI applications.
*/
def hiveResultString(): Seq[String] = executedPlan match {
case ExecutedCommandExec(desc: DescribeTableCommand) =>
// If it is a describe command for a Hive table, we want to have the output format
// be similar with Hive.
desc.run(sparkSession).map {
case Row(name: String, dataType: String, comment) =>
Seq(name, dataType,
Option(comment.asInstanceOf[String]).getOrElse(""))
.map(s => String.format(s"%-20s", s))
.mkString("\t")
}
// SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp.
case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended =>
command.executeCollect().map(_.getString(1))
case other =>
val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq
// We need the types so we can output struct field names
val types = analyzed.output.map(_.dataType)
// Reformat to match hive tab delimited output.
result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t"))
}

/** Formats a datum (based on the given data type) and returns the string representation. */
private def toHiveString(a: (Any, DataType)): String = {
val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType,
BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType)

def formatDecimal(d: java.math.BigDecimal): String = {
if (d.compareTo(java.math.BigDecimal.ZERO) == 0) {
java.math.BigDecimal.ZERO.toPlainString
} else {
d.stripTrailingZeros().toPlainString
}
}

/** Hive outputs fields of structs slightly differently than top level attributes. */
def toHiveStructString(a: (Any, DataType)): String = a match {
case (struct: Row, StructType(fields)) =>
struct.toSeq.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}"""
}.mkString("{", ",", "}")
case (seq: Seq[_], ArrayType(typ, _)) =>
seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
case (map: Map[_, _], MapType(kType, vType, _)) =>
map.map {
case (key, value) =>
toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
}.toSeq.sorted.mkString("{", ",", "}")
case (null, _) => "null"
case (s: String, StringType) => "\"" + s + "\""
case (decimal, DecimalType()) => decimal.toString
case (interval, CalendarIntervalType) => interval.toString
case (other, tpe) if primitiveTypes contains tpe => other.toString
}

a match {
case (struct: Row, StructType(fields)) =>
struct.toSeq.zip(fields).map {
case (v, t) => s""""${t.name}":${toHiveStructString((v, t.dataType))}"""
}.mkString("{", ",", "}")
case (seq: Seq[_], ArrayType(typ, _)) =>
seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
case (map: Map[_, _], MapType(kType, vType, _)) =>
map.map {
case (key, value) =>
toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
}.toSeq.sorted.mkString("{", ",", "}")
case (null, _) => "NULL"
case (d: Date, DateType) =>
DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d))
case (t: Timestamp, TimestampType) =>
DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t),
DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone))
case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8)
case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal)
case (interval, CalendarIntervalType) => interval.toString
case (other, tpe) if primitiveTypes.contains(tpe) => other.toString
}
}

def simpleString: String = withRedaction {
val concat = new StringConcat()
concat.append("== Physical Plan ==\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import java.util.{Locale, TimeZone}

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile}
import org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeTableCommand}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -287,7 +287,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
val schema = df.schema
val notIncludedMsg = "[not included in comparison]"
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x")
val answer = hiveResultString(df.queryExecution.executedPlan)
.map(_.replaceAll("#\\d+", "#x")
.replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/")
.replaceAll("Created By.*", s"Created By $notIncludedMsg")
.replaceAll("Created Time.*", s"Created Time $notIncludedMsg")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.HiveResult.hiveResultString


private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext)
Expand Down Expand Up @@ -61,7 +62,7 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont
context.sparkContext.setJobDescription(command)
val execution = context.sessionState.executePlan(context.sql(command).logicalPlan)
hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) {
execution.hiveResultString()
hiveResultString(execution.executedPlan)
}
tableSchema = getResultSetSchema(execution)
new CommandProcessorResponse(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ private[hive] class TestHiveSparkSession(

protected[hive] implicit class SqlCmd(sql: String) {
def cmd: () => Unit = {
() => new TestHiveQueryExecution(sql).hiveResultString(): Unit
() => new TestHiveQueryExecution(sql).executedPlan.executeCollect(): Unit
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.Dataset
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution}
Expand Down Expand Up @@ -345,7 +346,8 @@ abstract class HiveComparisonTest
val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
val query = new TestHiveQueryExecution(queryString.replace("../../data", testDataPath))
def getResult(): Seq[String] = {
SQLExecution.withNewExecutionId(query.sparkSession, query)(query.hiveResultString())
SQLExecution.withNewExecutionId(
query.sparkSession, query)(hiveResultString(query.executedPlan))
}
try { (query, prepareAnswer(query, getResult())) } catch {
case e: Throwable =>
Expand Down

0 comments on commit 2a30deb

Please sign in to comment.