Skip to content

Commit

Permalink
[SPARK-45733][PYTHON][TESTS][FOLLOWUP] Skip `pyspark.sql.tests.connec…
Browse files Browse the repository at this point in the history
…t.client.test_client` if not should_test_connect

### What changes were proposed in this pull request?

This is a follow-up of the following.
- #43591

### Why are the changes needed?

This test requires `pandas` which is an optional dependency in Apache Spark.

```
$ python/run-tests --modules=pyspark-connect --parallelism=1 --python-executables=python3.10  --testnames 'pyspark.sql.tests.connect.client.test_client'
Running PySpark tests. Output is in /Users/dongjoon/APACHE/spark-merge/python/unit-tests.log
Will test against the following Python executables: ['python3.10']
Will test the following Python tests: ['pyspark.sql.tests.connect.client.test_client']
python3.10 python_implementation is CPython
python3.10 version is: Python 3.10.13
Starting test(python3.10): pyspark.sql.tests.connect.client.test_client (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/216a8716-3a1f-4cf9-9c7c-63087f29f892/python3.10__pyspark.sql.tests.connect.client.test_client__tydue4ck.log)
Traceback (most recent call last):
  File "/Users/dongjoon/.pyenv/versions/3.10.13/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/dongjoon/.pyenv/versions/3.10.13/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/dongjoon/APACHE/spark-merge/python/pyspark/sql/tests/connect/client/test_client.py", line 137, in <module>
    class TestPolicy(DefaultPolicy):
NameError: name 'DefaultPolicy' is not defined
```

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

Pass the CIs and manually test without `pandas`.
```
$ pip3 uninstall pandas
$ python/run-tests --modules=pyspark-connect --parallelism=1 --python-executables=python3.10  --testnames 'pyspark.sql.tests.connect.client.test_client'
Running PySpark tests. Output is in /Users/dongjoon/APACHE/spark-merge/python/unit-tests.log
Will test against the following Python executables: ['python3.10']
Will test the following Python tests: ['pyspark.sql.tests.connect.client.test_client']
python3.10 python_implementation is CPython
python3.10 version is: Python 3.10.13
Starting test(python3.10): pyspark.sql.tests.connect.client.test_client (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/acf07ed5-938a-4272-87e1-47e3bf8b988e/python3.10__pyspark.sql.tests.connect.client.test_client__sfdosnek.log)
Finished test(python3.10): pyspark.sql.tests.connect.client.test_client (0s) ... 13 tests were skipped
Tests passed in 0 seconds

Skipped tests in pyspark.sql.tests.connect.client.test_client with python3.10:
      test_basic_flow (pyspark.sql.tests.connect.client.test_client.SparkConnectClientReattachTestCase) ... skip (0.002s)
      test_fail_and_retry_during_execute (pyspark.sql.tests.connect.client.test_client.SparkConnectClientReattachTestCase) ... skip (0.000s)
      test_fail_and_retry_during_reattach (pyspark.sql.tests.connect.client.test_client.SparkConnectClientReattachTestCase) ... skip (0.000s)
      test_fail_during_execute (pyspark.sql.tests.connect.client.test_client.SparkConnectClientReattachTestCase) ... skip (0.000s)
      test_channel_builder (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s)
      test_channel_builder_with_session (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s)
      test_interrupt_all (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s)
      test_is_closed (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s)
      test_properties (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s)
      test_retry (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s)
      test_retry_client_unit (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s)
      test_user_agent_default (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s)
      test_user_agent_passthrough (pyspark.sql.tests.connect.client.test_client.SparkConnectClientTestCase) ... skip (0.000s)
```

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #45830 from dongjoon-hyun/SPARK-45733.

Authored-by: Dongjoon Hyun <dhyun@apple.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
dongjoon-hyun committed Apr 3, 2024
1 parent 344f640 commit 360a3f9
Showing 1 changed file with 110 additions and 115 deletions.
225 changes: 110 additions & 115 deletions python/pyspark/sql/tests/connect/client/test_client.py
Expand Up @@ -36,6 +36,116 @@
from pyspark.errors import RetriesExceeded
import pyspark.sql.connect.proto as proto

class TestPolicy(DefaultPolicy):
def __init__(self):
super().__init__(
max_retries=3,
backoff_multiplier=4.0,
initial_backoff=10,
max_backoff=10,
jitter=10,
min_jitter_threshold=10,
)

class TestException(grpc.RpcError, grpc.Call):
"""Exception mock to test retryable exceptions."""

def __init__(self, msg, code=grpc.StatusCode.INTERNAL):
self.msg = msg
self._code = code

def code(self):
return self._code

def __str__(self):
return self.msg

def trailing_metadata(self):
return ()

class ResponseGenerator(Generator):
"""This class is used to generate values that are returned by the streaming
iterator of the GRPC stub."""

def __init__(self, funs):
self._funs = funs
self._iterator = iter(self._funs)

def send(self, value: Any) -> proto.ExecutePlanResponse:
val = next(self._iterator)
if callable(val):
return val()
else:
return val

def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> Any:
super().throw(type, value, traceback)

def close(self) -> None:
return super().close()

class MockSparkConnectStub:
"""Simple mock class for the GRPC stub used by the re-attachable execution."""

def __init__(self, execute_ops=None, attach_ops=None):
self._execute_ops = execute_ops
self._attach_ops = attach_ops
# Call counters
self.execute_calls = 0
self.release_calls = 0
self.release_until_calls = 0
self.attach_calls = 0

def ExecutePlan(self, *args, **kwargs):
self.execute_calls += 1
return self._execute_ops

def ReattachExecute(self, *args, **kwargs):
self.attach_calls += 1
return self._attach_ops

def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, **kwargs):
if req.HasField("release_all"):
self.release_calls += 1
elif req.HasField("release_until"):
print("increment")
self.release_until_calls += 1

class MockService:
# Simplest mock of the SparkConnectService.
# If this needs more complex logic, it needs to be replaced with Python mocking.

req: Optional[proto.ExecutePlanRequest]

def __init__(self, session_id: str):
self._session_id = session_id
self.req = None

def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
self.req = req
resp = proto.ExecutePlanResponse()
resp.session_id = self._session_id

pdf = pd.DataFrame(data={"col1": [1, 2]})
schema = pa.Schema.from_pandas(pdf)
table = pa.Table.from_pandas(pdf)
sink = pa.BufferOutputStream()

writer = pa.ipc.new_stream(sink, schema=schema)
writer.write(table)
writer.close()

buf = sink.getvalue()
resp.arrow_batch.data = buf.to_pybytes()
resp.arrow_batch.row_count = 2
return [resp]

def Interrupt(self, req: proto.InterruptRequest, metadata):
self.req = req
resp = proto.InterruptResponse()
resp.session_id = self._session_id
return resp


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectClientTestCase(unittest.TestCase):
Expand Down Expand Up @@ -134,18 +244,6 @@ def test_channel_builder_with_session(self):
self.assertEqual(client._session_id, chan.session_id)


class TestPolicy(DefaultPolicy):
def __init__(self):
super().__init__(
max_retries=3,
backoff_multiplier=4.0,
initial_backoff=10,
max_backoff=10,
jitter=10,
min_jitter_threshold=10,
)


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectClientReattachTestCase(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -243,109 +341,6 @@ def check():
eventually(timeout=1, catch_assertions=True)(check)()


class TestException(grpc.RpcError, grpc.Call):
"""Exception mock to test retryable exceptions."""

def __init__(self, msg, code=grpc.StatusCode.INTERNAL):
self.msg = msg
self._code = code

def code(self):
return self._code

def __str__(self):
return self.msg

def trailing_metadata(self):
return ()


class ResponseGenerator(Generator):
"""This class is used to generate values that are returned by the streaming
iterator of the GRPC stub."""

def __init__(self, funs):
self._funs = funs
self._iterator = iter(self._funs)

def send(self, value: Any) -> proto.ExecutePlanResponse:
val = next(self._iterator)
if callable(val):
return val()
else:
return val

def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> Any:
super().throw(type, value, traceback)

def close(self) -> None:
return super().close()


class MockSparkConnectStub:
"""Simple mock class for the GRPC stub used by the re-attachable execution."""

def __init__(self, execute_ops=None, attach_ops=None):
self._execute_ops = execute_ops
self._attach_ops = attach_ops
# Call counters
self.execute_calls = 0
self.release_calls = 0
self.release_until_calls = 0
self.attach_calls = 0

def ExecutePlan(self, *args, **kwargs):
self.execute_calls += 1
return self._execute_ops

def ReattachExecute(self, *args, **kwargs):
self.attach_calls += 1
return self._attach_ops

def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, **kwargs):
if req.HasField("release_all"):
self.release_calls += 1
elif req.HasField("release_until"):
print("increment")
self.release_until_calls += 1


class MockService:
# Simplest mock of the SparkConnectService.
# If this needs more complex logic, it needs to be replaced with Python mocking.

req: Optional[proto.ExecutePlanRequest]

def __init__(self, session_id: str):
self._session_id = session_id
self.req = None

def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
self.req = req
resp = proto.ExecutePlanResponse()
resp.session_id = self._session_id

pdf = pd.DataFrame(data={"col1": [1, 2]})
schema = pa.Schema.from_pandas(pdf)
table = pa.Table.from_pandas(pdf)
sink = pa.BufferOutputStream()

writer = pa.ipc.new_stream(sink, schema=schema)
writer.write(table)
writer.close()

buf = sink.getvalue()
resp.arrow_batch.data = buf.to_pybytes()
resp.arrow_batch.row_count = 2
return [resp]

def Interrupt(self, req: proto.InterruptRequest, metadata):
self.req = req
resp = proto.InterruptResponse()
resp.session_id = self._session_id
return resp


if __name__ == "__main__":
from pyspark.sql.tests.connect.client.test_client import * # noqa: F401

Expand Down

0 comments on commit 360a3f9

Please sign in to comment.