Skip to content

Commit

Permalink
[SPARK-29014][SQL] DataSourceV2: Fix current/default catalog usage
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
The handling of the catalog across plans should be as follows ([SPARK-29014](https://issues.apache.org/jira/browse/SPARK-29014)):
* The *current* catalog should be used when no catalog is specified
* The default catalog is the catalog *current* is initialized to
* If the *default* catalog is not set, then *current* catalog is the built-in Spark session catalog.

This PR addresses the issue where *current* catalog usage is not followed as describe above.

### Why are the changes needed?

It is a bug as described in the previous section.

### Does this PR introduce any user-facing change?
No.

### How was this patch tested?

Unit tests added.

Closes #26120 from imback82/cleanup_catalog.

Authored-by: Terry Kim <yuminkim@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
imback82 authored and cloud-fan committed Oct 18, 2019
1 parent 7435146 commit 39af51d
Show file tree
Hide file tree
Showing 12 changed files with 98 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ class Analyzer(
maxIterations: Int)
extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog {

private val catalog: SessionCatalog = catalogManager.v1SessionCatalog
private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog

override def isView(nameParts: Seq[String]): Boolean = catalog.isView(nameParts)
override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts)

// Only for tests.
def this(catalog: SessionCatalog, conf: SQLConf) = {
Expand Down Expand Up @@ -225,7 +225,7 @@ class Analyzer(
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables(conf) ::
ResolveHigherOrderFunctions(catalog) ::
ResolveHigherOrderFunctions(v1SessionCatalog) ::
ResolveLambdaVariables(conf) ::
ResolveTimeZone(conf) ::
ResolveRandomSeed ::
Expand Down Expand Up @@ -721,7 +721,7 @@ class Analyzer(
// have empty defaultDatabase and all the relations in viewText have database part defined.
def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match {
case u @ UnresolvedRelation(AsTemporaryViewIdentifier(ident))
if catalog.isTemporaryTable(ident) =>
if v1SessionCatalog.isTemporaryTable(ident) =>
resolveRelation(lookupTableFromCatalog(ident, u, AnalysisContext.get.defaultDatabase))

case u @ UnresolvedRelation(AsTableIdentifier(ident)) if !isRunningDirectlyOnFiles(ident) =>
Expand Down Expand Up @@ -778,7 +778,7 @@ class Analyzer(
val tableIdentWithDb = tableIdentifier.copy(
database = tableIdentifier.database.orElse(defaultDatabase))
try {
catalog.lookupRelation(tableIdentWithDb)
v1SessionCatalog.lookupRelation(tableIdentWithDb)
} catch {
case _: NoSuchTableException | _: NoSuchDatabaseException =>
u
Expand All @@ -792,8 +792,9 @@ class Analyzer(
// Note that we are testing (!db_exists || !table_exists) because the catalog throws
// an exception from tableExists if the database does not exist.
private def isRunningDirectlyOnFiles(table: TableIdentifier): Boolean = {
table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) &&
(!catalog.databaseExists(table.database.get) || !catalog.tableExists(table))
table.database.isDefined && conf.runSQLonFile && !v1SessionCatalog.isTemporaryTable(table) &&
(!v1SessionCatalog.databaseExists(table.database.get)
|| !v1SessionCatalog.tableExists(table))
}
}

Expand Down Expand Up @@ -1511,13 +1512,14 @@ class Analyzer(
plan.resolveExpressions {
case f: UnresolvedFunction
if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f
case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) =>
case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) =>
externalFunctionNameSet.add(normalizeFuncName(f.name))
f
case f: UnresolvedFunction =>
withPosition(f) {
throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase),
throw new NoSuchFunctionException(
f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
f.name.funcName)
}
}
Expand All @@ -1532,7 +1534,7 @@ class Analyzer(

val databaseName = name.database match {
case Some(a) => formatDatabaseName(a)
case None => catalog.getCurrentDatabase
case None => v1SessionCatalog.getCurrentDatabase
}

FunctionIdentifier(funcName, Some(databaseName))
Expand All @@ -1557,7 +1559,7 @@ class Analyzer(
}
case u @ UnresolvedGenerator(name, children) =>
withPosition(u) {
catalog.lookupFunction(name, children) match {
v1SessionCatalog.lookupFunction(name, children) match {
case generator: Generator => generator
case other =>
failAnalysis(s"$name is expected to be a generator. However, " +
Expand All @@ -1566,7 +1568,7 @@ class Analyzer(
}
case u @ UnresolvedFunction(funcId, children, isDistinct) =>
withPosition(u) {
catalog.lookupFunction(funcId, children) match {
v1SessionCatalog.lookupFunction(funcId, children) match {
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
// the context of a Window clause. They do not need to be wrapped in an
// AggregateExpression.
Expand Down Expand Up @@ -2765,17 +2767,17 @@ class Analyzer(
private def lookupV2RelationAndCatalog(
identifier: Seq[String]): Option[(DataSourceV2Relation, CatalogPlugin, Identifier)] =
identifier match {
case AsTemporaryViewIdentifier(ti) if catalog.isTemporaryTable(ti) => None
case CatalogObjectIdentifier(Some(v2Catalog), ident) =>
CatalogV2Util.loadTable(v2Catalog, ident) match {
case Some(table) => Some((DataSourceV2Relation.create(table), v2Catalog, ident))
case AsTemporaryViewIdentifier(ti) if v1SessionCatalog.isTemporaryTable(ti) => None
case CatalogObjectIdentifier(catalog, ident) if !CatalogV2Util.isSessionCatalog(catalog) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(table) => Some((DataSourceV2Relation.create(table), catalog, ident))
case None => None
}
case CatalogObjectIdentifier(None, ident) =>
CatalogV2Util.loadTable(catalogManager.v2SessionCatalog, ident) match {
case CatalogObjectIdentifier(catalog, ident) if CatalogV2Util.isSessionCatalog(catalog) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(_: V1Table) => None
case Some(table) =>
Some((DataSourceV2Relation.create(table), catalogManager.v2SessionCatalog, ident))
Some((DataSourceV2Relation.create(table), catalog, ident))
case None => None
}
case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,8 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
case ShowTablesStatement(Some(NonSessionCatalog(catalog, nameParts)), pattern) =>
ShowTables(catalog.asTableCatalog, nameParts, pattern)

// TODO (SPARK-29014): we should check if the current catalog is not session catalog here.
case ShowTablesStatement(None, pattern) if defaultCatalog.isDefined =>
ShowTables(defaultCatalog.get.asTableCatalog, catalogManager.currentNamespace, pattern)
case ShowTablesStatement(None, pattern) if !isSessionCatalog(currentCatalog) =>
ShowTables(currentCatalog.asTableCatalog, catalogManager.currentNamespace, pattern)

case UseStatement(isNamespaceSet, nameParts) =>
if (isNamespaceSet) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class CatalogManager(
}
}

def defaultCatalog: Option[CatalogPlugin] = {
private def defaultCatalog: Option[CatalogPlugin] = {
conf.defaultV2Catalog.flatMap { catalogName =>
try {
Some(catalog(catalogName))
Expand All @@ -74,9 +74,16 @@ class CatalogManager(
}
}

// If the V2_SESSION_CATALOG_IMPLEMENTATION config is specified, we try to instantiate the
// user-specified v2 session catalog. Otherwise, return the default session catalog.
def v2SessionCatalog: CatalogPlugin = {
/**
* If the V2_SESSION_CATALOG config is specified, we try to instantiate the user-specified v2
* session catalog. Otherwise, return the default session catalog.
*
* This catalog is a v2 catalog that delegates to the v1 session catalog. it is used when the
* session catalog is responsible for an identifier, but the source requires the v2 catalog API.
* This happens when the source implementation extends the v2 TableProvider API and is not listed
* in the fallback configuration, spark.sql.sources.write.useV1SourceList
*/
private def v2SessionCatalog: CatalogPlugin = {
conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).map { customV2SessionCatalog =>
try {
catalogs.getOrElseUpdate(SESSION_CATALOG_NAME, loadV2SessionCatalog())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,11 @@ private[sql] trait LookupCatalog extends Logging {

protected val catalogManager: CatalogManager

/**
* Returns the default catalog. When set, this catalog is used for all identifiers that do not
* set a specific catalog. When this is None, the session catalog is responsible for the
* identifier.
*
* If this is None and a table's provider (source) is a v2 provider, the v2 session catalog will
* be used.
*/
def defaultCatalog: Option[CatalogPlugin] = catalogManager.defaultCatalog

/**
* Returns the current catalog set.
*/
def currentCatalog: CatalogPlugin = catalogManager.currentCatalog

/**
* This catalog is a v2 catalog that delegates to the v1 session catalog. it is used when the
* session catalog is responsible for an identifier, but the source requires the v2 catalog API.
* This happens when the source implementation extends the v2 TableProvider API and is not listed
* in the fallback configuration, spark.sql.sources.write.useV1SourceList
*/
def sessionCatalog: CatalogPlugin = catalogManager.v2SessionCatalog

/**
* Extract catalog plugin and remaining identifier names.
*
Expand All @@ -69,16 +51,14 @@ private[sql] trait LookupCatalog extends Logging {
}
}

type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier)

/**
* Extract catalog and identifier from a multi-part identifier with the default catalog if needed.
* Extract catalog and identifier from a multi-part identifier with the current catalog if needed.
*/
object CatalogObjectIdentifier {
def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match {
def unapply(parts: Seq[String]): Some[(CatalogPlugin, Identifier)] = parts match {
case CatalogAndIdentifier(maybeCatalog, nameParts) =>
Some((
maybeCatalog.orElse(defaultCatalog),
maybeCatalog.getOrElse(currentCatalog),
Identifier.of(nameParts.init.toArray, nameParts.last)
))
}
Expand Down Expand Up @@ -108,7 +88,7 @@ private[sql] trait LookupCatalog extends Logging {
*/
object AsTableIdentifier {
def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match {
case CatalogAndIdentifier(None, names) if defaultCatalog.isEmpty =>
case CatalogAndIdentifier(None, names) if CatalogV2Util.isSessionCatalog(currentCatalog) =>
names match {
case Seq(name) =>
Some(TableIdentifier(name))
Expand Down Expand Up @@ -146,8 +126,7 @@ private[sql] trait LookupCatalog extends Logging {
Some((catalogManager.catalog(nameParts.head), nameParts.tail))
} catch {
case _: CatalogNotFoundException =>
// TODO (SPARK-29014): use current catalog here.
Some((defaultCatalog.getOrElse(sessionCatalog), nameParts))
Some((currentCatalog, nameParts))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.FakeV2SessionCatalog
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand All @@ -36,29 +37,30 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
import CatalystSqlParser._

private val catalogs = Seq("prod", "test").map(x => x -> DummyCatalogPlugin(x)).toMap
private val sessionCatalog = FakeV2SessionCatalog

override val catalogManager: CatalogManager = {
val manager = mock(classOf[CatalogManager])
when(manager.catalog(any())).thenAnswer((invocation: InvocationOnMock) => {
val name = invocation.getArgument[String](0)
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
})
when(manager.defaultCatalog).thenReturn(None)
when(manager.currentCatalog).thenReturn(sessionCatalog)
manager
}

test("catalog object identifier") {
Seq(
("tbl", None, Seq.empty, "tbl"),
("db.tbl", None, Seq("db"), "tbl"),
("prod.func", catalogs.get("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", None, Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"),
("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
("`db.tbl`", None, Seq.empty, "db.tbl"),
("parquet.`file:/tmp/db.tbl`", None, Seq("parquet"), "file:/tmp/db.tbl"),
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None,
("tbl", sessionCatalog, Seq.empty, "tbl"),
("db.tbl", sessionCatalog, Seq("db"), "tbl"),
("prod.func", catalogs("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", sessionCatalog, Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs("prod"), Seq("db"), "tbl"),
("test.db.tbl", catalogs("test"), Seq("db"), "tbl"),
("test.ns1.ns2.ns3.tbl", catalogs("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
("`db.tbl`", sessionCatalog, Seq.empty, "db.tbl"),
("parquet.`file:/tmp/db.tbl`", sessionCatalog, Seq("parquet"), "file:/tmp/db.tbl"),
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", sessionCatalog,
Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach {
case (sql, expectedCatalog, namespace, name) =>
inside(parseMultipartIdentifier(sql)) {
Expand Down Expand Up @@ -135,22 +137,22 @@ class LookupCatalogWithDefaultSuite extends SparkFunSuite with LookupCatalog wit
val name = invocation.getArgument[String](0)
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
})
when(manager.defaultCatalog).thenReturn(catalogs.get("prod"))
when(manager.currentCatalog).thenReturn(catalogs("prod"))
manager
}

test("catalog object identifier") {
Seq(
("tbl", catalogs.get("prod"), Seq.empty, "tbl"),
("db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
("prod.func", catalogs.get("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", catalogs.get("prod"), Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"),
("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
("`db.tbl`", catalogs.get("prod"), Seq.empty, "db.tbl"),
("parquet.`file:/tmp/db.tbl`", catalogs.get("prod"), Seq("parquet"), "file:/tmp/db.tbl"),
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", catalogs.get("prod"),
("tbl", catalogs("prod"), Seq.empty, "tbl"),
("db.tbl", catalogs("prod"), Seq("db"), "tbl"),
("prod.func", catalogs("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", catalogs("prod"), Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs("prod"), Seq("db"), "tbl"),
("test.db.tbl", catalogs("test"), Seq("db"), "tbl"),
("test.ns1.ns2.ns3.tbl", catalogs("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
("`db.tbl`", catalogs("prod"), Seq.empty, "db.tbl"),
("parquet.`file:/tmp/db.tbl`", catalogs("prod"), Seq("parquet"), "file:/tmp/db.tbl"),
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", catalogs("prod"),
Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach {
case (sql, expectedCatalog, namespace, name) =>
inside(parseMultipartIdentifier(sql)) {
Expand Down
18 changes: 10 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
def insertInto(tableName: String): Unit = {
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._

assertNotBucketed("insertInto")

Expand All @@ -354,14 +355,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

val session = df.sparkSession
val canUseV2 = lookupV2Provider().isDefined
val sessionCatalog = session.sessionState.analyzer.sessionCatalog

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(Some(catalog), ident) =>
case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) =>
insertInto(catalog, ident)

case CatalogObjectIdentifier(None, ident) if canUseV2 && ident.namespace().length <= 1 =>
insertInto(sessionCatalog, ident)
case CatalogObjectIdentifier(catalog, ident)
if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 =>
insertInto(catalog, ident)

case AsTableIdentifier(tableIdentifier) =>
insertInto(tableIdentifier)
Expand Down Expand Up @@ -480,17 +481,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
def saveAsTable(tableName: String): Unit = {
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._

val session = df.sparkSession
val canUseV2 = lookupV2Provider().isDefined
val sessionCatalog = session.sessionState.analyzer.sessionCatalog

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(Some(catalog), ident) =>
case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) =>
saveAsTable(catalog.asTableCatalog, ident)

case CatalogObjectIdentifier(None, ident) if canUseV2 && ident.namespace().length <= 1 =>
saveAsTable(sessionCatalog.asTableCatalog, ident)
case CatalogObjectIdentifier(catalog, ident)
if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 =>
saveAsTable(catalog.asTableCatalog, ident)

case AsTableIdentifier(tableIdentifier) =>
saveAsTable(tableIdentifier)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)

private val (catalog, identifier) = {
val CatalogObjectIdentifier(maybeCatalog, identifier) = tableName
val catalog = maybeCatalog.getOrElse(catalogManager.currentCatalog).asTableCatalog
(catalog, identifier)
val CatalogObjectIdentifier(catalog, identifier) = tableName
(catalog.asTableCatalog, identifier)
}

private val logicalPlan = df.queryExecution.logical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ class ResolveSessionCatalog(
}
ShowTablesCommand(Some(nameParts.head), pattern)

// TODO (SPARK-29014): we should check if the current catalog is session catalog here.
case ShowTablesStatement(None, pattern) if defaultCatalog.isEmpty =>
case ShowTablesStatement(None, pattern) if isSessionCatalog(currentCatalog) =>
ShowTablesCommand(None, pattern)

case AnalyzeTableStatement(tableName, partitionSpec, noScan) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class DataSourceV2DataFrameSessionCatalogSuite
val t1 = "prop_table"
withTable(t1) {
spark.range(20).write.format(v2Format).option("path", "abc").saveAsTable(t1)
val cat = spark.sessionState.catalogManager.v2SessionCatalog.asInstanceOf[TableCatalog]
val cat = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
val tableInfo = cat.loadTable(Identifier.of(Array.empty, t1))
assert(tableInfo.properties().get("location") === "abc")
assert(tableInfo.properties().get("provider") === v2Format)
Expand Down
Loading

0 comments on commit 39af51d

Please sign in to comment.