Skip to content
Permalink
Browse files

[SPARK-22036][SQL] Decimal multiplication with high precision/scale o…

…ften returns NULL

## What changes were proposed in this pull request?

When there is an operation between Decimals and the result is a number which is not representable exactly with the result's precision and scale, Spark is returning `NULL`. This was done to reflect Hive's behavior, but it is against SQL ANSI 2011, which states that "If the result cannot be represented exactly in the result type, then whether it is rounded or truncated is implementation-defined". Moreover, Hive now changed its behavior in order to respect the standard, thanks to HIVE-15331.

Therefore, the PR propose to:
 - update the rules to determine the result precision and scale according to the new Hive's ones introduces in HIVE-15331;
 - round the result of the operations, when it is not representable exactly with the result's precision and scale, instead of returning `NULL`
 - introduce a new config `spark.sql.decimalOperations.allowPrecisionLoss` which default to `true` (ie. the new behavior) in order to allow users to switch back to the previous one.

Hive behavior reflects SQLServer's one. The only difference is that the precision and scale are adjusted for all the arithmetic operations in Hive, while SQL Server is said to do so only for multiplications and divisions in the documentation. This PR follows Hive's behavior.

A more detailed explanation is available here: https://mail-archives.apache.org/mod_mbox/spark-dev/201712.mbox/%3CCAEorWNAJ4TxJR9NBcgSFMD_VxTg8qVxusjP%2BAJP-x%2BJV9zH-yA%40mail.gmail.com%3E.

## How was this patch tested?

modified and added UTs. Comparisons with results of Hive and SQLServer.

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #20023 from mgaido91/SPARK-22036.

(cherry picked from commit e28eb43)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information...
mgaido91 authored and cloud-fan committed Jan 18, 2018
1 parent f801ac4 commit 8a98274823a4671cee85081dd19f40146e736325
@@ -1793,6 +1793,11 @@ options.
- Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`.
- Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`.

- Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant to SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes
- The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`).
- Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them.
- The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible.

## Upgrading From Spark SQL 2.1 to 2.2

- Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access.
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


@@ -42,8 +43,10 @@ import org.apache.spark.sql.types._
* e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
* e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
* e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2)
* sum(e1) p1 + 10 s1
* avg(e1) p1 + 4 s1 + 4
*
* When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale
* needed are out of the range of available values, the scale is reduced up to 6, in order to
* prevent the truncation of the integer part of the decimals.
*
* To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited
* precision, do the math on unlimited-precision numbers, then introduce casts back to the
@@ -56,6 +59,7 @@ import org.apache.spark.sql.types._
* - INT gets turned into DECIMAL(10, 0)
* - LONG gets turned into DECIMAL(20, 0)
* - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
* - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value
*/
// scalastyle:on
object DecimalPrecision extends TypeCoercionRule {
@@ -93,41 +97,76 @@ object DecimalPrecision extends TypeCoercionRule {
case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e

case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt)
val resultScale = max(s1, s2)
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
resultType)

case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt)
val resultScale = max(s1, s2)
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
resultType)

case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
} else {
DecimalType.bounded(p1 + p2 + 1, s1 + s2)
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)

case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
val diff = (intDig + decDig) - DecimalType.MAX_SCALE
if (diff > 0) {
decDig -= diff / 2 + 1
intDig = DecimalType.MAX_SCALE - decDig
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
// Scale: max(6, s1 + p2 + 1)
val intDig = p1 - s1 + s2
val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1)
val prec = intDig + scale
DecimalType.adjustPrecisionScale(prec, scale)
} else {
var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
val diff = (intDig + decDig) - DecimalType.MAX_SCALE
if (diff > 0) {
decDig -= diff / 2 + 1
intDig = DecimalType.MAX_SCALE - decDig
}
DecimalType.bounded(intDig + decDig, decDig)
}
val resultType = DecimalType.bounded(intDig + decDig, decDig)
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)

case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
}
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)

case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
}
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
@@ -137,9 +176,6 @@ object DecimalPrecision extends TypeCoercionRule {
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
val resultType = widerDecimalType(p1, s1, p2, s2)
b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType)))

// TODO: MaxOf, MinOf, etc might want other rules
// SUM and AVERAGE are handled by the implementations of those expressions
}

/**
@@ -243,17 +279,35 @@ object DecimalPrecision extends TypeCoercionRule {
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
(left.dataType, right.dataType) match {
case (t: IntegralType, DecimalType.Fixed(p, s)) =>
b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right))
case (DecimalType.Fixed(p, s), t: IntegralType) =>
b.makeCopy(Array(left, Cast(right, DecimalType.forType(t))))
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
b.makeCopy(Array(left, Cast(right, DoubleType)))
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
b.makeCopy(Array(Cast(left, DoubleType), right))
case _ =>
b
(left, right) match {
// Promote literal integers inside a binary expression with fixed-precision decimals to
// decimals. The precision and scale are the ones strictly needed by the integer value.
// Requiring more precision than necessary may lead to a useless loss of precision.
// Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2.
// If we use the default precision and scale for the integer type, 2 is considered a
// DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18),
// which is out of range and therefore it will becomes DECIMAL(38, 7), leading to
// potentially loosing 11 digits of the fractional part. Using only the precision needed
// by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would
// become DECIMAL(38, 16), safely having a much lower precision loss.
case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType]
&& l.dataType.isInstanceOf[IntegralType] =>
b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r))
case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType]
&& r.dataType.isInstanceOf[IntegralType] =>
b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r))))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) =>
b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r))
case (l @ DecimalType.Expression(_, _), r @ IntegralType()) =>
b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType))))
case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) =>
b.makeCopy(Array(l, Cast(r, DoubleType)))
case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) =>
b.makeCopy(Array(Cast(l, DoubleType), r))
case _ => b
}
}

}
@@ -58,7 +58,7 @@ object Literal {
case s: Short => Literal(s, ShortType)
case s: String => Literal(UTF8String.fromString(s), StringType)
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale))
case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d))
case d: JavaBigDecimal =>
Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale()))
case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale))
@@ -1064,6 +1064,16 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS =
buildConf("spark.sql.decimalOperations.allowPrecisionLoss")
.internal()
.doc("When true (default), establishing the result type of an arithmetic operation " +
"happens according to Hive behavior and SQL ANSI 2011 specification, ie. rounding the " +
"decimal part of the result if an exact representation is not possible. Otherwise, NULL " +
"is returned in those cases, as previously.")
.booleanConf
.createWithDefault(true)

val SQL_STRING_REDACTION_PATTERN =
ConfigBuilder("spark.sql.redaction.string.regex")
.doc("Regex to decide which parts of strings produced by Spark contain sensitive " +
@@ -1441,6 +1451,8 @@ class SQLConf extends Serializable with Logging {

def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)

def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)

def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE)

def continuousStreamingExecutorPollIntervalMs: Long =
@@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.typeTag

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}


/**
@@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType {
val MAX_SCALE = 38
val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18)
val USER_DEFAULT: DecimalType = DecimalType(10, 0)
val MINIMUM_ADJUSTED_SCALE = 6

// The decimal types compatible with other numeric types
private[sql] val ByteDecimal = DecimalType(3, 0)
@@ -136,10 +137,52 @@ object DecimalType extends AbstractDataType {
case DoubleType => DoubleDecimal
}

private[sql] def fromLiteral(literal: Literal): DecimalType = literal.value match {
case v: Short => fromBigDecimal(BigDecimal(v))
case v: Int => fromBigDecimal(BigDecimal(v))
case v: Long => fromBigDecimal(BigDecimal(v))
case _ => forType(literal.dataType)
}

private[sql] def fromBigDecimal(d: BigDecimal): DecimalType = {
DecimalType(Math.max(d.precision, d.scale), d.scale)
}

private[sql] def bounded(precision: Int, scale: Int): DecimalType = {
DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
}

/**
* Scale adjustment implementation is based on Hive's one, which is itself inspired to
* SQLServer's one. In particular, when a result precision is greater than
* {@link #MAX_PRECISION}, the corresponding scale is reduced to prevent the integral part of a
* result from being truncated.
*
* This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true.
*/
private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
// Assumptions:
assert(precision >= scale)
assert(scale >= 0)

if (precision <= MAX_PRECISION) {
// Adjustment only needed when we exceed max precision
DecimalType(precision, scale)
} else {
// Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
val intDigits = precision - scale
// If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise
// preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE)
// The resulting scale is the maximum between what is available without causing a loss of
// digits for the integer part of the decimal and the minimum guaranteed scale, which is
// computed above
val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue)

DecimalType(MAX_PRECISION, adjustedScale)
}
}

override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT

override private[sql] def acceptsType(other: DataType): Boolean = {
@@ -408,8 +408,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType)
assertExpressionType(sum(Divide(1, 2.0f)), DoubleType)
assertExpressionType(sum(Divide(1.0f, 2)), DoubleType)
assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11))
assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11))
assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(22, 11))
assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(26, 6))
assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType)
assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType)
}
@@ -136,19 +136,19 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {

test("maximum decimals") {
for (expr <- Seq(d1, d2, i, u)) {
checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT)
checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT)
checkType(Add(expr, u), DecimalType(38, 17))
checkType(Subtract(expr, u), DecimalType(38, 17))
}

checkType(Multiply(d1, u), DecimalType(38, 19))
checkType(Multiply(d2, u), DecimalType(38, 20))
checkType(Multiply(i, u), DecimalType(38, 18))
checkType(Multiply(u, u), DecimalType(38, 36))
checkType(Multiply(d1, u), DecimalType(38, 16))
checkType(Multiply(d2, u), DecimalType(38, 14))
checkType(Multiply(i, u), DecimalType(38, 7))
checkType(Multiply(u, u), DecimalType(38, 6))

checkType(Divide(u, d1), DecimalType(38, 18))
checkType(Divide(u, d2), DecimalType(38, 19))
checkType(Divide(u, i), DecimalType(38, 23))
checkType(Divide(u, u), DecimalType(38, 18))
checkType(Divide(u, d1), DecimalType(38, 17))
checkType(Divide(u, d2), DecimalType(38, 16))
checkType(Divide(u, i), DecimalType(38, 18))
checkType(Divide(u, u), DecimalType(38, 6))

checkType(Remainder(d1, u), DecimalType(19, 18))
checkType(Remainder(d2, u), DecimalType(21, 18))

0 comments on commit 8a98274

Please sign in to comment.
You can’t perform that action at this time.