Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2312,7 +2312,11 @@ class Analyzer(override val catalogManager: CatalogManager)

plan.resolveExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_FUNCTION)) {
case f @ UnresolvedFunction(nameParts, _, _, _, _) =>
if (ResolveFunctions.lookupBuiltinOrTempFunction(nameParts).isDefined) {
val normQualifier = normalizeFuncName(nameParts).dropRight(1);
if ((Set(Seq(), Seq("builtin"), Seq("session"),
Seq("system", "builtin"), Seq("system", "session")) contains normQualifier) &&
ResolveFunctions.lookupBuiltinOrTempFunction(nameParts.takeRight(2))
.isDefined) {
f
} else {
val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts)
Expand Down Expand Up @@ -2467,16 +2471,16 @@ class Analyzer(override val catalogManager: CatalogManager)
}

def lookupBuiltinOrTempFunction(name: Seq[String]): Option[ExpressionInfo] = {
if (name.length == 1) {
v1SessionCatalog.lookupBuiltinOrTempFunction(name.head)
if (name.length == 1 || name.length == 2) {
v1SessionCatalog.lookupBuiltinOrTempFunction(name)
} else {
None
}
}

def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = {
if (name.length == 1) {
v1SessionCatalog.lookupBuiltinOrTempTableFunction(name.head)
if (name.length == 1 || name.length == 2) {
v1SessionCatalog.lookupBuiltinOrTempTableFunction(name)
} else {
None
}
Expand All @@ -2486,8 +2490,10 @@ class Analyzer(override val catalogManager: CatalogManager)
name: Seq[String],
arguments: Seq[Expression],
u: Option[UnresolvedFunction]): Option[Expression] = {
if (name.length == 1) {
v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments).map { func =>
val normQualifier = normalizeFuncName(name).dropRight(1);
if ((Set(Seq(), Seq("builtin"), Seq("session"),
Seq("system", "builtin"), Seq("system", "session")) contains normQualifier)) {
v1SessionCatalog.resolveBuiltinOrTempFunction(name.takeRight(2), arguments).map { func =>
if (u.isDefined) validateFunction(func, arguments.length, u.get) else func
}
} else {
Expand All @@ -2498,8 +2504,10 @@ class Analyzer(override val catalogManager: CatalogManager)
private def resolveBuiltinOrTempTableFunction(
name: Seq[String],
arguments: Seq[Expression]): Option[LogicalPlan] = {
if (name.length == 1) {
v1SessionCatalog.resolveBuiltinOrTempTableFunction(name.head, arguments)
val normQualifier = normalizeFuncName(name).dropRight(1);
if (normQualifier.length == 0 ||
normQualifier == Seq("builtin") || normQualifier == Seq("system", "builtin")) {
v1SessionCatalog.resolveBuiltinOrTempTableFunction(name, arguments)
} else {
None
}
Expand Down Expand Up @@ -2668,6 +2676,14 @@ class Analyzer(override val catalogManager: CatalogManager)
val aggregator = V2Aggregator(aggFunc, arguments)
aggregator.toAggregateExpression(u.isDistinct, u.filter)
}

def normalizeFuncName(name: Seq[String]): Seq[String] = {
if (conf.caseSensitiveAnalysis) {
name
} else {
name.map(_.toLowerCase(Locale.ROOT))
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ object FunctionRegistry {
val fr = new SimpleFunctionRegistry
expressions.foreach {
case (name, (info, builder)) =>
fr.internalRegisterFunction(FunctionIdentifier(name), info, builder)
fr.internalRegisterFunction(FunctionIdentifier(name, Option("builtin")), info, builder)
}
fr
}
Expand Down Expand Up @@ -968,7 +968,7 @@ object TableFunctionRegistry {
val fr = new SimpleTableFunctionRegistry
logicalPlans.foreach {
case (name, (info, builder)) =>
fr.internalRegisterFunction(FunctionIdentifier(name), info, builder)
fr.internalRegisterFunction(FunctionIdentifier(name, Option("builtin")), info, builder)
}
fr
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, quoteNameParts}
Expand Down Expand Up @@ -128,8 +129,9 @@ class NoSuchPartitionsException(errorClass: String, messageParameters: Map[Strin
}
}

class NoSuchTempFunctionException(func: String)
extends AnalysisException(errorClass = "ROUTINE_NOT_FOUND", Map("routineName" -> s"`$func`"))
class NoSuchTempFunctionException(func: FunctionIdentifier)
extends AnalysisException(errorClass = "ROUTINE_NOT_FOUND",
Map("routineName" -> s"`${func.database.get}`.`${func.funcName}`"))

class NoSuchIndexException(message: String, cause: Option[Throwable] = None)
extends AnalysisException(errorClass = "INDEX_NOT_FOUND",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1544,9 +1544,9 @@ class SessionCatalog(
/**
* Drop a temporary function.
*/
def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = {
if (!functionRegistry.dropFunction(FunctionIdentifier(name)) &&
!tableFunctionRegistry.dropFunction(FunctionIdentifier(name)) &&
def dropTempFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = {
if (!functionRegistry.dropFunction(name) &&
!tableFunctionRegistry.dropFunction(name) &&
!ignoreIfNotExists) {
throw new NoSuchTempFunctionException(name)
}
Expand All @@ -1557,8 +1557,8 @@ class SessionCatalog(
*/
def isTemporaryFunction(name: FunctionIdentifier): Boolean = {
// A temporary function is a function that has been registered in functionRegistry
// without a database name, and is neither a built-in function nor a Hive function
name.database.isEmpty && isRegisteredFunction(name) && !isBuiltinFunction(name)
// with a database name "session", and is neither a built-in function nor a Hive function
name.database == Some("session") && isRegisteredFunction(name) && !isBuiltinFunction(name)
}

/**
Expand Down Expand Up @@ -1596,8 +1596,13 @@ class SessionCatalog(
* Look up the `ExpressionInfo` of the given function by name if it's a built-in or temp function.
* This only supports scalar functions.
*/
def lookupBuiltinOrTempFunction(name: String): Option[ExpressionInfo] = {
FunctionRegistry.builtinOperators.get(name.toLowerCase(Locale.ROOT)).orElse {
def lookupBuiltinOrTempFunction(name: Seq[String]): Option[ExpressionInfo] = {
if (name.length == 1) {
FunctionRegistry.builtinOperators.get(name.head.toLowerCase(Locale.ROOT)).orElse {
synchronized(lookupTempFuncWithViewContext(
name, FunctionRegistry.builtin.functionExists, functionRegistry.lookupFunction))
}
} else {
synchronized(lookupTempFuncWithViewContext(
name, FunctionRegistry.builtin.functionExists, functionRegistry.lookupFunction))
}
Expand All @@ -1607,7 +1612,7 @@ class SessionCatalog(
* Look up the `ExpressionInfo` of the given function by name if it's a built-in or
* temp table function.
*/
def lookupBuiltinOrTempTableFunction(name: String): Option[ExpressionInfo] = synchronized {
def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = synchronized {
lookupTempFuncWithViewContext(
name, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry.lookupFunction)
}
Expand All @@ -1616,7 +1621,8 @@ class SessionCatalog(
* Look up a built-in or temp scalar function by name and resolves it to an Expression if such
* a function exists.
*/
def resolveBuiltinOrTempFunction(name: String, arguments: Seq[Expression]): Option[Expression] = {
def resolveBuiltinOrTempFunction(name: Seq[String],
arguments: Seq[Expression]): Option[Expression] = {
resolveBuiltinOrTempFunctionInternal(
name, arguments, FunctionRegistry.builtin.functionExists, functionRegistry)
}
Expand All @@ -1626,18 +1632,31 @@ class SessionCatalog(
* a function exists.
*/
def resolveBuiltinOrTempTableFunction(
name: String, arguments: Seq[Expression]): Option[LogicalPlan] = {
name: Seq[String], arguments: Seq[Expression]): Option[LogicalPlan] = {
resolveBuiltinOrTempFunctionInternal(
name, arguments, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry)
}

private def resolveBuiltinOrTempFunctionInternal[T](
name: String,
name: Seq[String],
arguments: Seq[Expression],
isBuiltin: FunctionIdentifier => Boolean,
registry: FunctionRegistryBase[T]): Option[T] = synchronized {
val funcIdent = FunctionIdentifier(name)
if (!registry.functionExists(funcIdent)) {
val funcIdent = FunctionIdentifier(name.last,
if (name.length == 1) {
Some("builtin")
} else {
name.headOption
}
)
val tempIdent = FunctionIdentifier(name.last,
if (name.length == 1) {
Some("session")
} else {
name.headOption
}
)
if (!registry.functionExists(funcIdent) && !registry.functionExists(tempIdent)) {
None
} else {
lookupTempFuncWithViewContext(
Expand All @@ -1646,29 +1665,41 @@ class SessionCatalog(
}

private def lookupTempFuncWithViewContext[T](
name: String,
name: Seq[String],
isBuiltin: FunctionIdentifier => Boolean,
lookupFunc: FunctionIdentifier => Option[T]): Option[T] = {
val funcIdent = FunctionIdentifier(name)
val funcIdent = FunctionIdentifier(name.last,
if (name.length == 1) {
Some("builtin")
} else {
name.headOption }
)
if (isBuiltin(funcIdent)) {
lookupFunc(funcIdent)
} else {
val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty
val referredTempFunctionNames = AnalysisContext.get.referredTempFunctionNames
val tempIdent = FunctionIdentifier(name.last,
if (name.length == 1) {
Some("session")
} else {
name.headOption
}
)
if (isResolvingView) {
// When resolving a view, only return a temp function if it's referred by this view.
if (referredTempFunctionNames.contains(name)) {
lookupFunc(funcIdent)
if (referredTempFunctionNames.contains(name.last)) {
lookupFunc(tempIdent)
} else {
None
}
} else {
val result = lookupFunc(funcIdent)
val result = lookupFunc(tempIdent)
if (result.isDefined) {
// We are not resolving a view and the function is a temp one, add it to
// `AnalysisContext`, so during the view creation, we can save all referred temp
// functions to view metadata.
AnalysisContext.get.referredTempFunctionNames.add(name)
AnalysisContext.get.referredTempFunctionNames.add(name.last)
}
result
}
Expand Down Expand Up @@ -1753,8 +1784,8 @@ class SessionCatalog(
*/
def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized {
if (name.database.isEmpty) {
lookupBuiltinOrTempFunction(name.funcName)
.orElse(lookupBuiltinOrTempTableFunction(name.funcName))
lookupBuiltinOrTempFunction(Seq(name.funcName, name.database.orNull))
.orElse(lookupBuiltinOrTempTableFunction(Seq(name.funcName, name.database.orNull)))
.getOrElse(lookupPersistentFunction(name))
} else {
lookupPersistentFunction(name)
Expand All @@ -1765,7 +1796,7 @@ class SessionCatalog(
// function from either v1 or v2 catalog. This method only look up v1 catalog.
def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
if (name.database.isEmpty) {
resolveBuiltinOrTempFunction(name.funcName, children)
resolveBuiltinOrTempFunction(Seq(name.funcName, name.database.orNull), children)
.getOrElse(resolvePersistentFunction(name, children))
} else {
resolvePersistentFunction(name, children)
Expand All @@ -1774,7 +1805,7 @@ class SessionCatalog(

def lookupTableFunction(name: FunctionIdentifier, children: Seq[Expression]): LogicalPlan = {
if (name.database.isEmpty) {
resolveBuiltinOrTempTableFunction(name.funcName, children)
resolveBuiltinOrTempTableFunction(Seq(name.funcName, name.database.orNull), children)
.getOrElse(resolvePersistentTableFunction(name, children))
} else {
resolvePersistentTableFunction(name, children)
Expand All @@ -1786,8 +1817,10 @@ class SessionCatalog(
*/
private def listBuiltinAndTempFunctions(pattern: String): Seq[FunctionIdentifier] = {
val functions = (functionRegistry.listFunction() ++ tableFunctionRegistry.listFunction())
.filter(_.database.isEmpty)
StringUtils.filterPattern(functions.map(_.unquotedString), pattern).map { f =>
.filter(funcId => funcId.catalog.isEmpty &&
(funcId.database == Some("builtin") || funcId.database == Some("session")))
StringUtils.filterPattern(functions.map(f => FunctionIdentifier(f.funcName).unquotedString),
pattern).map { f =>
// In functionRegistry, function names are stored as an unquoted format.
Try(parser.parseFunctionIdentifier(f)) match {
case Success(e) => e
Expand Down Expand Up @@ -1820,8 +1853,13 @@ class SessionCatalog(
// The session catalog caches some persistent functions in the FunctionRegistry
// so there can be duplicates.
functions.map {
case f if FunctionRegistry.functionSet.contains(f) => (f, "SYSTEM")
case f if TableFunctionRegistry.functionSet.contains(f) => (f, "SYSTEM")
case f if f.database.isEmpty &&
FunctionRegistry.functionSet.contains(FunctionIdentifier(f.funcName, Some("builtin"))) =>
(f, "SYSTEM")
case f if f.database.isEmpty &&
TableFunctionRegistry.functionSet.contains(FunctionIdentifier(f.funcName,
Some("builtin"))) =>
(f, "SYSTEM")
case f if f.database.isDefined => (qualifyIdentifier(f), "USER")
case f => (f, "USER")
}.distinct
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase {
ctx)
}

def unsupportedFunctionNameError(funcName: Seq[String], ctx: CreateFunctionContext): Throwable = {
def unsupportedFunctionNameError(funcName: Seq[String], ctx: StatementContext): Throwable = {
new ParseException(
errorClass = "INVALID_SQL_SYNTAX",
messageParameters = Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,12 +568,13 @@ class SparkSqlAstBuilder extends AstBuilder {

if (functionIdentifier.length > 2) {
throw QueryParsingErrors.unsupportedFunctionNameError(functionIdentifier, ctx)
} else if (functionIdentifier.length == 2) {
} else if (functionIdentifier.length == 2 &&
functionIdentifier.head.toLowerCase(Locale.ROOT) != "session") {
// Temporary function names should not contain database prefix like "database.function"
throw QueryParsingErrors.specifyingDBInCreateTempFuncError(functionIdentifier.head, ctx)
}
CreateFunctionCommand(
FunctionIdentifier(functionIdentifier.last),
FunctionIdentifier(functionIdentifier.last, Option("session")),
string(visitStringLit(ctx.className)),
resources.toSeq,
true,
Expand All @@ -594,11 +595,14 @@ class SparkSqlAstBuilder extends AstBuilder {
val functionName = visitMultipartIdentifier(ctx.multipartIdentifier)
val isTemp = ctx.TEMPORARY != null
if (isTemp) {
if (functionName.length > 1) {
if (functionName.length > 2) {
throw QueryParsingErrors.unsupportedFunctionNameError(functionName, ctx)
} else if (functionName.length == 2 &&
functionName.head.toLowerCase(Locale.ROOT) != "session") {
throw QueryParsingErrors.invalidNameForDropTempFunc(functionName, ctx)
}
DropFunctionCommand(
identifier = FunctionIdentifier(functionName.head),
identifier = FunctionIdentifier(functionName.last, Option("session")),
ifExists = ctx.EXISTS != null,
isTemp = true)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,8 @@ case class DropFunctionCommand(
override def run(sparkSession: SparkSession): Seq[Row] = {
val catalog = sparkSession.sessionState.catalog
if (isTemp) {
assert(identifier.database.isEmpty)
if (FunctionRegistry.builtin.functionExists(identifier)) {
throw QueryCompilationErrors.cannotDropBuiltinFuncError(identifier.funcName)
}
catalog.dropTempFunction(identifier.funcName, ifExists)
assert(identifier.database == Some("session"))
catalog.dropTempFunction(identifier, ifExists)
} else {
// We are dropping a permanent function.
catalog.dropFunction(identifier, ignoreIfNotExists = ifExists)
Expand All @@ -162,7 +159,7 @@ case class DropFunctionCommand(
* '|' is for alternation.
* For example, "show functions like 'yea*|windo*'" will return "window" and "year".
*/
case class ShowFunctionsCommand(
case class ShowFunctionsCommand (
db: String,
pattern: Option[String],
showUserFunctions: Boolean,
Expand Down