Skip to content

Commit

Permalink
[SPARK-47933][PYTHON] Parent Column class for Spark Connect and Spark…
Browse files Browse the repository at this point in the history
… Classic

### What changes were proposed in this pull request?

Same as apache#46129 but for `Column` class.

### Why are the changes needed?

Same as apache#46129

### Does this PR introduce _any_ user-facing change?

Same as apache#46129

### How was this patch tested?

Manually tested, and CI should verify them.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#46155 from HyukjinKwon/SPARK-47933.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon authored and JacobZheng0927 committed May 11, 2024
1 parent 20ddbb3 commit 1745397
Show file tree
Hide file tree
Showing 37 changed files with 1,766 additions and 917 deletions.
2 changes: 1 addition & 1 deletion dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def __hash__(self):
"pyspark.sql.session",
"pyspark.sql.conf",
"pyspark.sql.catalog",
"pyspark.sql.column",
"pyspark.sql.classic.column",
"pyspark.sql.classic.dataframe",
"pyspark.sql.datasource",
"pyspark.sql.group",
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#
from pyspark.ml import functions as PyMLFunctions
from pyspark.sql.connect.column import Column
from pyspark.sql.column import Column
from pyspark.sql.connect.functions.builtin import _invoke_function, _to_col, lit


Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
pass # Let it throw a better error message later when the API is invoked.

from pyspark.sql.functions import pandas_udf
from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.classic.column import Column, _to_java_column
from pyspark.sql.types import (
ArrayType,
ByteType,
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/ml/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from pyspark.ml.common import _java2py, _py2java
from pyspark.ml.linalg import Matrix, Vector
from pyspark.ml.wrapper import JavaWrapper, _jvm
from pyspark.sql.column import Column, _to_seq
from pyspark.sql.column import Column
from pyspark.sql.classic.column import _to_seq
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import lit

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ def attach_distributed_sequence_column(

return sdf.select(
ConnectColumn(DistributedSequenceID()).alias(column_name),
"*", # type: ignore[call-overload]
"*",
)
else:
return PySparkDataFrame(
Expand Down
44 changes: 22 additions & 22 deletions python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def product(col: Column, dropna: bool) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

return _invoke_function_over_columns( # type: ignore[return-value]
return _invoke_function_over_columns(
"pandas_product",
col, # type: ignore[arg-type]
col,
lit(dropna),
)

Expand All @@ -42,9 +42,9 @@ def stddev(col: Column, ddof: int) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

return _invoke_function_over_columns( # type: ignore[return-value]
return _invoke_function_over_columns(
"pandas_stddev",
col, # type: ignore[arg-type]
col,
lit(ddof),
)

Expand All @@ -59,9 +59,9 @@ def var(col: Column, ddof: int) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

return _invoke_function_over_columns( # type: ignore[return-value]
return _invoke_function_over_columns(
"pandas_var",
col, # type: ignore[arg-type]
col,
lit(ddof),
)

Expand All @@ -76,9 +76,9 @@ def skew(col: Column) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns

return _invoke_function_over_columns( # type: ignore[return-value]
return _invoke_function_over_columns(
"pandas_skew",
col, # type: ignore[arg-type]
col,
)

else:
Expand All @@ -92,9 +92,9 @@ def kurt(col: Column) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns

return _invoke_function_over_columns( # type: ignore[return-value]
return _invoke_function_over_columns(
"pandas_kurt",
col, # type: ignore[arg-type]
col,
)

else:
Expand All @@ -108,9 +108,9 @@ def mode(col: Column, dropna: bool) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

return _invoke_function_over_columns( # type: ignore[return-value]
return _invoke_function_over_columns(
"pandas_mode",
col, # type: ignore[arg-type]
col,
lit(dropna),
)

Expand All @@ -125,10 +125,10 @@ def covar(col1: Column, col2: Column, ddof: int) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

return _invoke_function_over_columns( # type: ignore[return-value]
return _invoke_function_over_columns(
"pandas_covar",
col1, # type: ignore[arg-type]
col2, # type: ignore[arg-type]
col1,
col2,
lit(ddof),
)

Expand All @@ -143,9 +143,9 @@ def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

return _invoke_function_over_columns( # type: ignore[return-value]
return _invoke_function_over_columns(
"ewm",
col, # type: ignore[arg-type]
col,
lit(alpha),
lit(ignore_na),
)
Expand All @@ -161,9 +161,9 @@ def null_index(col: Column) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns

return _invoke_function_over_columns( # type: ignore[return-value]
return _invoke_function_over_columns(
"null_index",
col, # type: ignore[arg-type]
col,
)

else:
Expand All @@ -177,11 +177,11 @@ def timestampdiff(unit: str, start: Column, end: Column) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

return _invoke_function_over_columns( # type: ignore[return-value]
return _invoke_function_over_columns(
"timestampdiff",
lit(unit),
start, # type: ignore[arg-type]
end, # type: ignore[arg-type]
start,
end,
)

else:
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/sql/avro/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from typing import Dict, Optional, TYPE_CHECKING, cast

from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.column import Column
from pyspark.sql.utils import get_active_spark_context, try_remote_avro_functions
from pyspark.util import _print_missing_jar

Expand Down Expand Up @@ -78,6 +78,7 @@ def from_avro(
[Row(value=Row(avro=Row(age=2, name='Alice')))]
"""
from py4j.java_gateway import JVMView
from pyspark.sql.classic.column import _to_java_column

sc = get_active_spark_context()
try:
Expand Down Expand Up @@ -128,6 +129,7 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
[Row(suite=bytearray(b'\\x02\\x00'))]
"""
from py4j.java_gateway import JVMView
from pyspark.sql.classic.column import _to_java_column

sc = get_active_spark_context()
try:
Expand Down
Loading

0 comments on commit 1745397

Please sign in to comment.