Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48247][PYTHON] Use all values in a dict when inferring MapType schema #46547

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions python/docs/source/migration_guide/pyspark_upgrade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Upgrading from PySpark 3.5 to 4.0
* In Spark 4.0, when applying ``astype`` to a decimal type object, the existing missing value is changed to ``True`` instead of ``False`` from Pandas API on Spark.
* In Spark 4.0, ``pyspark.testing.assertPandasOnSparkEqual`` has been removed from Pandas API on Spark, use ``pyspark.pandas.testing.assert_frame_equal`` instead.
* In Spark 4.0, the aliases ``Y``, ``M``, ``H``, ``T``, ``S`` have been deprecated from Pandas API on Spark, use ``YE``, ``ME``, ``h``, ``min``, ``s`` instead respectively.
* In Spark 4.0, the schema of a map column is inferred by merging the schemas of all pairs in the map. To restore the previous behavior where the schema is only inferred from the first non-null pair, you can set ``spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled`` to ``true``.



Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,12 @@ def _inferSchemaFromList(
(
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
) = self._client.get_configs(
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
"spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
"spark.sql.timestampType",
)
return reduce(
Expand All @@ -402,6 +404,7 @@ def _inferSchemaFromList(
names,
infer_dict_as_struct=(infer_dict_as_struct == "true"),
infer_array_from_first_element=(infer_array_from_first_element == "true"),
infer_map_from_first_pair=(infer_map_from_first_pair == "true"),
prefer_timestamp_ntz=(prefer_timestamp_ntz == "TIMESTAMP_NTZ"),
)
for row in data
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,7 @@ def _inferSchemaFromList(
)
infer_dict_as_struct = self._jconf.inferDictAsStruct()
infer_array_from_first_element = self._jconf.legacyInferArrayTypeFromFirstElement()
infer_map_from_first_pair = self._jconf.legacyInferMapStructTypeFromFirstItem()
prefer_timestamp_ntz = is_timestamp_ntz_preferred()
schema = reduce(
_merge_type,
Expand All @@ -1051,6 +1052,7 @@ def _inferSchemaFromList(
names,
infer_dict_as_struct=infer_dict_as_struct,
infer_array_from_first_element=infer_array_from_first_element,
infer_map_from_first_pair=infer_map_from_first_pair,
prefer_timestamp_ntz=prefer_timestamp_ntz,
)
for row in data
Expand Down Expand Up @@ -1093,6 +1095,7 @@ def _inferSchema(

infer_dict_as_struct = self._jconf.inferDictAsStruct()
infer_array_from_first_element = self._jconf.legacyInferArrayTypeFromFirstElement()
infer_map_from_first_pair = self._jconf.legacyInferMapStructTypeFromFirstItem()
prefer_timestamp_ntz = is_timestamp_ntz_preferred()
if samplingRatio is None:
schema = _infer_schema(
Expand All @@ -1110,6 +1113,7 @@ def _inferSchema(
names=names,
infer_dict_as_struct=infer_dict_as_struct,
infer_array_from_first_element=infer_array_from_first_element,
infer_map_from_first_pair=infer_map_from_first_pair,
prefer_timestamp_ntz=prefer_timestamp_ntz,
),
)
Expand All @@ -1129,6 +1133,7 @@ def _inferSchema(
names,
infer_dict_as_struct=infer_dict_as_struct,
infer_array_from_first_element=infer_array_from_first_element,
infer_map_from_first_pair=infer_map_from_first_pair,
prefer_timestamp_ntz=prefer_timestamp_ntz,
)
).reduce(_merge_type)
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def test_infer_array_element_type_with_struct(self):
def test_infer_array_merge_element_types_with_rdd(self):
super().test_infer_array_merge_element_types_with_rdd()

@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_infer_map_pair_type_empty_rdd(self):
super().test_infer_map_pair_type_empty_rdd()

@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_infer_map_merge_pair_types_with_rdd(self):
super().test_infer_map_merge_pair_types_with_rdd()

@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_infer_binary_type(self):
super().test_infer_binary_type()
Expand Down
51 changes: 51 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,57 @@ def test_infer_array_element_type_with_struct(self):
df = self.spark.createDataFrame(data)
self.assertEqual(Row(f1=[Row(payment=200.5), Row(payment=None)]), df.first())

def test_infer_map_merge_pair_types_with_rdd(self):
# SPARK-48247: Test inferring map pair type from all values in array
MapRow = Row("f1", "f2")

data = [MapRow({"a": 1, "b": None}, {"a": None, "b": 1})]

rdd = self.sc.parallelize(data)
df = self.spark.createDataFrame(rdd)
self.assertEqual(Row(f1={"a": 1, "b": None}, f2={"a": None, "b": 1}), df.first())

def test_infer_map_pair_type_empty_rdd(self):
# SPARK-48247: Test inferring map pair type from all rows
MapRow = Row("f1")

data = [MapRow({}), MapRow({"a": None}), MapRow({"a": 1})]

rdd = self.sc.parallelize(data)
df = self.spark.createDataFrame(rdd)
rows = df.collect()
self.assertEqual(Row(f1={}), rows[0])
self.assertEqual(Row(f1={"a": None}), rows[1])
self.assertEqual(Row(f1={"a": 1}), rows[2])

def test_infer_map_pair_type_empty(self):
# SPARK-48247: Test inferring map pair type from all rows
MapRow = Row("f1")

data = [MapRow({}), MapRow({"a": None}), MapRow({"a": 1})]

df = self.spark.createDataFrame(data)
rows = df.collect()
self.assertEqual(Row(f1={}), rows[0])
self.assertEqual(Row(f1={"a": None}), rows[1])
self.assertEqual(Row(f1={"a": 1}), rows[2])

def test_infer_map_pair_type_with_nested_maps(self):
# SPARK-48247: Test inferring nested map
NestedRow = Row("f1", "f2")

data = [
NestedRow({"payment": 200.5, "name": "A"}, {"outer": {"payment": 200.5, "name": "A"}})
]
df = self.spark.createDataFrame(data)
self.assertEqual(
Row(
f1={"payment": "200.5", "name": "A"},
f2={"outer": {"payment": "200.5", "name": "A"}},
),
df.first(),
)

def test_create_dataframe_from_dict_respects_schema(self):
df = self.spark.createDataFrame([{"a": 1}], ["b"])
self.assertEqual(df.columns, ["b"])
Expand Down
36 changes: 35 additions & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1888,6 +1888,7 @@ def _infer_type(
obj: Any,
infer_dict_as_struct: bool = False,
infer_array_from_first_element: bool = False,
infer_map_from_first_pair: bool = False,
prefer_timestamp_ntz: bool = False,
) -> DataType:
"""Infer the DataType from obj"""
Expand Down Expand Up @@ -1923,30 +1924,61 @@ def _infer_type(
value,
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
),
True,
)
return struct
else:
elif infer_map_from_first_pair:
for key, value in obj.items():
if key is not None and value is not None:
return MapType(
_infer_type(
key,
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
),
_infer_type(
value,
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
),
True,
)
return MapType(NullType(), NullType(), True)
else:
key_type: DataType = NullType()
value_type: DataType = NullType()
for key, value in obj.items():
if key is not None:
key_type = _merge_type(
key_type,
_infer_type(
key,
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
),
)
if value is not None:
value_type = _merge_type(
value_type,
_infer_type(
value,
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
),
)

return MapType(key_type, value_type, True)
elif isinstance(obj, list):
if len(obj) > 0:
if infer_array_from_first_element:
Expand Down Expand Up @@ -1989,6 +2021,7 @@ def _infer_schema(
names: Optional[List[str]] = None,
infer_dict_as_struct: bool = False,
infer_array_from_first_element: bool = False,
infer_map_from_first_pair: bool = False,
prefer_timestamp_ntz: bool = False,
) -> StructType:
"""Infer the schema from dict/namedtuple/object"""
Expand Down Expand Up @@ -2027,6 +2060,7 @@ def _infer_schema(
v,
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
),
True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4611,6 +4611,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_INFER_MAP_STRUCT_TYPE_FROM_FIRST_ITEM =
buildConf("spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled")
.internal()
.doc("PySpark's SparkSession.createDataFrame infers the key/value types of a map from all " +
"paris in the map by default. If this config is set to true, it restores the legacy " +
"behavior of only inferring the type from the first pair.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: first non-null pair

.version("4.0.0")
.booleanConf
.createWithDefault(false)

val LEGACY_USE_V1_COMMAND =
buildConf("spark.sql.legacy.useV1Command")
.internal()
Expand Down Expand Up @@ -5814,6 +5824,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def legacyInferArrayTypeFromFirstElement: Boolean = getConf(
SQLConf.LEGACY_INFER_ARRAY_TYPE_FROM_FIRST_ELEMENT)

def legacyInferMapStructTypeFromFirstItem: Boolean = getConf(
SQLConf.LEGACY_INFER_MAP_STRUCT_TYPE_FROM_FIRST_ITEM)

def parquetFieldIdReadEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED)

def parquetFieldIdWriteEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED)
Expand Down