From 3872971725fc7b4463b1a2640301093dcc165ecd Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sat, 13 Apr 2024 10:22:24 +0900 Subject: [PATCH] Make testing Spark Connect server having pyspark.core --- .github/workflows/build_python_connect.yml | 4 +++- .../sql/tests/connect/test_parity_memory_profiler.py | 3 --- .../sql/tests/connect/test_parity_udf_profiler.py | 3 --- python/pyspark/worker_util.py | 12 ++++-------- 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml index 6bd1b4526b0d9..965e839b6b2bc 100644 --- a/.github/workflows/build_python_connect.yml +++ b/.github/workflows/build_python_connect.yml @@ -82,7 +82,9 @@ jobs: cp conf/log4j2.properties.template conf/log4j2.properties sed -i 's/rootLogger.level = info/rootLogger.level = warn/g' conf/log4j2.properties # Start a Spark Connect server - ./sbin/start-connect-server.sh --driver-java-options "-Dlog4j.configurationFile=file:$GITHUB_WORKSPACE/conf/log4j2.properties" --jars `find connector/connect/server/target -name spark-connect*SNAPSHOT.jar` + PYTHONPATH="python/lib/pyspark.zip:python/lib/py4j-0.10.9.7-src.zip:$PYTHONPATH" ./sbin/start-connect-server.sh --driver-java-options "-Dlog4j.configurationFile=file:$GITHUB_WORKSPACE/conf/log4j2.properties" --jars `find connector/connect/server/target -name spark-connect*SNAPSHOT.jar` + # Make sure running Python workers that contains pyspark.core once. They will be reused. + python -c "from pyspark.sql import SparkSession; _ = SparkSession.builder.remote('sc://localhost').getOrCreate().range(100).repartition(100).mapInPandas(lambda x: x, 'id INT').collect()" # Remove Py4J and PySpark zipped library to make sure there is no JVM connection rm python/lib/* rm -r python/pyspark diff --git a/python/pyspark/sql/tests/connect/test_parity_memory_profiler.py b/python/pyspark/sql/tests/connect/test_parity_memory_profiler.py index f95e0bfbf8d60..513e49a144e50 100644 --- a/python/pyspark/sql/tests/connect/test_parity_memory_profiler.py +++ b/python/pyspark/sql/tests/connect/test_parity_memory_profiler.py @@ -18,13 +18,10 @@ import os import unittest -from pyspark.util import is_remote_only from pyspark.tests.test_memory_profiler import MemoryProfiler2TestsMixin, _do_computation from pyspark.testing.connectutils import ReusedConnectTestCase -# TODO(SPARK-47830): Reeanble MemoryProfilerParityTests for pyspark-connect -@unittest.skipIf(is_remote_only(), "Skipped for now") class MemoryProfilerParityTests(MemoryProfiler2TestsMixin, ReusedConnectTestCase): def setUp(self) -> None: super().setUp() diff --git a/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py index e682e46ca1852..dfa56ff0bb888 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py +++ b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py @@ -18,13 +18,10 @@ import os import unittest -from pyspark.util import is_remote_only from pyspark.sql.tests.test_udf_profiler import UDFProfiler2TestsMixin, _do_computation from pyspark.testing.connectutils import ReusedConnectTestCase -# TODO(SPARK-47756): Reeanble UDFProfilerParityTests for pyspark-connect -@unittest.skipIf(is_remote_only(), "Skipped for now") class UDFProfilerParityTests(UDFProfiler2TestsMixin, ReusedConnectTestCase): def setUp(self) -> None: super().setUp() diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py index 22389decac2f0..6dfa12ce3affe 100644 --- a/python/pyspark/worker_util.py +++ b/python/pyspark/worker_util.py @@ -32,6 +32,7 @@ except ImportError: has_resource_module = False +from pyspark.accumulators import _accumulatorRegistry from pyspark.util import is_remote_only from pyspark.errors import PySparkRuntimeError from pyspark.util import local_connect_and_auth @@ -183,11 +184,6 @@ def send_accumulator_updates(outfile: IO) -> None: """ Send the accumulator updates back to JVM. """ - if not is_remote_only(): - from pyspark.accumulators import _accumulatorRegistry - - write_int(len(_accumulatorRegistry), outfile) - for aid, accum in _accumulatorRegistry.items(): - pickleSer._write_with_length((aid, accum._value), outfile) - else: - write_int(0, outfile) + write_int(len(_accumulatorRegistry), outfile) + for aid, accum in _accumulatorRegistry.items(): + pickleSer._write_with_length((aid, accum._value), outfile)