Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41721][CONNECT][TESTS] Enable doctests in pyspark.sql.connect.catalog #39224

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def __hash__(self):
source_file_regexes=["python/pyspark/sql/connect"],
python_test_goals=[
# doctests
# No doctests yet.
"pyspark.sql.connect.catalog",
# unittests
"pyspark.sql.tests.connect.test_connect_column_expressions",
"pyspark.sql.tests.connect.test_connect_plan_only",
Expand Down
105 changes: 55 additions & 50 deletions python/pyspark/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ def getDatabase(self, dbName: str) -> Database:
locationUri=jdb.locationUri(),
)

# TODO(SPARK-41725): we don't have to `collect` for every `sql` but
# Spark Connect requires it. We should remove them out.
def databaseExists(self, dbName: str) -> bool:
"""Check if the database with the specified name exists.

Expand Down Expand Up @@ -273,15 +275,15 @@ def databaseExists(self, dbName: str) -> bool:

>>> spark.catalog.databaseExists("test_new_database")
False
>>> _ = spark.sql("CREATE DATABASE test_new_database")
>>> _ = spark.sql("CREATE DATABASE test_new_database").collect()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm ... this is orthogonal issue but just to make sure we don't forget. I think we should make sql method in Spark Connect to be analyzed per every call when they execute such commands. Otherwise, we can't make it compatible with the existing PySpark usage.

I change the existing doctest here for now to make the tests pass but should change this back ideally.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about create a jira ticket for this (should change this back ideally)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

>>> spark.catalog.databaseExists("test_new_database")
True

Using the fully qualified name with the catalog name.

>>> spark.catalog.databaseExists("spark_catalog.test_new_database")
True
>>> _ = spark.sql("DROP DATABASE test_new_database")
>>> _ = spark.sql("DROP DATABASE test_new_database").collect()
"""
return self._jcatalog.databaseExists(dbName)

Expand Down Expand Up @@ -370,8 +372,8 @@ def getTable(self, tableName: str) -> Table:

Examples
--------
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
>>> spark.catalog.getTable("tbl1")
Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ...

Expand All @@ -381,14 +383,14 @@ def getTable(self, tableName: str) -> Table:
Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ...
>>> spark.catalog.getTable("spark_catalog.default.tbl1")
Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ...
>>> _ = spark.sql("DROP TABLE tbl1")
>>> _ = spark.sql("DROP TABLE tbl1").collect()

Throw an analysis exception when the table does not exist.

