Skip to content

Commit

Permalink
[SPARK-48336][PS][CONNECT] Implement ps.sql in Spark Connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Implement `ps.sql` in Spark Connect

### Why are the changes needed?
feature parity in Spark Connect

### Does this PR introduce _any_ user-facing change?
yes:

```
In [4]: spark
Out[4]: <pyspark.sql.connect.session.SparkSession at 0x105136390>

In [5]:     >>> ps.sql('''
   ...:     ...   SELECT m1.a, m2.b
   ...:     ...   FROM {table1} m1 INNER JOIN {table2} m2
   ...:     ...   ON m1.key = m2.key
   ...:     ...   ORDER BY m1.a, m2.b''',
   ...:     ...   table1=ps.DataFrame({"a": [1,2], "key": ["a", "b"]}),
   ...:     ...   table2=pd.DataFrame({"b": [3,4,5], "key": ["a", "b", "b"]}))
/Users/ruifeng.zheng/Dev/spark/python/pyspark/pandas/utils.py:1018: PandasAPIOnSparkAdviceWarning: The config 'spark.sql.ansi.enabled' is set to True. This can cause unexpected behavior from pandas API on Spark since pandas API on Spark follows the behavior of pandas, not SQL.
  warnings.warn(message, PandasAPIOnSparkAdviceWarning)
/Users/ruifeng.zheng/Dev/spark/python/pyspark/pandas/utils.py:1018: PandasAPIOnSparkAdviceWarning: The config 'spark.sql.ansi.enabled' is set to True. This can cause unexpected behavior from pandas API on Spark since pandas API on Spark follows the behavior of pandas, not SQL.
  warnings.warn(message, PandasAPIOnSparkAdviceWarning)

   a  b
0  1  3
1  2  4
2  2  5
```

### How was this patch tested?

1. enabled UTs
2. also manually tested all the examples

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #46658 from zhengruifeng/ps_sql.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed May 21, 2024
1 parent 6e6e7a0 commit df50d4b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
63 changes: 38 additions & 25 deletions python/pyspark/pandas/sql_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from pyspark.pandas.namespace import _get_index_map
from pyspark import pandas as ps
from pyspark.sql import SparkSession
from pyspark.sql.utils import get_lit_sql_str
from pyspark.pandas.utils import default_session
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.series import Series
from pyspark.errors import PySparkTypeError
from pyspark.sql.utils import is_remote


Expand Down Expand Up @@ -203,15 +203,16 @@ def sql(
session = default_session()
formatter = PandasSQLStringFormatter(session)
try:
# ps.DataFrame are not supported for Spark Connect currently.
if is_remote():
for obj in kwargs.values():
if isinstance(obj, ps.DataFrame):
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE",
message_parameters={"data_type": type(obj).__name__},
)
sdf = session.sql(formatter.format(query, **kwargs), args)
if not is_remote():
sdf = session.sql(formatter.format(query, **kwargs), args)
else:
ps_query = formatter.format(query, **kwargs)
# here the new_kwargs stores the views
new_kwargs = {}
for psdf, name in formatter._temp_views:
new_kwargs[name] = psdf._to_spark()
# delegate views to spark.sql
sdf = session.sql(ps_query, args, **new_kwargs)
finally:
formatter.clear()

Expand Down Expand Up @@ -264,30 +265,42 @@ def _convert_value(self, val: Any, name: str) -> Optional[str]:
elif isinstance(val, (DataFrame, pd.DataFrame)):
df_name = "_pandas_api_%s" % str(uuid.uuid4()).replace("-", "")

if isinstance(val, pd.DataFrame):
# Don't store temp view for plain pandas instances
# because it is unable to know which pandas DataFrame
# holds which Series.
val = ps.from_pandas(val)
if not is_remote():
if isinstance(val, pd.DataFrame):
# Don't store temp view for plain pandas instances
# because it is unable to know which pandas DataFrame
# holds which Series.
val = ps.from_pandas(val)
else:
for df, n in self._temp_views:
if df is val:
return n
self._temp_views.append((val, df_name))
val._to_spark().createOrReplaceTempView(df_name)
return df_name
else:
if isinstance(val, pd.DataFrame):
# Always convert pd.DataFrame to ps.DataFrame, and record it in _temp_views.
val = ps.from_pandas(val)

for df, n in self._temp_views:
if df is val:
return n
self._temp_views.append((val, df_name))

val._to_spark().createOrReplaceTempView(df_name)
return df_name
self._temp_views.append((val, name))
# In Spark Connect, keep the original view name here (not the UUID one),
# the reformatted query is like: 'select * from {tbl} where A > 1'
# and then delegate the view operations to spark.sql.
return "{" + name + "}"
elif isinstance(val, str):
# This is matched to behavior from JVM implementation.
# See `sql` definition from `sql/catalyst/src/main/scala/org/apache/spark/
# sql/catalyst/expressions/literals.scala`
return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'"
return get_lit_sql_str(val)
else:
return val

def clear(self) -> None:
for _, n in self._temp_views:
self._session.catalog.dropTempView(n)
# In Spark Connect, views are created and dropped in Connect Server
if not is_remote():
for _, n in self._temp_views:
self._session.catalog.dropTempView(n)
self._temp_views = []
self._ref_sers = []

Expand Down
8 changes: 1 addition & 7 deletions python/pyspark/pandas/tests/connect/test_parity_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,7 @@


class SQLParityTests(SQLTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
@unittest.skip("Test depends on temp view issue on JVM side.")
def test_sql_with_index_col(self):
super().test_sql_with_index_col()

@unittest.skip("Test depends on temp view issue on JVM side.")
def test_sql_with_pandas_on_spark_objects(self):
super().test_sql_with_pandas_on_spark_objects()
pass


if __name__ == "__main__":
Expand Down

0 comments on commit df50d4b

Please sign in to comment.