From 72092183f809539cd6b7fe9d75cda81c7f31a940 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 17 May 2026 22:52:03 +0000 Subject: [PATCH 1/4] [SPARK-56909][SQL] Refactor Cast to int/long codegen under ANSI mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Introduce `CastUtils.java` and use it from `Cast.scala` to collapse the multi-line ANSI overflow-check codegen for casts that target `int` and `long` into one-line static-method calls. Source and target `DataType` constants used in the overflow error message live as `private static final` fields on the helper class, so the happy path performs no per-row `references[]` lookups. Helpers added: * `longToIntExact(long)` for narrowing `long -> int`. * `floatToIntExact(float)`, `doubleToIntExact(double)` for fractional -> int. * `floatToLongExact(float)`, `doubleToLongExact(double)` for fractional -> long. `Cast.scala` changes: * `castIntegralTypeToIntegralTypeExactCode` and `castFractionToIntegralTypeCode` dispatch on the target type: `int` (and `long` for the fraction case) emit a `CastUtils.<...>Exact` call; byte/short targets keep the inline body (refactored in SPARK-56910). * Eval paths for `castToInt` add ANSI `LongType` / `FloatType` / `DoubleType` cases, and `castToLong` adds `FloatType` / `DoubleType` cases, both delegating to the new helpers. ### Why are the changes needed? Part of SPARK-56908. The current ANSI cast codegen emits 5-line inline overflow blocks per call site. Multiplied across the many cast paths in a TPC-DS plan, this contributes meaningfully to the generated source size and to Janino compile time, and pushes whole-stage methods closer to the 64KB JVM method limit. ### Does this PR introduce _any_ user-facing change? No. The compiled behavior is identical; only the emitted Java source text changes. ### How was this patch tested? `build/sbt "catalyst/testOnly *CastSuite *CastWithAnsiOnSuite *CastWithAnsiOffSuite *AnsiCastSuite *TryCastSuite *ExpressionClassIdentitySuite"` — 312/312 pass. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.x --- .../sql/catalyst/expressions/CastUtils.java | 71 +++++++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 89 +++++++++++++------ 2 files changed, 133 insertions(+), 27 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java new file mode 100644 index 0000000000000..1f6a0daf616e3 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.sql.errors.QueryExecutionErrors; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; + +/** + * Static helpers used by {@code Cast.doGenCode} (and corresponding eval + * paths) for ANSI overflow-checked narrowing conversions. The source and + * target {@link DataType} objects referenced by the overflow error message + * are held in {@code private static final} fields so the happy path + * performs no per-row {@code references[]} lookups. + */ +public final class CastUtils { + + private CastUtils() {} + + private static final DataType INT = DataTypes.IntegerType; + private static final DataType LONG = DataTypes.LongType; + private static final DataType FLOAT = DataTypes.FloatType; + private static final DataType DOUBLE = DataTypes.DoubleType; + + // ----- integral narrowing -> int (ANSI: throw on overflow) ----- + + public static int longToIntExact(long v) { + if (v == (int) v) return (int) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, LONG, INT); + } + + // ----- fractional -> int (ANSI: throw on overflow) ----- + // Mirrors castFractionToIntegralTypeCode: floor(v) <= MAX && ceil(v) >= MIN. + + public static int floatToIntExact(float v) { + if (Math.floor(v) <= Integer.MAX_VALUE && Math.ceil(v) >= Integer.MIN_VALUE) return (int) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, FLOAT, INT); + } + + public static int doubleToIntExact(double v) { + if (Math.floor(v) <= Integer.MAX_VALUE && Math.ceil(v) >= Integer.MIN_VALUE) return (int) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, INT); + } + + // ----- fractional -> long (ANSI: throw on overflow) ----- + + public static long floatToLongExact(float v) { + if (Math.floor(v) <= Long.MAX_VALUE && Math.ceil(v) >= Long.MIN_VALUE) return (long) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, FLOAT, LONG); + } + + public static long doubleToLongExact(double v) { + if (Math.floor(v) <= Long.MAX_VALUE && Math.ceil(v) >= Long.MIN_VALUE) return (long) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, LONG); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c51d3508d04a4..170d4953d830d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -897,6 +897,10 @@ case class Cast( buildCast[Long](_, t => timestampToLong(t)) case _: TimeType => buildCast[Long](_, t => timeToLong(t)) + case FloatType if ansiEnabled => + b => CastUtils.floatToLongExact(b.asInstanceOf[Float]) + case DoubleType if ansiEnabled => + b => CastUtils.doubleToLongExact(b.asInstanceOf[Double]) case x: NumericType if ansiEnabled => val exactNumeric = PhysicalNumericType.exactNumeric(x) b => exactNumeric.toLong(b) @@ -939,6 +943,12 @@ case class Cast( }) case _: TimeType => buildCast[Long](_, t => timeToLong(t).toInt) + case LongType if ansiEnabled => + b => CastUtils.longToIntExact(b.asInstanceOf[Long]) + case FloatType if ansiEnabled => + b => CastUtils.floatToIntExact(b.asInstanceOf[Float]) + case DoubleType if ansiEnabled => + b => CastUtils.doubleToIntExact(b.asInstanceOf[Double]) case x: NumericType if ansiEnabled => val exactNumeric = PhysicalNumericType.exactNumeric(x) b => exactNumeric.toInt(b) @@ -1982,22 +1992,40 @@ case class Cast( } } + private[this] def integralPrefix(from: DataType): String = from match { + case ShortType => "short" + case IntegerType => "int" + case LongType => "long" + } + + private[this] def fractionalPrefix(from: DataType): String = from match { + case FloatType => "float" + case DoubleType => "double" + } + private[this] def castIntegralTypeToIntegralTypeExactCode( ctx: CodegenContext, integralType: String, from: DataType, to: DataType): CastFunction = { assert(ansiEnabled) - val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName) - val toDt = ctx.addReferenceObj("to", to, to.getClass.getName) - (c, evPrim, _) => - code""" - if ($c == ($integralType) $c) { - $evPrim = ($integralType) $c; - } else { - throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt); - } - """ + if (integralType == "int") { + val castUtils = classOf[CastUtils].getName + val method = s"${integralPrefix(from)}ToIntExact" + (c, evPrim, _) => code"$evPrim = $castUtils.$method($c);" + } else { + // Byte/short narrowing remains inline; refactored in a follow-up PR. + val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName) + val toDt = ctx.addReferenceObj("to", to, to.getClass.getName) + (c, evPrim, _) => + code""" + if ($c == ($integralType) $c) { + $evPrim = ($integralType) $c; + } else { + throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt); + } + """ + } } @@ -2017,23 +2045,30 @@ case class Cast( from: DataType, to: DataType): CastFunction = { assert(ansiEnabled) - val (min, max) = lowerAndUpperBound(integralType) - val mathClass = classOf[Math].getName - val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName) - val toDt = ctx.addReferenceObj("to", to, to.getClass.getName) - // When casting floating values to integral types, Spark uses the method `Numeric.toInt` - // Or `Numeric.toLong` directly. For positive floating values, it is equivalent to `Math.floor`; - // for negative floating values, it is equivalent to `Math.ceil`. - // So, we can use the condition `Math.floor(x) <= upperBound && Math.ceil(x) >= lowerBound` - // to check if the floating value x is in the range of an integral type after rounding. - (c, evPrim, _) => - code""" - if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) { - $evPrim = ($integralType) $c; - } else { - throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt); - } - """ + if (integralType == "int" || integralType == "long") { + val castUtils = classOf[CastUtils].getName + val method = s"${fractionalPrefix(from)}To${integralType.capitalize}Exact" + (c, evPrim, _) => code"$evPrim = $castUtils.$method($c);" + } else { + // Byte/short narrowing remains inline; refactored in a follow-up PR. + val (min, max) = lowerAndUpperBound(integralType) + val mathClass = classOf[Math].getName + val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName) + val toDt = ctx.addReferenceObj("to", to, to.getClass.getName) + // When casting floating values to integral types, Spark uses the method `Numeric.toInt` + // Or `Numeric.toLong` directly. For positive floating values, it is equivalent to + // `Math.floor`; for negative floating values, it is equivalent to `Math.ceil`. + // So, we can use the condition `Math.floor(x) <= upperBound && Math.ceil(x) >= lowerBound` + // to check if the floating value x is in the range of an integral type after rounding. + (c, evPrim, _) => + code""" + if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) { + $evPrim = ($integralType) $c; + } else { + throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt); + } + """ + } } private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { From 2c1ab79e72d2313298bfba0102f8d74a8cf60ff6 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 17 May 2026 22:58:31 +0000 Subject: [PATCH 2/4] [SPARK-56910][SQL] Refactor Cast to byte/short codegen under ANSI mode Extend `CastUtils.java` with helpers for `byte` and `short` ANSI cast targets and use them from `Cast.scala`. Drops the byte/short-target dispatch (and the now-unused `lowerAndUpperBound` Scala helper) added in SPARK-56909 -- after this PR, all integral and fractional narrowing ANSI casts share the same `CastUtils.<...>Exact` one-line codegen. Helpers added: * `shortToByteExact(short)`, `intToByteExact(int)`, `longToByteExact(long)` * `intToShortExact(int)`, `longToShortExact(long)` * `floatToByteExact(float)`, `doubleToByteExact(double)` * `floatToShortExact(float)`, `doubleToShortExact(double)` `Cast.scala` changes: * `castIntegralTypeToIntegralTypeExactCode` / `castFractionToIntegralTypeCode` no longer dispatch on target type -- the helper-name pattern `${integralPrefix(from)}To${target.capitalize}Exact` covers all four target types. * Eval paths for `castToByte` and `castToShort` add ANSI cases for `ShortType` / `IntegerType` / `LongType` / `FloatType` / `DoubleType` source types that delegate to the new helpers; the existing `exactNumeric.toInt(b) + bounds-check` fallback now only handles the remaining `Decimal` source. Part of SPARK-56908 (umbrella). The original byte/short ANSI cast bodies were 5 lines each across 8 call sites; this PR collapses them to one line per call site, matching the int/long target work from SPARK-56909. No. The compiled behavior is identical; only the emitted Java source text changes. ``` build/sbt "catalyst/testOnly *CastSuite *CastWithAnsiOnSuite \ *CastWithAnsiOffSuite *AnsiCastSuite *TryCastSuite \ *ExpressionClassIdentitySuite" ``` 312/312 pass. Generated-by: Cursor 1.x --- .../sql/catalyst/expressions/CastUtils.java | 51 ++++++++++++- .../spark/sql/catalyst/expressions/Cast.scala | 76 ++++++------------- 2 files changed, 73 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java index 1f6a0daf616e3..3c599d0cc8659 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java @@ -32,21 +32,68 @@ public final class CastUtils { private CastUtils() {} + private static final DataType SHORT = DataTypes.ShortType; private static final DataType INT = DataTypes.IntegerType; private static final DataType LONG = DataTypes.LongType; + private static final DataType BYTE = DataTypes.ByteType; private static final DataType FLOAT = DataTypes.FloatType; private static final DataType DOUBLE = DataTypes.DoubleType; - // ----- integral narrowing -> int (ANSI: throw on overflow) ----- + // ----- integral narrowing (ANSI: throw on overflow) ----- + + public static byte shortToByteExact(short v) { + if (v == (byte) v) return (byte) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, SHORT, BYTE); + } + + public static byte intToByteExact(int v) { + if (v == (byte) v) return (byte) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, INT, BYTE); + } + + public static byte longToByteExact(long v) { + if (v == (byte) v) return (byte) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, LONG, BYTE); + } + + public static short intToShortExact(int v) { + if (v == (short) v) return (short) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, INT, SHORT); + } + + public static short longToShortExact(long v) { + if (v == (short) v) return (short) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, LONG, SHORT); + } public static int longToIntExact(long v) { if (v == (int) v) return (int) v; throw QueryExecutionErrors.castingCauseOverflowError(v, LONG, INT); } - // ----- fractional -> int (ANSI: throw on overflow) ----- + // ----- fractional -> integral (ANSI: throw on overflow) ----- // Mirrors castFractionToIntegralTypeCode: floor(v) <= MAX && ceil(v) >= MIN. + public static byte floatToByteExact(float v) { + if (Math.floor(v) <= Byte.MAX_VALUE && Math.ceil(v) >= Byte.MIN_VALUE) return (byte) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, FLOAT, BYTE); + } + + public static byte doubleToByteExact(double v) { + if (Math.floor(v) <= Byte.MAX_VALUE && Math.ceil(v) >= Byte.MIN_VALUE) return (byte) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, BYTE); + } + + public static short floatToShortExact(float v) { + if (Math.floor(v) <= Short.MAX_VALUE && Math.ceil(v) >= Short.MIN_VALUE) return (short) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, FLOAT, SHORT); + } + + public static short doubleToShortExact(double v) { + if (Math.floor(v) <= Short.MAX_VALUE && Math.ceil(v) >= Short.MIN_VALUE) return (short) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, SHORT); + } + public static int floatToIntExact(float v) { if (Math.floor(v) <= Integer.MAX_VALUE && Math.ceil(v) >= Integer.MIN_VALUE) return (int) v; throw QueryExecutionErrors.castingCauseOverflowError(v, FLOAT, INT); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 170d4953d830d..431da5a50a51e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -994,6 +994,14 @@ case class Cast( errorOrNull(t, from, ShortType) } }) + case IntegerType if ansiEnabled => + b => CastUtils.intToShortExact(b.asInstanceOf[Int]) + case LongType if ansiEnabled => + b => CastUtils.longToShortExact(b.asInstanceOf[Long]) + case FloatType if ansiEnabled => + b => CastUtils.floatToShortExact(b.asInstanceOf[Float]) + case DoubleType if ansiEnabled => + b => CastUtils.doubleToShortExact(b.asInstanceOf[Double]) case x: NumericType if ansiEnabled => val exactNumeric = PhysicalNumericType.exactNumeric(x) b => @@ -1050,6 +1058,16 @@ case class Cast( errorOrNull(t, from, ByteType) } }) + case ShortType if ansiEnabled => + b => CastUtils.shortToByteExact(b.asInstanceOf[Short]) + case IntegerType if ansiEnabled => + b => CastUtils.intToByteExact(b.asInstanceOf[Int]) + case LongType if ansiEnabled => + b => CastUtils.longToByteExact(b.asInstanceOf[Long]) + case FloatType if ansiEnabled => + b => CastUtils.floatToByteExact(b.asInstanceOf[Float]) + case DoubleType if ansiEnabled => + b => CastUtils.doubleToByteExact(b.asInstanceOf[Double]) case x: NumericType if ansiEnabled => val exactNumeric = PhysicalNumericType.exactNumeric(x) b => @@ -2009,34 +2027,9 @@ case class Cast( from: DataType, to: DataType): CastFunction = { assert(ansiEnabled) - if (integralType == "int") { - val castUtils = classOf[CastUtils].getName - val method = s"${integralPrefix(from)}ToIntExact" - (c, evPrim, _) => code"$evPrim = $castUtils.$method($c);" - } else { - // Byte/short narrowing remains inline; refactored in a follow-up PR. - val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName) - val toDt = ctx.addReferenceObj("to", to, to.getClass.getName) - (c, evPrim, _) => - code""" - if ($c == ($integralType) $c) { - $evPrim = ($integralType) $c; - } else { - throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt); - } - """ - } - } - - - private[this] def lowerAndUpperBound(integralType: String): (String, String) = { - val (min, max, typeIndicator) = integralType.toLowerCase(Locale.ROOT) match { - case "long" => (Long.MinValue, Long.MaxValue, "L") - case "int" => (Int.MinValue, Int.MaxValue, "") - case "short" => (Short.MinValue, Short.MaxValue, "") - case "byte" => (Byte.MinValue, Byte.MaxValue, "") - } - (min.toString + typeIndicator, max.toString + typeIndicator) + val castUtils = classOf[CastUtils].getName + val method = s"${integralPrefix(from)}To${integralType.capitalize}Exact" + (c, evPrim, _) => code"$evPrim = $castUtils.$method($c);" } private[this] def castFractionToIntegralTypeCode( @@ -2045,30 +2038,9 @@ case class Cast( from: DataType, to: DataType): CastFunction = { assert(ansiEnabled) - if (integralType == "int" || integralType == "long") { - val castUtils = classOf[CastUtils].getName - val method = s"${fractionalPrefix(from)}To${integralType.capitalize}Exact" - (c, evPrim, _) => code"$evPrim = $castUtils.$method($c);" - } else { - // Byte/short narrowing remains inline; refactored in a follow-up PR. - val (min, max) = lowerAndUpperBound(integralType) - val mathClass = classOf[Math].getName - val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName) - val toDt = ctx.addReferenceObj("to", to, to.getClass.getName) - // When casting floating values to integral types, Spark uses the method `Numeric.toInt` - // Or `Numeric.toLong` directly. For positive floating values, it is equivalent to - // `Math.floor`; for negative floating values, it is equivalent to `Math.ceil`. - // So, we can use the condition `Math.floor(x) <= upperBound && Math.ceil(x) >= lowerBound` - // to check if the floating value x is in the range of an integral type after rounding. - (c, evPrim, _) => - code""" - if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) { - $evPrim = ($integralType) $c; - } else { - throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt); - } - """ - } + val castUtils = classOf[CastUtils].getName + val method = s"${fractionalPrefix(from)}To${integralType.capitalize}Exact" + (c, evPrim, _) => code"$evPrim = $castUtils.$method($c);" } private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { From 8f72067a55bd1261064b9d535067bff6029aad08 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 17 May 2026 23:03:00 +0000 Subject: [PATCH 3/4] [SPARK-56911][SQL] Refactor Cast to decimal codegen under ANSI mode ### What changes were proposed in this pull request? Extend `CastUtils.java` with two helpers for decimal precision adjustment and use them from `Cast.changePrecision` (both the eval and codegen implementations). The new helpers mutate the input `Decimal` in place (matching the behavior of the existing inline codegen), so they're safe to call on the temporary produced by `Decimal.fromString(...)` / `Decimal.apply(...)` / decimal-arithmetic results. Helpers added: * `changePrecisionExact(Decimal, int, int, QueryContext)`: ANSI throw on overflow, preserves the per-call-site `QueryContext` so error messages keep their query-origin info. * `changePrecisionOrNull(Decimal, int, int)`: non-ANSI, returns `null` on overflow (no `QueryContext` needed). `Cast.scala` changes: * `changePrecision` eval method dispatches on `nullOnOverflow` and delegates to the appropriate helper. * `changePrecision` codegen method has three branches now: the existing `canNullSafeCast` fast path (unchanged), a `nullOnOverflow` branch (inline), and the ANSI throw branch which now emits a one-line `CastUtils.changePrecisionExact(...)` call instead of the 5-line `if/else` overflow block. ### Why are the changes needed? Part of SPARK-56908 (umbrella). The ANSI throw branch of `Cast.changePrecision` is hit by every cast to decimal that may overflow (very common in TPC-DS, where `cast(int as decimal(7,2))` is widespread). Collapsing the 5-line inline body to one line shrinks the generated Java source for those plans. ### Does this PR introduce _any_ user-facing change? No. The compiled behavior is identical; only the emitted Java source text changes. ### How was this patch tested? ``` build/sbt "catalyst/testOnly *CastSuite *CastWithAnsiOnSuite \ *CastWithAnsiOffSuite *AnsiCastSuite *TryCastSuite *DecimalSuite \ *ExpressionClassIdentitySuite" ``` 337/337 pass. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.x --- .../sql/catalyst/expressions/CastUtils.java | 17 ++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 32 ++++++++----------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java index 3c599d0cc8659..d400a065d957c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java @@ -17,9 +17,11 @@ package org.apache.spark.sql.catalyst.expressions; +import org.apache.spark.QueryContext; import org.apache.spark.sql.errors.QueryExecutionErrors; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Decimal; /** * Static helpers used by {@code Cast.doGenCode} (and corresponding eval @@ -115,4 +117,19 @@ public static long doubleToLongExact(double v) { if (Math.floor(v) <= Long.MAX_VALUE && Math.ceil(v) >= Long.MIN_VALUE) return (long) v; throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, LONG); } + + // ----- decimal precision adjustment ----- + // Mutates the input Decimal in place. Used by Cast.changePrecision (and by + // BinaryArithmetic / DivModLike in follow-up PRs) to apply the target + // precision/scale on the per-row hot path. + + public static Decimal changePrecisionExact( + Decimal d, int precision, int scale, QueryContext context) { + if (d.changePrecision(precision, scale)) return d; + throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(d, precision, scale, context); + } + + public static Decimal changePrecisionOrNull(Decimal d, int precision, int scale) { + return d.changePrecision(precision, scale) ? d : null; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 431da5a50a51e..fb1c4cc0d1beb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1107,15 +1107,11 @@ case class Cast( value: Decimal, decimalType: DecimalType, nullOnOverflow: Boolean): Decimal = { - if (value.changePrecision(decimalType.precision, decimalType.scale)) { - value + if (nullOnOverflow) { + CastUtils.changePrecisionOrNull(value, decimalType.precision, decimalType.scale) } else { - if (nullOnOverflow) { - null - } else { - throw QueryExecutionErrors.cannotChangeDecimalPrecisionError( - value, decimalType.precision, decimalType.scale, getContextOrNull()) - } + CastUtils.changePrecisionExact( + value, decimalType.precision, decimalType.scale, getContextOrNull()) } } @@ -1568,23 +1564,21 @@ case class Cast( |$d.changePrecision(${decimalType.precision}, ${decimalType.scale}); |$evPrim = $d; """.stripMargin - } else { - val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow) - val overflowCode = if (nullOnOverflow) { - s"$evNull = true;" - } else { - s""" - |throw QueryExecutionErrors.cannotChangeDecimalPrecisionError( - | $d, ${decimalType.precision}, ${decimalType.scale}, $errorContextCode); - """.stripMargin - } + } else if (nullOnOverflow) { code""" |if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { | $evPrim = $d; |} else { - | $overflowCode + | $evNull = true; |} """.stripMargin + } else { + val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow) + val castUtils = classOf[CastUtils].getName + code""" + |$evPrim = $castUtils.changePrecisionExact( + | $d, ${decimalType.precision}, ${decimalType.scale}, $errorContextCode); + """.stripMargin } } From 2a324d8f38af8dd14d73e7e543e128c49afee445 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 17 May 2026 23:17:28 +0000 Subject: [PATCH 4/4] [SPARK-56914][SQL] Refactor decimal arithmetic codegen under ANSI mode ### What changes were proposed in this pull request? Use `CastUtils.changePrecisionExact` / `changePrecisionOrNull` (added in SPARK-56911) from the `DecimalType.Fixed` branches of: * `BinaryArithmetic.doGenCode` (covers `Add` / `Subtract` / `Multiply` on `Decimal`). * `BinaryDivModLike.doGenCode` (covers `Divide` / `IntegralDivide` / `Remainder` / `Pmod` on `Decimal`). * `BinaryArithmetic.checkDecimalOverflow` (eval path used by both groups via `numeric.plus`/`minus`/`times`/`div`). Each call site goes from `eval1.$op(eval2).toPrecision(p, s, ROUND_HALF_UP, !failOnError, ctx)` + a 4-line null check to a single `CastUtils.changePrecision{Exact,OrNull}` call. ### Why are the changes needed? Part of SPARK-56908 (umbrella). Decimal arithmetic is widespread in TPC-DS plans, and the `BinaryArithmetic` Decimal branch was one of the longer ANSI codegen bodies still emitted inline. ### Does this PR introduce _any_ user-facing change? No. The compiled behavior is identical; only the emitted Java source text changes. ### How was this patch tested? ``` build/sbt "catalyst/testOnly *ArithmeticExpressionSuite *DecimalSuite" ``` 60/60 pass. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.x --- .../sql/catalyst/expressions/arithmetic.scala | 58 ++++++++++++------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1c93a65867615..3fde53b316076 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -256,7 +256,11 @@ abstract class BinaryArithmetic extends BinaryOperator with SupportQueryContext } protected def checkDecimalOverflow(value: Decimal, precision: Int, scale: Int): Decimal = { - value.toPrecision(precision, scale, Decimal.ROUND_HALF_UP, !failOnError, getContextOrNull()) + if (failOnError) { + CastUtils.changePrecisionExact(value, precision, scale, getContextOrNull()) + } else { + CastUtils.changePrecisionOrNull(value, precision, scale) + } } /** Name of the function for this expression on a [[Decimal]] type. */ @@ -278,19 +282,21 @@ abstract class BinaryArithmetic extends BinaryOperator with SupportQueryContext override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case DecimalType.Fixed(precision, scale) => - val errorContextCode = getContextOrNullCode(ctx, failOnError) - val updateIsNull = if (failOnError) { - "" + val castUtils = classOf[CastUtils].getName + if (failOnError) { + val errorContextCode = getContextOrNullCode(ctx, failOnError) + defineCodeGen(ctx, ev, (eval1, eval2) => + s"$castUtils.changePrecisionExact(" + + s"$eval1.$decimalMethod($eval2), $precision, $scale, $errorContextCode)") } else { - s"${ev.isNull} = ${ev.value} == null;" + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + |${ev.value} = $castUtils.changePrecisionOrNull( + | $eval1.$decimalMethod($eval2), $precision, $scale); + |${ev.isNull} = ${ev.value} == null; + """.stripMargin + }) } - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" - |${ev.value} = $eval1.$decimalMethod($eval2).toPrecision( - | $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContextCode); - |$updateIsNull - """.stripMargin - }) case CalendarIntervalType => val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)") @@ -717,16 +723,26 @@ trait DivModLike extends BinaryArithmetic { val errorContextCode = getContextOrNullCode(ctx, failOnError) val operation = super.dataType match { case DecimalType.Fixed(precision, scale) => + val castUtils = classOf[CastUtils].getName val decimalValue = ctx.freshName("decimalValue") - s""" - |Decimal $decimalValue = ${eval1.value}.$decimalMethod(${eval2.value}).toPrecision( - | $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContextCode); - |if ($decimalValue != null) { - | ${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")}; - |} else { - | ${ev.isNull} = true; - |} - |""".stripMargin + if (failOnError) { + s""" + |Decimal $decimalValue = $castUtils.changePrecisionExact( + | ${eval1.value}.$decimalMethod(${eval2.value}), $precision, $scale, + | $errorContextCode); + |${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")}; + |""".stripMargin + } else { + s""" + |Decimal $decimalValue = $castUtils.changePrecisionOrNull( + | ${eval1.value}.$decimalMethod(${eval2.value}), $precision, $scale); + |if ($decimalValue != null) { + | ${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")}; + |} else { + | ${ev.isNull} = true; + |} + |""".stripMargin + } case _ => s"${ev.value} = ($javaType)(${eval1.value} $symbol ${eval2.value});" } val checkIntegralDivideOverflow = if (checkDivideOverflow) {