Skip to content

Commit

Permalink
[KYUUBI #4745] Support Flink's LocalZonedTimestamp DataType
Browse files Browse the repository at this point in the history
### _Why are the changes needed?_

### _How was this patch tested?_
- [X] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [X] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #4751 from link3280/KYUUBI-4745.

Closes #4745

e1e900b [Paul Lin] [KYUUBI #4745] Replace hive's timestamp format with the kyuubi's
0693d1f [Paul Lin] [KYUUBI #4745] Pin time zone in tests
462b39f [Paul Lin] [KYUUBI #4745] Improve variable naming
5f9976d [Paul Lin] [KYUUBI #4745] Support Flink's LocalZonedTimestamp DataType

Authored-by: Paul Lin <paullin3280@gmail.com>
Signed-off-by: Cheng Pan <chengpan@apache.org>
(cherry picked from commit 79d6645)
Signed-off-by: Cheng Pan <chengpan@apache.org>
  • Loading branch information
link3280 authored and pan3793 committed Apr 25, 2023
1 parent 044fc51 commit 9e295f9
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 14 deletions.
Expand Up @@ -18,6 +18,7 @@
package org.apache.kyuubi.engine.flink.operation

import java.io.IOException
import java.time.ZoneId

import scala.collection.JavaConverters.collectionAsScalaIterableConverter

Expand Down Expand Up @@ -95,9 +96,15 @@ abstract class FlinkOperation(session: Session) extends AbstractOperation(sessio
case FETCH_FIRST => resultSet.getData.fetchAbsolute(0);
}
val token = resultSet.getData.take(rowSetSize)
val timeZone = Option(flinkSession.getSessionConfig.get("table.local-time-zone"))
val zoneId = timeZone match {
case Some(tz) => ZoneId.of(tz)
case None => ZoneId.systemDefault()
}
val resultRowSet = RowSet.resultSetToTRowSet(
token.toList,
resultSet,
zoneId,
getProtocolVersion)
resultRowSet.setStartRowOffset(resultSet.getData.getPosition)
resultRowSet
Expand Down
Expand Up @@ -21,7 +21,9 @@ import java.{lang, util}
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.{LocalDate, LocalDateTime}
import java.time.{Instant, LocalDate, LocalDateTime, ZonedDateTime, ZoneId}
import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder, TextStyle}
import java.time.temporal.ChronoField
import java.util.Collections

import scala.collection.JavaConverters._
Expand All @@ -42,15 +44,16 @@ object RowSet {
def resultSetToTRowSet(
rows: Seq[Row],
resultSet: ResultSet,
zoneId: ZoneId,
protocolVersion: TProtocolVersion): TRowSet = {
if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
toRowBaseSet(rows, resultSet)
toRowBaseSet(rows, resultSet, zoneId)
} else {
toColumnBasedSet(rows, resultSet)
toColumnBasedSet(rows, resultSet, zoneId)
}
}

def toRowBaseSet(rows: Seq[Row], resultSet: ResultSet): TRowSet = {
def toRowBaseSet(rows: Seq[Row], resultSet: ResultSet, zoneId: ZoneId): TRowSet = {
val rowSize = rows.size
val tRows = new util.ArrayList[TRow](rowSize)
var i = 0
Expand All @@ -60,7 +63,7 @@ object RowSet {
val columnSize = row.getArity
var j = 0
while (j < columnSize) {
val columnValue = toTColumnValue(j, row, resultSet)
val columnValue = toTColumnValue(j, row, resultSet, zoneId)
tRow.addToColVals(columnValue)
j += 1
}
Expand All @@ -71,14 +74,14 @@ object RowSet {
new TRowSet(0, tRows)
}

def toColumnBasedSet(rows: Seq[Row], resultSet: ResultSet): TRowSet = {
def toColumnBasedSet(rows: Seq[Row], resultSet: ResultSet, zoneId: ZoneId): TRowSet = {
val size = rows.length
val tRowSet = new TRowSet(0, new util.ArrayList[TRow](size))
val columnSize = resultSet.getColumns.size()
var i = 0
while (i < columnSize) {
val field = resultSet.getColumns.get(i)
val tColumn = toTColumn(rows, i, field.getDataType.getLogicalType)
val tColumn = toTColumn(rows, i, field.getDataType.getLogicalType, zoneId)
tRowSet.addToColumns(tColumn)
i += 1
}
Expand All @@ -88,7 +91,8 @@ object RowSet {
private def toTColumnValue(
ordinal: Int,
row: Row,
resultSet: ResultSet): TColumnValue = {
resultSet: ResultSet,
zoneId: ZoneId): TColumnValue = {

val column = resultSet.getColumns.get(ordinal)
val logicalType = column.getDataType.getLogicalType
Expand Down Expand Up @@ -153,6 +157,12 @@ object RowSet {
s"for type ${t.getClass}.")
}
TColumnValue.stringVal(tStringValue)
case _: LocalZonedTimestampType =>
val tStringValue = new TStringValue
val fieldValue = row.getField(ordinal)
tStringValue.setValue(TIMESTAMP_LZT_FORMATTER.format(
ZonedDateTime.ofInstant(fieldValue.asInstanceOf[Instant], zoneId)))
TColumnValue.stringVal(tStringValue)
case t =>
val tStringValue = new TStringValue
if (row.getField(ordinal) != null) {
Expand All @@ -166,7 +176,11 @@ object RowSet {
ByteBuffer.wrap(bitSet.toByteArray)
}

private def toTColumn(rows: Seq[Row], ordinal: Int, logicalType: LogicalType): TColumn = {
private def toTColumn(
rows: Seq[Row],
ordinal: Int,
logicalType: LogicalType,
zoneId: ZoneId): TColumn = {
val nulls = new java.util.BitSet()
// for each column, determine the conversion class by sampling the first non-value value
// if there's no row, set the entire column empty
Expand Down Expand Up @@ -211,6 +225,12 @@ object RowSet {
s"for type ${t.getClass}.")
}
TColumn.stringVal(new TStringColumn(values, nulls))
case _: LocalZonedTimestampType =>
val values = getOrSetAsNull[Instant](rows, ordinal, nulls, Instant.EPOCH)
.toArray().map(v =>
TIMESTAMP_LZT_FORMATTER.format(
ZonedDateTime.ofInstant(v.asInstanceOf[Instant], zoneId)))
TColumn.stringVal(new TStringColumn(values.toList.asJava, nulls))
case _ =>
var i = 0
val rowSize = rows.length
Expand Down Expand Up @@ -303,13 +323,14 @@ object RowSet {
case _: DecimalType => TTypeId.DECIMAL_TYPE
case _: DateType => TTypeId.DATE_TYPE
case _: TimestampType => TTypeId.TIMESTAMP_TYPE
case _: LocalZonedTimestampType => TTypeId.TIMESTAMPLOCALTZ_TYPE
case _: ArrayType => TTypeId.ARRAY_TYPE
case _: MapType => TTypeId.MAP_TYPE
case _: RowType => TTypeId.STRUCT_TYPE
case _: BinaryType => TTypeId.BINARY_TYPE
case _: VarBinaryType => TTypeId.BINARY_TYPE
case _: TimeType => TTypeId.STRING_TYPE
case t @ (_: ZonedTimestampType | _: LocalZonedTimestampType | _: MultisetType |
case t @ (_: ZonedTimestampType | _: MultisetType |
_: YearMonthIntervalType | _: DayTimeIntervalType) =>
throw new IllegalArgumentException(
"Flink data type `%s` is not supported currently".format(t.asSummaryString()),
Expand Down Expand Up @@ -377,4 +398,26 @@ object RowSet {
other.toString
}
}

/** should stay in sync with org.apache.kyuubi.jdbc.hive.common.TimestampTZUtil */
var TIMESTAMP_LZT_FORMATTER: DateTimeFormatter = {
val builder = new DateTimeFormatterBuilder
// Date part
builder.append(DateTimeFormatter.ofPattern("yyyy-MM-dd"))
// Time part
builder
.optionalStart
.appendLiteral(" ")
.append(DateTimeFormatter.ofPattern("HH:mm:ss"))
.optionalStart
.appendFraction(ChronoField.NANO_OF_SECOND, 1, 9, true)
.optionalEnd
.optionalEnd

// Zone part
builder.optionalStart.appendLiteral(" ").optionalEnd
builder.optionalStart.appendZoneText(TextStyle.NARROW).optionalEnd

builder.toFormatter
}
}
Expand Up @@ -756,6 +756,23 @@ class FlinkOperationSuite extends WithFlinkSQLEngine with HiveJDBCTestHelper {
}
}

test("execute statement - select timestamp with local time zone") {
withJdbcStatement() { statement =>
statement.executeQuery("CREATE VIEW T1 AS SELECT TO_TIMESTAMP_LTZ(4001, 3)")
statement.executeQuery("SET 'table.local-time-zone' = 'UTC'")
val resultSetUTC = statement.executeQuery("SELECT * FROM T1")
val metaData = resultSetUTC.getMetaData
assert(metaData.getColumnType(1) === java.sql.Types.OTHER)
assert(resultSetUTC.next())
assert(resultSetUTC.getString(1) === "1970-01-01 00:00:04.001 UTC")

statement.executeQuery("SET 'table.local-time-zone' = 'America/Los_Angeles'")
val resultSetPST = statement.executeQuery("SELECT * FROM T1")
assert(resultSetPST.next())
assert(resultSetPST.getString(1) === "1969-12-31 16:00:04.001 America/Los_Angeles")
}
}

test("execute statement - select time") {
withJdbcStatement() { statement =>
val resultSet =
Expand Down
Expand Up @@ -17,6 +17,8 @@

package org.apache.kyuubi.engine.flink.result

import java.time.ZoneId

import org.apache.flink.table.api.{DataTypes, ResultKind}
import org.apache.flink.table.catalog.Column
import org.apache.flink.table.data.StringData
Expand Down Expand Up @@ -44,9 +46,10 @@ class ResultSetSuite extends KyuubiFunSuite {
.data(rowsNew)
.build

assert(RowSet.toRowBaseSet(rowsNew, resultSetNew)
=== RowSet.toRowBaseSet(rowsOld, resultSetOld))
assert(RowSet.toColumnBasedSet(rowsNew, resultSetNew)
=== RowSet.toColumnBasedSet(rowsOld, resultSetOld))
val timeZone = ZoneId.of("America/Los_Angeles")
assert(RowSet.toRowBaseSet(rowsNew, resultSetNew, timeZone)
=== RowSet.toRowBaseSet(rowsOld, resultSetOld, timeZone))
assert(RowSet.toColumnBasedSet(rowsNew, resultSetNew, timeZone)
=== RowSet.toColumnBasedSet(rowsOld, resultSetOld, timeZone))
}
}

0 comments on commit 9e295f9

Please sign in to comment.