Skip to content

Commit

Permalink
[SPARK-48184][PYTHON][CONNECT] Always set the seed of `Dataframe.samp…
Browse files Browse the repository at this point in the history
…le` in Client side

### What changes were proposed in this pull request?
Always set the seed of `Dataframe.sample` in Client side

### Why are the changes needed?
Bug fix

If the seed is not set in Client, it will be set in server side with a random int

https://github.com/apache/spark/blob/c4df12cc884cddefcfcf8324b4d7b9349fb4f6a0/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala#L386

which cause inconsistent results in multiple executions

In Spark Classic:
```
In [1]: df = spark.range(10000).sample(0.1)

In [2]: [df.count() for i in range(10)]
Out[2]: [1006, 1006, 1006, 1006, 1006, 1006, 1006, 1006, 1006, 1006]
```

In Spark Connect:

before:
```
In [1]: df = spark.range(10000).sample(0.1)

In [2]: [df.count() for i in range(10)]
Out[2]: [969, 1005, 958, 996, 987, 1026, 991, 1020, 1012, 979]
```

after:
```
In [1]: df = spark.range(10000).sample(0.1)

In [2]: [df.count() for i in range(10)]
Out[2]: [1032, 1032, 1032, 1032, 1032, 1032, 1032, 1032, 1032, 1032]
```

### Does this PR introduce _any_ user-facing change?
yes, bug fix

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#46456 from zhengruifeng/py_connect_sample_seed.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
zhengruifeng authored and JacobZheng0927 committed May 11, 2024
1 parent 5f009d8 commit e858946
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def sample(
if withReplacement is None:
withReplacement = False

seed = int(seed) if seed is not None else None
seed = int(seed) if seed is not None else random.randint(0, sys.maxsize)

return DataFrame(
plan.Sample(
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/connect/test_connect_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def test_sample(self):
self.assertEqual(plan.root.sample.lower_bound, 0.0)
self.assertEqual(plan.root.sample.upper_bound, 0.3)
self.assertEqual(plan.root.sample.with_replacement, False)
self.assertEqual(plan.root.sample.HasField("seed"), False)
self.assertEqual(plan.root.sample.HasField("seed"), True)
self.assertEqual(plan.root.sample.deterministic_order, False)

plan = (
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,11 @@ def test_sample(self):
IllegalArgumentException, lambda: self.spark.range(1).sample(-1.0).count()
)

def test_sample_with_random_seed(self):
df = self.spark.range(10000).sample(0.1)
cnts = [df.count() for i in range(10)]
self.assertEqual(1, len(set(cnts)))

def test_toDF_with_string(self):
df = self.spark.createDataFrame([("John", 30), ("Alice", 25), ("Bob", 28)])
data = [("John", 30), ("Alice", 25), ("Bob", 28)]
Expand Down

0 comments on commit e858946

Please sign in to comment.