From a3437ee4a87d1f51b362adeb20d4fcc264085ba7 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 13 Oct 2017 21:45:27 -0700 Subject: [PATCH 1/4] [SPARK-22271][SQL]mean overflows and returns null for some decimal variables --- .../sql/catalyst/expressions/aggregate/Average.scala | 3 ++- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index c423e17169e8..94c4ab72c9e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -80,7 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, dt), resultType) + Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded (DecimalType.MAX_PRECISION, 0)), + resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ad461fa6144b..17447acc1ef5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2103,4 +2103,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + + test("SPARK-22271: mean overflows and returns null for some decimal variables") { + val d: BigDecimal = BigDecimal(0.034567890) + val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") + val result = df.select('DecimalCol cast DecimalType(38, 33)) + .select(col("DecimalCol")).describe() + val mean = result.select("DecimalCol").where($"summary" === "mean") + assert(mean.collect.toSet === Set(Row("0.0345678900000000000000000000000000000"))) + } } From de2aa6975c31f4c095e07a34b66b24ee39f83b01 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 14 Oct 2017 09:28:54 -0700 Subject: [PATCH 2/4] remove extra space --- .../spark/sql/catalyst/expressions/aggregate/Average.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 94c4ab72c9e6..708bdbfc3605 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -80,7 +80,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded (DecimalType.MAX_PRECISION, 0)), + Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)), resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) From 72467e175d626648451940bdbccaf7866b2ded6c Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 15 Oct 2017 14:42:27 -0700 Subject: [PATCH 3/4] remove unnecessary data type info --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 17447acc1ef5..a90b2b2c23ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2105,7 +2105,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-22271: mean overflows and returns null for some decimal variables") { - val d: BigDecimal = BigDecimal(0.034567890) + val d = 0.034567890 val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") val result = df.select('DecimalCol cast DecimalType(38, 33)) .select(col("DecimalCol")).describe() From 8438a6171e015d4ad239c1562254635ee5ed51ce Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 17 Oct 2017 09:21:02 -0700 Subject: [PATCH 4/4] fix code format problems --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a90b2b2c23ca..80a4bb775ee4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2108,8 +2108,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val d = 0.034567890 val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") val result = df.select('DecimalCol cast DecimalType(38, 33)) - .select(col("DecimalCol")).describe() + .select(col("DecimalCol")).describe() val mean = result.select("DecimalCol").where($"summary" === "mean") - assert(mean.collect.toSet === Set(Row("0.0345678900000000000000000000000000000"))) + assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) } }