Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sql-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ license: |
- Since Spark 3.0, the unary arithmetic operator plus(`+`) only accepts string, numeric and interval type values as inputs. Besides, `+` with a integral string representation will be coerced to double value, e.g. `+'1'` results `1.0`. In Spark version 2.4 and earlier, this operator is ignored. There is no type checking for it, thus, all type values with a `+` prefix are valid, e.g. `+ array(1, 2)` is valid and results `[1, 2]`. Besides, there is no type coercion for it at all, e.g. in Spark 2.4, the result of `+'1'` is string `1`.

- Since Spark 3.0, day-time interval strings are converted to intervals with respect to the `from` and `to` bounds. If an input string does not match to the pattern defined by specified bounds, the `ParseException` exception is thrown. For example, `interval '2 10:20' hour to minute` raises the exception because the expected format is `[+|-]h[h]:[m]m`. In Spark version 2.4, the `from` bound was not taken into account, and the `to` bound was used to truncate the resulted interval. For instance, the day-time interval string from the showed example is converted to `interval 10 hours 20 minutes`. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.fromDayTimeString.enabled` to `true`.

- Since Spark 3.0, negative scale of decimal is not allowed by default, e.g. data type of literal like `1E10BD` is `DecimalType(11, 0)`. In Spark version 2.4 and earlier, it was `DecimalType(2, -9)`. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.allowNegativeScaleOfDecimal.enabled` to `true`.

- Since Spark 3.0, the `date_add` and `date_sub` functions only accepts int, smallint, tinyint as the 2nd argument, fractional and string types are not valid anymore, e.g. `date_add(cast('1964-05-23' as date), '12.34')` will cause `AnalysisException`. In Spark version 2.4 and earlier, if the 2nd argument is fractional or string value, it will be coerced to int value, and the result will be a date value of `1964-06-04`.

Expand Down
12 changes: 8 additions & 4 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,14 @@ def test_create_dataframe_from_dict_respects_schema(self):
self.assertEqual(df.columns, ['b'])

def test_negative_decimal(self):
df = self.spark.createDataFrame([(1, ), (11, )], ["value"])
ret = df.select(col("value").cast(DecimalType(1, -1))).collect()
actual = list(map(lambda r: int(r.value), ret))
self.assertEqual(actual, [0, 10])
try:
self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal.enabled=true")
df = self.spark.createDataFrame([(1, ), (11, )], ["value"])
ret = df.select(col("value").cast(DecimalType(1, -1))).collect()
actual = list(map(lambda r: int(r.value), ret))
self.assertEqual(actual, [0, 10])
finally:
self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal.enabled=false")

