diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/LambdaBinder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/LambdaBinder.scala new file mode 100644 index 000000000000..da9b12566fa0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/LambdaBinder.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.{Expression, LambdaFunction, NamedLambdaVariable} +import org.apache.spark.sql.catalyst.util.TypeUtils.{toSQLConf, toSQLId} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType + +/** + * Object used to bind lambda function arguments to their types and validate lambda argument + * constraints. + * + * This object creates a bound [[LambdaFunction]] by binding the arguments to the given type + * information (dataType and nullability). The argument names come from the lambda function + * itself. It handles three cases: + * + * 1. Already bound lambda functions: Returns the function as-is, assuming it has been + * correctly bound to its arguments. + * + * 2. Unbound lambda functions: Validates and binds the function by: + * - Checking that the number of arguments matches the expected count + * - Checking for duplicate argument names (respecting case sensitivity configuration) + * - Creating [[NamedLambdaVariable]] instances with the provided types + * + * 3. Non-lambda expressions: Wraps the expression in a lambda function with hidden arguments + * (named `col0`, `col1`, etc.). This is used when an expression does not consume lambda + * arguments but needs to be passed to a higher-order function. The arguments are hidden to + * prevent accidental naming collisions. + */ +object LambdaBinder extends SQLConfHelper { + + /** + * Binds lambda function arguments to their types and validates lambda argument constraints. + */ + def apply(expression: Expression, argumentsInfo: Seq[(DataType, Boolean)]): LambdaFunction = + expression match { + case f: LambdaFunction if f.bound => f + + case LambdaFunction(function, names, _) => + if (names.size != argumentsInfo.size) { + expression.failAnalysis( + errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", + messageParameters = Map( + "expectedNumArgs" -> names.size.toString, + "actualNumArgs" -> argumentsInfo.size.toString + ) + ) + } + + if (names.map(a => conf.canonicalize(a.name)).distinct.size < names.size) { + expression.failAnalysis( + errorClass = "INVALID_LAMBDA_FUNCTION_CALL.DUPLICATE_ARG_NAMES", + messageParameters = Map( + "args" -> names.map(a => conf.canonicalize(a.name)).map(toSQLId(_)).mkString(", "), + "caseSensitiveConfig" -> toSQLConf(SQLConf.CASE_SENSITIVE.key) + ) + ) + } + + val arguments = argumentsInfo.zip(names).map { + case ((dataType, nullable), ne) => + NamedLambdaVariable(ne.name, dataType, nullable) + } + LambdaFunction(function, arguments) + + case _ => + val arguments = argumentsInfo.zipWithIndex.map { + case ((dataType, nullable), i) => + NamedLambdaVariable(s"col$i", dataType, nullable) + } + LambdaFunction(expression, arguments, hidden = true) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 25fec0fffeaf..9c94d045ae86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -21,9 +21,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.TypeUtils.{toSQLConf, toSQLId} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType /** * Resolve the lambda variables exposed by a higher order functions. @@ -49,53 +46,6 @@ object ResolveLambdaVariables extends Rule[LogicalPlan] { } } - /** - * Create a bound lambda function by binding the arguments of a lambda function to the given - * partial arguments (dataType and nullability only). If the expression happens to be an already - * bound lambda function then we assume it has been bound to the correct arguments and do - * nothing. This function will produce a lambda function with hidden arguments when it is passed - * an arbitrary expression. - */ - private def createLambda( - e: Expression, - argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { - case f: LambdaFunction if f.bound => f - - case LambdaFunction(function, names, _) => - if (names.size != argInfo.size) { - e.failAnalysis( - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH", - messageParameters = Map( - "expectedNumArgs" -> names.size.toString, - "actualNumArgs" -> argInfo.size.toString)) - } - - if (names.map(a => conf.canonicalize(a.name)).distinct.size < names.size) { - e.failAnalysis( - errorClass = "INVALID_LAMBDA_FUNCTION_CALL.DUPLICATE_ARG_NAMES", - messageParameters = Map( - "args" -> names.map(a => conf.canonicalize(a.name)).map(toSQLId(_)).mkString(", "), - "caseSensitiveConfig" -> toSQLConf(SQLConf.CASE_SENSITIVE.key))) - } - - val arguments = argInfo.zip(names).map { - case ((dataType, nullable), ne) => - NamedLambdaVariable(ne.name, dataType, nullable) - } - LambdaFunction(function, arguments) - - case _ => - // This expression does not consume any of the lambda's arguments (it is independent). We do - // create a lambda function with default parameters because this is expected by the higher - // order function. Note that we hide the lambda variables produced by this function in order - // to prevent accidental naming collisions. - val arguments = argInfo.zipWithIndex.map { - case ((dataType, nullable), i) => - NamedLambdaVariable(s"col$i", dataType, nullable) - } - LambdaFunction(e, arguments, hidden = true) - } - /** * Resolve lambda variables in the expression subtree, using the passed lambda variable registry. */ @@ -104,7 +54,7 @@ object ResolveLambdaVariables extends Rule[LogicalPlan] { case h: HigherOrderFunction if h.argumentsResolved && h.checkArgumentDataTypes().isSuccess => SubqueryExpressionInLambdaOrHigherOrderFunctionValidator(e) - h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap)) + h.bind(LambdaBinder(_, _)).mapChildren(resolve(_, parentLambdaMap)) case l: LambdaFunction if !l.bound => SubqueryExpressionInLambdaOrHigherOrderFunctionValidator(e)