[SPARK-41633][SQL] Identify aggregation expressions in the nodePatterns of PythonUDF#39142
[SPARK-41633][SQL] Identify aggregation expressions in the nodePatterns of PythonUDF#39142gengliangwang wants to merge 4 commits intoapache:masterfrom
Conversation
| override def toString: String = s"$name(${children.mkString(", ")})#${resultId.id}$typeSuffix" | ||
|
|
||
| final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF) | ||
| private def nodePatternsOfPythonFunction: Option[TreePattern] = { |
There was a problem hiding this comment.
just to think about future proofing - how do we prevent us from adding another expression in the future that can be an aggregate expression but doesn't update this location?
There was a problem hiding this comment.
Good point. I just did a simple refactoring on the method PythonUDF.isGroupedAggPandasUDFEvalType to check whether an eval type is a pandas aggregation function.
I added a comment for developers as well.
I asked @sigmod a similar question while reviewing PRs under https://issues.apache.org/jira/browse/SPARK-35042. The answer is developer needs to maintain the node patterns and create more test cases to make sure no regressions happens.
We can add a test case to go over all the PythonEvalType as well. But given it is written in the following way
private[spark] object PythonEvalType {
val NON_UDF = 0
val SQL_BATCHED_UDF = 100
val SQL_SCALAR_PANDAS_UDF = 200
val SQL_GROUPED_MAP_PANDAS_UDF = 201
val SQL_GROUPED_AGG_PANDAS_UDF = 202
val SQL_WINDOW_AGG_PANDAS_UDF = 203
val SQL_SCALAR_PANDAS_ITER_UDF = 204
val SQL_MAP_PANDAS_ITER_UDF = 205
val SQL_COGROUPED_MAP_PANDAS_UDF = 206
val SQL_MAP_ARROW_ITER_UDF = 207
val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208
I am still thinking about how to refactor and write tests to enumerate all the types for future-proofing.
| // SPARK-41633: We should check whether to update the node patterns when adding a new eval type. | ||
| private def nodePatternsOfPythonFunction: Seq[TreePattern] = { | ||
| if (PythonUDF.isGroupedAggPandasUDFEvalType(evalType)) { | ||
| Seq(AGGREGATE_EXPRESSION) |
There was a problem hiding this comment.
I think it's better to make node pattern match the node type. We created node patterns to speed up isInstanceOf check, not to replace it. Even if we fix the node pattern here, expr.exists(_.isInstanceOf[AggregateExpression]) is still broken.
I think we should have a PythonUDAF that indicates aggregate function explicitly, instead of relying on node pattern.
There was a problem hiding this comment.
I agree that we should have a subtype for aggregate. That's a much safer change. Should've done that from the beginning IMO.
|
Due to comments in #39142 (comment), I am closing this one |
What changes were proposed in this pull request?
For Python UDF with aggregation functions, we can mark them as
AGGREGATE_EXPRESSIONin thenodePatternsof PythonUDF. So that we can check whether an expression contains aggregation in a handy way:expr.containsPattern(AGGREGATE_EXPRESSION)Why are the changes needed?
Address comment in #39134 (comment). This should provide a handy way for checking whether an expression contains aggregations.
Does this PR introduce any user-facing change?
No
How was this patch tested?
New UT