From 0494383f43f4879dbfa06b3832756b39c37e1f49 Mon Sep 17 00:00:00 2001 From: Serge Rielau Date: Thu, 22 Dec 2022 17:31:51 -0800 Subject: [PATCH] [SPARK-41670] builtin schema --- .../sql/catalyst/analysis/Analyzer.scala | 34 +++++-- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../analysis/NoSuchItemException.scala | 6 +- .../sql/catalyst/catalog/SessionCatalog.scala | 92 +++++++++++++------ .../spark/sql/errors/QueryParsingErrors.scala | 2 +- .../spark/sql/execution/SparkSqlParser.scala | 12 ++- .../sql/execution/command/functions.scala | 9 +- 7 files changed, 108 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index daeddd309d7d1..09926af59c818 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2312,7 +2312,11 @@ class Analyzer(override val catalogManager: CatalogManager) plan.resolveExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_FUNCTION)) { case f @ UnresolvedFunction(nameParts, _, _, _, _) => - if (ResolveFunctions.lookupBuiltinOrTempFunction(nameParts).isDefined) { + val normQualifier = normalizeFuncName(nameParts).dropRight(1); + if ((Set(Seq(), Seq("builtin"), Seq("session"), + Seq("system", "builtin"), Seq("system", "session")) contains normQualifier) && + ResolveFunctions.lookupBuiltinOrTempFunction(nameParts.takeRight(2)) + .isDefined) { f } else { val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts) @@ -2467,16 +2471,16 @@ class Analyzer(override val catalogManager: CatalogManager) } def lookupBuiltinOrTempFunction(name: Seq[String]): Option[ExpressionInfo] = { - if (name.length == 1) { - v1SessionCatalog.lookupBuiltinOrTempFunction(name.head) + if (name.length == 1 || name.length == 2) { + v1SessionCatalog.lookupBuiltinOrTempFunction(name) } else { None } } def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = { - if (name.length == 1) { - v1SessionCatalog.lookupBuiltinOrTempTableFunction(name.head) + if (name.length == 1 || name.length == 2) { + v1SessionCatalog.lookupBuiltinOrTempTableFunction(name) } else { None } @@ -2486,8 +2490,10 @@ class Analyzer(override val catalogManager: CatalogManager) name: Seq[String], arguments: Seq[Expression], u: Option[UnresolvedFunction]): Option[Expression] = { - if (name.length == 1) { - v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments).map { func => + val normQualifier = normalizeFuncName(name).dropRight(1); + if ((Set(Seq(), Seq("builtin"), Seq("session"), + Seq("system", "builtin"), Seq("system", "session")) contains normQualifier)) { + v1SessionCatalog.resolveBuiltinOrTempFunction(name.takeRight(2), arguments).map { func => if (u.isDefined) validateFunction(func, arguments.length, u.get) else func } } else { @@ -2498,8 +2504,10 @@ class Analyzer(override val catalogManager: CatalogManager) private def resolveBuiltinOrTempTableFunction( name: Seq[String], arguments: Seq[Expression]): Option[LogicalPlan] = { - if (name.length == 1) { - v1SessionCatalog.resolveBuiltinOrTempTableFunction(name.head, arguments) + val normQualifier = normalizeFuncName(name).dropRight(1); + if (normQualifier.length == 0 || + normQualifier == Seq("builtin") || normQualifier == Seq("system", "builtin")) { + v1SessionCatalog.resolveBuiltinOrTempTableFunction(name, arguments) } else { None } @@ -2668,6 +2676,14 @@ class Analyzer(override val catalogManager: CatalogManager) val aggregator = V2Aggregator(aggFunc, arguments) aggregator.toAggregateExpression(u.isDistinct, u.filter) } + + def normalizeFuncName(name: Seq[String]): Seq[String] = { + if (conf.caseSensitiveAnalysis) { + name + } else { + name.map(_.toLowerCase(Locale.ROOT)) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 4b6603b635880..9e16031c66a1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -815,7 +815,7 @@ object FunctionRegistry { val fr = new SimpleFunctionRegistry expressions.foreach { case (name, (info, builder)) => - fr.internalRegisterFunction(FunctionIdentifier(name), info, builder) + fr.internalRegisterFunction(FunctionIdentifier(name, Option("builtin")), info, builder) } fr } @@ -968,7 +968,7 @@ object TableFunctionRegistry { val fr = new SimpleTableFunctionRegistry logicalPlans.foreach { case (name, (info, builder)) => - fr.internalRegisterFunction(FunctionIdentifier(name), info, builder) + fr.internalRegisterFunction(FunctionIdentifier(name, Option("builtin")), info, builder) } fr } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index f6624126e9492..32efda8924700 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.util.{quoteIdentifier, quoteNameParts} @@ -128,8 +129,9 @@ class NoSuchPartitionsException(errorClass: String, messageParameters: Map[Strin } } -class NoSuchTempFunctionException(func: String) - extends AnalysisException(errorClass = "ROUTINE_NOT_FOUND", Map("routineName" -> s"`$func`")) +class NoSuchTempFunctionException(func: FunctionIdentifier) + extends AnalysisException(errorClass = "ROUTINE_NOT_FOUND", + Map("routineName" -> s"`${func.database.get}`.`${func.funcName}`")) class NoSuchIndexException(message: String, cause: Option[Throwable] = None) extends AnalysisException(errorClass = "INDEX_NOT_FOUND", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 0621461329944..258bda525f483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1544,9 +1544,9 @@ class SessionCatalog( /** * Drop a temporary function. */ - def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = { - if (!functionRegistry.dropFunction(FunctionIdentifier(name)) && - !tableFunctionRegistry.dropFunction(FunctionIdentifier(name)) && + def dropTempFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = { + if (!functionRegistry.dropFunction(name) && + !tableFunctionRegistry.dropFunction(name) && !ignoreIfNotExists) { throw new NoSuchTempFunctionException(name) } @@ -1557,8 +1557,8 @@ class SessionCatalog( */ def isTemporaryFunction(name: FunctionIdentifier): Boolean = { // A temporary function is a function that has been registered in functionRegistry - // without a database name, and is neither a built-in function nor a Hive function - name.database.isEmpty && isRegisteredFunction(name) && !isBuiltinFunction(name) + // with a database name "session", and is neither a built-in function nor a Hive function + name.database == Some("session") && isRegisteredFunction(name) && !isBuiltinFunction(name) } /** @@ -1596,8 +1596,13 @@ class SessionCatalog( * Look up the `ExpressionInfo` of the given function by name if it's a built-in or temp function. * This only supports scalar functions. */ - def lookupBuiltinOrTempFunction(name: String): Option[ExpressionInfo] = { - FunctionRegistry.builtinOperators.get(name.toLowerCase(Locale.ROOT)).orElse { + def lookupBuiltinOrTempFunction(name: Seq[String]): Option[ExpressionInfo] = { + if (name.length == 1) { + FunctionRegistry.builtinOperators.get(name.head.toLowerCase(Locale.ROOT)).orElse { + synchronized(lookupTempFuncWithViewContext( + name, FunctionRegistry.builtin.functionExists, functionRegistry.lookupFunction)) + } + } else { synchronized(lookupTempFuncWithViewContext( name, FunctionRegistry.builtin.functionExists, functionRegistry.lookupFunction)) } @@ -1607,7 +1612,7 @@ class SessionCatalog( * Look up the `ExpressionInfo` of the given function by name if it's a built-in or * temp table function. */ - def lookupBuiltinOrTempTableFunction(name: String): Option[ExpressionInfo] = synchronized { + def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = synchronized { lookupTempFuncWithViewContext( name, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry.lookupFunction) } @@ -1616,7 +1621,8 @@ class SessionCatalog( * Look up a built-in or temp scalar function by name and resolves it to an Expression if such * a function exists. */ - def resolveBuiltinOrTempFunction(name: String, arguments: Seq[Expression]): Option[Expression] = { + def resolveBuiltinOrTempFunction(name: Seq[String], + arguments: Seq[Expression]): Option[Expression] = { resolveBuiltinOrTempFunctionInternal( name, arguments, FunctionRegistry.builtin.functionExists, functionRegistry) } @@ -1626,18 +1632,31 @@ class SessionCatalog( * a function exists. */ def resolveBuiltinOrTempTableFunction( - name: String, arguments: Seq[Expression]): Option[LogicalPlan] = { + name: Seq[String], arguments: Seq[Expression]): Option[LogicalPlan] = { resolveBuiltinOrTempFunctionInternal( name, arguments, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry) } private def resolveBuiltinOrTempFunctionInternal[T]( - name: String, + name: Seq[String], arguments: Seq[Expression], isBuiltin: FunctionIdentifier => Boolean, registry: FunctionRegistryBase[T]): Option[T] = synchronized { - val funcIdent = FunctionIdentifier(name) - if (!registry.functionExists(funcIdent)) { + val funcIdent = FunctionIdentifier(name.last, + if (name.length == 1) { + Some("builtin") + } else { + name.headOption + } + ) + val tempIdent = FunctionIdentifier(name.last, + if (name.length == 1) { + Some("session") + } else { + name.headOption + } + ) + if (!registry.functionExists(funcIdent) && !registry.functionExists(tempIdent)) { None } else { lookupTempFuncWithViewContext( @@ -1646,29 +1665,41 @@ class SessionCatalog( } private def lookupTempFuncWithViewContext[T]( - name: String, + name: Seq[String], isBuiltin: FunctionIdentifier => Boolean, lookupFunc: FunctionIdentifier => Option[T]): Option[T] = { - val funcIdent = FunctionIdentifier(name) + val funcIdent = FunctionIdentifier(name.last, + if (name.length == 1) { + Some("builtin") + } else { + name.headOption } + ) if (isBuiltin(funcIdent)) { lookupFunc(funcIdent) } else { val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty val referredTempFunctionNames = AnalysisContext.get.referredTempFunctionNames + val tempIdent = FunctionIdentifier(name.last, + if (name.length == 1) { + Some("session") + } else { + name.headOption + } + ) if (isResolvingView) { // When resolving a view, only return a temp function if it's referred by this view. - if (referredTempFunctionNames.contains(name)) { - lookupFunc(funcIdent) + if (referredTempFunctionNames.contains(name.last)) { + lookupFunc(tempIdent) } else { None } } else { - val result = lookupFunc(funcIdent) + val result = lookupFunc(tempIdent) if (result.isDefined) { // We are not resolving a view and the function is a temp one, add it to // `AnalysisContext`, so during the view creation, we can save all referred temp // functions to view metadata. - AnalysisContext.get.referredTempFunctionNames.add(name) + AnalysisContext.get.referredTempFunctionNames.add(name.last) } result } @@ -1753,8 +1784,8 @@ class SessionCatalog( */ def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized { if (name.database.isEmpty) { - lookupBuiltinOrTempFunction(name.funcName) - .orElse(lookupBuiltinOrTempTableFunction(name.funcName)) + lookupBuiltinOrTempFunction(Seq(name.funcName, name.database.orNull)) + .orElse(lookupBuiltinOrTempTableFunction(Seq(name.funcName, name.database.orNull))) .getOrElse(lookupPersistentFunction(name)) } else { lookupPersistentFunction(name) @@ -1765,7 +1796,7 @@ class SessionCatalog( // function from either v1 or v2 catalog. This method only look up v1 catalog. def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { if (name.database.isEmpty) { - resolveBuiltinOrTempFunction(name.funcName, children) + resolveBuiltinOrTempFunction(Seq(name.funcName, name.database.orNull), children) .getOrElse(resolvePersistentFunction(name, children)) } else { resolvePersistentFunction(name, children) @@ -1774,7 +1805,7 @@ class SessionCatalog( def lookupTableFunction(name: FunctionIdentifier, children: Seq[Expression]): LogicalPlan = { if (name.database.isEmpty) { - resolveBuiltinOrTempTableFunction(name.funcName, children) + resolveBuiltinOrTempTableFunction(Seq(name.funcName, name.database.orNull), children) .getOrElse(resolvePersistentTableFunction(name, children)) } else { resolvePersistentTableFunction(name, children) @@ -1786,8 +1817,10 @@ class SessionCatalog( */ private def listBuiltinAndTempFunctions(pattern: String): Seq[FunctionIdentifier] = { val functions = (functionRegistry.listFunction() ++ tableFunctionRegistry.listFunction()) - .filter(_.database.isEmpty) - StringUtils.filterPattern(functions.map(_.unquotedString), pattern).map { f => + .filter(funcId => funcId.catalog.isEmpty && + (funcId.database == Some("builtin") || funcId.database == Some("session"))) + StringUtils.filterPattern(functions.map(f => FunctionIdentifier(f.funcName).unquotedString), + pattern).map { f => // In functionRegistry, function names are stored as an unquoted format. Try(parser.parseFunctionIdentifier(f)) match { case Success(e) => e @@ -1820,8 +1853,13 @@ class SessionCatalog( // The session catalog caches some persistent functions in the FunctionRegistry // so there can be duplicates. functions.map { - case f if FunctionRegistry.functionSet.contains(f) => (f, "SYSTEM") - case f if TableFunctionRegistry.functionSet.contains(f) => (f, "SYSTEM") + case f if f.database.isEmpty && + FunctionRegistry.functionSet.contains(FunctionIdentifier(f.funcName, Some("builtin"))) => + (f, "SYSTEM") + case f if f.database.isEmpty && + TableFunctionRegistry.functionSet.contains(FunctionIdentifier(f.funcName, + Some("builtin"))) => + (f, "SYSTEM") case f if f.database.isDefined => (qualifyIdentifier(f), "USER") case f => (f, "USER") }.distinct diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 773a79a3f3f0a..19d45567cff93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -557,7 +557,7 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase { ctx) } - def unsupportedFunctionNameError(funcName: Seq[String], ctx: CreateFunctionContext): Throwable = { + def unsupportedFunctionNameError(funcName: Seq[String], ctx: StatementContext): Throwable = { new ParseException( errorClass = "INVALID_SQL_SYNTAX", messageParameters = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index f551aa9efbf17..32ca4a0922fd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -568,12 +568,13 @@ class SparkSqlAstBuilder extends AstBuilder { if (functionIdentifier.length > 2) { throw QueryParsingErrors.unsupportedFunctionNameError(functionIdentifier, ctx) - } else if (functionIdentifier.length == 2) { + } else if (functionIdentifier.length == 2 && + functionIdentifier.head.toLowerCase(Locale.ROOT) != "session") { // Temporary function names should not contain database prefix like "database.function" throw QueryParsingErrors.specifyingDBInCreateTempFuncError(functionIdentifier.head, ctx) } CreateFunctionCommand( - FunctionIdentifier(functionIdentifier.last), + FunctionIdentifier(functionIdentifier.last, Option("session")), string(visitStringLit(ctx.className)), resources.toSeq, true, @@ -594,11 +595,14 @@ class SparkSqlAstBuilder extends AstBuilder { val functionName = visitMultipartIdentifier(ctx.multipartIdentifier) val isTemp = ctx.TEMPORARY != null if (isTemp) { - if (functionName.length > 1) { + if (functionName.length > 2) { + throw QueryParsingErrors.unsupportedFunctionNameError(functionName, ctx) + } else if (functionName.length == 2 && + functionName.head.toLowerCase(Locale.ROOT) != "session") { throw QueryParsingErrors.invalidNameForDropTempFunc(functionName, ctx) } DropFunctionCommand( - identifier = FunctionIdentifier(functionName.head), + identifier = FunctionIdentifier(functionName.last, Option("session")), ifExists = ctx.EXISTS != null, isTemp = true) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index eb88acd7b0b28..57fffdc389c3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -138,11 +138,8 @@ case class DropFunctionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog if (isTemp) { - assert(identifier.database.isEmpty) - if (FunctionRegistry.builtin.functionExists(identifier)) { - throw QueryCompilationErrors.cannotDropBuiltinFuncError(identifier.funcName) - } - catalog.dropTempFunction(identifier.funcName, ifExists) + assert(identifier.database == Some("session")) + catalog.dropTempFunction(identifier, ifExists) } else { // We are dropping a permanent function. catalog.dropFunction(identifier, ignoreIfNotExists = ifExists) @@ -162,7 +159,7 @@ case class DropFunctionCommand( * '|' is for alternation. * For example, "show functions like 'yea*|windo*'" will return "window" and "year". */ -case class ShowFunctionsCommand( +case class ShowFunctionsCommand ( db: String, pattern: Option[String], showUserFunctions: Boolean,