Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy}
import org.apache.spark.sql.internal.connector.V1Function
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType.DAY
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
import org.apache.spark.util.collection.{Utils => CUtils}

Expand Down Expand Up @@ -318,7 +317,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveLambdaVariables ::
ResolveTimeZone ::
ResolveRandomSeed ::
ResolveBinaryArithmetic ::
ResolveUnion ::
ResolveRowLevelCommandAssignments ::
RewriteDeleteFromTable ::
Expand Down Expand Up @@ -348,110 +346,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
HandleSpecialCommand)
)

/**
* For [[Add]]:
* 1. if both side are interval, stays the same;
* 2. else if one side is date and the other is interval,
* turns it to [[DateAddInterval]];
* 3. else if one side is interval, turns it to [[TimeAdd]];
* 4. else if one side is date, turns it to [[DateAdd]] ;
* 5. else stays the same.
*
* For [[Subtract]]:
* 1. if both side are interval, stays the same;
* 2. else if the left side is date and the right side is interval,
* turns it to [[DateAddInterval(l, -r)]];
* 3. else if the right side is an interval, turns it to [[TimeAdd(l, -r)]];
* 4. else if one side is timestamp, turns it to [[SubtractTimestamps]];
* 5. else if the right side is date, turns it to [[DateDiff]]/[[SubtractDates]];
* 6. else if the left side is date, turns it to [[DateSub]];
* 7. else turns it to stays the same.
*
* For [[Multiply]]:
* 1. If one side is interval, turns it to [[MultiplyInterval]];
* 2. otherwise, stays the same.
*
* For [[Divide]]:
* 1. If the left side is interval, turns it to [[DivideInterval]];
* 2. otherwise, stays the same.
*/
object ResolveBinaryArithmetic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(BINARY_ARITHMETIC), ruleId) {
case p: LogicalPlan => p.transformExpressionsUpWithPruning(
_.containsPattern(BINARY_ARITHMETIC), ruleId) {
case a @ Add(l, r, mode) if a.childrenResolved => (l.dataType, r.dataType) match {
case (DateType, DayTimeIntervalType(DAY, DAY)) => DateAdd(l, ExtractANSIIntervalDays(r))
case (DateType, _: DayTimeIntervalType) => TimeAdd(Cast(l, TimestampType), r)
case (DayTimeIntervalType(DAY, DAY), DateType) => DateAdd(r, ExtractANSIIntervalDays(l))
case (_: DayTimeIntervalType, DateType) => TimeAdd(Cast(r, TimestampType), l)
case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(l, r)
case (_: YearMonthIntervalType, DateType) => DateAddYMInterval(r, l)
case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
TimestampAddYMInterval(l, r)
case (_: YearMonthIntervalType, TimestampType | TimestampNTZType) =>
TimestampAddYMInterval(r, l)
case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) => a
case (_: NullType, _: AnsiIntervalType) =>
a.copy(left = Cast(a.left, a.right.dataType))
case (_: AnsiIntervalType, _: NullType) =>
a.copy(right = Cast(a.right, a.left.dataType))
case (DateType, CalendarIntervalType) =>
DateAddInterval(l, r, ansiEnabled = mode == EvalMode.ANSI)
case (_, CalendarIntervalType | _: DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType)
case (CalendarIntervalType, DateType) =>
DateAddInterval(r, l, ansiEnabled = mode == EvalMode.ANSI)
case (CalendarIntervalType | _: DayTimeIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
case (DateType, dt) if dt != StringType => DateAdd(l, r)
case (dt, DateType) if dt != StringType => DateAdd(r, l)
case _ => a
}
case s @ Subtract(l, r, mode) if s.childrenResolved => (l.dataType, r.dataType) match {
case (DateType, DayTimeIntervalType(DAY, DAY)) =>
DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), mode == EvalMode.ANSI))
case (DateType, _: DayTimeIntervalType) =>
DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, mode == EvalMode.ANSI)))
case (DateType, _: YearMonthIntervalType) =>
DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI)))
case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI)))
case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) => s
case (_: NullType, _: AnsiIntervalType) =>
s.copy(left = Cast(s.left, s.right.dataType))
case (_: AnsiIntervalType, _: NullType) =>
s.copy(right = Cast(s.right, s.left.dataType))
case (DateType, CalendarIntervalType) =>
DatetimeSub(l, r, DateAddInterval(l,
UnaryMinus(r, mode == EvalMode.ANSI), ansiEnabled = mode == EvalMode.ANSI))
case (_, CalendarIntervalType | _: DayTimeIntervalType) =>
Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, mode == EvalMode.ANSI))), l.dataType)
case _ if AnyTimestampType.unapply(l) || AnyTimestampType.unapply(r) =>
SubtractTimestamps(l, r)
case (_, DateType) => SubtractDates(l, r)
case (DateType, dt) if dt != StringType => DateSub(l, r)
case _ => s
}
case m @ Multiply(l, r, mode) if m.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => MultiplyInterval(l, r, mode == EvalMode.ANSI)
case (_, CalendarIntervalType) => MultiplyInterval(r, l, mode == EvalMode.ANSI)
case (_: YearMonthIntervalType, _) => MultiplyYMInterval(l, r)
case (_, _: YearMonthIntervalType) => MultiplyYMInterval(r, l)
case (_: DayTimeIntervalType, _) => MultiplyDTInterval(l, r)
case (_, _: DayTimeIntervalType) => MultiplyDTInterval(r, l)
case _ => m
}
case d @ Divide(l, r, mode) if d.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => DivideInterval(l, r, mode == EvalMode.ANSI)
case (_: YearMonthIntervalType, _) => DivideYMInterval(l, r)
case (_: DayTimeIntervalType, _) => DivideDTInterval(l, r)
case _ => d
}
}
}
}

