diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 3cfbc0334d8eb..970860ec4f009 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -21,6 +21,7 @@ import sys import unittest from io import StringIO +from typing import List import numpy as np import pandas as pd @@ -4649,6 +4650,32 @@ def identify2(x) -> ps.DataFrame[slice("a", int), slice("b", int)]: # noqa: F40 self.assert_eq(sorted(actual["a"].to_numpy()), sorted(expected["a"].to_numpy())) self.assert_eq(sorted(actual["b"].to_numpy()), sorted(expected["b"].to_numpy())) + pdf = pd.DataFrame( + {"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [[e] for e in [4, 5, 6, 3, 2, 1, 0, 0, 0]]}, + index=np.random.rand(9), + ) + psdf = ps.from_pandas(pdf) + + def identify3(x) -> ps.DataFrame[float, [int, List[int]]]: + return x + + actual = psdf.pandas_on_spark.apply_batch(identify3) + actual.columns = ["a", "b"] + self.assert_eq(actual, pdf) + + # For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+ + if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"): + import numpy.typing as ntp + + psdf = ps.from_pandas(pdf) + + def identify4(x) -> ps.DataFrame[float, [int, ntp.NDArray[int]]]: # type: ignore + return x + + actual = psdf.pandas_on_spark.apply_batch(identify4) + actual.columns = ["a", "b"] + self.assert_eq(actual, pdf) + def test_transform_batch(self): pdf = pd.DataFrame( { diff --git a/python/pyspark/pandas/tests/test_typedef.py b/python/pyspark/pandas/tests/test_typedef.py index 6b644f40f1a70..b0f1d55c9bcd8 100644 --- a/python/pyspark/pandas/tests/test_typedef.py +++ b/python/pyspark/pandas/tests/test_typedef.py @@ -19,6 +19,7 @@ import unittest import datetime import decimal +from distutils.version import LooseVersion from typing import List import pandas @@ -334,29 +335,6 @@ def test_as_spark_type_pandas_on_spark_dtype(self): decimal.Decimal: (np.dtype("object"), DecimalType(38, 18)), # ArrayType np.ndarray: (np.dtype("object"), ArrayType(StringType())), - List[bytes]: (np.dtype("object"), ArrayType(BinaryType())), - List[np.character]: (np.dtype("object"), ArrayType(BinaryType())), - List[np.bytes_]: (np.dtype("object"), ArrayType(BinaryType())), - List[np.string_]: (np.dtype("object"), ArrayType(BinaryType())), - List[bool]: (np.dtype("object"), ArrayType(BooleanType())), - List[np.bool]: (np.dtype("object"), ArrayType(BooleanType())), - List[datetime.date]: (np.dtype("object"), ArrayType(DateType())), - List[np.int8]: (np.dtype("object"), ArrayType(ByteType())), - List[np.byte]: (np.dtype("object"), ArrayType(ByteType())), - List[decimal.Decimal]: (np.dtype("object"), ArrayType(DecimalType(38, 18))), - List[float]: (np.dtype("object"), ArrayType(DoubleType())), - List[np.float]: (np.dtype("object"), ArrayType(DoubleType())), - List[np.float64]: (np.dtype("object"), ArrayType(DoubleType())), - List[np.float32]: (np.dtype("object"), ArrayType(FloatType())), - List[np.int32]: (np.dtype("object"), ArrayType(IntegerType())), - List[int]: (np.dtype("object"), ArrayType(LongType())), - List[np.int]: (np.dtype("object"), ArrayType(LongType())), - List[np.int64]: (np.dtype("object"), ArrayType(LongType())), - List[np.int16]: (np.dtype("object"), ArrayType(ShortType())), - List[str]: (np.dtype("object"), ArrayType(StringType())), - List[np.unicode_]: (np.dtype("object"), ArrayType(StringType())), - List[datetime.datetime]: (np.dtype("object"), ArrayType(TimestampType())), - List[np.datetime64]: (np.dtype("object"), ArrayType(TimestampType())), # CategoricalDtype CategoricalDtype(categories=["a", "b", "c"]): ( CategoricalDtype(categories=["a", "b", "c"]), @@ -368,6 +346,28 @@ def test_as_spark_type_pandas_on_spark_dtype(self): self.assertEqual(as_spark_type(numpy_or_python_type), spark_type) self.assertEqual(pandas_on_spark_type(numpy_or_python_type), (dtype, spark_type)) + if isinstance(numpy_or_python_type, CategoricalDtype): + # Nested CategoricalDtype is not yet supported. + continue + + self.assertEqual(as_spark_type(List[numpy_or_python_type]), ArrayType(spark_type)) + self.assertEqual( + pandas_on_spark_type(List[numpy_or_python_type]), + (np.dtype("object"), ArrayType(spark_type)), + ) + + # For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+ + if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"): + import numpy.typing as ntp + + self.assertEqual( + as_spark_type(ntp.NDArray[numpy_or_python_type]), ArrayType(spark_type) + ) + self.assertEqual( + pandas_on_spark_type(ntp.NDArray[numpy_or_python_type]), + (np.dtype("object"), ArrayType(spark_type)), + ) + with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."): as_spark_type(np.dtype("uint64")) diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py index ecdc5324ddf51..b21facb4406a5 100644 --- a/python/pyspark/pandas/typedef/typehints.py +++ b/python/pyspark/pandas/typedef/typehints.py @@ -20,8 +20,10 @@ """ import datetime import decimal +import sys import typing from collections import Iterable +from distutils.version import LooseVersion from inspect import getfullargspec, isclass from typing import ( # noqa: F401 Any, @@ -152,6 +154,19 @@ def as_spark_type( - dictionaries of field_name -> type - Python3's typing system """ + # For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+ + if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"): + if ( + hasattr(tpe, "__origin__") + and tpe.__origin__ is np.ndarray # type: ignore + and hasattr(tpe, "__args__") + and len(tpe.__args__) > 1 # type: ignore + ): + # numpy.typing.NDArray + return types.ArrayType( + as_spark_type(tpe.__args__[1].__args__[0], raise_error=raise_error) # type: ignore + ) + if isinstance(tpe, np.dtype) and tpe == np.dtype("object"): pass # ArrayType @@ -568,7 +583,9 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp else: parameters = getattr(tuple_type, "__args__") - index_parameters = [p for p in parameters if issubclass(p, IndexNameTypeHolder)] + index_parameters = [ + p for p in parameters if isclass(p) and issubclass(p, IndexNameTypeHolder) + ] data_parameters = [p for p in parameters if p not in index_parameters] assert len(data_parameters) > 0, "Type hints for data must not be empty."