Skip to content

Commit

Permalink
[SPARK-43961][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.…
Browse files Browse the repository at this point in the history
…listTables

### What changes were proposed in this pull request?
Currently, the syntax `SHOW TABLES LIKE pattern` supports an optional pattern, so as filtered out the expected tables.
But the `Catalog.listTables` missing the function both in Catalog API and Connect Catalog API.

In fact, the optional pattern is very useful.

This PR also extracts the common `wrapNamespace` to clean up the duplicated code.

### Why are the changes needed?
This PR want add the optional pattern for `Catalog.listTables`.

### Does this PR introduce _any_ user-facing change?
'No'.
New feature.

### How was this patch tested?
New test cases.

Closes #41461 from beliefer/SPARK-43961.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
beliefer authored and cloud-fan committed Jun 7, 2023
1 parent 41fd030 commit 64855fa
Show file tree
Hide file tree
Showing 15 changed files with 259 additions and 90 deletions.
Expand Up @@ -76,6 +76,15 @@ abstract class Catalog {
@throws[AnalysisException]("database does not exist")
def listTables(dbName: String): Dataset[Table]

/**
* Returns a list of tables/views in the specified database (namespace) which name match the
* specify pattern (the name can be qualified with catalog). This includes all temporary views.
*
* @since 3.5.0
*/
@throws[AnalysisException]("database does not exist")
def listTables(dbName: String, pattern: String): Dataset[Table]

/**
* Returns a list of functions registered in the current database (namespace). This includes all
* temporary functions.
Expand Down
Expand Up @@ -101,6 +101,19 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
}
}

/**
* Returns a list of tables/views in the specified database (namespace) which name match the
* specify pattern (the name can be qualified with catalog). This includes all temporary views.
*
* @since 3.5.0
*/
@throws[AnalysisException]("database does not exist")
def listTables(dbName: String, pattern: String): Dataset[Table] = {
sparkSession.newDataset(CatalogImpl.tableEncoder) { builder =>
builder.getCatalogBuilder.getListTablesBuilder.setDbName(dbName).setPattern(pattern)
}
}

/**
* Returns a list of functions registered in the current database (namespace). This includes all
* temporary functions.
Expand Down
Expand Up @@ -126,6 +126,14 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper {
parquetTableName,
orcTableName,
jsonTableName))
assert(
spark.catalog
.listTables(spark.catalog.currentDatabase, "par*")
.collect()
.map(_.name)
.toSet == Set(parquetTableName))
assert(
spark.catalog.listTables(spark.catalog.currentDatabase, "txt*").collect().isEmpty)
}
assert(spark.catalog.tableExists(parquetTableName))
assert(!spark.catalog.tableExists(orcTableName))
Expand Down
Expand Up @@ -77,6 +77,8 @@ message ListDatabases {
message ListTables {
// (Optional)
optional string db_name = 1;
// (Optional) The pattern that the table name needs to match
optional string pattern = 2;
}

// See `spark.catalog.listFunctions`
Expand Down
Expand Up @@ -2706,7 +2706,14 @@ class SparkConnectPlanner(val session: SparkSession) {

private def transformListTables(getListTables: proto.ListTables): LogicalPlan = {
if (getListTables.hasDbName) {
session.catalog.listTables(getListTables.getDbName).logicalPlan
if (getListTables.hasPattern) {
session.catalog.listTables(getListTables.getDbName, getListTables.getPattern).logicalPlan
} else {
session.catalog.listTables(getListTables.getDbName).logicalPlan
}
} else if (getListTables.hasPattern) {
val currentDatabase = session.catalog.currentDatabase
session.catalog.listTables(currentDatabase, getListTables.getPattern).logicalPlan
} else {
session.catalog.listTables().logicalPlan
}
Expand Down
4 changes: 3 additions & 1 deletion project/MimaExcludes.scala
Expand Up @@ -52,7 +52,9 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.prettyJson"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.jsonValue"),
// [SPARK-43881][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listDatabases
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listDatabases")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listDatabases"),
// [SPARK-43961][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listTables
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listTables")
)

// Defulat exclude rules
Expand Down
24 changes: 21 additions & 3 deletions python/pyspark/sql/catalog.py
Expand Up @@ -202,7 +202,7 @@ def listDatabases(self, pattern: Optional[str] = None) -> List[Database]:
The pattern that the database name needs to match.
.. versionchanged: 3.5.0
Added ``pattern`` argument.
Adds ``pattern`` argument.
Returns
-------
Expand Down Expand Up @@ -307,7 +307,9 @@ def databaseExists(self, dbName: str) -> bool:
"""
return self._jcatalog.databaseExists(dbName)

def listTables(self, dbName: Optional[str] = None) -> List[Table]:
def listTables(
self, dbName: Optional[str] = None, pattern: Optional[str] = None
) -> List[Table]:
"""Returns a list of tables/views in the specified database.
.. versionadded:: 2.0.0
Expand All @@ -320,6 +322,12 @@ def listTables(self, dbName: Optional[str] = None) -> List[Table]:
.. versionchanged:: 3.4.0
Allow ``dbName`` to be qualified with catalog name.
pattern : str
The pattern that the database name needs to match.
.. versionchanged: 3.5.0
Adds ``pattern`` argument.
Returns
-------
list
Expand All @@ -336,13 +344,23 @@ def listTables(self, dbName: Optional[str] = None) -> List[Table]:
>>> spark.catalog.listTables()
[Table(name='test_view', catalog=None, namespace=[], description=None, ...
>>> spark.catalog.listTables(pattern="test*")
[Table(name='test_view', catalog=None, namespace=[], description=None, ...
>>> spark.catalog.listTables(pattern="table*")
[]
>>> _ = spark.catalog.dropTempView("test_view")
>>> spark.catalog.listTables()
[]
"""
if dbName is None:
dbName = self.currentDatabase()
iter = self._jcatalog.listTables(dbName).toLocalIterator()

if pattern is None:
iter = self._jcatalog.listTables(dbName).toLocalIterator()
else:
iter = self._jcatalog.listTables(dbName, pattern).toLocalIterator()
tables = []
while iter.hasNext():
jtable = iter.next()
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/sql/connect/catalog.py
Expand Up @@ -116,8 +116,10 @@ def databaseExists(self, dbName: str) -> bool:

databaseExists.__doc__ = PySparkCatalog.databaseExists.__doc__

def listTables(self, dbName: Optional[str] = None) -> List[Table]:
pdf = self._execute_and_fetch(plan.ListTables(db_name=dbName))
def listTables(
self, dbName: Optional[str] = None, pattern: Optional[str] = None
) -> List[Table]:
pdf = self._execute_and_fetch(plan.ListTables(db_name=dbName, pattern=pattern))
return [
Table(
name=row.iloc[0],
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/plan.py
Expand Up @@ -1648,14 +1648,17 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:


class ListTables(LogicalPlan):
def __init__(self, db_name: Optional[str] = None) -> None:
def __init__(self, db_name: Optional[str] = None, pattern: Optional[str] = None) -> None:
super().__init__(None)
self._db_name = db_name
self._pattern = pattern

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(list_tables=proto.ListTables()))
if self._db_name is not None:
plan.catalog.list_tables.db_name = self._db_name
if self._pattern is not None:
plan.catalog.list_tables.pattern = self._pattern
return plan


Expand Down

0 comments on commit 64855fa

Please sign in to comment.