Skip to content

Commit

Permalink
[SPARK-21848][SQL] Add trait UserDefinedExpression to identify user-d…
Browse files Browse the repository at this point in the history
…efined functions

## What changes were proposed in this pull request?

Add trait UserDefinedExpression to identify user-defined functions.
UDF can be expensive. In optimizer we may need to avoid executing UDF multiple times.
E.g.
```scala
table.select(UDF as 'a).select('a, ('a + 1) as 'b)
```
If UDF is expensive in this case, optimizer should not collapse the project to
```scala
table.select(UDF as 'a, (UDF+1) as 'b)
```

Currently UDF classes like PythonUDF, HiveGenericUDF are not defined in catalyst.
This PR is to add a new trait to make it easier to identify user-defined functions.

## How was this patch tested?

Unit test

Author: Wang Gengliang <ltnwgl@gmail.com>

Closes #19064 from gengliangwang/UDFType.
  • Loading branch information
gengliangwang authored and gatorsmile committed Aug 29, 2017
1 parent 32fa0b8 commit 8fcbda9
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -635,3 +635,9 @@ abstract class TernaryExpression extends Expression {
}
}
}

/**
* Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages
* and Hive function wrappers.
*/
trait UserDefinedExpression
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ case class ScalaUDF(
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true)
extends Expression with ImplicitCastInputTypes with NonSQLExpression {
extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {

override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,11 @@ case class ScalaUDAF(
udaf: UserDefinedAggregateFunction,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes {
extends ImperativeAggregate
with NonSQLExpression
with Logging
with ImplicitCastInputTypes
with UserDefinedExpression {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.python

import org.apache.spark.api.python.PythonFunction
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression}
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -29,7 +29,7 @@ case class PythonUDF(
func: PythonFunction,
dataType: DataType,
children: Seq[Expression])
extends Expression with Unevaluable with NonSQLExpression {
extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {

override def toString: String = s"$name(${children.mkString(", ")})"

Expand Down
18 changes: 14 additions & 4 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ import org.apache.spark.sql.types._

private[hive] case class HiveSimpleUDF(
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with CodegenFallback with Logging {
extends Expression
with HiveInspectors
with CodegenFallback
with Logging
with UserDefinedExpression {

override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic)

Expand Down Expand Up @@ -119,7 +123,11 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp

private[hive] case class HiveGenericUDF(
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with CodegenFallback with Logging {
extends Expression
with HiveInspectors
with CodegenFallback
with Logging
with UserDefinedExpression {

override def nullable: Boolean = true

Expand Down Expand Up @@ -191,7 +199,7 @@ private[hive] case class HiveGenericUDTF(
name: String,
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression])
extends Generator with HiveInspectors with CodegenFallback {
extends Generator with HiveInspectors with CodegenFallback with UserDefinedExpression {

@transient
protected lazy val function: GenericUDTF = {
Expand Down Expand Up @@ -303,7 +311,9 @@ private[hive] case class HiveUDAFFunction(
isUDAFBridgeRequired: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors {
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
with HiveInspectors
with UserDefinedExpression {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
Expand Down

0 comments on commit 8fcbda9

Please sign in to comment.