Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43631][CONNECT][PS] Enable Series.interpolate with Spark Connect #41670

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1768,6 +1768,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
val ignoreNA = extractBoolean(children(2), "ignoreNA")
Some(EWM(children(0), alpha, ignoreNA))

case "last_non_null" if fun.getArgumentsCount == 1 =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
Some(LastNonNull(children(0)))

case "null_index" if fun.getArgumentsCount == 1 =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
Some(NullIndex(children(0)))

// ML-specific functions
case "vector_to_array" if fun.getArgumentsCount == 2 =>
val expr = transformExpression(fun.getArguments(0))
Expand Down
9 changes: 4 additions & 5 deletions python/pyspark/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
CategoricalDtype,
)
from pandas.tseries.frequencies import DateOffset
from pyspark import SparkContext
from pyspark.sql import functions as F, Column as PySparkColumn, DataFrame as SparkDataFrame
from pyspark.sql.types import (
ArrayType,
Expand All @@ -70,7 +69,7 @@
TimestampType,
)
from pyspark.sql.window import Window
from pyspark.sql.utils import get_column_class
from pyspark.sql.utils import get_column_class, get_window_class

from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import Axis, Dtype, Label, Name, Scalar, T
Expand Down Expand Up @@ -2257,10 +2256,10 @@ def _interpolate(
return self._psdf.copy()._psser_for(self._column_label)

scol = self.spark.column
sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
last_non_null = PySparkColumn(sql_utils.lastNonNull(scol._jc))
null_index = PySparkColumn(sql_utils.nullIndex(scol._jc))
last_non_null = SF.last_non_null(scol)
null_index = SF.null_index(scol)

Window = get_window_class()
window_forward = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
Window.unboundedPreceding, Window.currentRow
)
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,31 @@ def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
else:
sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na))


def last_non_null(col: Column) -> Column:
if is_remote():
from pyspark.sql.connect.functions import _invoke_function_over_columns

return _invoke_function_over_columns( # type: ignore[return-value]
"last_non_null",
col, # type: ignore[arg-type]
)

else:
sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.lastNonNull(col._jc))


def null_index(col: Column) -> Column:
if is_remote():
from pyspark.sql.connect.functions import _invoke_function_over_columns

return _invoke_function_over_columns( # type: ignore[return-value]
"null_index",
col, # type: ignore[arg-type]
)

else:
sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
class GenericFunctionsParityTests(
GenericFunctionsTestsMixin, TestUtils, PandasOnSparkTestUtils, ReusedConnectTestCase
):
@unittest.skip("TODO(SPARK-43631): Enable Series.interpolate with Spark Connect.")
@unittest.skip(
"TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client."
)
def test_interpolate(self):
super().test_interpolate()

Expand Down
14 changes: 13 additions & 1 deletion python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from pyspark.sql.session import SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.column import Column
from pyspark.sql.window import Window
from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex

has_numpy = False
Expand Down Expand Up @@ -188,7 +189,7 @@ def try_remote_window(f: FuncT) -> FuncT:
def wrapped(*args: Any, **kwargs: Any) -> Any:

if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
from pyspark.sql.connect.window import Window
from pyspark.sql.connect.window import Window # type: ignore[misc]

return getattr(Window, f.__name__)(*args, **kwargs)
else:
Expand Down Expand Up @@ -282,3 +283,14 @@ def get_dataframe_class() -> Type["DataFrame"]:
return ConnectDataFrame # type: ignore[return-value]
else:
return PySparkDataFrame


def get_window_class() -> Type["Window"]:
from pyspark.sql.window import Window as PySparkWindow

if is_remote():
from pyspark.sql.connect.window import Window as ConnectWindow

return ConnectWindow # type: ignore[return-value]
else:
return PySparkWindow