From 2fa1d6be8f6fe9e71f2def743484940b8c4b6dbf Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 28 Jan 2023 16:51:40 -0800 Subject: [PATCH] [SPARK-41830][CONNECT][PYTHON][TESTS][FOLLOWUP] Enable parity test `test_sample` ### What changes were proposed in this pull request? Enable parity test `test_sample` ### Why are the changes needed? For test coverage ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? enabled test Closes #39765 from zhengruifeng/connect_enable_41830. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/tests/connect/test_parity_dataframe.py | 5 ----- python/pyspark/sql/tests/test_dataframe.py | 6 +++++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index db0d727d33021..e04119bea9f05 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -95,11 +95,6 @@ def test_require_cross(self): def test_same_semantics_error(self): super().test_same_semantics_error() - # TODO(SPARK-41830): Fix DataFrame.sample parameters - @unittest.skip("Fails in Spark Connect, should enable.") - def test_sample(self): - super().test_sample() - @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") def test_toDF_with_schema_string(self): super().test_toDF_with_schema_string() diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 845cf0f1fbe14..0ba0649245c12 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -46,6 +46,7 @@ from pyspark.errors import ( AnalysisException, IllegalArgumentException, + SparkConnectException, SparkConnectAnalysisException, ) from pyspark.testing.sqlutils import ( @@ -888,7 +889,10 @@ def test_sample(self): self.assertRaises(TypeError, lambda: self.spark.range(1).sample(seed="abc")) - self.assertRaises(IllegalArgumentException, lambda: self.spark.range(1).sample(-1.0)) + self.assertRaises( + (IllegalArgumentException, SparkConnectException), + lambda: self.spark.range(1).sample(-1.0).count(), + ) def test_toDF_with_schema_string(self): data = [Row(key=i, value=str(i)) for i in range(100)]