From b38e73a4133485fc31435baf33efea4f12b6ba19 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 17 May 2026 23:13:12 +0000 Subject: [PATCH] [SPARK-56913][SQL] Refactor BinaryArithmetic byte/short codegen under ANSI mode ### What changes were proposed in this pull request? Introduce `ArithmeticUtils.java` with six static helpers (`byteAddExact`, `byteSubtractExact`, `byteMultiplyExact`, `shortAddExact`, `shortSubtractExact`, `shortMultiplyExact`) and use them from `BinaryArithmetic.doGenCode` and from `Add` / `Subtract` / `Multiply.nullSafeEval`. The `Byte`/`Short` ANSI overflow-check branch of `BinaryArithmetic.doGenCode` previously emitted ~7 lines per call site (int tmpResult + overflow check + cast back). After this PR it emits a single `ArithmeticUtils.Exact(...)` call. The eval-path counterparts for Add/Subtract/Multiply also delegate to the helpers under ANSI mode, replacing the previous fall-through to `numeric.plus`/`minus`/`times` (which threw a generic `ArithmeticException`) -- the eval path now produces the same SQL-formatted `BINARY_ARITHMETIC_OVERFLOW` error as the codegen path. Primitive `int`/`long`/`float`/`double` branches are intentionally left inline (single bytecode op; routing through a static method would be a runtime regression). ### Why are the changes needed? Part of SPARK-56908 (umbrella). The Byte/Short ANSI branch is the largest single inline body in `BinaryArithmetic.doGenCode`. ### Does this PR introduce _any_ user-facing change? No. Compiled behavior is identical; the eval path now produces a SQL-formatted overflow error matching the codegen path (the previous generic `ArithmeticException` was an inconsistency). ### How was this patch tested? ``` build/sbt "catalyst/testOnly *ArithmeticExpressionSuite" ``` 35/35 pass. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.x --- .../catalyst/expressions/ArithmeticUtils.java | 85 +++++++++++++++++++ .../sql/catalyst/expressions/arithmetic.scala | 49 ++++++----- 2 files changed, 111 insertions(+), 23 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArithmeticUtils.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArithmeticUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArithmeticUtils.java new file mode 100644 index 000000000000..cb52e2c67636 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArithmeticUtils.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.sql.errors.QueryExecutionErrors; + +/** + * Static helpers used by {@code BinaryArithmetic.doGenCode} (and + * corresponding eval paths) for ANSI overflow-checked {@code byte} and + * {@code short} arithmetic. Primitive {@code int} / {@code long} / + * {@code float} / {@code double} arithmetic stays inline -- routing those + * single-bytecode operations through a static method would be a runtime + * regression. + */ +public final class ArithmeticUtils { + + private ArithmeticUtils() {} + + // ----- Byte: int arithmetic with overflow check (ANSI) ----- + + public static byte byteAddExact(byte a, byte b) { + int r = a + b; + if (r < Byte.MIN_VALUE || r > Byte.MAX_VALUE) { + throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(a, "+", b, "try_add"); + } + return (byte) r; + } + + public static byte byteSubtractExact(byte a, byte b) { + int r = a - b; + if (r < Byte.MIN_VALUE || r > Byte.MAX_VALUE) { + throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(a, "-", b, "try_subtract"); + } + return (byte) r; + } + + public static byte byteMultiplyExact(byte a, byte b) { + int r = a * b; + if (r < Byte.MIN_VALUE || r > Byte.MAX_VALUE) { + throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(a, "*", b, "try_multiply"); + } + return (byte) r; + } + + // ----- Short: int arithmetic with overflow check (ANSI) ----- + + public static short shortAddExact(short a, short b) { + int r = a + b; + if (r < Short.MIN_VALUE || r > Short.MAX_VALUE) { + throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(a, "+", b, "try_add"); + } + return (short) r; + } + + public static short shortSubtractExact(short a, short b) { + int r = a - b; + if (r < Short.MIN_VALUE || r > Short.MAX_VALUE) { + throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(a, "-", b, "try_subtract"); + } + return (short) r; + } + + public static short shortMultiplyExact(short a, short b) { + int r = a * b; + if (r < Short.MIN_VALUE || r > Short.MAX_VALUE) { + throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(a, "*", b, "try_multiply"); + } + return (short) r; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1c93a6586761..a67192504c24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -301,31 +301,22 @@ abstract class BinaryArithmetic extends BinaryOperator with SupportQueryContext val mathUtils = IntervalMathUtils.getClass.getCanonicalName.stripSuffix("$") defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathUtils.${exactMathMethod.get}($eval1, $eval2)") // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType if failOnError => + val opName = symbol match { + case "+" => "Add" + case "-" => "Subtract" + case "*" => "Multiply" + case _ => + throw QueryExecutionErrors.notOverrideExpectedMethodsError(this.getClass.getName, + s"genCode for Byte/Short with symbol '$symbol'", "genCode") + } + val typeName = if (dataType == ByteType) "byte" else "short" + val arithmeticUtils = classOf[ArithmeticUtils].getName + defineCodeGen(ctx, ev, (eval1, eval2) => + s"$arithmeticUtils.$typeName${opName}Exact($eval1, $eval2)") case ByteType | ShortType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val tmpResult = ctx.freshName("tmpResult") - val try_suggestion = symbol match { - case "+" => "try_add" - case "-" => "try_subtract" - case "*" => "try_multiply" - case _ => "unknown_function" - } - val overflowCheck = if (failOnError) { - val javaType = CodeGenerator.boxedType(dataType) - s""" - |if ($tmpResult < $javaType.MIN_VALUE || $tmpResult > $javaType.MAX_VALUE) { - | throw QueryExecutionErrors.binaryArithmeticCauseOverflowError( - | $eval1, "$symbol", $eval2, "$try_suggestion"); - |} - """.stripMargin - } else { - "" - } - s""" - |${CodeGenerator.JAVA_INT} $tmpResult = $eval1 $symbol $eval2; - |$overflowCheck - |${ev.value} = (${CodeGenerator.javaType(dataType)})($tmpResult); - """.stripMargin + s"${ev.value} = (${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2);" }) case IntegerType | LongType if failOnError && exactMathMethod.isDefined => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { @@ -458,6 +449,10 @@ case class Add( MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], getContextOrNull()) case _: LongType if failOnError => MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], getContextOrNull()) + case _: ByteType if failOnError => + ArithmeticUtils.byteAddExact(input1.asInstanceOf[Byte], input2.asInstanceOf[Byte]) + case _: ShortType if failOnError => + ArithmeticUtils.shortAddExact(input1.asInstanceOf[Short], input2.asInstanceOf[Short]) case _ => numeric.plus(input1, input2) } @@ -555,6 +550,10 @@ case class Subtract( input1.asInstanceOf[Long], input2.asInstanceOf[Long], getContextOrNull()) + case _: ByteType if failOnError => + ArithmeticUtils.byteSubtractExact(input1.asInstanceOf[Byte], input2.asInstanceOf[Byte]) + case _: ShortType if failOnError => + ArithmeticUtils.shortSubtractExact(input1.asInstanceOf[Short], input2.asInstanceOf[Short]) case _ => numeric.minus(input1, input2) } @@ -625,6 +624,10 @@ case class Multiply( input1.asInstanceOf[Long], input2.asInstanceOf[Long], getContextOrNull()) + case _: ByteType if failOnError => + ArithmeticUtils.byteMultiplyExact(input1.asInstanceOf[Byte], input2.asInstanceOf[Byte]) + case _: ShortType if failOnError => + ArithmeticUtils.shortMultiplyExact(input1.asInstanceOf[Short], input2.asInstanceOf[Short]) case _ => numeric.times(input1, input2) }