diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index 26fc634307879..938dc4f6dec69 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -70,6 +70,7 @@ Upgrading from PySpark 3.5 to 4.0 * In Spark 4.0, when applying ``astype`` to a decimal type object, the existing missing value is changed to ``True`` instead of ``False`` from Pandas API on Spark. * In Spark 4.0, ``pyspark.testing.assertPandasOnSparkEqual`` has been removed from Pandas API on Spark, use ``pyspark.pandas.testing.assert_frame_equal`` instead. * In Spark 4.0, the aliases ``Y``, ``M``, ``H``, ``T``, ``S`` have been deprecated from Pandas API on Spark, use ``YE``, ``ME``, ``h``, ``min``, ``s`` instead respectively. +* In Spark 4.0, the schema of a map column is inferred by merging the schemas of all pairs in the map. To restore the previous behavior where the schema is only inferred from the first non-null pair, you can set ``spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled`` to ``true``. diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index bec3c5b579a0c..5e6c5e5587646 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -388,10 +388,12 @@ def _inferSchemaFromList( ( infer_dict_as_struct, infer_array_from_first_element, + infer_map_from_first_pair, prefer_timestamp_ntz, ) = self._client.get_configs( "spark.sql.pyspark.inferNestedDictAsStruct.enabled", "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled", + "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", "spark.sql.timestampType", ) return reduce( @@ -402,6 +404,7 @@ def _inferSchemaFromList( names, infer_dict_as_struct=(infer_dict_as_struct == "true"), infer_array_from_first_element=(infer_array_from_first_element == "true"), + infer_map_from_first_pair=(infer_map_from_first_pair == "true"), prefer_timestamp_ntz=(prefer_timestamp_ntz == "TIMESTAMP_NTZ"), ) for row in data diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 0a06f2e190c93..9077ee8874444 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -1042,6 +1042,7 @@ def _inferSchemaFromList( ) infer_dict_as_struct = self._jconf.inferDictAsStruct() infer_array_from_first_element = self._jconf.legacyInferArrayTypeFromFirstElement() + infer_map_from_first_pair = self._jconf.legacyInferMapStructTypeFromFirstItem() prefer_timestamp_ntz = is_timestamp_ntz_preferred() schema = reduce( _merge_type, @@ -1051,6 +1052,7 @@ def _inferSchemaFromList( names, infer_dict_as_struct=infer_dict_as_struct, infer_array_from_first_element=infer_array_from_first_element, + infer_map_from_first_pair=infer_map_from_first_pair, prefer_timestamp_ntz=prefer_timestamp_ntz, ) for row in data @@ -1093,6 +1095,7 @@ def _inferSchema( infer_dict_as_struct = self._jconf.inferDictAsStruct() infer_array_from_first_element = self._jconf.legacyInferArrayTypeFromFirstElement() + infer_map_from_first_pair = self._jconf.legacyInferMapStructTypeFromFirstItem() prefer_timestamp_ntz = is_timestamp_ntz_preferred() if samplingRatio is None: schema = _infer_schema( @@ -1110,6 +1113,7 @@ def _inferSchema( names=names, infer_dict_as_struct=infer_dict_as_struct, infer_array_from_first_element=infer_array_from_first_element, + infer_map_from_first_pair=infer_map_from_first_pair, prefer_timestamp_ntz=prefer_timestamp_ntz, ), ) @@ -1129,6 +1133,7 @@ def _inferSchema( names, infer_dict_as_struct=infer_dict_as_struct, infer_array_from_first_element=infer_array_from_first_element, + infer_map_from_first_pair=infer_map_from_first_pair, prefer_timestamp_ntz=prefer_timestamp_ntz, ) ).reduce(_merge_type) diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py index 82a677574b455..a295c6cc7585f 100644 --- a/python/pyspark/sql/tests/connect/test_parity_types.py +++ b/python/pyspark/sql/tests/connect/test_parity_types.py @@ -50,6 +50,14 @@ def test_infer_array_element_type_with_struct(self): def test_infer_array_merge_element_types_with_rdd(self): super().test_infer_array_merge_element_types_with_rdd() + @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") + def test_infer_map_pair_type_empty_rdd(self): + super().test_infer_map_pair_type_empty_rdd() + + @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") + def test_infer_map_merge_pair_types_with_rdd(self): + super().test_infer_map_merge_pair_types_with_rdd() + @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") def test_infer_binary_type(self): super().test_infer_binary_type() diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 40eded6a4433c..616761322a89d 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -410,6 +410,57 @@ def test_infer_array_element_type_with_struct(self): df = self.spark.createDataFrame(data) self.assertEqual(Row(f1=[Row(payment=200.5), Row(payment=None)]), df.first()) + def test_infer_map_merge_pair_types_with_rdd(self): + # SPARK-48247: Test inferring map pair type from all values in array + MapRow = Row("f1", "f2") + + data = [MapRow({"a": 1, "b": None}, {"a": None, "b": 1})] + + rdd = self.sc.parallelize(data) + df = self.spark.createDataFrame(rdd) + self.assertEqual(Row(f1={"a": 1, "b": None}, f2={"a": None, "b": 1}), df.first()) + + def test_infer_map_pair_type_empty_rdd(self): + # SPARK-48247: Test inferring map pair type from all rows + MapRow = Row("f1") + + data = [MapRow({}), MapRow({"a": None}), MapRow({"a": 1})] + + rdd = self.sc.parallelize(data) + df = self.spark.createDataFrame(rdd) + rows = df.collect() + self.assertEqual(Row(f1={}), rows[0]) + self.assertEqual(Row(f1={"a": None}), rows[1]) + self.assertEqual(Row(f1={"a": 1}), rows[2]) + + def test_infer_map_pair_type_empty(self): + # SPARK-48247: Test inferring map pair type from all rows + MapRow = Row("f1") + + data = [MapRow({}), MapRow({"a": None}), MapRow({"a": 1})] + + df = self.spark.createDataFrame(data) + rows = df.collect() + self.assertEqual(Row(f1={}), rows[0]) + self.assertEqual(Row(f1={"a": None}), rows[1]) + self.assertEqual(Row(f1={"a": 1}), rows[2]) + + def test_infer_map_pair_type_with_nested_maps(self): + # SPARK-48247: Test inferring nested map + NestedRow = Row("f1", "f2") + + data = [ + NestedRow({"payment": 200.5, "name": "A"}, {"outer": {"payment": 200.5, "name": "A"}}) + ] + df = self.spark.createDataFrame(data) + self.assertEqual( + Row( + f1={"payment": "200.5", "name": "A"}, + f2={"outer": {"payment": "200.5", "name": "A"}}, + ), + df.first(), + ) + def test_create_dataframe_from_dict_respects_schema(self): df = self.spark.createDataFrame([{"a": 1}], ["b"]) self.assertEqual(df.columns, ["b"]) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 41be12620fd56..2415c5a33704a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1888,6 +1888,7 @@ def _infer_type( obj: Any, infer_dict_as_struct: bool = False, infer_array_from_first_element: bool = False, + infer_map_from_first_pair: bool = False, prefer_timestamp_ntz: bool = False, ) -> DataType: """Infer the DataType from obj""" @@ -1923,12 +1924,13 @@ def _infer_type( value, infer_dict_as_struct, infer_array_from_first_element, + infer_map_from_first_pair, prefer_timestamp_ntz, ), True, ) return struct - else: + elif infer_map_from_first_pair: for key, value in obj.items(): if key is not None and value is not None: return MapType( @@ -1936,17 +1938,47 @@ def _infer_type( key, infer_dict_as_struct, infer_array_from_first_element, + infer_map_from_first_pair, prefer_timestamp_ntz, ), _infer_type( value, infer_dict_as_struct, infer_array_from_first_element, + infer_map_from_first_pair, prefer_timestamp_ntz, ), True, ) return MapType(NullType(), NullType(), True) + else: + key_type: DataType = NullType() + value_type: DataType = NullType() + for key, value in obj.items(): + if key is not None: + key_type = _merge_type( + key_type, + _infer_type( + key, + infer_dict_as_struct, + infer_array_from_first_element, + infer_map_from_first_pair, + prefer_timestamp_ntz, + ), + ) + if value is not None: + value_type = _merge_type( + value_type, + _infer_type( + value, + infer_dict_as_struct, + infer_array_from_first_element, + infer_map_from_first_pair, + prefer_timestamp_ntz, + ), + ) + + return MapType(key_type, value_type, True) elif isinstance(obj, list): if len(obj) > 0: if infer_array_from_first_element: @@ -1989,6 +2021,7 @@ def _infer_schema( names: Optional[List[str]] = None, infer_dict_as_struct: bool = False, infer_array_from_first_element: bool = False, + infer_map_from_first_pair: bool = False, prefer_timestamp_ntz: bool = False, ) -> StructType: """Infer the schema from dict/namedtuple/object""" @@ -2027,6 +2060,7 @@ def _infer_schema( v, infer_dict_as_struct, infer_array_from_first_element, + infer_map_from_first_pair, prefer_timestamp_ntz, ), True, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e78157d611586..3e97f2c455ded 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4611,6 +4611,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_INFER_MAP_STRUCT_TYPE_FROM_FIRST_ITEM = + buildConf("spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled") + .internal() + .doc("PySpark's SparkSession.createDataFrame infers the key/value types of a map from all " + + "paris in the map by default. If this config is set to true, it restores the legacy " + + "behavior of only inferring the type from the first non-null pair.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val LEGACY_USE_V1_COMMAND = buildConf("spark.sql.legacy.useV1Command") .internal() @@ -5814,6 +5824,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def legacyInferArrayTypeFromFirstElement: Boolean = getConf( SQLConf.LEGACY_INFER_ARRAY_TYPE_FROM_FIRST_ELEMENT) + def legacyInferMapStructTypeFromFirstItem: Boolean = getConf( + SQLConf.LEGACY_INFER_MAP_STRUCT_TYPE_FROM_FIRST_ITEM) + def parquetFieldIdReadEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED) def parquetFieldIdWriteEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED)