diff --git a/backends-velox/src/test/scala/org/apache/gluten/functions/ArithmeticAnsiValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/functions/ArithmeticAnsiValidateSuite.scala index a1633c4cb49..994097976bf 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/functions/ArithmeticAnsiValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/functions/ArithmeticAnsiValidateSuite.scala @@ -100,4 +100,109 @@ class ArithmeticAnsiValidateSuite extends FunctionsValidateSuite { } } + test("decimal add overflow") { + // Normal decimal add should succeed and match Spark results + runQueryAndCompare( + "SELECT CAST(1.0 AS DECIMAL(10,2)) + CAST(2.0 AS DECIMAL(10,2))") { + checkGlutenPlan[ProjectExecTransformer] + } + + // Overflow: max DECIMAL(38,0) + 1 should throw in ANSI mode + if (isSparkVersionGE("4.0")) { + intercept[SparkException] { + sql("SELECT CAST(99999999999999999999999999999999999999 AS DECIMAL(38,0)) + " + + "CAST(1 AS DECIMAL(38,0))").collect() + } + } else { + intercept[ArithmeticException] { + sql("SELECT CAST(99999999999999999999999999999999999999 AS DECIMAL(38,0)) + " + + "CAST(1 AS DECIMAL(38,0))").collect() + } + } + } + + test("decimal subtract overflow") { + // Normal decimal subtract should succeed and match Spark results + runQueryAndCompare( + "SELECT CAST(5.0 AS DECIMAL(10,2)) - CAST(2.0 AS DECIMAL(10,2))") { + checkGlutenPlan[ProjectExecTransformer] + } + + // Overflow: -max DECIMAL(38,0) - 1 should throw in ANSI mode + if (isSparkVersionGE("4.0")) { + intercept[SparkException] { + sql("SELECT CAST(-99999999999999999999999999999999999999 AS DECIMAL(38,0)) - " + + "CAST(1 AS DECIMAL(38,0))").collect() + } + } else { + intercept[ArithmeticException] { + sql("SELECT CAST(-99999999999999999999999999999999999999 AS DECIMAL(38,0)) - " + + "CAST(1 AS DECIMAL(38,0))").collect() + } + } + } + + test("decimal try_add") { + // Normal case should match Spark results + runQueryAndCompare( + "SELECT try_add(CAST(1.0 AS DECIMAL(10,2)), CAST(2.0 AS DECIMAL(10,2)))") { + checkGlutenPlan[ProjectExecTransformer] + } + // Overflow should return null + runQueryAndCompare( + "SELECT try_add(CAST(99999999999999999999999999999999999999 AS DECIMAL(38,0)), " + + "CAST(1 AS DECIMAL(38,0)))") { + checkGlutenPlan[ProjectExecTransformer] + } + } + + test("decimal try_subtract") { + // Normal case should match Spark results + runQueryAndCompare( + "SELECT try_subtract(CAST(5.0 AS DECIMAL(10,2)), CAST(2.0 AS DECIMAL(10,2)))") { + checkGlutenPlan[ProjectExecTransformer] + } + // Overflow should return null + runQueryAndCompare( + "SELECT try_subtract(CAST(-99999999999999999999999999999999999999 AS DECIMAL(38,0)), " + + "CAST(1 AS DECIMAL(38,0)))") { + checkGlutenPlan[ProjectExecTransformer] + } + } + + test("decimal multiply overflow") { + // Normal decimal multiply should succeed and match Spark results + runQueryAndCompare( + "SELECT CAST(2.0 AS DECIMAL(10,2)) * CAST(3.0 AS DECIMAL(10,2))") { + checkGlutenPlan[ProjectExecTransformer] + } + + // Overflow: max DECIMAL(38,0) * 2 should throw in ANSI mode + if (isSparkVersionGE("4.0")) { + intercept[SparkException] { + sql("SELECT CAST(99999999999999999999999999999999999999 AS DECIMAL(38,0)) * " + + "CAST(2 AS DECIMAL(38,0))").collect() + } + } else { + intercept[ArithmeticException] { + sql("SELECT CAST(99999999999999999999999999999999999999 AS DECIMAL(38,0)) * " + + "CAST(2 AS DECIMAL(38,0))").collect() + } + } + } + + test("decimal try_multiply") { + // Normal case should match Spark results + runQueryAndCompare( + "SELECT try_multiply(CAST(2.0 AS DECIMAL(10,2)), CAST(3.0 AS DECIMAL(10,2)))") { + checkGlutenPlan[ProjectExecTransformer] + } + // Overflow should return null + runQueryAndCompare( + "SELECT try_multiply(CAST(99999999999999999999999999999999999999 AS DECIMAL(38,0)), " + + "CAST(2 AS DECIMAL(38,0)))") { + checkGlutenPlan[ProjectExecTransformer] + } + } + } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 7969d305025..ddf1fa8db88 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -642,18 +642,50 @@ object ExpressionConverter extends SQLConfHelper with Logging { substraitExprName, expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), expr) - case CheckOverflow(b: BinaryArithmetic, decimalType, _) + case CheckOverflow(b: BinaryArithmetic, decimalType, nullOnOverflow) if !BackendsApiManager.getSettings.transformCheckOverflow && DecimalArithmeticUtil.isDecimalArithmetic(b) => - val arithmeticExprName = + val baseExprName = BackendsApiManager.getSparkPlanExecApiInstance.getDecimalArithmeticExprName( getAndCheckSubstraitName(b, expressionsMap), SparkShimLoader.getSparkShims.decimalAllowPrecisionLoss(b)) + // When nullOnOverflow is false, it's ANSI mode - use checked_ prefix for overflow errors + val arithmeticExprName = if (!nullOnOverflow) { + "checked_" + baseExprName + } else { + baseExprName + } val left = replaceWithExpressionTransformer0(b.left, attributeSeq, expressionsMap) val right = replaceWithExpressionTransformer0(b.right, attributeSeq, expressionsMap) DecimalArithmeticExpressionTransformer(arithmeticExprName, left, right, decimalType, b) + // Velox path: decimal Add/Subtract/Multiply in ANSI or TRY mode uses checked_ variants. + // ANSI mode (nullOnOverflow=false): checked_* throws on overflow. + // TRY mode: try(checked_*) returns null on overflow. + case c @ CheckOverflow(b: BinaryArithmetic, _, nullOnOverflow) + if BackendsApiManager.getSettings.transformCheckOverflow && + DecimalArithmeticUtil.isDecimalArithmetic(b) && + (b.isInstanceOf[Add] || b.isInstanceOf[Subtract] || b.isInstanceOf[Multiply]) && + (!nullOnOverflow || + SparkShimLoader.getSparkShims.withTryEvalMode(b)) => + val baseExprName = + BackendsApiManager.getSparkPlanExecApiInstance.getDecimalArithmeticExprName( + getAndCheckSubstraitName(b, expressionsMap), + SparkShimLoader.getSparkShims.decimalAllowPrecisionLoss(b)) + val checkedExprName = "checked_" + baseExprName + val childTransformer = + genRescaleDecimalTransformer(checkedExprName, b, attributeSeq, expressionsMap) + if (SparkShimLoader.getSparkShims.withTryEvalMode(b)) { + // TRY mode: wrap checked_ in try() to return null on overflow. + GenericExpressionTransformer( + ExpressionMappings.expressionsMap(classOf[TryEval]), + Seq(childTransformer), + c) + } else { + // ANSI mode: checked_ throws on overflow. + CheckOverflowTransformer(substraitExprName, childTransformer, c) + } case c: CheckOverflow => CheckOverflowTransformer( substraitExprName,