Skip to content

Commit

Permalink
[SPARK-29014][SQL] DataSourceV2: Fix current/default catalog usage
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
The handling of the catalog across plans should be as follows ([SPARK-29014](https://issues.apache.org/jira/browse/SPARK-29014)):
* The *current* catalog should be used when no catalog is specified
* The default catalog is the catalog *current* is initialized to
* If the *default* catalog is not set, then *current* catalog is the built-in Spark session catalog.

This PR addresses the issue where *current* catalog usage is not followed as describe above.

### Why are the changes needed?

It is a bug as described in the previous section.

### Does this PR introduce any user-facing change?
No.

### How was this patch tested?

Unit tests added.

Closes #26120 from imback82/cleanup_catalog.

Authored-by: Terry Kim <yuminkim@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
imback82 authored and cloud-fan committed Oct 18, 2019
1 parent 7435146 commit 39af51d
Show file tree
Hide file tree
Showing 12 changed files with 98 additions and 94 deletions.
Expand Up @@ -125,9 +125,9 @@ class Analyzer(
maxIterations: Int)
extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog {

private val catalog: SessionCatalog = catalogManager.v1SessionCatalog
private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog

override def isView(nameParts: Seq[String]): Boolean = catalog.isView(nameParts)
override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts)

// Only for tests.
def this(catalog: SessionCatalog, conf: SQLConf) = {
Expand Down Expand Up @@ -225,7 +225,7 @@ class Analyzer(
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables(conf) ::
ResolveHigherOrderFunctions(catalog) ::
ResolveHigherOrderFunctions(v1SessionCatalog) ::
ResolveLambdaVariables(conf) ::
ResolveTimeZone(conf) ::
ResolveRandomSeed ::
Expand Down Expand Up @@ -721,7 +721,7 @@ class Analyzer(
// have empty defaultDatabase and all the relations in viewText have database part defined.
def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match {
case u @ UnresolvedRelation(AsTemporaryViewIdentifier(ident))
if catalog.isTemporaryTable(ident) =>
if v1SessionCatalog.isTemporaryTable(ident) =>
resolveRelation(lookupTableFromCatalog(ident, u, AnalysisContext.get.defaultDatabase))

case u @ UnresolvedRelation(AsTableIdentifier(ident)) if !isRunningDirectlyOnFiles(ident) =>
Expand Down Expand Up @@ -778,7 +778,7 @@ class Analyzer(
val tableIdentWithDb = tableIdentifier.copy(
database = tableIdentifier.database.orElse(defaultDatabase))
try {
catalog.lookupRelation(tableIdentWithDb)
v1SessionCatalog.lookupRelation(tableIdentWithDb)
} catch {
case _: NoSuchTableException | _: NoSuchDatabaseException =>
u
Expand All @@ -792,8 +792,9 @@ class Analyzer(
// Note that we are testing (!db_exists || !table_exists) because the catalog throws
// an exception from tableExists if the database does not exist.
private def isRunningDirectlyOnFiles(table: TableIdentifier): Boolean = {
table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) &&
(!catalog.databaseExists(table.database.get) || !catalog.tableExists(table))
table.database.isDefined && conf.runSQLonFile && !v1SessionCatalog.isTemporaryTable(table) &&
(!v1SessionCatalog.databaseExists(table.database.get)
|| !v1SessionCatalog.tableExists(table))
}
}

Expand Down Expand Up @@ -1511,13 +1512,14 @@ class Analyzer(
plan.resolveExpressions {
case f: UnresolvedFunction
if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f
case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) =>
case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) =>
externalFunctionNameSet.add(normalizeFuncName(f.name))
f
case f: UnresolvedFunction =>
withPosition(f) {
throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase),
throw new NoSuchFunctionException(
f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
f.name.funcName)
}
}
Expand All @@ -1532,7 +1534,7 @@ class Analyzer(

val databaseName = name.database match {
case Some(a) => formatDatabaseName(a)
case None => catalog.getCurrentDatabase
case None => v1SessionCatalog.getCurrentDatabase
}

FunctionIdentifier(funcName, Some(databaseName))
Expand All @@ -1557,7 +1559,7 @@ class Analyzer(
}
case u @ UnresolvedGenerator(name, children) =>
withPosition(u) {
catalog.lookupFunction(name, children) match {
v1SessionCatalog.lookupFunction(name, children) match {
case generator: Generator => generator
case other =>
failAnalysis(s"$name is expected to be a generator. However, " +
Expand All @@ -1566,7 +1568,7 @@ class Analyzer(
}
case u @ UnresolvedFunction(funcId, children, isDistinct) =>
withPosition(u) {
catalog.lookupFunction(funcId, children) match {
v1SessionCatalog.lookupFunction(funcId, children) match {
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
// the context of a Window clause. They do not need to be wrapped in an
// AggregateExpression.
Expand Down Expand Up @@ -2765,17 +2767,17 @@ class Analyzer(
private def lookupV2RelationAndCatalog(
identifier: Seq[String]): Option[(DataSourceV2Relation, CatalogPlugin, Identifier)] =
identifier match {
case AsTemporaryViewIdentifier(ti) if catalog.isTemporaryTable(ti) => None
case CatalogObjectIdentifier(Some(v2Catalog), ident) =>
CatalogV2Util.loadTable(v2Catalog, ident) match {
case Some(table) => Some((DataSourceV2Relation.create(table), v2Catalog, ident))
case AsTemporaryViewIdentifier(ti) if v1SessionCatalog.isTemporaryTable(ti) => None
case CatalogObjectIdentifier(catalog, ident) if !CatalogV2Util.isSessionCatalog(catalog) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(table) => Some((DataSourceV2Relation.create(table), catalog, ident))
case None => None
}
case CatalogObjectIdentifier(None, ident) =>
CatalogV2Util.loadTable(catalogManager.v2SessionCatalog, ident) match {
case CatalogObjectIdentifier(catalog, ident) if CatalogV2Util.isSessionCatalog(catalog) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(_: V1Table) => None
case Some(table) =>
Some((DataSourceV2Relation.create(table), catalogManager.v2SessionCatalog, ident))
Some((DataSourceV2Relation.create(table), catalog, ident))
case None => None
}
case _ => None
Expand Down
Expand Up @@ -177,9 +177,8 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
case ShowTablesStatement(Some(NonSessionCatalog(catalog, nameParts)), pattern) =>
ShowTables(catalog.asTableCatalog, nameParts, pattern)

// TODO (SPARK-29014): we should check if the current catalog is not session catalog here.
case ShowTablesStatement(None, pattern) if defaultCatalog.isDefined =>
ShowTables(defaultCatalog.get.asTableCatalog, catalogManager.currentNamespace, pattern)
case ShowTablesStatement(None, pattern) if !isSessionCatalog(currentCatalog) =>
ShowTables(currentCatalog.asTableCatalog, catalogManager.currentNamespace, pattern)

case UseStatement(isNamespaceSet, nameParts) =>
if (isNamespaceSet) {
Expand Down
Expand Up @@ -53,7 +53,7 @@ class CatalogManager(
}
}

def defaultCatalog: Option[CatalogPlugin] = {
private def defaultCatalog: Option[CatalogPlugin] = {
conf.defaultV2Catalog.flatMap { catalogName =>
try {
Some(catalog(catalogName))
Expand All @@ -74,9 +74,16 @@ class CatalogManager(
}
}

// If the V2_SESSION_CATALOG_IMPLEMENTATION config is specified, we try to instantiate the
// user-specified v2 session catalog. Otherwise, return the default session catalog.
def v2SessionCatalog: CatalogPlugin = {
/**
* If the V2_SESSION_CATALOG config is specified, we try to instantiate the user-specified v2
* session catalog. Otherwise, return the default session catalog.
*
* This catalog is a v2 catalog that delegates to the v1 session catalog. it is used when the
* session catalog is responsible for an identifier, but the source requires the v2 catalog API.
* This happens when the source implementation extends the v2 TableProvider API and is not listed
* in the fallback configuration, spark.sql.sources.write.useV1SourceList
*/
private def v2SessionCatalog: CatalogPlugin = {
conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).map { customV2SessionCatalog =>
try {
catalogs.getOrElseUpdate(SESSION_CATALOG_NAME, loadV2SessionCatalog())
Expand Down
Expand Up @@ -27,29 +27,11 @@ private[sql] trait LookupCatalog extends Logging {

protected val catalogManager: CatalogManager

/**
* Returns the default catalog. When set, this catalog is used for all identifiers that do not
* set a specific catalog. When this is None, the session catalog is responsible for the
* identifier.
*
* If this is None and a table's provider (source) is a v2 provider, the v2 session catalog will
* be used.
*/
def defaultCatalog: Option[CatalogPlugin] = catalogManager.defaultCatalog

/**
* Returns the current catalog set.
*/
def currentCatalog: CatalogPlugin = catalogManager.currentCatalog

/**
* This catalog is a v2 catalog that delegates to the v1 session catalog. it is used when the
* session catalog is responsible for an identifier, but the source requires the v2 catalog API.
* This happens when the source implementation extends the v2 TableProvider API and is not listed
* in the fallback configuration, spark.sql.sources.write.useV1SourceList
*/
def sessionCatalog: CatalogPlugin = catalogManager.v2SessionCatalog

/**
* Extract catalog plugin and remaining identifier names.
*
Expand All @@ -69,16 +51,14 @@ private[sql] trait LookupCatalog extends Logging {
}
}

type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier)

/**
* Extract catalog and identifier from a multi-part identifier with the default catalog if needed.
* Extract catalog and identifier from a multi-part identifier with the current catalog if needed.
*/
object CatalogObjectIdentifier {
def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match {
def unapply(parts: Seq[String]): Some[(CatalogPlugin, Identifier)] = parts match {
case CatalogAndIdentifier(maybeCatalog, nameParts) =>
Some((
maybeCatalog.orElse(defaultCatalog),
maybeCatalog.getOrElse(currentCatalog),
Identifier.of(nameParts.init.toArray, nameParts.last)
))
}
Expand Down Expand Up @@ -108,7 +88,7 @@ private[sql] trait LookupCatalog extends Logging {
*/
object AsTableIdentifier {
def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match {
case CatalogAndIdentifier(None, names) if defaultCatalog.isEmpty =>
case CatalogAndIdentifier(None, names) if CatalogV2Util.isSessionCatalog(currentCatalog) =>
names match {
case Seq(name) =>
Some(TableIdentifier(name))
Expand Down Expand Up @@ -146,8 +126,7 @@ private[sql] trait LookupCatalog extends Logging {
Some((catalogManager.catalog(nameParts.head), nameParts.tail))
} catch {
case _: CatalogNotFoundException =>
// TODO (SPARK-29014): use current catalog here.
Some((defaultCatalog.getOrElse(sessionCatalog), nameParts))
Some((currentCatalog, nameParts))
}
}
}
Expand Down
Expand Up @@ -24,6 +24,7 @@ import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.FakeV2SessionCatalog
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand All @@ -36,29 +37,30 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
import CatalystSqlParser._

private val catalogs = Seq("prod", "test").map(x => x -> DummyCatalogPlugin(x)).toMap
private val sessionCatalog = FakeV2SessionCatalog

override val catalogManager: CatalogManager = {
val manager = mock(classOf[CatalogManager])
when(manager.catalog(any())).thenAnswer((invocation: InvocationOnMock) => {
val name = invocation.getArgument[String](0)
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
})
when(manager.defaultCatalog).thenReturn(None)
when(manager.currentCatalog).thenReturn(sessionCatalog)
manager
}

test("catalog object identifier") {
Seq(
("tbl", None, Seq.empty, "tbl"),
("db.tbl", None, Seq("db"), "tbl"),
("prod.func", catalogs.get("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", None, Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"),
("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
("`db.tbl`", None, Seq.empty, "db.tbl"),
("parquet.`file:/tmp/db.tbl`", None, Seq("parquet"), "file:/tmp/db.tbl"),
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None,
("tbl", sessionCatalog, Seq.empty, "tbl"),
("db.tbl", sessionCatalog, Seq("db"), "tbl"),
("prod.func", catalogs("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", sessionCatalog, Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs("prod"), Seq("db"), "tbl"),
("test.db.tbl", catalogs("test"), Seq("db"), "tbl"),
("test.ns1.ns2.ns3.tbl", catalogs("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
("`db.tbl`", sessionCatalog, Seq.empty, "db.tbl"),
("parquet.`file:/tmp/db.tbl`", sessionCatalog, Seq("parquet"), "file:/tmp/db.tbl"),
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", sessionCatalog,
Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach {
case (sql, expectedCatalog, namespace, name) =>
inside(parseMultipartIdentifier(sql)) {
Expand Down Expand Up @@ -135,22 +137,22 @@ class LookupCatalogWithDefaultSuite extends SparkFunSuite with LookupCatalog wit
val name = invocation.getArgument[String](0)
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
})
when(manager.defaultCatalog).thenReturn(catalogs.get("prod"))
when(manager.currentCatalog).thenReturn(catalogs("prod"))
manager
}

test("catalog object identifier") {
Seq(
("tbl", catalogs.get("prod"), Seq.empty, "tbl"),
("db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
("prod.func", catalogs.get("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", catalogs.get("prod"), Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"),
("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
("`db.tbl`", catalogs.get("prod"), Seq.empty, "db.tbl"),
("parquet.`file:/tmp/db.tbl`", catalogs.get("prod"), Seq("parquet"), "file:/tmp/db.tbl"),
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", catalogs.get("prod"),
("tbl", catalogs("prod"), Seq.empty, "tbl"),
("db.tbl", catalogs("prod"), Seq("db"), "tbl"),
("prod.func", catalogs("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", catalogs("prod"), Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs("prod"), Seq("db"), "tbl"),
("test.db.tbl", catalogs("test"), Seq("db"), "tbl"),
("test.ns1.ns2.ns3.tbl", catalogs("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
("`db.tbl`", catalogs("prod"), Seq.empty, "db.tbl"),
("parquet.`file:/tmp/db.tbl`", catalogs("prod"), Seq("parquet"), "file:/tmp/db.tbl"),
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", catalogs("prod"),
Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach {
case (sql, expectedCatalog, namespace, name) =>
inside(parseMultipartIdentifier(sql)) {
Expand Down
18 changes: 10 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Expand Up @@ -341,6 +341,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
def insertInto(tableName: String): Unit = {
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._

assertNotBucketed("insertInto")

Expand All @@ -354,14 +355,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

val session = df.sparkSession
val canUseV2 = lookupV2Provider().isDefined
val sessionCatalog = session.sessionState.analyzer.sessionCatalog

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(Some(catalog), ident) =>
case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) =>
insertInto(catalog, ident)

case CatalogObjectIdentifier(None, ident) if canUseV2 && ident.namespace().length <= 1 =>
insertInto(sessionCatalog, ident)
case CatalogObjectIdentifier(catalog, ident)
if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 =>
insertInto(catalog, ident)

case AsTableIdentifier(tableIdentifier) =>
insertInto(tableIdentifier)
Expand Down Expand Up @@ -480,17 +481,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
def saveAsTable(tableName: String): Unit = {
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._

val session = df.sparkSession
val canUseV2 = lookupV2Provider().isDefined
val sessionCatalog = session.sessionState.analyzer.sessionCatalog

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(Some(catalog), ident) =>
case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) =>
saveAsTable(catalog.asTableCatalog, ident)

case CatalogObjectIdentifier(None, ident) if canUseV2 && ident.namespace().length <= 1 =>
saveAsTable(sessionCatalog.asTableCatalog, ident)
case CatalogObjectIdentifier(catalog, ident)
if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 =>
saveAsTable(catalog.asTableCatalog, ident)

case AsTableIdentifier(tableIdentifier) =>
saveAsTable(tableIdentifier)
Expand Down
Expand Up @@ -51,9 +51,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)

private val (catalog, identifier) = {
val CatalogObjectIdentifier(maybeCatalog, identifier) = tableName
val catalog = maybeCatalog.getOrElse(catalogManager.currentCatalog).asTableCatalog
(catalog, identifier)
val CatalogObjectIdentifier(catalog, identifier) = tableName
(catalog.asTableCatalog, identifier)
}

private val logicalPlan = df.queryExecution.logical
Expand Down
Expand Up @@ -262,8 +262,7 @@ class ResolveSessionCatalog(
}
ShowTablesCommand(Some(nameParts.head), pattern)

// TODO (SPARK-29014): we should check if the current catalog is session catalog here.
case ShowTablesStatement(None, pattern) if defaultCatalog.isEmpty =>
case ShowTablesStatement(None, pattern) if isSessionCatalog(currentCatalog) =>
ShowTablesCommand(None, pattern)

case AnalyzeTableStatement(tableName, partitionSpec, noScan) =>
Expand Down
Expand Up @@ -84,7 +84,7 @@ class DataSourceV2DataFrameSessionCatalogSuite
val t1 = "prop_table"
withTable(t1) {
spark.range(20).write.format(v2Format).option("path", "abc").saveAsTable(t1)
val cat = spark.sessionState.catalogManager.v2SessionCatalog.asInstanceOf[TableCatalog]
val cat = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
val tableInfo = cat.loadTable(Identifier.of(Array.empty, t1))
assert(tableInfo.properties().get("location") === "abc")
assert(tableInfo.properties().get("provider") === v2Format)
Expand Down

0 comments on commit 39af51d

Please sign in to comment.