Skip to content

Commit

Permalink
[SPARK-48504][PYTHON][CONNECT] Parent Window class for Spark Connect …
Browse files Browse the repository at this point in the history
…and Spark Classic

### What changes were proposed in this pull request?
 Parent Window class for Spark Connect and Spark Classic

### Why are the changes needed?
Same as #46129

### Does this PR introduce _any_ user-facing change?
Same as #46129

### How was this patch tested?
CI

### Was this patch authored or co-authored using generative AI tooling?
NO

Closes #46841 from zhengruifeng/py_parent_window.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Jun 3, 2024
1 parent 8d9d9c2 commit 67ad55c
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 162 deletions.
2 changes: 1 addition & 1 deletion dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def __hash__(self):
"pyspark.sql.catalog",
"pyspark.sql.classic.column",
"pyspark.sql.classic.dataframe",
"pyspark.sql.classic.window",
"pyspark.sql.datasource",
"pyspark.sql.group",
"pyspark.sql.functions.builtin",
Expand All @@ -488,7 +489,6 @@ def __hash__(self):
"pyspark.sql.streaming.listener",
"pyspark.sql.udf",
"pyspark.sql.udtf",
"pyspark.sql.window",
"pyspark.sql.avro.functions",
"pyspark.sql.protobuf.functions",
"pyspark.sql.pandas.conversion",
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/classic/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def otherwise(self, value: Any) -> ParentColumn:
return Column(jc)

def over(self, window: "WindowSpec") -> ParentColumn:
from pyspark.sql.window import WindowSpec
from pyspark.sql.classic.window import WindowSpec

if not isinstance(window, WindowSpec):
raise PySparkTypeError(
Expand Down
146 changes: 146 additions & 0 deletions python/pyspark/sql/classic/window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union

from pyspark.sql.window import (
Window as ParentWindow,
WindowSpec as ParentWindowSpec,
)
from pyspark.sql.utils import get_active_spark_context

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
from pyspark.sql._typing import ColumnOrName, ColumnOrName_


__all__ = ["Window", "WindowSpec"]


def _to_java_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...]) -> "JavaObject":
from pyspark.sql.classic.column import _to_seq, _to_java_column

if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0] # type: ignore[assignment]
sc = get_active_spark_context()
return _to_seq(sc, cast(Iterable["ColumnOrName"], cols), _to_java_column)


class Window(ParentWindow):
@staticmethod
def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
from py4j.java_gateway import JVMView

sc = get_active_spark_context()
jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.partitionBy(
_to_java_cols(cols)
)
return WindowSpec(jspec)

@staticmethod
def orderBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
from py4j.java_gateway import JVMView

sc = get_active_spark_context()
jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.orderBy(
_to_java_cols(cols)
)
return WindowSpec(jspec)

@staticmethod
def rowsBetween(start: int, end: int) -> ParentWindowSpec:
from py4j.java_gateway import JVMView

if start <= Window._PRECEDING_THRESHOLD:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
end = Window.unboundedFollowing
sc = get_active_spark_context()
jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.rowsBetween(
start, end
)
return WindowSpec(jspec)

@staticmethod
def rangeBetween(start: int, end: int) -> ParentWindowSpec:
from py4j.java_gateway import JVMView

if start <= Window._PRECEDING_THRESHOLD:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
end = Window.unboundedFollowing
sc = get_active_spark_context()
jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.rangeBetween(
start, end
)
return WindowSpec(jspec)


class WindowSpec(ParentWindowSpec):
def __new__(cls, jspec: "JavaObject") -> "WindowSpec":
self = object.__new__(cls)
self.__init__(jspec) # type: ignore[misc]
return self

def __init__(self, jspec: "JavaObject") -> None:
self._jspec = jspec

def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
return WindowSpec(self._jspec.partitionBy(_to_java_cols(cols)))

def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
return WindowSpec(self._jspec.orderBy(_to_java_cols(cols)))

def rowsBetween(self, start: int, end: int) -> ParentWindowSpec:
if start <= Window._PRECEDING_THRESHOLD:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
end = Window.unboundedFollowing
return WindowSpec(self._jspec.rowsBetween(start, end))

def rangeBetween(self, start: int, end: int) -> ParentWindowSpec:
if start <= Window._PRECEDING_THRESHOLD:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
end = Window.unboundedFollowing
return WindowSpec(self._jspec.rangeBetween(start, end))


def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.window

# It inherits docstrings but doctests cannot detect them so we run
# the parent classe's doctests here directly.
globs = pyspark.sql.window.__dict__.copy()
spark = (
SparkSession.builder.master("local[4]").appName("sql.classic.window tests").getOrCreate()
)
globs["spark"] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.window,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
)
spark.stop()
if failure_count:
sys.exit(-1)


if __name__ == "__main__":
_test()
3 changes: 2 additions & 1 deletion python/pyspark/sql/connect/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#
from types import FunctionType
from typing import Any, Callable, Iterable, Union, Optional, NewType, Protocol, Tuple
from typing import Any, Callable, Iterable, Union, Optional, NewType, Protocol, Tuple, TypeVar
import datetime
import decimal

Expand All @@ -28,6 +28,7 @@


ColumnOrName = Union[Column, str]
ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName)

ColumnOrNameOrOrdinal = Union[Column, str, int]

