Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8478][SQL] Harmonize UDF-related code to use uniformly UDF instead of Udf #6920

Closed
wants to merge 9 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.types.DataType
* User-defined function.
* @param dataType Return type of function.
*/
case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression])
case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expression])
extends Expression {

override def nullable: Boolean = true
Expand Down Expand Up @@ -957,6 +957,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)
override def eval(input: InternalRow): Any = converter(f(input))

// TODO(davies): make ScalaUdf work with codegen
// TODO(davies): make ScalaUDF work with codegen
override def isThreadSafe: Boolean = false
}
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] lazy val analyzer: Analyzer =
new Analyzer(catalog, functionRegistry, conf) {
override val extendedResolutionRules =
ExtractPythonUdfs ::
ExtractPythonUDFs ::
sources.PreInsertCastAndRename ::
Nil

Expand Down Expand Up @@ -257,7 +257,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* The following example registers a Scala closure as UDF:
* {{{
* sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1)
* sqlContext.udf.register("myUDF", (arg1: Int, arg2: String) => arg2 + arg1)
* }}}
*
* The following example registers a UDF in Java:
Expand Down
96 changes: 48 additions & 48 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.Accumulator
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.expressions.ScalaUdf
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
import org.apache.spark.sql.execution.PythonUDF
import org.apache.spark.sql.types.DataType

Expand All @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.DataType
case class UserDefinedFunction protected[sql] (f: AnyRef, dataType: DataType) {

def apply(exprs: Column*): Column = {
Column(ScalaUdf(f, dataType, exprs.map(_.expr)))
Column(ScalaUDF(f, dataType, exprs.map(_.expr)))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ private[spark] case class PythonUDF(
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan
Expand Down
34 changes: 17 additions & 17 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1477,7 +1477,7 @@ object functions {
(0 to 10).map { x =>
val args = (1 to x).map(i => s"arg$i: Column").mkString(", ")
val fTypes = Seq.fill(x + 1)("_").mkString(", ")
val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ")
val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ")
println(s"""
/**
* Call a Scala function of ${x} arguments as user-defined function (UDF). This requires
Expand All @@ -1489,7 +1489,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = {
ScalaUdf(f, returnType, Seq($argsInUdf))
ScalaUDF(f, returnType, Seq($argsInUDF))
}""")
}
}
Expand Down Expand Up @@ -1627,7 +1627,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function0[_], returnType: DataType): Column = {
ScalaUdf(f, returnType, Seq())
ScalaUDF(f, returnType, Seq())
}

/**
Expand All @@ -1640,7 +1640,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr))
ScalaUDF(f, returnType, Seq(arg1.expr))
}

/**
Expand All @@ -1653,7 +1653,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr))
ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr))
}

/**
Expand All @@ -1666,7 +1666,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr))
ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr))
}

/**
Expand All @@ -1679,7 +1679,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
}

/**
Expand All @@ -1692,7 +1692,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
}

/**
Expand All @@ -1705,7 +1705,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
}

/**
Expand All @@ -1718,7 +1718,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
}

/**
Expand All @@ -1731,7 +1731,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
}

/**
Expand All @@ -1744,7 +1744,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
}

/**
Expand All @@ -1757,7 +1757,7 @@ object functions {
*/
@deprecated("Use udf", "1.5.0")
def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
}

// scalastyle:on
Expand All @@ -1770,8 +1770,8 @@ object functions {
*
* val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
* val sqlContext = df.sqlContext
* sqlContext.udf.register("simpleUdf", (v: Int) => v * v)
* df.select($"id", callUDF("simpleUdf", $"value"))
* sqlContext.udf.register("simpleUDF", (v: Int) => v * v)
* df.select($"id", callUDF("simpleUDF", $"value"))
* }}}
*
* @group udf_funcs
Expand All @@ -1789,8 +1789,8 @@ object functions {
*
* val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
* val sqlContext = df.sqlContext
* sqlContext.udf.register("simpleUdf", (v: Int) => v * v)
* df.select($"id", callUdf("simpleUdf", $"value"))
* sqlContext.udf.register("simpleUDF", (v: Int) => v * v)
* df.select($"id", callUdf("simpleUDF", $"value"))
* }}}
*
* @group udf_funcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {

val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index")
// we except the id is materialized once
val idUdf = udf(() => UUID.randomUUID().toString)
val idUDF = udf(() => UUID.randomUUID().toString)

val dfWithId = df.withColumn("id", idUdf())
val dfWithId = df.withColumn("id", idUDF())
// Make a new DataFrame (actually the same reference to the old one)
val cached = dfWithId.cache()
// Trigger the cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.SQLConf.SQLConfEntry._
import org.apache.spark.sql.catalyst.ParserDialect
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand}
import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand}
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand}
import org.apache.spark.sql.sources.DataSourceStrategy
Expand Down Expand Up @@ -381,7 +381,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
catalog.ParquetConversions ::
catalog.CreateTables ::
catalog.PreInsertionCasts ::
ExtractPythonUdfs ::
ExtractPythonUDFs ::
ResolveHiveWindowFunction ::
sources.PreInsertCastAndRename ::
Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1638,7 +1638,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
sys.error(s"Couldn't find function $functionName"))
val functionClassName = functionInfo.getFunctionClass.getName

