Skip to content
Permalink
Browse files

[SPARK-23942][PYTHON][SQL][BRANCH-2.3] Makes collect in PySpark as ac…

…tion for a query executor listener

## What changes were proposed in this pull request?

This PR proposes to add `collect` to  a query executor as an action.

Seems `collect` / `collect` with Arrow are not recognised via `QueryExecutionListener` as an action. For example, if we have a custom listener as below:

```scala
package org.apache.spark.sql

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.QueryExecutionListener

class TestQueryExecutionListener extends QueryExecutionListener with Logging {
  override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
    logError("Look at me! I'm 'onSuccess'")
  }

  override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { }
}
```
and set `spark.sql.queryExecutionListeners` to `org.apache.spark.sql.TestQueryExecutionListener`

Other operations in PySpark or Scala side seems fine:

```python
>>> sql("SELECT * FROM range(1)").show()
```
```
18/04/09 17:02:04 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess'
+---+
| id|
+---+
|  0|
+---+
```

```scala
scala> sql("SELECT * FROM range(1)").collect()
```
```
18/04/09 16:58:41 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess'
res1: Array[org.apache.spark.sql.Row] = Array([0])
```

but ..

**Before**

```python
>>> sql("SELECT * FROM range(1)").collect()
```
```
[Row(id=0)]
```

```python
>>> spark.conf.set("spark.sql.execution.arrow.enabled", "true")
>>> sql("SELECT * FROM range(1)").toPandas()
```
```
   id
0   0
```

**After**

```python
>>> sql("SELECT * FROM range(1)").collect()
```
```
18/04/09 16:57:58 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess'
[Row(id=0)]
```

```python
>>> spark.conf.set("spark.sql.execution.arrow.enabled", "true")
>>> sql("SELECT * FROM range(1)").toPandas()
```
```
18/04/09 17:53:26 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess'
   id
0   0
```

## How was this patch tested?

I have manually tested as described above and unit test was added.

Author: hyukjinkwon <gurwls223@apache.org>

Closes #21060 from HyukjinKwon/PR_TOOL_PICK_PR_21007_BRANCH-2.3.
  • Loading branch information...
HyukjinKwon committed Apr 14, 2018
1 parent dfdf1bb commit d4f204c5321cdc3955a48e9717ba06aaebbc2ab4
@@ -185,22 +185,12 @@ def __init__(self, key, value):
self.value = value


class ReusedSQLTestCase(ReusedPySparkTestCase):
@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.spark = SparkSession(cls.sc)

@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
cls.spark.stop()

def assertPandasEqual(self, expected, result):
msg = ("DataFrames are not equal: " +
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
"\n\nResult:\n%s\n%s" % (result, result.dtypes))
self.assertTrue(expected.equals(result), msg=msg)
class SQLTestUtils(object):
"""
This util assumes the instance of this to have 'spark' attribute, having a spark session.
It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the
the implementation of this class has 'spark' attribute.
"""

@contextmanager
def sql_conf(self, pairs):
@@ -209,6 +199,7 @@ def sql_conf(self, pairs):
`value` to the configuration `key` and then restores it back when it exits.
"""
assert isinstance(pairs, dict), "pairs should be a dictionary."
assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."

keys = pairs.keys()
new_values = pairs.values()
@@ -225,6 +216,24 @@ def sql_conf(self, pairs):
self.spark.conf.set(key, old_value)


class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.spark = SparkSession(cls.sc)

@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
cls.spark.stop()

def assertPandasEqual(self, expected, result):
msg = ("DataFrames are not equal: " +
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
"\n\nResult:\n%s\n%s" % (result, result.dtypes))
self.assertTrue(expected.equals(result), msg=msg)


class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
def test_data_type_eq(self):
@@ -2980,6 +2989,64 @@ def test_sparksession_with_stopped_sparkcontext(self):
sc.stop()


class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
# These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is
# static and immutable. This can't be set or unset, for example, via `spark.conf`.

@classmethod
def setUpClass(cls):
import glob
from pyspark.find_spark_home import _find_spark_home

SPARK_HOME = _find_spark_home()
filename_pattern = (
"sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
"TestQueryExecutionListener.class")
if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
raise unittest.SkipTest(
"'org.apache.spark.sql.TestQueryExecutionListener' is not "
"available. Will skip the related tests.")

# Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration.
cls.spark = SparkSession.builder \
.master("local[4]") \
.appName(cls.__name__) \
.config(
"spark.sql.queryExecutionListeners",
"org.apache.spark.sql.TestQueryExecutionListener") \
.getOrCreate()

@classmethod
def tearDownClass(cls):
cls.spark.stop()

def tearDown(self):
self.spark._jvm.OnSuccessCall.clear()

def test_query_execution_listener_on_collect(self):
self.assertFalse(
self.spark._jvm.OnSuccessCall.isCalled(),
"The callback from the query execution listener should not be called before 'collect'")
self.spark.sql("SELECT * FROM range(1)").collect()
self.assertTrue(
self.spark._jvm.OnSuccessCall.isCalled(),
"The callback from the query execution listener should be called after 'collect'")

@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
def test_query_execution_listener_on_collect_with_arrow(self):
with self.sql_conf({"spark.sql.execution.arrow.enabled": True}):
self.assertFalse(
self.spark._jvm.OnSuccessCall.isCalled(),
"The callback from the query execution listener should not be "
"called before 'toPandas'")
self.spark.sql("SELECT * FROM range(1)").toPandas()
self.assertTrue(
self.spark._jvm.OnSuccessCall.isCalled(),
"The callback from the query execution listener should be called after 'toPandas'")


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
if SparkSession._instantiatedSession is not None:
@@ -3189,10 +3189,10 @@ class Dataset[T] private[sql](

private[sql] def collectToPython(): Int = {
EvaluatePython.registerPicklers()
withNewExecutionId {
withAction("collectToPython", queryExecution) { plan =>
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
val iter = new SerDeUtil.AutoBatchedPickler(
queryExecution.executedPlan.executeCollect().iterator.map(toJava))
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
plan.executeCollect().iterator.map(toJava))
PythonRDD.serveIterator(iter, "serve-DataFrame")
}
}
@@ -3201,8 +3201,9 @@ class Dataset[T] private[sql](
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
*/
private[sql] def collectAsArrowToPython(): Int = {
withNewExecutionId {
val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable)
withAction("collectAsArrowToPython", queryExecution) { plan =>
val iter: Iterator[Array[Byte]] =
toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
PythonRDD.serveIterator(iter, "serve-Arrow")
}
}
@@ -3311,14 +3312,19 @@ class Dataset[T] private[sql](
}

/** Convert to an RDD of ArrowPayload byte arrays */
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
queryExecution.toRdd.mapPartitionsInternal { iter =>
plan.execute().mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toPayloadIterator(
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
}
}

// This is only used in tests, for now.
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
toArrowPayload(queryExecution.executedPlan)
}
}
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.QueryExecutionListener


class TestQueryExecutionListener extends QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
OnSuccessCall.isOnSuccessCalled.set(true)
}

override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { }
}

/**
* This has a variable to check if `onSuccess` is actually called or not. Currently, this is for
* the test case in PySpark. See SPARK-23942.
*/
object OnSuccessCall {
val isOnSuccessCalled = new AtomicBoolean(false)

def isCalled(): Boolean = isOnSuccessCalled.get()

def clear(): Unit = isOnSuccessCalled.set(false)
}

0 comments on commit d4f204c

Please sign in to comment.
You can’t perform that action at this time.