From 9534cc633bbc709f20a83adaeb0808bfa2d4f7ea Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 27 Sep 2024 12:18:55 +0800 Subject: [PATCH] nit fix fix fix --- python/pyspark/sql/classic/dataframe.py | 40 ++------------------- python/pyspark/sql/connect/dataframe.py | 47 +++---------------------- python/pyspark/sql/dataframe.py | 42 ++++++++++++++++++++++ 3 files changed, 48 insertions(+), 81 deletions(-) diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index 9f9dedbd38207..e412b98c47de5 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -601,44 +601,8 @@ def sample( # type: ignore[misc] fraction: Optional[Union[int, float]] = None, seed: Optional[int] = None, ) -> ParentDataFrame: - # For the cases below: - # sample(True, 0.5 [, seed]) - # sample(True, fraction=0.5 [, seed]) - # sample(withReplacement=False, fraction=0.5 [, seed]) - is_withReplacement_set = type(withReplacement) == bool and isinstance(fraction, float) - - # For the case below: - # sample(faction=0.5 [, seed]) - is_withReplacement_omitted_kwargs = withReplacement is None and isinstance(fraction, float) - - # For the case below: - # sample(0.5 [, seed]) - is_withReplacement_omitted_args = isinstance(withReplacement, float) - - if not ( - is_withReplacement_set - or is_withReplacement_omitted_kwargs - or is_withReplacement_omitted_args - ): - argtypes = [type(arg).__name__ for arg in [withReplacement, fraction, seed]] - raise PySparkTypeError( - errorClass="NOT_BOOL_OR_FLOAT_OR_INT", - messageParameters={ - "arg_name": "withReplacement (optional), " - + "fraction (required) and seed (optional)", - "arg_type": ", ".join(argtypes), - }, - ) - - if is_withReplacement_omitted_args: - if fraction is not None: - seed = cast(int, fraction) - fraction = withReplacement - withReplacement = None - - seed = int(seed) if seed is not None else None - args = [arg for arg in [withReplacement, fraction, seed] if arg is not None] - jdf = self._jdf.sample(*args) + _w, _f, _s = self._preapare_args_for_sample(withReplacement, fraction, seed) + jdf = self._jdf.sample(*[_w, _f, _s]) return DataFrame(jdf, self.sparkSession) def sampleBy( diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 136fe60532df4..bb4dcb38c9e58 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -781,53 +781,14 @@ def sample( fraction: Optional[Union[int, float]] = None, seed: Optional[int] = None, ) -> ParentDataFrame: - # For the cases below: - # sample(True, 0.5 [, seed]) - # sample(True, fraction=0.5 [, seed]) - # sample(withReplacement=False, fraction=0.5 [, seed]) - is_withReplacement_set = type(withReplacement) == bool and isinstance(fraction, float) - - # For the case below: - # sample(faction=0.5 [, seed]) - is_withReplacement_omitted_kwargs = withReplacement is None and isinstance(fraction, float) - - # For the case below: - # sample(0.5 [, seed]) - is_withReplacement_omitted_args = isinstance(withReplacement, float) - - if not ( - is_withReplacement_set - or is_withReplacement_omitted_kwargs - or is_withReplacement_omitted_args - ): - argtypes = [type(arg).__name__ for arg in [withReplacement, fraction, seed]] - raise PySparkTypeError( - errorClass="NOT_BOOL_OR_FLOAT_OR_INT", - messageParameters={ - "arg_name": "withReplacement (optional), " - + "fraction (required) and seed (optional)", - "arg_type": ", ".join(argtypes), - }, - ) - - if is_withReplacement_omitted_args: - if fraction is not None: - seed = cast(int, fraction) - fraction = withReplacement - withReplacement = None - - if withReplacement is None: - withReplacement = False - - seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) - + _w, _f, _s = self._preapare_args_for_sample(withReplacement, fraction, seed) res = DataFrame( plan.Sample( child=self._plan, lower_bound=0.0, - upper_bound=fraction, # type: ignore[arg-type] - with_replacement=withReplacement, # type: ignore[arg-type] - seed=seed, + upper_bound=_f, + with_replacement=_w, + seed=_s, ), session=self._session, ) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5906108163b46..c21e2271a64ac 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -17,6 +17,8 @@ # mypy: disable-error-code="empty-body" +import sys +import random from typing import ( Any, Callable, @@ -2040,6 +2042,46 @@ def sample( """ ... + def _preapare_args_for_sample( + self, + withReplacement: Optional[Union[float, bool]] = None, + fraction: Optional[Union[int, float]] = None, + seed: Optional[int] = None, + ) -> Tuple[bool, float, int]: + from pyspark.errors import PySparkTypeError + + if isinstance(withReplacement, bool) and isinstance(fraction, float): + # For the cases below: + # sample(True, 0.5 [, seed]) + # sample(True, fraction=0.5 [, seed]) + # sample(withReplacement=False, fraction=0.5 [, seed]) + _seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) + return withReplacement, fraction, _seed + + elif withReplacement is None and isinstance(fraction, float): + # For the case below: + # sample(faction=0.5 [, seed]) + _seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) + return False, fraction, _seed + + elif isinstance(withReplacement, float): + # For the case below: + # sample(0.5 [, seed]) + _seed = int(fraction) if fraction is not None else random.randint(0, sys.maxsize) + _fraction = float(withReplacement) + return False, _fraction, _seed + + else: + argtypes = [type(arg).__name__ for arg in [withReplacement, fraction, seed]] + raise PySparkTypeError( + errorClass="NOT_BOOL_OR_FLOAT_OR_INT", + messageParameters={ + "arg_name": "withReplacement (optional), " + + "fraction (required) and seed (optional)", + "arg_type": ", ".join(argtypes), + }, + ) + @dispatch_df_method def sampleBy( self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None