From 42c1c8f7d26478ea734a620ff9f4bb6f1c8e8f48 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 15 May 2024 08:22:27 +0900 Subject: [PATCH] [SPARK-48247][PYTHON] Use all values in a dict when inferring MapType schema ### What changes were proposed in this pull request? This is similar with https://github.com/apache/spark/pull/36545. This PR proposes to infer the map types from all pairs instead of the first pair. ### Why are the changes needed? To have the consistent behaivor. e.g., ```python >>> spark.createDataFrame([[1], [2], ["a"], ["c"]]).collect() [Row(_1='1'), Row(_1='2'), Row(_1='a'), Row(_1='c')] ``` ### Does this PR introduce _any_ user-facing change? Yes. See below **Without Spark Connect:** ```python >>> spark.createDataFrame([{"outer": {"payment": 200.5, "name": "A"}}]).collect() [Row(outer={'name': 'A', 'payment': '200.5'})] >>> spark.conf.set("spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", True) >>> spark.createDataFrame([{"outer": {"payment": 200.5, "name": "A"}}]).collect() [Row(outer={'name': None, 'payment': 200.5})] ``` **With Spark Conenct:** ```python >>> spark.createDataFrame([{"outer": {"payment": 200.5, "name": "A"}}]).collect() [Row(outer={'payment': '200.5', 'name': 'A'})] >>> spark.conf.set("spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", True) >>> spark.createDataFrame([{"outer": {"payment": 200.5, "name": "A"}}]).collect() Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/sql/connect/session.py", line 635, in createDataFrame _table = LocalDataToArrowConversion.convert(_data, _schema) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.../spark/python/pyspark/sql/connect/conversion.py", line 378, in convert return pa.Table.from_arrays(pylist, schema=pa_schema) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "pyarrow/table.pxi", line 3974, in pyarrow.lib.Table.from_arrays File "pyarrow/table.pxi", line 1464, in pyarrow.lib._sanitize_arrays File "pyarrow/array.pxi", line 373, in pyarrow.lib.asarray File "pyarrow/array.pxi", line 343, in pyarrow.lib.array File "pyarrow/array.pxi", line 42, in pyarrow.lib._sequence_to_array File "pyarrow/error.pxi", line 154, in pyarrow.lib.pyarrow_internal_check_status File "pyarrow/error.pxi", line 91, in pyarrow.lib.check_status pyarrow.lib.ArrowInvalid: Could not convert 'A' with type str: tried to convert to double ``` ### How was this patch tested? Unittests added ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46547 from HyukjinKwon/infer-map-first. Lead-authored-by: Hyukjin Kwon Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../migration_guide/pyspark_upgrade.rst | 1 + python/pyspark/sql/connect/session.py | 3 ++ python/pyspark/sql/session.py | 5 ++ .../sql/tests/connect/test_parity_types.py | 8 +++ python/pyspark/sql/tests/test_types.py | 51 +++++++++++++++++++ python/pyspark/sql/types.py | 36 ++++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 13 +++++ 7 files changed, 116 insertions(+), 1 deletion(-) diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index 0f252519e7daf..4ca9dc334f17d 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -71,6 +71,7 @@ Upgrading from PySpark 3.5 to 4.0 * In Spark 4.0, when applying ``astype`` to a decimal type object, the existing missing value is changed to ``True`` instead of ``False`` from Pandas API on Spark. * In Spark 4.0, ``pyspark.testing.assertPandasOnSparkEqual`` has been removed from Pandas API on Spark, use ``pyspark.pandas.testing.assert_frame_equal`` instead. * In Spark 4.0, the aliases ``Y``, ``M``, ``H``, ``T``, ``S`` have been deprecated from Pandas API on Spark, use ``YE``, ``ME``, ``h``, ``min``, ``s`` instead respectively. +* In Spark 4.0, the schema of a map column is inferred by merging the schemas of all pairs in the map. To restore the previous behavior where the schema is only inferred from the first non-null pair, you can set ``spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled`` to ``true``. diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index bec3c5b579a0c..5e6c5e5587646 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -388,10 +388,12 @@ def _inferSchemaFromList( ( infer_dict_as_struct, infer_array_from_first_element, + infer_map_from_first_pair, prefer_timestamp_ntz, ) = self._client.get_configs( "spark.sql.pyspark.inferNestedDictAsStruct.enabled", "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled", + "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", "spark.sql.timestampType", ) return reduce( @@ -402,6 +404,7 @@ def _inferSchemaFromList( names, infer_dict_as_struct=(infer_dict_as_struct == "true"), infer_array_from_first_element=(infer_array_from_first_element == "true"), + infer_map_from_first_pair=(infer_map_from_first_pair == "true"), prefer_timestamp_ntz=(prefer_timestamp_ntz == "TIMESTAMP_NTZ"), ) for row in data diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 0a06f2e190c93..9077ee8874444 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -1042,6 +1042,7 @@ def _inferSchemaFromList( ) infer_dict_as_struct = self._jconf.inferDictAsStruct() infer_array_from_first_element = self._jconf.legacyInferArrayTypeFromFirstElement() + infer_map_from_first_pair = self._jconf.legacyInferMapStructTypeFromFirstItem() prefer_timestamp_ntz = is_timestamp_ntz_preferred() schema = reduce( _merge_type, @@ -1051,6 +1052,7 @@ def _inferSchemaFromList( names, infer_dict_as_struct=infer_dict_as_struct, infer_array_from_first_element=infer_array_from_first_element, + infer_map_from_first_pair=infer_map_from_first_pair, prefer_timestamp_ntz=prefer_timestamp_ntz, ) for row in data @@ -1093,6 +1095,7 @@ def _inferSchema( infer_dict_as_struct = self._jconf.inferDictAsStruct() infer_array_from_first_element = self._jconf.legacyInferArrayTypeFromFirstElement() + infer_map_from_first_pair = self._jconf.legacyInferMapStructTypeFromFirstItem() prefer_timestamp_ntz = is_timestamp_ntz_preferred() if samplingRatio is None: schema = _infer_schema( @@ -1110,6 +1113,7 @@ def _inferSchema( names=names, infer_dict_as_struct=infer_dict_as_struct, infer_array_from_first_element=infer_array_from_first_element, + infer_map_from_first_pair=infer_map_from_first_pair, prefer_timestamp_ntz=prefer_timestamp_ntz, ), ) @@ -1129,6 +1133,7 @@ def _inferSchema( names, infer_dict_as_struct=infer_dict_as_struct, infer_array_from_first_element=infer_array_from_first_element, + infer_map_from_first_pair=infer_map_from_first_pair, prefer_timestamp_ntz=prefer_timestamp_ntz, ) ).reduce(_merge_type) diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py index 55acb4b1a381b..aef42b96a249d 100644 --- a/python/pyspark/sql/tests/connect/test_parity_types.py +++ b/python/pyspark/sql/tests/connect/test_parity_types.py @@ -46,6 +46,14 @@ def test_infer_array_element_type_empty_rdd(self): def test_infer_array_merge_element_types_with_rdd(self): super().test_infer_array_merge_element_types_with_rdd() + @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") + def test_infer_map_pair_type_empty_rdd(self): + super().test_infer_map_pair_type_empty_rdd() + + @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") + def test_infer_map_merge_pair_types_with_rdd(self): + super().test_infer_map_merge_pair_types_with_rdd() + @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") def test_infer_binary_type(self): super().test_infer_binary_type() diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 84d89b544f152..5942ae2abdb31 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -410,6 +410,57 @@ def test_infer_array_element_type_with_struct(self): df = self.spark.createDataFrame(data) self.assertEqual(Row(f1=[Row(payment=200.5), Row(payment=None)]), df.first()) + def test_infer_map_merge_pair_types_with_rdd(self): + # SPARK-48247: Test inferring map pair type from all values in array + MapRow = Row("f1", "f2") + + data = [MapRow({"a": 1, "b": None}, {"a": None, "b": 1})] + + rdd = self.sc.parallelize(data) + df = self.spark.createDataFrame(rdd) + self.assertEqual(Row(f1={"a": 1, "b": None}, f2={"a": None, "b": 1}), df.first()) + + def test_infer_map_pair_type_empty_rdd(self): + # SPARK-48247: Test inferring map pair type from all rows + MapRow = Row("f1") + + data = [MapRow({}), MapRow({"a": None}), MapRow({"a": 1})] + + rdd = self.sc.parallelize(data) + df = self.spark.createDataFrame(rdd) + rows = df.collect() + self.assertEqual(Row(f1={}), rows[0]) + self.assertEqual(Row(f1={"a": None}), rows[1]) + self.assertEqual(Row(f1={"a": 1}), rows[2]) + + def test_infer_map_pair_type_empty(self): + # SPARK-48247: Test inferring map pair type from all rows + MapRow = Row("f1") + + data = [MapRow({}), MapRow({"a": None}), MapRow({"a": 1})] + + df = self.spark.createDataFrame(data) + rows = df.collect() + self.assertEqual(Row(f1={}), rows[0]) + self.assertEqual(Row(f1={"a": None}), rows[1]) + self.assertEqual(Row(f1={"a": 1}), rows[2]) + + def test_infer_map_pair_type_with_nested_maps(self): + # SPARK-48247: Test inferring nested map + NestedRow = Row("f1", "f2") + + data = [ + NestedRow({"payment": 200.5, "name": "A"}, {"outer": {"payment": 200.5, "name": "A"}}) + ] + df = self.spark.createDataFrame(data) + self.assertEqual( + Row( + f1={"payment": "200.5", "name": "A"}, + f2={"outer": {"payment": "200.5", "name": "A"}}, + ), + df.first(), + ) + def test_create_dataframe_from_dict_respects_schema(self): df = self.spark.createDataFrame([{"a": 1}], ["b"]) self.assertEqual(df.columns, ["b"]) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fbd4987713e26..ed3535e7d4aaa 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1888,6 +1888,7 @@ def _infer_type( obj: Any, infer_dict_as_struct: bool = False, infer_array_from_first_element: bool = False, + infer_map_from_first_pair: bool = False, prefer_timestamp_ntz: bool = False, ) -> DataType: """Infer the DataType from obj""" @@ -1923,12 +1924,13 @@ def _infer_type( value, infer_dict_as_struct, infer_array_from_first_element, + infer_map_from_first_pair, prefer_timestamp_ntz, ), True, ) return struct - else: + elif infer_map_from_first_pair: for key, value in obj.items(): if key is not None and value is not None: return MapType( @@ -1936,17 +1938,47 @@ def _infer_type( key, infer_dict_as_struct, infer_array_from_first_element, + infer_map_from_first_pair, prefer_timestamp_ntz, ), _infer_type( value, infer_dict_as_struct, infer_array_from_first_element, + infer_map_from_first_pair, prefer_timestamp_ntz, ), True, ) return MapType(NullType(), NullType(), True) + else: + key_type: DataType = NullType() + value_type: DataType = NullType() + for key, value in obj.items(): + if key is not None: + key_type = _merge_type( + key_type, + _infer_type( + key, + infer_dict_as_struct, + infer_array_from_first_element, + infer_map_from_first_pair, + prefer_timestamp_ntz, + ), + ) + if value is not None: + value_type = _merge_type( + value_type, + _infer_type( + value, + infer_dict_as_struct, + infer_array_from_first_element, + infer_map_from_first_pair, + prefer_timestamp_ntz, + ), + ) + + return MapType(key_type, value_type, True) elif isinstance(obj, list): if len(obj) > 0: if infer_array_from_first_element: @@ -2003,6 +2035,7 @@ def _infer_schema( names: Optional[List[str]] = None, infer_dict_as_struct: bool = False, infer_array_from_first_element: bool = False, + infer_map_from_first_pair: bool = False, prefer_timestamp_ntz: bool = False, ) -> StructType: """Infer the schema from dict/namedtuple/object""" @@ -2041,6 +2074,7 @@ def _infer_schema( v, infer_dict_as_struct, infer_array_from_first_element, + infer_map_from_first_pair, prefer_timestamp_ntz, ), True, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1add1e4f82af8..9edef5a1f3ca4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4621,6 +4621,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_INFER_MAP_STRUCT_TYPE_FROM_FIRST_ITEM = + buildConf("spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled") + .internal() + .doc("PySpark's SparkSession.createDataFrame infers the key/value types of a map from all " + + "paris in the map by default. If this config is set to true, it restores the legacy " + + "behavior of only inferring the type from the first non-null pair.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val LEGACY_USE_V1_COMMAND = buildConf("spark.sql.legacy.useV1Command") .internal() @@ -5826,6 +5836,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def legacyInferArrayTypeFromFirstElement: Boolean = getConf( SQLConf.LEGACY_INFER_ARRAY_TYPE_FROM_FIRST_ELEMENT) + def legacyInferMapStructTypeFromFirstItem: Boolean = getConf( + SQLConf.LEGACY_INFER_MAP_STRUCT_TYPE_FROM_FIRST_ITEM) + def parquetFieldIdReadEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED) def parquetFieldIdWriteEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED)