From 72eab9e897c8249ef45f7804083177aa5c29b505 Mon Sep 17 00:00:00 2001 From: itholic Date: Tue, 20 Jun 2023 15:29:20 +0900 Subject: [PATCH 1/3] [SPARK-43631][CONNECT][PS] Enable Series.interpolate with Spark Connect --- .../connect/planner/SparkConnectPlanner.scala | 8 ++ python/pyspark/pandas/series.py | 85 +++++++++++++------ .../connect/test_parity_generic_functions.py | 4 +- python/pyspark/sql/utils.py | 14 ++- 4 files changed, 83 insertions(+), 28 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index dc819fb4020e9..92d2598b61ffb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1688,6 +1688,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { val ignoreNA = extractBoolean(children(2), "ignoreNA") Some(EWM(children(0), alpha, ignoreNA)) + case "last_non_null" if fun.getArgumentsCount == 1 => + val children = fun.getArgumentsList.asScala.map(transformExpression) + Some(LastNonNull(children(0))) + + case "null_index" if fun.getArgumentsCount == 1 => + val children = fun.getArgumentsList.asScala.map(transformExpression) + Some(NullIndex(children(0))) + // ML-specific functions case "vector_to_array" if fun.getArgumentsCount == 2 => val expr = transformExpression(fun.getArguments(0)) diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 0f1e814946ac4..f78ef0bbb011d 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -70,7 +70,7 @@ TimestampType, ) from pyspark.sql.window import Window -from pyspark.sql.utils import get_column_class +from pyspark.sql.utils import get_column_class, is_remote, get_window_class from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark.pandas._typing import Axis, Dtype, Label, Name, Scalar, T @@ -2257,27 +2257,44 @@ def _interpolate( return self._psdf.copy()._psser_for(self._column_label) scol = self.spark.column - sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils - last_non_null = PySparkColumn(sql_utils.lastNonNull(scol._jc)) - null_index = PySparkColumn(sql_utils.nullIndex(scol._jc)) + if is_remote(): + from pyspark.sql.connect.functions import _invoke_function_over_columns + last_non_null = _invoke_function_over_columns( + "last_non_null", + scol, # type: ignore[arg-type] + ) + null_index = _invoke_function_over_columns( + "null_index", + scol, # type: ignore[arg-type] + ) + else: + sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils + last_non_null = PySparkColumn( + sql_utils.lastNonNull(scol._jc) # type: ignore[assignment] + ) + null_index = PySparkColumn(sql_utils.nullIndex(scol._jc)) # type: ignore[assignment] + + Window = get_window_class() window_forward = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween( Window.unboundedPreceding, Window.currentRow ) - last_non_null_forward = last_non_null.over(window_forward) - null_index_forward = null_index.over(window_forward) + last_non_null_forward = last_non_null.over(window_forward) # type: ignore[arg-type] + null_index_forward = null_index.over(window_forward) # type: ignore[arg-type] window_backward = Window.orderBy(F.desc(NATURAL_ORDER_COLUMN_NAME)).rowsBetween( Window.unboundedPreceding, Window.currentRow ) - last_non_null_backward = last_non_null.over(window_backward) - null_index_backward = null_index.over(window_backward) + last_non_null_backward = last_non_null.over(window_backward) # type: ignore[arg-type] + null_index_backward = null_index.over(window_backward) # type: ignore[arg-type] fill = (last_non_null_backward - last_non_null_forward) / ( null_index_backward + null_index_forward ) * null_index_forward + last_non_null_forward - fill_cond = ~F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward) + fill_cond = ~F.isnull(last_non_null_backward) & ~F.isnull( # type: ignore[arg-type] + last_non_null_forward # type: ignore[arg-type] + ) pad_head = F.lit(None) pad_head_cond = F.lit(False) @@ -2287,35 +2304,55 @@ def _interpolate( # inputs -> NaN, NaN, 1.0, NaN, NaN, NaN, 5.0, NaN, NaN if limit_direction is None or limit_direction == "forward": # outputs -> NaN, NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0 - pad_tail = last_non_null_forward - pad_tail_cond = F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward) + pad_tail = last_non_null_forward # type: ignore[assignment] + pad_tail_cond = F.isnull(last_non_null_backward) & ~F.isnull( # type: ignore[arg-type] + last_non_null_forward # type: ignore[arg-type] + ) if limit is not None: # outputs (limit=1) -> NaN, NaN, 1.0, 2.0, NaN, NaN, 5.0, 5.0, NaN - fill_cond = fill_cond & (null_index_forward <= F.lit(limit)) - pad_tail_cond = pad_tail_cond & (null_index_forward <= F.lit(limit)) + fill_cond = fill_cond & ( + null_index_forward <= F.lit(limit) # type: ignore[assignment] + ) + pad_tail_cond = pad_tail_cond & ( + null_index_forward <= F.lit(limit) # type: ignore[assignment] + ) elif limit_direction == "backward": # outputs -> 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, NaN, NaN - pad_head = last_non_null_backward - pad_head_cond = ~F.isnull(last_non_null_backward) & F.isnull(last_non_null_forward) + pad_head = last_non_null_backward # type: ignore[assignment] + pad_head_cond = ~F.isnull(last_non_null_backward) & F.isnull( # type: ignore[arg-type] + last_non_null_forward # type: ignore[arg-type] + ) if limit is not None: # outputs (limit=1) -> NaN, 1.0, 1.0, NaN, NaN, 4.0, 5.0, NaN, NaN - fill_cond = fill_cond & (null_index_backward <= F.lit(limit)) - pad_head_cond = pad_head_cond & (null_index_backward <= F.lit(limit)) + fill_cond = fill_cond & ( + null_index_backward <= F.lit(limit) # type: ignore[assignment] + ) + pad_head_cond = pad_head_cond & ( + null_index_backward <= F.lit(limit) # type: ignore[assignment] + ) else: # outputs -> 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0 - pad_head = last_non_null_backward - pad_head_cond = ~F.isnull(last_non_null_backward) & F.isnull(last_non_null_forward) - pad_tail = last_non_null_forward - pad_tail_cond = F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward) + pad_head = last_non_null_backward # type: ignore[assignment] + pad_head_cond = ~F.isnull(last_non_null_backward) & F.isnull( # type: ignore[arg-type] + last_non_null_forward # type: ignore[arg-type] + ) + pad_tail = last_non_null_forward # type: ignore[assignment] + pad_tail_cond = F.isnull(last_non_null_backward) & ~F.isnull( # type: ignore[arg-type] + last_non_null_forward # type: ignore[arg-type] + ) if limit is not None: # outputs (limit=1) -> NaN, 1.0, 1.0, 2.0, NaN, 4.0, 5.0, 5.0, NaN - fill_cond = fill_cond & ( + fill_cond = fill_cond & ( # type: ignore[assignment] (null_index_forward <= F.lit(limit)) | (null_index_backward <= F.lit(limit)) ) - pad_head_cond = pad_head_cond & (null_index_backward <= F.lit(limit)) - pad_tail_cond = pad_tail_cond & (null_index_forward <= F.lit(limit)) + pad_head_cond = pad_head_cond & ( + null_index_backward <= F.lit(limit) # type: ignore[assignment] + ) + pad_tail_cond = pad_tail_cond & ( + null_index_forward <= F.lit(limit) # type: ignore[assignment] + ) if limit_area == "inside": pad_head_cond = F.lit(False) diff --git a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py index d2c05893ae23b..158215073ad9f 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +++ b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py @@ -24,9 +24,7 @@ class GenericFunctionsParityTests( GenericFunctionsTestsMixin, TestUtils, PandasOnSparkTestUtils, ReusedConnectTestCase ): - @unittest.skip("TODO(SPARK-43631): Enable Series.interpolate with Spark Connect.") - def test_interpolate(self): - super().test_interpolate() + pass if __name__ == "__main__": diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 7ecfa65dcd13f..608ed7e9ac9f8 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -46,6 +46,7 @@ from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.column import Column + from pyspark.sql.window import Window from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex has_numpy = False @@ -188,7 +189,7 @@ def try_remote_window(f: FuncT) -> FuncT: def wrapped(*args: Any, **kwargs: Any) -> Any: if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: - from pyspark.sql.connect.window import Window + from pyspark.sql.connect.window import Window # type: ignore[misc] return getattr(Window, f.__name__)(*args, **kwargs) else: @@ -282,3 +283,14 @@ def get_dataframe_class() -> Type["DataFrame"]: return ConnectDataFrame # type: ignore[return-value] else: return PySparkDataFrame + + +def get_window_class() -> Type["Window"]: + from pyspark.sql.window import Window as PySparkWindow + + if is_remote(): + from pyspark.sql.connect.window import Window as ConnectWindow + + return ConnectWindow # type: ignore[return-value] + else: + return PySparkWindow From 27c95741adab2272e15ebdbf1de79f238f7d6b10 Mon Sep 17 00:00:00 2001 From: itholic Date: Thu, 22 Jun 2023 19:27:55 +0900 Subject: [PATCH 2/3] Adjusted comments --- python/pyspark/pandas/series.py | 84 +++++-------------- python/pyspark/pandas/spark/functions.py | 28 +++++++ .../connect/test_parity_generic_functions.py | 6 +- 3 files changed, 56 insertions(+), 62 deletions(-) diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index f78ef0bbb011d..95ca92e78787d 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -53,7 +53,6 @@ CategoricalDtype, ) from pandas.tseries.frequencies import DateOffset -from pyspark import SparkContext from pyspark.sql import functions as F, Column as PySparkColumn, DataFrame as SparkDataFrame from pyspark.sql.types import ( ArrayType, @@ -70,7 +69,7 @@ TimestampType, ) from pyspark.sql.window import Window -from pyspark.sql.utils import get_column_class, is_remote, get_window_class +from pyspark.sql.utils import get_column_class, get_window_class from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark.pandas._typing import Axis, Dtype, Label, Name, Scalar, T @@ -2257,44 +2256,27 @@ def _interpolate( return self._psdf.copy()._psser_for(self._column_label) scol = self.spark.column - if is_remote(): - from pyspark.sql.connect.functions import _invoke_function_over_columns - - last_non_null = _invoke_function_over_columns( - "last_non_null", - scol, # type: ignore[arg-type] - ) - null_index = _invoke_function_over_columns( - "null_index", - scol, # type: ignore[arg-type] - ) - else: - sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils - last_non_null = PySparkColumn( - sql_utils.lastNonNull(scol._jc) # type: ignore[assignment] - ) - null_index = PySparkColumn(sql_utils.nullIndex(scol._jc)) # type: ignore[assignment] + last_non_null = SF.last_non_null(scol) + null_index = SF.null_index(scol) Window = get_window_class() window_forward = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween( Window.unboundedPreceding, Window.currentRow ) - last_non_null_forward = last_non_null.over(window_forward) # type: ignore[arg-type] - null_index_forward = null_index.over(window_forward) # type: ignore[arg-type] + last_non_null_forward = last_non_null.over(window_forward) + null_index_forward = null_index.over(window_forward) window_backward = Window.orderBy(F.desc(NATURAL_ORDER_COLUMN_NAME)).rowsBetween( Window.unboundedPreceding, Window.currentRow ) - last_non_null_backward = last_non_null.over(window_backward) # type: ignore[arg-type] - null_index_backward = null_index.over(window_backward) # type: ignore[arg-type] + last_non_null_backward = last_non_null.over(window_backward) + null_index_backward = null_index.over(window_backward) fill = (last_non_null_backward - last_non_null_forward) / ( null_index_backward + null_index_forward ) * null_index_forward + last_non_null_forward - fill_cond = ~F.isnull(last_non_null_backward) & ~F.isnull( # type: ignore[arg-type] - last_non_null_forward # type: ignore[arg-type] - ) + fill_cond = ~F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward) pad_head = F.lit(None) pad_head_cond = F.lit(False) @@ -2304,55 +2286,35 @@ def _interpolate( # inputs -> NaN, NaN, 1.0, NaN, NaN, NaN, 5.0, NaN, NaN if limit_direction is None or limit_direction == "forward": # outputs -> NaN, NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0 - pad_tail = last_non_null_forward # type: ignore[assignment] - pad_tail_cond = F.isnull(last_non_null_backward) & ~F.isnull( # type: ignore[arg-type] - last_non_null_forward # type: ignore[arg-type] - ) + pad_tail = last_non_null_forward + pad_tail_cond = F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward) if limit is not None: # outputs (limit=1) -> NaN, NaN, 1.0, 2.0, NaN, NaN, 5.0, 5.0, NaN - fill_cond = fill_cond & ( - null_index_forward <= F.lit(limit) # type: ignore[assignment] - ) - pad_tail_cond = pad_tail_cond & ( - null_index_forward <= F.lit(limit) # type: ignore[assignment] - ) + fill_cond = fill_cond & (null_index_forward <= F.lit(limit)) + pad_tail_cond = pad_tail_cond & (null_index_forward <= F.lit(limit)) elif limit_direction == "backward": # outputs -> 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, NaN, NaN - pad_head = last_non_null_backward # type: ignore[assignment] - pad_head_cond = ~F.isnull(last_non_null_backward) & F.isnull( # type: ignore[arg-type] - last_non_null_forward # type: ignore[arg-type] - ) + pad_head = last_non_null_backward + pad_head_cond = ~F.isnull(last_non_null_backward) & F.isnull(last_non_null_forward) if limit is not None: # outputs (limit=1) -> NaN, 1.0, 1.0, NaN, NaN, 4.0, 5.0, NaN, NaN - fill_cond = fill_cond & ( - null_index_backward <= F.lit(limit) # type: ignore[assignment] - ) - pad_head_cond = pad_head_cond & ( - null_index_backward <= F.lit(limit) # type: ignore[assignment] - ) + fill_cond = fill_cond & (null_index_backward <= F.lit(limit)) + pad_head_cond = pad_head_cond & (null_index_backward <= F.lit(limit)) else: # outputs -> 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0 - pad_head = last_non_null_backward # type: ignore[assignment] - pad_head_cond = ~F.isnull(last_non_null_backward) & F.isnull( # type: ignore[arg-type] - last_non_null_forward # type: ignore[arg-type] - ) - pad_tail = last_non_null_forward # type: ignore[assignment] - pad_tail_cond = F.isnull(last_non_null_backward) & ~F.isnull( # type: ignore[arg-type] - last_non_null_forward # type: ignore[arg-type] - ) + pad_head = last_non_null_backward + pad_head_cond = ~F.isnull(last_non_null_backward) & F.isnull(last_non_null_forward) + pad_tail = last_non_null_forward + pad_tail_cond = F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward) if limit is not None: # outputs (limit=1) -> NaN, 1.0, 1.0, 2.0, NaN, 4.0, 5.0, 5.0, NaN - fill_cond = fill_cond & ( # type: ignore[assignment] + fill_cond = fill_cond & ( (null_index_forward <= F.lit(limit)) | (null_index_backward <= F.lit(limit)) ) - pad_head_cond = pad_head_cond & ( - null_index_backward <= F.lit(limit) # type: ignore[assignment] - ) - pad_tail_cond = pad_tail_cond & ( - null_index_forward <= F.lit(limit) # type: ignore[assignment] - ) + pad_head_cond = pad_head_cond & (null_index_backward <= F.lit(limit)) + pad_tail_cond = pad_tail_cond & (null_index_forward <= F.lit(limit)) if limit_area == "inside": pad_head_cond = F.lit(False) diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index b33705263c70a..c87bc49cd9202 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -141,3 +141,31 @@ def repeat(col: Column, n: Union[int, Column]) -> Column: """ _n = F.lit(n) if isinstance(n, int) else n return F.call_udf("repeat", col, _n) + + +def last_non_null(col: Column) -> Column: + if is_remote(): + from pyspark.sql.connect.functions import _invoke_function_over_columns + + return _invoke_function_over_columns( + "last_non_null", + col, # type: ignore[arg-type] + ) + + else: + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.lastNonNull(col._jc)) + + +def null_index(col: Column) -> Column: + if is_remote(): + from pyspark.sql.connect.functions import _invoke_function_over_columns + + return _invoke_function_over_columns( + "null_index", + col, # type: ignore[arg-type] + ) + + else: + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc)) diff --git a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py index 158215073ad9f..1bf2650d87425 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +++ b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py @@ -24,7 +24,11 @@ class GenericFunctionsParityTests( GenericFunctionsTestsMixin, TestUtils, PandasOnSparkTestUtils, ReusedConnectTestCase ): - pass + @unittest.skip( + "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." + ) + def test_interpolate(self): + super().test_interpolate() if __name__ == "__main__": From dc60fac9baccb98d48f784cba6c94a8c577433ce Mon Sep 17 00:00:00 2001 From: itholic Date: Fri, 23 Jun 2023 12:26:00 +0900 Subject: [PATCH 3/3] fix linter --- python/pyspark/pandas/spark/functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index c87bc49cd9202..4f2624483dbd1 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -147,7 +147,7 @@ def last_non_null(col: Column) -> Column: if is_remote(): from pyspark.sql.connect.functions import _invoke_function_over_columns - return _invoke_function_over_columns( + return _invoke_function_over_columns( # type: ignore[return-value] "last_non_null", col, # type: ignore[arg-type] ) @@ -161,7 +161,7 @@ def null_index(col: Column) -> Column: if is_remote(): from pyspark.sql.connect.functions import _invoke_function_over_columns - return _invoke_function_over_columns( + return _invoke_function_over_columns( # type: ignore[return-value] "null_index", col, # type: ignore[arg-type] )