diff --git a/python/pyspark/sql/tests/arrow/test_arrow_c_stream.py b/python/pyspark/sql/tests/arrow/test_arrow_c_stream.py index 9534db71bae6..e71d1d18329f 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_c_stream.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_c_stream.py @@ -16,11 +16,25 @@ # import ctypes import unittest -import pyarrow as pa -import pandas as pd import pyspark.pandas as ps +from pyspark.testing.utils import ( + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) +if have_pandas: + import pandas as pd +if have_pyarrow: + import pyarrow as pa + + +@unittest.skipIf( + not have_pyarrow or not have_pandas, + pyarrow_requirement_message or pandas_requirement_message, +) class TestSparkArrowCStreamer(unittest.TestCase): def test_spark_arrow_c_streamer_arrow_consumer(self): pdf = pd.DataFrame([[1, "a"], [2, "b"], [3, "c"], [4, "d"]], columns=["id", "value"]) @@ -53,12 +67,6 @@ def test_spark_arrow_c_streamer_arrow_consumer(self): if __name__ == "__main__": - from pyspark.sql.tests.arrow.test_arrow_c_stream import * # noqa: F401 - - try: - import xmlrunner # type: ignore + from pyspark.testing import main - test_runner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) - except ImportError: - test_runner = None - unittest.main(testRunner=test_runner, verbosity=2) + main()