Skip to content

Commit

Permalink
[SPARK-48247][PYTHON] Use all values in a dict when inferring MapType…
Browse files Browse the repository at this point in the history
… schema

### What changes were proposed in this pull request?

This is similar with #36545. This PR proposes to infer the map types from all pairs instead of the first pair.

### Why are the changes needed?

To have the consistent behaivor. e.g.,

```python
>>> spark.createDataFrame([[1], [2], ["a"], ["c"]]).collect()
[Row(_1='1'), Row(_1='2'), Row(_1='a'), Row(_1='c')]
```

### Does this PR introduce _any_ user-facing change?

Yes. See below

**Without Spark Connect:**

```python
>>> spark.createDataFrame([{"outer": {"payment": 200.5, "name": "A"}}]).collect()
[Row(outer={'name': 'A', 'payment': '200.5'})]
>>> spark.conf.set("spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", True)
>>> spark.createDataFrame([{"outer": {"payment": 200.5, "name": "A"}}]).collect()
[Row(outer={'name': None, 'payment': 200.5})]
```

**With Spark Conenct:**

```python
>>> spark.createDataFrame([{"outer": {"payment": 200.5, "name": "A"}}]).collect()
[Row(outer={'payment': '200.5', 'name': 'A'})]
>>> spark.conf.set("spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", True)
>>> spark.createDataFrame([{"outer": {"payment": 200.5, "name": "A"}}]).collect()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/.../spark/python/pyspark/sql/connect/session.py", line 635, in createDataFrame
    _table = LocalDataToArrowConversion.convert(_data, _schema)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../spark/python/pyspark/sql/connect/conversion.py", line 378, in convert
    return pa.Table.from_arrays(pylist, schema=pa_schema)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "pyarrow/table.pxi", line 3974, in pyarrow.lib.Table.from_arrays
  File "pyarrow/table.pxi", line 1464, in pyarrow.lib._sanitize_arrays
  File "pyarrow/array.pxi", line 373, in pyarrow.lib.asarray
  File "pyarrow/array.pxi", line 343, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 42, in pyarrow.lib._sequence_to_array
  File "pyarrow/error.pxi", line 154, in pyarrow.lib.pyarrow_internal_check_status
  File "pyarrow/error.pxi", line 91, in pyarrow.lib.check_status
pyarrow.lib.ArrowInvalid: Could not convert 'A' with type str: tried to convert to double
```

### How was this patch tested?

Unittests added

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #46547 from HyukjinKwon/infer-map-first.

Lead-authored-by: Hyukjin Kwon <gurwls223@apache.org>
Co-authored-by: Hyukjin Kwon <gurwls223@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon and HyukjinKwon committed May 14, 2024
1 parent 79aeae1 commit 42c1c8f
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/docs/source/migration_guide/pyspark_upgrade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Upgrading from PySpark 3.5 to 4.0
* In Spark 4.0, when applying ``astype`` to a decimal type object, the existing missing value is changed to ``True`` instead of ``False`` from Pandas API on Spark.
* In Spark 4.0, ``pyspark.testing.assertPandasOnSparkEqual`` has been removed from Pandas API on Spark, use ``pyspark.pandas.testing.assert_frame_equal`` instead.
* In Spark 4.0, the aliases ``Y``, ``M``, ``H``, ``T``, ``S`` have been deprecated from Pandas API on Spark, use ``YE``, ``ME``, ``h``, ``min``, ``s`` instead respectively.
* In Spark 4.0, the schema of a map column is inferred by merging the schemas of all pairs in the map. To restore the previous behavior where the schema is only inferred from the first non-null pair, you can set ``spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled`` to ``true``.



Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,12 @@ def _inferSchemaFromList(
(
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
) = self._client.get_configs(
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
"spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
"spark.sql.timestampType",
)
return reduce(
Expand All @@ -402,6 +404,7 @@ def _inferSchemaFromList(
names,
infer_dict_as_struct=(infer_dict_as_struct == "true"),
infer_array_from_first_element=(infer_array_from_first_element == "true"),
infer_map_from_first_pair=(infer_map_from_first_pair == "true"),
prefer_timestamp_ntz=(prefer_timestamp_ntz == "TIMESTAMP_NTZ"),
)
for row in data
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,7 @@ def _inferSchemaFromList(
)
infer_dict_as_struct = self._jconf.inferDictAsStruct()
infer_array_from_first_element = self._jconf.legacyInferArrayTypeFromFirstElement()
infer_map_from_first_pair = self._jconf.legacyInferMapStructTypeFromFirstItem()
prefer_timestamp_ntz = is_timestamp_ntz_preferred()
schema = reduce(
_merge_type,
Expand All @@ -1051,6 +1052,7 @@ def _inferSchemaFromList(
names,
infer_dict_as_struct=infer_dict_as_struct,
infer_array_from_first_element=infer_array_from_first_element,
infer_map_from_first_pair=infer_map_from_first_pair,
prefer_timestamp_ntz=prefer_timestamp_ntz,
)
for row in data
Expand Down Expand Up @@ -1093,6 +1095,7 @@ def _inferSchema(

infer_dict_as_struct = self._jconf.inferDictAsStruct()
infer_array_from_first_element = self._jconf.legacyInferArrayTypeFromFirstElement()
infer_map_from_first_pair = self._jconf.legacyInferMapStructTypeFromFirstItem()
prefer_timestamp_ntz = is_timestamp_ntz_preferred()
if samplingRatio is None:
schema = _infer_schema(
Expand All @@ -1110,6 +1113,7 @@ def _inferSchema(
names=names,
infer_dict_as_struct=infer_dict_as_struct,
infer_array_from_first_element=infer_array_from_first_element,
infer_map_from_first_pair=infer_map_from_first_pair,
prefer_timestamp_ntz=prefer_timestamp_ntz,
),
)
Expand All @@ -1129,6 +1133,7 @@ def _inferSchema(
names,
infer_dict_as_struct=infer_dict_as_struct,
infer_array_from_first_element=infer_array_from_first_element,
infer_map_from_first_pair=infer_map_from_first_pair,
prefer_timestamp_ntz=prefer_timestamp_ntz,
)
).reduce(_merge_type)
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def test_infer_array_element_type_empty_rdd(self):
def test_infer_array_merge_element_types_with_rdd(self):
super().test_infer_array_merge_element_types_with_rdd()

@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_infer_map_pair_type_empty_rdd(self):
super().test_infer_map_pair_type_empty_rdd()

@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_infer_map_merge_pair_types_with_rdd(self):
super().test_infer_map_merge_pair_types_with_rdd()

@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_infer_binary_type(self):
super().test_infer_binary_type()
Expand Down
51 changes: 51 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,57 @@ def test_infer_array_element_type_with_struct(self):
df = self.spark.createDataFrame(data)
self.assertEqual(Row(f1=[Row(payment=200.5), Row(payment=None)]), df.first())

def test_infer_map_merge_pair_types_with_rdd(self):
# SPARK-48247: Test inferring map pair type from all values in array
MapRow = Row("f1", "f2")

data = [MapRow({"a": 1, "b": None}, {"a": None, "b": 1})]

rdd = self.sc.parallelize(data)
df = self.spark.createDataFrame(rdd)
self.assertEqual(Row(f1={"a": 1, "b": None}, f2={"a": None, "b": 1}), df.first())

def test_infer_map_pair_type_empty_rdd(self):
# SPARK-48247: Test inferring map pair type from all rows
MapRow = Row("f1")

data = [MapRow({}), MapRow({"a": None}), MapRow({"a": 1})]

rdd = self.sc.parallelize(data)
df = self.spark.createDataFrame(rdd)
rows = df.collect()
self.assertEqual(Row(f1={}), rows[0])
self.assertEqual(Row(f1={"a": None}), rows[1])
self.assertEqual(Row(f1={"a": 1}), rows[2])

def test_infer_map_pair_type_empty(self):
# SPARK-48247: Test inferring map pair type from all rows
MapRow = Row("f1")

data = [MapRow({}), MapRow({"a": None}), MapRow({"a": 1})]

df = self.spark.createDataFrame(data)
rows = df.collect()
self.assertEqual(Row(f1={}), rows[0])
self.assertEqual(Row(f1={"a": None}), rows[1])
self.assertEqual(Row(f1={"a": 1}), rows[2])

def test_infer_map_pair_type_with_nested_maps(self):
# SPARK-48247: Test inferring nested map
NestedRow = Row("f1", "f2")

data = [
NestedRow({"payment": 200.5, "name": "A"}, {"outer": {"payment": 200.5, "name": "A"}})
]
df = self.spark.createDataFrame(data)
self.assertEqual(
Row(
f1={"payment": "200.5", "name": "A"},
f2={"outer": {"payment": "200.5", "name": "A"}},
),
df.first(),
)

def test_create_dataframe_from_dict_respects_schema(self):
df = self.spark.createDataFrame([{"a": 1}], ["b"])
self.assertEqual(df.columns, ["b"])
Expand Down
36 changes: 35 additions & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1888,6 +1888,7 @@ def _infer_type(
obj: Any,
infer_dict_as_struct: bool = False,
infer_array_from_first_element: bool = False,
infer_map_from_first_pair: bool = False,
prefer_timestamp_ntz: bool = False,
) -> DataType:
"""Infer the DataType from obj"""
Expand Down Expand Up @@ -1923,30 +1924,61 @@ def _infer_type(
value,
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
),
True,
)
return struct
else:
elif infer_map_from_first_pair:
for key, value in obj.items():
if key is not None and value is not None:
return MapType(
_infer_type(
key,
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
),
_infer_type(
value,
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
),
True,
)
return MapType(NullType(), NullType(), True)
else:
key_type: DataType = NullType()
value_type: DataType = NullType()
for key, value in obj.items():
if key is not None:
key_type = _merge_type(
key_type,
_infer_type(
key,
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
),
)
if value is not None:
value_type = _merge_type(
value_type,
_infer_type(
value,
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
),
)

return MapType(key_type, value_type, True)
elif isinstance(obj, list):
if len(obj) > 0:
if infer_array_from_first_element:
Expand Down Expand Up @@ -2003,6 +2035,7 @@ def _infer_schema(
names: Optional[List[str]] = None,
infer_dict_as_struct: bool = False,
infer_array_from_first_element: bool = False,
infer_map_from_first_pair: bool = False,
prefer_timestamp_ntz: bool = False,
) -> StructType:
"""Infer the schema from dict/namedtuple/object"""
Expand Down Expand Up @@ -2041,6 +2074,7 @@ def _infer_schema(
v,
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
),
True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4621,6 +4621,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_INFER_MAP_STRUCT_TYPE_FROM_FIRST_ITEM =
buildConf("spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled")
.internal()
.doc("PySpark's SparkSession.createDataFrame infers the key/value types of a map from all " +
"paris in the map by default. If this config is set to true, it restores the legacy " +
"behavior of only inferring the type from the first non-null pair.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val LEGACY_USE_V1_COMMAND =
buildConf("spark.sql.legacy.useV1Command")
.internal()
Expand Down Expand Up @@ -5826,6 +5836,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def legacyInferArrayTypeFromFirstElement: Boolean = getConf(
SQLConf.LEGACY_INFER_ARRAY_TYPE_FROM_FIRST_ELEMENT)

def legacyInferMapStructTypeFromFirstItem: Boolean = getConf(
SQLConf.LEGACY_INFER_MAP_STRUCT_TYPE_FROM_FIRST_ITEM)

def parquetFieldIdReadEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED)

def parquetFieldIdWriteEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED)
Expand Down

0 comments on commit 42c1c8f

Please sign in to comment.