From d315fac89655a89febb44d206f896f87da4e43d2 Mon Sep 17 00:00:00 2001 From: bomeng Date: Thu, 17 Mar 2016 10:48:55 -0700 Subject: [PATCH 1/4] fix the data type cast issue --- .../catalyst/analysis/DecimalPrecision.scala | 12 ++++----- .../analysis/DecimalPrecisionSuite.scala | 26 +++++++++++-------- .../org/apache/spark/sql/DataFrameSuite.scala | 13 ++++++++++ 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 9c38dd2ee4e53..8018a30b4b7fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -61,7 +61,9 @@ import org.apache.spark.sql.types._ object DecimalPrecision extends Rule[LogicalPlan] { import scala.math.{max, min} - private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + private def isNumericType(t: DataType): Boolean = + t == FloatType || t == DoubleType || t == ByteType || + t == IntegerType || t == ShortType || t == LongType // Returns the wider decimal type that's wider than both of them def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { @@ -244,14 +246,10 @@ object DecimalPrecision extends Rule[LogicalPlan] { // and fixed-precision decimals in an expression with floats / doubles to doubles case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { - case (t: IntegralType, DecimalType.Fixed(p, s)) => + case (t: NumericType, DecimalType.Fixed(p, s)) if isNumericType(t) => b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) - case (DecimalType.Fixed(p, s), t: IntegralType) => + case (DecimalType.Fixed(p, s), t: NumericType) if isNumericType(t) => b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(left, Cast(right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(left, DoubleType), right)) case _ => b } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index b2613e4909288..2d720f557b461 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -101,7 +101,7 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { checkComparison(LessThan(i, d1), DecimalType(11, 1)) checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT) - checkComparison(GreaterThanOrEqual(d1, f), DoubleType) + checkComparison(GreaterThanOrEqual(d1, f), DecimalType(14, 7)) checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) } @@ -120,13 +120,13 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { test("bringing in primitive types") { checkType(Add(d1, i), DecimalType(12, 1)) - checkType(Add(d1, f), DoubleType) + checkType(Add(d1, f), DecimalType(15, 7)) checkType(Add(i, d1), DecimalType(12, 1)) - checkType(Add(f, d1), DoubleType) + checkType(Add(f, d1), DecimalType(15, 7)) checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1)) checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1)) checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1)) - checkType(Add(d1, Cast(i, DoubleType)), DoubleType) + checkType(Add(d1, Cast(i, DoubleType)), DecimalType(31, 15)) } test("maximum decimals") { @@ -150,13 +150,17 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { checkType(Remainder(i, u), DecimalType(28, 18)) checkType(Remainder(u, u), DecimalType.SYSTEM_DEFAULT) - for (expr <- Seq(f, b)) { - checkType(Add(expr, u), DoubleType) - checkType(Subtract(expr, u), DoubleType) - checkType(Multiply(expr, u), DoubleType) - checkType(Divide(expr, u), DoubleType) - checkType(Remainder(expr, u), DoubleType) - } + checkType(Add(f, u), DecimalType(38, 18)) + checkType(Subtract(f, u), DecimalType(38, 18)) + checkType(Multiply(f, u), DecimalType(38, 25)) + checkType(Divide(f, u), DecimalType(38, 25)) + checkType(Remainder(f, u), DecimalType(25, 18)) + + checkType(Add(b, u), DecimalType(38, 18)) + checkType(Subtract(b, u), DecimalType(38, 18)) + checkType(Multiply(b, u), DecimalType(38, 33)) + checkType(Divide(b, u), DecimalType(38, 21)) + checkType(Remainder(b, u), DecimalType(33, 18)) } test("DecimalType.isWiderThan") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 199e138abfdc2..d64ad97ac7191 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1376,4 +1376,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(e.getStackTrace.head.getClassName != classOf[QueryExecution].getName) } + + test("SPARK-13858/13861") { + val rowRDD = sparkContext.parallelize(Seq(Row(1.49f))) + val schema = StructType(StructField("col", FloatType) :: Nil) + val df = sqlContext.createDataFrame(rowRDD, schema) + df.registerTempTable("table") + +// val query1 = sql("select count(*) from table where col between 0 and 1.49") +// checkAnswer(query1, Row(1) :: Nil) + + val query2 = sql("select count(*) from table where col = 1.49") + checkAnswer(query2, Row(1) :: Nil) + } } From 83d8d9deef85fed16ddfa015c60650bc34debc00 Mon Sep 17 00:00:00 2001 From: bomeng Date: Thu, 17 Mar 2016 11:05:45 -0700 Subject: [PATCH 2/4] update the comments --- .../apache/spark/sql/catalyst/analysis/DecimalPrecision.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 8018a30b4b7fe..e4e3d83ed7ac4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -242,8 +242,8 @@ object DecimalPrecision extends Rule[LogicalPlan] { * side is a decimal. */ private val nondecimalAndDecimal: PartialFunction[Expression, Expression] = { - // Promote integers inside a binary expression with fixed-precision decimals to decimals, - // and fixed-precision decimals in an expression with floats / doubles to doubles + // Promote different data types inside a binary expression + // with fixed-precision decimals to decimals case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { case (t: NumericType, DecimalType.Fixed(p, s)) if isNumericType(t) => From db571e0f44e221e76508f99eb0554a31b51ce32d Mon Sep 17 00:00:00 2001 From: bomeng Date: Thu, 17 Mar 2016 11:20:36 -0700 Subject: [PATCH 3/4] add one more test case --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d64ad97ac7191..1f383927fe866 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1383,8 +1383,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = sqlContext.createDataFrame(rowRDD, schema) df.registerTempTable("table") -// val query1 = sql("select count(*) from table where col between 0 and 1.49") -// checkAnswer(query1, Row(1) :: Nil) + val query1 = sql("select count(*) from table where col between 0 and 1.49") + checkAnswer(query1, Row(1) :: Nil) val query2 = sql("select count(*) from table where col = 1.49") checkAnswer(query2, Row(1) :: Nil) From 7f261a90a08b5559fce8b40cf9e3eff26ec1b3a6 Mon Sep 17 00:00:00 2001 From: bomeng Date: Fri, 18 Mar 2016 17:35:48 -0700 Subject: [PATCH 4/4] another approach based data type and its range --- .../catalyst/analysis/DecimalPrecision.scala | 47 +++++++++++++++---- .../analysis/DecimalPrecisionSuite.scala | 26 +++++----- 2 files changed, 48 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index e4e3d83ed7ac4..8a2840211b308 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -61,10 +61,6 @@ import org.apache.spark.sql.types._ object DecimalPrecision extends Rule[LogicalPlan] { import scala.math.{max, min} - private def isNumericType(t: DataType): Boolean = - t == FloatType || t == DoubleType || t == ByteType || - t == IntegerType || t == ShortType || t == LongType - // Returns the wider decimal type that's wider than both of them def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) @@ -242,16 +238,47 @@ object DecimalPrecision extends Rule[LogicalPlan] { * side is a decimal. */ private val nondecimalAndDecimal: PartialFunction[Expression, Expression] = { - // Promote different data types inside a binary expression - // with fixed-precision decimals to decimals + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to floats / doubles case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { - case (t: NumericType, DecimalType.Fixed(p, s)) if isNumericType(t) => + case (t: IntegralType, DecimalType.Fixed(p, s)) => b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) - case (DecimalType.Fixed(p, s), t: NumericType) if isNumericType(t) => + case (DecimalType.Fixed(p, s), t: IntegralType) => b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) - case _ => - b + + case (t, DecimalType.Fixed(p, s)) if t == FloatType => + right match { + case l: Literal => { + val value = l.value.asInstanceOf[Decimal].toDouble + if (value > Float.MaxValue || value < Float.MinValue) { + b.makeCopy(Array(left, Cast(right, DoubleType))) + } else { + b.makeCopy(Array(left, Cast(right, FloatType))) + } + } + case _ => b.makeCopy(Array(left, Cast(right, DoubleType))) + } + + case (DecimalType.Fixed(p, s), t) if t == FloatType => + left match { + case l: Literal => { + val value = l.value.asInstanceOf[Decimal].toDouble + if (value > Float.MaxValue || value < Float.MinValue) { + b.makeCopy(Array(Cast(left, DoubleType), right)) + } else { + b.makeCopy(Array(Cast(left, FloatType), right)) + } + } + case _ => b.makeCopy(Array(Cast(left, DoubleType), right)) + } + + case (t, DecimalType.Fixed(p, s)) if t == DoubleType => + b.makeCopy(Array(left, Cast(right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if t == DoubleType => + b.makeCopy(Array(Cast(left, DoubleType), right)) + + case _ => b } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 2d720f557b461..b2613e4909288 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -101,7 +101,7 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { checkComparison(LessThan(i, d1), DecimalType(11, 1)) checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT) - checkComparison(GreaterThanOrEqual(d1, f), DecimalType(14, 7)) + checkComparison(GreaterThanOrEqual(d1, f), DoubleType) checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) } @@ -120,13 +120,13 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { test("bringing in primitive types") { checkType(Add(d1, i), DecimalType(12, 1)) - checkType(Add(d1, f), DecimalType(15, 7)) + checkType(Add(d1, f), DoubleType) checkType(Add(i, d1), DecimalType(12, 1)) - checkType(Add(f, d1), DecimalType(15, 7)) + checkType(Add(f, d1), DoubleType) checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1)) checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1)) checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1)) - checkType(Add(d1, Cast(i, DoubleType)), DecimalType(31, 15)) + checkType(Add(d1, Cast(i, DoubleType)), DoubleType) } test("maximum decimals") { @@ -150,17 +150,13 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { checkType(Remainder(i, u), DecimalType(28, 18)) checkType(Remainder(u, u), DecimalType.SYSTEM_DEFAULT) - checkType(Add(f, u), DecimalType(38, 18)) - checkType(Subtract(f, u), DecimalType(38, 18)) - checkType(Multiply(f, u), DecimalType(38, 25)) - checkType(Divide(f, u), DecimalType(38, 25)) - checkType(Remainder(f, u), DecimalType(25, 18)) - - checkType(Add(b, u), DecimalType(38, 18)) - checkType(Subtract(b, u), DecimalType(38, 18)) - checkType(Multiply(b, u), DecimalType(38, 33)) - checkType(Divide(b, u), DecimalType(38, 21)) - checkType(Remainder(b, u), DecimalType(33, 18)) + for (expr <- Seq(f, b)) { + checkType(Add(expr, u), DoubleType) + checkType(Subtract(expr, u), DoubleType) + checkType(Multiply(expr, u), DoubleType) + checkType(Divide(expr, u), DoubleType) + checkType(Remainder(expr, u), DoubleType) + } } test("DecimalType.isWiderThan") {