Skip to content

Commit

Permalink
Create timestamp formatter once before collect in toHiveString
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxGekk committed Jun 16, 2020
1 parent f0e6d0e commit a152d94
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 53 deletions.
Expand Up @@ -33,6 +33,23 @@ import org.apache.spark.unsafe.types.CalendarInterval
* Runs a query returning the result in Hive compatible form.
*/
object HiveResult {
case class TimeFormatters(date: DateFormatter, timestamp: TimestampFormatter)

def getTimeFormatters: TimeFormatters = {
// The date formatter does not depend on Spark's session time zone controlled by
// the SQL config `spark.sql.session.timeZone`. The `zoneId` parameter is used only in
// parsing of special date values like `now`, `yesterday` and etc. but not in date formatting.
// While formatting of:
// - `java.time.LocalDate`, zone id is not used by `DateTimeFormatter` at all.
// - `java.sql.Date`, the date formatter delegates formatting to the legacy formatter
// which uses the default system time zone `TimeZone.getDefault`. This works correctly
// due to `DateTimeUtils.toJavaDate` which is based on the system time zone too.
val dateFormatter = DateFormatter(ZoneOffset.UTC)
val timestampFormatter = TimestampFormatter.getFractionFormatter(
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))
TimeFormatters(dateFormatter, timestampFormatter)
}

/**
* Returns the result as a hive compatible sequence of strings. This is used in tests and
* `SparkSQLDriver` for CLI applications.
Expand All @@ -55,11 +72,12 @@ object HiveResult {
case command @ ExecutedCommandExec(_: ShowViewsCommand) =>
command.executeCollect().map(_.getString(1))
case other =>
val timeFormatters = getTimeFormatters
val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq
// We need the types so we can output struct field names
val types = executedPlan.output.map(_.dataType)
// Reformat to match hive tab delimited output.
result.map(_.zip(types).map(e => toHiveString(e)))
result.map(_.zip(types).map(e => toHiveString(e, false, timeFormatters)))
.map(_.mkString("\t"))
}

Expand All @@ -72,47 +90,32 @@ object HiveResult {
}
}

// We can create the date formatter only once because it does not depend on Spark's
// session time zone controlled by the SQL config `spark.sql.session.timeZone`.
// The `zoneId` parameter is used only in parsing of special date values like `now`,
// `yesterday` and etc. but not in date formatting. While formatting of:
// - `java.time.LocalDate`, zone id is not used by `DateTimeFormatter` at all.
// - `java.sql.Date`, the date formatter delegates formatting to the legacy formatter
// which uses the default system time zone `TimeZone.getDefault`. This works correctly
// due to `DateTimeUtils.toJavaDate` which is based on the system time zone too.
private val dateFormatter = DateFormatter(
format = DateFormatter.defaultPattern,
// We can set any time zone id. UTC was taken for simplicity.
zoneId = ZoneOffset.UTC,
locale = DateFormatter.defaultLocale,
// Use `FastDateFormat` as the legacy formatter because it is thread-safe.
legacyFormat = LegacyDateFormats.FAST_DATE_FORMAT,
isParsing = false)
private def timestampFormatter = TimestampFormatter.getFractionFormatter(
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))

/** Formats a datum (based on the given data type) and returns the string representation. */
def toHiveString(a: (Any, DataType), nested: Boolean = false): String = a match {
def toHiveString(
a: (Any, DataType),
nested: Boolean,
formatters: TimeFormatters): String = a match {
case (null, _) => if (nested) "null" else "NULL"
case (b, BooleanType) => b.toString
case (d: Date, DateType) => dateFormatter.format(d)
case (ld: LocalDate, DateType) => dateFormatter.format(ld)
case (t: Timestamp, TimestampType) => timestampFormatter.format(t)
case (i: Instant, TimestampType) => timestampFormatter.format(i)
case (d: Date, DateType) => formatters.date.format(d)
case (ld: LocalDate, DateType) => formatters.date.format(ld)
case (t: Timestamp, TimestampType) => formatters.timestamp.format(t)
case (i: Instant, TimestampType) => formatters.timestamp.format(i)
case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8)
case (decimal: java.math.BigDecimal, DecimalType()) => decimal.toPlainString
case (n, _: NumericType) => n.toString
case (s: String, StringType) => if (nested) "\"" + s + "\"" else s
case (interval: CalendarInterval, CalendarIntervalType) => interval.toString
case (seq: Seq[_], ArrayType(typ, _)) =>
seq.map(v => (v, typ)).map(e => toHiveString(e, true)).mkString("[", ",", "]")
seq.map(v => (v, typ)).map(e => toHiveString(e, true, formatters)).mkString("[", ",", "]")
case (m: Map[_, _], MapType(kType, vType, _)) =>
m.map { case (key, value) =>
toHiveString((key, kType), true) + ":" + toHiveString((value, vType), true)
toHiveString((key, kType), true, formatters) + ":" +
toHiveString((value, vType), true, formatters)
}.toSeq.sorted.mkString("{", ",", "}")
case (struct: Row, StructType(fields)) =>
struct.toSeq.zip(fields).map { case (v, t) =>
s""""${t.name}":${toHiveString((v, t.dataType), true)}"""
s""""${t.name}":${toHiveString((v, t.dataType), true, formatters)}"""
}.mkString("{", ",", "}")
case (other, _: UserDefinedType[_]) => other.toString
}
Expand Down
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.connector.InMemoryTableCatalog
import org.apache.spark.sql.execution.HiveResult._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}