Expand Down
58 changes: 29 additions & 29 deletions python/pyspark/sql/connect/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,23 @@

check_dependencies(__name__)

import sys
from typing import TYPE_CHECKING, Union, Sequence, List, Optional

from pyspark.sql.column import Column
from pyspark.sql.window import (
Window as ParentWindow,
WindowSpec as ParentWindowSpec,
)
from pyspark.sql.connect.expressions import (
ColumnReference,
Expression,
SortOrder,
)
from pyspark.util import (
JVM_LONG_MIN,
JVM_LONG_MAX,
)
from pyspark.sql.window import Window as PySparkWindow, WindowSpec as PySparkWindowSpec
from pyspark.errors import PySparkTypeError

if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
from pyspark.sql.connect._typing import ColumnOrName, ColumnOrName_

__all__ = ["Window", "WindowSpec"]

Expand Down Expand Up @@ -63,7 +62,17 @@ def __repr__(self) -> str:
return f"WindowFrame(RANGE_FRAME, {self._start}, {self._end})"


class WindowSpec:
class WindowSpec(ParentWindowSpec):
def __new__(
cls,
partitionSpec: Sequence[Expression],
orderSpec: Sequence[SortOrder],
frame: Optional[WindowFrame],
) -> "WindowSpec":
self = object.__new__(cls)
self.__init__(partitionSpec, orderSpec, frame) # type: ignore[misc]
return self

def __init__(
self,
partitionSpec: Sequence[Expression],
Expand All @@ -84,7 +93,7 @@ def __init__(

self._frame = frame

def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName"]]) -> "WindowSpec":
def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
_cols: List[ColumnOrName] = []
for col in cols:
if isinstance(col, (str, Column)):
Expand All @@ -105,19 +114,19 @@ def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName"]]) -> "Wi
)

newPartitionSpec: List[Expression] = []
for c in _cols:
for c in _cols: # type: ignore[assignment]
if isinstance(c, Column):
newPartitionSpec.append(c._expr) # type: ignore[arg-type]
else:
newPartitionSpec.append(ColumnReference(c))
newPartitionSpec.append(ColumnReference(c)) # type: ignore[arg-type]

return WindowSpec(
partitionSpec=newPartitionSpec,
orderSpec=self._orderSpec,
frame=self._frame,
)

def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName"]]) -> "WindowSpec":
def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
_cols: List[ColumnOrName] = []
for col in cols:
if isinstance(col, (str, Column)):
Expand All @@ -138,22 +147,22 @@ def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName"]]) -> "Window
)

newOrderSpec: List[SortOrder] = []
for c in _cols:
for c in _cols: # type: ignore[assignment]
if isinstance(c, Column):
if isinstance(c._expr, SortOrder):
newOrderSpec.append(c._expr)
else:
newOrderSpec.append(SortOrder(c._expr)) # type: ignore[arg-type]
else:
newOrderSpec.append(SortOrder(ColumnReference(c)))
newOrderSpec.append(SortOrder(ColumnReference(c))) # type: ignore[arg-type]

return WindowSpec(
partitionSpec=self._partitionSpec,
orderSpec=newOrderSpec,
frame=self._frame,
)

def rowsBetween(self, start: int, end: int) -> "WindowSpec":
def rowsBetween(self, start: int, end: int) -> ParentWindowSpec:
if start <= Window._PRECEDING_THRESHOLD:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
Expand All @@ -165,7 +174,7 @@ def rowsBetween(self, start: int, end: int) -> "WindowSpec":
frame=WindowFrame(isRowFrame=True, start=start, end=end),
)

def rangeBetween(self, start: int, end: int) -> "WindowSpec":
def rangeBetween(self, start: int, end: int) -> ParentWindowSpec:
if start <= Window._PRECEDING_THRESHOLD:
start = Window.unboundedPreceding
if end >= Window._FOLLOWING_THRESHOLD:
Expand Down Expand Up @@ -197,32 +206,23 @@ def __repr__(self) -> str:
WindowSpec.__doc__ = PySparkWindowSpec.__doc__


class Window:
_PRECEDING_THRESHOLD = max(-sys.maxsize, JVM_LONG_MIN)
_FOLLOWING_THRESHOLD = min(sys.maxsize, JVM_LONG_MAX)

unboundedPreceding: int = JVM_LONG_MIN

unboundedFollowing: int = JVM_LONG_MAX

currentRow: int = 0

class Window(ParentWindow):
_spec = WindowSpec(partitionSpec=[], orderSpec=[], frame=None)

@staticmethod
def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName"]]) -> "WindowSpec":
def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
return Window._spec.partitionBy(*cols)

@staticmethod
def orderBy(*cols: Union["ColumnOrName", List["ColumnOrName"]]) -> "WindowSpec":
def orderBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
return Window._spec.orderBy(*cols)

@staticmethod
def rowsBetween(start: int, end: int) -> "WindowSpec":
def rowsBetween(start: int, end: int) -> ParentWindowSpec:
return Window._spec.rowsBetween(start, end)

@staticmethod
def rangeBetween(start: int, end: int) -> "WindowSpec":
def rangeBetween(start: int, end: int) -> ParentWindowSpec:
return Window._spec.rangeBetween(start, end)


Expand Down
Loading

0 comments on commit 67ad55c

Please sign in to comment.