Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-28666] Support saveAsTable for V2 tables through Session Catalog #25402

Closed
wants to merge 16 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,11 @@ class Analyzer(
if catalog.isTemporaryTable(ident) =>
u // temporary views take precedence over catalog table names

case u @ UnresolvedRelation(CatalogObjectIdentifier(Some(catalogPlugin), ident)) =>
loadTable(catalogPlugin, ident).map(DataSourceV2Relation.create).getOrElse(u)
case u @ UnresolvedRelation(CatalogObjectIdentifier(maybeCatalog, ident)) =>
maybeCatalog.orElse(sessionCatalog)
.flatMap(loadTable(_, ident))
.map(DataSourceV2Relation.create)
.getOrElse(u)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A +1 on this.

}
}

Expand Down
24 changes: 21 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, NoSuchTableException, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, InsertIntoTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister}
import org.apache.spark.sql.sources.BaseRelation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: duplicated import?

import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.TableCapability._
import org.apache.spark.sql.types.{IntegerType, StructType}
Expand Down Expand Up @@ -493,13 +494,27 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._

import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
val session = df.sparkSession
val useV1Sources =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated code with save, possible to have a function?

session.sessionState.conf.useV1SourceWriterList.toLowerCase(Locale.ROOT).split(",")
val cls = DataSource.lookupDataSource(source, session.sessionState.conf)
val shouldUseV1Source = cls.newInstance() match {
case d: DataSourceRegister if useV1Sources.contains(d.shortName()) => true
case _ => useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT))
}

val canUseV2 = !shouldUseV1Source && classOf[TableProvider].isAssignableFrom(cls)
val sessionCatalogOpt = session.sessionState.analyzer.sessionCatalog

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(Some(catalog), ident) =>
saveAsTable(catalog.asTableCatalog, ident, modeForDSV2)
// TODO(SPARK-28666): This should go through V2SessionCatalog

case CatalogObjectIdentifier(None, ident)
if canUseV2 && sessionCatalogOpt.isDefined && ident.namespace().length <= 1 =>
// We pass in the modeForDSV1, as using the V2 session catalog should maintain compatibility
// for now.
saveAsTable(sessionCatalogOpt.get.asTableCatalog, ident, modeForDSV1)

case AsTableIdentifier(tableIdentifier) =>
saveAsTable(tableIdentifier)
Expand All @@ -523,6 +538,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val tableOpt = try Option(catalog.loadTable(ident)) catch {
case _: NoSuchTableException => None
}
if (tableOpt.exists(_.isInstanceOf[CatalogTableAsV2])) {
return saveAsTable(TableIdentifier(ident.name(), ident.namespace().headOption))
}

val command = (mode, tableOpt) match {
case (SaveMode.Append, Some(table)) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.sources.v2

import java.util
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalog.v2.Identifier
import org.apache.spark.sql.catalog.v2.expressions.Transform
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class DataSourceV2DataFrameSessionCatalogSuite
extends QueryTest
with SharedSQLContext
with BeforeAndAfter {
import testImplicits._

private val v2Format = classOf[InMemoryTableProvider].getName

before {
spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, classOf[TestV2SessionCatalog].getName)
}

override def afterEach(): Unit = {
super.afterEach()
spark.catalog("session").asInstanceOf[TestV2SessionCatalog].clearTables()
}

test("saveAsTable and v2 table - table doesn't exist") {
val t1 = "tbl"
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
df.write.format(v2Format).saveAsTable(t1)
checkAnswer(spark.table(t1), df)
}

test("saveAsTable: v2 table - table exists") {
val t1 = "tbl"
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
spark.sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format")
intercept[TableAlreadyExistsException] {
df.select("id", "data").write.format(v2Format).saveAsTable(t1)
}
df.write.format(v2Format).mode("append").saveAsTable(t1)
checkAnswer(spark.table(t1), df)

// Check that appends are by name
df.select('data, 'id).write.format(v2Format).mode("append").saveAsTable(t1)
Copy link
Contributor

@cloud-fan cloud-fan Aug 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, in DS v1, saveAsTable fails if the table exists, but the table provider is different from the one specified in df.write.format. Do we have this check in the v2 code path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the provider isn't necessarily exposed by the table API, I'm not sure if such a check is required/possible.

checkAnswer(spark.table(t1), df.union(df))
}

test("saveAsTable: v2 table - table overwrite and table doesn't exist") {
val t1 = "tbl"
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
df.write.format(v2Format).mode("overwrite").saveAsTable(t1)
checkAnswer(spark.table(t1), df)
}

test("saveAsTable: v2 table - table overwrite and table exists") {
val t1 = "tbl"
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
spark.sql(s"CREATE TABLE $t1 USING $v2Format AS SELECT 'c', 'd'")
df.write.format(v2Format).mode("overwrite").saveAsTable(t1)
checkAnswer(spark.table(t1), df)
}

test("saveAsTable: v2 table - ignore mode and table doesn't exist") {
val t1 = "tbl"
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
df.write.format(v2Format).mode("ignore").saveAsTable(t1)
checkAnswer(spark.table(t1), df)
}

test("saveAsTable: v2 table - ignore mode and table exists") {
val t1 = "tbl"
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
spark.sql(s"CREATE TABLE $t1 USING $v2Format AS SELECT 'c', 'd'")
df.write.format(v2Format).mode("ignore").saveAsTable(t1)
checkAnswer(spark.table(t1), Seq(Row("c", "d")))
}
}

class InMemoryTableProvider extends TableProvider {
override def getTable(options: CaseInsensitiveStringMap): Table = {
throw new UnsupportedOperationException("D'oh!")
}
}

/** A SessionCatalog that always loads an in memory Table, so we can test write code paths. */
class TestV2SessionCatalog extends V2SessionCatalog {

protected val tables: util.Map[Identifier, InMemoryTable] =
new ConcurrentHashMap[Identifier, InMemoryTable]()

override def loadTable(ident: Identifier): Table = {
if (tables.containsKey(ident)) {
tables.get(ident)
} else {
// Table was created through the built-in catalog
val t = super.loadTable(ident)
val table = new InMemoryTable(t.name(), t.schema(), t.partitioning(), t.properties())
tables.put(ident, table)
table
}
}

override def createTable(
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String]): Table = {
val t = new InMemoryTable(ident.name(), schema, partitions, properties)
tables.put(ident, t)
t
}

def clearTables(): Unit = {
assert(!tables.isEmpty, "Tables were empty, maybe didn't use the session catalog code path?")
tables.keySet().asScala.foreach(super.dropTable)
tables.clear()
}
}