Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10215] [SQL] Fix precision of division (follow the rule in Hive) #8415

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,15 @@ object HiveTypeCoercion {
resultType)

case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1),
max(6, s1 + p2 + 1))
var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
val diff = (intDig + decDig) - DecimalType.MAX_SCALE
if (diff > 0) {
decDig -= diff / 2 + 1
intDig = DecimalType.MAX_SCALE - decDig
DecimalType.bounded(intDig + decDig, decDig)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is not used. I will remove it when I merge the pr.

}
val resultType = DecimalType.bounded(intDig + decDig, decDig)
val widerType = widerDecimalType(p1, s1, p2, s2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the resultType always tighter than widerType?

CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._

class AnalysisSuite extends AnalysisTest {
import TestRelations._
import org.apache.spark.sql.catalyst.analysis.TestRelations._

test("union project *") {
val plan = (1 to 100)
Expand Down Expand Up @@ -96,7 +95,7 @@ class AnalysisSuite extends AnalysisTest {
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
// StringType will be promoted into Decimal(38, 18)
assert(pl(3).dataType == DecimalType(38, 29))
assert(pl(3).dataType == DecimalType(38, 22))
assert(pl(4).dataType == DoubleType)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
checkType(Multiply(i, u), DecimalType(38, 18))
checkType(Multiply(u, u), DecimalType(38, 36))

checkType(Divide(u, d1), DecimalType(38, 21))
checkType(Divide(u, d2), DecimalType(38, 24))
checkType(Divide(u, i), DecimalType(38, 29))
checkType(Divide(u, u), DecimalType(38, 38))
checkType(Divide(u, d1), DecimalType(38, 18))
checkType(Divide(u, d2), DecimalType(38, 19))
checkType(Divide(u, i), DecimalType(38, 23))
checkType(Divide(u, u), DecimalType(38, 18))

checkType(Remainder(d1, u), DecimalType(19, 18))
checkType(Remainder(d2, u), DecimalType(21, 18))
Expand Down
25 changes: 23 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1622,9 +1622,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333")))
checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333")))
checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"),
Row(BigDecimal("3.4333333333333333333333333333333333333", new MathContext(38))))
Row(BigDecimal("3.433333333333333333333333333", new MathContext(38))))
checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"),
Row(null))
Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38))))
}

test("SPARK-10215 Div of Decimal returns null") {
val d = Decimal(1.12321)
val df = Seq((d, 1)).toDF("a", "b")

checkAnswer(
df.selectExpr("b * a / b"),
Seq(Row(d.toBigDecimal)))
checkAnswer(
df.selectExpr("b * a / b / b"),
Seq(Row(d.toBigDecimal)))
checkAnswer(
df.selectExpr("b * a + b"),
Seq(Row(BigDecimal(2.12321))))
checkAnswer(
df.selectExpr("b * a - b"),
Seq(Row(BigDecimal(0.12321))))
checkAnswer(
df.selectExpr("b * a * b"),
Seq(Row(d.toBigDecimal)))
}

test("external sorting updates peak execution memory") {
Expand Down