Skip to content

Commit

Permalink
[SPARK-32706][SQL] Improve cast string to decimal type
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This pr makes cast string type to decimal decimal type fast fail if precision larger that 38.

### Why are the changes needed?

It is very slow if precision very large.

Benchmark and benchmark result:
```scala
import org.apache.spark.benchmark.Benchmark
val bd1 = new java.math.BigDecimal("6.0790316E+25569151")
val bd2 = new java.math.BigDecimal("6.0790316E+25");

val benchmark = new Benchmark("Benchmark string to decimal", 1, minNumIters = 2)
benchmark.addCase(bd1.toString) { _ =>
  println(Decimal(bd1).precision)
}
benchmark.addCase(bd2.toString) { _ =>
  println(Decimal(bd2).precision)
}
benchmark.run()
```
```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_251-b08 on Mac OS X 10.15.6
Intel(R) Core(TM) i9-9980HK CPU  2.40GHz
Benchmark string to decimal:              Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------------------------------
6.0790316E+25569151                                9340           9381          57          0.0  9340094625.0       1.0X
6.0790316E+25                                         0              0           0          0.5        2150.0 4344230.1X
```
Stacktrace:
![image](https://user-images.githubusercontent.com/5399861/92941705-4c868980-f483-11ea-8a15-b93acde8c0f4.png)

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit test and benchmark test:
Dataset | Before this pr (Seconds) | After this pr (Seconds)
-- | -- | --
https://issues.apache.org/jira/secure/attachment/13011406/part-00000.parquet | 2640 | 2

Closes #29731 from wangyum/SPARK-32706.

Authored-by: Yuming Wang <yumwang@ebay.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
wangyum authored and cloud-fan committed Sep 16, 2020
1 parent c918909 commit 3bc13e6
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -669,19 +669,13 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit


private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try {
// According the benchmark test, `s.toString.trim` is much faster than `s.trim.toString`.
// Please refer to https://github.com/apache/spark/pull/26640
changePrecision(Decimal(new JavaBigDecimal(s.toString.trim)), target)
} catch {
case _: NumberFormatException =>
if (ansiEnabled) {
throw new NumberFormatException(s"invalid input syntax for type numeric: $s")
} else {
null
}
case StringType if !ansiEnabled =>
buildCast[UTF8String](_, s => {
val d = Decimal.fromString(s)
if (d == null) null else changePrecision(d, target)
})
case StringType if ansiEnabled =>
buildCast[UTF8String](_, s => changePrecision(Decimal.fromStringANSI(s), target))
case BooleanType =>
buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
case DateType =>
Expand Down Expand Up @@ -1185,20 +1179,21 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal])
val canNullSafeCast = Cast.canNullSafeCastToDecimal(from, target)
from match {
case StringType =>
case StringType if !ansiEnabled =>
(c, evPrim, evNull) =>
val handleException = if (ansiEnabled) {
s"""throw new NumberFormatException("invalid input syntax for type numeric: " + $c);"""
} else {
s"$evNull =true;"
}
code"""
try {
Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString().trim()));
Decimal $tmp = Decimal.fromString($c);
if ($tmp == null) {
$evNull = true;
} else {
${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)}
}
"""
case StringType if ansiEnabled =>
(c, evPrim, evNull) =>
code"""
Decimal $tmp = Decimal.fromStringANSI($c);
${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)}
} catch (java.lang.NumberFormatException e) {
$handleException
}
"""
case BooleanType =>
(c, evPrim, evNull) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
package org.apache.spark.sql.types

import java.lang.{Long => JLong}
import java.math.{BigInteger, MathContext, RoundingMode}
import java.math.{BigDecimal => JavaBigDecimal, BigInteger, MathContext, RoundingMode}

import scala.util.Try

import org.apache.spark.annotation.Unstable
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.unsafe.types.UTF8String

/**
* A mutable implementation of BigDecimal that can hold a Long if values are small enough.
Expand Down Expand Up @@ -550,9 +551,6 @@ object Decimal {
private[sql] val ZERO = Decimal(0)
private[sql] val ONE = Decimal(1)

private val LONG_MAX_BIG_INT = BigInteger.valueOf(JLong.MAX_VALUE)
private val LONG_MIN_BIG_INT = BigInteger.valueOf(JLong.MIN_VALUE)

def apply(value: Double): Decimal = new Decimal().set(value)

def apply(value: Long): Decimal = new Decimal().set(value)
Expand Down Expand Up @@ -589,6 +587,52 @@ object Decimal {
}
}

private def calculatePrecision(bigDecimal: JavaBigDecimal): Int = {
if (bigDecimal.scale < 0) {
bigDecimal.precision - bigDecimal.scale
} else {
bigDecimal.precision
}
}

private def stringToJavaBigDecimal(str: UTF8String): JavaBigDecimal = {
// According the benchmark test, `s.toString.trim` is much faster than `s.trim.toString`.
// Please refer to https://github.com/apache/spark/pull/26640
new JavaBigDecimal(str.toString.trim)
}

def fromString(str: UTF8String): Decimal = {
try {
val bigDecimal = stringToJavaBigDecimal(str)
// We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow.
// For example: Decimal("6.0790316E+25569151")
if (calculatePrecision(bigDecimal) > DecimalType.MAX_PRECISION) {
null
} else {
Decimal(bigDecimal)
}
} catch {
case _: NumberFormatException =>
null
}
}

def fromStringANSI(str: UTF8String): Decimal = {
try {
val bigDecimal = stringToJavaBigDecimal(str)
// We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow.
// For example: Decimal("6.0790316E+25569151")
if (calculatePrecision(bigDecimal) > DecimalType.MAX_PRECISION) {
throw new ArithmeticException(s"out of decimal type range: $str")
} else {
Decimal(bigDecimal)
}
} catch {
case _: NumberFormatException =>
throw new NumberFormatException(s"invalid input syntax for type numeric: $str")
}
}

/**
* Creates a decimal from unscaled, precision and scale without checking the bounds.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,29 @@ class CastSuite extends CastSuiteBase {
val v = Literal.create(Row(1), new ExampleSubTypeUDT())
checkEvaluation(cast(v, new ExampleBaseTypeUDT), Row(1))
}

test("Fast fail for cast string type to decimal type") {
checkEvaluation(cast("12345678901234567890123456789012345678", DecimalType(38, 0)),
Decimal("12345678901234567890123456789012345678"))
checkEvaluation(cast("123456789012345678901234567890123456789", DecimalType(38, 0)), null)
checkEvaluation(cast("12345678901234567890123456789012345678", DecimalType(38, 1)), null)

checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("0.00000000000000000000000000000000000000000001", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 18)),
Decimal("0E-18"))
checkEvaluation(cast("6E-120", DecimalType(38, 0)),
Decimal("0"))

checkEvaluation(cast("6E+37", DecimalType(38, 0)),
Decimal("60000000000000000000000000000000000000"))
checkEvaluation(cast("6E+38", DecimalType(38, 0)), null)
checkEvaluation(cast("6E+37", DecimalType(38, 1)), null)

checkEvaluation(cast("abcd", DecimalType(38, 1)), null)
}
}

/**
Expand Down Expand Up @@ -1405,4 +1428,37 @@ class AnsiCastSuite extends CastSuiteBase {
checkEvaluation(cast(negativeTs, LongType), expectedSecs)
}
}

test("Fast fail for cast string type to decimal type in ansi mode") {
checkEvaluation(cast("12345678901234567890123456789012345678", DecimalType(38, 0)),
Decimal("12345678901234567890123456789012345678"))
checkExceptionInExpression[ArithmeticException](
cast("123456789012345678901234567890123456789", DecimalType(38, 0)),
"out of decimal type range")
checkExceptionInExpression[ArithmeticException](
cast("12345678901234567890123456789012345678", DecimalType(38, 1)),
"cannot be represented as Decimal(38, 1)")

checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("0.00000000000000000000000000000000000000000001", DecimalType(38, 0)),
Decimal("0"))
checkEvaluation(cast("0.00000000000000000000000000000000000001", DecimalType(38, 18)),
Decimal("0E-18"))
checkEvaluation(cast("6E-120", DecimalType(38, 0)),
Decimal("0"))

checkEvaluation(cast("6E+37", DecimalType(38, 0)),
Decimal("60000000000000000000000000000000000000"))
checkExceptionInExpression[ArithmeticException](
cast("6E+38", DecimalType(38, 0)),
"out of decimal type range")
checkExceptionInExpression[ArithmeticException](
cast("6E+37", DecimalType(38, 1)),
"cannot be represented as Decimal(38, 1)")

checkExceptionInExpression[NumberFormatException](
cast("abcd", DecimalType(38, 1)),
"invalid input syntax for type numeric")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.Decimal._
import org.apache.spark.unsafe.types.UTF8String

class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper {
/** Check that a Decimal has the given string representation, precision and scale */
Expand Down Expand Up @@ -256,4 +257,33 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper
assert(decimalLong.toScalaBigInt == scala.math.BigInt("123456789"))
assert(decimalLong.toJavaBigInteger == new java.math.BigInteger("123456789"))
}

