Skip to content

Commit

Permalink
[SPARK-22829] Add new built-in function date_trunc()
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Adding date_trunc() as a built-in function.
`date_trunc` is common in other databases, but Spark or Hive does not have support for this. `date_trunc` is commonly used by data scientists and business intelligence application such as Superset (https://github.com/apache/incubator-superset).
We do have `trunc` but this only works with 'MONTH' and 'YEAR' level on the DateType input.

date_trunc() in other databases:
AWS Redshift: http://docs.aws.amazon.com/redshift/latest/dg/r_DATE_TRUNC.html
PostgreSQL: https://www.postgresql.org/docs/9.1/static/functions-datetime.html
Presto: https://prestodb.io/docs/current/functions/datetime.html

## How was this patch tested?

Unit tests

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Youngbin Kim <ykim828@hotmail.com>

Closes #20015 from youngbink/date_trunc.
  • Loading branch information
youngbink authored and gatorsmile committed Dec 20, 2017
1 parent 3a7494d commit 6e36d8d
Show file tree
Hide file tree
Showing 8 changed files with 445 additions and 52 deletions.
20 changes: 19 additions & 1 deletion python/pyspark/sql/functions.py
Expand Up @@ -1099,7 +1099,7 @@ def trunc(date, format):
"""
Returns date truncated to the unit specified by the format.
:param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
:param format: 'year', 'yyyy', 'yy' or 'month', 'mon', 'mm'
>>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
>>> df.select(trunc(df.d, 'year').alias('year')).collect()
Expand All @@ -1111,6 +1111,24 @@ def trunc(date, format):
return Column(sc._jvm.functions.trunc(_to_java_column(date), format))


@since(2.3)
def date_trunc(format, timestamp):
"""
Returns timestamp truncated to the unit specified by the format.
:param format: 'year', 'yyyy', 'yy', 'month', 'mon', 'mm',
'day', 'dd', 'hour', 'minute', 'second', 'week', 'quarter'
>>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t'])
>>> df.select(date_trunc('year', df.t).alias('year')).collect()
[Row(year=datetime.datetime(1997, 1, 1, 0, 0))]
>>> df.select(date_trunc('mon', df.t).alias('month')).collect()
[Row(month=datetime.datetime(1997, 2, 1, 0, 0))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_trunc(format, _to_java_column(timestamp)))


@since(1.5)
def next_day(date, dayOfWeek):
"""
Expand Down
Expand Up @@ -392,6 +392,7 @@ object FunctionRegistry {
expression[ToUnixTimestamp]("to_unix_timestamp"),
expression[ToUTCTimestamp]("to_utc_timestamp"),
expression[TruncDate]("trunc"),
expression[TruncTimestamp]("date_trunc"),
expression[UnixTimestamp]("unix_timestamp"),
expression[DayOfWeek]("dayofweek"),
expression[WeekOfYear]("weekofyear"),
Expand Down
Expand Up @@ -1294,87 +1294,181 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child:
override def dataType: DataType = TimestampType
}

/**
* Returns date truncated to the unit specified by the format.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.",
examples = """
Examples:
> SELECT _FUNC_('2009-02-12', 'MM');
2009-02-01
> SELECT _FUNC_('2015-10-27', 'YEAR');
2015-01-01
""",
since = "1.5.0")
// scalastyle:on line.size.limit
case class TruncDate(date: Expression, format: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = date
override def right: Expression = format

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
val instant: Expression
val format: Expression
override def nullable: Boolean = true
override def prettyName: String = "trunc"

private lazy val truncLevel: Int =
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])

override def eval(input: InternalRow): Any = {
/**
* @param input internalRow (time)
* @param maxLevel Maximum level that can be used for truncation (e.g MONTH for Date input)
* @param truncFunc function: (time, level) => time
*/
protected def evalHelper(input: InternalRow, maxLevel: Int)(
truncFunc: (Any, Int) => Any): Any = {
val level = if (format.foldable) {
truncLevel
} else {
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
}
if (level == -1) {
// unknown format
if (level == DateTimeUtils.TRUNC_INVALID || level > maxLevel) {
// unknown format or too large level
null
} else {
val d = date.eval(input)
if (d == null) {
val t = instant.eval(input)
if (t == null) {
null
} else {
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
truncFunc(t, level)
}
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
protected def codeGenHelper(
ctx: CodegenContext,
ev: ExprCode,
maxLevel: Int,
orderReversed: Boolean = false)(
truncFunc: (String, String) => String)
: ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")

if (format.foldable) {
if (truncLevel == -1) {
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
} else {
val d = date.genCode(ctx)
val t = instant.genCode(ctx)
val truncFuncStr = truncFunc(t.value, truncLevel.toString)
ev.copy(code = s"""
${d.code}
boolean ${ev.isNull} = ${d.isNull};
${t.code}
boolean ${ev.isNull} = ${t.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.truncDate(${d.value}, $truncLevel);
${ev.value} = $dtu.$truncFuncStr;
}""")
}
} else {
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
nullSafeCodeGen(ctx, ev, (left, right) => {
val form = ctx.freshName("form")
val (dateVal, fmt) = if (orderReversed) {
(right, left)
} else {
(left, right)
}
val truncFuncStr = truncFunc(dateVal, form)
s"""
int $form = $dtu.parseTruncLevel($fmt);
if ($form == -1) {
if ($form == -1 || $form > $maxLevel) {
${ev.isNull} = true;
} else {
${ev.value} = $dtu.truncDate($dateVal, $form);
${ev.value} = $dtu.$truncFuncStr
}
"""
})
}
}
}

/**
* Returns date truncated to the unit specified by the format.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.
`fmt` should be one of ["year", "yyyy", "yy", "mon", "month", "mm"]
""",
examples = """
Examples:
> SELECT _FUNC_('2009-02-12', 'MM');
2009-02-01
> SELECT _FUNC_('2015-10-27', 'YEAR');
2015-01-01
""",
since = "1.5.0")
// scalastyle:on line.size.limit
case class TruncDate(date: Expression, format: Expression)
extends TruncInstant {
override def left: Expression = date
override def right: Expression = format

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
override def prettyName: String = "trunc"
override val instant = date

override def eval(input: InternalRow): Any = {
evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (d: Any, level: Int) =>
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (date: String, fmt: String) =>
s"truncDate($date, $fmt);"
}
}
}