/**
* Substitute child plan with WindowSpecDefinitions.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType.DAY

/**
* In Spark ANSI mode, the type coercion rules are based on the type precedence lists of the input
Expand Down Expand Up @@ -77,6 +78,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
UnpivotCoercion ::
WidenSetOperationTypes ::
new AnsiCombinedTypeCoercionRule(
ResolveBinaryArithmetic ::
InConversion ::
PromoteStrings ::
DecimalPrecision ::
Expand Down Expand Up @@ -288,6 +290,111 @@ object AnsiTypeCoercion extends TypeCoercionBase {
}
}

/**
* For [[Add]]:
* 1. if both side are interval, stays the same;
* 2. else if one side is date and the other is interval,
* turns it to [[DateAddInterval]];
* 3. else if one side is interval, turns it to [[TimeAdd]];
* 4. else if one side is date, turns it to [[DateAdd]] ;
* 5. else stays the same.
*
* For [[Subtract]]:
* 1. if both side are interval, stays the same;
* 2. else if the left side is date and the right side is interval,
* turns it to [[DateAddInterval(l, -r)]];
* 3. else if the right side is an interval, turns it to [[TimeAdd(l, -r)]];
* 4. else if one side is timestamp, turns it to [[SubtractTimestamps]];
* 5. else if the right side is date, turns it to [[DateDiff]]/[[SubtractDates]];
* 6. else if the left side is date, turns it to [[DateSub]];
* 7. else turns it to stays the same.
*
* For [[Multiply]]:
* 1. If one side is interval, turns it to [[MultiplyInterval]];
* 2. otherwise, stays the same.
*
* For [[Divide]]:
* 1. If the left side is interval, turns it to [[DivideInterval]];
* 2. otherwise, stays the same.
*/
object ResolveBinaryArithmetic extends TypeCoercionRule {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to do some code sharing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override val transform: PartialFunction[Expression, Expression] = {
case a @ Add(l, r, mode) if a.childrenResolved => (l.dataType, r.dataType) match {
case (DateType, DayTimeIntervalType(DAY, DAY)) => DateAdd(l, ExtractANSIIntervalDays(r))
case (DateType, _: DayTimeIntervalType) => TimeAdd(Cast(l, TimestampType), r)
case (DayTimeIntervalType(DAY, DAY), DateType) => DateAdd(r, ExtractANSIIntervalDays(l))
case (_: DayTimeIntervalType, DateType) => TimeAdd(Cast(r, TimestampType), l)
case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(l, r)
case (_: YearMonthIntervalType, DateType) => DateAddYMInterval(r, l)
case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
TimestampAddYMInterval(l, r)
case (_: YearMonthIntervalType, TimestampType | TimestampNTZType) =>
TimestampAddYMInterval(r, l)
case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) => a
case (_: NullType, _: AnsiIntervalType) =>
a.copy(left = Cast(a.left, a.right.dataType))
case (_: AnsiIntervalType, _: NullType) =>
a.copy(right = Cast(a.right, a.left.dataType))
case (DateType, CalendarIntervalType) =>
DateAddInterval(l, r, ansiEnabled = mode == EvalMode.ANSI)
case (_: DatetimeType | _: AnsiIntervalType,
CalendarIntervalType | _: DayTimeIntervalType) =>
Cast(TimeAdd(l, r), l.dataType)
case (CalendarIntervalType, DateType) =>
DateAddInterval(r, l, ansiEnabled = mode == EvalMode.ANSI)
case (CalendarIntervalType | _: DayTimeIntervalType,
_: DatetimeType | _: AnsiIntervalType) =>
Cast(TimeAdd(r, l), r.dataType)
case (DateType, dt) if dt != StringType => DateAdd(l, r)
case (dt, DateType) if dt != StringType => DateAdd(r, l)
case _ => a
}
case s @ Subtract(l, r, mode) if s.childrenResolved => (l.dataType, r.dataType) match {
case (DateType, DayTimeIntervalType(DAY, DAY)) =>
DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), mode == EvalMode.ANSI))
case (DateType, _: DayTimeIntervalType) =>
DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, mode == EvalMode.ANSI)))
case (DateType, _: YearMonthIntervalType) =>
DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI)))
case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) =>
DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI)))
case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) => s
case (_: NullType, _: AnsiIntervalType) =>
s.copy(left = Cast(s.left, s.right.dataType))
case (_: AnsiIntervalType, _: NullType) =>
s.copy(right = Cast(s.right, s.left.dataType))
case (DateType, CalendarIntervalType) =>
DatetimeSub(l, r, DateAddInterval(l,
UnaryMinus(r, mode == EvalMode.ANSI), ansiEnabled = mode == EvalMode.ANSI))
case (_: DatetimeType | _: AnsiIntervalType,
CalendarIntervalType | _: DayTimeIntervalType) =>
Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, mode == EvalMode.ANSI))), l.dataType)
case _ if AnyTimestampType.unapply(l) || AnyTimestampType.unapply(r) =>
SubtractTimestamps(l, r)
case (_, DateType) => SubtractDates(l, r)
case (DateType, dt) if dt != StringType => DateSub(l, r)
case _ => s
}
case m @ Multiply(l, r, mode) if m.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => MultiplyInterval(l, r, mode == EvalMode.ANSI)
case (_, CalendarIntervalType) => MultiplyInterval(r, l, mode == EvalMode.ANSI)
case (_: YearMonthIntervalType, _) => MultiplyYMInterval(l, r)
case (_, _: YearMonthIntervalType) => MultiplyYMInterval(r, l)
case (_: DayTimeIntervalType, _) => MultiplyDTInterval(l, r)
case (_, _: DayTimeIntervalType) => MultiplyDTInterval(r, l)
case _ => m
}
case d @ Divide(l, r, mode) if d.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => DivideInterval(l, r, mode == EvalMode.ANSI)
case (_: YearMonthIntervalType, _) => DivideYMInterval(l, r)
case (_: DayTimeIntervalType, _) => DivideDTInterval(l, r)
case _ => d
}
}
}

object DateTimeOperations extends TypeCoercionRule {
override val transform: PartialFunction[Expression, Expression] = {
// Skip nodes who's children have not been resolved yet.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,7 @@ object TypeCoercion extends TypeCoercionBase {
UnpivotCoercion ::
WidenSetOperationTypes ::
new CombinedTypeCoercionRule(
ResolveBinaryArithmetic ::
InConversion ::
PromoteStrings ::
DecimalPrecision ::
Expand Down Expand Up @@ -1170,6 +1171,29 @@ object TypeCoercion extends TypeCoercionBase {
}
}

// For legacy support. For example: string type +/- interval.
object ResolveBinaryArithmetic extends TypeCoercionRule {
override val transform: PartialFunction[Expression, Expression] =
AnsiTypeCoercion.ResolveBinaryArithmetic.transform.andThen {
case a @ Add(l, r, _) if a.childrenResolved => (l.dataType, r.dataType) match {
case (StringType | _: NumericType | BinaryType | BooleanType,
CalendarIntervalType | _: DayTimeIntervalType) =>
Cast(TimeAdd(l, r), l.dataType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first child of TimeAdd has to be AnyTimestampType but it can be StringType | _: NumericType | BinaryType | BooleanType here..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ANSI mode. It should be DatetimeType and AnsiIntervalType:
For non-ANSI mode. It should be DatetimeType, AnsiIntervalType, StringType, NumericType, BinaryType and BooleanType.

case (CalendarIntervalType | _: DayTimeIntervalType,
StringType | _: NumericType | BinaryType | BooleanType) =>
Cast(TimeAdd(r, l), r.dataType)
case _ => a
}
case s @ Subtract(l, r, mode) if s.childrenResolved => (l.dataType, r.dataType) match {
case (StringType | _: NumericType | BinaryType | BooleanType,
CalendarIntervalType | _: DayTimeIntervalType) =>
Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, mode == EvalMode.ANSI))), l.dataType)
case _ => s
}
case other => other
}
}

object DateTimeOperations extends TypeCoercionRule {
override val transform: PartialFunction[Expression, Expression] = {
// Skip nodes who's children have not been resolved yet.
Expand Down
Loading