Skip to content

Commit

Permalink
[SPARK-14888][SQL] UnresolvedFunction should use FunctionIdentifier
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This patch changes UnresolvedFunction and UnresolvedGenerator to use a FunctionIdentifier rather than just a String for function name. Also changed SessionCatalog to accept FunctionIdentifier in lookupFunction.

## How was this patch tested?
Updated related unit tests.

Author: Reynold Xin <rxin@databricks.com>

Closes #12659 from rxin/SPARK-14888.
  • Loading branch information
rxin committed Apr 25, 2016
1 parent 34336b6 commit f36c9c8
Show file tree
Hide file tree
Showing 13 changed files with 117 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -838,9 +838,9 @@ class Analyzer(
s"its class is ${other.getClass.getCanonicalName}, which is not a generator.")
}
}
case u @ UnresolvedFunction(name, children, isDistinct) =>
case u @ UnresolvedFunction(funcId, children, isDistinct) =>
withPosition(u) {
catalog.lookupFunction(name, children) match {
catalog.lookupFunction(funcId, children) match {
// DISTINCT is not meaningful for a Max or a Min.
case max: Max if isDistinct =>
AggregateExpression(max, Complete, isDistinct = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ object FunctionRegistry {
}

/** See usage above. */
def expression[T <: Expression](name: String)
private def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {

// See if we can find a constructor that accepts Seq[Expression]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{errors, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
Expand All @@ -30,8 +31,8 @@ import org.apache.spark.sql.types.{DataType, StructType}
* Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully
* resolved.
*/
class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: String) extends
errors.TreeNodeException(tree, s"Invalid call to $function on unresolved object", null)
class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: String)
extends TreeNodeException(tree, s"Invalid call to $function on unresolved object", null)

/**
* Holds the name of a relation that has yet to be looked up in a catalog.
Expand Down Expand Up @@ -138,7 +139,8 @@ object UnresolvedAttribute {
* the [[org.apache.spark.sql.catalyst.plans.logical.Generate]] operator.
* The analyzer will resolve this generator.
*/
case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends Generator {
case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expression])
extends Generator {

override def elementTypes: Seq[(DataType, Boolean, String)] =
throw new UnresolvedException(this, "elementTypes")
Expand All @@ -147,7 +149,7 @@ case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false

override def prettyName: String = name
override def prettyName: String = name.unquotedString
override def toString: String = s"'$name(${children.mkString(", ")})"

override def eval(input: InternalRow = null): TraversableOnce[InternalRow] =
Expand All @@ -161,7 +163,7 @@ case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends
}

case class UnresolvedFunction(
name: String,
name: FunctionIdentifier,
children: Seq[Expression],
isDistinct: Boolean)
extends Expression with Unevaluable {
Expand All @@ -171,10 +173,16 @@ case class UnresolvedFunction(
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false

override def prettyName: String = name
override def prettyName: String = name.unquotedString
override def toString: String = s"'$name(${children.mkString(", ")})"
}

object UnresolvedFunction {
def apply(name: String, children: Seq[Expression], isDistinct: Boolean): UnresolvedFunction = {
UnresolvedFunction(FunctionIdentifier(name, None), children, isDistinct)
}
}

/**
* Represents all of the input attributes to a given relational operator, for example in
* "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ class SessionCatalog(
this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true))
}

protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan]
/** List of temporary tables, mapping from table name to their logical plan. */
protected val tempTables = new mutable.HashMap[String, LogicalPlan]

// Note: we track current database here because certain operations do not explicitly
// specify the database (e.g. DROP TABLE my_table). In these cases we must first
// check whether the temporary table or function exists, then, if not, operate on
// the corresponding item in the current database.
protected[this] var currentDb = {
protected var currentDb = {
val defaultName = "default"
val defaultDbDefinition = CatalogDatabase(defaultName, "default database", "", Map())
// Initialize default database if it doesn't already exist
Expand Down Expand Up @@ -118,7 +119,7 @@ class SessionCatalog(

def setCurrentDatabase(db: String): Unit = {
if (!databaseExists(db)) {
throw new AnalysisException(s"cannot set current database to non-existent '$db'")
throw new AnalysisException(s"Database '$db' does not exist.")
}
currentDb = db
}
Expand Down Expand Up @@ -593,9 +594,6 @@ class SessionCatalog(
/**
* Drop a temporary function.
*/
// TODO: The reason that we distinguish dropFunction and dropTempFunction is that
// Hive has DROP FUNCTION and DROP TEMPORARY FUNCTION. We may want to consolidate
// dropFunction and dropTempFunction.
def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = {
if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) {
throw new AnalysisException(
Expand All @@ -622,40 +620,44 @@ class SessionCatalog(
* based on the function class and put the builder into the FunctionRegistry.
* The name of this function in the FunctionRegistry will be `databaseName.functionName`.
*/
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
// TODO: Right now, the name can be qualified or not qualified.
// It will be better to get a FunctionIdentifier.
// TODO: Right now, we assume that name is not qualified!
val qualifiedName = FunctionIdentifier(name, Some(currentDb)).unquotedString
if (functionRegistry.functionExists(name)) {
def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
// Note: the implementation of this function is a little bit convoluted.
// We probably shouldn't use a single FunctionRegistry to register all three kinds of functions
// (built-in, temp, and external).
if (name.database.isEmpty && functionRegistry.functionExists(name.funcName)) {
// This function has been already loaded into the function registry.
functionRegistry.lookupFunction(name, children)
} else if (functionRegistry.functionExists(qualifiedName)) {
return functionRegistry.lookupFunction(name.funcName, children)
}

// If the name itself is not qualified, add the current database to it.
val qualifiedName = if (name.database.isEmpty) name.copy(database = Some(currentDb)) else name

if (functionRegistry.functionExists(qualifiedName.unquotedString)) {
// This function has been already loaded into the function registry.
// Unlike the above block, we find this function by using the qualified name.
functionRegistry.lookupFunction(qualifiedName, children)
} else {
// The function has not been loaded to the function registry, which means
// that the function is a permanent function (if it actually has been registered
// in the metastore). We need to first put the function in the FunctionRegistry.
val catalogFunction = try {
externalCatalog.getFunction(currentDb, name)
} catch {
case e: AnalysisException => failFunctionLookup(name)
case e: NoSuchFunctionException => failFunctionLookup(name)
}
loadFunctionResources(catalogFunction.resources)
// Please note that qualifiedName is provided by the user. However,
// catalogFunction.identifier.unquotedString is returned by the underlying
// catalog. So, it is possible that qualifiedName is not exactly the same as
// catalogFunction.identifier.unquotedString (difference is on case-sensitivity).
// At here, we preserve the input from the user.
val info = new ExpressionInfo(catalogFunction.className, qualifiedName)
val builder = makeFunctionBuilder(qualifiedName, catalogFunction.className)
createTempFunction(qualifiedName, info, builder, ignoreIfExists = false)
// Now, we need to create the Expression.
functionRegistry.lookupFunction(qualifiedName, children)
return functionRegistry.lookupFunction(qualifiedName.unquotedString, children)
}

// The function has not been loaded to the function registry, which means
// that the function is a permanent function (if it actually has been registered
// in the metastore). We need to first put the function in the FunctionRegistry.
val catalogFunction = try {
externalCatalog.getFunction(currentDb, name.funcName)
} catch {
case e: AnalysisException => failFunctionLookup(name.funcName)
case e: NoSuchFunctionException => failFunctionLookup(name.funcName)
}
loadFunctionResources(catalogFunction.resources)
// Please note that qualifiedName is provided by the user. However,
// catalogFunction.identifier.unquotedString is returned by the underlying
// catalog. So, it is possible that qualifiedName is not exactly the same as
// catalogFunction.identifier.unquotedString (difference is on case-sensitivity).
// At here, we preserve the input from the user.
val info = new ExpressionInfo(catalogFunction.className, qualifiedName.unquotedString)
val builder = makeFunctionBuilder(qualifiedName.unquotedString, catalogFunction.className)
createTempFunction(qualifiedName.unquotedString, info, builder, ignoreIfExists = false)
// Now, we need to create the Expression.
return functionRegistry.lookupFunction(qualifiedName.unquotedString, children)
}

/**
Expand All @@ -671,8 +673,6 @@ class SessionCatalog(
externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern)
.map { f => FunctionIdentifier(f) }
// TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry.
// So, the returned list may have two entries for the same function.
dbFunctions ++ loadedFunctions
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,17 @@ package org.apache.spark.sql.catalyst
*/
sealed trait IdentifierWithDatabase {
val identifier: String

def database: Option[String]
def quotedString: String = database.map(db => s"`$db`.`$identifier`").getOrElse(s"`$identifier`")
def unquotedString: String = database.map(db => s"$db.$identifier").getOrElse(identifier)

def quotedString: String = {
if (database.isDefined) s"`${database.get}`.`$identifier`" else s"`$identifier`"
}

def unquotedString: String = {
if (database.isDefined) s"${database.get}.$identifier" else identifier
}

override def toString: String = quotedString
}

Expand Down Expand Up @@ -63,6 +71,8 @@ case class FunctionIdentifier(funcName: String, database: Option[String])
override val identifier: String = funcName

def this(funcName: String) = this(funcName, None)

override def toString: String = unquotedString
}

object FunctionIdentifier {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
Expand Down Expand Up @@ -554,7 +554,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
case "json_tuple" =>
JsonTuple(expressions)
case name =>
UnresolvedGenerator(name, expressions)
UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions)
}

Generate(
Expand Down Expand Up @@ -1033,12 +1033,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)
val arguments = ctx.expression().asScala.map(expression) match {
case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct =>
// Transform COUNT(*) into COUNT(1). Move this to analysis?
// Transform COUNT(*) into COUNT(1).
Seq(Literal(1))
case expressions =>
expressions
}
val function = UnresolvedFunction(name, arguments, isDistinct)
val function = UnresolvedFunction(visitFunctionName(ctx.qualifiedName), arguments, isDistinct)

// Check if the function is evaluated in a windowed context.
ctx.windowSpec match {
Expand All @@ -1050,6 +1050,17 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
}

/**
* Create a function database (optional) and name pair.
*/
protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = {
ctx.identifier().asScala.map(_.getText) match {
case Seq(db, fn) => FunctionIdentifier(fn, Option(db))
case Seq(fn) => FunctionIdentifier(fn, None)
case other => throw new ParseException(s"Unsupported function name '${ctx.getText}'", ctx)
}
}

/**
* Create a reference to a window frame, i.e. [[WindowSpecReference]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -705,11 +705,11 @@ class SessionCatalogSuite extends SparkFunSuite {
catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false)
catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false)
val arguments = Seq(Literal(1), Literal(2), Literal(3))
assert(catalog.lookupFunction("temp1", arguments) === Literal(1))
assert(catalog.lookupFunction("temp2", arguments) === Literal(3))
assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1))
assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3))
// Temporary function does not exist.
intercept[AnalysisException] {
catalog.lookupFunction("temp3", arguments)
catalog.lookupFunction(FunctionIdentifier("temp3"), arguments)
}
val tempFunc3 = (e: Seq[Expression]) => Literal(e.size)
val info3 = new ExpressionInfo("tempFunc3", "temp1")
Expand All @@ -719,7 +719,8 @@ class SessionCatalogSuite extends SparkFunSuite {
}
// Temporary function is overridden
catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true)
assert(catalog.lookupFunction("temp1", arguments) === Literal(arguments.length))
assert(
catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(arguments.length))
}

test("drop function") {
Expand Down Expand Up @@ -755,10 +756,10 @@ class SessionCatalogSuite extends SparkFunSuite {
val tempFunc = (e: Seq[Expression]) => e.head
catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false)
val arguments = Seq(Literal(1), Literal(2), Literal(3))
assert(catalog.lookupFunction("func1", arguments) === Literal(1))
assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1))
catalog.dropTempFunction("func1", ignoreIfNotExists = false)
intercept[AnalysisException] {
catalog.lookupFunction("func1", arguments)
catalog.lookupFunction(FunctionIdentifier("func1"), arguments)
}
intercept[AnalysisException] {
catalog.dropTempFunction("func1", ignoreIfNotExists = false)
Expand Down Expand Up @@ -792,10 +793,11 @@ class SessionCatalogSuite extends SparkFunSuite {
val info1 = new ExpressionInfo("tempFunc1", "func1")
val tempFunc1 = (e: Seq[Expression]) => e.head
catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false)
assert(catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) == Literal(1))
assert(catalog.lookupFunction(
FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1))
catalog.dropTempFunction("func1", ignoreIfNotExists = false)
intercept[AnalysisException] {
catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3)))
catalog.lookupFunction(FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3)))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.parser

import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
Expand Down Expand Up @@ -199,7 +200,8 @@ class ExpressionParserSuite extends PlanTest {

test("function expressions") {
assertEqual("foo()", 'foo.function())
assertEqual("foo.bar()", Symbol("foo.bar").function())
assertEqual("foo.bar()",
UnresolvedFunction(FunctionIdentifier("bar", Some("foo")), Seq.empty, isDistinct = false))
assertEqual("foo(*)", 'foo.function(star()))
assertEqual("count(*)", 'count.function(1))
assertEqual("foo(a, b)", 'foo.function('a, 'b))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
package org.apache.spark.sql.catalyst.parser

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.{BooleanType, IntegerType}
import org.apache.spark.sql.types.IntegerType


class PlanParserSuite extends PlanTest {
import CatalystSqlParser._
Expand Down Expand Up @@ -300,7 +302,7 @@ class PlanParserSuite extends PlanTest {
// Unresolved generator.
val expected = table("t")
.generate(
UnresolvedGenerator("posexplode", Seq('x)),
UnresolvedGenerator(FunctionIdentifier("posexplode"), Seq('x)),
join = true,
outer = false,
Some("posexpl"),
Expand Down

0 comments on commit f36c9c8

Please sign in to comment.