From 8c5bee599d58fdb6d1c0335ec2de872f8256f0ba Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 9 Jul 2020 15:56:40 +0900 Subject: [PATCH] [SPARK-28067][SPARK-32018] Fix decimal overflow issues ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/27627 to fix the remaining issues. There are 2 issues fixed in this PR: 1. `UnsafeRow.setDecimal` can set an overflowed decimal and causes an error when reading it. The expected behavior is to return null. 2. The update/merge expression for decimal type in `Sum` is wrong. We shouldn't turn the `sum` value back to 0 after it becomes null due to overflow. This issue was hidden because: 2.1 for hash aggregate, the buffer is unsafe row. Due to the first bug, we fail when overflow happens, so there is no chance to mistakenly turn null back to 0. 2.2 for sort-based aggregate, the buffer is generic row. The decimal can overflow (the Decimal class has unlimited precision) and we don't have the null problem. If we only fix the first bug, then the second bug is exposed and test fails. If we only fix the second bug, there is no way to test it. This PR fixes these 2 bugs together. ### Why are the changes needed? Fix issues during decimal sum when overflow happens ### Does this PR introduce _any_ user-facing change? Yes. Now decimal sum can return null correctly for overflow under non-ansi mode. ### How was this patch tested? new test and updated test Closes #29026 from cloud-fan/decimal. Authored-by: Wenchen Fan Signed-off-by: HyukjinKwon --- .../sql/catalyst/expressions/UnsafeRow.java | 2 +- .../catalyst/expressions/aggregate/Sum.scala | 64 ++++++++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 14 +--- .../org/apache/spark/sql/UnsafeRowSuite.scala | 10 +++ 4 files changed, 54 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 034894bd86085..4dc5ce1de047b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -288,7 +288,7 @@ public void setDecimal(int ordinal, Decimal value, int precision) { Platform.putLong(baseObject, baseOffset + cursor, 0L); Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); - if (value == null) { + if (value == null || !value.changePrecision(precision, value.scale())) { setNullAt(ordinal); // keep the offset for future update Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 6e850267100fb..a29ae2c8b65a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -58,13 +58,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => DoubleType } - private lazy val sumDataType = resultType - - private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val sum = AttributeReference("sum", resultType)() private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)() - private lazy val zero = Literal.default(sumDataType) + private lazy val zero = Literal.default(resultType) override lazy val aggBufferAttributes = resultType match { case _: DecimalType => sum :: isEmpty :: Nil @@ -72,25 +70,38 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast } override lazy val initialValues: Seq[Expression] = resultType match { - case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType)) + case _: DecimalType => Seq(zero, Literal(true, BooleanType)) case _ => Seq(Literal(null, resultType)) } override lazy val updateExpressions: Seq[Expression] = { - if (child.nullable) { - val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - resultType match { - case _: DecimalType => - Seq(updateSumExpr, isEmpty && child.isNull) - case _ => Seq(updateSumExpr) - } - } else { - val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType) - resultType match { - case _: DecimalType => - Seq(updateSumExpr, Literal(false, BooleanType)) - case _ => Seq(updateSumExpr) - } + resultType match { + case _: DecimalType => + // For decimal type, the initial value of `sum` is 0. We need to keep `sum` unchanged if + // the input is null, as SUM function ignores null input. The `sum` can only be null if + // overflow happens under non-ansi mode. + val sumExpr = if (child.nullable) { + If(child.isNull, sum, sum + KnownNotNull(child).cast(resultType)) + } else { + sum + child.cast(resultType) + } + // The buffer becomes non-empty after seeing the first not-null input. + val isEmptyExpr = if (child.nullable) { + isEmpty && child.isNull + } else { + Literal(false, BooleanType) + } + Seq(sumExpr, isEmptyExpr) + case _ => + // For non-decimal type, the initial value of `sum` is null, which indicates no value. + // We need `coalesce(sum, zero)` to start summing values. And we need an outer `coalesce` + // in case the input is nullable. The `sum` can only be null if there is no value, as + // non-decimal type can produce overflowed value under non-ansi mode. + if (child.nullable) { + Seq(coalesce(coalesce(sum, zero) + child.cast(resultType), sum)) + } else { + Seq(coalesce(sum, zero) + child.cast(resultType)) + } } } @@ -107,15 +118,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast * means we have seen atleast a value that was not null. */ override lazy val mergeExpressions: Seq[Expression] = { - val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left) resultType match { case _: DecimalType => - val inputOverflow = !isEmpty.right && sum.right.isNull val bufferOverflow = !isEmpty.left && sum.left.isNull + val inputOverflow = !isEmpty.right && sum.right.isNull Seq( - If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr), + If( + bufferOverflow || inputOverflow, + Literal.create(null, resultType), + // If both the buffer and the input do not overflow, just add them, as they can't be + // null. See the comments inside `updateExpressions`: `sum` can only be null if + // overflow happens. + KnownNotNull(sum.left) + KnownNotNull(sum.right)), isEmpty.left && isEmpty.right) - case _ => Seq(mergeSumExpr) + case _ => Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left)) } } @@ -128,7 +144,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast */ override lazy val evaluateExpression: Expression = resultType match { case d: DecimalType => - If(isEmpty, Literal.create(null, sumDataType), + If(isEmpty, Literal.create(null, resultType), CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) case _ => sum } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8359dff674a87..52ef5895ed9ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -195,22 +195,14 @@ class DataFrameSuite extends QueryTest private def assertDecimalSumOverflow( df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { if (!ansiEnabled) { - try { - checkAnswer(df, expectedAnswer) - } catch { - case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] => - // This is an existing bug that we can write overflowed decimal to UnsafeRow but fail - // to read it. - assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) - } + checkAnswer(df, expectedAnswer) } else { val e = intercept[SparkException] { - df.collect + df.collect() } assert(e.getCause.isInstanceOf[ArithmeticException]) assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || - e.getCause.getMessage.contains("Overflow in sum of decimals") || - e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + e.getCause.getMessage.contains("Overflow in sum of decimals")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index a5f904c621e6e..9daa69ce9f155 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -178,4 +178,14 @@ class UnsafeRowSuite extends SparkFunSuite { // Makes sure hashCode on unsafe array won't crash unsafeRow.getArray(0).hashCode() } + + test("SPARK-32018: setDecimal with overflowed value") { + val d1 = new Decimal().set(BigDecimal("10000000000000000000")).toPrecision(38, 18) + val row = InternalRow.apply(d1) + val unsafeRow = UnsafeProjection.create(Array[DataType](DecimalType(38, 18))).apply(row) + assert(unsafeRow.getDecimal(0, 38, 18) === d1) + val d2 = (d1 * Decimal(10)).toPrecision(39, 18) + unsafeRow.setDecimal(0, d2, 38) + assert(unsafeRow.getDecimal(0, 38, 18) === null) + } }