New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-36632][SQL] DivideYMInterval and DivideDTInterval should throw the same exception when divide by zero. #33889
Changes from 12 commits
631da9e
218824c
0325bfc
cc4d5be
97b646a
4600e4f
1eebc3a
0bd83da
89af674
a2a90a0
823048a
b00d8d8
a3e3059
7df29e5
d202787
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -612,6 +612,13 @@ case class DivideYMInterval( | |
override def inputTypes: Seq[AbstractDataType] = Seq(YearMonthIntervalType, NumericType) | ||
override def dataType: DataType = YearMonthIntervalType() | ||
|
||
@transient | ||
private lazy val checkFunc: (Any) => Unit = right.dataType match { | ||
case _: DecimalType => (num) => | ||
if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError() | ||
case _ => (num) => if (num == 0) throw QueryExecutionErrors.divideByZeroError() | ||
} | ||
|
||
@transient | ||
private lazy val evalFunc: (Int, Any) => Any = right.dataType match { | ||
case LongType => (months: Int, num) => | ||
|
@@ -629,6 +636,7 @@ case class DivideYMInterval( | |
|
||
override def nullSafeEval(interval: Any, num: Any): Any = { | ||
checkDivideOverflow(interval.asInstanceOf[Int], Int.MinValue, right, num) | ||
checkFunc(num) | ||
evalFunc(interval.asInstanceOf[Int], num) | ||
} | ||
|
||
|
@@ -641,6 +649,11 @@ case class DivideYMInterval( | |
val javaType = CodeGenerator.javaType(dataType) | ||
val months = left.genCode(ctx) | ||
val num = right.genCode(ctx) | ||
val checkDivideByZero = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can add
to avoid duplicate code here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and we can move these util functions to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
s""" | ||
|if (${num.value} == 0) | ||
| throw QueryExecutionErrors.divideByZeroError(); | ||
|""".stripMargin | ||
val checkIntegralDivideOverflow = | ||
s""" | ||
|if (${months.value} == ${Int.MinValue} && ${num.value} == -1) | ||
|
@@ -650,17 +663,26 @@ case class DivideYMInterval( | |
// Similarly to non-codegen code. The result of `divide(Int, Long, ...)` must fit to `Int`. | ||
// Casting to `Int` is safe here. | ||
s""" | ||
|$checkDivideByZero | ||
|$checkIntegralDivideOverflow | ||
|${ev.value} = ($javaType)$math.divide($m, $n, java.math.RoundingMode.HALF_UP); | ||
""".stripMargin) | ||
case _: DecimalType => | ||
defineCodeGen(ctx, ev, (m, n) => | ||
s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" + | ||
".setScale(0, java.math.RoundingMode.HALF_UP).intValueExact()") | ||
nullSafeCodeGen(ctx, ev, (m, n) => | ||
s""" | ||
|if ($n.isZero()) | ||
| throw QueryExecutionErrors.divideByZeroError(); | ||
|${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal() | ||
| .setScale(0, java.math.RoundingMode.HALF_UP).intValueExact(); | ||
""".stripMargin) | ||
case _: FractionalType => | ||
val math = classOf[DoubleMath].getName | ||
defineCodeGen(ctx, ev, (m, n) => | ||
s"$math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP)") | ||
nullSafeCodeGen(ctx, ev, (m, n) => | ||
s""" | ||
|if ($n == 0) | ||
| throw QueryExecutionErrors.divideByZeroError(); | ||
|${ev.value} = $math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP); | ||
""".stripMargin) | ||
} | ||
|
||
override def toString: String = s"($left / $right)" | ||
|
@@ -683,6 +705,13 @@ case class DivideDTInterval( | |
override def inputTypes: Seq[AbstractDataType] = Seq(DayTimeIntervalType, NumericType) | ||
override def dataType: DataType = DayTimeIntervalType() | ||
|
||
@transient | ||
private lazy val checkFunc: (Any) => Unit = right.dataType match { | ||
case _: DecimalType => (num) => | ||
if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError() | ||
case _ => (num) => if (num == 0) throw QueryExecutionErrors.divideByZeroError() | ||
} | ||
|
||
@transient | ||
private lazy val evalFunc: (Long, Any) => Any = right.dataType match { | ||
case _: IntegralType => (micros: Long, num) => | ||
|
@@ -696,6 +725,7 @@ case class DivideDTInterval( | |
|
||
override def nullSafeEval(interval: Any, num: Any): Any = { | ||
checkDivideOverflow(interval.asInstanceOf[Long], Long.MinValue, right, num) | ||
checkFunc(num) | ||
evalFunc(interval.asInstanceOf[Long], num) | ||
} | ||
|
||
|
@@ -704,24 +734,38 @@ case class DivideDTInterval( | |
val math = classOf[LongMath].getName | ||
val micros = left.genCode(ctx) | ||
val num = right.genCode(ctx) | ||
val checkDivideByZero = | ||
s""" | ||
|if (${num.value} == 0) | ||
| throw QueryExecutionErrors.divideByZeroError(); | ||
|""".stripMargin | ||
val checkIntegralDivideOverflow = | ||
s""" | ||
|if (${micros.value} == ${Long.MinValue}L && ${num.value} == -1L) | ||
| throw QueryExecutionErrors.overflowInIntegralDivideError(); | ||
|""".stripMargin | ||
nullSafeCodeGen(ctx, ev, (m, n) => | ||
s""" | ||
|$checkDivideByZero | ||
|$checkIntegralDivideOverflow | ||
|${ev.value} = $math.divide($m, $n, java.math.RoundingMode.HALF_UP); | ||
""".stripMargin) | ||
case _: DecimalType => | ||
defineCodeGen(ctx, ev, (m, n) => | ||
s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" + | ||
".setScale(0, java.math.RoundingMode.HALF_UP).longValueExact()") | ||
nullSafeCodeGen(ctx, ev, (m, n) => | ||
s""" | ||
|if ($n.isZero()) | ||
| throw QueryExecutionErrors.divideByZeroError(); | ||
|${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal() | ||
| .setScale(0, java.math.RoundingMode.HALF_UP).longValueExact(); | ||
""".stripMargin) | ||
case _: FractionalType => | ||
val math = classOf[DoubleMath].getName | ||
defineCodeGen(ctx, ev, (m, n) => | ||
s"$math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP)") | ||
nullSafeCodeGen(ctx, ev, (m, n) => | ||
s""" | ||
|if ($n == 0) | ||
| throw QueryExecutionErrors.divideByZeroError(); | ||
|${ev.value} = $math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP); | ||
""".stripMargin) | ||
} | ||
|
||
override def toString: String = s"($left / $right)" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
... val divideByZeroCheck: Any => Unit = ...