From 7c4b454c863b4e760a3c7df9f0d17f94e86a5a47 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 18 Sep 2018 15:47:10 +0200 Subject: [PATCH 1/6] [SPARK-25454][SQL] Avoid precision loss in division with decimal with negative scale --- .../sql/catalyst/analysis/DecimalPrecision.scala | 5 +++-- .../catalyst/analysis/DecimalPrecisionSuite.scala | 2 ++ .../native/decimalArithmeticOperations.sql | 3 +++ .../native/decimalArithmeticOperations.sql.out | 14 +++++++++++--- 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index e511f8064e28a..8e5050d29c21b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -129,16 +129,17 @@ object DecimalPrecision extends TypeCoercionRule { resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val adjP2 = if (s2 < 0) p2 - s2 else p2 val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) // Scale: max(6, s1 + p2 + 1) val intDig = p1 - s1 + s2 - val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) + val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + adjP2 + 1) val prec = intDig + scale DecimalType.adjustPrecisionScale(prec, scale) } else { var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) - var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + adjP2 + 1)) val diff = (intDig + decDig) - DecimalType.MAX_SCALE if (diff > 0) { decDig -= diff / 2 + 1 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index bd87ca6017e99..ec3c04b43f79f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -276,9 +276,11 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { val a = AttributeReference("a", DecimalType(3, -10))() val b = AttributeReference("b", DecimalType(1, -1))() val c = AttributeReference("c", DecimalType(35, 1))() + val nonNegative = AttributeReference("nn", DecimalType(11, 0))() checkType(Multiply(a, b), DecimalType(5, -11)) checkType(Multiply(a, c), DecimalType(38, -9)) checkType(Multiply(b, c), DecimalType(37, 0)) + checkType(Divide(nonNegative, a), DecimalType(15, 14)) } /** strength reduction for integer/decimal comparisons */ diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index 28a0e20c0f495..be427aa88459c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -83,4 +83,7 @@ select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.1 select 123456789123456789.1234567890 * 1.123456789123456789; select 12345678912345.123456789123 / 0.000000012345678; +-- division with negative scale operands +select 26393499451/ 1000e6; + drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index cbf44548b3cce..3af6e69c50b02 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 40 +-- Number of queries: 41 -- !query 0 @@ -328,8 +328,16 @@ NULL -- !query 39 -drop table decimals_test +select 26393499451/ 1000e6 -- !query 39 schema -struct<> +struct<(CAST(CAST(26393499451 AS DECIMAL(11,0)) AS DECIMAL(11,0)) / CAST(1.000E+9 AS DECIMAL(11,0))):decimal(16,11)> -- !query 39 output +26.393499451 + + +-- !query 40 +drop table decimals_test +-- !query 40 schema +struct<> +-- !query 40 output From 520b64ed058843f88c57bb9de1bbc334e5c86f32 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 19 Sep 2018 09:14:39 +0200 Subject: [PATCH 2/6] address comment --- .../inputs/typeCoercion/native/decimalArithmeticOperations.sql | 2 +- .../typeCoercion/native/decimalArithmeticOperations.sql.out | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index be427aa88459c..e8b52da5cfac2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -84,6 +84,6 @@ select 123456789123456789.1234567890 * 1.123456789123456789; select 12345678912345.123456789123 / 0.000000012345678; -- division with negative scale operands -select 26393499451/ 1000e6; +select 26393499451 / 1000e6; drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index 3af6e69c50b02..b2d305b9a3e13 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -328,7 +328,7 @@ NULL -- !query 39 -select 26393499451/ 1000e6 +select 26393499451 / 1000e6 -- !query 39 schema struct<(CAST(CAST(26393499451 AS DECIMAL(11,0)) AS DECIMAL(11,0)) / CAST(1.000E+9 AS DECIMAL(11,0))):decimal(16,11)> -- !query 39 output From 27a9ea656428d4705c6323deda64b14eb7ced7a4 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 19 Sep 2018 15:22:44 +0200 Subject: [PATCH 3/6] address comments --- .../catalyst/analysis/DecimalPrecision.scala | 9 +++++--- .../ArithmeticExpressionSuite.scala | 22 +++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 8e5050d29c21b..9f53c218d7e88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -40,10 +40,13 @@ import org.apache.spark.sql.types._ * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) * e1 * e2 p1 + p2 + 1 s1 + s2 - * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) + * e1 / e2 max(p1-s1+s2, 0) + max(6, s1+adjP2+1) max(6, s1+adjP2+1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) * + * Where adjP2 is p2 - s2 if s2 < 0, p2 otherwise. This adjustment is needed because Spark does not + * forbid decimals with negative scale, while MS SQL and Hive do. + * * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale * needed are out of the range of available values, the scale is reduced up to 6, in order to * prevent the truncation of the integer part of the decimals. @@ -133,12 +136,12 @@ object DecimalPrecision extends TypeCoercionRule { val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) // Scale: max(6, s1 + p2 + 1) - val intDig = p1 - s1 + s2 + val intDig = max(p1 - s1 + s2, 0) // can be negative if s2 < 0 val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + adjP2 + 1) val prec = intDig + scale DecimalType.adjustPrecisionScale(prec, scale) } else { - var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var intDig = max(min(DecimalType.MAX_SCALE, p1 - s1 + s2), 0) // can be negative if s2 < 0 var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + adjP2 + 1)) val diff = (intDig + decDig) - DecimalType.MAX_SCALE if (diff > 0) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 9a752af523ffc..f58cd31397df6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -366,4 +367,25 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2) assert(ctx2.inlinedMutableStates.size == 1) } + + test("SPARK-25454: decimal operations with negative scale") { + val a = Literal(BigDecimal(1234567891)) + val b = Literal(BigDecimal(100e6)) + val c = Literal(BigDecimal(123456.7891)) + assert(b.dataType.isInstanceOf[DecimalType] && + b.dataType.asInstanceOf[DecimalType].scale < 0) + Seq("true", "false").foreach { allowPrecLoss => + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss) { + checkEvaluationWithOptimization(Add(a, b), Decimal(BigDecimal(1334567891))) + checkEvaluationWithOptimization(Add(b, c), Decimal(BigDecimal(100123456.7891))) + checkEvaluationWithOptimization(Subtract(a, b), Decimal(BigDecimal(1134567891))) + checkEvaluationWithOptimization(Subtract(b, c), Decimal(BigDecimal(99876543.2109))) + checkEvaluationWithOptimization(Multiply(a, b), Decimal(BigDecimal(123456789100000000L))) + checkEvaluationWithOptimization(Multiply(b, c), Decimal(BigDecimal(12345678910000L))) + checkEvaluationWithOptimization(Divide(a, b), Decimal(BigDecimal(12.34567891))) + checkEvaluationWithOptimization(Divide(b, c), Decimal(BigDecimal(810.000007))) + checkEvaluationWithOptimization(Divide(c, b), Decimal(BigDecimal(0.001234567891))) + } + } + } } From 4e240d9abea9ea67312f31e3af129416b8c3381a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 21 Sep 2018 11:15:54 +0200 Subject: [PATCH 4/6] add tests --- .../expressions/ArithmeticExpressionSuite.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index f58cd31397df6..68ec7f4e37147 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -372,19 +372,26 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val a = Literal(BigDecimal(1234567891)) val b = Literal(BigDecimal(100e6)) val c = Literal(BigDecimal(123456.7891)) - assert(b.dataType.isInstanceOf[DecimalType] && - b.dataType.asInstanceOf[DecimalType].scale < 0) + val d = Literal(BigDecimal(678e8)) + Seq(b, d).foreach { l => + assert(l.dataType.isInstanceOf[DecimalType] && + l.dataType.asInstanceOf[DecimalType].scale < 0) + } Seq("true", "false").foreach { allowPrecLoss => withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss) { checkEvaluationWithOptimization(Add(a, b), Decimal(BigDecimal(1334567891))) checkEvaluationWithOptimization(Add(b, c), Decimal(BigDecimal(100123456.7891))) + checkEvaluationWithOptimization(Add(b, d), Decimal(BigDecimal(67900e6))) checkEvaluationWithOptimization(Subtract(a, b), Decimal(BigDecimal(1134567891))) checkEvaluationWithOptimization(Subtract(b, c), Decimal(BigDecimal(99876543.2109))) + checkEvaluationWithOptimization(Subtract(d, b), Decimal(BigDecimal(67700e6))) checkEvaluationWithOptimization(Multiply(a, b), Decimal(BigDecimal(123456789100000000L))) checkEvaluationWithOptimization(Multiply(b, c), Decimal(BigDecimal(12345678910000L))) + checkEvaluationWithOptimization(Multiply(d, b), Decimal(BigDecimal(67800e14))) checkEvaluationWithOptimization(Divide(a, b), Decimal(BigDecimal(12.34567891))) checkEvaluationWithOptimization(Divide(b, c), Decimal(BigDecimal(810.000007))) checkEvaluationWithOptimization(Divide(c, b), Decimal(BigDecimal(0.001234567891))) + checkEvaluationWithOptimization(Divide(d, b), Decimal(BigDecimal(678))) } } } From dd19f7f71f33f334658fab0f100841d3472bbbae Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 7 Jan 2019 21:11:21 +0100 Subject: [PATCH 5/6] add comments --- .../org/apache/spark/sql/types/DecimalType.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 15004e4b9667d..9f8be3b9ad621 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -28,12 +28,16 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} /** * The data type representing `java.math.BigDecimal` values. - * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number - * of digits on right side of dot). + * A Decimal represents an exact numeric in which the precision and scale can be arbitrarily sized. + * The precision is the number of significant digits and it can range from 1 to 38. The scale can + * be positive or negative. If zero or positive, the scale is the number of digits to the right of + * the decimal point. If negative, the unscaled value of the number is multiplied by ten to the + * power of the negation of the scale. * - * The precision can be up to 38, scale can also be up to 38 (less or equal to precision). + * Please, notice that not all datasources support negative scales. In that case, writing decimals + * with negative scales can lead to errors and exceptions. * - * The default precision and scale is (10, 0). + * The default precision and scale is (38, 18). * * Please use `DataTypes.createDecimalType()` to create a specific instance. * From 97b9c56f3050609604c7b5c50563db571f84b190 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 8 Jan 2019 22:10:29 +0100 Subject: [PATCH 6/6] fix test failures --- .../ArithmeticExpressionSuite.scala | 29 ++++++++++--------- .../expressions/ExpressionEvalHelper.scala | 15 +++++++++- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index cbebe105535a2..31206567a9244 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -21,9 +21,12 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -388,19 +391,19 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } Seq("true", "false").foreach { allowPrecLoss => withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss) { - checkEvaluationWithOptimization(Add(a, b), Decimal(BigDecimal(1334567891))) - checkEvaluationWithOptimization(Add(b, c), Decimal(BigDecimal(100123456.7891))) - checkEvaluationWithOptimization(Add(b, d), Decimal(BigDecimal(67900e6))) - checkEvaluationWithOptimization(Subtract(a, b), Decimal(BigDecimal(1134567891))) - checkEvaluationWithOptimization(Subtract(b, c), Decimal(BigDecimal(99876543.2109))) - checkEvaluationWithOptimization(Subtract(d, b), Decimal(BigDecimal(67700e6))) - checkEvaluationWithOptimization(Multiply(a, b), Decimal(BigDecimal(123456789100000000L))) - checkEvaluationWithOptimization(Multiply(b, c), Decimal(BigDecimal(12345678910000L))) - checkEvaluationWithOptimization(Multiply(d, b), Decimal(BigDecimal(67800e14))) - checkEvaluationWithOptimization(Divide(a, b), Decimal(BigDecimal(12.34567891))) - checkEvaluationWithOptimization(Divide(b, c), Decimal(BigDecimal(810.000007))) - checkEvaluationWithOptimization(Divide(c, b), Decimal(BigDecimal(0.001234567891))) - checkEvaluationWithOptimization(Divide(d, b), Decimal(BigDecimal(678))) + checkEvaluationWithAnalysis(Add(a, b), Decimal(BigDecimal(1334567891))) + checkEvaluationWithAnalysis(Add(b, c), Decimal(BigDecimal(100123456.7891))) + checkEvaluationWithAnalysis(Add(b, d), Decimal(BigDecimal(67900e6))) + checkEvaluationWithAnalysis(Subtract(a, b), Decimal(BigDecimal(1134567891))) + checkEvaluationWithAnalysis(Subtract(b, c), Decimal(BigDecimal(99876543.2109))) + checkEvaluationWithAnalysis(Subtract(d, b), Decimal(BigDecimal(67700e6))) + checkEvaluationWithAnalysis(Multiply(a, b), Decimal(BigDecimal(123456789100000000L))) + checkEvaluationWithAnalysis(Multiply(b, c), Decimal(BigDecimal(12345678910000L))) + checkEvaluationWithAnalysis(Multiply(d, b), Decimal(BigDecimal(67800e14))) + checkEvaluationWithAnalysis(Divide(a, b), Decimal(BigDecimal(12.34567891))) + checkEvaluationWithAnalysis(Divide(b, c), Decimal(BigDecimal(810.000007))) + checkEvaluationWithAnalysis(Divide(c, b), Decimal(BigDecimal(0.001234567891))) + checkEvaluationWithAnalysis(Divide(d, b), Decimal(BigDecimal(678))) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 1c91adab71375..e9c0c07adc7b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -28,7 +28,8 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry, ResolveTimeZone} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.PlanTestBase @@ -44,6 +45,9 @@ import org.apache.spark.util.Utils trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBase { self: SparkFunSuite => + private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + private val analyzer = new Analyzer(catalog, conf) + protected def create_row(values: Any*): InternalRow = { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } @@ -302,6 +306,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } + protected def checkEvaluationWithAnalysis( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val plan = Project(Alias(expression, s"Analyzed($expression)")() :: Nil, OneRowRelation()) + val analyzedPlan = analyzer.execute(plan) + checkEvaluationWithoutCodegen(analyzedPlan.expressions.head, expected, inputRow) + } + protected def checkDoubleEvaluation( expression: => Expression, expected: Spread[Double],