>>> spark.catalog.getTable("tbl1")
>>> spark.catalog.getTable("tbl1") # doctest: +SKIP
Traceback (most recent call last):
...
pyspark.sql.utils.AnalysisException: ...
AnalysisException: ...
"""
jtable = self._jcatalog.getTable(tableName)
jnamespace = jtable.namespace()
Expand Down Expand Up @@ -532,7 +534,8 @@ def getFunction(self, functionName: str) -> Function:

Examples
--------
>>> func = spark.sql("CREATE FUNCTION my_func1 AS 'test.org.apache.spark.sql.MyDoubleAvg'")
>>> _ = spark.sql(
... "CREATE FUNCTION my_func1 AS 'test.org.apache.spark.sql.MyDoubleAvg'").collect()
>>> spark.catalog.getFunction("my_func1")
Function(name='my_func1', catalog='spark_catalog', namespace=['default'], ...

Expand All @@ -545,10 +548,10 @@ def getFunction(self, functionName: str) -> Function:

Throw an analysis exception when the function does not exists.

>>> spark.catalog.getFunction("my_func2")
>>> spark.catalog.getFunction("my_func2") # doctest: +SKIP
Traceback (most recent call last):
...
pyspark.sql.utils.AnalysisException: ...
AnalysisException: ...
"""
jfunction = self._jcatalog.getFunction(functionName)
jnamespace = jfunction.namespace()
Expand Down Expand Up @@ -599,11 +602,11 @@ def listColumns(self, tableName: str, dbName: Optional[str] = None) -> List[Colu

Examples
--------
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
>>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet")
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
>>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet").collect()
>>> spark.catalog.listColumns("tblA")
[Column(name='name', description=None, dataType='string', nullable=True, ...
>>> _ = spark.sql("DROP TABLE tblA")
>>> _ = spark.sql("DROP TABLE tblA").collect()
"""
if dbName is None:
iter = self._jcatalog.listColumns(tableName).toLocalIterator()
Expand Down Expand Up @@ -664,8 +667,8 @@ def tableExists(self, tableName: str, dbName: Optional[str] = None) -> bool:

>>> spark.catalog.tableExists("unexisting_table")
False
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
>>> spark.catalog.tableExists("tbl1")
True

Expand All @@ -677,13 +680,13 @@ def tableExists(self, tableName: str, dbName: Optional[str] = None) -> bool:
True
>>> spark.catalog.tableExists("tbl1", "default")
True
>>> _ = spark.sql("DROP TABLE tbl1")
>>> _ = spark.sql("DROP TABLE tbl1").collect()

Check if views exist:

>>> spark.catalog.tableExists("view1")
False
>>> _ = spark.sql("CREATE VIEW view1 AS SELECT 1")
>>> _ = spark.sql("CREATE VIEW view1 AS SELECT 1").collect()
>>> spark.catalog.tableExists("view1")
True

Expand All @@ -695,14 +698,14 @@ def tableExists(self, tableName: str, dbName: Optional[str] = None) -> bool:
True
>>> spark.catalog.tableExists("view1", "default")
True
>>> _ = spark.sql("DROP VIEW view1")
>>> _ = spark.sql("DROP VIEW view1").collect()

Check if temporary views exist:

>>> _ = spark.sql("CREATE TEMPORARY VIEW view1 AS SELECT 1")
>>> _ = spark.sql("CREATE TEMPORARY VIEW view1 AS SELECT 1").collect()
>>> spark.catalog.tableExists("view1")
True
>>> df = spark.sql("DROP VIEW view1")
>>> df = spark.sql("DROP VIEW view1").collect()
>>> spark.catalog.tableExists("view1")
False
"""
Expand Down Expand Up @@ -803,15 +806,15 @@ def createTable(
Creating a managed table.

>>> _ = spark.catalog.createTable("tbl1", schema=spark.range(1).schema, source='parquet')
>>> _ = spark.sql("DROP TABLE tbl1")
>>> _ = spark.sql("DROP TABLE tbl1").collect()

Creating an external table

>>> import tempfile
>>> with tempfile.TemporaryDirectory() as d:
... _ = spark.catalog.createTable(
... "tbl2", schema=spark.range(1).schema, path=d, source='parquet')
>>> _ = spark.sql("DROP TABLE tbl2")
>>> _ = spark.sql("DROP TABLE tbl2").collect()
"""
if path is not None:
options["path"] = path
Expand Down Expand Up @@ -864,7 +867,7 @@ def dropTempView(self, viewName: str) -> bool:

Throw an exception if the temporary view does not exists.

>>> spark.table("my_table")
>>> spark.table("my_table") # doctest: +SKIP
Traceback (most recent call last):
...
AnalysisException: ...
Expand Down Expand Up @@ -904,7 +907,7 @@ def dropGlobalTempView(self, viewName: str) -> bool:

Throw an exception if the global view does not exists.

>>> spark.table("global_temp.my_table")
>>> spark.table("global_temp.my_table") # doctest: +SKIP
Traceback (most recent call last):
...
AnalysisException: ...
Expand Down Expand Up @@ -945,8 +948,8 @@ def isCached(self, tableName: str) -> bool:

Examples
--------
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
>>> spark.catalog.cacheTable("tbl1")
>>> spark.catalog.isCached("tbl1")
True
Expand All @@ -956,14 +959,14 @@ def isCached(self, tableName: str) -> bool:
>>> spark.catalog.isCached("not_existing_table")
Traceback (most recent call last):
...
pyspark.sql.utils.AnalysisException: ...
AnalysisException: ...

Using the fully qualified name for the table.

>>> spark.catalog.isCached("spark_catalog.default.tbl1")
True
>>> spark.catalog.uncacheTable("tbl1")
>>> _ = spark.sql("DROP TABLE tbl1")
>>> _ = spark.sql("DROP TABLE tbl1").collect()
"""
return self._jcatalog.isCached(tableName)

Expand All @@ -982,22 +985,22 @@ def cacheTable(self, tableName: str) -> None:

Examples
--------
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
>>> spark.catalog.cacheTable("tbl1")

Throw an analysis exception when the table does not exist.

>>> spark.catalog.cacheTable("not_existing_table")
Traceback (most recent call last):
...
pyspark.sql.utils.AnalysisException: ...
AnalysisException: ...

Using the fully qualified name for the table.

>>> spark.catalog.cacheTable("spark_catalog.default.tbl1")
>>> spark.catalog.uncacheTable("tbl1")
>>> _ = spark.sql("DROP TABLE tbl1")
>>> _ = spark.sql("DROP TABLE tbl1").collect()
"""
self._jcatalog.cacheTable(tableName)

Expand All @@ -1016,8 +1019,8 @@ def uncacheTable(self, tableName: str) -> None:

Examples
--------
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
>>> spark.catalog.cacheTable("tbl1")
>>> spark.catalog.uncacheTable("tbl1")
>>> spark.catalog.isCached("tbl1")
Expand All @@ -1028,14 +1031,14 @@ def uncacheTable(self, tableName: str) -> None:
>>> spark.catalog.uncacheTable("not_existing_table") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
pyspark.sql.utils.AnalysisException: ...
AnalysisException: ...

Using the fully qualified name for the table.

>>> spark.catalog.uncacheTable("spark_catalog.default.tbl1")
>>> spark.catalog.isCached("tbl1")
False
>>> _ = spark.sql("DROP TABLE tbl1")
>>> _ = spark.sql("DROP TABLE tbl1").collect()
"""
self._jcatalog.uncacheTable(tableName)

Expand All @@ -1049,12 +1052,12 @@ def clearCache(self) -> None:

Examples
--------
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
>>> spark.catalog.clearCache()
>>> spark.catalog.isCached("tbl1")
False
>>> _ = spark.sql("DROP TABLE tbl1")
>>> _ = spark.sql("DROP TABLE tbl1").collect()
"""
self._jcatalog.clearCache()

Expand All @@ -1080,9 +1083,10 @@ def refreshTable(self, tableName: str) -> None:

>>> import tempfile
>>> with tempfile.TemporaryDirectory() as d:
... _ = spark.sql("DROP TABLE IF EXISTS tbl1")
... _ = spark.sql("CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d))
... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'")
... _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
... _ = spark.sql(
... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d)).collect()
... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'").collect()
... spark.catalog.cacheTable("tbl1")
... spark.table("tbl1").show()
+---+
Expand All @@ -1105,7 +1109,7 @@ def refreshTable(self, tableName: str) -> None:
Using the fully qualified name for the table.

>>> spark.catalog.refreshTable("spark_catalog.default.tbl1")
>>> _ = spark.sql("DROP TABLE tbl1")
>>> _ = spark.sql("DROP TABLE tbl1").collect()
"""
self._jcatalog.refreshTable(tableName)

Expand Down Expand Up @@ -1133,12 +1137,12 @@ def recoverPartitions(self, tableName: str) -> None:

>>> import tempfile
>>> with tempfile.TemporaryDirectory() as d:
... _ = spark.sql("DROP TABLE IF EXISTS tbl1")
... _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
... spark.range(1).selectExpr(
... "id as key", "id as value").write.partitionBy("key").mode("overwrite").save(d)
... _ = spark.sql(
... "CREATE TABLE tbl1 (key LONG, value LONG)"
... "USING parquet OPTIONS (path '{}') PARTITIONED BY (key)".format(d))
... "USING parquet OPTIONS (path '{}') PARTITIONED BY (key)".format(d)).collect()
... spark.table("tbl1").show()
... spark.catalog.recoverPartitions("tbl1")
... spark.table("tbl1").show()
Expand All @@ -1151,7 +1155,7 @@ def recoverPartitions(self, tableName: str) -> None:
+-----+---+
| 0| 0|
+-----+---+
>>> _ = spark.sql("DROP TABLE tbl1")
>>> _ = spark.sql("DROP TABLE tbl1").collect()
"""
self._jcatalog.recoverPartitions(tableName)

Expand All @@ -1175,9 +1179,10 @@ def refreshByPath(self, path: str) -> None:

>>> import tempfile
>>> with tempfile.TemporaryDirectory() as d:
... _ = spark.sql("DROP TABLE IF EXISTS tbl1")
... _ = spark.sql("CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d))
... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'")
... _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
... _ = spark.sql(
... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d)).collect()
... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'").collect()
... spark.catalog.cacheTable("tbl1")
... spark.table("tbl1").show()
+---+
Expand All @@ -1197,7 +1202,7 @@ def refreshByPath(self, path: str) -> None:
>>> spark.table("tbl1").count()
0

>>> _ = spark.sql("DROP TABLE tbl1")
>>> _ = spark.sql("DROP TABLE tbl1").collect()
"""
self._jcatalog.refreshByPath(path)

Expand Down