Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions python/pyspark/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class CatalogMetadata(NamedTuple):

class Database(NamedTuple):
name: str
catalog: Optional[str]
description: Optional[str]
locationUri: str

Expand Down Expand Up @@ -139,11 +140,40 @@ def listDatabases(self) -> List[Database]:
jdb = iter.next()
databases.append(
Database(
name=jdb.name(), description=jdb.description(), locationUri=jdb.locationUri()
name=jdb.name(),
catalog=jdb.catalog(),
description=jdb.description(),
locationUri=jdb.locationUri(),
)
)
return databases

def getDatabase(self, dbName: str) -> Database:
"""Get the database with the specified name.
This throws an AnalysisException when the database cannot be found.

.. versionadded:: 3.4.0

Parameters
----------
dbName : str
name of the database to check existence.

Examples
--------
>>> spark.catalog.getDatabase("default")
Database(name='default', catalog=None, description='default database', ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not aware of this though:

Is the catalog=None because of current catalog is not set? I am asking because I thought every DB should have a catalog.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's fix this at scala side as well later.

>>> spark.catalog.getDatabase("spark_catalog.default")
Database(name='default', catalog='spark_catalog', description='default database', ...
"""
jdb = self._jcatalog.getDatabase(dbName)
return Database(
name=jdb.name(),
catalog=jdb.catalog(),
description=jdb.description(),
locationUri=jdb.locationUri(),
)

def databaseExists(self, dbName: str) -> bool:
"""Check if the database with the specified name exists.

Expand Down Expand Up @@ -309,14 +339,33 @@ def listColumns(self, tableName: str, dbName: Optional[str] = None) -> List[Colu

.. versionadded:: 2.0.0

Parameters
----------
tableName : str
name of the table to check existence
dbName : str, optional
name of the database to check table existence in.

.. deprecated:: 3.4.0

.. versionchanged:: 3.4
Allowed ``tableName`` to be qualified with catalog name when ``dbName`` is None.

Notes
-----
the order of arguments here is different from that of its JVM counterpart
because Python does not support method overloading.
"""
if dbName is None:
dbName = self.currentDatabase()
iter = self._jcatalog.listColumns(dbName, tableName).toLocalIterator()
iter = self._jcatalog.listColumns(tableName).toLocalIterator()
else:
warnings.warn(
"`dbName` has been deprecated since Spark 3.4 and might be removed in "
"a future version. Use listColumns(`dbName.tableName`) instead.",
FutureWarning,
)
iter = self._jcatalog.listColumns(dbName, tableName).toLocalIterator()

columns = []
while iter.hasNext():
jcolumn = iter.next()
Expand Down Expand Up @@ -590,7 +639,11 @@ def clearCache(self) -> None:

@since(2.0)
def refreshTable(self, tableName: str) -> None:
"""Invalidates and refreshes all the cached data and metadata of the given table."""
"""Invalidates and refreshes all the cached data and metadata of the given table.

.. versionchanged:: 3.4
Allowed ``tableName`` to be qualified with catalog name.
"""
self._jcatalog.refreshTable(tableName)

@since("2.1.1")
Expand Down
32 changes: 31 additions & 1 deletion python/pyspark/sql/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def test_database_exists(self):
self.assertTrue(spark.catalog.databaseExists("spark_catalog.some_db"))
self.assertFalse(spark.catalog.databaseExists("spark_catalog.some_db2"))

def test_get_database(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add a freshtable API test on Python side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

spark = self.spark
with self.database("some_db"):
spark.sql("CREATE DATABASE some_db")
db = spark.catalog.getDatabase("spark_catalog.some_db")
self.assertEqual(db.name, "some_db")
self.assertEqual(db.catalog, "spark_catalog")

def test_list_tables(self):
from pyspark.sql.catalog import Table

Expand Down Expand Up @@ -245,7 +253,9 @@ def test_list_columns(self):
spark.sql(
"CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet"
)
columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
columns = sorted(
spark.catalog.listColumns("spark_catalog.default.tab1"), key=lambda c: c.name
)
columnsDefault = sorted(
spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name
)
Expand Down Expand Up @@ -352,6 +362,26 @@ def test_get_table(self):
self.assertEqual(spark.catalog.getTable("default.tab1").catalog, "spark_catalog")
self.assertEqual(spark.catalog.getTable("spark_catalog.default.tab1").name, "tab1")

def test_refresh_table(self):
import os
import tempfile

spark = self.spark
with tempfile.TemporaryDirectory() as tmp_dir:
with self.table("my_tab"):
spark.sql(
"CREATE TABLE my_tab (col STRING) USING TEXT LOCATION '{}'".format(tmp_dir)
)
spark.sql("INSERT INTO my_tab SELECT 'abc'")
spark.catalog.cacheTable("my_tab")
self.assertEqual(spark.table("my_tab").count(), 1)

os.system("rm -rf {}/*".format(tmp_dir))
self.assertEqual(spark.table("my_tab").count(), 1)

spark.catalog.refreshTable("spark_catalog.default.my_tab")
self.assertEqual(spark.table("my_tab").count(), 0)


if __name__ == "__main__":
import unittest
Expand Down