Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43961][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listTables #41461

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -76,6 +76,15 @@ abstract class Catalog {
@throws[AnalysisException]("database does not exist")
def listTables(dbName: String): Dataset[Table]

/**
* Returns a list of tables/views in the specified database (namespace) which name match the
* specify pattern (the name can be qualified with catalog). This includes all temporary views.
*
* @since 3.5.0
*/
@throws[AnalysisException]("database does not exist")
def listTables(dbName: String, pattern: String): Dataset[Table]

/**
* Returns a list of functions registered in the current database (namespace). This includes all
* temporary functions.
Expand Down
Expand Up @@ -101,6 +101,19 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
}
}

/**
* Returns a list of tables/views in the specified database (namespace) which name match the
* specify pattern (the name can be qualified with catalog). This includes all temporary views.
*
* @since 3.5.0
*/
@throws[AnalysisException]("database does not exist")
def listTables(dbName: String, pattern: String): Dataset[Table] = {
sparkSession.newDataset(CatalogImpl.tableEncoder) { builder =>
builder.getCatalogBuilder.getListTablesBuilder.setDbName(dbName).setPattern(pattern)
}
}

/**
* Returns a list of functions registered in the current database (namespace). This includes all
* temporary functions.
Expand Down
Expand Up @@ -126,6 +126,14 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper {
parquetTableName,
orcTableName,
jsonTableName))
assert(
spark.catalog
.listTables(spark.catalog.currentDatabase, "par*")
.collect()
.map(_.name)
.toSet == Set(parquetTableName))
assert(
spark.catalog.listTables(spark.catalog.currentDatabase, "txt*").collect().isEmpty)
}
assert(spark.catalog.tableExists(parquetTableName))
assert(!spark.catalog.tableExists(orcTableName))
Expand Down
Expand Up @@ -77,6 +77,8 @@ message ListDatabases {
message ListTables {
// (Optional)
optional string db_name = 1;
// (Optional) The pattern that the table name needs to match
optional string pattern = 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a test case for Spark Connect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added test cases in CatalogSuite.

}

// See `spark.catalog.listFunctions`
Expand Down
Expand Up @@ -2706,7 +2706,14 @@ class SparkConnectPlanner(val session: SparkSession) {

private def transformListTables(getListTables: proto.ListTables): LogicalPlan = {
if (getListTables.hasDbName) {
session.catalog.listTables(getListTables.getDbName).logicalPlan
if (getListTables.hasPattern) {
session.catalog.listTables(getListTables.getDbName, getListTables.getPattern).logicalPlan
} else {
session.catalog.listTables(getListTables.getDbName).logicalPlan
}
} else if (getListTables.hasPattern) {
val currentDatabase = session.catalog.currentDatabase
session.catalog.listTables(currentDatabase, getListTables.getPattern).logicalPlan
} else {
session.catalog.listTables().logicalPlan
}
Expand Down
4 changes: 3 additions & 1 deletion project/MimaExcludes.scala
Expand Up @@ -46,7 +46,9 @@ object MimaExcludes {
// [SPARK-43792][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listCatalogs
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listCatalogs"),
// [SPARK-43881][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listDatabases
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listDatabases")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listDatabases"),
// [SPARK-43961][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listTables
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listTables")
)

// Defulat exclude rules
Expand Down
24 changes: 21 additions & 3 deletions python/pyspark/sql/catalog.py
Expand Up @@ -202,7 +202,7 @@ def listDatabases(self, pattern: Optional[str] = None) -> List[Database]:
The pattern that the database name needs to match.

.. versionchanged: 3.5.0
Added ``pattern`` argument.
Adds ``pattern`` argument.

Returns
-------
Expand Down Expand Up @@ -307,7 +307,9 @@ def databaseExists(self, dbName: str) -> bool:
"""
return self._jcatalog.databaseExists(dbName)

def listTables(self, dbName: Optional[str] = None) -> List[Table]:
def listTables(
self, dbName: Optional[str] = None, pattern: Optional[str] = None
) -> List[Table]:
"""Returns a list of tables/views in the specified database.

.. versionadded:: 2.0.0
Expand All @@ -320,6 +322,12 @@ def listTables(self, dbName: Optional[str] = None) -> List[Table]:
.. versionchanged:: 3.4.0
Allow ``dbName`` to be qualified with catalog name.

pattern : str
The pattern that the database name needs to match.

.. versionchanged: 3.5.0
Adds ``pattern`` argument.

Returns
-------
list
Expand All @@ -336,13 +344,23 @@ def listTables(self, dbName: Optional[str] = None) -> List[Table]:
>>> spark.catalog.listTables()
[Table(name='test_view', catalog=None, namespace=[], description=None, ...

>>> spark.catalog.listTables(pattern="test*")
[Table(name='test_view', catalog=None, namespace=[], description=None, ...

>>> spark.catalog.listTables(pattern="table*")
[]

>>> _ = spark.catalog.dropTempView("test_view")
>>> spark.catalog.listTables()
[]
"""
if dbName is None:
dbName = self.currentDatabase()
iter = self._jcatalog.listTables(dbName).toLocalIterator()

if pattern is None:
iter = self._jcatalog.listTables(dbName).toLocalIterator()
else:
iter = self._jcatalog.listTables(dbName, pattern).toLocalIterator()
tables = []
while iter.hasNext():
jtable = iter.next()
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/sql/connect/catalog.py
Expand Up @@ -116,8 +116,10 @@ def databaseExists(self, dbName: str) -> bool:

databaseExists.__doc__ = PySparkCatalog.databaseExists.__doc__

def listTables(self, dbName: Optional[str] = None) -> List[Table]:
pdf = self._execute_and_fetch(plan.ListTables(db_name=dbName))
def listTables(
self, dbName: Optional[str] = None, pattern: Optional[str] = None
) -> List[Table]:
pdf = self._execute_and_fetch(plan.ListTables(db_name=dbName, pattern=pattern))
return [
Table(
name=row.iloc[0],
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/plan.py
Expand Up @@ -1648,14 +1648,17 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:


class ListTables(LogicalPlan):
def __init__(self, db_name: Optional[str] = None) -> None:
def __init__(self, db_name: Optional[str] = None, pattern: Optional[str] = None) -> None:
super().__init__(None)
self._db_name = db_name
self._pattern = pattern

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(list_tables=proto.ListTables()))
if self._db_name is not None:
plan.catalog.list_tables.db_name = self._db_name
if self._pattern is not None:
plan.catalog.list_tables.pattern = self._pattern
return plan


Expand Down