(HiveGenericUdtf(
(HiveGenericUDTF(
new HiveFunctionWrapper(functionClassName),
children.map(nodeToExpr)), attributes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
val functionClassName = functionInfo.getFunctionClass.getName

if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children)
HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children)
HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children)
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children)
HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
HiveUDAF(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children)
HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children)
} else {
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
}
Expand All @@ -79,7 +79,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
throw new UnsupportedOperationException
}

private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with Logging {

type UDFType = UDF
Expand Down Expand Up @@ -146,7 +146,7 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector)
override def get(): AnyRef = wrap(func(), oi)
}

private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with Logging {
type UDFType = GenericUDF

Expand Down Expand Up @@ -413,7 +413,7 @@ private[hive] case class HiveWindowFunction(
new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children)
}

private[hive] case class HiveGenericUdaf(
private[hive] case class HiveGenericUDAF(
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression]) extends AggregateExpression
with HiveInspectors {
Expand Down Expand Up @@ -441,11 +441,11 @@ private[hive] case class HiveGenericUdaf(
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}

def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this)
def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this)
}

/** It is used as a wrapper for the hive functions which uses UDAF interface */
private[hive] case class HiveUdaf(
private[hive] case class HiveUDAF(
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression]) extends AggregateExpression
with HiveInspectors {
Expand Down Expand Up @@ -474,7 +474,7 @@ private[hive] case class HiveUdaf(
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}

def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true)
def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true)
}

/**
Expand All @@ -488,7 +488,7 @@ private[hive] case class HiveUdaf(
* Operators that require maintaining state in between input rows should instead be implemented as
* user defined aggregations, which have clean semantics even in a partitioned execution.
*/
private[hive] case class HiveGenericUdtf(
private[hive] case class HiveGenericUDTF(
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression])
extends Generator with HiveInspectors {
Expand Down Expand Up @@ -553,7 +553,7 @@ private[hive] case class HiveGenericUdtf(
}
}

private[hive] case class HiveUdafFunction(
private[hive] case class HiveUDAFFunction(
funcWrapper: HiveFunctionWrapper,
exprs: Seq[Expression],
base: AggregateExpression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
* Records the UDFs present when the server starts, so we can delete ones that are created by
* tests.
*/
protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames
protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames

/**
* Resets the test instance by deleting any tables that have been created.
Expand All @@ -410,7 +410,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
catalog.client.reset()
catalog.unregisterAllTables()

FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName =>
FunctionRegistry.getFunctionNames.filterNot(originalUDFs.contains(_)).foreach { udfName =>
FunctionRegistry.unregisterTemporaryUDF(udfName)
}

Expand Down
Loading