Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8186][SPARK-8187][SPARK-8194][SPARK-8198][SPARK-9133] [SPARK-9290] [SQL] functions: date_add, date_sub, add_months, months_between, time-interval calculation #7754

Closed
wants to merge 12 commits into from
76 changes: 64 additions & 12 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
__all__ += ['lag', 'lead', 'ntile']

__all__ += [
'date_format',
'date_format', 'date_add', 'date_sub', 'add_months', 'months_between',
'year', 'quarter', 'month', 'hour', 'minute', 'second',
'dayofmonth', 'dayofyear', 'weekofyear']

Expand Down Expand Up @@ -716,7 +716,7 @@ def date_format(dateCol, format):
[Row(date=u'04/08/2015')]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_format(dateCol, format))
return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format))


@since(1.5)
Expand All @@ -729,7 +729,7 @@ def year(col):
[Row(year=2015)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.year(col))
return Column(sc._jvm.functions.year(_to_java_column(col)))


@since(1.5)
Expand All @@ -742,7 +742,7 @@ def quarter(col):
[Row(quarter=2)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.quarter(col))
return Column(sc._jvm.functions.quarter(_to_java_column(col)))


@since(1.5)
Expand All @@ -755,7 +755,7 @@ def month(col):
[Row(month=4)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.month(col))
return Column(sc._jvm.functions.month(_to_java_column(col)))


@since(1.5)
Expand All @@ -768,7 +768,7 @@ def dayofmonth(col):
[Row(day=8)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.dayofmonth(col))
return Column(sc._jvm.functions.dayofmonth(_to_java_column(col)))


@since(1.5)
Expand All @@ -781,7 +781,7 @@ def dayofyear(col):
[Row(day=98)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.dayofyear(col))
return Column(sc._jvm.functions.dayofyear(_to_java_column(col)))


@since(1.5)
Expand All @@ -794,7 +794,7 @@ def hour(col):
[Row(hour=13)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.hour(col))
return Column(sc._jvm.functions.hour(_to_java_column(col)))


@since(1.5)
Expand All @@ -807,7 +807,7 @@ def minute(col):
[Row(minute=8)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.minute(col))
return Column(sc._jvm.functions.minute(_to_java_column(col)))


@since(1.5)
Expand All @@ -820,7 +820,7 @@ def second(col):
[Row(second=15)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.second(col))
return Column(sc._jvm.functions.second(_to_java_column(col)))


@since(1.5)
Expand All @@ -829,11 +829,63 @@ def weekofyear(col):
Extract the week number of a given date as integer.

>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
>>> df.select(weekofyear('a').alias('week')).collect()
>>> df.select(weekofyear(df.a).alias('week')).collect()
[Row(week=15)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.weekofyear(col))
return Column(sc._jvm.functions.weekofyear(_to_java_column(col)))


@since(1.5)
def date_add(start, days):
"""
Returns the date that is `days` days after `start`

>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
>>> df.select(date_add(df.d, 1).alias('d')).collect()
[Row(d=datetime.date(2015, 4, 9))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_add(_to_java_column(start), days))


@since(1.5)
def date_sub(start, days):
"""
Returns the date that is `days` days before `start`

>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
>>> df.select(date_sub(df.d, 1).alias('d')).collect()
[Row(d=datetime.date(2015, 4, 7))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_sub(_to_java_column(start), days))


@since(1.5)
def add_months(start, months):
"""
Returns the date that is `months` months after `start`

>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
>>> df.select(add_months(df.d, 1).alias('d')).collect()
[Row(d=datetime.date(2015, 5, 8))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.add_months(_to_java_column(start), months))


@since(1.5)
def months_between(date1, date2):
"""
Returns the number of months between date1 and date2.

>>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd'])
>>> df.select(months_between(df.t, df.d).alias('months')).collect()
[Row(months=3.9495967...)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))


