Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 77 additions & 10 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,12 @@ def __init__(self, key, value):
self.value = value


class ReusedSQLTestCase(ReusedPySparkTestCase):
@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.spark = SparkSession(cls.sc)

@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
cls.spark.stop()
class SQLTestUtils(object):
"""
This util assumes the instance of this to have 'spark' attribute, having a spark session.
It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the
the implementation of this class has 'spark' attribute.
"""

@contextmanager
def sql_conf(self, pairs):
Expand All @@ -204,6 +200,7 @@ def sql_conf(self, pairs):
`value` to the configuration `key` and then restores it back when it exits.
"""
assert isinstance(pairs, dict), "pairs should be a dictionary."
assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."

keys = pairs.keys()
new_values = pairs.values()
Expand All @@ -219,6 +216,18 @@ def sql_conf(self, pairs):
else:
self.spark.conf.set(key, old_value)


class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.spark = SparkSession(cls.sc)

@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
cls.spark.stop()

def assertPandasEqual(self, expected, result):
msg = ("DataFrames are not equal: " +
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
Expand Down Expand Up @@ -3062,6 +3071,64 @@ def test_sparksession_with_stopped_sparkcontext(self):
sc.stop()


class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
# These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is
# static and immutable. This can't be set or unset, for example, via `spark.conf`.

@classmethod
def setUpClass(cls):
import glob
from pyspark.find_spark_home import _find_spark_home

SPARK_HOME = _find_spark_home()
filename_pattern = (
"sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
"TestQueryExecutionListener.class")
if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
raise unittest.SkipTest(
Copy link
Member

@viirya viirya Apr 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this part. What is the case we can't find the class? TestQueryExecutionListener.scala has been removed or moved? If it happens, should we just silently skip this test like this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, nope. It's when we do sbt package, according to https://spark.apache.org/docs/latest/building-spark.html#building-with-sbt. In this case, test files are not actually compiled. If we run the tests, it'd hit some exceptions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I admit It's rare. But I believe this is more correct. In fact, there are few test cases actually taking care about this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and .. for

If it happens, should we just silently skip this test like this?

Yea, ideally we should warn explicitly in the console. The problem is about our own testing script .. We could make some changes to explicitly warn but seems we need some duplicated changes.

There are some discussions / changes going on here - #20909

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I see. Makes sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @viirya. I know this one is a rather tricky one to judge what's righter. Will maybe cc you when we actually discuss about this further. I believe some people could think differently and I might have to have more discussion. But for now, I feel sure on this.

"'org.apache.spark.sql.TestQueryExecutionListener' is not "
"available. Will skip the related tests.")

# Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration.
cls.spark = SparkSession.builder \
.master("local[4]") \
.appName(cls.__name__) \
.config(
"spark.sql.queryExecutionListeners",
"org.apache.spark.sql.TestQueryExecutionListener") \
.getOrCreate()

@classmethod
def tearDownClass(cls):
cls.spark.stop()

def tearDown(self):
self.spark._jvm.OnSuccessCall.clear()

def test_query_execution_listener_on_collect(self):
self.assertFalse(
self.spark._jvm.OnSuccessCall.isCalled(),
"The callback from the query execution listener should not be called before 'collect'")
self.spark.sql("SELECT * FROM range(1)").collect()
self.assertTrue(
self.spark._jvm.OnSuccessCall.isCalled(),
"The callback from the query execution listener should be called after 'collect'")

@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
def test_query_execution_listener_on_collect_with_arrow(self):
with self.sql_conf({"spark.sql.execution.arrow.enabled": True}):
self.assertFalse(
self.spark._jvm.OnSuccessCall.isCalled(),
"The callback from the query execution listener should not be "
"called before 'toPandas'")
self.spark.sql("SELECT * FROM range(1)").toPandas()
self.assertTrue(
self.spark._jvm.OnSuccessCall.isCalled(),
"The callback from the query execution listener should be called after 'toPandas'")


class SparkSessionTests(PySparkTestCase):

# This test is separate because it's closely related with session's start and stop.
Expand Down
20 changes: 13 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3189,10 +3189,10 @@ class Dataset[T] private[sql](

private[sql] def collectToPython(): Int = {
EvaluatePython.registerPicklers()
withNewExecutionId {
withAction("collectToPython", queryExecution) { plan =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes can cause the behavior changes. Please submit a PR to document it.

val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
val iter = new SerDeUtil.AutoBatchedPickler(
queryExecution.executedPlan.executeCollect().iterator.map(toJava))
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
plan.executeCollect().iterator.map(toJava))
PythonRDD.serveIterator(iter, "serve-DataFrame")
}
}
Expand All @@ -3201,8 +3201,9 @@ class Dataset[T] private[sql](
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
*/
private[sql] def collectAsArrowToPython(): Int = {
withNewExecutionId {
val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable)
withAction("collectAsArrowToPython", queryExecution) { plan =>
val iter: Iterator[Array[Byte]] =
toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
PythonRDD.serveIterator(iter, "serve-Arrow")
}
}
Expand Down Expand Up @@ -3311,14 +3312,19 @@ class Dataset[T] private[sql](
}

/** Convert to an RDD of ArrowPayload byte arrays */
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
queryExecution.toRdd.mapPartitionsInternal { iter =>
plan.execute().mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toPayloadIterator(
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
}
}

// This is only used in tests, for now.
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
toArrowPayload(queryExecution.executedPlan)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's possible. Just took a look; however, mind if I had a separate one as is for Python test specifically? maybe I am too much worried but thinking about having a dependency with a class in a suite and I am a bit hesitant.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that's fine. Thanks for putting a comment in the class for what it is for.

* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.QueryExecutionListener


class TestQueryExecutionListener extends QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
OnSuccessCall.isOnSuccessCalled.set(true)
}

override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { }
}

/**
* This has a variable to check if `onSuccess` is actually called or not. Currently, this is for
* the test case in PySpark. See SPARK-23942.
*/
object OnSuccessCall {
val isOnSuccessCalled = new AtomicBoolean(false)

def isCalled(): Boolean = isOnSuccessCalled.get()

def clear(): Unit = isOnSuccessCalled.set(false)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need a newline at the end?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, it already has. github shows a warning and mark on this UI if it doesn't IIRC.