Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35839][SQL] New SQL function: to_timestamp_ntz #32995

Closed
wants to merge 17 commits into from
Expand Up @@ -541,6 +541,7 @@ object FunctionRegistry {
expression[ParseToDate]("to_date"),
expression[ToUnixTimestamp]("to_unix_timestamp"),
expression[ToUTCTimestamp]("to_utc_timestamp"),
expression[ParseToTimestampWithoutTZ]("to_timestamp_ntz"),
expression[TruncDate]("trunc"),
expression[TruncTimestamp]("date_trunc"),
expression[UnixTimestamp]("unix_timestamp"),
Expand Down
Expand Up @@ -66,6 +66,10 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression {

protected def isParsing: Boolean

// Whether the timestamp formatter is for TimestampWithoutTZType.
// If yes, the formatter is always `Iso8601TimestampFormatter`.
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
protected def forTimestampWithoutTZ: Boolean = false

@transient final protected lazy val formatterOption: Option[TimestampFormatter] =
if (formatString.foldable) {
Option(formatString.eval()).map(fmt => getFormatter(fmt.toString))
Expand All @@ -76,7 +80,8 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression {
format = fmt,
zoneId = zoneId,
legacyFormat = SIMPLE_DATE_FORMAT,
isParsing = isParsing)
isParsing = isParsing,
forTimestampWithoutTZ = forTimestampWithoutTZ)
}
}

Expand Down Expand Up @@ -995,6 +1000,77 @@ case class UnixTimestamp(
copy(timeExp = newLeft, format = newRight)
}

case class GetTimestampWithoutTZ(
left: Expression,
right: Expression,
timeZoneId: Option[String] = None,
failOnError: Boolean = SQLConf.get.ansiEnabled) extends ToTimestamp {

override val forTimestampWithoutTZ: Boolean = true

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)

override def dataType: DataType = TimestampWithoutTZType

override protected def downScaleFactor: Long = 1

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Some(timeZoneId))

override protected def withNewChildrenInternal(
newLeft: Expression,
newRight: Expression): Expression =
copy(left = newLeft, right = newRight)
}


/**
* Parses a column to a timestamp without time zone based on the supplied format.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(timestamp_str[, fmt]) - Parses the `timestamp_str` expression with the `fmt` expression
to a timestamp without time zone. Returns null with invalid input. By default, it follows casting rules to
a timestamp if the `fmt` is omitted.
""",
arguments = """
Arguments:
* timestamp_str - A string to be parsed to timestamp without time zone.
* fmt - Timestamp format pattern to follow. See <a href="https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html">Datetime Patterns</a> for valid
date and time format patterns.
""",
examples = """
Examples:
> SELECT _FUNC_('2016-12-31 00:12:00');
2016-12-31 00:12:00
> SELECT _FUNC_('2016-12-31', 'yyyy-MM-dd');
2016-12-31 00:00:00
""",
group = "datetime_funcs",
since = "3.2.0")
// scalastyle:on line.size.limit
case class ParseToTimestampWithoutTZ(
left: Expression,
format: Option[Expression],
child: Expression) extends RuntimeReplaceable {

def this(left: Expression, format: Expression) = {
this(left, Option(format), GetTimestampWithoutTZ(left, format))
}

def this(left: Expression) = this(left, None, Cast(left, TimestampWithoutTZType))

override def flatArguments: Iterator[Any] = Iterator(left, format)
override def exprsReplaced: Seq[Expression] = left +: format.toSeq

override def prettyName: String = "to_timestamp_ntz"
override def dataType: DataType = TimestampWithoutTZType

override protected def withNewChildInternal(newChild: Expression): ParseToTimestampWithoutTZ =
copy(child = newChild)
}

abstract class ToTimestamp
extends BinaryExpression with TimestampFormatterHelper with ExpectsInputTypes {

Expand Down Expand Up @@ -1037,7 +1113,11 @@ abstract class ToTimestamp
} else {
val formatter = formatterOption.getOrElse(getFormatter(fmt.toString))
try {
formatter.parse(t.asInstanceOf[UTF8String].toString) / downScaleFactor
if (forTimestampWithoutTZ) {
formatter.parseWithoutTimeZone(t.asInstanceOf[UTF8String].toString)
} else {
formatter.parse(t.asInstanceOf[UTF8String].toString) / downScaleFactor
}
} catch {
case e if isParseError(e) =>
if (failOnError) {
Expand All @@ -1054,14 +1134,25 @@ abstract class ToTimestamp
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = CodeGenerator.javaType(dataType)
val parseErrorBranch = if (failOnError) "throw e;" else s"${ev.isNull} = true;"
val parseMethod = if (forTimestampWithoutTZ) {
"parseWithoutTimeZone"
} else {
"parse"
}
val downScaleCode = if (forTimestampWithoutTZ) {
""
} else {
s"/ $downScaleFactor"
}

left.dataType match {
case StringType => formatterOption.map { fmt =>
val df = classOf[TimestampFormatter].getName
val formatterName = ctx.addReferenceObj("formatter", fmt, df)
nullSafeCodeGen(ctx, ev, (datetimeStr, _) =>
s"""
|try {
| ${ev.value} = $formatterName.parse($datetimeStr.toString()) / $downScaleFactor;
| ${ev.value} = $formatterName.$parseMethod($datetimeStr.toString()) $downScaleCode;
|} catch (java.time.DateTimeException e) {
| $parseErrorBranch
|} catch (java.time.format.DateTimeParseException e) {
Expand All @@ -1083,7 +1174,7 @@ abstract class ToTimestamp
| $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT(),
| true);
|try {
| ${ev.value} = $timestampFormatter.parse($string.toString()) / $downScaleFactor;
| ${ev.value} = $timestampFormatter.$parseMethod($string.toString()) $downScaleCode;
|} catch (java.time.format.DateTimeParseException e) {
| $parseErrorBranch
|} catch (java.time.DateTimeException e) {
Expand Down
Expand Up @@ -73,7 +73,7 @@ trait DateTimeFormatterHelper {
}
}

private def toLocalTime(accessor: TemporalAccessor): LocalTime = {
protected def toLocalTime(accessor: TemporalAccessor): LocalTime = {
val localTime = accessor.query(TemporalQueries.localTime())
// If all the time fields are specified, return the local time directly.
if (localTime != null) return localTime
Expand Down
Expand Up @@ -50,9 +50,36 @@ sealed trait TimestampFormatter extends Serializable {
@throws(classOf[DateTimeException])
def parse(s: String): Long

/**
* Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local time.
*
* @param s - string with timestamp to parse
* @return microseconds since epoch.
* @throws ParseException can be thrown by legacy parser
* @throws DateTimeParseException can be thrown by new parser
* @throws DateTimeException unable to obtain local date or time
* @throws IllegalStateException The formatter for timestamp without time zone should always
* implement this method. The exception should never be hit.
*/
@throws(classOf[ParseException])
@throws(classOf[DateTimeParseException])
@throws(classOf[DateTimeException])
@throws(classOf[IllegalStateException])
def parseWithoutTimeZone(s: String): Long =
throw new IllegalStateException(
s"The method `parseWithoutTimeZone(s: String)` should be implemented in the formatter " +
"of timestamp without time zone")

