From 50ae9d49a7fbe2f6089703ad577a706a1372e504 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 25 Feb 2019 16:03:13 -0800 Subject: [PATCH 1/9] Added basic StructType test --- .../sql/tests/test_pandas_udf_scalar.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 28ef98d7b3f1e..79593eae656a4 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -23,13 +23,16 @@ import time import unittest +if sys.version >= '3': + unicode = str + from datetime import date, datetime from decimal import Decimal from distutils.version import LooseVersion from pyspark.rdd import PythonEvalType from pyspark.sql import Column -from pyspark.sql.functions import array, col, expr, lit, sum, udf, pandas_udf +from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf from pyspark.sql.types import Row from pyspark.sql.types import * from pyspark.sql.utils import AnalysisException @@ -265,6 +268,23 @@ def test_vectorized_udf_null_array(self): result = df.select(array_f(col('array'))) self.assertEquals(df.collect(), result.collect()) + def test_vectorized_udf_struct_type(self): + import pandas as pd + + df = self.spark.range(10) + return_type = StructType([ + StructField('id', IntegerType()), + StructField('str', StringType())]) + + @pandas_udf(returnType=return_type) + def f(id): + return pd.DataFrame({'id': id, 'str': id.apply(unicode)}) + + actual = df.select(f(col('id')).alias('struct')).collect() + expected = df.select(struct(col('id'), col('id').cast('string').alias('str')) \ + .alias('struct')).collect() + self.assertEqual(expected, actual) + def test_vectorized_udf_complex(self): df = self.spark.range(10).select( col('id').cast('int').alias('a'), From 5888a79a94f50ffad1778fd70a82a336117b2210 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 25 Feb 2019 17:37:30 -0800 Subject: [PATCH 2/9] Support StructType return from scalar Pandas UDF --- python/pyspark/serializers.py | 34 ++++++++++++++++++++++++++++++---- python/pyspark/sql/types.py | 6 ++++++ python/pyspark/worker.py | 7 ++++++- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a2c59fedfc8cd..33a4a17f8d1b9 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -64,6 +64,7 @@ from itertools import izip as zip, imap as map else: import pickle + basestring = unicode = str xrange = range pickle_protocol = pickle.HIGHEST_PROTOCOL @@ -244,7 +245,7 @@ def __repr__(self): return "ArrowStreamSerializer" -def _create_batch(series, timezone, safecheck): +def _create_batch(series, timezone, safecheck, assign_cols_by_name): """ Create an Arrow record batch from the given pandas.Series or list of Series, with optional type. @@ -254,6 +255,7 @@ def _create_batch(series, timezone, safecheck): """ import decimal from distutils.version import LooseVersion + import pandas as pd import pyarrow as pa from pyspark.sql.types import _check_series_convert_timestamps_internal # Make input conform to [(series1, type1), (series2, type2), ...] @@ -295,7 +297,29 @@ def create_array(s, t): raise RuntimeError(error_msg % (s.dtype, t), e) return array - arrs = [create_array(s, t) for s, t in series] + arrs = [] + for s, t in series: + if pa.types.is_struct(t): + if not isinstance(s, pd.DataFrame): + raise ValueError("A field of type StructType expects a pandas.DataFrame, " + + "but got: %s" % str(type(s))) + + # Assign result columns by schema name if user labeled with strings, else use position + struct_arrs = [] + struct_names = [] + if assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns): + for field in t: + struct_arrs.append(create_array(s[field.name], field.type)) + struct_names.append(field.name) + else: + for i, field in enumerate(t): + struct_arrs.append(create_array(s[s.columns[i]], field.type)) + struct_names.append(field.name) + + arrs.append(pa.StructArray.from_arrays(struct_arrs, names=struct_names)) + else: + arrs.append(create_array(s, t)) + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) @@ -304,10 +328,11 @@ class ArrowStreamPandasSerializer(Serializer): Serializes Pandas.Series as Arrow data with Arrow streaming format. """ - def __init__(self, timezone, safecheck): + def __init__(self, timezone, safecheck, assign_cols_by_name): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone self._safecheck = safecheck + self._assign_cols_by_name = assign_cols_by_name def arrow_to_pandas(self, arrow_column): from pyspark.sql.types import from_arrow_type, \ @@ -326,7 +351,8 @@ def dump_stream(self, iterator, stream): writer = None try: for series in iterator: - batch = _create_batch(series, self._timezone, self._safecheck) + batch = _create_batch(series, self._timezone, self._safecheck, + self._assign_cols_by_name) if writer is None: write_int(SpecialLengths.START_ARROW_STREAM, stream) writer = pa.RecordBatchStreamWriter(stream, batch.schema) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 348cb5b118594..2a75fbfdec296 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1616,6 +1616,12 @@ def to_arrow_type(dt): if type(dt.elementType) == TimestampType: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) + elif type(dt) == StructType: + if any(field.dataType == StructType for field in dt): + raise TypeError("Nested StructTypes not supported in conversion to Arrow") + fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) + for field in dt] + arrow_type = pa.struct(fields) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 01934a0e72758..cf1b575fe1e70 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -254,7 +254,12 @@ def read_udfs(pickleSer, infile, eval_type): timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion", "false").lower() == 'true' - ser = ArrowStreamPandasSerializer(timezone, safecheck) + # NOTE: this is duplicated from wrap_grouped_map_pandas_udf + assign_cols_by_name = runner_conf.get( + "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ + .lower() == "true" + + ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name) else: ser = BatchedSerializer(PickleSerializer(), 100) From 0c4a1c67f0fdb1033cff435b5cc156da5a819546 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 26 Feb 2019 11:55:10 -0800 Subject: [PATCH 3/9] fixed type check and createDataFrame with arrow, added more tests --- python/pyspark/serializers.py | 2 +- python/pyspark/sql/session.py | 3 +- .../sql/tests/test_pandas_udf_scalar.py | 51 +++++++++++++++++-- python/pyspark/sql/types.py | 4 +- 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 33a4a17f8d1b9..8776eebe1b3dd 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -299,7 +299,7 @@ def create_array(s, t): arrs = [] for s, t in series: - if pa.types.is_struct(t): + if t is not None and pa.types.is_struct(t): if not isinstance(s, pd.DataFrame): raise ValueError("A field of type StructType expects a pandas.DataFrame, " + "but got: %s" % str(type(s))) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index bdf1701a58959..32a2c8a67252d 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -557,8 +557,9 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): # Create Arrow record batches safecheck = self._wrapped._conf.arrowSafeTypeConversion() + col_by_name = True # col by name only applies to StructType columns, can't happen here batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)], - timezone, safecheck) + timezone, safecheck, col_by_name) for pdf_slice in pdf_slices] # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 79593eae656a4..69a4985266a78 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -273,18 +273,59 @@ def test_vectorized_udf_struct_type(self): df = self.spark.range(10) return_type = StructType([ - StructField('id', IntegerType()), + StructField('id', LongType()), StructField('str', StringType())]) - @pandas_udf(returnType=return_type) - def f(id): + def func(id): return pd.DataFrame({'id': id, 'str': id.apply(unicode)}) - actual = df.select(f(col('id')).alias('struct')).collect() - expected = df.select(struct(col('id'), col('id').cast('string').alias('str')) \ + f = pandas_udf(func, returnType=return_type) + + expected = df.select(struct(col('id'), col('id').cast('string').alias('str')) .alias('struct')).collect() + + actual = df.select(f(col('id')).alias('struct')).collect() + self.assertEqual(expected, actual) + + g = pandas_udf(func, 'id: long, str: string') + actual = df.select(g(col('id')).alias('struct')).collect() self.assertEqual(expected, actual) + def test_vectorized_udf_struct_complex(self): + import pandas as pd + + df = self.spark.range(10) + return_type = StructType([ + StructField('ts', TimestampType()), + StructField('arr', ArrayType(LongType()))]) + + @pandas_udf(returnType=return_type) + def f(id): + return pd.DataFrame({'ts': id.apply(lambda i: pd.Timestamp(i)), + 'arr': id.apply(lambda i: [i, i + 1])}) + + actual = df.withColumn('f', f(col('id'))).collect() + for i, row in enumerate(actual): + id, f = row + self.assertEqual(i, id) + self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0]) + self.assertListEqual([i, i + 1], f[1]) + + def test_vectorized_udf_nested_struct(self): + nested_type = StructType([ + StructField('id', IntegerType()), + StructField('nested', StructType([ + StructField('foo', StringType()), + StructField('bar', FloatType()) + ])) + ]) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'Invalid returnType with scalar Pandas UDFs'): + pandas_udf(lambda x: x, returnType=nested_type) + def test_vectorized_udf_complex(self): df = self.spark.range(10).select( col('id').cast('int').alias('a'), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 2a75fbfdec296..f43e4c0fab6dd 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1617,8 +1617,8 @@ def to_arrow_type(dt): raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) elif type(dt) == StructType: - if any(field.dataType == StructType for field in dt): - raise TypeError("Nested StructTypes not supported in conversion to Arrow") + if any(type(field.dataType) == StructType for field in dt): + raise TypeError("Nested StructType not supported in conversion to Arrow") fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) for field in dt] arrow_type = pa.struct(fields) From 8567ce6fb2107446f69396a1b1bcc738a51509a5 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 26 Feb 2019 12:28:03 -0800 Subject: [PATCH 4/9] raise error when group agg udf --- python/pyspark/sql/udf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 58f4e0dff5ee5..3698cb09c0d28 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -133,6 +133,8 @@ def returnType(self): "UDFs: returnType must be a StructType.") elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: try: + if isinstance(self._returnType_placeholder, StructType): + raise TypeError to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( From bfabb7d5e2c6f2813a6fbea3d2f52db11d90992c Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 26 Feb 2019 13:28:21 -0800 Subject: [PATCH 5/9] Add workaround for earlier versions of pyarrow --- python/pyspark/serializers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 8776eebe1b3dd..0ec3e2288bde8 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -316,7 +316,11 @@ def create_array(s, t): struct_arrs.append(create_array(s[s.columns[i]], field.type)) struct_names.append(field.name) - arrs.append(pa.StructArray.from_arrays(struct_arrs, names=struct_names)) + # TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version + if LooseVersion(pa.__version__) < LooseVersion("0.9.0"): + arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs)) + else: + arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) else: arrs.append(create_array(s, t)) From 174ad9907706f78f8b58628ffd5349c91b51e1b7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 27 Feb 2019 17:39:41 -0800 Subject: [PATCH 6/9] remove unnecessary str addition --- python/pyspark/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 0ec3e2288bde8..f93f46b0227bf 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -301,7 +301,7 @@ def create_array(s, t): for s, t in series: if t is not None and pa.types.is_struct(t): if not isinstance(s, pd.DataFrame): - raise ValueError("A field of type StructType expects a pandas.DataFrame, " + + raise ValueError("A field of type StructType expects a pandas.DataFrame, " "but got: %s" % str(type(s))) # Assign result columns by schema name if user labeled with strings, else use position From 94fd921533612a3b19be9b6f75362d32c383b3bd Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 28 Feb 2019 09:05:24 -0800 Subject: [PATCH 7/9] Added check for array with struct type --- python/pyspark/sql/tests/test_pandas_udf_scalar.py | 4 ++++ python/pyspark/sql/types.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 69a4985266a78..c9dcf77f6c0a6 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -404,6 +404,10 @@ def test_vectorized_udf_unsupported_types(self): NotImplementedError, 'Invalid returnType.*scalar Pandas UDF.*MapType'): pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*ArrayType.StructType'): + pandas_udf(lambda x: x, ArrayType(StructType([StructField('a', IntegerType())]))) def test_vectorized_udf_dates(self): schema = StructType().add("idx", LongType()).add("date", DateType()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f43e4c0fab6dd..d87f0f91499ae 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1613,7 +1613,7 @@ def to_arrow_type(dt): # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') elif type(dt) == ArrayType: - if type(dt.elementType) == TimestampType: + if type(dt.elementType) in [StructType, TimestampType]: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) elif type(dt) == StructType: From 89796b2c48cb25ad547a6a5c97c3bcf812860d81 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 4 Mar 2019 18:50:00 -0800 Subject: [PATCH 8/9] handle empty partitions that produce empty DataFrame result, updated pydocs --- python/pyspark/serializers.py | 21 ++++++++++--------- python/pyspark/sql/functions.py | 12 ++++++++++- .../sql/tests/test_pandas_udf_scalar.py | 14 +++++++++++++ python/pyspark/worker.py | 5 +++-- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index f93f46b0227bf..0c3c68ec0bd95 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -304,17 +304,18 @@ def create_array(s, t): raise ValueError("A field of type StructType expects a pandas.DataFrame, " "but got: %s" % str(type(s))) - # Assign result columns by schema name if user labeled with strings, else use position - struct_arrs = [] - struct_names = [] - if assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns): - for field in t: - struct_arrs.append(create_array(s[field.name], field.type)) - struct_names.append(field.name) + # Input partition and result pandas.DataFrame empty, make empty Arrays with struct + if len(s) == 0 and len(s.columns) == 0: + arrs_names = [(pa.array([], type=field.type), field.name) for field in t] + # Assign result columns by schema name if user labeled with strings + elif assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns): + arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t] + # Assign result columns by position else: - for i, field in enumerate(t): - struct_arrs.append(create_array(s[s.columns[i]], field.type)) - struct_names.append(field.name) + arrs_names = [(create_array(s[s.columns[i]], field.type), field.name) + for i, field in enumerate(t)] + + struct_arrs, struct_names = zip(*arrs_names) # TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version if LooseVersion(pa.__version__) < LooseVersion("0.9.0"): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3c33e2bed92d9..a36423e67d750 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2842,8 +2842,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`. - :class:`MapType`, :class:`StructType` are currently not supported as output types. + :class:`MapType`, nested :class:`StructType` are currently not supported as output types. Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and :meth:`pyspark.sql.DataFrame.select`. @@ -2868,6 +2869,15 @@ def pandas_udf(f=None, returnType=None, functionType=None): +----------+--------------+------------+ | 8| JOHN DOE| 22| +----------+--------------+------------+ + >>> @pandas_udf("first string, last string") # doctest: +SKIP + ... def split_expand(n): + ... return n.str.split(expand=True) + >>> df.select(split_expand("name")).show() # doctest: +SKIP + +------------------+ + |split_expand(name)| + +------------------+ + | [John, Doe]| + +------------------+ .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input column, but is the length of an internal batch used for each call to the function. diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index c9dcf77f6c0a6..28b6db216d00a 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -392,6 +392,20 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_struct_with_empty_partition(self): + df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))\ + .withColumn('name', lit('John Doe')) + + @pandas_udf("first string, last string") + def split_expand(n): + return n.str.split(expand=True) + + result = df.select(split_expand('name')).collect() + self.assertEqual(1, len(result)) + row = result[0] + self.assertEqual('John', row[0]['first']) + self.assertEqual('Doe', row[0]['last']) + def test_vectorized_udf_varargs(self): df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) f = pandas_udf(lambda *v: v[0], LongType()) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index cf1b575fe1e70..0e9b6d665a36f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -39,7 +39,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import to_arrow_type +from pyspark.sql.types import to_arrow_type, StructType from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle @@ -90,8 +90,9 @@ def wrap_scalar_pandas_udf(f, return_type): def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): + pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series" raise TypeError("Return type of the user-defined function should be " - "Pandas.Series, but is {}".format(type(result))) + "{}, but is {}".format(pd_type, type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " "expected %d, got %d" % (len(a[0]), len(result))) From 91decf035e0faf9d386d3c505134e7c73174c6e3 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 4 Mar 2019 21:31:55 -0800 Subject: [PATCH 9/9] Fix check for nested structs in grouped map udfs, add note for grouped agg --- python/pyspark/sql/tests/test_pandas_udf_grouped_map.py | 1 + python/pyspark/sql/udf.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index a0a25359d1e01..f7684d3fbcff0 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -273,6 +273,7 @@ def test_unsupported_types(self): StructField('map', MapType(StringType(), IntegerType())), StructField('arr_ts', ArrayType(TimestampType())), StructField('null', NullType()), + StructField('struct', StructType([StructField('l', LongType())])), ] # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0 diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 3698cb09c0d28..275abe9c85d1e 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -123,7 +123,7 @@ def returnType(self): elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: if isinstance(self._returnType_placeholder, StructType): try: - to_arrow_schema(self._returnType_placeholder) + to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( "Invalid returnType with grouped map Pandas UDFs: " @@ -133,6 +133,7 @@ def returnType(self): "UDFs: returnType must be a StructType.") elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: try: + # StructType is not yet allowed as a return type, explicitly check here to fail fast if isinstance(self._returnType_placeholder, StructType): raise TypeError to_arrow_type(self._returnType_placeholder)