Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22274][PYTHON][SQL] User-defined aggregation functions with pandas udf (full shuffle) #19872

Closed
wants to merge 35 commits into from

Conversation

icexelloss
Copy link
Contributor

@icexelloss icexelloss commented Dec 4, 2017

What changes were proposed in this pull request?

Add support for using pandas UDFs with groupby().agg().

This PR introduces a new type of pandas UDF - group aggregate pandas UDF. This type of UDF defines a transformation of multiple pandas Series -> a scalar value. Group aggregate pandas UDFs can be used with groupby().agg(). Note group aggregate pandas UDF doesn't support partial aggregation, i.e., a full shuffle is required.

This PR doesn't support group aggregate pandas UDFs that return ArrayType, StructType or MapType. Support for these types is left for future PR.

How was this patch tested?

GroupbyAggPandasUDFTests

@icexelloss
Copy link
Contributor Author

cc @HyukjinKwon @holdenk @ueshin

Passing some basic tests. I will work on this more next week to clean up and add more testing.

@SparkQA
Copy link

SparkQA commented Dec 4, 2017

Test build #84414 has finished for PR 19872 at commit 4cfaf0e.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -113,6 +113,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
def apply(plan: SparkPlan): SparkPlan = plan transformUp {
// FlatMapGroupsInPandas can be evaluated directly in python worker
// Therefore we don't need to extract the UDFs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlatMapGroupsInPandas and AggregateInPandasExec can be...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

@SparkQA
Copy link

SparkQA commented Dec 4, 2017

Test build #84415 has finished for PR 19872 at commit a1058b8.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

jdf = self._jgd.aggInPandas(
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs]))
else:
jdf = self._jgd.agg(exprs[0]._jc,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If exprs[n] (n > 0) is a UDFColumn? I think we should make sure if any column is a UDFColumn, all columns should be UDFColumn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is removed.

jdf = self._jgd.agg(exprs[0]._jc,
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
if isinstance(exprs[0], UDFColumn):
assert all(isinstance(c, UDFColumn) for c in exprs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A informative error message should be better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like all exprs should be UDFColumn".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I'm a little worried about this change, if other folks have wrapped Java UDAFs (which is reasonable since there aren't other ways to make UDAFs in PySpark before this), this seems like they won't be able to mix them. I'd suggest maybe doing what @viirya suggested bellow but instead of a failure just a warning until Spark 3.

What do y'all think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still trying to figure out the best way to dispatch this, but either way I think we won't be able to fix Java UDAF with pandas UDF.

@holdenk I am not sure what kind of warning message do you have in mind. Can you please explain?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah so what your saying is you don't support mixed Python & Java UDAFs? That's certainly something which needs to be communicated in both the documentation and the error message.

Is there a reason why we don't support this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answered in #19872 (comment)

Copy link
Contributor

@holdenk holdenk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. I'm off for a flight to Strata but a few quick questions. I'll read this more over the coming week :)

@@ -56,6 +56,10 @@ def _create_udf(f, returnType, evalType):
return udf_obj._wrapped()


class UDFColumn(Column):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we add this new sub-class?

@@ -2070,6 +2070,8 @@ class PandasUDFType(object):

GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF

GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I'm worried that it isn't clear to the user that this will result in a full-shuffle with no-partial aggregation. Is there maybe a place we can document this warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in docstring of pandas_udf and groupby().agg()

Copy link
Member

@HyukjinKwon HyukjinKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought @ueshin is working on this BTW.


val argOffsets = inputs.map { input =>
input.map { e =>
allInputs += e
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation nit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks!

functionExprs: Seq[Expression],
output: Seq[Attribute],
child: LogicalPlan
) extends UnaryNode {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

    child: LogicalPlan) extends UnaryNode {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@@ -56,6 +56,10 @@ def _create_udf(f, returnType, evalType):
return udf_obj._wrapped()


class UDFColumn(Column):
Copy link
Member

@HyukjinKwon HyukjinKwon Dec 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, what do you think about adding an attribute instead in __call__ like a flag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

val childrenExpressions = exprs.flatMap(expr =>
expr.children.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation nit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed


val udfOutputs = exprs.flatMap(expr =>
Seq(AttributeReference(expr.name, expr.dataType)())
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be inlined.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

class GroupbyAggTests(ReusedSQLTestCase):
def assertFramesEqual(self, expected, result):
msg = ("DataFrames are not equal: " +
("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation nit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@icexelloss
Copy link
Contributor Author

I thought @ueshin is working on this BTW.

Oh, I certainly don't want to duplicate @ueshin 's work. I am under the impression that @ueshin is working on two-stage PySpark UDAF with pandas_udf, but I cannot really find the Jira for it...

@ueshin can you point me to what you are working on so I don't overstep?

@SparkQA
Copy link

SparkQA commented Dec 4, 2017

Test build #84446 has finished for PR 19872 at commit c1dc543.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class UDFColumn(Column):
  • case class AggregateInPandas(
  • case class AggregateInPandasExec(

@SparkQA
Copy link

SparkQA commented Dec 8, 2017

Test build #84628 has finished for PR 19872 at commit 3352050.

  • This patch fails Python style tests.
  • This patch does not merge cleanly.
  • This patch adds no public classes.

@icexelloss
Copy link
Contributor Author

icexelloss commented Dec 8, 2017

I end up removing UDFColumn class and using the existing Aggregate logical plan for pandas group_agg UDF. I also move the dispatch logic to SparkStrategy. This reuses a lot of code being to existing Aggregate and minimize the code changes needed for pandas group_agg UDF.

The code works and three tests (test_basic, test_alias, test_multiple) passes now but the code is kind of messy. I am going on vacation next week but I will clean up the code and move this PR forward when I get back (Dec 16).

Thanks all.

@icexelloss
Copy link
Contributor Author

And to @holdenk 's question. Pandas group_agg udf fundamentally uses different physical plan than the existing java/scala udf and therefore it's hard to combine them together. I don't know a good way to do this, the closest is maybe to compute java/scala and python aggregation separately and join them together.

@SparkQA
Copy link

SparkQA commented Dec 8, 2017

Test build #84630 has finished for PR 19872 at commit 184b37f.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 8, 2017

Test build #84631 has finished for PR 19872 at commit 4332f28.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 8, 2017

Test build #84632 has finished for PR 19872 at commit 37eff29.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@ueshin
Copy link
Member

ueshin commented Dec 8, 2017

@icexelloss I'm sorry for the late response.
Actually I tried to implement prototypes of Pandas UDAF with partial aggregation and combining existing aggregate functions, but they are still much complicated (ueshin#2, ueshin#3, ueshin#4). I was thinking about easier way to achieve that but not yet.
I've not looked into this pr yet but I guess we can start this pr and pick some functionalities from my prototypes if needed.

@@ -32,7 +31,5 @@ case class PythonUDF(
evalType: Int)
extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {

override def toString: String = s"$name(${children.mkString(", ")})"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, my bad, adding back

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added back

@@ -4016,6 +4016,124 @@ def test_unsupported_types(self):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.groupby('id').apply(f).collect()

@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class GroupbyAggTests(ReusedSQLTestCase):
def assertFramesEqual(self, expected, result):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how about making this the common method?

val joined = new JoinedRow
val resultProj = UnsafeProjection.create(output, output)

columnarBatchIter.map(_.rowIterator.next()).map{ outputRow =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: columnarBatchIter.flatMap(_.rowIterator)?
nit: style, add a space between map and { outputRow =>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

columnarBatchIter.flatMap(_.rowIterator)

Doesn't work because rowIterator is a java iterator not a scala iterator, we can convert it, but I am not sure it's better though. @ueshin if you prefer the flatMap one I can change it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant columnarBatchIter.flatMap(_.rowIterator.asScala). I'd prefer this one.

@@ -48,9 +48,26 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
}.isDefined
}

private def isPandasGroupAggUdf(expr: Expression): Boolean = expr match {
case _ @ PythonUDF(_, _, _, _, PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF ) => true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need _ @ here.
nit: remove extra space after SQL_PANDAS_GROUP_AGG_UDF.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

if (hasPandasGroupAggUdf(agg)) {
Aggregate(agg.groupingExpressions, agg.aggregateExpressions, agg.child)
} else {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: style, we need indent for this block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@@ -15,10 +15,9 @@
* limitations under the License.
*/

package org.apache.spark.sql.execution.python
package org.apache.spark.sql.catalyst.expressions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to move package to catalyst?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do. This is similar to https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

The reason is we need to access the class PythonUDF in analyzer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks!

@SparkQA
Copy link

SparkQA commented Dec 19, 2017

Test build #85136 has finished for PR 19872 at commit ab91314.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 20, 2017

Test build #85137 has finished for PR 19872 at commit 1a197b7.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 20, 2017

Test build #85138 has finished for PR 19872 at commit 62c8f00.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@ueshin
Copy link
Member

ueshin commented Dec 20, 2017

@ramacode2014 Hi, I'm not sure why you received notifications from this PR, but I guess you can unsubscribe by the "Unsubscribe" button in the right column of this page. Sorry for the inconvenience. Thanks!

alias.toAttribute

if (hasPandasGroupAggUdf(agg)) {
Aggregate(agg.groupingExpressions, agg.aggregateExpressions, agg.child)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to copy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. But I added copy in ExtractGroupAggPandasUDFFromAggregate similar to existing rules.

}

private def hasPandasGroupAggUdf(agg: Aggregate): Boolean = {
val actualAggExpr = agg.aggregateExpressions.drop(agg.groupingExpressions.length)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to drop the grouping expressions?
If we need, we can drop them only if conf.dataFrameRetainGroupColumns == true, otherwise aggregateExpressions doesn't contain groupingExpressions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed. Added test_retain_grouping_columns test

val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]

allInputs.appendAll(groupingExpressions)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't need to append groupingExpressions. Seems like they are dropped later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed.

.compute(projectedRowIter, context.partitionId(), context)

val joined = new JoinedRow
val resultProj = UnsafeProjection.create(output, output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to handle resultExpressions for the following cases:

    def test_result_expressions(self):
        import numpy as np
        from pyspark.sql.functions import mean, pandas_udf, PandasUDFType

        df = self.data

        @pandas_udf('double', PandasUDFType.GROUP_AGG)
        def mean_udf(v, w):
            return np.average(v, weights=w)

        result1 = (df.groupby('id')
                   .agg(mean_udf(df.v, lit(1.0)) + 1)
                   .sort('id')
                   .toPandas())

        expected1 = (df.groupby('id')
                     .agg(mean(df.v) + 1)
                     .sort('id')
                     .toPandas())

        self.assertPandasEqual(expected1, result1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ueshin for reminding me of this. Just want to clarify the semantics:

Does

 .agg(mean(df.v) + 1)

mean "compute mean of df.v and plus the mean by one as output", i.e, same as

.agg(mean(df.v).alias('mean'))
.withColumn('mean', col('mean') + 1)

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so about the behavior. I guess the plan could be different, though.
We can compare the behavior with non-udf aggregation and let's follow the behavior.

Copy link
Contributor Author

@icexelloss icexelloss Dec 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added ExtractGroupAggPandasUDFFromAggregate rule to deal with this

actualAggExpr.exists(isPandasGroupAggUdf)
}


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove an extra line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@SparkQA
Copy link

SparkQA commented Dec 20, 2017

Test build #85152 has finished for PR 19872 at commit ea5d6f3.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@icexelloss
Copy link
Contributor Author

@ueshin I pushed some more change to address your comments. There is one regression in existing test SQLTests.test_udf_with_aggregate_function. I will try to fix it tomorrow.

@SparkQA
Copy link

SparkQA commented Dec 28, 2017

Test build #85442 has finished for PR 19872 at commit 99367a6.

  • This patch fails PySpark unit tests.
  • This patch does not merge cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 28, 2017

Test build #85446 has finished for PR 19872 at commit 66a31f9.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 18, 2018

Test build #86345 has finished for PR 19872 at commit 17fad5c.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 18, 2018

Test build #86344 has finished for PR 19872 at commit a94b146.

  • This patch fails to build.
  • This patch does not merge cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 18, 2018

Test build #86346 has finished for PR 19872 at commit 0fec5cf.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 18, 2018

Test build #86350 has finished for PR 19872 at commit 4d22107.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@icexelloss
Copy link
Contributor Author

@ueshin I think all comments are addressed. Can you take a final look? Thanks!

Copy link
Member

@ueshin ueshin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to add PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF to udf.py#L40-L41 to pass require_minimum_pyarrow_version().

LGTM except for the comments.

Btw, I'm afraid I guess we shouldn't merge this into branch-2.3 since we are already close to release 2.3.
WDYT? @HyukjinKwon @cloud-fan

3. GROUP_AGG

A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar
The `returnType` should be a primitive data type, e.g, :class:`DoubleType`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very small nit: e.g. instead of e.g.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks!

@HyukjinKwon
Copy link
Member

+1 for master-only. We can cherry-pick and backport if we should even after this gets merged anyway. For a reminder, we should complete the doc #19575 too.

@icexelloss
Copy link
Contributor Author

Addressed latest comments. Yeah I think master only is fine.

@SparkQA
Copy link

SparkQA commented Jan 22, 2018

Test build #86487 has finished for PR 19872 at commit 91885e5.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 22, 2018

Test build #86492 has finished for PR 19872 at commit cc659bc.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@ueshin
Copy link
Member

ueshin commented Jan 23, 2018

Thanks! merging to master.

@asfgit asfgit closed this in b2ce17b Jan 23, 2018
@icexelloss
Copy link
Contributor Author

Thanks all for review!

@icexelloss icexelloss deleted the SPARK-22274-groupby-agg branch January 26, 2018 21:05
@icexelloss icexelloss restored the SPARK-22274-groupby-agg branch January 26, 2018 21:05
@@ -199,7 +200,7 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
object PhysicalAggregation {
// groupingExpressions, aggregateExpressions, resultExpressions, child
type ReturnType =
(Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan)
(Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@icexelloss Thank you for this contribution! I just came across the change in this file. I am not sure if changing the type at here is the best option. The reason is that whenever we use this PhysicalAggregation rule, we have to check the instance type of those aggregate expressions and do casting. To me, it seems better to leave this rule untouched and create a new rule just for Python UDAF. What do you think?

(maybe you and reviewers already discussed it. If so, can you point me to the discussion?)

Thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yhuai,

You bring up a good point. I agree with you ideally we should avoid doing. When I was making the change, I found the solution implemented results in least amount of duplicate code, because a lot of logic is shared between AggregateExpression and Python UDF, but the downside is exactly what you mentioned.

One alternative is to create new rules for Python UDAF, my concern is that could result in quite a bit of code duplication. Maybe there is a way to avoid code duplication and keep the type safety, I am happy to explore the option. (Maybe create a parent class for AggregateExpression and Python UDAF)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer that we try out using a new rule. We can create utility function to reuse code. Will you have a chance to try it out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yhuai Yeah I can certainly try it out. Created https://issues.apache.org/jira/browse/SPARK-23302 to track.

I assume this is not urgent?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be good to try it out soon. But it is not urgent.

from pyspark.sql.functions import pandas_udf, PandasUDFType

with QuietTest(self.sc):
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@icexelloss This line does not compile ( we need assertRaisesRegexp). Can you file a pr to fix it? Thanks! Meanwhile, we will look into jenkins setup and see why the test was not exercised.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll file the follow-up pr to fix it soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed #20467. Thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yhuai, if you meant not running tests in Python 2, this link might be helpful. Let me leave it just in case - #19884 (comment).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin Thanks for fixing this. (I am late to the party)

ghost pushed a commit to dbtsai/spark that referenced this pull request Feb 1, 2018
… of `assertRaisesRegex`.

## What changes were proposed in this pull request?

This is a follow-up pr of apache#19872 which uses `assertRaisesRegex` but it doesn't exist in Python 2, so some tests fail when running tests in Python 2 environment.
Unfortunately, we missed it because currently Python 2 environment of the pr builder doesn't have proper versions of pandas or pyarrow, so the tests were skipped.

This pr modifies to use `assertRaisesRegexp` instead of `assertRaisesRegex`.

## How was this patch tested?

Tested manually in my local environment.

Author: Takuya UESHIN <ueshin@databricks.com>

Closes apache#20467 from ueshin/issues/SPARK-22274/fup1.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
8 participants