def format(us: Long): String
def format(ts: Timestamp): String
def format(instant: Instant): String

@throws(classOf[IllegalStateException])
def format(localDateTime: LocalDateTime): String =
throw new IllegalStateException(
s"The method `format(localDateTime: LocalDateTime)` should be implemented in the formatter " +
"of timestamp without time zone")

def validatePatternString(): Unit
}

Expand Down Expand Up @@ -84,6 +111,15 @@ class Iso8601TimestampFormatter(
} catch checkParsedDiff(s, legacyFormatter.parse)
}

override def parseWithoutTimeZone(s: String): Long = {
try {
val parsed = formatter.parse(s)
val localDate = toLocalDate(parsed)
val localTime = toLocalTime(parsed)
DateTimeUtils.localDateTimeToMicros(LocalDateTime.of(localDate, localTime))
} catch checkParsedDiff(s, legacyFormatter.parse)
}

override def format(instant: Instant): String = {
try {
formatter.withZone(zoneId).format(instant)
Expand All @@ -100,6 +136,10 @@ class Iso8601TimestampFormatter(
legacyFormatter.format(ts)
}

override def format(localDateTime: LocalDateTime): String = {
localDateTime.format(formatter)
}

override def validatePatternString(): Unit = {
try {
formatter
Expand Down Expand Up @@ -286,9 +326,10 @@ object TimestampFormatter {
zoneId: ZoneId,
locale: Locale = defaultLocale,
legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT,
isParsing: Boolean): TimestampFormatter = {
isParsing: Boolean,
forTimestampWithoutTZ: Boolean = false): TimestampFormatter = {
val pattern = format.getOrElse(defaultPattern)
val formatter = if (SQLConf.get.legacyTimeParserPolicy == LEGACY) {
val formatter = if (SQLConf.get.legacyTimeParserPolicy == LEGACY && !forTimestampWithoutTZ) {
getLegacyFormatter(pattern, zoneId, locale, legacyFormat)
} else {
new Iso8601TimestampFormatter(
Expand Down Expand Up @@ -330,6 +371,16 @@ object TimestampFormatter {
getFormatter(Some(format), zoneId, defaultLocale, legacyFormat, isParsing)
}

def apply(
format: String,
zoneId: ZoneId,
legacyFormat: LegacyDateFormat,
isParsing: Boolean,
forTimestampWithoutTZ: Boolean): TimestampFormatter = {
getFormatter(Some(format), zoneId, defaultLocale, legacyFormat, isParsing,
forTimestampWithoutTZ)
}

def apply(
format: String,
zoneId: ZoneId,
Expand Down
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}
import java.text.{ParseException, SimpleDateFormat}
import java.time.{DateTimeException, Duration, Instant, LocalDate, Period, ZoneId}
import java.time.{DateTimeException, Duration, Instant, LocalDate, LocalDateTime, Period, ZoneId}
import java.time.format.DateTimeParseException
import java.time.temporal.ChronoUnit
import java.util.{Calendar, Locale, TimeZone}
Expand Down Expand Up @@ -1283,6 +1283,33 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("to_timestamp_ntz") {
val specialTs = Seq(
"0001-01-01T00:00:00", // the fist timestamp of Common Era
"1582-10-15T23:59:59", // the cutover date from Julian to Gregorian calendar
"1970-01-01T00:00:00", // the epoch timestamp
"9999-12-31T23:59:59" // the last supported timestamp according to SQL standard
)
outstandingZoneIds.foreach { zoneId =>
withDefaultTimeZone(zoneId) {
specialTs.foreach { s =>
val input = s.replace("T", " ")
val expectedTs = LocalDateTime.parse(s)
checkEvaluation(
GetTimestampWithoutTZ(Literal(input), Literal("yyyy-MM-dd HH:mm:ss")), expectedTs)
Seq(".123456", ".123456PST", ".123456CST", ".123456UTC").foreach { segment =>
val input2 = input + segment
val expectedTs2 = LocalDateTime.parse(s + ".123456")
checkEvaluation(
GetTimestampWithoutTZ(Literal(input2), Literal("yyyy-MM-dd HH:mm:ss.SSSSSS[zzz]")),
expectedTs2)
}
}
}
}

}

test("to_timestamp exception mode") {
withSQLConf(SQLConf.LEGACY_TIME_PARSER_POLICY.key -> "legacy") {
checkEvaluation(
Expand Down
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution

import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.{Duration, Instant, LocalDate, Period}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
Expand Down Expand Up @@ -101,6 +101,7 @@ object HiveResult {
case (ld: LocalDate, DateType) => formatters.date.format(ld)
case (t: Timestamp, TimestampType) => formatters.timestamp.format(t)
case (i: Instant, TimestampType) => formatters.timestamp.format(i)
case (l: LocalDateTime, TimestampWithoutTZType) => formatters.timestamp.format(l)
case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8)
case (decimal: java.math.BigDecimal, DecimalType()) => decimal.toPlainString
case (n, _: NumericType) => n.toString
Expand Down
@@ -1,6 +1,6 @@
<!-- Automatically generated by ExpressionsSchemaSuite -->
## Summary
- Number of queries: 355
- Number of queries: 356
- Number of expressions that missing example: 13
- Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,window
## Schema of Built-in Functions
Expand Down Expand Up @@ -206,6 +206,7 @@
| org.apache.spark.sql.catalyst.expressions.Overlay | overlay | SELECT overlay('Spark SQL' PLACING '_' FROM 6) | struct<overlay(Spark SQL, _, 6, -1):string> |
| org.apache.spark.sql.catalyst.expressions.ParseToDate | to_date | SELECT to_date('2009-07-30 04:17:52') | struct<to_date(2009-07-30 04:17:52):date> |
| org.apache.spark.sql.catalyst.expressions.ParseToTimestamp | to_timestamp | SELECT to_timestamp('2016-12-31 00:12:00') | struct<to_timestamp(2016-12-31 00:12:00):timestamp> |
| org.apache.spark.sql.catalyst.expressions.ParseToTimestampWithoutTZ | to_timestamp_ntz | SELECT to_timestamp_ntz('2016-12-31 00:12:00') | struct<to_timestamp_ntz(2016-12-31 00:12:00):timestamp without time zone> |
| org.apache.spark.sql.catalyst.expressions.ParseUrl | parse_url | SELECT parse_url('http://spark.apache.org/path?query=1', 'HOST') | struct<parse_url(http://spark.apache.org/path?query=1, HOST):string> |
| org.apache.spark.sql.catalyst.expressions.PercentRank | percent_rank | SELECT a, b, percent_rank(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct<a:string,b:int,PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):double> |
| org.apache.spark.sql.catalyst.expressions.Pi | pi | SELECT pi() | struct<PI():double> |
Expand Down Expand Up @@ -360,4 +361,4 @@
| org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()') | struct<xpath(<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>, a/b/text()):array<string>> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_long(<a><b>1</b><b>2</b></a>, sum(a/b)):bigint> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_short(<a><b>1</b><b>2</b></a>, sum(a/b)):smallint> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |