Skip to content

Commit

Permalink
[SPARK-24554][PYTHON][SQL] Add MapType support for PySpark with Arrow
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This change adds MapType support for PySpark with Arrow, if using pyarrow >= 2.0.0.

### Why are the changes needed?

MapType was previous unsupported with Arrow.

### Does this PR introduce _any_ user-facing change?

User can now enable MapType for `createDataFrame()`, `toPandas()` with Arrow optimization, and with Pandas UDFs.

### How was this patch tested?

Added new PySpark tests for createDataFrame(), toPandas() and Scalar Pandas UDFs.

Closes #30393 from BryanCutler/arrow-add-MapType-SPARK-24554.

Authored-by: Bryan Cutler <cutlerb@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
BryanCutler authored and HyukjinKwon committed Nov 18, 2020
1 parent dd32f45 commit 8e2a0bd
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 36 deletions.
3 changes: 2 additions & 1 deletion python/docs/source/user_guide/arrow_pandas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,9 @@ Supported SQL Types

.. currentmodule:: pyspark.sql.types

Currently, all Spark SQL data types are supported by Arrow-based conversion except :class:`MapType`,
Currently, all Spark SQL data types are supported by Arrow-based conversion except
:class:`ArrayType` of :class:`TimestampType`, and nested :class:`StructType`.
:class: `MapType` is only supported when using PyArrow 2.0.0 and above.

Setting Arrow Batch Size
~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
from pyspark.sql.types import IntegralType
from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \
DoubleType, BooleanType, TimestampType, StructType, DataType
DoubleType, BooleanType, MapType, TimestampType, StructType, DataType
from pyspark.traceback_utils import SCCallSiteSync


Expand Down Expand Up @@ -100,7 +100,8 @@ def toPandas(self):
# of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled.
if use_arrow:
try:
from pyspark.sql.pandas.types import _check_series_localize_timestamps
from pyspark.sql.pandas.types import _check_series_localize_timestamps, \
_convert_map_items_to_dict
import pyarrow
# Rename columns to avoid duplicated column names.
tmp_column_names = ['col_{}'.format(i) for i in range(len(self.columns))]
Expand All @@ -117,6 +118,9 @@ def toPandas(self):
if isinstance(field.dataType, TimestampType):
pdf[field.name] = \
_check_series_localize_timestamps(pdf[field.name], timezone)
elif isinstance(field.dataType, MapType):
pdf[field.name] = \
_convert_map_items_to_dict(pdf[field.name])
return pdf
else:
return pd.DataFrame.from_records([], columns=self.columns)
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
should be checked for accuracy by users.
Currently,
:class:`pyspark.sql.types.MapType`,
:class:`pyspark.sql.types.ArrayType` of :class:`pyspark.sql.types.TimestampType` and
nested :class:`pyspark.sql.types.StructType`
are currently not supported as output types.
Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def __init__(self, timezone, safecheck, assign_cols_by_name):
self._assign_cols_by_name = assign_cols_by_name

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.pandas.types import _check_series_localize_timestamps
from pyspark.sql.pandas.types import _check_series_localize_timestamps, \
_convert_map_items_to_dict
import pyarrow

# If the given column is a date type column, creates a series of datetime.date directly
Expand All @@ -127,6 +128,8 @@ def arrow_to_pandas(self, arrow_column):

if pyarrow.types.is_timestamp(arrow_column.type):
return _check_series_localize_timestamps(s, self._timezone)
elif pyarrow.types.is_map(arrow_column.type):
return _convert_map_items_to_dict(s)
else:
return s

Expand All @@ -147,7 +150,8 @@ def _create_batch(self, series):
"""
import pandas as pd
import pyarrow as pa
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal, \
_convert_dict_to_map_items
from pandas.api.types import is_categorical_dtype
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
Expand All @@ -160,6 +164,8 @@ def create_array(s, t):
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s, self._timezone)
elif t is not None and pa.types.is_map(t):
s = _convert_dict_to_map_items(s)
elif is_categorical_dtype(s.dtype):
# Note: This can be removed once minimum pyarrow version is >= 0.16.1
s = s.astype(s.dtypes.categories.dtype)
Expand Down
42 changes: 39 additions & 3 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
pandas instances during the type conversion.
"""

from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \
DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, ArrayType, \
StructType, StructField, BooleanType
from pyspark.sql.types import BooleanType, ByteType, ShortType, IntegerType, LongType, \
FloatType, DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, \
ArrayType, MapType, StructType, StructField


def to_arrow_type(dt):
""" Convert Spark data type to pyarrow type
"""
from distutils.version import LooseVersion
import pyarrow as pa
if type(dt) == BooleanType:
arrow_type = pa.bool_()
Expand Down Expand Up @@ -58,6 +59,13 @@ def to_arrow_type(dt):
if type(dt.elementType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
elif type(dt) == MapType:
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
raise TypeError("MapType is only supported with pyarrow 2.0.0 and above")
if type(dt.keyType) in [StructType, TimestampType] or \
type(dt.valueType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.map_(to_arrow_type(dt.keyType), to_arrow_type(dt.valueType))
elif type(dt) == StructType:
if any(type(field.dataType) == StructType for field in dt):
raise TypeError("Nested StructType not supported in conversion to Arrow")
Expand All @@ -81,6 +89,8 @@ def to_arrow_schema(schema):
def from_arrow_type(at):
""" Convert pyarrow type to Spark data type.
"""
from distutils.version import LooseVersion
import pyarrow as pa
import pyarrow.types as types
if types.is_boolean(at):
spark_type = BooleanType()
Expand Down Expand Up @@ -110,6 +120,12 @@ def from_arrow_type(at):
if types.is_timestamp(at.value_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = ArrayType(from_arrow_type(at.value_type))
elif types.is_map(at):
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
raise TypeError("MapType is only supported with pyarrow 2.0.0 and above")
if types.is_timestamp(at.key_type) or types.is_timestamp(at.item_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = MapType(from_arrow_type(at.key_type), from_arrow_type(at.item_type))
elif types.is_struct(at):
if any(types.is_struct(field.type) for field in at):
raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
Expand Down Expand Up @@ -306,3 +322,23 @@ def _check_series_convert_timestamps_tz_local(s, timezone):
`pandas.Series` where if it is a timestamp, has been converted to tz-naive
"""
return _check_series_convert_timestamps_localize(s, timezone, None)


def _convert_map_items_to_dict(s):
"""
Convert a series with items as list of (key, value), as made from an Arrow column of map type,
to dict for compatibility with non-arrow MapType columns.
:param s: pandas.Series of lists of (key, value) pairs
:return: pandas.Series of dictionaries
"""
return s.apply(lambda m: None if m is None else {k: v for k, v in m})


def _convert_dict_to_map_items(s):
"""
Convert a series of dictionaries to list of (key, value) pairs to match expected data
for Arrow column of map type.
:param s: pandas.Series of dictionaries
:return: pandas.Series of lists of (key, value) pairs
"""
return s.apply(lambda d: list(d.items()) if d is not None else None)
77 changes: 68 additions & 9 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
import time
import unittest
import warnings
from distutils.version import LooseVersion

from pyspark import SparkContext, SparkConf
from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StringType, IntegerType, LongType, \
FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, MapType, \
ArrayType
FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, ArrayType
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
Expand Down Expand Up @@ -114,9 +114,10 @@ def create_pandas_data_frame(self):
return pd.DataFrame(data=data_dict)

def test_toPandas_fallback_enabled(self):
ts = datetime.datetime(2015, 11, 1, 0, 30)
with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
schema = StructType([StructField("a", ArrayType(TimestampType()), True)])
df = self.spark.createDataFrame([([ts],)], schema=schema)
with QuietTest(self.sc):
with self.warnings_lock:
with warnings.catch_warnings(record=True) as warns:
Expand All @@ -129,10 +130,10 @@ def test_toPandas_fallback_enabled(self):
self.assertTrue(len(user_warns) > 0)
self.assertTrue(
"Attempting non-optimization" in str(user_warns[-1]))
assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
assert_frame_equal(pdf, pd.DataFrame({"a": [[ts]]}))

def test_toPandas_fallback_disabled(self):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
schema = StructType([StructField("a", ArrayType(TimestampType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
with self.warnings_lock:
Expand Down Expand Up @@ -336,6 +337,62 @@ def test_toPandas_with_array_type(self):
self.assertTrue(expected[r][e] == result_arrow[r][e] and
result[r][e] == result_arrow[r][e])

def test_createDataFrame_with_map_type(self):
map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]

pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": map_data})
schema = "id long, m map<string, long>"

with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df = self.spark.createDataFrame(pdf, schema=schema)

if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
with QuietTest(self.sc):
with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
self.spark.createDataFrame(pdf, schema=schema)
else:
df_arrow = self.spark.createDataFrame(pdf, schema=schema)

result = df.collect()
result_arrow = df_arrow.collect()

self.assertEqual(len(result), len(result_arrow))
for row, row_arrow in zip(result, result_arrow):
i, m = row
_, m_arrow = row_arrow
self.assertEqual(m, map_data[i])
self.assertEqual(m_arrow, map_data[i])

def test_toPandas_with_map_type(self):
pdf = pd.DataFrame({"id": [0, 1, 2, 3],
"m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}]})

with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df = self.spark.createDataFrame(pdf, schema="id long, m map<string, long>")

if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
with QuietTest(self.sc):
with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
df.toPandas()
else:
pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df)
assert_frame_equal(pdf_arrow, pdf_non)

def test_toPandas_with_map_type_nulls(self):
pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4],
"m": [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]})

with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df = self.spark.createDataFrame(pdf, schema="id long, m map<string, long>")

if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
with QuietTest(self.sc):
with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
df.toPandas()
else:
pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df)
assert_frame_equal(pdf_arrow, pdf_non)

def test_createDataFrame_with_int_col_names(self):
import numpy as np
pdf = pd.DataFrame(np.random.rand(4, 2))
Expand All @@ -345,26 +402,28 @@ def test_createDataFrame_with_int_col_names(self):
self.assertEqual(pdf_col_names, df_arrow.columns)

def test_createDataFrame_fallback_enabled(self):
ts = datetime.datetime(2015, 11, 1, 0, 30)
with QuietTest(self.sc):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
with warnings.catch_warnings(record=True) as warns:
# we want the warnings to appear even if this test is run from a subclass
warnings.simplefilter("always")
df = self.spark.createDataFrame(
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
pd.DataFrame({"a": [[ts]]}), "a: array<timestamp>")
# Catch and check the last UserWarning.
user_warns = [
warn.message for warn in warns if isinstance(warn.message, UserWarning)]
self.assertTrue(len(user_warns) > 0)
self.assertTrue(
"Attempting non-optimization" in str(user_warns[-1]))
self.assertEqual(df.collect(), [Row(a={u'a': 1})])
self.assertEqual(df.collect(), [Row(a=[ts])])

def test_createDataFrame_fallback_disabled(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
self.spark.createDataFrame(
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
pd.DataFrame({"a": [[datetime.datetime(2015, 11, 1, 0, 30)]]}),
"a: array<timestamp>")

# Regression test for SPARK-23314
def test_timestamp_dst(self):
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ def test_wrong_return_type(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid return type.*MapType'):
'Invalid return type.*ArrayType.*TimestampType'):
left.groupby('id').cogroup(right.groupby('id')).applyInPandas(
lambda l, r: l, 'id long, v map<int, int>')
lambda l, r: l, 'id long, v array<timestamp>')

def test_wrong_args(self):
left = self.data1
Expand Down
7 changes: 3 additions & 4 deletions python/pyspark/sql/tests/test_pandas_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
window
from pyspark.sql.types import IntegerType, DoubleType, ArrayType, BinaryType, ByteType, \
LongType, DecimalType, ShortType, FloatType, StringType, BooleanType, StructType, \
StructField, NullType, MapType, TimestampType
StructField, NullType, TimestampType
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
Expand Down Expand Up @@ -246,10 +246,10 @@ def test_wrong_return_type(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid return type.*grouped map Pandas UDF.*MapType'):
'Invalid return type.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
pandas_udf(
lambda pdf: pdf,
'id long, v map<int, int>',
'id long, v array<timestamp>',
PandasUDFType.GROUPED_MAP)

def test_wrong_args(self):
Expand All @@ -276,7 +276,6 @@ def test_wrong_args(self):
def test_unsupported_types(self):
common_err_msg = 'Invalid return type.*grouped map Pandas UDF.*'
unsupported_types = [
StructField('map', MapType(StringType(), IntegerType())),
StructField('arr_ts', ArrayType(TimestampType())),
StructField('null', NullType()),
StructField('struct', StructType([StructField('l', LongType())])),
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark.sql import Row
from pyspark.sql.functions import array, explode, col, lit, mean, sum, \
udf, pandas_udf, PandasUDFType
from pyspark.sql.types import ArrayType, TimestampType, DoubleType, MapType
from pyspark.sql.types import ArrayType, TimestampType
from pyspark.sql.utils import AnalysisException
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
Expand Down Expand Up @@ -159,7 +159,7 @@ def mean_and_std_udf(v):

with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
@pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG)
@pandas_udf(ArrayType(TimestampType()), PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v):
return {v.mean(): v.std()}

Expand Down
Loading

0 comments on commit 8e2a0bd

Please sign in to comment.