From ebad3e9ea2e6b269659ad4dae7d3a55bcc722fdf Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Sun, 13 Nov 2022 20:20:26 -0800 Subject: [PATCH 1/2] [SPARK-41127][CONNECT][PYTHON] Implement DataFrame.CreateGlobalView in Python client. --- python/pyspark/sql/connect/client.py | 7 +++ python/pyspark/sql/connect/dataframe.py | 42 +++++++++++++++++ python/pyspark/sql/connect/plan.py | 45 +++++++++++++++++++ .../sql/tests/connect/test_connect_basic.py | 17 ++++++- 4 files changed, 109 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index c2d808bb6ee01..629497201344d 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -400,6 +400,13 @@ def schema(self, plan: pb2.Plan) -> StructType: def explain_string(self, plan: pb2.Plan) -> str: return self._analyze(plan).explain_string + def execute_command(self, command: pb2.Command) -> None: + req = pb2.Request() + if self._user_id: + req.user_context.user_id = self._user_id + req.plan.command.CopyFrom(command) + self._execute_and_fetch(req) + def _analyze(self, plan: pb2.Plan) -> AnalyzeResult: req = pb2.Request() if self._user_id: diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 600bff3bce906..9e136129a012d 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -633,6 +633,48 @@ def explain(self) -> str: else: return "" + def createGlobalTempView(self, name: str) -> None: + """Creates a global temporary view with this :class:`DataFrame`. + + The lifetime of this temporary view is tied to this Spark application. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + name : str + Name of the view. + + Returns + ------- + None + """ + command = plan.CreateView( + child=self._plan, name=name, is_global=True, replace=False + ).command(session=self._session) + self._session.execute_command(command) + + def createOrReplaceGlobalTempView(self, name: str) -> None: + """Creates or replaces a global temporary view using the given name. + + The lifetime of this temporary view is tied to this Spark application. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + name : str + Name of the view. + + Returns + ------- + None + """ + command = plan.CreateView( + child=self._plan, name=name, is_global=True, replace=True + ).command(session=self._session) + self._session.execute_command(command) + class DataFrameStatFunctions: """Functionality for statistic functions with :class:`DataFrame`. diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 8d8941e006a88..a6eef9210da9d 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -69,6 +69,9 @@ def to_attr_or_expression( def plan(self, session: "RemoteSparkSession") -> proto.Relation: ... + def command(self, session: "RemoteSparkSession") -> proto.Command: + ... + def _verify(self, session: "RemoteSparkSession") -> bool: """This method is used to verify that the current logical plan can be serialized to Proto and back and afterwards is identical.""" @@ -862,3 +865,45 @@ def _repr_html_(self) -> str: """ + + +class CreateView(LogicalPlan): + def __init__( + self, child: Optional["LogicalPlan"], name: str, is_global: bool, replace: bool + ) -> None: + super().__init__(child) + self._name = name + self._is_gloal = is_global + self._replace = replace + + def command(self, session: "RemoteSparkSession") -> proto.Command: + assert self._child is not None + + plan = proto.Command() + plan.create_dataframe_view.replace = self._replace + plan.create_dataframe_view.is_global = self._is_gloal + plan.create_dataframe_view.name = self._name + plan.create_dataframe_view.input.CopyFrom(self._child.plan(session)) + return plan + + def print(self, indent: int = 0) -> str: + i = " " * indent + return ( + f"{i}" + f"" + ) + + def _repr_html_(self) -> str: + return f""" + + """ diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 2c44e030626f5..ee63acf9e94d5 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -20,8 +20,9 @@ import tempfile import grpc # type: ignore +from grpc._channel import _MultiThreadedRendezvous # type: ignore -from pyspark.testing.sqlutils import have_pandas +from pyspark.testing.sqlutils import have_pandas, SQLTestUtils if have_pandas: import pandas @@ -39,7 +40,7 @@ @unittest.skipIf(not should_test_connect, connect_requirement_message) -class SparkConnectSQLTestCase(ReusedPySparkTestCase): +class SparkConnectSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): """Parent test fixture class for all Spark Connect related test cases.""" @@ -207,6 +208,18 @@ def test_range(self): .equals(self.spark.range(start=0, end=10, step=3, numPartitions=2).toPandas()) ) + def test_create_global_temp_view(self): + # SPARK-41127: test global temp view creation. + with self.tempView("view_1"): + self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") + self.connect.sql("SELECT 2 AS X LIMIT 1").createOrReplaceGlobalTempView("view_1") + self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1")) + + # Test when creating a view which is alreayd exists but + self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1")) + with self.assertRaises(_MultiThreadedRendezvous): + self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") + def test_empty_dataset(self): # SPARK-41005: Test arrow based collection with empty dataset. self.assertTrue( From 665a35daf8677804329848cf32f6e74bb15e27e7 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Mon, 14 Nov 2022 20:36:44 -0800 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Hyukjin Kwon --- python/pyspark/sql/connect/dataframe.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 9e136129a012d..2685e95554aa2 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -644,10 +644,6 @@ def createGlobalTempView(self, name: str) -> None: ---------- name : str Name of the view. - - Returns - ------- - None """ command = plan.CreateView( child=self._plan, name=name, is_global=True, replace=False @@ -665,10 +661,6 @@ def createOrReplaceGlobalTempView(self, name: str) -> None: ---------- name : str Name of the view. - - Returns - ------- - None """ command = plan.CreateView( child=self._plan, name=name, is_global=True, replace=True