Skip to content

Commit

Permalink
Refine type hints in pyspark.pandas.window
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Jun 26, 2021
1 parent 939ea3d commit 53febc2
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 151 deletions.
14 changes: 14 additions & 0 deletions python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -11751,6 +11751,20 @@ def from_dict(
"""
return DataFrame(pd.DataFrame.from_dict(data, orient=orient, dtype=dtype, columns=columns))

# Override the `groupby` to specify the actual return type annotation.
def groupby(
self,
by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]],
axis: Union[int, str] = 0,
as_index: bool = True,
dropna: bool = True,
) -> "DataFrameGroupBy":
return cast(
"DataFrameGroupBy", super().groupby(by=by, axis=axis, as_index=as_index, dropna=dropna)
)

groupby.__doc__ = Frame.groupby.__doc__

def _build_groupby(
self, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool
) -> "DataFrameGroupBy":
Expand Down
18 changes: 11 additions & 7 deletions python/pyspark/pandas/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@
validate_axis,
SPARK_CONF_ARROW_ENABLED,
)
from pyspark.pandas.window import Rolling, Expanding

if TYPE_CHECKING:
from pyspark.pandas.frame import DataFrame # noqa: F401 (SPARK-34943)
from pyspark.pandas.indexes.base import Index # noqa: F401 (SPARK-34943)
from pyspark.pandas.groupby import GroupBy # noqa: F401 (SPARK-34943)
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
from pyspark.pandas.window import Rolling, Expanding # noqa: F401 (SPARK-34943)


T_Frame = TypeVar("T_Frame", bound="Frame")
Expand Down Expand Up @@ -2508,7 +2508,9 @@ def last_valid_index(self) -> Optional[Union[Scalar, Tuple[Scalar, ...]]]:
return tuple(last_valid_row)

# TODO: 'center', 'win_type', 'on', 'axis' parameter should be implemented.
def rolling(self, window: int, min_periods: Optional[int] = None) -> Rolling:
def rolling(
self: T_Frame, window: int, min_periods: Optional[int] = None
) -> "Rolling[T_Frame]":
"""
Provide rolling transformations.
Expand All @@ -2533,13 +2535,13 @@ def rolling(self, window: int, min_periods: Optional[int] = None) -> Rolling:
-------
a Window sub-classed for the particular operation
"""
return Rolling(
cast(Union["Series", "DataFrame"], self), window=window, min_periods=min_periods
)
from pyspark.pandas.window import Rolling

return Rolling(self, window=window, min_periods=min_periods)

# TODO: 'center' and 'axis' parameter should be implemented.
# 'axis' implementation, refer https://github.com/pyspark.pandas/pull/607
def expanding(self, min_periods: int = 1) -> Expanding:
def expanding(self: T_Frame, min_periods: int = 1) -> "Expanding[T_Frame]":
"""
Provide expanding transformations.
Expand All @@ -2557,7 +2559,9 @@ def expanding(self, min_periods: int = 1) -> Expanding:
-------
a Window sub-classed for the particular operation
"""
return Expanding(cast(Union["Series", "DataFrame"], self), min_periods=min_periods)
from pyspark.pandas.window import Expanding

return Expanding(self, min_periods=min_periods)

def get(self, key: Any, default: Optional[Any] = None) -> Any:
"""
Expand Down
22 changes: 13 additions & 9 deletions python/pyspark/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
TypeVar,
Union,
cast,
TYPE_CHECKING,
)

import pandas as pd
Expand Down Expand Up @@ -85,9 +86,12 @@
verify_temp_column_name,
)
from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale
from pyspark.pandas.window import RollingGroupby, ExpandingGroupby
from pyspark.pandas.exceptions import DataError

if TYPE_CHECKING:
from pyspark.pandas.window import RollingGroupby, ExpandingGroupby # noqa: F401 (SPARK-34943)


# to keep it the same as pandas
NamedAgg = namedtuple("NamedAgg", ["column", "aggfunc"])

Expand Down Expand Up @@ -2320,7 +2324,7 @@ def nunique(self, dropna: bool = True) -> T_Frame:

return self._reduce_for_stat_function(stat_function, only_numeric=False)

def rolling(self, window: int, min_periods: Optional[int] = None) -> RollingGroupby:
def rolling(self, window: int, min_periods: Optional[int] = None) -> "RollingGroupby[T_Frame]":
"""
Return an rolling grouper, providing rolling
functionality per group.
Expand All @@ -2345,11 +2349,11 @@ def rolling(self, window: int, min_periods: Optional[int] = None) -> RollingGrou
Series.groupby
DataFrame.groupby
"""
return RollingGroupby(
cast(Union[SeriesGroupBy, DataFrameGroupBy], self), window, min_periods=min_periods
)
from pyspark.pandas.window import RollingGroupby

def expanding(self, min_periods: int = 1) -> ExpandingGroupby:
return RollingGroupby(self, window, min_periods=min_periods)

def expanding(self, min_periods: int = 1) -> "ExpandingGroupby[T_Frame]":
"""
Return an expanding grouper, providing expanding
functionality per group.
Expand All @@ -2369,9 +2373,9 @@ def expanding(self, min_periods: int = 1) -> ExpandingGroupby:
Series.groupby
DataFrame.groupby
"""
return ExpandingGroupby(
cast(Union[SeriesGroupBy, DataFrameGroupBy], self), min_periods=min_periods
)
from pyspark.pandas.window import ExpandingGroupby

return ExpandingGroupby(self, min_periods=min_periods)

def get_group(self, name: Union[Any, Tuple, List[Union[Any, Tuple]]]) -> T_Frame:
"""
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -6216,6 +6216,20 @@ def _reduce_for_stat_function(
result = unpack_scalar(self._internal.spark_frame.select(scol))
return result if result is not None else np.nan

# Override the `groupby` to specify the actual return type annotation.
def groupby(
self,
by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]],
axis: Union[int, str_type] = 0,
as_index: bool = True,
dropna: bool = True,
) -> "SeriesGroupBy":
return cast(
"SeriesGroupBy", super().groupby(by=by, axis=axis, as_index=as_index, dropna=dropna)
)

groupby.__doc__ = Frame.groupby.__doc__

def _build_groupby(
self, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool
) -> "SeriesGroupBy":
Expand Down

0 comments on commit 53febc2

Please sign in to comment.