Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 2 additions & 38 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,44 +601,8 @@ def sample( # type: ignore[misc]
fraction: Optional[Union[int, float]] = None,
seed: Optional[int] = None,
) -> ParentDataFrame:
# For the cases below:
# sample(True, 0.5 [, seed])
# sample(True, fraction=0.5 [, seed])
# sample(withReplacement=False, fraction=0.5 [, seed])
is_withReplacement_set = type(withReplacement) == bool and isinstance(fraction, float)

# For the case below:
# sample(faction=0.5 [, seed])
is_withReplacement_omitted_kwargs = withReplacement is None and isinstance(fraction, float)

# For the case below:
# sample(0.5 [, seed])
is_withReplacement_omitted_args = isinstance(withReplacement, float)

if not (
is_withReplacement_set
or is_withReplacement_omitted_kwargs
or is_withReplacement_omitted_args
):
argtypes = [type(arg).__name__ for arg in [withReplacement, fraction, seed]]
raise PySparkTypeError(
errorClass="NOT_BOOL_OR_FLOAT_OR_INT",
messageParameters={
"arg_name": "withReplacement (optional), "
+ "fraction (required) and seed (optional)",
"arg_type": ", ".join(argtypes),
},
)

if is_withReplacement_omitted_args:
if fraction is not None:
seed = cast(int, fraction)
fraction = withReplacement
withReplacement = None

seed = int(seed) if seed is not None else None
args = [arg for arg in [withReplacement, fraction, seed] if arg is not None]
jdf = self._jdf.sample(*args)
_w, _f, _s = self._preapare_args_for_sample(withReplacement, fraction, seed)
jdf = self._jdf.sample(*[_w, _f, _s])
return DataFrame(jdf, self.sparkSession)

def sampleBy(
Expand Down
47 changes: 4 additions & 43 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,53 +781,14 @@ def sample(
fraction: Optional[Union[int, float]] = None,
seed: Optional[int] = None,
) -> ParentDataFrame:
# For the cases below:
# sample(True, 0.5 [, seed])
# sample(True, fraction=0.5 [, seed])
# sample(withReplacement=False, fraction=0.5 [, seed])
is_withReplacement_set = type(withReplacement) == bool and isinstance(fraction, float)

# For the case below:
# sample(faction=0.5 [, seed])
is_withReplacement_omitted_kwargs = withReplacement is None and isinstance(fraction, float)

# For the case below:
# sample(0.5 [, seed])
is_withReplacement_omitted_args = isinstance(withReplacement, float)

if not (
is_withReplacement_set
or is_withReplacement_omitted_kwargs
or is_withReplacement_omitted_args
):
argtypes = [type(arg).__name__ for arg in [withReplacement, fraction, seed]]
raise PySparkTypeError(
errorClass="NOT_BOOL_OR_FLOAT_OR_INT",
messageParameters={
"arg_name": "withReplacement (optional), "
+ "fraction (required) and seed (optional)",
"arg_type": ", ".join(argtypes),
},
)

if is_withReplacement_omitted_args:
if fraction is not None:
seed = cast(int, fraction)
fraction = withReplacement
withReplacement = None

if withReplacement is None:
withReplacement = False

seed = int(seed) if seed is not None else random.randint(0, sys.maxsize)

_w, _f, _s = self._preapare_args_for_sample(withReplacement, fraction, seed)
res = DataFrame(
plan.Sample(
child=self._plan,
lower_bound=0.0,
upper_bound=fraction, # type: ignore[arg-type]
with_replacement=withReplacement, # type: ignore[arg-type]
seed=seed,
upper_bound=_f,
with_replacement=_w,
seed=_s,
),
session=self._session,
)
Expand Down
42 changes: 42 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

# mypy: disable-error-code="empty-body"

import sys
import random
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -2040,6 +2042,46 @@ def sample(
"""
...

def _preapare_args_for_sample(
self,
withReplacement: Optional[Union[float, bool]] = None,
fraction: Optional[Union[int, float]] = None,
seed: Optional[int] = None,
) -> Tuple[bool, float, int]:
from pyspark.errors import PySparkTypeError

if isinstance(withReplacement, bool) and isinstance(fraction, float):
# For the cases below:
# sample(True, 0.5 [, seed])
# sample(True, fraction=0.5 [, seed])
# sample(withReplacement=False, fraction=0.5 [, seed])
_seed = int(seed) if seed is not None else random.randint(0, sys.maxsize)
return withReplacement, fraction, _seed

elif withReplacement is None and isinstance(fraction, float):
# For the case below:
# sample(faction=0.5 [, seed])
_seed = int(seed) if seed is not None else random.randint(0, sys.maxsize)
return False, fraction, _seed

elif isinstance(withReplacement, float):
# For the case below:
# sample(0.5 [, seed])
_seed = int(fraction) if fraction is not None else random.randint(0, sys.maxsize)
_fraction = float(withReplacement)
return False, _fraction, _seed

else:
argtypes = [type(arg).__name__ for arg in [withReplacement, fraction, seed]]
raise PySparkTypeError(
errorClass="NOT_BOOL_OR_FLOAT_OR_INT",
messageParameters={
"arg_name": "withReplacement (optional), "
+ "fraction (required) and seed (optional)",
"arg_type": ", ".join(argtypes),
},
)

@dispatch_df_method
def sampleBy(
self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None
Expand Down