From 23a165c72d56267809580652365b3771857f2c4e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 17 May 2026 23:25:05 +0000 Subject: [PATCH] [SPARK-56916][SQL] Simplify ElementAt array codegen under ANSI mode ### What changes were proposed in this pull request? Introduce `ArrayUtils.java` with a single helper `elementAtIndexExact(int length, int index, QueryContext context)` and use it from `ElementAt`'s `ArrayType` branch in both `doGenCode` and `doElementAt` (eval). The helper normalizes a 1-based `element_at` index against the array length and returns the 0-based position, throwing `invalidElementAtIndexError` for out-of-bound and `invalidIndexOfZeroError` for zero index. The caller still emits the type-specific `arr.get(pos, dataType)` (not the helper, since the return type depends on the array element type). The non-ANSI branch is left inline because it can choose between `defaultValueOutOfBound` (an `Option[Expression]` that requires codegen access) or `null`. ### Why are the changes needed? Part of SPARK-56908 (umbrella). The ANSI `ElementAt` codegen body was the largest single inline body in `collectionOperations.scala` -- the helper collapses ~12 lines to ~3 per call site. ### Does this PR introduce _any_ user-facing change? No. The compiled behavior is identical; only the emitted Java source text changes. ### How was this patch tested? ``` build/sbt "catalyst/testOnly *CollectionExpressionsSuite" ``` 59/59 pass. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.x --- .../catalyst/expressions/ElementAtUtils.java | 51 +++++++++++++++ .../expressions/collectionOperations.scala | 65 ++++++++++++------- 2 files changed, 93 insertions(+), 23 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java new file mode 100644 index 0000000000000..27e1245e4ffc3 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ElementAtUtils.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.QueryContext; +import org.apache.spark.sql.errors.QueryExecutionErrors; + +/** + * Static helpers used by {@link ElementAt} on {@code ArrayType} + * (codegen and eval) under ANSI mode. + */ +public final class ElementAtUtils { + + private ElementAtUtils() {} + + /** + * Validates a 1-based {@code element_at} index against the array length + * and returns the 0-based position. Throws when the absolute index + * exceeds the array length (ANSI out-of-bounds) or when {@code index} is + * zero (always invalid). + * + * @param length the array length + * @param index the 1-based index supplied by the user (positive or negative) + * @param context the query context attached to the error + * @return the validated 0-based position + */ + public static int elementAtIndexExact(int length, int index, QueryContext context) { + if (length < Math.abs(index)) { + throw QueryExecutionErrors.invalidElementAtIndexError(index, length, context); + } + if (index == 0) { + throw QueryExecutionErrors.invalidIndexOfZeroError(context); + } + return index > 0 ? index - 1 : length + index; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 60966f3098ca8..5a6a57b76e167 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2738,19 +2738,21 @@ case class ElementAt( override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal) @transient private lazy val doElementAt: (Any, Any) => Any = left.dataType match { + case _: ArrayType if failOnError => + (value, ordinal) => { + val array = value.asInstanceOf[ArrayData] + val idx = ElementAtUtils.elementAtIndexExact( + array.numElements(), ordinal.asInstanceOf[Int], getContextOrNull()) + if (arrayElementNullable && array.isNullAt(idx)) null else array.get(idx, dataType) + } case _: ArrayType => (value, ordinal) => { val array = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Int] if (array.numElements() < math.abs(index)) { - if (failOnError) { - throw QueryExecutionErrors.invalidElementAtIndexError( - index, array.numElements(), getContextOrNull()) - } else { - defaultValueOutOfBound match { - case Some(value) => value.eval() - case None => null - } + defaultValueOutOfBound match { + case Some(value) => value.eval() + case None => null } } else { val idx = if (index == 0) { @@ -2773,7 +2775,7 @@ case class ElementAt( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { left.dataType match { - case _: ArrayType => + case _: ArrayType if failOnError => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("elementAtIndex") val nullCheck = if (arrayElementNullable) { @@ -2786,21 +2788,38 @@ case class ElementAt( "" } val errorContext = getContextOrNullCode(ctx) - val indexOutOfBoundBranch = if (failOnError) { - // scalastyle:off line.size.limit - s"throw QueryExecutionErrors.invalidElementAtIndexError($index, $eval1.numElements(), $errorContext);" - // scalastyle:on line.size.limit + val utils = classOf[ElementAtUtils].getName + s""" + |int $index = $utils.elementAtIndexExact( + | $eval1.numElements(), (int) $eval2, $errorContext); + |$nullCheck + |{ + | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + |} + """.stripMargin + }) + case _: ArrayType => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("elementAtIndex") + val nullCheck = if (arrayElementNullable) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else + """.stripMargin } else { - defaultValueOutOfBound match { - case Some(value) => - val defaultValueEval = value.genCode(ctx) - s""" - ${defaultValueEval.code} - ${ev.isNull} = ${defaultValueEval.isNull}; - ${ev.value} = ${defaultValueEval.value}; - """.stripMargin - case None => s"${ev.isNull} = true;" - } + "" + } + val errorContext = getContextOrNullCode(ctx) + val indexOutOfBoundBranch = defaultValueOutOfBound match { + case Some(value) => + val defaultValueEval = value.genCode(ctx) + s""" + ${defaultValueEval.code} + ${ev.isNull} = ${defaultValueEval.isNull}; + ${ev.value} = ${defaultValueEval.value}; + """.stripMargin + case None => s"${ev.isNull} = true;" } s"""