Expand All @@ -31,10 +32,10 @@ class HiveResultSuite extends SharedSparkSession {
val dates = Seq("2018-12-28", "1582-10-03", "1582-10-04", "1582-10-15")
val df = dates.toDF("a").selectExpr("cast(a as date) as b")
val executedPlan1 = df.queryExecution.executedPlan
val result = HiveResult.hiveResultString(executedPlan1)
val result = hiveResultString(executedPlan1)
assert(result == dates)
val executedPlan2 = df.selectExpr("array(b)").queryExecution.executedPlan
val result2 = HiveResult.hiveResultString(executedPlan2)
val result2 = hiveResultString(executedPlan2)
assert(result2 == dates.map(x => s"[$x]"))
}
}
Expand All @@ -48,31 +49,31 @@ class HiveResultSuite extends SharedSparkSession {
"1582-10-15 01:02:03")
val df = timestamps.toDF("a").selectExpr("cast(a as timestamp) as b")
val executedPlan1 = df.queryExecution.executedPlan
val result = HiveResult.hiveResultString(executedPlan1)
val result = hiveResultString(executedPlan1)
assert(result == timestamps)
val executedPlan2 = df.selectExpr("array(b)").queryExecution.executedPlan
val result2 = HiveResult.hiveResultString(executedPlan2)
val result2 = hiveResultString(executedPlan2)
assert(result2 == timestamps.map(x => s"[$x]"))
}

test("toHiveString correctly handles UDTs") {
val point = new ExamplePoint(50.0, 50.0)
val tpe = new ExamplePointUDT()
assert(HiveResult.toHiveString((point, tpe)) === "(50.0, 50.0)")
assert(toHiveString((point, tpe), false, getTimeFormatters) === "(50.0, 50.0)")
}

test("decimal formatting in hive result") {
val df = Seq(new java.math.BigDecimal("1")).toDS()
Seq(2, 6, 18).foreach { scala =>
val executedPlan =
df.selectExpr(s"CAST(value AS decimal(38, $scala))").queryExecution.executedPlan
val result = HiveResult.hiveResultString(executedPlan)
val result = hiveResultString(executedPlan)
assert(result.head.split("\\.").last.length === scala)
}

val executedPlan = Seq(java.math.BigDecimal.ZERO).toDS()
.selectExpr(s"CAST(value AS decimal(38, 8))").queryExecution.executedPlan
val result = HiveResult.hiveResultString(executedPlan)
val result = hiveResultString(executedPlan)
assert(result.head === "0.00000000")
}