test("UTF8String to Decimal") {
def checkFromString(string: String): Unit = {
assert(Decimal.fromString(UTF8String.fromString(string)) === Decimal(string))
assert(Decimal.fromStringANSI(UTF8String.fromString(string)) === Decimal(string))
}

def checkOutOfRangeFromString(string: String): Unit = {
assert(Decimal.fromString(UTF8String.fromString(string)) === null)
val e = intercept[ArithmeticException](Decimal.fromStringANSI(UTF8String.fromString(string)))
assert(e.getMessage.contains("out of decimal type range"))
}

checkFromString("12345678901234567890123456789012345678")
checkOutOfRangeFromString("123456789012345678901234567890123456789")

checkFromString("0.00000000000000000000000000000000000001")
checkFromString("0.000000000000000000000000000000000000000000000001")

checkFromString("6E-640")

checkFromString("6E+37")
checkOutOfRangeFromString("6E+38")
checkOutOfRangeFromString("6.0790316E+25569151")

assert(Decimal.fromString(UTF8String.fromString("str")) === null)
val e = intercept[NumberFormatException](Decimal.fromStringANSI(UTF8String.fromString("str")))
assert(e.getMessage.contains("invalid input syntax for type numeric"))
}
}

0 comments on commit 3bc13e6

Please sign in to comment.