@since(1.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,12 @@ object FunctionRegistry {
expression[Upper]("upper"),

// datetime functions
expression[AddMonths]("add_months"),
expression[CurrentDate]("current_date"),
expression[CurrentTimestamp]("current_timestamp"),
expression[DateAdd]("date_add"),
expression[DateFormatClass]("date_format"),
expression[DateSub]("date_sub"),
expression[DayOfMonth]("day"),
expression[DayOfYear]("dayofyear"),
expression[DayOfMonth]("dayofmonth"),
Expand All @@ -216,6 +219,7 @@ object FunctionRegistry {
expression[LastDay]("last_day"),
expression[Minute]("minute"),
expression[Month]("month"),
expression[MonthsBetween]("months_between"),
expression[NextDay]("next_day"),
expression[Quarter]("quarter"),
expression[Second]("second"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ object HiveTypeCoercion {
Division ::
PropagateTypes ::
ImplicitTypeCasts ::
DateTimeOperations ::
Nil

// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
Expand Down Expand Up @@ -638,6 +639,27 @@ object HiveTypeCoercion {
}
}

/**
* Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType
* to TimeAdd/TimeSub
*/
object DateTimeOperations extends Rule[LogicalPlan] {

private val acceptedTypes = Seq(DateType, TimestampType, StringType)

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case Add(l @ CalendarIntervalType(), r) if acceptedTypes.contains(r.dataType) =>
Cast(TimeAdd(r, l), r.dataType)
case Add(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) =>
Cast(TimeAdd(l, r), l.dataType)
case Subtract(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) =>
Cast(TimeSub(l, r), l.dataType)
}
}

/**
* Casts types according to the expected input types for [[Expression]]s.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

import scala.util.Try

Expand Down Expand Up @@ -63,6 +63,53 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
}
}

/**
* Adds a number of days to startdate.
*/
case class DateAdd(startDate: Expression, days: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = startDate
override def right: Expression = days

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)

override def dataType: DataType = DateType

override def nullSafeEval(start: Any, d: Any): Any = {
start.asInstanceOf[Int] + d.asInstanceOf[Int]
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (sd, d) => {
s"""${ev.primitive} = $sd + $d;"""
})
}
}

/**
* Subtracts a number of days to startdate.
*/
case class DateSub(startDate: Expression, days: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = startDate
override def right: Expression = days

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)

override def dataType: DataType = DateType

override def nullSafeEval(start: Any, d: Any): Any = {
start.asInstanceOf[Int] - d.asInstanceOf[Int]
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (sd, d) => {
s"""${ev.primitive} = $sd - $d;"""
})
}
}

case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
Expand Down Expand Up @@ -543,3 +590,109 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)

override def prettyName: String = "next_day"
}

/**
* Adds an interval to timestamp.
*/
case class TimeAdd(start: Expression, interval: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = start
override def right: Expression = interval

override def toString: String = s"$left + $right"
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)

override def dataType: DataType = TimestampType

override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
DateTimeUtils.timestampAddInterval(
start.asInstanceOf[Long], itvl.months, itvl.microseconds)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)"""
})
}
}

/**
* Subtracts an interval from timestamp.
*/
case class TimeSub(start: Expression, interval: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = start
override def right: Expression = interval

override def toString: String = s"$left - $right"
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)

override def dataType: DataType = TimestampType

override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
DateTimeUtils.timestampAddInterval(
start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
})
}
}

/**
* Returns the date that is num_months after start_date.
*/
case class AddMonths(startDate: Expression, numMonths: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = startDate
override def right: Expression = numMonths

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)

override def dataType: DataType = DateType

override def nullSafeEval(start: Any, months: Any): Any = {
DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int])
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, m) => {
s"""$dtu.dateAddMonths($sd, $m)"""
})
}
}

/**
* Returns number of months between dates date1 and date2.
*/
case class MonthsBetween(date1: Expression, date2: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = date1
override def right: Expression = date2

override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType)

override def dataType: DataType = DoubleType

override def nullSafeEval(t1: Any, t2: Any): Any = {
DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long])
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (l, r) => {
s"""$dtu.monthsBetween($l, $r)"""
})
}
}
Loading