def test_create_dataframe_from_objects(self):
data = [MyObject(1, "1"), MyObject(2, "2")]
Expand Down
2 changes: 0 additions & 2 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,8 +867,6 @@ def _parse_datatype_json_string(json_string):
>>> complex_maptype = MapType(complex_structtype,
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
>>> # Decimal with negative scale.
>>> check_datatype(DecimalType(1,-1))
"""
return _parse_datatype_json_value(json.loads(json_string))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ object Literal {
case s: String => Literal(UTF8String.fromString(s), StringType)
case c: Char => Literal(UTF8String.fromString(c.toString), StringType)
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d))
case d: BigDecimal =>
val decimal = Decimal(d)
Literal(decimal, DecimalType.fromDecimal(decimal))
case d: JavaBigDecimal =>
Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale()))
val decimal = Decimal(d)
Literal(decimal, DecimalType.fromDecimal(decimal))
case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale))
case i: Instant => Literal(instantToMicros(i), TimestampType)
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1973,6 +1973,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED =
buildConf("spark.sql.legacy.allowNegativeScaleOfDecimal.enabled")
.internal()
.doc("When set to true, negative scale of Decimal type is allowed. For example, " +
"the type of number 1E10BD under legacy mode is DecimalType(2, -9), but is " +
"Decimal(11, 0) in non legacy mode.")
.booleanConf
.createWithDefault(false)

val LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED =
buildConf("spark.sql.legacy.createHiveTableByDefault.enabled")
.internal()
Expand Down Expand Up @@ -2705,6 +2714,9 @@ class SQLConf extends Serializable with Logging {
def exponentLiteralAsDecimalEnabled: Boolean =
getConf(SQLConf.LEGACY_EXPONENT_LITERAL_AS_DECIMAL_ENABLED)

def allowNegativeScaleOfDecimalEnabled: Boolean =
getConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED)

def createHiveTableByDefaultEnabled: Boolean =
getConf(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.math.{BigInteger, MathContext, RoundingMode}
import scala.util.Try

import org.apache.spark.annotation.Unstable
import org.apache.spark.sql.internal.SQLConf

/**
* A mutable implementation of BigDecimal that can hold a Long if values are small enough.
Expand Down Expand Up @@ -89,6 +90,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
* and return it, or return null if it cannot be set due to overflow.
*/
def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = {
DecimalType.checkNegativeScale(scale)
if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) {
// We can't represent this compactly as a long without risking overflow
if (precision < 19) {
Expand All @@ -113,6 +115,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
* Set this Decimal to the given BigDecimal value, with a given precision and scale.
*/
def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
DecimalType.checkNegativeScale(scale)
this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP)
if (decimalVal.precision > precision) {
throw new ArithmeticException(
Expand All @@ -136,10 +139,16 @@ final class Decimal extends Ordered[Decimal] with Serializable {
// result. For example, the precision of 0.01 equals to 1 based on the definition, but
// the scale is 2. The expected precision should be 2.
this._precision = decimal.scale
this._scale = decimal.scale
} else if (decimal.scale < 0 && !SQLConf.get.allowNegativeScaleOfDecimalEnabled) {
this._precision = decimal.precision - decimal.scale
this._scale = 0
// set scale to 0 to correct unscaled value
this.decimalVal = decimal.setScale(0)
} else {
this._precision = decimal.precision
this._scale = decimal.scale
}
this._scale = decimal.scale
this
}

Expand Down Expand Up @@ -375,6 +384,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (precision == this.precision && scale == this.scale) {
return true
}
DecimalType.checkNegativeScale(scale)
// First, update our longVal if we can, or transfer over to using a BigDecimal
if (decimalVal.eq(null)) {
if (scale < _scale) {
Expand Down Expand Up @@ -583,6 +593,7 @@ object Decimal {
* Creates a decimal from unscaled, precision and scale without checking the bounds.
*/
def createUnsafe(unscaled: Long, precision: Int, scale: Int): Decimal = {
DecimalType.checkNegativeScale(scale)
val dec = new Decimal()
dec.longVal = unscaled
dec._precision = precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.reflect.runtime.universe.typeTag
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.internal.SQLConf

/**
* The data type representing `java.math.BigDecimal` values.
Expand All @@ -41,6 +42,8 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
@Stable
case class DecimalType(precision: Int, scale: Int) extends FractionalType {

DecimalType.checkNegativeScale(scale)

if (scale > precision) {
throw new AnalysisException(
s"Decimal scale ($scale) cannot be greater than precision ($precision).")
Expand Down Expand Up @@ -141,20 +144,26 @@ object DecimalType extends AbstractDataType {
}

private[sql] def fromLiteral(literal: Literal): DecimalType = literal.value match {
case v: Short => fromBigDecimal(BigDecimal(v))
case v: Int => fromBigDecimal(BigDecimal(v))
case v: Long => fromBigDecimal(BigDecimal(v))
case v: Short => fromDecimal(Decimal(BigDecimal(v)))
case v: Int => fromDecimal(Decimal(BigDecimal(v)))
case v: Long => fromDecimal(Decimal(BigDecimal(v)))
case _ => forType(literal.dataType)
}

private[sql] def fromBigDecimal(d: BigDecimal): DecimalType = {
DecimalType(Math.max(d.precision, d.scale), d.scale)
}
private[sql] def fromDecimal(d: Decimal): DecimalType = DecimalType(d.precision, d.scale)

private[sql] def bounded(precision: Int, scale: Int): DecimalType = {
DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
}

private[sql] def checkNegativeScale(scale: Int): Unit = {
if (scale < 0 && !SQLConf.get.allowNegativeScaleOfDecimalEnabled) {
throw new AnalysisException(s"Negative scale is not allowed: $scale. " +
s"You can use spark.sql.legacy.allowNegativeScaleOfDecimal.enabled=true " +
s"to enable legacy mode to allow it.")
Comment on lines +162 to +163
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no need s"".

}
}

/**
* Scale adjustment implementation is based on Hive's one, which is itself inspired to
* SQLServer's one. In particular, when a result precision is greater than
Expand All @@ -164,7 +173,8 @@ object DecimalType extends AbstractDataType {
* This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true.
*/
private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
// Assumption:
// Assumptions:
checkNegativeScale(scale)
assert(precision >= scale)

if (precision <= MAX_PRECISION) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Union}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -273,12 +274,14 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {
}

test("SPARK-24468: operations on decimals with negative scale") {
val a = AttributeReference("a", DecimalType(3, -10))()
val b = AttributeReference("b", DecimalType(1, -1))()
val c = AttributeReference("c", DecimalType(35, 1))()
checkType(Multiply(a, b), DecimalType(5, -11))
checkType(Multiply(a, c), DecimalType(38, -9))
checkType(Multiply(b, c), DecimalType(37, 0))
withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") {
val a = AttributeReference("a", DecimalType(3, -10))()
val b = AttributeReference("b", DecimalType(1, -1))()
val c = AttributeReference("c", DecimalType(35, 1))()
checkType(Multiply(a, b), DecimalType(5, -11))
checkType(Multiply(a, c), DecimalType(38, -9))
checkType(Multiply(b, c), DecimalType(37, 0))
}
}

/** strength reduction for integer/decimal comparisons */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.Locale

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper {
Expand Down Expand Up @@ -147,13 +148,15 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper {
val options = new CSVOptions(Map.empty[String, String], false, "GMT")
val inferSchema = new CSVInferSchema(options)

// 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9).
assert(inferSchema.inferField(DecimalType(3, -10), "1.19E11") ==
DecimalType(4, -9))
withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") {
// 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9).
assert(inferSchema.inferField(DecimalType(3, -10), "1.19E11") ==
DecimalType(4, -9))
}

// BigDecimal("12345678901234567890.01234567890123456789") is precision 40 and scale 20.
val value = "12345678901234567890.01234567890123456789"
assert(inferSchema.inferField(DecimalType(3, -10), value) == DoubleType)
assert(inferSchema.inferField(DecimalType(3, 0), value) == DoubleType)

// Seq(s"${Long.MaxValue}1", "2015-12-01 00:00:00") should be StringType
assert(inferSchema.inferField(NullType, s"${Long.MaxValue}1") == DecimalType(20, 0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1048,15 +1048,9 @@ class CastSuite extends CastSuiteBase {
assert(cast(Decimal(9.95), DecimalType(2, 1)).nullable)
assert(cast(Decimal(9.95), DecimalType(3, 1)).nullable === false)

assert(cast(Decimal("1003"), DecimalType(3, -1)).nullable)
assert(cast(Decimal("1003"), DecimalType(4, -1)).nullable === false)
assert(cast(Decimal("995"), DecimalType(2, -1)).nullable)
assert(cast(Decimal("995"), DecimalType(3, -1)).nullable === false)

assert(cast(true, DecimalType.SYSTEM_DEFAULT).nullable === false)
assert(cast(true, DecimalType(1, 1)).nullable)


checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03))
checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03))
checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0))
Expand Down Expand Up @@ -1095,17 +1089,9 @@ class CastSuite extends CastSuiteBase {

checkEvaluation(cast(Decimal("1003"), DecimalType.SYSTEM_DEFAULT), Decimal(1003))
checkEvaluation(cast(Decimal("1003"), DecimalType(4, 0)), Decimal(1003))
checkEvaluation(cast(Decimal("1003"), DecimalType(3, -1)), Decimal(1000))
checkEvaluation(cast(Decimal("1003"), DecimalType(2, -2)), Decimal(1000))
checkEvaluation(cast(Decimal("1003"), DecimalType(1, -2)), null)
checkEvaluation(cast(Decimal("1003"), DecimalType(2, -1)), null)
checkEvaluation(cast(Decimal("1003"), DecimalType(3, 0)), null)

checkEvaluation(cast(Decimal("995"), DecimalType(3, 0)), Decimal(995))
checkEvaluation(cast(Decimal("995"), DecimalType(3, -1)), Decimal(1000))
checkEvaluation(cast(Decimal("995"), DecimalType(2, -2)), Decimal(1000))
checkEvaluation(cast(Decimal("995"), DecimalType(2, -1)), null)
checkEvaluation(cast(Decimal("995"), DecimalType(1, -2)), null)

checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null)
checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null)
Expand All @@ -1119,6 +1105,23 @@ class CastSuite extends CastSuiteBase {

checkEvaluation(cast(true, DecimalType(2, 1)), Decimal(1))
checkEvaluation(cast(true, DecimalType(1, 1)), null)

withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") {
assert(cast(Decimal("1003"), DecimalType(3, -1)).nullable)
assert(cast(Decimal("1003"), DecimalType(4, -1)).nullable === false)
assert(cast(Decimal("995"), DecimalType(2, -1)).nullable)
assert(cast(Decimal("995"), DecimalType(3, -1)).nullable === false)

checkEvaluation(cast(Decimal("1003"), DecimalType(3, -1)), Decimal(1000))
checkEvaluation(cast(Decimal("1003"), DecimalType(2, -2)), Decimal(1000))
checkEvaluation(cast(Decimal("1003"), DecimalType(1, -2)), null)
checkEvaluation(cast(Decimal("1003"), DecimalType(2, -1)), null)

checkEvaluation(cast(Decimal("995"), DecimalType(3, -1)), Decimal(1000))
checkEvaluation(cast(Decimal("995"), DecimalType(2, -2)), Decimal(1000))
checkEvaluation(cast(Decimal("995"), DecimalType(2, -1)), null)
checkEvaluation(cast(Decimal("995"), DecimalType(1, -2)), null)
}
}

test("SPARK-28470: Cast should honor nullOnOverflow property") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,16 @@ class ExpressionParserSuite extends AnalysisTest {
intercept("1.20E-38BD", "decimal can only support precision up to 38")
}

test("SPARK-30252: Decimal should set zero scale rather than negative scale by default") {
assertEqual("123.0BD", Literal(Decimal(BigDecimal("123.0")), DecimalType(4, 1)))
assertEqual("123BD", Literal(Decimal(BigDecimal("123")), DecimalType(3, 0)))
assertEqual("123E10BD", Literal(Decimal(BigDecimal("123E10")), DecimalType(13, 0)))
assertEqual("123E+10BD", Literal(Decimal(BigDecimal("123E+10")), DecimalType(13, 0)))
assertEqual("123E-10BD", Literal(Decimal(BigDecimal("123E-10")), DecimalType(10, 10)))
assertEqual("1.23E10BD", Literal(Decimal(BigDecimal("1.23E10")), DecimalType(11, 0)))
assertEqual("-1.23E10BD", Literal(Decimal(BigDecimal("-1.23E10")), DecimalType(11, 0)))
}

test("SPARK-29956: scientific decimal should be parsed as Decimal in legacy mode") {
def testDecimal(value: String, parser: ParserInterface): Unit = {
assertEqual(value, Literal(BigDecimal(value).underlying), parser)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ package org.apache.spark.sql.types
import org.scalatest.PrivateMethodTester

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.Decimal._

class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper {
/** Check that a Decimal has the given string representation, precision and scale */
private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
assert(d.toString === string)
Expand Down Expand Up @@ -63,13 +66,27 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
intercept[ArithmeticException](Decimal(1e17.toLong, 17, 0))
}

test("creating decimals with negative scale") {
checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3)
checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2)
checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9)
checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10)
checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10)
checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10)
test("creating decimals with negative scale under legacy mode") {
withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") {
checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3)
checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2)
checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9)
checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10)
checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10)
checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10)
}
}

test("SPARK-30252: Negative scale is not allowed by default") {
def checkNegativeScaleDecimal(d: => Decimal): Unit = {
intercept[AnalysisException](d)
.getMessage
.contains("Negative scale is not allowed under ansi mode")
}
checkNegativeScaleDecimal(Decimal(BigDecimal("98765"), 5, -3))
checkNegativeScaleDecimal(Decimal(BigDecimal("98765").underlying(), 5, -3))
checkNegativeScaleDecimal(Decimal(98765L, 5, -3))
checkNegativeScaleDecimal(Decimal.createUnsafe(98765L, 5, -3))
}

test("double and long values") {
Expand Down
Loading