diff --git a/.github/workflows/build_infra_images_cache.yml b/.github/workflows/build_infra_images_cache.yml index 41729692b200e..a6b505c4aef9d 100644 --- a/.github/workflows/build_infra_images_cache.yml +++ b/.github/workflows/build_infra_images_cache.yml @@ -38,6 +38,7 @@ on: - 'dev/spark-test-image/python-311/Dockerfile' - 'dev/spark-test-image/python-312/Dockerfile' - 'dev/spark-test-image/python-313/Dockerfile' + - 'dev/spark-test-image/python-313-nogil/Dockerfile' - 'dev/spark-test-image/numpy-213/Dockerfile' - '.github/workflows/build_infra_images_cache.yml' # Create infra image when cutting down branches/tags @@ -216,6 +217,19 @@ jobs: - name: Image digest (PySpark with Python 3.13) if: hashFiles('dev/spark-test-image/python-313/Dockerfile') != '' run: echo ${{ steps.docker_build_pyspark_python_313.outputs.digest }} + - name: Build and push (PySpark with Python 3.13 no GIL) + if: hashFiles('dev/spark-test-image/python-313-nogil/Dockerfile') != '' + id: docker_build_pyspark_python_313_nogil + uses: docker/build-push-action@v6 + with: + context: ./dev/spark-test-image/python-313-nogil/ + push: true + tags: ghcr.io/apache/spark/apache-spark-github-action-image-pyspark-python-313-nogil-cache:${{ github.ref_name }}-static + cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-pyspark-python-313-nogil-cache:${{ github.ref_name }} + cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-pyspark-python-313-nogil-cache:${{ github.ref_name }},mode=max + - name: Image digest (PySpark with Python 3.13 no GIL) + if: hashFiles('dev/spark-test-image/python-313-nogil/Dockerfile') != '' + run: echo ${{ steps.docker_build_pyspark_python_313_nogil.outputs.digest }} - name: Build and push (PySpark with Numpy 2.1.3) if: hashFiles('dev/spark-test-image/numpy-213/Dockerfile') != '' id: docker_build_pyspark_numpy_213 diff --git a/.github/workflows/build_python_3.13_nogil.yml b/.github/workflows/build_python_3.13_nogil.yml new file mode 100644 index 0000000000000..6fc717cc118fc --- /dev/null +++ b/.github/workflows/build_python_3.13_nogil.yml @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: "Build / Python-only (master, Python 3.13 no GIL)" + +on: + schedule: + - cron: '0 19 */3 * *' + workflow_dispatch: + +jobs: + run-build: + permissions: + packages: write + name: Run + uses: ./.github/workflows/build_and_test.yml + if: github.repository == 'apache/spark' + with: + java: 17 + branch: master + hadoop: hadoop3 + envs: >- + { + "PYSPARK_IMAGE_TO_TEST": "python-313-nogil", + "PYTHON_TO_TEST": "python3.13t", + "PYTHON_GIL": "0" + } + jobs: >- + { + "pyspark": "true", + "pyspark-pandas": "true" + } diff --git a/dev/spark-test-image/python-313-nogil/Dockerfile b/dev/spark-test-image/python-313-nogil/Dockerfile new file mode 100644 index 0000000000000..cee6a4cca4d33 --- /dev/null +++ b/dev/spark-test-image/python-313-nogil/Dockerfile @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Image for building and testing Spark branches. Based on Ubuntu 22.04. +# See also in https://hub.docker.com/_/ubuntu +FROM ubuntu:jammy-20240911.1 +LABEL org.opencontainers.image.authors="Apache Spark project " +LABEL org.opencontainers.image.licenses="Apache-2.0" +LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image For PySpark with Python 3.13 (no GIL)" +# Overwrite this label to avoid exposing the underlying Ubuntu OS version label +LABEL org.opencontainers.image.version="" + +ENV FULL_REFRESH_DATE=20250407 + +ENV DEBIAN_FRONTEND=noninteractive +ENV DEBCONF_NONINTERACTIVE_SEEN=true + +RUN apt-get update && apt-get install -y \ + build-essential \ + ca-certificates \ + curl \ + gfortran \ + git \ + gnupg \ + libcurl4-openssl-dev \ + libfontconfig1-dev \ + libfreetype6-dev \ + libfribidi-dev \ + libgit2-dev \ + libharfbuzz-dev \ + libjpeg-dev \ + liblapack-dev \ + libopenblas-dev \ + libpng-dev \ + libpython3-dev \ + libssl-dev \ + libtiff5-dev \ + libxml2-dev \ + openjdk-17-jdk-headless \ + pkg-config \ + qpdf \ + tzdata \ + software-properties-common \ + wget \ + zlib1g-dev + +# Install Python 3.13 (no GIL) +RUN add-apt-repository ppa:deadsnakes/ppa +RUN apt-get update && apt-get install -y \ + python3.13-nogil \ + && apt-get autoremove --purge -y \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + + +ARG BASIC_PIP_PKGS="numpy pyarrow>=19.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" +ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.1 googleapis-common-protos==1.65.0 graphviz==0.20.3" + + +# Install Python 3.13 packages +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13t +# TODO: Add BASIC_PIP_PKGS and CONNECT_PIP_PKGS when it supports Python 3.13 free threaded +# TODO: Add lxml, grpcio, grpcio-status back when they support Python 3.13 free threaded +RUN python3.13t -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this +RUN python3.13t -m pip install numpy>=2.1 pyarrow>=19.0.0 six==1.16.0 pandas==2.2.3 scipy coverage matplotlib openpyxl jinja2 && \ + python3.13t -m pip cache purge diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index 8ab0ead1c9032..198e6ff2fce55 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import pyspark.sql.connect.proto as pb2 import json from typing import Dict, List, Optional, TYPE_CHECKING @@ -43,6 +42,7 @@ ) if TYPE_CHECKING: + import pyspark.sql.connect.proto as pb2 from google.rpc.error_details_pb2 import ErrorInfo @@ -55,7 +55,7 @@ class SparkConnectException(PySparkException): def convert_exception( info: "ErrorInfo", truncated_message: str, - resp: Optional[pb2.FetchErrorDetailsResponse], + resp: Optional["pb2.FetchErrorDetailsResponse"], display_server_stacktrace: bool = False, ) -> SparkConnectException: converted = _convert_exception(info, truncated_message, resp, display_server_stacktrace) @@ -65,9 +65,11 @@ def convert_exception( def _convert_exception( info: "ErrorInfo", truncated_message: str, - resp: Optional[pb2.FetchErrorDetailsResponse], + resp: Optional["pb2.FetchErrorDetailsResponse"], display_server_stacktrace: bool = False, ) -> SparkConnectException: + import pyspark.sql.connect.proto as pb2 + raw_classes = info.metadata.get("classes") classes: List[str] = json.loads(raw_classes) if raw_classes else [] sql_state = info.metadata.get("sqlState") @@ -139,13 +141,13 @@ def _convert_exception( ) -def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str: +def _extract_jvm_stacktrace(resp: "pb2.FetchErrorDetailsResponse") -> str: if len(resp.errors[resp.root_error_idx].stack_trace) == 0: return "" lines: List[str] = [] - def format_stacktrace(error: pb2.FetchErrorDetailsResponse.Error) -> None: + def format_stacktrace(error: "pb2.FetchErrorDetailsResponse.Error") -> None: message = f"{error.error_type_hierarchy[0]}: {error.message}" if len(lines) == 0: lines.append(error.error_type_hierarchy[0]) @@ -404,7 +406,7 @@ class PickleException(SparkConnectGrpcException, BasePickleException): class SQLQueryContext(BaseQueryContext): - def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext): + def __init__(self, q: "pb2.FetchErrorDetailsResponse.QueryContext"): self._q = q def contextType(self) -> QueryContextType: @@ -441,7 +443,7 @@ def summary(self) -> str: class DataFrameQueryContext(BaseQueryContext): - def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext): + def __init__(self, q: "pb2.FetchErrorDetailsResponse.QueryContext"): self._q = q def contextType(self) -> QueryContextType: diff --git a/python/pyspark/ml/connect/__init__.py b/python/pyspark/ml/connect/__init__.py index 875a5370d996d..6a5453db0be9c 100644 --- a/python/pyspark/ml/connect/__init__.py +++ b/python/pyspark/ml/connect/__init__.py @@ -16,10 +16,6 @@ # """Spark Connect Python Client - ML module""" -from pyspark.sql.connect.utils import check_dependencies - -check_dependencies(__name__) - from pyspark.ml.connect.base import ( Estimator, Transformer, diff --git a/python/pyspark/ml/connect/base.py b/python/pyspark/ml/connect/base.py index 516b5057cc192..32c72d5907455 100644 --- a/python/pyspark/ml/connect/base.py +++ b/python/pyspark/ml/connect/base.py @@ -39,7 +39,6 @@ HasFeaturesCol, HasPredictionCol, ) -from pyspark.ml.connect.util import transform_dataframe_column if TYPE_CHECKING: from pyspark.ml._typing import ParamMap @@ -188,6 +187,8 @@ def transform( return self._transform(dataset) def _transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]: + from pyspark.ml.connect.util import transform_dataframe_column + input_cols = self._input_columns() transform_fn = self._get_transform_fn() output_cols = self._output_columns() diff --git a/python/pyspark/ml/connect/evaluation.py b/python/pyspark/ml/connect/evaluation.py index 267094f12a027..f324bb193c0ce 100644 --- a/python/pyspark/ml/connect/evaluation.py +++ b/python/pyspark/ml/connect/evaluation.py @@ -24,7 +24,6 @@ from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasProbabilityCol from pyspark.ml.connect.base import Evaluator from pyspark.ml.connect.io_utils import ParamsReadWrite -from pyspark.ml.connect.util import aggregate_dataframe from pyspark.sql import DataFrame @@ -56,6 +55,8 @@ def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]: raise NotImplementedError() def _evaluate(self, dataset: Union["DataFrame", "pd.DataFrame"]) -> float: + from pyspark.ml.connect.util import aggregate_dataframe + torch_metric = self._get_torch_metric() def local_agg_fn(pandas_df: "pd.DataFrame") -> "pd.DataFrame": diff --git a/python/pyspark/ml/connect/feature.py b/python/pyspark/ml/connect/feature.py index a0e5b6a943d10..b0e2028e43faa 100644 --- a/python/pyspark/ml/connect/feature.py +++ b/python/pyspark/ml/connect/feature.py @@ -35,7 +35,6 @@ ) from pyspark.ml.connect.base import Estimator, Model, Transformer from pyspark.ml.connect.io_utils import ParamsReadWrite, CoreModelReadWrite -from pyspark.ml.connect.summarizer import summarize_dataframe class MaxAbsScaler(Estimator, HasInputCol, HasOutputCol, ParamsReadWrite): @@ -81,6 +80,8 @@ def __init__(self, *, inputCol: Optional[str] = None, outputCol: Optional[str] = self._set(**kwargs) def _fit(self, dataset: Union["pd.DataFrame", "DataFrame"]) -> "MaxAbsScalerModel": + from pyspark.ml.connect.summarizer import summarize_dataframe + input_col = self.getInputCol() stat_res = summarize_dataframe(dataset, input_col, ["min", "max", "count"]) @@ -197,6 +198,8 @@ def __init__(self, inputCol: Optional[str] = None, outputCol: Optional[str] = No self._set(**kwargs) def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "StandardScalerModel": + from pyspark.ml.connect.summarizer import summarize_dataframe + input_col = self.getInputCol() stat_result = summarize_dataframe(dataset, input_col, ["mean", "std", "count"]) diff --git a/python/pyspark/ml/connect/functions.py b/python/pyspark/ml/connect/functions.py index e3664db87ae64..db77b4e641237 100644 --- a/python/pyspark/ml/connect/functions.py +++ b/python/pyspark/ml/connect/functions.py @@ -19,13 +19,15 @@ from pyspark.ml import functions as PyMLFunctions from pyspark.sql.column import Column -from pyspark.sql.connect.functions.builtin import _invoke_function, _to_col, lit + if TYPE_CHECKING: from pyspark.sql._typing import UserDefinedFunctionLike def vector_to_array(col: Column, dtype: str = "float64") -> Column: + from pyspark.sql.connect.functions.builtin import _invoke_function, _to_col, lit + return _invoke_function("vector_to_array", _to_col(col), lit(dtype)) @@ -33,6 +35,8 @@ def vector_to_array(col: Column, dtype: str = "float64") -> Column: def array_to_vector(col: Column) -> Column: + from pyspark.sql.connect.functions.builtin import _invoke_function, _to_col + return _invoke_function("array_to_vector", _to_col(col)) @@ -49,6 +53,11 @@ def predict_batch_udf(*args: Any, **kwargs: Any) -> "UserDefinedFunctionLike": def _test() -> None: import os import sys + + if os.environ.get("PYTHON_GIL", "?") == "0": + print("Not supported in no-GIL mode", file=sys.stderr) + sys.exit(0) + import doctest from pyspark.sql import SparkSession as PySparkSession import pyspark.ml.connect.functions diff --git a/python/pyspark/ml/connect/proto.py b/python/pyspark/ml/connect/proto.py index b0e012964fc4a..31f100859281a 100644 --- a/python/pyspark/ml/connect/proto.py +++ b/python/pyspark/ml/connect/proto.py @@ -14,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) + from typing import Optional, TYPE_CHECKING, List import pyspark.sql.connect.proto as pb2 diff --git a/python/pyspark/ml/connect/readwrite.py b/python/pyspark/ml/connect/readwrite.py index 95551f67c0120..0dc38e7275c1f 100644 --- a/python/pyspark/ml/connect/readwrite.py +++ b/python/pyspark/ml/connect/readwrite.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) from typing import cast, Type, TYPE_CHECKING, Union, Dict, Any diff --git a/python/pyspark/ml/connect/serialize.py b/python/pyspark/ml/connect/serialize.py index 42bedfb330b1b..37102d463b057 100644 --- a/python/pyspark/ml/connect/serialize.py +++ b/python/pyspark/ml/connect/serialize.py @@ -14,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) + from typing import Any, List, TYPE_CHECKING, Mapping, Dict import pyspark.sql.connect.proto as pb2 diff --git a/python/pyspark/ml/connect/util.py b/python/pyspark/ml/connect/util.py index 1c77baeba5f88..8bf4b3480e32e 100644 --- a/python/pyspark/ml/connect/util.py +++ b/python/pyspark/ml/connect/util.py @@ -15,14 +15,16 @@ # limitations under the License. # -from typing import Any, TypeVar, Callable, List, Tuple, Union, Iterable +from typing import Any, TypeVar, Callable, List, Tuple, Union, Iterable, TYPE_CHECKING import pandas as pd from pyspark import cloudpickle from pyspark.sql import DataFrame from pyspark.sql.functions import col, pandas_udf -import pyspark.sql.connect.proto as pb2 + +if TYPE_CHECKING: + import pyspark.sql.connect.proto as pb2 FuncT = TypeVar("FuncT", bound=Callable[..., Any]) @@ -180,6 +182,8 @@ def transform_fn_pandas_udf(*s: "pd.Series") -> "pd.Series": def _extract_id_methods(obj_identifier: str) -> Tuple[List["pb2.Fetch.Method"], str]: """Extract the obj reference id and the methods. Eg, model.summary""" + import pyspark.sql.connect.proto as pb2 + method_chain = obj_identifier.split(".") obj_ref = method_chain[0] methods: List["pb2.Fetch.Method"] = [] diff --git a/python/pyspark/sql/connect/tvf.py b/python/pyspark/sql/connect/tvf.py index 2fca610a5fe3a..104768d5bc3cb 100644 --- a/python/pyspark/sql/connect/tvf.py +++ b/python/pyspark/sql/connect/tvf.py @@ -14,21 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Optional +from typing import Optional, TYPE_CHECKING from pyspark.errors import PySparkValueError -from pyspark.sql.connect.column import Column -from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.connect.functions.builtin import _to_col -from pyspark.sql.connect.plan import UnresolvedTableValuedFunction -from pyspark.sql.connect.session import SparkSession from pyspark.sql.tvf import TableValuedFunction as PySparkTableValuedFunction +if TYPE_CHECKING: + from pyspark.sql.connect.column import Column + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.connect.session import SparkSession + class TableValuedFunction: __doc__ = PySparkTableValuedFunction.__doc__ - def __init__(self, sparkSession: SparkSession): + def __init__(self, sparkSession: "SparkSession"): self._sparkSession = sparkSession def range( @@ -37,34 +37,34 @@ def range( end: Optional[int] = None, step: int = 1, numPartitions: Optional[int] = None, - ) -> DataFrame: + ) -> "DataFrame": return self._sparkSession.range( # type: ignore[return-value] start, end, step, numPartitions ) range.__doc__ = PySparkTableValuedFunction.range.__doc__ - def explode(self, collection: Column) -> DataFrame: + def explode(self, collection: "Column") -> "DataFrame": return self._fn("explode", collection) explode.__doc__ = PySparkTableValuedFunction.explode.__doc__ - def explode_outer(self, collection: Column) -> DataFrame: + def explode_outer(self, collection: "Column") -> "DataFrame": return self._fn("explode_outer", collection) explode_outer.__doc__ = PySparkTableValuedFunction.explode_outer.__doc__ - def inline(self, input: Column) -> DataFrame: + def inline(self, input: "Column") -> "DataFrame": return self._fn("inline", input) inline.__doc__ = PySparkTableValuedFunction.inline.__doc__ - def inline_outer(self, input: Column) -> DataFrame: + def inline_outer(self, input: "Column") -> "DataFrame": return self._fn("inline_outer", input) inline_outer.__doc__ = PySparkTableValuedFunction.inline_outer.__doc__ - def json_tuple(self, input: Column, *fields: Column) -> DataFrame: + def json_tuple(self, input: "Column", *fields: "Column") -> "DataFrame": if len(fields) == 0: raise PySparkValueError( errorClass="CANNOT_BE_EMPTY", @@ -74,42 +74,46 @@ def json_tuple(self, input: Column, *fields: Column) -> DataFrame: json_tuple.__doc__ = PySparkTableValuedFunction.json_tuple.__doc__ - def posexplode(self, collection: Column) -> DataFrame: + def posexplode(self, collection: "Column") -> "DataFrame": return self._fn("posexplode", collection) posexplode.__doc__ = PySparkTableValuedFunction.posexplode.__doc__ - def posexplode_outer(self, collection: Column) -> DataFrame: + def posexplode_outer(self, collection: "Column") -> "DataFrame": return self._fn("posexplode_outer", collection) posexplode_outer.__doc__ = PySparkTableValuedFunction.posexplode_outer.__doc__ - def stack(self, n: Column, *fields: Column) -> DataFrame: + def stack(self, n: "Column", *fields: "Column") -> "DataFrame": return self._fn("stack", n, *fields) stack.__doc__ = PySparkTableValuedFunction.stack.__doc__ - def collations(self) -> DataFrame: + def collations(self) -> "DataFrame": return self._fn("collations") collations.__doc__ = PySparkTableValuedFunction.collations.__doc__ - def sql_keywords(self) -> DataFrame: + def sql_keywords(self) -> "DataFrame": return self._fn("sql_keywords") sql_keywords.__doc__ = PySparkTableValuedFunction.sql_keywords.__doc__ - def variant_explode(self, input: Column) -> DataFrame: + def variant_explode(self, input: "Column") -> "DataFrame": return self._fn("variant_explode", input) variant_explode.__doc__ = PySparkTableValuedFunction.variant_explode.__doc__ - def variant_explode_outer(self, input: Column) -> DataFrame: + def variant_explode_outer(self, input: "Column") -> "DataFrame": return self._fn("variant_explode_outer", input) variant_explode_outer.__doc__ = PySparkTableValuedFunction.variant_explode_outer.__doc__ - def _fn(self, name: str, *args: Column) -> DataFrame: + def _fn(self, name: str, *args: "Column") -> "DataFrame": + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.connect.plan import UnresolvedTableValuedFunction + from pyspark.sql.connect.functions.builtin import _to_col + return DataFrame( UnresolvedTableValuedFunction(name, [_to_col(arg) for arg in args]), self._sparkSession ) @@ -117,8 +121,13 @@ def _fn(self, name: str, *args: Column) -> DataFrame: def _test() -> None: import os - import doctest import sys + + if os.environ.get("PYTHON_GIL", "?") == "0": + print("Not supported in no-GIL mode", file=sys.stderr) + sys.exit(0) + + import doctest from pyspark.sql import SparkSession as PySparkSession import pyspark.sql.connect.tvf diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index f0637056ab8f9..fbb957b5e265c 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -41,6 +41,7 @@ from pyspark.testing.sqlutils import SQLTestUtils from pyspark.testing.connectutils import ( should_test_connect, + connect_requirement_message, ReusedConnectTestCase, ) from pyspark.testing.pandasutils import PandasOnSparkTestUtils @@ -58,7 +59,10 @@ from pyspark.sql.connect import functions as CF -@unittest.skipIf(is_remote_only(), "Requires JVM access") +@unittest.skipIf( + not should_test_connect or is_remote_only(), + connect_requirement_message or "Requires JVM access", +) class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils): """Parent test fixture class for all Spark Connect related test cases.""" diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 55948891089ee..9cbeb854555ad 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -53,8 +53,11 @@ @unittest.skipIf( - not have_pandas or not have_pyarrow, - cast(str, pandas_requirement_message or pyarrow_requirement_message), + not have_pandas or not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0", + cast( + str, + pandas_requirement_message or pyarrow_requirement_message or "Not supported in no-GIL mode", + ), ) class TransformWithStateInPandasTestsMixin: @classmethod