Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][SPARK-29155][SQL] Support special date/timestamp values in the PostgreSQL dialect only #25834

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private lazy val dateFormatter = DateFormatter(zoneId)
private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId)
private val failOnIntegralTypeOverflow = SQLConf.get.ansiEnabled
private val supportSpecialValues = SQLConf.get.isPostgreSqlDialect

// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
Expand Down Expand Up @@ -423,8 +424,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String

// TimestampConverter
private[this] def castToTimestamp(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, utfs => DateTimeUtils.stringToTimestamp(utfs, zoneId).orNull)
case StringType => buildCast[UTF8String](_, utfs =>
DateTimeUtils.stringToTimestamp(utfs, zoneId, supportSpecialValues).orNull)
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1L else 0)
case LongType =>
Expand Down Expand Up @@ -468,8 +469,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String

// DateConverter
private[this] def castToDate(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s, zoneId).orNull)
case StringType => buildCast[UTF8String](_, s =>
DateTimeUtils.stringToDate(s, zoneId, supportSpecialValues).orNull)
case TimestampType =>
// throw valid precision more than seconds, according to Hive.
// Timestamp.nanos is in 0 to 999,999,999, no more than a second.
Expand Down Expand Up @@ -1067,10 +1068,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case StringType =>
val intOpt = ctx.freshVariable("intOpt", classOf[Option[Integer]])
val zid = getZoneId()
val sv = supportSpecialValues.toString
(c, evPrim, evNull) =>
code"""
scala.Option<Integer> $intOpt =
org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c, $zid);
org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c, $zid, $sv);
if ($intOpt.isDefined()) {
$evPrim = ((Integer) $intOpt.get()).intValue();
} else {
Expand Down Expand Up @@ -1181,11 +1183,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val zid = JavaCode.global(
ctx.addReferenceObj("zoneId", zoneId, zoneIdClass.getName),
zoneIdClass)
val sv = supportSpecialValues.toString
val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]])
(c, evPrim, evNull) =>
code"""
scala.Option<Long> $longOpt =
org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c, $zid);
org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c, $zid, $sv);
MaxGekk marked this conversation as resolved.
Show resolved Hide resolved
if ($longOpt.isDefined()) {
$evPrim = ((Long) $longOpt.get()).longValue();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1732,13 +1732,14 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
throw new ParseException(s"Cannot parse the $valueType value: $value", ctx)
}
}
def supportSpecialValues() = SQLConf.get.isPostgreSqlDialect
def zoneId() = getZoneId(SQLConf.get.sessionLocalTimeZone)
try {
valueType match {
case "DATE" =>
toLiteral(stringToDate(_, getZoneId(SQLConf.get.sessionLocalTimeZone)), DateType)
toLiteral(stringToDate(_, zoneId(), supportSpecialValues()), DateType)
case "TIMESTAMP" =>
val zoneId = getZoneId(SQLConf.get.sessionLocalTimeZone)
toLiteral(stringToTimestamp(_, zoneId), TimestampType)
toLiteral(stringToTimestamp(_, zoneId(), supportSpecialValues()), TimestampType)
case "INTERVAL" =>
Literal(CalendarInterval.fromString(value), CalendarIntervalType)
case "X" =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ class Iso8601DateFormatter(
private lazy val formatter = getOrCreateFormatter(pattern, locale)

override def parse(s: String): Int = {
val specialDate = convertSpecialDate(s.trim, zoneId)
val specialDate = if (supportSpecialValues) {
convertSpecialDate(s.trim, zoneId)
} else None
specialDate.getOrElse {
val localDate = LocalDate.parse(s, formatter)
localDateToDays(localDate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ import java.util.Locale
import com.google.common.cache.CacheBuilder

import org.apache.spark.sql.catalyst.util.DateTimeFormatterHelper._
import org.apache.spark.sql.internal.SQLConf

trait DateTimeFormatterHelper {
protected val supportSpecialValues: Boolean = SQLConf.get.isPostgreSqlDialect

// Converts the parsed temporal object to ZonedDateTime. It sets time components to zeros
// if they does not exist in the parsed object.
protected def toZonedDateTime(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ object DateTimeUtils {
* `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]-[h]h:[m]m`
* `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m`
*/
def stringToTimestamp(s: UTF8String, timeZoneId: ZoneId): Option[SQLTimestamp] = {
def stringToTimestamp(
s: UTF8String,
timeZoneId: ZoneId,
supportSpecialValues: Boolean): Option[SQLTimestamp] = {
if (s == null) {
return None
}
Expand All @@ -219,8 +222,10 @@ object DateTimeUtils {
var i = 0
var currentSegmentValue = 0
val bytes = s.trim.getBytes
val specialTimestamp = convertSpecialTimestamp(bytes, timeZoneId)
if (specialTimestamp.isDefined) return specialTimestamp
if (supportSpecialValues) {
val specialTimestamp = convertSpecialTimestamp(bytes, timeZoneId)
if (specialTimestamp.isDefined) return specialTimestamp
}
var j = 0
var digitsMilli = 0
var justTime = false
Expand Down Expand Up @@ -378,16 +383,21 @@ object DateTimeUtils {
* `yyyy-[m]m-[d]d *`
* `yyyy-[m]m-[d]dT*`
*/
def stringToDate(s: UTF8String, zoneId: ZoneId): Option[SQLDate] = {
def stringToDate(
s: UTF8String,
zoneId: ZoneId,
supportSpecialValues: Boolean): Option[SQLDate] = {
if (s == null) {
return None
}
val segments: Array[Int] = Array[Int](1, 1, 1)
var i = 0
var currentSegmentValue = 0
val bytes = s.trim.getBytes
val specialDate = convertSpecialDate(bytes, zoneId)
if (specialDate.isDefined) return specialDate
if (supportSpecialValues) {
val specialDate = convertSpecialDate(bytes, zoneId)
if (specialDate.isDefined) return specialDate
}
var j = 0
while (j < bytes.length && (i < 3 && !(bytes(j) == ' ' || bytes(j) == 'T'))) {
val b = bytes(j)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import java.util.concurrent.TimeUnit.SECONDS

import DateTimeUtils.convertSpecialTimestamp

import org.apache.spark.sql.internal.SQLConf

sealed trait TimestampFormatter extends Serializable {
/**
* Parses a timestamp in a string and converts it to microseconds.
Expand All @@ -52,8 +54,10 @@ class Iso8601TimestampFormatter(
protected lazy val formatter = getOrCreateFormatter(pattern, locale)

override def parse(s: String): Long = {
val specialDate = convertSpecialTimestamp(s.trim, zoneId)
specialDate.getOrElse {
val specialTimestamp = if (supportSpecialValues) {
convertSpecialTimestamp(s.trim, zoneId)
} else None
specialTimestamp.getOrElse {
val parsed = formatter.parse(s)
val parsedZoneId = parsed.query(TemporalQueries.zone())
val timeZoneId = if (parsedZoneId == null) zoneId else parsedZoneId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,24 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

object Dialect extends Enumeration {
val SPARK, POSTGRESQL = Value
}

val DIALECT =
buildConf("spark.sql.dialect")
.doc("The specific features of the SQL language to be adopted, which are available when " +
"accessing the given database. Currently, Spark supports two database dialects, `Spark` " +
"and `PostgreSQL`. With `PostgreSQL` dialect, Spark will: " +
"1. perform integral division with the / operator if both sides are integral types; " +
"2. accept \"true\", \"yes\", \"1\", \"false\", \"no\", \"0\", and unique prefixes as " +
"input and trim input for the boolean data type; " +
"3. support special date/timestamp values: epoch, now, today, yesterday and tomorrow.")
.stringConf
.transform(_.toUpperCase(Locale.ROOT))
.checkValues(Dialect.values.map(_.toString))
.createWithDefault(Dialect.SPARK.toString)

val PREFER_INTEGRAL_DIVISION = buildConf("spark.sql.function.preferIntegralDivision")
.internal()
.doc("When true, will perform integral division with the / operator " +
Expand Down Expand Up @@ -2194,6 +2212,8 @@ class SQLConf extends Serializable with Logging {

def utcTimestampFuncEnabled: Boolean = getConf(UTC_TIMESTAMP_FUNC_ENABLED)

def isPostgreSqlDialect: Boolean = getConf(SQLConf.DIALECT) == Dialect.POSTGRESQL.toString

/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1200,4 +1200,17 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(Long.MinValue - 0.9D, LongType), Long.MinValue)
}
}

test("cast special timestamp and date value") {
withSQLConf(SQLConf.DIALECT.key -> SQLConf.Dialect.POSTGRESQL.toString) {
checkEvaluation(cast(Literal("epoch"), TimestampType, Option("UTC")),
new Timestamp(0))
checkEvaluation(cast(Literal("epoch"), DateType, Option("UTC")),
Date.valueOf("1970-01-01"))
}
withSQLConf(SQLConf.DIALECT.key -> SQLConf.Dialect.SPARK.toString) {
checkEvaluation(cast(Literal("epoch"), TimestampType, Option("UTC")), null)
checkEvaluation(cast(Literal("epoch"), DateType, Option("UTC")), null)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("hive-hash for date type") {
def checkHiveHashForDateType(dateString: String, expected: Long): Unit = {
checkHiveHash(
DateTimeUtils.stringToDate(UTF8String.fromString(dateString), ZoneOffset.UTC).get,
DateTimeUtils.stringToDate(UTF8String.fromString(dateString), ZoneOffset.UTC, true).get,
DateType,
expected)
}
Expand Down Expand Up @@ -210,7 +210,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
expected: Long,
zoneId: ZoneId = ZoneOffset.UTC): Unit = {
checkHiveHash(
DateTimeUtils.stringToTimestamp(UTF8String.fromString(timestamp), zoneId).get,
DateTimeUtils.stringToTimestamp(UTF8String.fromString(timestamp), zoneId, true).get,
TimestampType,
expected)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers {
val tf = TimestampFormatter.getFractionFormatter(DateTimeUtils.defaultTimeZone.toZoneId)
def checkStringToTimestamp(originalTime: String, expectedParsedTime: String) {
val parsedTimestampOp = DateTimeUtils.stringToTimestamp(
UTF8String.fromString(originalTime), defaultZoneId)
UTF8String.fromString(originalTime), defaultZoneId, true)
assert(parsedTimestampOp.isDefined, "timestamp with nanoseconds was not parsed correctly")
assert(DateTimeUtils.timestampToString(tf, parsedTimestampOp.get) === expectedParsedTime)
}
Expand Down Expand Up @@ -120,8 +120,11 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers {
checkFromToJavaDate(new Date(df2.parse("1776-07-04 18:30:00 UTC").getTime))
}

private def toDate(s: String, zoneId: ZoneId = ZoneOffset.UTC): Option[SQLDate] = {
stringToDate(UTF8String.fromString(s), zoneId)
private def toDate(
s: String,
zoneId: ZoneId = ZoneOffset.UTC,
supportSpecialValues: Boolean = true): Option[SQLDate] = {
stringToDate(UTF8String.fromString(s), zoneId, supportSpecialValues)
}

test("string to date") {
Expand All @@ -148,8 +151,11 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers {
assert(toDate("1999 08").isEmpty)
}

private def toTimestamp(str: String, zoneId: ZoneId): Option[SQLTimestamp] = {
stringToTimestamp(UTF8String.fromString(str), zoneId)
private def toTimestamp(
str: String,
zoneId: ZoneId,
supportSpecialValues: Boolean = true): Option[SQLTimestamp] = {
stringToTimestamp(UTF8String.fromString(str), zoneId, supportSpecialValues)
}

test("string to timestamp") {
Expand Down Expand Up @@ -276,9 +282,9 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers {

// Test stringToTimestamp
assert(stringToTimestamp(
UTF8String.fromString("2015-02-29 00:00:00"), defaultZoneId).isEmpty)
UTF8String.fromString("2015-02-29 00:00:00"), defaultZoneId, true).isEmpty)
assert(stringToTimestamp(
UTF8String.fromString("2015-04-31 00:00:00"), defaultZoneId).isEmpty)
UTF8String.fromString("2015-04-31 00:00:00"), defaultZoneId, true).isEmpty)
assert(toTimestamp("2015-02-29", defaultZoneId).isEmpty)
assert(toTimestamp("2015-04-31", defaultZoneId).isEmpty)
}
Expand Down Expand Up @@ -469,15 +475,15 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers {
}

val defaultInputTS = DateTimeUtils.stringToTimestamp(
UTF8String.fromString("2015-03-05T09:32:05.359123"), defaultZoneId)
UTF8String.fromString("2015-03-05T09:32:05.359123"), defaultZoneId, true)
val defaultInputTS1 = DateTimeUtils.stringToTimestamp(
UTF8String.fromString("2015-03-31T20:32:05.359"), defaultZoneId)
UTF8String.fromString("2015-03-31T20:32:05.359"), defaultZoneId, true)
val defaultInputTS2 = DateTimeUtils.stringToTimestamp(
UTF8String.fromString("2015-04-01T02:32:05.359"), defaultZoneId)
UTF8String.fromString("2015-04-01T02:32:05.359"), defaultZoneId, true)
val defaultInputTS3 = DateTimeUtils.stringToTimestamp(
UTF8String.fromString("2015-03-30T02:32:05.359"), defaultZoneId)
UTF8String.fromString("2015-03-30T02:32:05.359"), defaultZoneId, true)
val defaultInputTS4 = DateTimeUtils.stringToTimestamp(
UTF8String.fromString("2015-03-29T02:32:05.359"), defaultZoneId)
UTF8String.fromString("2015-03-29T02:32:05.359"), defaultZoneId, true)

testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", defaultInputTS.get)
testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", defaultInputTS.get)
Expand All @@ -502,17 +508,17 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers {
for (tz <- ALL_TIMEZONES) {
withDefaultTimeZone(tz) {
val inputTS = DateTimeUtils.stringToTimestamp(
UTF8String.fromString("2015-03-05T09:32:05.359"), defaultZoneId)
UTF8String.fromString("2015-03-05T09:32:05.359"), defaultZoneId, true)
val inputTS1 = DateTimeUtils.stringToTimestamp(
UTF8String.fromString("2015-03-31T20:32:05.359"), defaultZoneId)
UTF8String.fromString("2015-03-31T20:32:05.359"), defaultZoneId, true)
val inputTS2 = DateTimeUtils.stringToTimestamp(
UTF8String.fromString("2015-04-01T02:32:05.359"), defaultZoneId)
UTF8String.fromString("2015-04-01T02:32:05.359"), defaultZoneId, true)
val inputTS3 = DateTimeUtils.stringToTimestamp(
UTF8String.fromString("2015-03-30T02:32:05.359"), defaultZoneId)
UTF8String.fromString("2015-03-30T02:32:05.359"), defaultZoneId, true)
val inputTS4 = DateTimeUtils.stringToTimestamp(
UTF8String.fromString("2015-03-29T02:32:05.359"), defaultZoneId)
UTF8String.fromString("2015-03-29T02:32:05.359"), defaultZoneId, true)
val inputTS5 = DateTimeUtils.stringToTimestamp(
UTF8String.fromString("1999-03-29T01:02:03.456789"), defaultZoneId)
UTF8String.fromString("1999-03-29T01:02:03.456789"), defaultZoneId, true)

testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", inputTS.get, tz)
testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", inputTS.get, tz)
Expand Down Expand Up @@ -576,16 +582,20 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers {
DateTimeTestUtils.outstandingZoneIds.foreach { zoneId =>
val tolerance = TimeUnit.SECONDS.toMicros(30)

assert(toTimestamp("Epoch", zoneId).get === 0)
assert(toTimestamp("Epoch", zoneId, true).get === 0)
val now = instantToMicros(LocalDateTime.now(zoneId).atZone(zoneId).toInstant)
toTimestamp("NOW", zoneId).get should be (now +- tolerance)
toTimestamp("NOW", zoneId, true).get should be (now +- tolerance)
assert(toTimestamp("now UTC", zoneId) === None)
val today = instantToMicros(LocalDateTime.now(zoneId)
.`with`(LocalTime.MIDNIGHT)
.atZone(zoneId).toInstant)
toTimestamp(" Yesterday", zoneId).get should be (today - MICROS_PER_DAY +- tolerance)
toTimestamp("Today ", zoneId).get should be (today +- tolerance)
toTimestamp(" tomorrow CET ", zoneId).get should be (today + MICROS_PER_DAY +- tolerance)
toTimestamp(" Yesterday", zoneId, true).get should be (today - MICROS_PER_DAY +- tolerance)
toTimestamp("Today ", zoneId, true).get should be (today +- tolerance)
toTimestamp(" tomorrow CET ", zoneId, true).get should be
(today + MICROS_PER_DAY +- tolerance)

// It must return None when support of special values is disabled
assert(toTimestamp("Epoch", zoneId, false) === None)
}
}

Expand All @@ -598,6 +608,9 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers {
assert(toDate("now UTC", zoneId) === None) // "now" does not accept time zones
assert(toDate("today", zoneId).get === today)
assert(toDate("tomorrow CET ", zoneId).get === today + 1)

// It must return None when support of special values is disabled
assert(toDate("Epoch", zoneId, false) === None)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ class UnsafeArraySuite extends SparkFunSuite {
val doubleArray = Array(1.1, 2.2, 3.3)
val stringArray = Array("1", "10", "100")
val dateArray = Array(
DateTimeUtils.stringToDate(UTF8String.fromString("1970-1-1"), ZoneOffset.UTC).get,
DateTimeUtils.stringToDate(UTF8String.fromString("2016-7-26"), ZoneOffset.UTC).get)
DateTimeUtils.stringToDate(UTF8String.fromString("1970-1-1"), ZoneOffset.UTC, true).get,
DateTimeUtils.stringToDate(UTF8String.fromString("2016-7-26"), ZoneOffset.UTC, true).get)
private def defaultZoneId = ZoneId.systemDefault()
val timestampArray = Array(
DateTimeUtils.stringToTimestamp(
UTF8String.fromString("1970-1-1 00:00:00"), defaultZoneId).get,
UTF8String.fromString("1970-1-1 00:00:00"), defaultZoneId, true).get,
DateTimeUtils.stringToTimestamp(
UTF8String.fromString("2016-7-26 00:00:00"), defaultZoneId).get)
UTF8String.fromString("2016-7-26 00:00:00"), defaultZoneId, true).get)
val decimalArray4_1 = Array(
BigDecimal("123.4").setScale(1, BigDecimal.RoundingMode.FLOOR),
BigDecimal("567.8").setScale(1, BigDecimal.RoundingMode.FLOOR))
Expand Down
Loading