Skip to content

Commit

Permalink
[SPARK-28369][SQL] Honor spark.sql.decimalOperations.nullOnOverflow i…
Browse files Browse the repository at this point in the history
…n ScalaUDF result

## What changes were proposed in this pull request?

When a `ScalaUDF` returns a value which overflows, currently it returns null regardless of the value of the config `spark.sql.decimalOperations.nullOnOverflow`.

The PR makes it respect the above-mentioned config and behave accordingly.

## How was this patch tested?

added UT

Closes #25144 from mgaido91/SPARK-28369.

Authored-by: Marco Gaido <marcogaido91@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
mgaido91 authored and cloud-fan committed Jul 22, 2019
1 parent c97f06d commit a783690
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 12 deletions.
Expand Up @@ -343,6 +343,9 @@ object CatalystTypeConverters {

private class DecimalConverter(dataType: DecimalType)
extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {

private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow

override def toCatalystImpl(scalaValue: Any): Decimal = {
val decimal = scalaValue match {
case d: BigDecimal => Decimal(d)
Expand All @@ -353,7 +356,7 @@ object CatalystTypeConverters {
s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ s"cannot be converted to ${dataType.catalogString}")
}
decimal.toPrecision(dataType.precision, dataType.scale)
decimal.toPrecision(dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow)
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = {
if (catalystValue == null) null
Expand Down
Expand Up @@ -1177,6 +1177,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
dataType match {
case DecimalType.Fixed(_, s) =>
val decimal = input1.asInstanceOf[Decimal]
// Overflow cannot happen, so no need to control nullOnOverflow
decimal.toPrecision(decimal.precision, s, mode)
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
Expand Down
Expand Up @@ -414,20 +414,12 @@ final class Decimal extends Ordered[Decimal] with Serializable {

def floor: Decimal = if (scale == 0) this else {
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
val res = toPrecision(newPrecision, 0, ROUND_FLOOR)
if (res == null) {
throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
}
res
toPrecision(newPrecision, 0, ROUND_FLOOR, nullOnOverflow = false)
}

def ceil: Decimal = if (scale == 0) this else {
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
val res = toPrecision(newPrecision, 0, ROUND_CEILING)
if (res == null) {
throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
}
res
toPrecision(newPrecision, 0, ROUND_CEILING, nullOnOverflow = false)
}
}

Expand Down
Expand Up @@ -21,7 +21,8 @@ import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand Down Expand Up @@ -54,4 +55,26 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}

test("SPARK-28369: honor nullOnOverflow config for ScalaUDF") {
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
val udf = ScalaUDF(
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil)
val e1 = intercept[ArithmeticException](udf.eval())
assert(e1.getMessage.contains("cannot be represented as Decimal"))
val e2 = intercept[SparkException] {
checkEvaluationWithUnsafeProjection(udf, null)
}
assert(e2.getCause.isInstanceOf[ArithmeticException])
}
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
val udf = ScalaUDF(
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil)
checkEvaluation(udf, null)
}
}
}

0 comments on commit a783690

Please sign in to comment.