diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index f9a209d2bcb3d..843c92a9b27d2 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -813,7 +813,7 @@ def sample( if withReplacement is None: withReplacement = False - seed = int(seed) if seed is not None else None + seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) return DataFrame( plan.Sample( diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py index 09c3171ee11fd..e8d04aeada740 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan.py @@ -443,7 +443,7 @@ def test_sample(self): self.assertEqual(plan.root.sample.lower_bound, 0.0) self.assertEqual(plan.root.sample.upper_bound, 0.3) self.assertEqual(plan.root.sample.with_replacement, False) - self.assertEqual(plan.root.sample.HasField("seed"), False) + self.assertEqual(plan.root.sample.HasField("seed"), True) self.assertEqual(plan.root.sample.deterministic_order, False) plan = ( diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 16dd0d2a3bf7c..f491b496ddae5 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -430,6 +430,11 @@ def test_sample(self): IllegalArgumentException, lambda: self.spark.range(1).sample(-1.0).count() ) + def test_sample_with_random_seed(self): + df = self.spark.range(10000).sample(0.1) + cnts = [df.count() for i in range(10)] + self.assertEqual(1, len(set(cnts))) + def test_toDF_with_string(self): df = self.spark.createDataFrame([("John", 30), ("Alice", 25), ("Bob", 28)]) data = [("John", 30), ("Alice", 25), ("Bob", 28)]