Navigation Menu

Skip to content

Commit

Permalink
[SPARK-36708][PYTHON] Support numpy.typing for annotating ArrayType i…
Browse files Browse the repository at this point in the history
…n pandas API on Spark

### What changes were proposed in this pull request?

This PR adds the support of understanding `numpy.typing` package that's added from NumPy 1.21.

### Why are the changes needed?

For user-friendly return type specification in type hints for function apply APIs in pandas API on Spark.

### Does this PR introduce _any_ user-facing change?

Yes, this PR will enable users to specify return type as `numpy.typing.NDArray[...]` to internally specify pandas UDF's return type.

For example,

```python
import pandas as pd
import pyspark.pandas as ps

pdf = pd.DataFrame(
    {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [[e] for e in [4, 5, 6, 3, 2, 1, 0, 0, 0]]},
    index=np.random.rand(9),
)
psdf = ps.from_pandas(pdf)

def func(x) -> ps.DataFrame[float, [int, ntp.NDArray[int]]]:
    return x

psdf.pandas_on_spark.apply_batch(func)
```

### How was this patch tested?

Unittest and e2e tests were added.

Closes #34028 from HyukjinKwon/SPARK-36708.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Sep 23, 2021
1 parent 6a5ee02 commit cc2fcb4
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 24 deletions.
27 changes: 27 additions & 0 deletions python/pyspark/pandas/tests/test_dataframe.py
Expand Up @@ -21,6 +21,7 @@
import sys
import unittest
from io import StringIO
from typing import List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -4649,6 +4650,32 @@ def identify2(x) -> ps.DataFrame[slice("a", int), slice("b", int)]: # noqa: F40
self.assert_eq(sorted(actual["a"].to_numpy()), sorted(expected["a"].to_numpy()))
self.assert_eq(sorted(actual["b"].to_numpy()), sorted(expected["b"].to_numpy()))

pdf = pd.DataFrame(
{"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [[e] for e in [4, 5, 6, 3, 2, 1, 0, 0, 0]]},
index=np.random.rand(9),
)
psdf = ps.from_pandas(pdf)

def identify3(x) -> ps.DataFrame[float, [int, List[int]]]:
return x

actual = psdf.pandas_on_spark.apply_batch(identify3)
actual.columns = ["a", "b"]
self.assert_eq(actual, pdf)

# For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
import numpy.typing as ntp

psdf = ps.from_pandas(pdf)

def identify4(x) -> ps.DataFrame[float, [int, ntp.NDArray[int]]]: # type: ignore
return x

actual = psdf.pandas_on_spark.apply_batch(identify4)
actual.columns = ["a", "b"]
self.assert_eq(actual, pdf)

def test_transform_batch(self):
pdf = pd.DataFrame(
{
Expand Down
46 changes: 23 additions & 23 deletions python/pyspark/pandas/tests/test_typedef.py
Expand Up @@ -19,6 +19,7 @@
import unittest
import datetime
import decimal
from distutils.version import LooseVersion
from typing import List

import pandas
Expand Down Expand Up @@ -334,29 +335,6 @@ def test_as_spark_type_pandas_on_spark_dtype(self):
decimal.Decimal: (np.dtype("object"), DecimalType(38, 18)),
# ArrayType
np.ndarray: (np.dtype("object"), ArrayType(StringType())),
List[bytes]: (np.dtype("object"), ArrayType(BinaryType())),
List[np.character]: (np.dtype("object"), ArrayType(BinaryType())),
List[np.bytes_]: (np.dtype("object"), ArrayType(BinaryType())),
List[np.string_]: (np.dtype("object"), ArrayType(BinaryType())),
List[bool]: (np.dtype("object"), ArrayType(BooleanType())),
List[np.bool]: (np.dtype("object"), ArrayType(BooleanType())),
List[datetime.date]: (np.dtype("object"), ArrayType(DateType())),
List[np.int8]: (np.dtype("object"), ArrayType(ByteType())),
List[np.byte]: (np.dtype("object"), ArrayType(ByteType())),
List[decimal.Decimal]: (np.dtype("object"), ArrayType(DecimalType(38, 18))),
List[float]: (np.dtype("object"), ArrayType(DoubleType())),
List[np.float]: (np.dtype("object"), ArrayType(DoubleType())),
List[np.float64]: (np.dtype("object"), ArrayType(DoubleType())),
List[np.float32]: (np.dtype("object"), ArrayType(FloatType())),
List[np.int32]: (np.dtype("object"), ArrayType(IntegerType())),
List[int]: (np.dtype("object"), ArrayType(LongType())),
List[np.int]: (np.dtype("object"), ArrayType(LongType())),
List[np.int64]: (np.dtype("object"), ArrayType(LongType())),
List[np.int16]: (np.dtype("object"), ArrayType(ShortType())),
List[str]: (np.dtype("object"), ArrayType(StringType())),
List[np.unicode_]: (np.dtype("object"), ArrayType(StringType())),
List[datetime.datetime]: (np.dtype("object"), ArrayType(TimestampType())),
List[np.datetime64]: (np.dtype("object"), ArrayType(TimestampType())),
# CategoricalDtype
CategoricalDtype(categories=["a", "b", "c"]): (
CategoricalDtype(categories=["a", "b", "c"]),
Expand All @@ -368,6 +346,28 @@ def test_as_spark_type_pandas_on_spark_dtype(self):
self.assertEqual(as_spark_type(numpy_or_python_type), spark_type)
self.assertEqual(pandas_on_spark_type(numpy_or_python_type), (dtype, spark_type))

if isinstance(numpy_or_python_type, CategoricalDtype):
# Nested CategoricalDtype is not yet supported.
continue

self.assertEqual(as_spark_type(List[numpy_or_python_type]), ArrayType(spark_type))
self.assertEqual(
pandas_on_spark_type(List[numpy_or_python_type]),
(np.dtype("object"), ArrayType(spark_type)),
)

# For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
import numpy.typing as ntp

self.assertEqual(
as_spark_type(ntp.NDArray[numpy_or_python_type]), ArrayType(spark_type)
)
self.assertEqual(
pandas_on_spark_type(ntp.NDArray[numpy_or_python_type]),
(np.dtype("object"), ArrayType(spark_type)),
)

with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."):
as_spark_type(np.dtype("uint64"))

Expand Down
19 changes: 18 additions & 1 deletion python/pyspark/pandas/typedef/typehints.py
Expand Up @@ -20,8 +20,10 @@
"""
import datetime
import decimal
import sys
import typing
from collections import Iterable
from distutils.version import LooseVersion
from inspect import getfullargspec, isclass
from typing import ( # noqa: F401
Any,
Expand Down Expand Up @@ -152,6 +154,19 @@ def as_spark_type(
- dictionaries of field_name -> type
- Python3's typing system
"""
# For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
if (
hasattr(tpe, "__origin__")
and tpe.__origin__ is np.ndarray # type: ignore
and hasattr(tpe, "__args__")
and len(tpe.__args__) > 1 # type: ignore
):
# numpy.typing.NDArray
return types.ArrayType(
as_spark_type(tpe.__args__[1].__args__[0], raise_error=raise_error) # type: ignore
)

if isinstance(tpe, np.dtype) and tpe == np.dtype("object"):
pass
# ArrayType
Expand Down Expand Up @@ -568,7 +583,9 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
else:
parameters = getattr(tuple_type, "__args__")

index_parameters = [p for p in parameters if issubclass(p, IndexNameTypeHolder)]
index_parameters = [
p for p in parameters if isclass(p) and issubclass(p, IndexNameTypeHolder)
]
data_parameters = [p for p in parameters if p not in index_parameters]
assert len(data_parameters) > 0, "Type hints for data must not be empty."

Expand Down

0 comments on commit cc2fcb4

Please sign in to comment.