From df50d4b309b55ca97a2409dcdda30011e3b43f87 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 21 May 2024 19:38:53 +0800 Subject: [PATCH] [SPARK-48336][PS][CONNECT] Implement `ps.sql` in Spark Connect ### What changes were proposed in this pull request? Implement `ps.sql` in Spark Connect ### Why are the changes needed? feature parity in Spark Connect ### Does this PR introduce _any_ user-facing change? yes: ``` In [4]: spark Out[4]: In [5]: >>> ps.sql(''' ...: ... SELECT m1.a, m2.b ...: ... FROM {table1} m1 INNER JOIN {table2} m2 ...: ... ON m1.key = m2.key ...: ... ORDER BY m1.a, m2.b''', ...: ... table1=ps.DataFrame({"a": [1,2], "key": ["a", "b"]}), ...: ... table2=pd.DataFrame({"b": [3,4,5], "key": ["a", "b", "b"]})) /Users/ruifeng.zheng/Dev/spark/python/pyspark/pandas/utils.py:1018: PandasAPIOnSparkAdviceWarning: The config 'spark.sql.ansi.enabled' is set to True. This can cause unexpected behavior from pandas API on Spark since pandas API on Spark follows the behavior of pandas, not SQL. warnings.warn(message, PandasAPIOnSparkAdviceWarning) /Users/ruifeng.zheng/Dev/spark/python/pyspark/pandas/utils.py:1018: PandasAPIOnSparkAdviceWarning: The config 'spark.sql.ansi.enabled' is set to True. This can cause unexpected behavior from pandas API on Spark since pandas API on Spark follows the behavior of pandas, not SQL. warnings.warn(message, PandasAPIOnSparkAdviceWarning) a b 0 1 3 1 2 4 2 2 5 ``` ### How was this patch tested? 1. enabled UTs 2. also manually tested all the examples ### Was this patch authored or co-authored using generative AI tooling? No Closes #46658 from zhengruifeng/ps_sql. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/pandas/sql_formatter.py | 63 +++++++++++-------- .../pandas/tests/connect/test_parity_sql.py | 8 +-- 2 files changed, 39 insertions(+), 32 deletions(-) diff --git a/python/pyspark/pandas/sql_formatter.py b/python/pyspark/pandas/sql_formatter.py index 7e8263f552f0c..b6d48077675bd 100644 --- a/python/pyspark/pandas/sql_formatter.py +++ b/python/pyspark/pandas/sql_formatter.py @@ -27,10 +27,10 @@ from pyspark.pandas.namespace import _get_index_map from pyspark import pandas as ps from pyspark.sql import SparkSession +from pyspark.sql.utils import get_lit_sql_str from pyspark.pandas.utils import default_session from pyspark.pandas.frame import DataFrame from pyspark.pandas.series import Series -from pyspark.errors import PySparkTypeError from pyspark.sql.utils import is_remote @@ -203,15 +203,16 @@ def sql( session = default_session() formatter = PandasSQLStringFormatter(session) try: - # ps.DataFrame are not supported for Spark Connect currently. - if is_remote(): - for obj in kwargs.values(): - if isinstance(obj, ps.DataFrame): - raise PySparkTypeError( - error_class="UNSUPPORTED_DATA_TYPE", - message_parameters={"data_type": type(obj).__name__}, - ) - sdf = session.sql(formatter.format(query, **kwargs), args) + if not is_remote(): + sdf = session.sql(formatter.format(query, **kwargs), args) + else: + ps_query = formatter.format(query, **kwargs) + # here the new_kwargs stores the views + new_kwargs = {} + for psdf, name in formatter._temp_views: + new_kwargs[name] = psdf._to_spark() + # delegate views to spark.sql + sdf = session.sql(ps_query, args, **new_kwargs) finally: formatter.clear() @@ -264,30 +265,42 @@ def _convert_value(self, val: Any, name: str) -> Optional[str]: elif isinstance(val, (DataFrame, pd.DataFrame)): df_name = "_pandas_api_%s" % str(uuid.uuid4()).replace("-", "") - if isinstance(val, pd.DataFrame): - # Don't store temp view for plain pandas instances - # because it is unable to know which pandas DataFrame - # holds which Series. - val = ps.from_pandas(val) + if not is_remote(): + if isinstance(val, pd.DataFrame): + # Don't store temp view for plain pandas instances + # because it is unable to know which pandas DataFrame + # holds which Series. + val = ps.from_pandas(val) + else: + for df, n in self._temp_views: + if df is val: + return n + self._temp_views.append((val, df_name)) + val._to_spark().createOrReplaceTempView(df_name) + return df_name else: + if isinstance(val, pd.DataFrame): + # Always convert pd.DataFrame to ps.DataFrame, and record it in _temp_views. + val = ps.from_pandas(val) + for df, n in self._temp_views: if df is val: return n - self._temp_views.append((val, df_name)) - - val._to_spark().createOrReplaceTempView(df_name) - return df_name + self._temp_views.append((val, name)) + # In Spark Connect, keep the original view name here (not the UUID one), + # the reformatted query is like: 'select * from {tbl} where A > 1' + # and then delegate the view operations to spark.sql. + return "{" + name + "}" elif isinstance(val, str): - # This is matched to behavior from JVM implementation. - # See `sql` definition from `sql/catalyst/src/main/scala/org/apache/spark/ - # sql/catalyst/expressions/literals.scala` - return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'" + return get_lit_sql_str(val) else: return val def clear(self) -> None: - for _, n in self._temp_views: - self._session.catalog.dropTempView(n) + # In Spark Connect, views are created and dropped in Connect Server + if not is_remote(): + for _, n in self._temp_views: + self._session.catalog.dropTempView(n) self._temp_views = [] self._ref_sers = [] diff --git a/python/pyspark/pandas/tests/connect/test_parity_sql.py b/python/pyspark/pandas/tests/connect/test_parity_sql.py index 2e503cac07a8a..29abbda8c0ebb 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_sql.py +++ b/python/pyspark/pandas/tests/connect/test_parity_sql.py @@ -22,13 +22,7 @@ class SQLParityTests(SQLTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): - @unittest.skip("Test depends on temp view issue on JVM side.") - def test_sql_with_index_col(self): - super().test_sql_with_index_col() - - @unittest.skip("Test depends on temp view issue on JVM side.") - def test_sql_with_pandas_on_spark_objects(self): - super().test_sql_with_pandas_on_spark_objects() + pass if __name__ == "__main__":