Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions python/pyspark/pandas/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
InternalFrame,
SPARK_INDEX_NAME_FORMAT,
SPARK_DEFAULT_SERIES_NAME,
SPARK_DEFAULT_INDEX_NAME,
SPARK_INDEX_NAME_PATTERN,
)
from pyspark.pandas.typedef import infer_return_type, DataFrameType, ScalarType, SeriesType
from pyspark.pandas.utils import (
Expand Down Expand Up @@ -384,8 +384,8 @@ def apply_batch(
"The given function should specify a frame as its type "
"hints; however, the return type was %s." % return_sig
)
index_field = cast(DataFrameType, return_type).index_field
should_retain_index = index_field is not None
index_fields = cast(DataFrameType, return_type).index_fields
should_retain_index = index_fields is not None
return_schema = cast(DataFrameType, return_type).spark_type

output_func = GroupBy._make_pandas_df_builder_func(
Expand All @@ -397,12 +397,19 @@ def apply_batch(

index_spark_columns = None
index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
index_fields = None

if should_retain_index:
index_spark_columns = [scol_for(sdf, index_field.struct_field.name)]
index_fields = [index_field]
if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
index_names = [(index_field.struct_field.name,)]
index_spark_columns = [
scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
]

if not any(
[
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
for index_field in index_fields
]
):
index_names = [(index_field.struct_field.name,) for index_field in index_fields]
internal = InternalFrame(
spark_frame=sdf,
index_names=index_names,
Expand Down Expand Up @@ -680,17 +687,19 @@ def udf(pdf: pd.DataFrame) -> pd.Series:
)
return first_series(DataFrame(internal))
else:
index_field = cast(DataFrameType, return_type).index_field
index_field = (
index_field.normalize_spark_type() if index_field is not None else None
index_fields = cast(DataFrameType, return_type).index_fields
index_fields = (
[index_field.normalize_spark_type() for index_field in index_fields]
if index_fields is not None
else None
)
data_fields = [
field.normalize_spark_type()
for field in cast(DataFrameType, return_type).data_fields
]
normalized_fields = ([index_field] if index_field is not None else []) + data_fields
normalized_fields = (index_fields if index_fields is not None else []) + data_fields
return_schema = StructType([field.struct_field for field in normalized_fields])
should_retain_index = index_field is not None
should_retain_index = index_fields is not None

self_applied = DataFrame(self._psdf._internal.resolved_copy)

Expand All @@ -711,12 +720,21 @@ def udf(pdf: pd.DataFrame) -> pd.Series:

index_spark_columns = None
index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
index_fields = None

if should_retain_index:
index_spark_columns = [scol_for(sdf, index_field.struct_field.name)]
index_fields = [index_field]
if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
index_names = [(index_field.struct_field.name,)]
index_spark_columns = [
scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
]

if not any(
[
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
for index_field in index_fields
]
):
index_names = [
(index_field.struct_field.name,) for index_field in index_fields
]
internal = InternalFrame(
spark_frame=sdf,
index_names=index_names,
Expand Down
24 changes: 16 additions & 8 deletions python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
SPARK_INDEX_NAME_FORMAT,
SPARK_DEFAULT_INDEX_NAME,
SPARK_DEFAULT_SERIES_NAME,
SPARK_INDEX_NAME_PATTERN,
)
from pyspark.pandas.missing.frame import _MissingPandasLikeDataFrame
from pyspark.pandas.ml import corr
Expand Down Expand Up @@ -2511,7 +2512,7 @@ def apply_func(pdf: pd.DataFrame) -> pd.DataFrame:
return_type = infer_return_type(func)
require_index_axis = isinstance(return_type, SeriesType)
require_column_axis = isinstance(return_type, DataFrameType)
index_field = None
index_fields = None

if require_index_axis:
if axis != 0:
Expand All @@ -2536,8 +2537,8 @@ def apply_func(pdf: pd.DataFrame) -> pd.DataFrame:
"hints when axis is 1 or 'column'; however, the return type "
"was %s" % return_sig
)
index_field = cast(DataFrameType, return_type).index_field
should_retain_index = index_field is not None
index_fields = cast(DataFrameType, return_type).index_fields
should_retain_index = index_fields is not None
data_fields = cast(DataFrameType, return_type).data_fields
return_schema = cast(DataFrameType, return_type).spark_type
else:
Expand Down Expand Up @@ -2565,12 +2566,19 @@ def apply_func(pdf: pd.DataFrame) -> pd.DataFrame:

index_spark_columns = None
index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
index_fields = None

if should_retain_index:
index_spark_columns = [scol_for(sdf, index_field.struct_field.name)]
index_fields = [index_field]
if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
index_names = [(index_field.struct_field.name,)]
index_spark_columns = [
scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
]

if not any(
[
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
for index_field in index_fields
]
):
index_names = [(index_field.struct_field.name,) for index_field in index_fields]
internal = InternalFrame(
spark_frame=sdf,
index_names=index_names,
Expand Down
23 changes: 15 additions & 8 deletions python/pyspark/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
NATURAL_ORDER_COLUMN_NAME,
SPARK_INDEX_NAME_FORMAT,
SPARK_DEFAULT_SERIES_NAME,
SPARK_DEFAULT_INDEX_NAME,
SPARK_INDEX_NAME_PATTERN,
)
from pyspark.pandas.missing.groupby import (
MissingPandasLikeDataFrameGroupBy,
Expand Down Expand Up @@ -1242,9 +1242,8 @@ def pandas_apply(pdf: pd.DataFrame, *a: Any, **k: Any) -> Any:
if isinstance(return_type, DataFrameType):
data_fields = cast(DataFrameType, return_type).data_fields
return_schema = cast(DataFrameType, return_type).spark_type
index_field = cast(DataFrameType, return_type).index_field
should_retain_index = index_field is not None
index_fields = [index_field]
index_fields = cast(DataFrameType, return_type).index_fields
should_retain_index = index_fields is not None
psdf_from_pandas = None
else:
should_return_series = True
Expand Down Expand Up @@ -1317,10 +1316,18 @@ def wrapped_func(
)
else:
index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
index_field = index_fields[0]
index_spark_columns = [scol_for(sdf, index_field.struct_field.name)]
if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
index_names = [(index_field.struct_field.name,)]

index_spark_columns = [
scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
]

if not any(
[
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
for index_field in index_fields
]
):
index_names = [(index_field.struct_field.name,) for index_field in index_fields]
internal = InternalFrame(
spark_frame=sdf,
index_names=index_names,
Expand Down
26 changes: 26 additions & 0 deletions python/pyspark/pandas/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4678,6 +4678,32 @@ def identify4(
actual.columns = ["a", "b"]
self.assert_eq(actual, pdf)

arrays = [[1, 2, 3, 4, 5, 6, 7, 8, 9], ["a", "b", "c", "d", "e", "f", "g", "h", "i"]]
idx = pd.MultiIndex.from_arrays(arrays, names=("number", "color"))
pdf = pd.DataFrame(
{"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [[e] for e in [4, 5, 6, 3, 2, 1, 0, 0, 0]]},
index=idx,
)
psdf = ps.from_pandas(pdf)

def identify4(x) -> ps.DataFrame[[int, str], [int, List[int]]]:
return x

actual = psdf.pandas_on_spark.apply_batch(identify4)
actual.index.names = ["number", "color"]
actual.columns = ["a", "b"]
self.assert_eq(actual, pdf)

def identify5(
x,
) -> ps.DataFrame[
[("number", int), ("color", str)], [("a", int), ("b", List[int])] # noqa: F405
]:
return x

actual = psdf.pandas_on_spark.apply_batch(identify5)
self.assert_eq(actual, pdf)

def test_transform_batch(self):
pdf = pd.DataFrame(
{
Expand Down
Loading