Skip to content

Commit

Permalink
Revert "[SPARK-48016][SQL] Fix a bug in try_divide function when with…
Browse files Browse the repository at this point in the history
… decimals"

This reverts commit e78ee2c.
  • Loading branch information
dongjoon-hyun committed May 1, 2024
1 parent 953d7f9 commit fc0ef07
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 261 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ object DecimalPrecision extends TypeCoercionRule {
val resultType = widerDecimalType(p1, s1, p2, s2)
val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType)
val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType)
b.withNewChildren(Seq(newE1, newE2))
b.makeCopy(Array(newE1, newE2))
}

/**
Expand Down Expand Up @@ -202,21 +202,21 @@ object DecimalPrecision extends TypeCoercionRule {
case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
l.dataType.isInstanceOf[IntegralType] &&
literalPickMinimumPrecision =>
b.withNewChildren(Seq(Cast(l, DataTypeUtils.fromLiteral(l)), r))
b.makeCopy(Array(Cast(l, DataTypeUtils.fromLiteral(l)), r))
case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] &&
r.dataType.isInstanceOf[IntegralType] &&
literalPickMinimumPrecision =>
b.withNewChildren(Seq(l, Cast(r, DataTypeUtils.fromLiteral(r))))
b.makeCopy(Array(l, Cast(r, DataTypeUtils.fromLiteral(r))))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case (l @ IntegralTypeExpression(), r @ DecimalExpression(_, _)) =>
b.withNewChildren(Seq(Cast(l, DecimalType.forType(l.dataType)), r))
b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r))
case (l @ DecimalExpression(_, _), r @ IntegralTypeExpression()) =>
b.withNewChildren(Seq(l, Cast(r, DecimalType.forType(r.dataType))))
b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType))))
case (l, r @ DecimalExpression(_, _)) if isFloat(l.dataType) =>
b.withNewChildren(Seq(l, Cast(r, DoubleType)))
b.makeCopy(Array(l, Cast(r, DoubleType)))
case (l @ DecimalExpression(_, _), r) if isFloat(r.dataType) =>
b.withNewChildren(Seq(Cast(l, DoubleType), r))
b.makeCopy(Array(Cast(l, DoubleType), r))
case _ => b
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1102,22 +1102,22 @@ object TypeCoercion extends TypeCoercionBase {

case a @ BinaryArithmetic(left @ StringTypeExpression(), right)
if right.dataType != CalendarIntervalType =>
a.withNewChildren(Seq(Cast(left, DoubleType), right))
a.makeCopy(Array(Cast(left, DoubleType), right))
case a @ BinaryArithmetic(left, right @ StringTypeExpression())
if left.dataType != CalendarIntervalType =>
a.withNewChildren(Seq(left, Cast(right, DoubleType)))
a.makeCopy(Array(left, Cast(right, DoubleType)))

// For equality between string and timestamp we cast the string to a timestamp
// so that things like rounding of subsecond precision does not affect the comparison.
case p @ Equality(left @ StringTypeExpression(), right @ TimestampTypeExpression()) =>
p.withNewChildren(Seq(Cast(left, TimestampType), right))
p.makeCopy(Array(Cast(left, TimestampType), right))
case p @ Equality(left @ TimestampTypeExpression(), right @ StringTypeExpression()) =>
p.withNewChildren(Seq(left, Cast(right, TimestampType)))
p.makeCopy(Array(left, Cast(right, TimestampType)))

case p @ BinaryComparison(left, right)
if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined =>
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get
p.withNewChildren(Seq(castExpr(left, commonType), castExpr(right, commonType)))
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
}
}

Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/test/resources/log4j2.properties
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ logger.parquet_recordwriter.name = org.apache.parquet.hadoop.InternalParquetReco
logger.parquet_recordwriter.additivity = false
logger.parquet_recordwriter.level = off

logger.parquet_outputcommitter.name = org.sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scalaapache.parquet.hadoop.ParquetOutputCommitter
logger.parquet_outputcommitter.name = org.apache.parquet.hadoop.ParquetOutputCommitter
logger.parquet_outputcommitter.additivity = false
logger.parquet_outputcommitter.level = off

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,6 @@ Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x]
+- OneRowRelation


-- !query
SELECT try_add(2147483647, decimal(1))
-- !query analysis
Project [try_add(2147483647, cast(1 as decimal(10,0))) AS try_add(2147483647, 1)#x]
+- OneRowRelation


-- !query
SELECT try_add(2147483647, "1")
-- !query analysis
Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#xL]
+- OneRowRelation


-- !query
SELECT try_add(-2147483648, -1)
-- !query analysis
Expand Down Expand Up @@ -225,20 +211,6 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 / 0.0))#x]
+- OneRowRelation


-- !query
SELECT try_divide(1, decimal(0))
-- !query analysis
Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x]
+- OneRowRelation


-- !query
SELECT try_divide(1, "0")
-- !query analysis
Project [try_divide(1, 0) AS try_divide(1, 0)#x]
+- OneRowRelation


-- !query
SELECT try_divide(interval 2 year, 2)
-- !query analysis
Expand Down Expand Up @@ -295,20 +267,6 @@ Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x]
+- OneRowRelation


-- !query
SELECT try_subtract(2147483647, decimal(-1))
-- !query analysis
Project [try_subtract(2147483647, cast(-1 as decimal(10,0))) AS try_subtract(2147483647, -1)#x]
+- OneRowRelation


-- !query
SELECT try_subtract(2147483647, "-1")
-- !query analysis
Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#xL]
+- OneRowRelation


-- !query
SELECT try_subtract(-2147483648, 1)
-- !query analysis
Expand Down Expand Up @@ -393,20 +351,6 @@ Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x]
+- OneRowRelation


-- !query
SELECT try_multiply(2147483647, decimal(-2))
-- !query analysis
Project [try_multiply(2147483647, cast(-2 as decimal(10,0))) AS try_multiply(2147483647, -2)#x]
+- OneRowRelation


-- !query
SELECT try_multiply(2147483647, "-2")
-- !query analysis
Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#xL]
+- OneRowRelation


-- !query
SELECT try_multiply(-2147483648, 2)
-- !query analysis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,6 @@ Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x]
+- OneRowRelation


-- !query
SELECT try_add(2147483647, decimal(1))
-- !query analysis
Project [try_add(2147483647, cast(1 as decimal(10,0))) AS try_add(2147483647, 1)#x]
+- OneRowRelation


-- !query
SELECT try_add(2147483647, "1")
-- !query analysis
Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x]
+- OneRowRelation


-- !query
SELECT try_add(-2147483648, -1)
-- !query analysis
Expand Down Expand Up @@ -225,20 +211,6 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 / 0.0))#x]
+- OneRowRelation


-- !query
SELECT try_divide(1, decimal(0))
-- !query analysis
Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x]
+- OneRowRelation


-- !query
SELECT try_divide(1, "0")
-- !query analysis
Project [try_divide(1, 0) AS try_divide(1, 0)#x]
+- OneRowRelation


-- !query
SELECT try_divide(interval 2 year, 2)
-- !query analysis
Expand Down Expand Up @@ -295,20 +267,6 @@ Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x]
+- OneRowRelation


-- !query
SELECT try_subtract(2147483647, decimal(-1))
-- !query analysis
Project [try_subtract(2147483647, cast(-1 as decimal(10,0))) AS try_subtract(2147483647, -1)#x]
+- OneRowRelation


-- !query
SELECT try_subtract(2147483647, "-1")
-- !query analysis
Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x]
+- OneRowRelation


-- !query
SELECT try_subtract(-2147483648, 1)
-- !query analysis
Expand Down Expand Up @@ -393,20 +351,6 @@ Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x]
+- OneRowRelation


-- !query
SELECT try_multiply(2147483647, decimal(-2))
-- !query analysis
Project [try_multiply(2147483647, cast(-2 as decimal(10,0))) AS try_multiply(2147483647, -2)#x]
+- OneRowRelation


-- !query
SELECT try_multiply(2147483647, "-2")
-- !query analysis
Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x]
+- OneRowRelation


-- !query
SELECT try_multiply(-2147483648, 2)
-- !query analysis
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
-- Numeric + Numeric
SELECT try_add(1, 1);
SELECT try_add(2147483647, 1);
SELECT try_add(2147483647, decimal(1));
SELECT try_add(2147483647, "1");
SELECT try_add(-2147483648, -1);
SELECT try_add(9223372036854775807L, 1);
SELECT try_add(-9223372036854775808L, -1);
Expand Down Expand Up @@ -40,8 +38,6 @@ SELECT try_divide(0, 0);
SELECT try_divide(1, (2147483647 + 1));
SELECT try_divide(1L, (9223372036854775807L + 1L));
SELECT try_divide(1, 1.0 / 0.0);
SELECT try_divide(1, decimal(0));
SELECT try_divide(1, "0");

-- Interval / Numeric
SELECT try_divide(interval 2 year, 2);
Expand All @@ -54,8 +50,6 @@ SELECT try_divide(interval 106751991 day, 0.5);
-- Numeric - Numeric
SELECT try_subtract(1, 1);
SELECT try_subtract(2147483647, -1);
SELECT try_subtract(2147483647, decimal(-1));
SELECT try_subtract(2147483647, "-1");
SELECT try_subtract(-2147483648, 1);
SELECT try_subtract(9223372036854775807L, -1);
SELECT try_subtract(-9223372036854775808L, 1);
Expand All @@ -72,8 +66,6 @@ SELECT try_subtract(interval 106751991 day, interval -3 day);
-- Numeric * Numeric
SELECT try_multiply(2, 3);
SELECT try_multiply(2147483647, -2);
SELECT try_multiply(2147483647, decimal(-2));
SELECT try_multiply(2147483647, "-2");
SELECT try_multiply(-2147483648, 2);
SELECT try_multiply(9223372036854775807L, 2);
SELECT try_multiply(-9223372036854775808L, -2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,6 @@ struct<try_add(2147483647, 1):int>
NULL


-- !query
SELECT try_add(2147483647, decimal(1))
-- !query schema
struct<try_add(2147483647, 1):decimal(11,0)>
-- !query output
2147483648


-- !query
SELECT try_add(2147483647, "1")
-- !query schema
struct<try_add(2147483647, 1):bigint>
-- !query output
2147483648


-- !query
SELECT try_add(-2147483648, -1)
-- !query schema
Expand Down Expand Up @@ -357,22 +341,6 @@ org.apache.spark.SparkArithmeticException
}


-- !query
SELECT try_divide(1, decimal(0))
-- !query schema
struct<try_divide(1, 0):decimal(12,11)>
-- !query output
NULL


-- !query
SELECT try_divide(1, "0")
-- !query schema
struct<try_divide(1, 0):double>
-- !query output
NULL


-- !query
SELECT try_divide(interval 2 year, 2)
-- !query schema
Expand Down Expand Up @@ -437,22 +405,6 @@ struct<try_subtract(2147483647, -1):int>
NULL


-- !query
SELECT try_subtract(2147483647, decimal(-1))
-- !query schema
struct<try_subtract(2147483647, -1):decimal(11,0)>
-- !query output
2147483648


-- !query
SELECT try_subtract(2147483647, "-1")
-- !query schema
struct<try_subtract(2147483647, -1):bigint>
-- !query output
2147483648


-- !query
SELECT try_subtract(-2147483648, 1)
-- !query schema
Expand Down Expand Up @@ -595,22 +547,6 @@ struct<try_multiply(2147483647, -2):int>
NULL


-- !query
SELECT try_multiply(2147483647, decimal(-2))
-- !query schema
struct<try_multiply(2147483647, -2):decimal(21,0)>
-- !query output
-4294967294


-- !query
SELECT try_multiply(2147483647, "-2")
-- !query schema
struct<try_multiply(2147483647, -2):bigint>
-- !query output
-4294967294


-- !query
SELECT try_multiply(-2147483648, 2)
-- !query schema
Expand Down

0 comments on commit fc0ef07

Please sign in to comment.