/**
* Returns timestamp truncated to the unit specified by the format.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(fmt, ts) - Returns timestamp `ts` truncated to the unit specified by the format model `fmt`.
`fmt` should be one of ["YEAR", "YYYY", "YY", "MON", "MONTH", "MM", "DAY", "DD", "HOUR", "MINUTE", "SECOND", "WEEK", "QUARTER"]
""",
examples = """
Examples:
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR');
2015-01-01T00:00:00
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM');
2015-03-01T00:00:00
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD');
2015-03-05T00:00:00
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR');
2015-03-05T09:00:00
""",
since = "2.3.0")
// scalastyle:on line.size.limit
case class TruncTimestamp(
format: Expression,
timestamp: Expression,
timeZoneId: Option[String] = None)
extends TruncInstant with TimeZoneAwareExpression {
override def left: Expression = format
override def right: Expression = timestamp

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType)
override def dataType: TimestampType = TimestampType
override def prettyName: String = "date_trunc"
override val instant = timestamp
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

def this(format: Expression, timestamp: Expression) = this(format, timestamp, None)

override def eval(input: InternalRow): Any = {
evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_SECOND) { (t: Any, level: Int) =>
DateTimeUtils.truncTimestamp(t.asInstanceOf[Long], level, timeZone)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val tz = ctx.addReferenceObj("timeZone", timeZone)
codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_SECOND, true) {
(date: String, fmt: String) =>
s"truncTimestamp($date, $fmt, $tz);"
}
}
}

/**
* Returns the number of days from startDate to endDate.
*/
Expand Down

0 comments on commit 6e36d8d

Please sign in to comment.