Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions python/pyspark/testing/connectutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
import unittest
import uuid
import contextlib
from typing import Callable, Optional

from pyspark import Row, SparkConf
from pyspark.util import is_remote_only
from pyspark.testing.utils import PySparkErrorTestUtils
from pyspark import Row, SparkConf
from pyspark.loose_version import LooseVersion
from pyspark.util import is_remote_only
from pyspark.testing.utils import (
have_pandas,
Expand Down Expand Up @@ -306,3 +305,28 @@ def _both_conf():
yield

return _both_conf()


def skip_if_server_version_is(
cond: Callable[[LooseVersion], bool], reason: Optional[str] = None
) -> Callable[[...], ...]:
def decorator(f: Callable) -> Callable:
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
version = self.spark.version
if cond(LooseVersion(version)):
raise unittest.SkipTest(
f"Skipping test {f.__name__} because server version is {version}"
+ (f" ({reason})" if reason else "")
)
return f(self, *args, **kwargs)

return wrapper

return decorator


def skip_if_server_version_is_greater_than_or_equal_to(
version: str, reason: Optional[str] = None
) -> Callable[[...], ...]:
return skip_if_server_version_is(lambda v: v >= LooseVersion(version), reason)