Skip to content

Commit

Permalink
[SPARK-43302][SQL][FOLLOWUP] Code cleanup for PythonUDAF
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This is a followup of #40739 to do some code cleanup
1. remove the pattern `PYTHON_UDAF` as it's not used by any rule.
2. add `PythonFuncExpression.evalType` for convenience: catalyst rules (including third-party extensions) may want to get the eval type of a python function, no matter it's UDF or UDAF.
3. update the python profile to use `PythonUDAF.resultId` instead of `AggregateExpression.resultId`, to be consistent with `PythonUDF`

### Why are the changes needed?

code cleanup

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

existing tests

Closes #41142 from cloud-fan/follow.

Lead-authored-by: Wenchen Fan <wenchen@databricks.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Signed-off-by: Kent Yao <yao@apache.org>
  • Loading branch information
2 people authored and yaooqinn committed May 17, 2023
1 parent 9cb3174 commit fddf25a
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 13 deletions.
4 changes: 4 additions & 0 deletions python/pyspark/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def _to_java_column(col: "ColumnOrName") -> JavaObject:
return jcol


def _to_java_expr(col: "ColumnOrName") -> JavaObject:
return _to_java_column(col).expr()


def _to_seq(
sc: SparkContext,
cols: Iterable["ColumnOrName"],
Expand Down
12 changes: 7 additions & 5 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pyspark import SparkContext
from pyspark.profiler import Profiler
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq
from pyspark.sql.types import (
ArrayType,
BinaryType,
Expand Down Expand Up @@ -419,8 +419,9 @@ def func(*args: Any, **kwargs: Any) -> Any:

func.__signature__ = inspect.signature(f) # type: ignore[attr-defined]
judf = self._create_judf(func)
jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column))
id = jPythonUDF.expr().resultId().id()
jUDFExpr = judf.builder(_to_seq(sc, cols, _to_java_expr))
jPythonUDF = judf.fromUDFExpr(jUDFExpr)
id = jUDFExpr.resultId().id()
sc.profiler_collector.add_profiler(id, profiler)
else: # memory_profiler_enabled
f = self.func
Expand All @@ -436,8 +437,9 @@ def func(*args: Any, **kwargs: Any) -> Any:

func.__signature__ = inspect.signature(f) # type: ignore[attr-defined]
judf = self._create_judf(func)
jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column))
id = jPythonUDF.expr().resultId().id()
jUDFExpr = judf.builder(_to_seq(sc, cols, _to_java_expr))
jPythonUDF = judf.fromUDFExpr(jUDFExpr)
id = jUDFExpr.resultId().id()
sc.profiler_collector.add_profiler(id, memory_profiler)
else:
judf = self._judf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.api.python.{PythonEvalType, PythonFunction}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreePattern.{PYTHON_UDAF, PYTHON_UDF, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern.{PYTHON_UDF, TreePattern}
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{DataType, StructType}
Expand Down Expand Up @@ -53,9 +53,12 @@ object PythonUDF {
trait PythonFuncExpression extends NonSQLExpression with UserDefinedExpression { self: Expression =>
def name: String
def func: PythonFunction
def evalType: Int
def udfDeterministic: Boolean
def resultId: ExprId

final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF)

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)

override def toString: String = s"$name(${children.mkString(", ")})#${resultId.id}$typeSuffix"
Expand All @@ -80,8 +83,6 @@ case class PythonUDF(
lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)(
exprId = resultId)

final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF)

override lazy val canonicalized: Expression = {
val canonicalizedChildren = children.map(_.canonicalized)
// `resultId` can be seen as cosmetic variation in PythonUDF, as it doesn't affect the result.
Expand Down Expand Up @@ -119,6 +120,8 @@ case class PythonUDAF(
resultId: ExprId = NamedExpression.newExprId)
extends UnevaluableAggregateFunc with PythonFuncExpression {

override def evalType: Int = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF

override def sql(isDistinct: Boolean): String = {
val distinct = if (isDistinct) "DISTINCT " else ""
s"$name($distinct${children.mkString(", ")})"
Expand All @@ -129,8 +132,6 @@ case class PythonUDAF(
name + children.mkString(start, ", ", ")") + s"#${resultId.id}$typeSuffix"
}

final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDAF)

override lazy val canonicalized: Expression = {
val canonicalizedChildren = children.map(_.canonicalized)
// `resultId` can be seen as cosmetic variation in PythonUDAF, as it doesn't affect the result.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ object TreePattern extends Enumeration {
val PARAMETERIZED_QUERY: Value = Value
val PIVOT: Value = Value
val PLAN_EXPRESSION: Value = Value
val PYTHON_UDAF: Value = Value
val PYTHON_UDF: Value = Value
val REGEXP_EXTRACT_FAMILY: Value = Value
val REGEXP_REPLACE: Value = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,16 @@ case class UserDefinedPythonFunction(

/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
def apply(exprs: Column*): Column = {
builder(exprs.map(_.expr)) match {
fromUDFExpr(builder(exprs.map(_.expr)))
}

/**
* Returns a [[Column]] that will evaluate the UDF expression with the given input.
*/
def fromUDFExpr(expr: Expression): Column = {
expr match {
case udaf: PythonUDAF => Column(udaf.toAggregateExpression())
case udf => Column(udf)
case _ => Column(expr)
}
}
}

0 comments on commit fddf25a

Please sign in to comment.