Skip to content

Commit

Permalink
fix decimal overflow issues
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jul 7, 2020
1 parent 42f01e3 commit 3717fc6
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ public void setDecimal(int ordinal, Decimal value, int precision) {
Platform.putLong(baseObject, baseOffset + cursor, 0L);
Platform.putLong(baseObject, baseOffset + cursor + 8, 0L);

if (value == null) {
if (value == null || !value.changePrecision(precision, value.scale())) {
setNullAt(ordinal);
// keep the offset for future update
Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,39 +58,50 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
case _ => DoubleType
}

private lazy val sumDataType = resultType

private lazy val sum = AttributeReference("sum", sumDataType)()
private lazy val sum = AttributeReference("sum", resultType)()

private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)()

private lazy val zero = Literal.default(sumDataType)
private lazy val zero = Literal.default(resultType)

override lazy val aggBufferAttributes = resultType match {
case _: DecimalType => sum :: isEmpty :: Nil
case _ => sum :: Nil
}

override lazy val initialValues: Seq[Expression] = resultType match {
case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType))
case _: DecimalType => Seq(zero, Literal(true, BooleanType))
case _ => Seq(Literal(null, resultType))
}

override lazy val updateExpressions: Seq[Expression] = {
if (child.nullable) {
val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
resultType match {
case _: DecimalType =>
Seq(updateSumExpr, isEmpty && child.isNull)
case _ => Seq(updateSumExpr)
}
} else {
val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType)
resultType match {
case _: DecimalType =>
Seq(updateSumExpr, Literal(false, BooleanType))
case _ => Seq(updateSumExpr)
}
resultType match {
case _: DecimalType =>
// For decimal type, the initial value of `sum` is 0. We need to keep `sum` unchanged if
// the input is null, as SUM function ignores null input. The `sum` can only be null if
// overflow happens under non-ansi mode.
val sumExpr = if (child.nullable) {
If(child.isNull, sum, sum + KnownNotNull(child).cast(resultType))
} else {
sum + child.cast(resultType)
}
// The buffer becomes non-empty after seeing the first not-null input.
val isEmptyExpr = if (child.nullable) {
isEmpty && child.isNull
} else {
Literal(false, BooleanType)
}
Seq(sumExpr, isEmptyExpr)
case _ =>
// For non-decimal type, the initial value of `sum` is null, which indicates no value.
// We need `coalesce(sum, zero)` to start summing values. And we need an outer `coalesce`
// in case the input is nullable. The `sum` can only be null if there is no value, as
// non-decimal type can produce overflowed value under non-ansi mode.
if (child.nullable) {
Seq(coalesce(coalesce(sum, zero) + child.cast(resultType), sum))
} else {
Seq(coalesce(sum, zero) + child.cast(resultType))
}
}
}

Expand All @@ -107,15 +118,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
* means we have seen atleast a value that was not null.
*/
override lazy val mergeExpressions: Seq[Expression] = {
val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
resultType match {
case _: DecimalType =>
val inputOverflow = !isEmpty.right && sum.right.isNull
val bufferOverflow = !isEmpty.left && sum.left.isNull
val inputOverflow = !isEmpty.right && sum.right.isNull
Seq(
If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr),
If(
bufferOverflow || inputOverflow,
Literal.create(null, resultType),
// If both the buffer and the input do not overflow, just add them, as they can't be
// null. See the comments inside `updateExpressions`: `sum` can only be null if
// overflow happens.
KnownNotNull(sum.left) + KnownNotNull(sum.right)),
isEmpty.left && isEmpty.right)
case _ => Seq(mergeSumExpr)
case _ => Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left))
}
}

Expand All @@ -128,7 +144,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
*/
override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType =>
If(isEmpty, Literal.create(null, sumDataType),
If(isEmpty, Literal.create(null, resultType),
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled))
case _ => sum
}
Expand Down
14 changes: 3 additions & 11 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,22 +195,14 @@ class DataFrameSuite extends QueryTest
private def assertDecimalSumOverflow(
df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = {
if (!ansiEnabled) {
try {
checkAnswer(df, expectedAnswer)
} catch {
case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] =>
// This is an existing bug that we can write overflowed decimal to UnsafeRow but fail
// to read it.
assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38"))
}
checkAnswer(df, expectedAnswer)
} else {
val e = intercept[SparkException] {
df.collect
df.collect()
}
assert(e.getCause.isInstanceOf[ArithmeticException])
assert(e.getCause.getMessage.contains("cannot be represented as Decimal") ||
e.getCause.getMessage.contains("Overflow in sum of decimals") ||
e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38"))
e.getCause.getMessage.contains("Overflow in sum of decimals"))
}
}

Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,14 @@ class UnsafeRowSuite extends SparkFunSuite {
// Makes sure hashCode on unsafe array won't crash
unsafeRow.getArray(0).hashCode()
}

test("SPARK-32018: setDecimal with overflowed value") {
val d1 = new Decimal().set(BigDecimal("10000000000000000000")).toPrecision(38, 18)
val row = InternalRow.apply(d1)
val unsafeRow = UnsafeProjection.create(Array[DataType](DecimalType(38, 18))).apply(row)
assert(unsafeRow.getDecimal(0, 38, 18) === d1)
val d2 = (d1 * Decimal(10)).toPrecision(39, 18)
unsafeRow.setDecimal(0, d2, 38)
assert(unsafeRow.getDecimal(0, 38, 18) === null)
}
}

0 comments on commit 3717fc6

Please sign in to comment.