Expand All @@ -84,7 +85,7 @@ class HiveResultSuite extends SharedSparkSession {
spark.sql(s"CREATE TABLE $ns.$tbl (id bigint) USING $source")
val df = spark.sql(s"SHOW TABLES FROM $ns")
val executedPlan = df.queryExecution.executedPlan
assert(HiveResult.hiveResultString(executedPlan).head == tbl)
assert(hiveResultString(executedPlan).head == tbl)
}
}
}
Expand All @@ -101,7 +102,7 @@ class HiveResultSuite extends SharedSparkSession {
val expected = "id " +
"\tbigint " +
"\tcol1 "
assert(HiveResult.hiveResultString(executedPlan).head == expected)
assert(hiveResultString(executedPlan).head == expected)
}
}
}
Expand Down
Expand Up @@ -36,7 +36,7 @@ import org.apache.hive.service.cli.session.HiveSession
import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLContext}
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.execution.HiveResult.{getTimeFormatters, toHiveString, TimeFormatters}
import org.apache.spark.sql.execution.command.SetCommand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -73,7 +73,11 @@ private[hive] class SparkExecuteStatementOperation(
}
}

def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int): Unit = {
def addNonNullColumnValue(
from: SparkRow,
to: ArrayBuffer[Any],
ordinal: Int,
timeFormatters: TimeFormatters): Unit = {
dataTypes(ordinal) match {
case StringType =>
to += from.getString(ordinal)
Expand All @@ -100,13 +104,14 @@ private[hive] class SparkExecuteStatementOperation(
// - work with spark.sql.datetime.java8API.enabled
// These types have always been sent over the wire as string, converted later.
case _: DateType | _: TimestampType =>
val hiveString = HiveResult.toHiveString((from.get(ordinal), dataTypes(ordinal)))
to += hiveString
to += toHiveString((from.get(ordinal), dataTypes(ordinal)), false, timeFormatters)
case CalendarIntervalType =>
to += HiveResult.toHiveString((from.getAs[CalendarInterval](ordinal), CalendarIntervalType))
to += toHiveString(
(from.getAs[CalendarInterval](ordinal), CalendarIntervalType),
false,
timeFormatters)
case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] =>
val hiveString = HiveResult.toHiveString((from.get(ordinal), dataTypes(ordinal)))
to += hiveString
to += toHiveString((from.get(ordinal), dataTypes(ordinal)), false, timeFormatters)
}
}

Expand Down Expand Up @@ -159,6 +164,7 @@ private[hive] class SparkExecuteStatementOperation(
if (!iter.hasNext) {
resultRowSet
} else {
val timeFormatters = getTimeFormatters
// maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int
val maxRows = maxRowsL.toInt
var curRow = 0
Expand All @@ -170,7 +176,7 @@ private[hive] class SparkExecuteStatementOperation(
if (sparkRow.isNullAt(curCol)) {
row += null
} else {
addNonNullColumnValue(sparkRow, row, curCol)
addNonNullColumnValue(sparkRow, row, curCol, timeFormatters)
}
curCol += 1
}
Expand Down
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.SQLQueryTestSuite
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.util.fileToString
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.execution.HiveResult.{getTimeFormatters, toHiveString, TimeFormatters}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -257,8 +257,9 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ
private def getNormalizedResult(statement: Statement, sql: String): (String, Seq[String]) = {
val rs = statement.executeQuery(sql)
val cols = rs.getMetaData.getColumnCount
val timeFormatters = getTimeFormatters
val buildStr = () => (for (i <- 1 to cols) yield {
getHiveResult(rs.getObject(i))
getHiveResult(rs.getObject(i), timeFormatters)
}).mkString("\t")

val answer = Iterator.continually(rs.next()).takeWhile(identity).map(_ => buildStr()).toSeq
Expand All @@ -280,18 +281,18 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ
upperCase.startsWith("(")
}

private def getHiveResult(obj: Object): String = {
private def getHiveResult(obj: Object, timeFormatters: TimeFormatters): String = {
obj match {
case null =>
HiveResult.toHiveString((null, StringType))
toHiveString((null, StringType), false, timeFormatters)
case d: java.sql.Date =>
HiveResult.toHiveString((d, DateType))
toHiveString((d, DateType), false, timeFormatters)
case t: Timestamp =>
HiveResult.toHiveString((t, TimestampType))
toHiveString((t, TimestampType), false, timeFormatters)
case d: java.math.BigDecimal =>
HiveResult.toHiveString((d, DecimalType.fromDecimal(Decimal(d))))
toHiveString((d, DecimalType.fromDecimal(Decimal(d))), false, timeFormatters)
case bin: Array[Byte] =>
HiveResult.toHiveString((bin, BinaryType))
toHiveString((bin, BinaryType), false, timeFormatters)
case other =>
other.toString
}
Expand Down

0 comments on commit a152d94

Please sign in to comment.