Skip to content

Commit

Permalink
[KYUUBI #4316] Fix returned Timestamp values may lose precision
Browse files Browse the repository at this point in the history
### _Why are the changes needed?_

This PR proposes to use `org.apache.spark.sql.execution#toHiveString` to replace `org.apache.kyuubi.engine.spark.schema#toHiveString` to get consistent result w/ `spark-sql` and `STS`.

Because of [SPARK-32006](https://issues.apache.org/jira/browse/SPARK-32006), it only works w/ Spark 3.1 and above.

The patch takes effects on both thrift and arrow result format.

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [x] Add screenshots for manual tests if appropriate
```
➜  ~ beeline -u 'jdbc:hive2://0.0.0.0:10009/default'
Connecting to jdbc:hive2://0.0.0.0:10009/default
Connected to: Spark SQL (version 3.3.1)
Driver: Hive JDBC (version 2.3.9)
Transaction isolation: TRANSACTION_REPEATABLE_READ
Beeline version 2.3.9 by Apache Hive
0: jdbc:hive2://0.0.0.0:10009/default> select to_timestamp('2023-02-08 22:17:33.123456789');
+----------------------------------------------+
| to_timestamp(2023-02-08 22:17:33.123456789)  |
+----------------------------------------------+
| 2023-02-08 22:17:33.123456                   |
+----------------------------------------------+
1 row selected (0.415 seconds)
```

- [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #4318 from pan3793/hive-string.

Closes #4316

ba9016f [Cheng Pan] nit
8be774b [Cheng Pan] nit
bd696fe [Cheng Pan] nit
b5cf051 [Cheng Pan] fix
dd6b702 [Cheng Pan] test
63edd34 [Cheng Pan] nit
37cc70a [Cheng Pan] Fix python ut
c66ad22 [Cheng Pan] [KYUUBI #4316] Fix returned Timestamp values may lose precision
41d9444 [Cheng Pan] Revert "[KYUUBI #3958] Fix Spark session timezone format"

Authored-by: Cheng Pan <chengpan@apache.org>
Signed-off-by: Cheng Pan <chengpan@apache.org>
  • Loading branch information
pan3793 committed Feb 14, 2023
1 parent 763c088 commit 8fe7947
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 184 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ class ExecutePython(
val output = response.map(_.content.getOutput()).getOrElse("")
val ename = response.map(_.content.getEname()).getOrElse("")
val evalue = response.map(_.content.getEvalue()).getOrElse("")
val traceback = response.map(_.content.getTraceback()).getOrElse(Array.empty)
val traceback = response.map(_.content.getTraceback()).getOrElse(Seq.empty)
iter =
new ArrayFetchIterator[Row](Array(Row(output, status, ename, evalue, Row(traceback: _*))))
new ArrayFetchIterator[Row](Array(Row(output, status, ename, evalue, traceback)))
setState(OperationState.FINISHED)
} else {
throw KyuubiSQLException(s"Interpret error:\n$statement\n $response")
Expand Down Expand Up @@ -210,7 +210,7 @@ case class SessionPythonWorker(
stdin.flush()
val pythonResponse = Option(stdout.readLine()).map(ExecutePython.fromJson[PythonResponse](_))
// throw exception if internal python code fail
if (internal && pythonResponse.map(_.content.status) != Some(PythonResponse.OK_STATUS)) {
if (internal && !pythonResponse.map(_.content.status).contains(PythonResponse.OK_STATUS)) {
throw KyuubiSQLException(s"Internal python code $code failure: $pythonResponse")
}
pythonResponse
Expand Down Expand Up @@ -328,7 +328,7 @@ object ExecutePython extends Logging {
}

// for test
def defaultSparkHome(): String = {
def defaultSparkHome: String = {
val homeDirFilter: FilenameFilter = (dir: File, name: String) =>
dir.isDirectory && name.contains("spark-") && !name.contains("-engine")
// get from kyuubi-server/../externals/kyuubi-download/target
Expand Down Expand Up @@ -418,7 +418,7 @@ case class PythonResponseContent(
data: Map[String, String],
ename: String,
evalue: String,
traceback: Array[String],
traceback: Seq[String],
status: String) {
def getOutput(): String = {
Option(data)
Expand All @@ -431,7 +431,7 @@ case class PythonResponseContent(
def getEvalue(): String = {
Option(evalue).getOrElse("")
}
def getTraceback(): Array[String] = {
Option(traceback).getOrElse(Array.empty)
def getTraceback(): Seq[String] = {
Option(traceback).getOrElse(Seq.empty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TProgressU
import org.apache.spark.kyuubi.{SparkProgressMonitor, SQLOperationListener}
import org.apache.spark.kyuubi.SparkUtilsHelper.redact
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.{KyuubiSQLException, Utils}
Expand Down Expand Up @@ -135,27 +136,35 @@ abstract class SparkOperation(session: Session)
spark.sparkContext.setLocalProperty

protected def withLocalProperties[T](f: => T): T = {
try {
spark.sparkContext.setJobGroup(statementId, redactedStatement, forceCancel)
spark.sparkContext.setLocalProperty(KYUUBI_SESSION_USER_KEY, session.user)
spark.sparkContext.setLocalProperty(KYUUBI_STATEMENT_ID_KEY, statementId)
schedulerPool match {
case Some(pool) =>
spark.sparkContext.setLocalProperty(SPARK_SCHEDULER_POOL_KEY, pool)
case None =>
}
if (isSessionUserSignEnabled) {
setSessionUserSign()
}
SQLConf.withExistingConf(spark.sessionState.conf) {
val originalSession = SparkSession.getActiveSession
try {
SparkSession.setActiveSession(spark)
spark.sparkContext.setJobGroup(statementId, redactedStatement, forceCancel)
spark.sparkContext.setLocalProperty(KYUUBI_SESSION_USER_KEY, session.user)
spark.sparkContext.setLocalProperty(KYUUBI_STATEMENT_ID_KEY, statementId)
schedulerPool match {
case Some(pool) =>
spark.sparkContext.setLocalProperty(SPARK_SCHEDULER_POOL_KEY, pool)
case None =>
}
if (isSessionUserSignEnabled) {
setSessionUserSign()
}

f
} finally {
spark.sparkContext.setLocalProperty(SPARK_SCHEDULER_POOL_KEY, null)
spark.sparkContext.setLocalProperty(KYUUBI_SESSION_USER_KEY, null)
spark.sparkContext.setLocalProperty(KYUUBI_STATEMENT_ID_KEY, null)
spark.sparkContext.clearJobGroup()
if (isSessionUserSignEnabled) {
clearSessionUserSign()
f
} finally {
spark.sparkContext.setLocalProperty(SPARK_SCHEDULER_POOL_KEY, null)
spark.sparkContext.setLocalProperty(KYUUBI_SESSION_USER_KEY, null)
spark.sparkContext.setLocalProperty(KYUUBI_STATEMENT_ID_KEY, null)
spark.sparkContext.clearJobGroup()
if (isSessionUserSignEnabled) {
clearSessionUserSign()
}
originalSession match {
case Some(session) => SparkSession.setActiveSession(session)
case None => SparkSession.clearActiveSession()
}
}
}
}
Expand Down Expand Up @@ -246,7 +255,7 @@ abstract class SparkOperation(session: Session)
} else {
val taken = iter.take(rowSetSize)
RowSet.toTRowSet(
taken.toList.asInstanceOf[List[Row]],
taken.toSeq.asInstanceOf[Seq[Row]],
resultSchema,
getProtocolVersion,
timeZone)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,25 @@
package org.apache.kyuubi.engine.spark.schema

import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.sql.Timestamp
import java.time._
import java.util.Date
import java.time.ZoneId

import scala.collection.JavaConverters._

import org.apache.hive.service.rpc.thrift._
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.types._

import org.apache.kyuubi.engine.spark.schema.SchemaHelper.TIMESTAMP_NTZ
import org.apache.kyuubi.util.RowSetUtils._

object RowSet {

def toHiveString(valueAndType: (Any, DataType), nested: Boolean = false): String = {
// compatible w/ Spark 3.1 and above
val timeFormatters = HiveResult.getTimeFormatters
HiveResult.toHiveString(valueAndType, nested, timeFormatters)
}

def toTRowSet(
bytes: Array[Byte],
protocolVersion: TProtocolVersion): TRowSet = {
Expand Down Expand Up @@ -68,9 +71,9 @@ object RowSet {
}

def toRowBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
var i = 0
val rowSize = rows.length
val tRows = new java.util.ArrayList[TRow](rowSize)
var i = 0
while (i < rowSize) {
val row = rows(i)
val tRow = new TRow()
Expand Down Expand Up @@ -151,13 +154,7 @@ object RowSet {
while (i < rowSize) {
val row = rows(i)
nulls.set(i, row.isNullAt(ordinal))
val value =
if (row.isNullAt(ordinal)) {
""
} else {
toHiveString((row.get(ordinal), typ), timeZone)
}
values.add(value)
values.add(toHiveString(row.get(ordinal) -> typ))
i += 1
}
TColumn.stringVal(new TStringColumn(values, nulls))
Expand Down Expand Up @@ -238,69 +235,12 @@ object RowSet {
case _ =>
val tStrValue = new TStringValue
if (!row.isNullAt(ordinal)) {
tStrValue.setValue(
toHiveString((row.get(ordinal), types(ordinal).dataType), timeZone))
tStrValue.setValue(toHiveString(row.get(ordinal) -> types(ordinal).dataType))
}
TColumnValue.stringVal(tStrValue)
}
}

/**
* A simpler impl of Spark's toHiveString
*/
def toHiveString(dataWithType: (Any, DataType), timeZone: ZoneId): String = {
dataWithType match {
case (null, _) =>
// Only match nulls in nested type values
"null"

case (d: Date, DateType) =>
formatDate(d)

case (ld: LocalDate, DateType) =>
formatLocalDate(ld)

case (t: Timestamp, TimestampType) =>
formatTimestamp(t, Option(timeZone))

case (t: LocalDateTime, ntz) if ntz.getClass.getSimpleName.equals(TIMESTAMP_NTZ) =>
formatLocalDateTime(t)

case (i: Instant, TimestampType) =>
formatInstant(i, Option(timeZone))

case (bin: Array[Byte], BinaryType) =>
new String(bin, StandardCharsets.UTF_8)

case (decimal: java.math.BigDecimal, DecimalType()) =>
decimal.toPlainString

case (s: String, StringType) =>
// Only match string in nested type values
"\"" + s + "\""

case (d: Duration, _) => toDayTimeIntervalString(d)

case (p: Period, _) => toYearMonthIntervalString(p)

case (seq: scala.collection.Seq[_], ArrayType(typ, _)) =>
seq.map(v => (v, typ)).map(e => toHiveString(e, timeZone)).mkString("[", ",", "]")

case (m: Map[_, _], MapType(kType, vType, _)) =>
m.map { case (key, value) =>
toHiveString((key, kType), timeZone) + ":" + toHiveString((value, vType), timeZone)
}.toSeq.sorted.mkString("{", ",", "}")

case (struct: Row, StructType(fields)) =>
struct.toSeq.zip(fields).map { case (v, t) =>
s""""${t.name}":${toHiveString((v, t.dataType), timeZone)}"""
}.mkString("{", ",", "}")

case (other, _) =>
other.toString
}
}

private def toTColumn(data: Array[Byte]): TColumn = {
val values = new java.util.ArrayList[ByteBuffer](1)
values.add(ByteBuffer.wrap(data))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.time.ZoneId
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
import org.apache.spark.sql.types._

import org.apache.kyuubi.engine.spark.schema.RowSet

Expand All @@ -41,11 +41,11 @@ object SparkDatasetHelper {
val dt = DataType.fromDDL(schemaDDL)
dt match {
case StructType(Array(StructField(_, st: StructType, _, _))) =>
RowSet.toHiveString((row, st), timeZone)
RowSet.toHiveString((row, st), nested = true)
case StructType(Array(StructField(_, at: ArrayType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, at), timeZone)
RowSet.toHiveString((row.toSeq.head, at), nested = true)
case StructType(Array(StructField(_, mt: MapType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, mt), timeZone)
RowSet.toHiveString((row.toSeq.head, mt), nested = true)
case _ =>
throw new UnsupportedOperationException
}
Expand All @@ -54,7 +54,7 @@ object SparkDatasetHelper {
val cols = df.schema.map {
case sf @ StructField(name, _: StructType, _, _) =>
toHiveStringUDF(quotedCol(name), lit(sf.toDDL)).as(name)
case sf @ StructField(name, (_: MapType | _: ArrayType), _, _) =>
case sf @ StructField(name, _: MapType | _: ArrayType, _, _) =>
toHiveStringUDF(struct(quotedCol(name)), lit(sf.toDDL)).as(name)
case StructField(name, _, _, _) => quotedCol(name)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

import org.apache.kyuubi.KyuubiFunSuite
import org.apache.kyuubi.engine.spark.schema.RowSet.toHiveString

class RowSetSuite extends KyuubiFunSuite {

Expand Down Expand Up @@ -159,22 +158,22 @@ class RowSetSuite extends KyuubiFunSuite {

val decCol = cols.next().getStringVal
decCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b.isEmpty)
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === s"$i.$i")
}

val dateCol = cols.next().getStringVal
dateCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b.isEmpty)
case (b, 11) => assert(b === "NULL")
case (b, i) =>
assert(b === toHiveString((Date.valueOf(s"2018-11-${i + 1}"), DateType), zoneId))
assert(b === RowSet.toHiveString(Date.valueOf(s"2018-11-${i + 1}") -> DateType))
}

val tsCol = cols.next().getStringVal
tsCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b.isEmpty)
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b ===
toHiveString((Timestamp.valueOf(s"2018-11-17 13:33:33.$i"), TimestampType), zoneId))
RowSet.toHiveString(Timestamp.valueOf(s"2018-11-17 13:33:33.$i") -> TimestampType))
}

val binCol = cols.next().getBinaryVal
Expand All @@ -185,23 +184,21 @@ class RowSetSuite extends KyuubiFunSuite {

val arrCol = cols.next().getStringVal
arrCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "")
case (b, i) => assert(b === toHiveString(
(Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq, ArrayType(DoubleType)),
zoneId))
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === RowSet.toHiveString(
Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq -> ArrayType(DoubleType)))
}

val mapCol = cols.next().getStringVal
mapCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "")
case (b, i) => assert(b === toHiveString(
(Map(i -> java.lang.Double.valueOf(s"$i.$i")), MapType(IntegerType, DoubleType)),
zoneId))
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === RowSet.toHiveString(
Map(i -> java.lang.Double.valueOf(s"$i.$i")) -> MapType(IntegerType, DoubleType)))
}

val intervalCol = cols.next().getStringVal
intervalCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "")
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === new CalendarInterval(i, i, i).toString)
}
}
Expand Down Expand Up @@ -237,15 +234,15 @@ class RowSetSuite extends KyuubiFunSuite {
assert(r6.get(9).getStringVal.getValue === "2018-11-06")

val r7 = iter.next().getColVals
assert(r7.get(10).getStringVal.getValue === "2018-11-17 13:33:33.600")
assert(r7.get(10).getStringVal.getValue === "2018-11-17 13:33:33.6")
assert(r7.get(11).getStringVal.getValue === new String(
Array.fill[Byte](6)(6.toByte),
StandardCharsets.UTF_8))

val r8 = iter.next().getColVals
assert(r8.get(12).getStringVal.getValue === Array.fill(7)(7.7d).mkString("[", ",", "]"))
assert(r8.get(13).getStringVal.getValue ===
toHiveString((Map(7 -> 7.7d), MapType(IntegerType, DoubleType)), zoneId))
RowSet.toHiveString(Map(7 -> 7.7d) -> MapType(IntegerType, DoubleType)))

val r9 = iter.next().getColVals
assert(r9.get(14).getStringVal.getValue === new CalendarInterval(8, 8, 8).toString)
Expand Down
Loading

0 comments on commit 8fe7947

Please sign in to comment.